diff --git a/bot/pluralkit/commands.py b/bot/pluralkit/commands.py index 13affd9a..443b4098 100644 --- a/bot/pluralkit/commands.py +++ b/bot/pluralkit/commands.py @@ -7,7 +7,7 @@ import humanize from pluralkit import db from pluralkit.bot import client, logger -from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed +from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention @command(cmd="pk;system", subcommand=None, description="Shows information about your system.") async def this_system_info(conn, message, args): @@ -495,6 +495,23 @@ async def switch_out(conn, message, args): await db.add_switch(conn, system_id=system["id"], member_id=None) return True, "Switch-out registered." +@command(cmd="pk;mod", subcommand="log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.") +async def set_log(conn, message, args): + if not message.author.server_permissions.administrator: + return False, "You must be a server administrator to use this command." + + server = message.server + if len(args) == 0: + channel_id = None + else: + channel = parse_channel_mention(args[0], server=server) + if not channel: + return False, "Channel not found." + channel_id = channel.id + + await db.update_server(conn, server.id, logging_channel_id=channel_id) + return True, "Updated logging channel." if channel_id else "Cleared logging channel." + def make_help(cmds): embed = discord.Embed() embed.colour = discord.Colour.blue() diff --git a/bot/pluralkit/db.py b/bot/pluralkit/db.py index 2c590991..2f48d95e 100644 --- a/bot/pluralkit/db.py +++ b/bot/pluralkit/db.py @@ -24,10 +24,6 @@ def db_wrap(func): return res return inner - -webhook_cache = {} - - @db_wrap async def create_system(conn, system_name: str, system_hid: str): logger.debug("Creating system (name={}, hid={})".format( @@ -140,11 +136,7 @@ async def get_members_exceeding(conn, system_id: int, length: int): @db_wrap async def get_webhook(conn, channel_id: str): - if channel_id in webhook_cache: - return webhook_cache[channel_id] - res = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id)) - webhook_cache[channel_id] = res - return res + return await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id)) @db_wrap @@ -164,7 +156,7 @@ async def add_message(conn, message_id: str, channel_id: str, member_id: int, se @db_wrap async def get_members_by_account(conn, account_id: str): # Returns a "chimera" object - return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) + return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag, systems.name as system_name, systems.hid as system_hid from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) @db_wrap @@ -190,6 +182,16 @@ async def add_switch(conn, system_id: int, member_id: int): logger.debug("Adding switch (system={}, member={})".format(system_id, member_id)) return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id) +@db_wrap +async def get_server_info(conn, server_id: str): + return await conn.fetchrow("select * from servers where id = $1", int(server_id)) + +@db_wrap +async def update_server(conn, server_id: str, logging_channel_id: str): + logging_channel_id = int(logging_channel_id) if logging_channel_id else None + logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id)) + await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id) + async def create_tables(conn): await conn.execute("""create table if not exists systems ( id serial primary key, @@ -237,6 +239,5 @@ async def create_tables(conn): )""") await conn.execute("""create table if not exists servers ( id bigint primary key, - cmd_chans bigint[], - proxy_chans bigint[] + log_channel bigint )""") diff --git a/bot/pluralkit/proxy.py b/bot/pluralkit/proxy.py index 5a87b690..587971e9 100644 --- a/bot/pluralkit/proxy.py +++ b/bot/pluralkit/proxy.py @@ -2,10 +2,31 @@ import os import time import aiohttp +import discord from pluralkit import db from pluralkit.bot import client, logger +async def log_message(original_message, hook_message, member, log_channel): + author_name = member["name"] + if member["system_name"]: + author_name += " ({})".format(member["system_name"]) + + embed = discord.Embed() + embed.colour = discord.Colour.blue() + embed.set_author(name=author_name, icon_url=member["avatar_url"] or discord.Embed.Empty) + embed.add_field(name="Member", value=member["name"]) + embed.add_field(name="Sender", value="{}#{}".format(original_message.author.name, original_message.author.discriminator)) + if member["system_name"]: + embed.add_field(name="System", value=member["system_name"]) + embed.add_field(name="Content", value=hook_message.clean_content) + embed.timestamp = hook_message.timestamp + embed.set_footer(text="System ID: {} | Member ID: {} | Sender ID: {} | Message ID: {}".format(member["system_hid"], member["hid"], original_message.author.id, hook_message.id)) + + if member["avatar_url"]: + embed.set_thumbnail(url=member["avatar_url"]) + + await client.send_message(log_channel, embed=embed) async def get_webhook(conn, channel): async with conn.transaction(): @@ -29,29 +50,48 @@ async def get_webhook(conn, channel): return hook_row["webhook"], hook_row["token"] - -async def proxy_message(conn, member, message, inner): - logger.debug("Proxying message '{}' for member {}".format( - inner, member["hid"])) - # Delete the original message - await client.delete_message(message) - - # Get the webhook details - hook_id, hook_token = await get_webhook(conn, message.channel) +async def send_hook_message(member, text, hook_id, hook_token): async with aiohttp.ClientSession() as session: + # Set up parameters req_data = { "username": "{} {}".format(member["name"], member["tag"] or "").strip(), "avatar_url": member["avatar_url"], - "content": inner + "content": text } req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} - # And send the message - async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp: - resp_data = await resp.json() - logger.debug("Discord webhook response: {}".format(resp_data)) - # Insert new message details into the DB - await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id) + # Send request + async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp: + if resp.status == 200: + resp_data = await resp.json() + return discord.Message(reactions=[], **resp_data) + else: + # Fake a Discord exception, also because #yolo + raise discord.HTTPException(resp, await resp.text()) + + +async def proxy_message(conn, member, trigger_message, inner): + logger.debug("Proxying message '{}' for member {}".format( + inner, member["hid"])) + # Delete the original message + await client.delete_message(trigger_message) + + # Get the webhook details + hook_id, hook_token = await get_webhook(conn, trigger_message.channel) + + # And send the message + hook_message = await send_hook_message(member, inner, hook_id, hook_token) + + # Insert new message details into the DB + await db.add_message(conn, message_id=hook_message.id, channel_id=trigger_message.channel.id, member_id=member["id"], sender_id=trigger_message.author.id) + + # Check server info for a log channel + server_info = await db.get_server_info(conn, trigger_message.server.id) + if server_info and server_info["log_channel"]: + channel = trigger_message.server.get_channel(str(server_info["log_channel"])) + if channel: + # Log the message + await log_message(trigger_message, hook_message, member, channel) async def handle_proxying(conn, message): diff --git a/bot/pluralkit/utils.py b/bot/pluralkit/utils.py index 6e871423..8675af5c 100644 --- a/bot/pluralkit/utils.py +++ b/bot/pluralkit/utils.py @@ -30,6 +30,17 @@ async def parse_mention(mention: str) -> discord.User: except (ValueError, discord.NotFound): return None +def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel: + match = re.fullmatch("<#(\\d+)>", mention) + if match: + return server.get_channel(match.group(1)) + + try: + return server.get_channel(str(int(mention))) + except ValueError: + return None + + async def get_system_fuzzy(conn, key) -> asyncpg.Record: if isinstance(key, discord.User): @@ -196,8 +207,6 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb # Get system name and hid system = await db.get_system(conn, system_id=member["system"]) - if system["name"]: - system_value = "{}".format(system["name"]) if member["color"]: card.colour = int(member["color"], 16)