Add message logging feature
This commit is contained in:
		| @@ -7,7 +7,7 @@ import humanize | |||||||
|  |  | ||||||
| from pluralkit import db | from pluralkit import db | ||||||
| from pluralkit.bot import client, logger | from pluralkit.bot import client, logger | ||||||
| from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed | from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand=None, description="Shows information about your system.") | @command(cmd="pk;system", subcommand=None, description="Shows information about your system.") | ||||||
| async def this_system_info(conn, message, args): | async def this_system_info(conn, message, args): | ||||||
| @@ -495,6 +495,23 @@ async def switch_out(conn, message, args): | |||||||
|     await db.add_switch(conn, system_id=system["id"], member_id=None) |     await db.add_switch(conn, system_id=system["id"], member_id=None) | ||||||
|     return True, "Switch-out registered." |     return True, "Switch-out registered." | ||||||
|  |  | ||||||
|  | @command(cmd="pk;mod", subcommand="log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.") | ||||||
|  | async def set_log(conn, message, args): | ||||||
|  |     if not message.author.server_permissions.administrator: | ||||||
|  |         return False, "You must be a server administrator to use this command." | ||||||
|  |      | ||||||
|  |     server = message.server | ||||||
|  |     if len(args) == 0: | ||||||
|  |         channel_id = None | ||||||
|  |     else: | ||||||
|  |         channel = parse_channel_mention(args[0], server=server) | ||||||
|  |         if not channel: | ||||||
|  |             return False, "Channel not found." | ||||||
|  |         channel_id = channel.id | ||||||
|  |  | ||||||
|  |     await db.update_server(conn, server.id, logging_channel_id=channel_id) | ||||||
|  |     return True, "Updated logging channel." if channel_id else "Cleared logging channel." | ||||||
|  |  | ||||||
| def make_help(cmds): | def make_help(cmds): | ||||||
|     embed = discord.Embed() |     embed = discord.Embed() | ||||||
|     embed.colour = discord.Colour.blue() |     embed.colour = discord.Colour.blue() | ||||||
|   | |||||||
| @@ -24,10 +24,6 @@ def db_wrap(func): | |||||||
|         return res |         return res | ||||||
|     return inner |     return inner | ||||||
|  |  | ||||||
|  |  | ||||||
| webhook_cache = {} |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def create_system(conn, system_name: str, system_hid: str): | async def create_system(conn, system_name: str, system_hid: str): | ||||||
|     logger.debug("Creating system (name={}, hid={})".format( |     logger.debug("Creating system (name={}, hid={})".format( | ||||||
| @@ -140,11 +136,7 @@ async def get_members_exceeding(conn, system_id: int, length: int): | |||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_webhook(conn, channel_id: str): | async def get_webhook(conn, channel_id: str): | ||||||
|     if channel_id in webhook_cache: |     return await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id)) | ||||||
|         return webhook_cache[channel_id] |  | ||||||
|     res = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id)) |  | ||||||
|     webhook_cache[channel_id] = res |  | ||||||
|     return res |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| @@ -164,7 +156,7 @@ async def add_message(conn, message_id: str, channel_id: str, member_id: int, se | |||||||
| @db_wrap | @db_wrap | ||||||
| async def get_members_by_account(conn, account_id: str): | async def get_members_by_account(conn, account_id: str): | ||||||
|     # Returns a "chimera" object |     # Returns a "chimera" object | ||||||
|     return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) |     return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag, systems.name as system_name, systems.hid as system_hid from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| @@ -190,6 +182,16 @@ async def add_switch(conn, system_id: int, member_id: int): | |||||||
|     logger.debug("Adding switch (system={}, member={})".format(system_id, member_id)) |     logger.debug("Adding switch (system={}, member={})".format(system_id, member_id)) | ||||||
|     return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id) |     return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id) | ||||||
|  |  | ||||||
|  | @db_wrap | ||||||
|  | async def get_server_info(conn, server_id: str): | ||||||
|  |     return await conn.fetchrow("select * from servers where id = $1", int(server_id)) | ||||||
|  |  | ||||||
|  | @db_wrap | ||||||
|  | async def update_server(conn, server_id: str, logging_channel_id: str): | ||||||
|  |     logging_channel_id = int(logging_channel_id) if logging_channel_id else None | ||||||
|  |     logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id)) | ||||||
|  |     await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id) | ||||||
|  |  | ||||||
| async def create_tables(conn): | async def create_tables(conn): | ||||||
|     await conn.execute("""create table if not exists systems ( |     await conn.execute("""create table if not exists systems ( | ||||||
|         id          serial primary key, |         id          serial primary key, | ||||||
| @@ -237,6 +239,5 @@ async def create_tables(conn): | |||||||
|     )""") |     )""") | ||||||
|     await conn.execute("""create table if not exists servers ( |     await conn.execute("""create table if not exists servers ( | ||||||
|         id          bigint primary key, |         id          bigint primary key, | ||||||
|         cmd_chans   bigint[], |         log_channel bigint | ||||||
|         proxy_chans bigint[] |  | ||||||
|     )""") |     )""") | ||||||
|   | |||||||
| @@ -2,10 +2,31 @@ import os | |||||||
| import time | import time | ||||||
|  |  | ||||||
| import aiohttp | import aiohttp | ||||||
|  | import discord | ||||||
|  |  | ||||||
| from pluralkit import db | from pluralkit import db | ||||||
| from pluralkit.bot import client, logger | from pluralkit.bot import client, logger | ||||||
|  |  | ||||||
|  | async def log_message(original_message, hook_message, member, log_channel): | ||||||
|  |     author_name = member["name"] | ||||||
|  |     if member["system_name"]: | ||||||
|  |         author_name += " ({})".format(member["system_name"]) | ||||||
|  |  | ||||||
|  |     embed = discord.Embed() | ||||||
|  |     embed.colour = discord.Colour.blue() | ||||||
|  |     embed.set_author(name=author_name, icon_url=member["avatar_url"] or discord.Embed.Empty) | ||||||
|  |     embed.add_field(name="Member", value=member["name"]) | ||||||
|  |     embed.add_field(name="Sender", value="{}#{}".format(original_message.author.name, original_message.author.discriminator)) | ||||||
|  |     if member["system_name"]: | ||||||
|  |         embed.add_field(name="System", value=member["system_name"]) | ||||||
|  |     embed.add_field(name="Content", value=hook_message.clean_content) | ||||||
|  |     embed.timestamp = hook_message.timestamp | ||||||
|  |     embed.set_footer(text="System ID: {} | Member ID: {} | Sender ID: {} | Message ID: {}".format(member["system_hid"], member["hid"], original_message.author.id, hook_message.id)) | ||||||
|  |  | ||||||
|  |     if member["avatar_url"]: | ||||||
|  |         embed.set_thumbnail(url=member["avatar_url"]) | ||||||
|  |      | ||||||
|  |     await client.send_message(log_channel, embed=embed) | ||||||
|  |  | ||||||
| async def get_webhook(conn, channel): | async def get_webhook(conn, channel): | ||||||
|     async with conn.transaction(): |     async with conn.transaction(): | ||||||
| @@ -29,29 +50,48 @@ async def get_webhook(conn, channel): | |||||||
|  |  | ||||||
|         return hook_row["webhook"], hook_row["token"] |         return hook_row["webhook"], hook_row["token"] | ||||||
|  |  | ||||||
|  | async def send_hook_message(member, text, hook_id, hook_token): | ||||||
| async def proxy_message(conn, member, message, inner): |  | ||||||
|     logger.debug("Proxying message '{}' for member {}".format( |  | ||||||
|         inner, member["hid"])) |  | ||||||
|     # Delete the original message |  | ||||||
|     await client.delete_message(message) |  | ||||||
|  |  | ||||||
|     # Get the webhook details |  | ||||||
|     hook_id, hook_token = await get_webhook(conn, message.channel) |  | ||||||
|     async with aiohttp.ClientSession() as session: |     async with aiohttp.ClientSession() as session: | ||||||
|  |         # Set up parameters | ||||||
|         req_data = { |         req_data = { | ||||||
|             "username": "{} {}".format(member["name"], member["tag"] or "").strip(), |             "username": "{} {}".format(member["name"], member["tag"] or "").strip(), | ||||||
|             "avatar_url": member["avatar_url"], |             "avatar_url": member["avatar_url"], | ||||||
|             "content": inner |             "content": text | ||||||
|         } |         } | ||||||
|         req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} |         req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} | ||||||
|         # And send the message |  | ||||||
|         async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp: |  | ||||||
|             resp_data = await resp.json() |  | ||||||
|             logger.debug("Discord webhook response: {}".format(resp_data)) |  | ||||||
|  |  | ||||||
|             # Insert new message details into the DB |         # Send request | ||||||
|             await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id) |         async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp: | ||||||
|  |             if resp.status == 200: | ||||||
|  |                 resp_data = await resp.json() | ||||||
|  |                 return discord.Message(reactions=[], **resp_data) | ||||||
|  |             else: | ||||||
|  |                 # Fake a Discord exception, also because #yolo | ||||||
|  |                 raise discord.HTTPException(resp, await resp.text()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def proxy_message(conn, member, trigger_message, inner): | ||||||
|  |     logger.debug("Proxying message '{}' for member {}".format( | ||||||
|  |         inner, member["hid"])) | ||||||
|  |     # Delete the original message | ||||||
|  |     await client.delete_message(trigger_message) | ||||||
|  |  | ||||||
|  |     # Get the webhook details | ||||||
|  |     hook_id, hook_token = await get_webhook(conn, trigger_message.channel) | ||||||
|  |  | ||||||
|  |     # And send the message | ||||||
|  |     hook_message = await send_hook_message(member, inner, hook_id, hook_token) | ||||||
|  |  | ||||||
|  |     # Insert new message details into the DB | ||||||
|  |     await db.add_message(conn, message_id=hook_message.id, channel_id=trigger_message.channel.id, member_id=member["id"], sender_id=trigger_message.author.id) | ||||||
|  |  | ||||||
|  |     # Check server info for a log channel | ||||||
|  |     server_info = await db.get_server_info(conn, trigger_message.server.id) | ||||||
|  |     if server_info and server_info["log_channel"]: | ||||||
|  |         channel = trigger_message.server.get_channel(str(server_info["log_channel"])) | ||||||
|  |         if channel: | ||||||
|  |             # Log the message | ||||||
|  |             await log_message(trigger_message, hook_message, member, channel) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def handle_proxying(conn, message): | async def handle_proxying(conn, message): | ||||||
|   | |||||||
| @@ -30,6 +30,17 @@ async def parse_mention(mention: str) -> discord.User: | |||||||
|     except (ValueError, discord.NotFound): |     except (ValueError, discord.NotFound): | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  | def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel: | ||||||
|  |     match = re.fullmatch("<#(\\d+)>", mention) | ||||||
|  |     if match: | ||||||
|  |         return server.get_channel(match.group(1)) | ||||||
|  |      | ||||||
|  |     try: | ||||||
|  |         return server.get_channel(str(int(mention))) | ||||||
|  |     except ValueError: | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| async def get_system_fuzzy(conn, key) -> asyncpg.Record: | async def get_system_fuzzy(conn, key) -> asyncpg.Record: | ||||||
|     if isinstance(key, discord.User): |     if isinstance(key, discord.User): | ||||||
| @@ -196,8 +207,6 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb | |||||||
|  |  | ||||||
|     # Get system name and hid |     # Get system name and hid | ||||||
|     system = await db.get_system(conn, system_id=member["system"]) |     system = await db.get_system(conn, system_id=member["system"]) | ||||||
|     if system["name"]: |  | ||||||
|         system_value = "{}".format(system["name"]) |  | ||||||
|  |  | ||||||
|     if member["color"]: |     if member["color"]: | ||||||
|         card.colour = int(member["color"], 16) |         card.colour = int(member["color"], 16) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user