Add message logging feature

This commit is contained in:
Ske 2018-07-12 15:03:34 +02:00
parent 95eab5a4a3
commit e4eee8cb09
4 changed files with 98 additions and 31 deletions

View File

@ -7,7 +7,7 @@ import humanize
from pluralkit import db from pluralkit import db
from pluralkit.bot import client, logger from pluralkit.bot import client, logger
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention
@command(cmd="pk;system", subcommand=None, description="Shows information about your system.") @command(cmd="pk;system", subcommand=None, description="Shows information about your system.")
async def this_system_info(conn, message, args): async def this_system_info(conn, message, args):
@ -495,6 +495,23 @@ async def switch_out(conn, message, args):
await db.add_switch(conn, system_id=system["id"], member_id=None) await db.add_switch(conn, system_id=system["id"], member_id=None)
return True, "Switch-out registered." return True, "Switch-out registered."
@command(cmd="pk;mod", subcommand="log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.")
async def set_log(conn, message, args):
if not message.author.server_permissions.administrator:
return False, "You must be a server administrator to use this command."
server = message.server
if len(args) == 0:
channel_id = None
else:
channel = parse_channel_mention(args[0], server=server)
if not channel:
return False, "Channel not found."
channel_id = channel.id
await db.update_server(conn, server.id, logging_channel_id=channel_id)
return True, "Updated logging channel." if channel_id else "Cleared logging channel."
def make_help(cmds): def make_help(cmds):
embed = discord.Embed() embed = discord.Embed()
embed.colour = discord.Colour.blue() embed.colour = discord.Colour.blue()

View File

@ -24,10 +24,6 @@ def db_wrap(func):
return res return res
return inner return inner
webhook_cache = {}
@db_wrap @db_wrap
async def create_system(conn, system_name: str, system_hid: str): async def create_system(conn, system_name: str, system_hid: str):
logger.debug("Creating system (name={}, hid={})".format( logger.debug("Creating system (name={}, hid={})".format(
@ -140,11 +136,7 @@ async def get_members_exceeding(conn, system_id: int, length: int):
@db_wrap @db_wrap
async def get_webhook(conn, channel_id: str): async def get_webhook(conn, channel_id: str):
if channel_id in webhook_cache: return await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
return webhook_cache[channel_id]
res = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
webhook_cache[channel_id] = res
return res
@db_wrap @db_wrap
@ -164,7 +156,7 @@ async def add_message(conn, message_id: str, channel_id: str, member_id: int, se
@db_wrap @db_wrap
async def get_members_by_account(conn, account_id: str): async def get_members_by_account(conn, account_id: str):
# Returns a "chimera" object # Returns a "chimera" object
return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag, systems.name as system_name, systems.hid as system_hid from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id))
@db_wrap @db_wrap
@ -190,6 +182,16 @@ async def add_switch(conn, system_id: int, member_id: int):
logger.debug("Adding switch (system={}, member={})".format(system_id, member_id)) logger.debug("Adding switch (system={}, member={})".format(system_id, member_id))
return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id) return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id)
@db_wrap
async def get_server_info(conn, server_id: str):
return await conn.fetchrow("select * from servers where id = $1", int(server_id))
@db_wrap
async def update_server(conn, server_id: str, logging_channel_id: str):
logging_channel_id = int(logging_channel_id) if logging_channel_id else None
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
async def create_tables(conn): async def create_tables(conn):
await conn.execute("""create table if not exists systems ( await conn.execute("""create table if not exists systems (
id serial primary key, id serial primary key,
@ -237,6 +239,5 @@ async def create_tables(conn):
)""") )""")
await conn.execute("""create table if not exists servers ( await conn.execute("""create table if not exists servers (
id bigint primary key, id bigint primary key,
cmd_chans bigint[], log_channel bigint
proxy_chans bigint[]
)""") )""")

View File

@ -2,10 +2,31 @@ import os
import time import time
import aiohttp import aiohttp
import discord
from pluralkit import db from pluralkit import db
from pluralkit.bot import client, logger from pluralkit.bot import client, logger
async def log_message(original_message, hook_message, member, log_channel):
author_name = member["name"]
if member["system_name"]:
author_name += " ({})".format(member["system_name"])
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.set_author(name=author_name, icon_url=member["avatar_url"] or discord.Embed.Empty)
embed.add_field(name="Member", value=member["name"])
embed.add_field(name="Sender", value="{}#{}".format(original_message.author.name, original_message.author.discriminator))
if member["system_name"]:
embed.add_field(name="System", value=member["system_name"])
embed.add_field(name="Content", value=hook_message.clean_content)
embed.timestamp = hook_message.timestamp
embed.set_footer(text="System ID: {} | Member ID: {} | Sender ID: {} | Message ID: {}".format(member["system_hid"], member["hid"], original_message.author.id, hook_message.id))
if member["avatar_url"]:
embed.set_thumbnail(url=member["avatar_url"])
await client.send_message(log_channel, embed=embed)
async def get_webhook(conn, channel): async def get_webhook(conn, channel):
async with conn.transaction(): async with conn.transaction():
@ -29,29 +50,48 @@ async def get_webhook(conn, channel):
return hook_row["webhook"], hook_row["token"] return hook_row["webhook"], hook_row["token"]
async def send_hook_message(member, text, hook_id, hook_token):
async def proxy_message(conn, member, message, inner):
logger.debug("Proxying message '{}' for member {}".format(
inner, member["hid"]))
# Delete the original message
await client.delete_message(message)
# Get the webhook details
hook_id, hook_token = await get_webhook(conn, message.channel)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Set up parameters
req_data = { req_data = {
"username": "{} {}".format(member["name"], member["tag"] or "").strip(), "username": "{} {}".format(member["name"], member["tag"] or "").strip(),
"avatar_url": member["avatar_url"], "avatar_url": member["avatar_url"],
"content": inner "content": text
} }
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])}
# And send the message
async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp:
resp_data = await resp.json()
logger.debug("Discord webhook response: {}".format(resp_data))
# Insert new message details into the DB # Send request
await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id) async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp:
if resp.status == 200:
resp_data = await resp.json()
return discord.Message(reactions=[], **resp_data)
else:
# Fake a Discord exception, also because #yolo
raise discord.HTTPException(resp, await resp.text())
async def proxy_message(conn, member, trigger_message, inner):
logger.debug("Proxying message '{}' for member {}".format(
inner, member["hid"]))
# Delete the original message
await client.delete_message(trigger_message)
# Get the webhook details
hook_id, hook_token = await get_webhook(conn, trigger_message.channel)
# And send the message
hook_message = await send_hook_message(member, inner, hook_id, hook_token)
# Insert new message details into the DB
await db.add_message(conn, message_id=hook_message.id, channel_id=trigger_message.channel.id, member_id=member["id"], sender_id=trigger_message.author.id)
# Check server info for a log channel
server_info = await db.get_server_info(conn, trigger_message.server.id)
if server_info and server_info["log_channel"]:
channel = trigger_message.server.get_channel(str(server_info["log_channel"]))
if channel:
# Log the message
await log_message(trigger_message, hook_message, member, channel)
async def handle_proxying(conn, message): async def handle_proxying(conn, message):

View File

@ -30,6 +30,17 @@ async def parse_mention(mention: str) -> discord.User:
except (ValueError, discord.NotFound): except (ValueError, discord.NotFound):
return None return None
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
match = re.fullmatch("<#(\\d+)>", mention)
if match:
return server.get_channel(match.group(1))
try:
return server.get_channel(str(int(mention)))
except ValueError:
return None
async def get_system_fuzzy(conn, key) -> asyncpg.Record: async def get_system_fuzzy(conn, key) -> asyncpg.Record:
if isinstance(key, discord.User): if isinstance(key, discord.User):
@ -196,8 +207,6 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb
# Get system name and hid # Get system name and hid
system = await db.get_system(conn, system_id=member["system"]) system = await db.get_system(conn, system_id=member["system"])
if system["name"]:
system_value = "{}".format(system["name"])
if member["color"]: if member["color"]:
card.colour = int(member["color"], 16) card.colour = int(member["color"], 16)