From b81b768b04730ef2f76f1650d2f3e7b958c7c4e9 Mon Sep 17 00:00:00 2001 From: Ske Date: Thu, 12 Jul 2018 00:47:44 +0200 Subject: [PATCH] Initial commit --- .gitignore | 2 + bot/Dockerfile | 11 ++ bot/main.py | 6 + bot/pluralkit/__init__.py | 1 + bot/pluralkit/bot.py | 82 ++++++++ bot/pluralkit/commands.py | 392 ++++++++++++++++++++++++++++++++++++++ bot/pluralkit/db.py | 193 +++++++++++++++++++ bot/pluralkit/proxy.py | 93 +++++++++ bot/pluralkit/utils.py | 211 ++++++++++++++++++++ bot/requirements.txt | 4 + docker-compose.yml | 27 +++ 11 files changed, 1022 insertions(+) create mode 100644 .gitignore create mode 100644 bot/Dockerfile create mode 100644 bot/main.py create mode 100644 bot/pluralkit/__init__.py create mode 100644 bot/pluralkit/bot.py create mode 100644 bot/pluralkit/commands.py create mode 100644 bot/pluralkit/db.py create mode 100644 bot/pluralkit/proxy.py create mode 100644 bot/pluralkit/utils.py create mode 100644 bot/requirements.txt create mode 100644 docker-compose.yml diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..39a35e82 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.env +.vscode/ diff --git a/bot/Dockerfile b/bot/Dockerfile new file mode 100644 index 00000000..7c2b6851 --- /dev/null +++ b/bot/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.6-alpine + +RUN apk --no-cache add build-base + +WORKDIR /app +ADD requirements.txt /app +RUN pip install --trusted-host pypi.python.org -r requirements.txt + +ADD . /app +ENTRYPOINT ["python", "main.py"] + diff --git a/bot/main.py b/bot/main.py new file mode 100644 index 00000000..7e3233f6 --- /dev/null +++ b/bot/main.py @@ -0,0 +1,6 @@ +import asyncio + +from pluralkit import bot + +loop = asyncio.get_event_loop() +loop.run_until_complete(bot.run()) \ No newline at end of file diff --git a/bot/pluralkit/__init__.py b/bot/pluralkit/__init__.py new file mode 100644 index 00000000..37d28ec9 --- /dev/null +++ b/bot/pluralkit/__init__.py @@ -0,0 +1 @@ +from . import commands, db, proxy \ No newline at end of file diff --git a/bot/pluralkit/bot.py b/bot/pluralkit/bot.py new file mode 100644 index 00000000..fcc1ecf6 --- /dev/null +++ b/bot/pluralkit/bot.py @@ -0,0 +1,82 @@ +import logging +import os + +import discord + +logging.basicConfig(level=logging.DEBUG) +logging.getLogger("discord").setLevel(logging.INFO) +logging.getLogger("websockets").setLevel(logging.INFO) + +logger = logging.getLogger("pluralkit.bot") +logger.setLevel(logging.DEBUG) + +client = discord.Client() + +@client.event +async def on_error(evt, *args, **kwargs): + logger.exception("Error while handling event {} with arguments {}:".format(evt, args)) + +@client.event +async def on_ready(): + # Print status info + logger.info("Connected to Discord.") + logger.info("Account: {}#{}".format(client.user.name, client.user.discriminator)) + logger.info("User ID: {}".format(client.user.id)) + +@client.event +async def on_message(message): + # Ignore bot messages + if message.author.bot: + return + + # Split into args. shlex sucks so we don't bother with quotes + args = message.content.split(" ") + + from pluralkit import proxy, utils + + # Find and execute command in map + if len(args) > 0 and args[0] in utils.command_map: + subcommand_map = utils.command_map[args[0]] + + if len(args) >= 2 and args[1] in subcommand_map: + async with client.pool.acquire() as conn: + await subcommand_map[args[1]][0](conn, message, args[2:]) + elif None in subcommand_map: + async with client.pool.acquire() as conn: + await subcommand_map[None][0](conn, message, args[1:]) + elif len(args) >= 2: + embed = discord.Embed() + embed.colour = discord.Colour.dark_red() + embed.description = "Subcommand \"{}\" not found.".format(args[1]) + await client.send_message(message.channel, embed=embed) + else: + # Try doing proxy parsing + async with client.pool.acquire() as conn: + await proxy.handle_proxying(conn, message) + +@client.event +async def on_reaction_add(reaction, user): + from pluralkit import proxy + + # Pass reactions to proxy system + async with client.pool.acquire() as conn: + await proxy.handle_reaction(conn, reaction, user) + +async def run(): + from pluralkit import db + try: + logger.info("Connecting to database...") + pool = await db.connect() + + logger.info("Attempting to create tables...") + async with pool.acquire() as conn: + await db.create_tables(conn) + + logger.info("Connecting to InfluxDB...") + + client.pool = pool + logger.info("Connecting to Discord...") + await client.start(os.environ["TOKEN"]) + finally: + logger.info("Logging out from Discord...") + await client.logout() \ No newline at end of file diff --git a/bot/pluralkit/commands.py b/bot/pluralkit/commands.py new file mode 100644 index 00000000..a7fc9af2 --- /dev/null +++ b/bot/pluralkit/commands.py @@ -0,0 +1,392 @@ +from datetime import datetime +import re +from urllib.parse import urlparse + +import discord + +from pluralkit import db +from pluralkit.bot import client, logger +from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map + +@command(cmd="pk;system", subcommand=None, description="Shows information about your system.") +async def this_system_info(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) + return True + +@command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.") +async def new_system(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is not None: + return False, "You already have a system registered. To remove your system, use `pk;system remove`, or to unlink your system from this account, use `pk;system unlink`." + + system_name = None + if len(args) > 2: + system_name = " ".join(args[2:]) + + async with conn.transaction(): + # TODO: figure out what to do if this errors out on collision on generate_hid + hid = generate_hid() + + system = await db.create_system(conn, system_name=system_name, system_hid=hid) + + # Link account + await db.link_account(conn, system_id=system["id"], account_id=message.author.id) + return True, "System registered! To begin adding members, use `pk;member new `." + +@command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.") +async def system_info(conn, message, args): + if len(args) == 0: + # Use sender's system + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + else: + # Look one up + system = await get_system_fuzzy(conn, args[0]) + + if system is None: + return False, "Unable to find system \"{}\".".format(args[0]) + + await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) + return True + +@command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.") +async def system_name(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + if len(args) == 0: + new_name = None + else: + new_name = " ".join(args) + + async with conn.transaction(): + await db.update_system_field(conn, system_id=system["id"], field="name", value=new_name) + return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared." + +@command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.") +async def system_description(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + # If "clear" in args, clear + if len(args) > 0 and args[0] == "clear": + new_description = None + else: + new_description = await text_input(message, "your system") + + if not new_description: + return True, "Description update cancelled." + + async with conn.transaction(): + await db.update_system_field(conn, system_id=system["id"], field="description", value=new_description) + return True, "Description set." if new_description else "Description cleared." + +@command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.") +async def system_tag(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + if len(args) == 0: + tag = None + else: + tag = " ".join(args) + max_length = 32 + + # Make sure there are no members which would make the combined length exceed 32 + members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=max_length - len(tag)) + if len(members_exceeding) > 0: + # If so, error out and warn + member_names = ", ".join([member["name"] for member in members_exceeding]) + logger.debug("Members exceeding combined length with tag '{}': {}".format(tag, member_names)) + return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names) + + async with conn.transaction(): + await db.update_system_field(conn, system_id=system["id"], field="tag", value=tag) + + return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared." + +@command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.") +async def system_remove(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"])) + + msg = await client.wait_for_message(author=message.author, channel=message.channel) + if msg.content == system["hid"]: + await db.remove_system(conn, system_id=system["id"]) + return True, "System removed." + else: + return True, "System removal cancelled." + + +@command(cmd="pk;system", subcommand="link", usage="", description="Links another account to your system.") +async def system_link(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + if len(args) == 0: + return False + + # Find account to link + linkee = await parse_mention(args[0]) + if not linkee: + return False, "Account not found." + + # Make sure account doesn't already have a system + account_system = await db.get_system_by_account(conn, linkee.id) + if account_system: + return False, "Account is already linked to a system (`{}`)".format(account_system["hid"]) + + # Send confirmation message + msg = await client.send_message(message.channel, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention)) + await client.add_reaction(msg, "✅") + await client.add_reaction(msg, "❌") + + reaction = await client.wait_for_reaction(emoji=["✅", "❌"], message=msg, user=linkee) + # If account to be linked confirms... + if reaction.reaction.emoji == "✅": + async with conn.transaction(): + # Execute the link + await db.link_account(conn, system_id=system["id"], account_id=linkee.id) + return True, "Account linked to system." + else: + await client.clear_reactions(msg) + return False, "Account link cancelled." + + +@command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.") +async def system_unlink(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + # Make sure you can't unlink every account + linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"]) + if len(linked_accounts) == 1: + return False, "This is the only account on your system, so you can't unlink it." + + async with conn.transaction(): + await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) + return True, "Account unlinked." + +@command(cmd="pk;member", subcommand="new", usage="", description="Adds a new member to your system.") +async def new_member(conn, message, args): + system = await db.get_system_by_account(conn, message.author.id) + + if system is None: + return False, "No system is registered to this account." + + if len(args) == 0: + return False + + name = " ".join(args) + async with conn.transaction(): + # TODO: figure out what to do if this errors out on collision on generate_hid + hid = generate_hid() + + # Insert member row + await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid) + return True, "Member \"{}\" (`{}`) registered!".format(name, hid) + +@member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False) +async def member_info(conn, message, member, args): + await client.send_message(message.channel, embed=await generate_member_info_card(conn, member)) + return True + +@member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.") +async def member_color(conn, message, member, args): + if len(args) == 0: + color = None + else: + match = re.fullmatch("#?([0-9a-f]{6})", args[0]) + if not match: + return False, "Color must be a valid hex color (eg. #ff0000)" + + color = match.group(1) + + async with conn.transaction(): + await db.update_member_field(conn, member_id=member["id"], field="color", value=color) + return True, "Color updated to #{}.".format(color) if color else "Color cleared." + +@member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.") +async def member_pronouns(conn, message, member, args): + if len(args) == 0: + pronouns = None + else: + pronouns = " ".join(args) + + async with conn.transaction(): + await db.update_member_field(conn, member_id=member["id"], field="pronouns", value=pronouns) + return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared." + +@member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.") +async def member_birthday(conn, message, member, args): + if len(args) == 0: + new_date = None + else: + # Parse date + try: + new_date = datetime.strptime(args[0], "%Y-%m-%d").date() + except ValueError: + return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)." + + async with conn.transaction(): + await db.update_member_field(conn, member_id=member["id"], field="birthday", value=new_date) + return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared." + +@member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.") +async def member_description(conn, message, member, args): + if len(args) > 0 and args[0] == "clear": + new_description = None + else: + new_description = await text_input(message, member["name"]) + + if not new_description: + return True, "Description update cancelled." + + async with conn.transaction(): + await db.update_member_field(conn, member_id=member["id"], field="description", value=new_description) + return True, "Description set." if new_description else "Description cleared." + +@member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.") +async def member_remove(conn, message, member, args): + await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"])) + + msg = await client.wait_for_message(author=message.author, channel=message.channel) + if msg.content == member["name"]: + await db.delete_member(conn, member_id=member["id"]) + return True, "Member removed." + else: + return True, "Member removal cancelled." + +@member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.") +async def member_avatar(conn, message, member, args): + if len(args) == 0: + avatar_url = None + else: + user = await parse_mention(args[0]) + if user: + # Set the avatar to the mentioned user's avatar + # Discord doesn't like webp, but also hosts png alternatives + avatar_url = user.avatar_url.replace(".webp", ".png") + else: + # Validate URL + u = urlparse(args[0]) + if u.scheme in ["http", "https"] and u.netloc and u.path: + avatar_url = args[0] + else: + return False, "Invalid URL." + + async with conn.transaction(): + await db.update_member_field(conn, member_id=member["id"], field="avatar_url", value=avatar_url) + return True, "Avatar set." if avatar_url else "Avatar cleared." + + +@member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).") +async def member_proxy(conn, message, member, args): + if len(args) == 0: + prefix, suffix = None, None + else: + # Sanity checking + example = " ".join(args) + if "text" not in example: + return False, "Example proxy message must contain the string 'text'." + + if example.count("text") != 1: + return False, "Example proxy message must contain the string 'text' exactly once." + + # Extract prefix and suffix + prefix = example[:example.index("text")].strip() + suffix = example[example.index("text")+4:].strip() + logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix)) + + # DB stores empty strings as None, make that work + if not prefix: + prefix = None + if not suffix: + suffix = None + + async with conn.transaction(): + await db.update_member_field(conn, member_id=member["id"], field="prefix", value=prefix) + await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix) + return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared." + +@command(cmd="pk;message", subcommand=None, usage="", description="Shows information about a proxied message. Requires the message ID.") +async def message_info(conn, message, args): + try: + mid = int(args[0]) + except ValueError: + return False + + # Find the message in the DB + message_row = await db.get_message(conn, mid) + if not message_row: + return False, "Message not found." + + # Find the actual message object + channel = client.get_channel(str(message_row["channel"])) + message = await client.get_message(channel, str(message_row["mid"])) + + # Get the original sender of the message + original_sender = await client.get_user_info(str(message_row["sender"])) + + # Get sender member and system + member = await db.get_member(conn, message_row["member"]) + system = await db.get_system(conn, member["system"]) + + embed = discord.Embed() + embed.timestamp = message.timestamp + embed.colour = discord.Colour.blue() + + if system["name"]: + system_value = "`{}`: {}".format(system["hid"], system["name"]) + else: + system_value = "`{}`".format(system["hid"]) + embed.add_field(name="System", value=system_value) + embed.add_field(name="Member", value="`{}`: {}".format(member["hid"], member["name"])) + embed.add_field(name="Sent by", value="{}#{}".format(original_sender.name, original_sender.discriminator)) + embed.add_field(name="Content", value=message.clean_content, inline=False) + + embed.set_author(name=member["name"], url=member["avatar_url"]) + + await client.send_message(message.channel, embed=embed) + return True + +@command(cmd="pk;help", subcommand=None, usage="[system|member|message]", description="Shows this help message.") +async def show_help(conn, message, args): + embed = discord.Embed() + embed.colour = discord.Colour.blue() + embed.title = "PluralKit Help" + embed.set_footer(text="<> denotes mandatory arguments, [] denotes optional arguments") + + if len(args) > 0 and ("pk;" + args[0]) in command_map: + cmds = ["", ("pk;" + args[0], command_map["pk;" + args[0]])] + else: + cmds = command_map.items() + + for cmd, subcommands in cmds: + for subcmd, (_, usage, description) in subcommands.items(): + embed.add_field(name="{} {} {}".format(cmd, subcmd or "", usage or ""), value=description, inline=False) + + await client.send_message(message.channel, embed=embed) + return True \ No newline at end of file diff --git a/bot/pluralkit/db.py b/bot/pluralkit/db.py new file mode 100644 index 00000000..eeda07c2 --- /dev/null +++ b/bot/pluralkit/db.py @@ -0,0 +1,193 @@ +import time + +import asyncpg +import asyncpg.exceptions + +from pluralkit.bot import logger + +async def connect(): + while True: + try: + return await asyncpg.create_pool(user="postgres", password="postgres", database="postgres", host="db") + except (ConnectionError, asyncpg.exceptions.CannotConnectNowError): + pass + +def db_wrap(func): + async def inner(*args, **kwargs): + before = time.perf_counter() + res = await func(*args, **kwargs) + after = time.perf_counter() + + logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000)) + return res + return inner + +webhook_cache = {} + +@db_wrap +async def create_system(conn, system_name: str, system_hid: str): + logger.debug("Creating system (name={}, hid={})".format(system_name, system_hid)) + return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid) + +@db_wrap +async def remove_system(conn, system_id: int): + logger.debug("Deleting system (id={})".format(system_id)) + await conn.execute("delete from systems where id = $1", system_id) + +@db_wrap +async def create_member(conn, system_id: int, member_name: str, member_hid: str): + logger.debug("Creating member (system={}, name={}, hid={})".format(system_id, member_name, member_hid)) + return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid) + +@db_wrap +async def delete_member(conn, member_id: int): + logger.debug("Deleting member (id={})".format(member_id)) + await conn.execute("update switches set member = null, member_del = true where member = $1", member_id) + await conn.execute("delete from members where id = $1", member_id) + +@db_wrap +async def link_account(conn, system_id: int, account_id: str): + logger.debug("Linking account (account_id={}, system_id={})".format(account_id, system_id)) + await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id) + +@db_wrap +async def unlink_account(conn, system_id: int, account_id: str): + logger.debug("Unlinking account (account_id={}, system_id={})".format(account_id, system_id)) + await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id) + +@db_wrap +async def get_linked_accounts(conn, system_id: int): + return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)] + +@db_wrap +async def get_system_by_account(conn, account_id: str): + return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id)) + +@db_wrap +async def get_system_by_hid(conn, system_hid: str): + return await conn.fetchrow("select * from systems where hid = $1", system_hid) + +@db_wrap +async def get_system(conn, system_id: int): + return await conn.fetchrow("select * from systems where id = $1", system_id) + +@db_wrap +async def get_member_by_name(conn, system_id: int, member_name: str): + return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name) + +@db_wrap +async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str): + return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid) + +@db_wrap +async def get_member_by_hid(conn, member_hid: str): + return await conn.fetchrow("select * from members where hid = $1", member_hid) + +@db_wrap +async def get_member(conn, member_id: int): + return await conn.fetchrow("select * from members where id = $1", member_id) + +@db_wrap +async def get_message(conn, message_id: str): + return await conn.fetchrow("select * from messages where mid = $1", message_id) + +@db_wrap +async def update_system_field(conn, system_id: int, field: str, value): + logger.debug("Updating system field (id={}, {}={})".format(system_id, field, value)) + await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id) + +@db_wrap +async def update_member_field(conn, member_id: int, field: str, value): + logger.debug("Updating member field (id={}, {}={})".format(member_id, field, value)) + await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id) + +@db_wrap +async def get_all_members(conn, system_id: int): + return await conn.fetch("select * from members where system = $1", system_id) + +@db_wrap +async def get_members_exceeding(conn, system_id: int, length: int): + return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length) + +@db_wrap +async def get_webhook(conn, channel_id: str): + if channel_id in webhook_cache: + return webhook_cache[channel_id] + res = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id)) + webhook_cache[channel_id] = res + return res + +@db_wrap +async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str): + logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(channel_id, webhook_id, webhook_token)) + await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token) + +@db_wrap +async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str): + logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(message_id, channel_id, member_id, sender_id)) + await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id)) + +@db_wrap +async def get_members_by_account(conn, account_id: str): + # Returns a "chimera" object + return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) + +@db_wrap +async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str): + await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id)) + +@db_wrap +async def delete_message(conn, message_id: str): + logger.debug("Deleting message (id={})".format(message_id)) + await conn.execute("delete from messages where mid = $1", int(message_id)) + +async def create_tables(conn): + await conn.execute("""create table if not exists systems ( + id serial primary key, + hid char(5) unique not null, + name text, + description text, + tag text, + created timestamp not null default current_timestamp + )""") + await conn.execute("""create table if not exists members ( + id serial primary key, + hid char(5) unique not null, + system serial not null references systems(id) on delete cascade, + color char(6), + avatar_url text, + name text not null, + birthday date, + pronouns text, + description text, + prefix text, + suffix text, + created timestamp not null default current_timestamp + )""") + await conn.execute("""create table if not exists accounts ( + uid bigint primary key, + system serial not null references systems(id) on delete cascade + )""") + await conn.execute("""create table if not exists messages ( + mid bigint primary key, + channel bigint not null, + member serial not null references members(id) on delete cascade, + sender bigint not null references accounts(uid) + )""") + await conn.execute("""create table if not exists switches ( + id serial primary key, + system serial not null references systems(id) on delete cascade, + member serial references members(id) on delete restrict, + timestamp timestamp not null default current_timestamp, + member_del bool default false + )""") + await conn.execute("""create table if not exists webhooks ( + channel bigint primary key, + webhook bigint not null, + token text not null + )""") + await conn.execute("""create table if not exists servers ( + id bigint primary key, + cmd_chans bigint[], + proxy_chans bigint[] + )""") \ No newline at end of file diff --git a/bot/pluralkit/proxy.py b/bot/pluralkit/proxy.py new file mode 100644 index 00000000..f99e6b9d --- /dev/null +++ b/bot/pluralkit/proxy.py @@ -0,0 +1,93 @@ +import os +import time + +import aiohttp + +from pluralkit import db +from pluralkit.bot import client, logger + +async def get_webhook(conn, channel): + async with conn.transaction(): + # Try to find an existing webhook + hook_row = await db.get_webhook(conn, channel_id=channel.id) + # There's none, we'll make one + if not hook_row: + async with aiohttp.ClientSession() as session: + req_data = {"name": "PluralKit Proxy Webhook"} + req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} + + async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp: + data = await resp.json() + hook_id = data["id"] + token = data["token"] + + # Insert new hook into DB + await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token) + return hook_id, token + + return hook_row["webhook"], hook_row["token"] + +async def proxy_message(conn, member, message, inner): + logger.debug("Proxying message '{}' for member {}".format(inner, member["hid"])) + # Delete the original message + await client.delete_message(message) + + # Get the webhook details + hook_id, hook_token = await get_webhook(conn, message.channel) + async with aiohttp.ClientSession() as session: + req_data = { + "username": "{} {}".format(member["name"], member["tag"] or "").strip(), + "avatar_url": member["avatar_url"], + "content": inner + } + req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} + # And send the message + async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp: + resp_data = await resp.json() + logger.debug("Discord webhook response: {}".format(resp_data)) + + # Insert new message details into the DB + await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id) + +async def handle_proxying(conn, message): + # Big fat query to find every member associated with this account + # Returned member object has a few more keys (system tag, for example) + members = await db.get_members_by_account(conn, account_id=message.author.id) + + # Sort by specificity (members with both prefix and suffix go higher) + members = sorted(members, key=lambda x: int(bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True) + + msg = message.content + for member in members: + # If no proxy details are configured, skip + if not member["prefix"] and not member["suffix"]: + continue + + # Database stores empty strings as null, fix that here + prefix = member["prefix"] or "" + suffix = member["suffix"] or "" + + # If we have a match, proxy the message + if msg.startswith(prefix) and msg.endswith(suffix): + # Extract the actual message contents sans tags + if suffix: + inner_message = message.content[len(prefix):-len(suffix)].strip() + else: + # Slicing to -0 breaks, don't do that + inner_message = message.content[len(prefix):].strip() + + await proxy_message(conn, member, message, inner_message) + break + + + +async def handle_reaction(conn, reaction, user): + if reaction.emoji == "❌": + async with conn.transaction(): + # Find the message in the DB, and make sure it's sent by the user who reacted + message = await db.get_message_by_sender_and_id(conn, message_id=reaction.message.id, sender_id=user.id) + + if message: + # If so, delete the message and remove it from the DB + await db.delete_message(conn, message["mid"]) + await client.delete_message(reaction.message) \ No newline at end of file diff --git a/bot/pluralkit/utils.py b/bot/pluralkit/utils.py new file mode 100644 index 00000000..b0aa7175 --- /dev/null +++ b/bot/pluralkit/utils.py @@ -0,0 +1,211 @@ +import random +import re +import string + +import asyncio +import asyncpg +import discord + +from pluralkit import db +from pluralkit.bot import client, logger + +def generate_hid() -> str: + return "".join(random.choices(string.ascii_lowercase, k=5)) + +async def parse_mention(mention: str) -> discord.User: + # First try matching mention format + match = re.fullmatch("<@!?(\\d+)>", mention) + if match: + try: + return await client.get_user_info(match.group(1)) + except discord.NotFound: + return None + + # Then try with just ID + try: + return await client.get_user_info(str(int(mention))) + except (ValueError, discord.NotFound): + return None + +async def get_system_fuzzy(conn, key) -> asyncpg.Record: + if isinstance(key, discord.User): + return await db.get_system_by_account(conn, account_id=key.id) + + if isinstance(key, str) and len(key) == 5: + return await db.get_system_by_hid(conn, system_hid=key) + + system = parse_mention(key) + + if system: + return system + return None + +async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record: + # First search by hid + if system_only: + member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key) + else: + member = await db.get_member_by_hid(conn, member_hid=key) + if member is not None: + return member + + # Then search by name, if we have a system + if system_id: + member = await db.get_member_by_name(conn, system_id=system_id, member_name=key) + if member is not None: + return member + +command_map = {} + +# Command wrapper +# Return True for success, return False for failure +# Second parameter is the message it'll send. If just False, will print usage +def command(cmd, subcommand, usage=None, description=None): + def wrap(func): + async def wrapper(conn, message, args): + res = await func(conn, message, args) + + if res is not None: + if not isinstance(res, tuple): + success, msg = res, None + else: + success, msg = res + + if not success and not msg: + # Failure, no message, print usage + usage_embed = discord.Embed() + usage_embed.colour = discord.Colour.blue() + usage_embed.add_field(name="Usage", value=usage, inline=False) + + await client.send_message(message.channel, embed=usage_embed) + elif not success: + # Failure, print message + error_embed = discord.Embed() + error_embed.colour = discord.Colour.dark_red() + error_embed.description = msg + await client.send_message(message.channel, embed=error_embed) + elif msg: + # Success, print message + success_embed = discord.Embed() + success_embed.colour = discord.Colour.blue() + success_embed.description = msg + await client.send_message(message.channel, embed=success_embed) + # Success, don't print anything + if cmd not in command_map: + command_map[cmd] = {} + if subcommand not in command_map[cmd]: + command_map[cmd][subcommand] = {} + + command_map[cmd][subcommand] = (wrapper, usage, description) + return wrapper + return wrap + +# Member command wrapper +# Tries to find member by first argument +# If system_only=False, allows members from other systems by hid +def member_command(cmd, subcommand, usage=None, description=None, system_only=True): + def wrap(func): + async def wrapper(conn, message, args): + # Return if no member param + if len(args) == 0: + return False + + # If system_only, we need a system to check + system = await db.get_system_by_account(conn, message.author.id) + if system_only and system is None: + return False, "No system is registered to this account." + + # System is allowed to be none if not system_only + system_id = system["id"] if system else None + # And find member by key + member = await get_member_fuzzy(conn, system_id=system_id, key=args[0], system_only=system_only) + + if member is None: + return False, "Can't find member \"{}\".".format(args[0]) + + return await func(conn, message, member, args[1:]) + return command(cmd=cmd, subcommand=subcommand, usage=usage, description=description)(wrapper) + return wrap + +async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed: + card = discord.Embed() + + if system["name"]: + card.title = system["name"] + + if system["description"]: + card.add_field(name="Description", value=system["description"], inline=False) + + if system["tag"]: + card.add_field(name="Tag", value=system["tag"]) + + # Get names of all linked accounts + async def get_name(account_id): + account = await client.get_user_info(account_id) + return "{}#{}".format(account.name, account.discriminator) + + account_name_futures = [] + for account_id in await db.get_linked_accounts(conn, system_id=system["id"]): + account_name_futures.append(get_name(account_id)) + # Run in parallel + account_names = await asyncio.gather(*account_name_futures) + + card.add_field(name="Linked accounts", value=", ".join(account_names)) + + # Get names of all members + member_texts = [] + for member in await db.get_all_members(conn, system_id=system["id"]): + member_texts.append("`{}`: {}".format(member["hid"], member["name"])) + + if len(member_texts) > 0: + card.add_field(name="Members", value="\n".join(member_texts), inline=False) + + card.set_footer(text="System ID: {}".format(system["hid"])) + return card + +async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed: + card = discord.Embed() + card.set_author(name=member["name"], icon_url=member["avatar_url"]) + + if member["color"]: + card.colour = int(member["color"], 16) + + if member["birthday"]: + card.add_field(name="Birthdate", value=member["birthday"].strftime("%b %d, %Y")) + + if member["pronouns"]: + card.add_field(name="Pronouns", value=member["pronouns"]) + + if member["prefix"] or member["suffix"]: + prefix = member["prefix"] or "" + suffix = member["suffix"] or "" + card.add_field(name="Proxy Tags", value="{}text{}".format(prefix, suffix)) + + if member["description"]: + card.add_field(name="Description", value=member["description"], inline=False) + + # Get system name and hid + system = await db.get_system(conn, system_id=member["system"]) + if system["name"]: + system_value = "`{}`: {}".format(system["hid"], system["name"]) + else: + system_value = "`{}`".format(system["hid"]) + card.add_field(name="System", value=system_value, inline=False) + + card.set_footer(text="System ID: {} | Member ID: {}".format(system["hid"], member["hid"])) + return card + +async def text_input(message, subject): + await client.send_message(message.channel, "Reply in this channel with the new description you want to set for {}.".format(subject)) + msg = await client.wait_for_message(author=message.author, channel=message.channel) + + await client.send_message(message.channel, "Alright. When you're happy with the new description, click the ✅ reaction. To cancel, click the ❌ reaction.") + await client.add_reaction(msg, "✅") + await client.add_reaction(msg, "❌") + + reaction = await client.wait_for_reaction(emoji=["✅", "❌"], message=msg, user=message.author) + if reaction.reaction.emoji == "✅": + return msg.content + else: + await client.clear_reactions(msg) + return None \ No newline at end of file diff --git a/bot/requirements.txt b/bot/requirements.txt new file mode 100644 index 00000000..cb088c81 --- /dev/null +++ b/bot/requirements.txt @@ -0,0 +1,4 @@ +discord.py +asyncpg +aiohttp +aioinflux \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..846d36e7 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,27 @@ +version: '3' +services: + bot: + build: bot + depends_on: + - db + environment: + - TOKEN + db: + image: postgres:alpine + volumes: + - db_data:/var/lib/postgres + restart: always + influx: + image: influxdb:alpine + volumes: + - /var/lib/influxdb + environment: + - INFLUXDB_GRAPHITE_ENABLED=true + restart: always + grafana: + image: grafana/grafana + ports: + - "3000:3000" + restart: always +volumes: + db_data: \ No newline at end of file