Add message logging feature
This commit is contained in:
parent
95eab5a4a3
commit
e4eee8cb09
@ -7,7 +7,7 @@ import humanize
|
|||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.bot import client, logger
|
from pluralkit.bot import client, logger
|
||||||
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed
|
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention
|
||||||
|
|
||||||
@command(cmd="pk;system", subcommand=None, description="Shows information about your system.")
|
@command(cmd="pk;system", subcommand=None, description="Shows information about your system.")
|
||||||
async def this_system_info(conn, message, args):
|
async def this_system_info(conn, message, args):
|
||||||
@ -495,6 +495,23 @@ async def switch_out(conn, message, args):
|
|||||||
await db.add_switch(conn, system_id=system["id"], member_id=None)
|
await db.add_switch(conn, system_id=system["id"], member_id=None)
|
||||||
return True, "Switch-out registered."
|
return True, "Switch-out registered."
|
||||||
|
|
||||||
|
@command(cmd="pk;mod", subcommand="log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.")
|
||||||
|
async def set_log(conn, message, args):
|
||||||
|
if not message.author.server_permissions.administrator:
|
||||||
|
return False, "You must be a server administrator to use this command."
|
||||||
|
|
||||||
|
server = message.server
|
||||||
|
if len(args) == 0:
|
||||||
|
channel_id = None
|
||||||
|
else:
|
||||||
|
channel = parse_channel_mention(args[0], server=server)
|
||||||
|
if not channel:
|
||||||
|
return False, "Channel not found."
|
||||||
|
channel_id = channel.id
|
||||||
|
|
||||||
|
await db.update_server(conn, server.id, logging_channel_id=channel_id)
|
||||||
|
return True, "Updated logging channel." if channel_id else "Cleared logging channel."
|
||||||
|
|
||||||
def make_help(cmds):
|
def make_help(cmds):
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.colour = discord.Colour.blue()
|
embed.colour = discord.Colour.blue()
|
||||||
|
@ -24,10 +24,6 @@ def db_wrap(func):
|
|||||||
return res
|
return res
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
webhook_cache = {}
|
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def create_system(conn, system_name: str, system_hid: str):
|
async def create_system(conn, system_name: str, system_hid: str):
|
||||||
logger.debug("Creating system (name={}, hid={})".format(
|
logger.debug("Creating system (name={}, hid={})".format(
|
||||||
@ -140,11 +136,7 @@ async def get_members_exceeding(conn, system_id: int, length: int):
|
|||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_webhook(conn, channel_id: str):
|
async def get_webhook(conn, channel_id: str):
|
||||||
if channel_id in webhook_cache:
|
return await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
|
||||||
return webhook_cache[channel_id]
|
|
||||||
res = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
|
|
||||||
webhook_cache[channel_id] = res
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
@ -164,7 +156,7 @@ async def add_message(conn, message_id: str, channel_id: str, member_id: int, se
|
|||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_members_by_account(conn, account_id: str):
|
async def get_members_by_account(conn, account_id: str):
|
||||||
# Returns a "chimera" object
|
# Returns a "chimera" object
|
||||||
return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id))
|
return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag, systems.name as system_name, systems.hid as system_hid from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id))
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
@ -190,6 +182,16 @@ async def add_switch(conn, system_id: int, member_id: int):
|
|||||||
logger.debug("Adding switch (system={}, member={})".format(system_id, member_id))
|
logger.debug("Adding switch (system={}, member={})".format(system_id, member_id))
|
||||||
return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id)
|
return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id)
|
||||||
|
|
||||||
|
@db_wrap
|
||||||
|
async def get_server_info(conn, server_id: str):
|
||||||
|
return await conn.fetchrow("select * from servers where id = $1", int(server_id))
|
||||||
|
|
||||||
|
@db_wrap
|
||||||
|
async def update_server(conn, server_id: str, logging_channel_id: str):
|
||||||
|
logging_channel_id = int(logging_channel_id) if logging_channel_id else None
|
||||||
|
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
|
||||||
|
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
|
||||||
|
|
||||||
async def create_tables(conn):
|
async def create_tables(conn):
|
||||||
await conn.execute("""create table if not exists systems (
|
await conn.execute("""create table if not exists systems (
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
@ -237,6 +239,5 @@ async def create_tables(conn):
|
|||||||
)""")
|
)""")
|
||||||
await conn.execute("""create table if not exists servers (
|
await conn.execute("""create table if not exists servers (
|
||||||
id bigint primary key,
|
id bigint primary key,
|
||||||
cmd_chans bigint[],
|
log_channel bigint
|
||||||
proxy_chans bigint[]
|
|
||||||
)""")
|
)""")
|
||||||
|
@ -2,10 +2,31 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import discord
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.bot import client, logger
|
from pluralkit.bot import client, logger
|
||||||
|
|
||||||
|
async def log_message(original_message, hook_message, member, log_channel):
|
||||||
|
author_name = member["name"]
|
||||||
|
if member["system_name"]:
|
||||||
|
author_name += " ({})".format(member["system_name"])
|
||||||
|
|
||||||
|
embed = discord.Embed()
|
||||||
|
embed.colour = discord.Colour.blue()
|
||||||
|
embed.set_author(name=author_name, icon_url=member["avatar_url"] or discord.Embed.Empty)
|
||||||
|
embed.add_field(name="Member", value=member["name"])
|
||||||
|
embed.add_field(name="Sender", value="{}#{}".format(original_message.author.name, original_message.author.discriminator))
|
||||||
|
if member["system_name"]:
|
||||||
|
embed.add_field(name="System", value=member["system_name"])
|
||||||
|
embed.add_field(name="Content", value=hook_message.clean_content)
|
||||||
|
embed.timestamp = hook_message.timestamp
|
||||||
|
embed.set_footer(text="System ID: {} | Member ID: {} | Sender ID: {} | Message ID: {}".format(member["system_hid"], member["hid"], original_message.author.id, hook_message.id))
|
||||||
|
|
||||||
|
if member["avatar_url"]:
|
||||||
|
embed.set_thumbnail(url=member["avatar_url"])
|
||||||
|
|
||||||
|
await client.send_message(log_channel, embed=embed)
|
||||||
|
|
||||||
async def get_webhook(conn, channel):
|
async def get_webhook(conn, channel):
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
@ -29,29 +50,48 @@ async def get_webhook(conn, channel):
|
|||||||
|
|
||||||
return hook_row["webhook"], hook_row["token"]
|
return hook_row["webhook"], hook_row["token"]
|
||||||
|
|
||||||
|
async def send_hook_message(member, text, hook_id, hook_token):
|
||||||
async def proxy_message(conn, member, message, inner):
|
|
||||||
logger.debug("Proxying message '{}' for member {}".format(
|
|
||||||
inner, member["hid"]))
|
|
||||||
# Delete the original message
|
|
||||||
await client.delete_message(message)
|
|
||||||
|
|
||||||
# Get the webhook details
|
|
||||||
hook_id, hook_token = await get_webhook(conn, message.channel)
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Set up parameters
|
||||||
req_data = {
|
req_data = {
|
||||||
"username": "{} {}".format(member["name"], member["tag"] or "").strip(),
|
"username": "{} {}".format(member["name"], member["tag"] or "").strip(),
|
||||||
"avatar_url": member["avatar_url"],
|
"avatar_url": member["avatar_url"],
|
||||||
"content": inner
|
"content": text
|
||||||
}
|
}
|
||||||
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])}
|
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])}
|
||||||
# And send the message
|
|
||||||
async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp:
|
|
||||||
resp_data = await resp.json()
|
|
||||||
logger.debug("Discord webhook response: {}".format(resp_data))
|
|
||||||
|
|
||||||
# Insert new message details into the DB
|
# Send request
|
||||||
await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id)
|
async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
resp_data = await resp.json()
|
||||||
|
return discord.Message(reactions=[], **resp_data)
|
||||||
|
else:
|
||||||
|
# Fake a Discord exception, also because #yolo
|
||||||
|
raise discord.HTTPException(resp, await resp.text())
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_message(conn, member, trigger_message, inner):
|
||||||
|
logger.debug("Proxying message '{}' for member {}".format(
|
||||||
|
inner, member["hid"]))
|
||||||
|
# Delete the original message
|
||||||
|
await client.delete_message(trigger_message)
|
||||||
|
|
||||||
|
# Get the webhook details
|
||||||
|
hook_id, hook_token = await get_webhook(conn, trigger_message.channel)
|
||||||
|
|
||||||
|
# And send the message
|
||||||
|
hook_message = await send_hook_message(member, inner, hook_id, hook_token)
|
||||||
|
|
||||||
|
# Insert new message details into the DB
|
||||||
|
await db.add_message(conn, message_id=hook_message.id, channel_id=trigger_message.channel.id, member_id=member["id"], sender_id=trigger_message.author.id)
|
||||||
|
|
||||||
|
# Check server info for a log channel
|
||||||
|
server_info = await db.get_server_info(conn, trigger_message.server.id)
|
||||||
|
if server_info and server_info["log_channel"]:
|
||||||
|
channel = trigger_message.server.get_channel(str(server_info["log_channel"]))
|
||||||
|
if channel:
|
||||||
|
# Log the message
|
||||||
|
await log_message(trigger_message, hook_message, member, channel)
|
||||||
|
|
||||||
|
|
||||||
async def handle_proxying(conn, message):
|
async def handle_proxying(conn, message):
|
||||||
|
@ -30,6 +30,17 @@ async def parse_mention(mention: str) -> discord.User:
|
|||||||
except (ValueError, discord.NotFound):
|
except (ValueError, discord.NotFound):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
|
||||||
|
match = re.fullmatch("<#(\\d+)>", mention)
|
||||||
|
if match:
|
||||||
|
return server.get_channel(match.group(1))
|
||||||
|
|
||||||
|
try:
|
||||||
|
return server.get_channel(str(int(mention)))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_system_fuzzy(conn, key) -> asyncpg.Record:
|
async def get_system_fuzzy(conn, key) -> asyncpg.Record:
|
||||||
if isinstance(key, discord.User):
|
if isinstance(key, discord.User):
|
||||||
@ -196,8 +207,6 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb
|
|||||||
|
|
||||||
# Get system name and hid
|
# Get system name and hid
|
||||||
system = await db.get_system(conn, system_id=member["system"])
|
system = await db.get_system(conn, system_id=member["system"])
|
||||||
if system["name"]:
|
|
||||||
system_value = "{}".format(system["name"])
|
|
||||||
|
|
||||||
if member["color"]:
|
if member["color"]:
|
||||||
card.colour = int(member["color"], 16)
|
card.colour = int(member["color"], 16)
|
||||||
|
Loading…
Reference in New Issue
Block a user