Massive refactor of pretty much everything in the bot

This commit is contained in:
Ske
2018-07-24 22:47:57 +02:00
parent 086fa84b4b
commit 8936029dc8
27 changed files with 1799 additions and 1450 deletions

11
src/bot.Dockerfile Normal file
View File

@@ -0,0 +1,11 @@
FROM python:3.6-alpine
RUN apk --no-cache add build-base
WORKDIR /app
ADD requirements.txt /app
RUN pip install --trusted-host pypi.python.org -r requirements.txt
ADD . /app
ENTRYPOINT ["python", "bot_main.py"]

11
src/bot_main.py Normal file
View File

@@ -0,0 +1,11 @@
import asyncio
import os
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from pluralkit import bot
pk = bot.PluralKitBot(os.environ["TOKEN"])
loop = asyncio.get_event_loop()
loop.run_until_complete(pk.run())

26
src/pluralkit/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
from collections import namedtuple
from datetime import date, datetime
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])):
id: int
hid: str
name: str
description: str
tag: str
avatar_url: str
created: datetime
class Member(namedtuple("Member", ["id", "hid", "system", "color", "avatar_url", "name", "birthday", "pronouns", "description", "prefix", "suffix", "created"])):
id: int
hid: str
system: int
color: str
avatar_url: str
name: str
birthday: date
pronouns: str
description: str
prefix: str
suffix: str
created: datetime

View File

@@ -0,0 +1,131 @@
import asyncio
from datetime import datetime
import logging
import json
import os
import time
import discord
from pluralkit import db, stats
from pluralkit.bot import channel_logger, commands, proxy
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logging.getLogger("pluralkit").setLevel(logging.DEBUG)
class PluralKitBot:
def __init__(self, token):
self.token = token
self.logger = logging.getLogger("pluralkit.bot")
self.client = discord.Client()
self.client.event(self.on_error)
self.client.event(self.on_ready)
self.client.event(self.on_message)
self.client.event(self.on_socket_raw_receive)
self.channel_logger = channel_logger.ChannelLogger(self.client)
self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
async def on_error(self, evt, *args, **kwargs):
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
async def on_ready(self):
self.logger.info("Connected to Discord.")
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
self.logger.info("- User ID: {}".format(self.client.user.id))
self.logger.info("- {} servers".format(len(self.client.servers)))
async def on_message(self, message):
# Ignore bot messages
if message.author.bot:
return
if await self.handle_command_dispatch(message):
return
if await self.handle_proxy_dispatch(message):
return
async def on_socket_raw_receive(self, msg):
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# we parse socket data manually for the reaction add event
if isinstance(msg, str):
try:
msg_data = json.loads(msg)
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
evt_data = msg_data.get("d")
if evt_data:
user_id = evt_data["user_id"]
message_id = evt_data["message_id"]
emoji = evt_data["emoji"]["name"]
async with self.pool.acquire() as conn:
await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
elif msg_data.get("t") == "MESSAGE_DELETE":
evt_data = msg_data.get("d")
if evt_data:
message_id = evt_data["id"]
async with self.pool.acquire() as conn:
await self.proxy.handle_deletion(conn, message_id)
except ValueError:
pass
async def handle_command_dispatch(self, message):
command_items = commands.command_list.items()
command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True)
prefix = "pk;"
for command_name, command in command_items:
if message.content.lower().startswith(prefix + command_name):
args_str = message.content[len(prefix + command_name):].strip()
args = args_str.split(" ")
# Splitting on empty string yields one-element array, remove that
if len(args) == 1 and not args[0]:
args = []
async with self.pool.acquire() as conn:
time_before = time.perf_counter()
await command.function(self.client, conn, message, args)
time_after = time.perf_counter()
# Report command time stats
execution_time = time_after - time_before
response_time = (datetime.now() - message.timestamp).total_seconds()
await stats.report_command(command_name, execution_time, response_time)
return True
async def handle_proxy_dispatch(self, message):
# Try doing proxy parsing
async with self.pool.acquire() as conn:
return await self.proxy.try_proxy_message(conn, message)
async def periodical_stat_timer(self, pool):
async with pool.acquire() as conn:
while True:
from pluralkit import stats
await stats.report_periodical_stats(conn)
await asyncio.sleep(30)
async def run(self):
try:
self.logger.info("Connecting to database...")
self.pool = await db.connect()
self.logger.info("Attempting to create tables...")
async with self.pool.acquire() as conn:
await db.create_tables(conn)
self.logger.info("Connecting to InfluxDB...")
await stats.connect()
self.logger.info("Starting periodical stat reporting...")
asyncio.get_event_loop().create_task(self.periodical_stat_timer(self.pool))
self.logger.info("Connecting to Discord...")
await self.client.start(self.token)
finally:
self.logger.info("Logging out from Discord...")
await self.client.logout()

View File

@@ -0,0 +1,109 @@
import logging
from datetime import datetime
import discord
from pluralkit import db
def embed_set_author_name(embed: discord.Embed, channel_name: str, member_name: str, system_name: str, avatar_url: str):
name = "#{}: {}".format(channel_name, member_name)
if system_name:
name += " ({})".format(system_name)
embed.set_author(name=name, icon_url=avatar_url or discord.Embed.Empty)
class ChannelLogger:
def __init__(self, client: discord.Client):
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
self.client = client
async def get_log_channel(self, conn, server_id: str):
server_info = await db.get_server_info(conn, server_id)
if not server_info:
return None
log_channel = server_info["log_channel"]
if not log_channel:
return None
return self.client.get_channel(str(log_channel))
async def send_to_log_channel(self, log_channel: discord.Channel, embed: discord.Embed):
try:
await self.client.send_message(log_channel, embed=embed)
except discord.Forbidden:
# TODO: spew big error
self.logger.warning(
"Did not have permission to send message to logging channel (server={}, channel={})".format(
log_channel.server.id, log_channel.id))
async def log_message_proxied(self, conn,
server_id: str,
channel_name: str,
channel_id: str,
sender_name: str,
sender_disc: int,
member_name: str,
member_hid: str,
member_avatar_url: str,
system_name: str,
system_hid: str,
message_text: str,
message_image: str,
message_timestamp: datetime,
message_id: str):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message_text
embed.timestamp = message_timestamp
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
embed.set_footer(
text="System ID: {} | Member ID: {} | Sender: {}#{} | Message ID: {}".format(system_hid, member_hid,
sender_name, sender_disc,
message_id))
if message_image:
embed.set_thumbnail(url=message_image)
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
embed.author.url = message_link
await self.send_to_log_channel(log_channel, embed)
async def log_message_deleted(self, conn,
server_id: str,
channel_name: str,
member_name: str,
member_hid: str,
member_avatar_url: str,
system_name: str,
system_hid: str,
message_text: str,
message_id: str,
deleted_by_moderator: bool):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message_text
embed.timestamp = datetime.utcnow()
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
embed.set_footer(
text="System ID: {} | Member ID: {} | Message ID: {} | Deleted by moderator? {}".format(system_hid,
member_hid,
message_id,
"Yes" if deleted_by_moderator else "No"))
await self.send_to_log_channel(log_channel, embed)

View File

@@ -0,0 +1,98 @@
import logging
from collections import namedtuple
import asyncpg
import discord
import pluralkit
from pluralkit import db
from pluralkit.bot import utils
command_list = {}
class InvalidCommandSyntax(Exception):
pass
class NoSystemRegistered(Exception):
pass
class CommandError(Exception):
def __init__(self, message):
self.message = message
class CommandContext(namedtuple("CommandContext", ["client", "conn", "message", "system"])):
client: discord.Client
conn: asyncpg.Connection
message: discord.Message
system: pluralkit.System
async def reply(self, message=None, embed=None):
return await self.client.send_message(self.message.channel, message, embed=embed)
class MemberCommandContext(namedtuple("MemberCommandContext", CommandContext._fields + ("member",)), CommandContext):
client: discord.Client
conn: asyncpg.Connection
message: discord.Message
system: pluralkit.System
member: pluralkit.Member
class CommandEntry(namedtuple("CommandEntry", ["command", "function", "usage", "description", "category"])):
pass
def command(cmd, usage=None, description=None, category=None, system_required=True):
def wrap(func):
async def wrapper(client, conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system_required and system is None:
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account"))
return
ctx = CommandContext(client=client, conn=conn, message=message, system=system)
try:
res = await func(ctx, args)
if res:
embed = res if isinstance(res, discord.Embed) else utils.make_default_embed(res)
await client.send_message(message.channel, embed=embed)
except NoSystemRegistered:
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account"))
except InvalidCommandSyntax:
usage_str = "**Usage:** pk;{} {}".format(cmd, usage or "")
await client.send_message(message.channel, embed=utils.make_default_embed(usage_str))
except CommandError as e:
embed = e.message if isinstance(e.message, discord.Embed) else utils.make_error_embed(e.message)
await client.send_message(message.channel, embed=embed)
# Put command in map
command_list[cmd] = CommandEntry(command=cmd, function=wrapper, usage=usage, description=description, category=category)
return wrapper
return wrap
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
def wrap(func):
async def wrapper(ctx: CommandContext, args):
# Return if no member param
if len(args) == 0:
raise InvalidCommandSyntax()
# System is allowed to be none if not system_only
system_id = ctx.system.id if ctx.system else None
# And find member by key
member = await utils.get_member_fuzzy(ctx.conn, system_id=system_id, key=args[0], system_only=system_only)
if member is None:
raise CommandError("Can't find member \"{}\".".format(args[0]))
ctx = MemberCommandContext(client=ctx.client, conn=ctx.conn, message=ctx.message, system=ctx.system, member=member)
return await func(ctx, args[1:])
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category, system_required=False)(wrapper)
return wrap
import pluralkit.bot.commands.import_commands
import pluralkit.bot.commands.member_commands
import pluralkit.bot.commands.message_commands
import pluralkit.bot.commands.misc_commands
import pluralkit.bot.commands.mod_commands
import pluralkit.bot.commands.switch_commands
import pluralkit.bot.commands.system_commands

View File

@@ -0,0 +1,143 @@
import asyncio
import re
from datetime import datetime
import logging
from typing import List
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="import tupperware", description="Import data from Tupperware.")
async def import_tupperware(ctx: CommandContext, args: List[str]):
tupperware_member = ctx.message.server.get_member("431544605209788416") or ctx.message.server.get_member("433916057053560832")
if not tupperware_member:
raise CommandError("This command only works in a server where the Tupperware bot is also present.")
channel_permissions = ctx.message.channel.permissions_for(tupperware_member)
if not (channel_permissions.read_messages and channel_permissions.send_messages):
raise CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
await ctx.reply(embed=utils.make_default_embed("Please reply to this message with `tul!list` (or the server equivalent)."))
# Check to make sure the Tupperware response actually belongs to the correct user
def ensure_account(tw_msg):
if not tw_msg.embeds:
return False
if not tw_msg.embeds[0]["title"]:
return False
return tw_msg.embeds[0]["title"].startswith("{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
embeds = []
tw_msg: discord.Message = await ctx.client.wait_for_message(author=tupperware_member, channel=ctx.message.channel, timeout=60.0, check=ensure_account)
if not tw_msg:
raise CommandError("Tupperware import timed out.")
embeds.append(tw_msg.embeds[0])
# Handle Tupperware pagination
def match_pagination():
pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0]["title"])
if not pagination_match:
return None
return int(pagination_match.group(1)), int(pagination_match.group(2))
pagination_match = match_pagination()
if pagination_match:
status_msg = await ctx.reply("Multi-page member list found. Please manually scroll through all the pages.")
current_page = 0
total_pages = 1
pages_found = {}
# Keep trying to read the embed with new pages
last_found_time = datetime.utcnow()
while len(pages_found) < total_pages:
new_page, total_pages = match_pagination()
# Put the found page in the pages dict
pages_found[new_page] = dict(tw_msg.embeds[0])
# If this isn't the same page as last check, edit the status message
if new_page != current_page:
last_found_time = datetime.utcnow()
await ctx.client.edit_message(status_msg, "Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(len(pages_found), total_pages))
current_page = new_page
# And sleep a bit to prevent spamming the CPU
await asyncio.sleep(0.25)
# Make sure it doesn't spin here for too long, time out after 30 seconds since last new page
if (datetime.utcnow() - last_found_time).seconds > 30:
raise CommandError("Pagination scan timed out.")
# Now that we've got all the pages, put them in the embeds list
# Make sure to erase the original one we put in above too
embeds = list([embed for page, embed in sorted(pages_found.items(), key=lambda x: x[0])])
# Also edit the status message to indicate we're now importing, and it may take a while because there's probably a lot of members
await ctx.client.edit_message(status_msg, "All pages read. Now importing...")
logger.debug("Importing from Tupperware...")
# Create new (nameless) system if there isn't any registered
system = ctx.system
if system is None:
hid = utils.generate_hid()
logger.debug("Creating new system (hid={})...".format(hid))
system = await db.create_system(ctx.conn, system_name=None, system_hid=hid)
await db.link_account(ctx.conn, system_id=system["id"], account_id=ctx.message.author.id)
for embed in embeds:
for field in embed["fields"]:
name = field["name"]
lines = field["value"].split("\n")
member_prefix = None
member_suffix = None
member_avatar = None
member_birthdate = None
member_description = None
# Read the message format line by line
for line in lines:
if line.startswith("Brackets:"):
brackets = line[len("Brackets: "):]
member_prefix = brackets[:brackets.index("text")].strip() or None
member_suffix = brackets[brackets.index("text")+4:].strip() or None
elif line.startswith("Avatar URL: "):
url = line[len("Avatar URL: "):]
member_avatar = url
elif line.startswith("Birthday: "):
bday_str = line[len("Birthday: "):]
bday = datetime.strptime(bday_str, "%a %b %d %Y")
if bday:
member_birthdate = bday.date()
elif line.startswith("Total messages sent: ") or line.startswith("Tag: "):
# Ignore this, just so it doesn't catch as the description
pass
else:
member_description = line
# Read by name - TW doesn't allow name collisions so we're safe here (prevents dupes)
existing_member = await db.get_member_by_name(ctx.conn, system_id=system.id, member_name=name)
if not existing_member:
# Or create a new member
hid = utils.generate_hid()
logger.debug("Creating new member {} (hid={})...".format(name, hid))
existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid)
# Save the new stuff in the DB
logger.debug("Updating fields...")
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="prefix", value=member_prefix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="suffix", value=member_suffix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url", value=member_avatar)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday", value=member_birthdate)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", value=member_description)
return "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting."

View File

@@ -0,0 +1,150 @@
import logging
import re
from datetime import datetime
from typing import List
from urllib.parse import urlparse
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands")
async def member_info(ctx: MemberCommandContext, args: List[str]):
await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, ctx.member))
@command(cmd="member new", usage="<name>", description="Adds a new member to your system.", category="Member commands")
async def new_member(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
name = " ".join(args)
bounds_error = utils.bounds_check_member_name(name, ctx.system.tag)
if bounds_error:
raise CommandError(bounds_error)
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid()
# Insert member row
await db.create_member(ctx.conn, system_id=ctx.system.id, member_name=name, member_hid=hid)
return "Member \"{}\" (`{}`) registered!".format(name, hid)
@member_command(cmd="member set", usage="<name|description|color|pronouns|birthdate|avatar> [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands")
async def member_set(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"]
db_properties = {
"name": "name",
"description": "description",
"color": "color",
"pronouns": "pronouns",
"birthdate": "birthday",
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
raise CommandError("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)))
if len(args) >= 2:
value = " ".join(args[1:])
# Sanity/validity checks and type conversions
if prop == "name":
bounds_error = utils.bounds_check_member_name(value, ctx.system.tag)
if bounds_error:
raise CommandError(bounds_error)
if prop == "color":
match = re.fullmatch("#?([0-9A-Fa-f]{6})", value)
if not match:
raise CommandError("Color must be a valid hex color (eg. #ff0000)")
value = match.group(1).lower()
if prop == "birthdate":
try:
value = datetime.strptime(value, "%Y-%m-%d").date()
except ValueError:
try:
# Try again, adding 0001 as a placeholder year
# This is considered a "null year" and will be omitted from the info card
# Useful if you want your birthday to be displayed yearless.
value = datetime.strptime("0001-" + value, "%Y-%m-%d").date()
except ValueError:
raise CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).")
if prop == "avatar":
user = await utils.parse_mention(ctx.client, value)
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
raise CommandError("Invalid URL.")
else:
# Can't clear member name
if prop == "name":
raise CommandError("Can't clear member name.")
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field=db_prop, value=value)
response = utils.make_default_embed("{} {}'s {}.".format("Updated" if value else "Cleared", ctx.member.name, prop))
if prop == "avatar" and value:
response.set_image(url=value)
if prop == "color" and value:
response.colour = int(value, 16)
return response
@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands")
async def member_proxy(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
prefix, suffix = None, None
else:
# Sanity checking
example = " ".join(args)
if "text" not in example:
raise CommandError("Example proxy message must contain the string 'text'.")
if example.count("text") != 1:
raise CommandError("Example proxy message must contain the string 'text' exactly once.")
# Extract prefix and suffix
prefix = example[:example.index("text")].strip()
suffix = example[example.index("text")+4:].strip()
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
# DB stores empty strings as None, make that work
if not prefix:
prefix = None
if not suffix:
suffix = None
async with ctx.conn.transaction():
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="prefix", value=prefix)
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="suffix", value=suffix)
return "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands")
async def member_delete(ctx: MemberCommandContext, args: List[str]):
await ctx.reply("Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(ctx.member.name, ctx.member.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content == ctx.member.hid:
await db.delete_member(ctx.conn, member_id=ctx.member.id)
return "Member deleted."
else:
return "Member deletion cancelled."

View File

@@ -0,0 +1,57 @@
import logging
from typing import List
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.",
category="Message commands")
async def message_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
try:
mid = int(args[0])
except ValueError:
raise InvalidCommandSyntax()
# Find the message in the DB
message = await db.get_message(ctx.conn, str(mid))
if not message:
raise CommandError("Message not found.")
# Get the original sender of the messages
try:
original_sender = await ctx.client.get_user_info(str(message.sender))
except discord.NotFound:
# Account was since deleted - rare but we're handling it anyway
original_sender = None
embed = discord.Embed()
embed.timestamp = discord.utils.snowflake_time(str(mid))
embed.colour = discord.Colour.blue()
if message.system_name:
system_value = "{} (`{}`)".format(message.system_name, message.system_hid)
else:
system_value = "`{}`".format(message.system_hid)
embed.add_field(name="System", value=system_value)
embed.add_field(name="Member", value="{} (`{}`)".format(message.name, message.hid))
if original_sender:
sender_name = "{}#{}".format(original_sender.name, original_sender.discriminator)
else:
sender_name = "(deleted account {})".format(message.sender)
embed.add_field(name="Sent by", value=sender_name)
if message.content: # Content can be empty string if there's an attachment
embed.add_field(name="Content", value=message.content, inline=False)
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
return embed

View File

@@ -0,0 +1,89 @@
import io
import json
import logging
import os
from typing import List
from discord.utils import oauth_url
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.")
async def show_help(ctx: CommandContext, args: List[str]):
embed = utils.make_default_embed("")
embed.title = "PluralKit Help"
embed.set_footer(text="By Astrid (Ske#6201, or 'qoxvy' on PK) | GitHub: https://github.com/xSke/PluralKit/")
category = args[0] if len(args) > 0 else None
from pluralkit.bot.help import help_pages
if category in help_pages:
for name, text in help_pages[category]:
if name:
embed.add_field(name=name, value=text)
else:
embed.description = text
else:
raise InvalidCommandSyntax()
return embed
@command(cmd="invite", description="Generates an invite link for this bot.")
async def invite_link(ctx: CommandContext, args: List[str]):
client_id = os.environ["CLIENT_ID"]
permissions = discord.Permissions()
permissions.manage_webhooks = True
permissions.send_messages = True
permissions.manage_messages = True
permissions.embed_links = True
permissions.attach_files = True
permissions.read_message_history = True
permissions.add_reactions = True
url = oauth_url(client_id, permissions)
logger.debug("Sending invite URL: {}".format(url))
return url
@command(cmd="export", description="Exports system data to a machine-readable format.")
async def export(ctx: CommandContext, args: List[str]):
members = await db.get_all_members(ctx.conn, ctx.system.id)
accounts = await db.get_linked_accounts(ctx.conn, ctx.system.id)
switches = await utils.get_front_history(ctx.conn, ctx.system.id, 999999)
system = ctx.system
data = {
"name": system.name,
"id": system.hid,
"description": system.description,
"tag": system.tag,
"avatar_url": system.avatar_url,
"created": system.created.isoformat(),
"members": [
{
"name": member.name,
"id": member.hid,
"color": member.color,
"avatar_url": member.avatar_url,
"birthday": member.birthday.isoformat() if member.birthday else None,
"pronouns": member.pronouns,
"description": member.description,
"prefix": member.prefix,
"suffix": member.suffix,
"created": member.created.isoformat()
} for member in members
],
"accounts": [str(uid) for uid in accounts],
"switches": [
{
"timestamp": timestamp.isoformat(),
"members": [member.hid for member in members]
} for timestamp, members in switches
]
}
f = io.BytesIO(json.dumps(data).encode("utf-8"))
await ctx.client.send_file(ctx.message.channel, f, filename="system.json")

View File

@@ -0,0 +1,24 @@
import logging
from typing import List
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")
async def set_log(ctx: CommandContext, args: List[str]):
if not ctx.message.author.server_permissions.administrator:
raise CommandError("You must be a server administrator to use this command.")
server = ctx.message.server
if len(args) == 0:
channel_id = None
else:
channel = utils.parse_channel_mention(args[0], server=server)
if not channel:
raise CommandError("Channel not found.")
channel_id = channel.id
await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id)
return "Updated logging channel." if channel_id else "Cleared logging channel."

View File

@@ -0,0 +1,119 @@
from datetime import datetime
import logging
from typing import List
import dateparser
import humanize
from pluralkit import Member
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
async def switch_member(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
members: List[Member] = []
for member_name in args:
# Find the member
member = await utils.get_member_fuzzy(ctx.conn, ctx.system.id, member_name)
if not member:
raise CommandError("Couldn't find member \"{}\".".format(member_name))
members.append(member)
# Compare requested switch IDs and existing fronter IDs to check for existing switches
# Lists, because order matters, it makes sense to just swap fronters
member_ids = [member.id for member in members]
fronter_ids = (await utils.get_fronter_ids(ctx.conn, ctx.system.id))[0]
if member_ids == fronter_ids:
if len(members) == 1:
raise CommandError("{} is already fronting.".format(members[0].name))
raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
# Also make sure there aren't any duplicates
if len(set(member_ids)) != len(member_ids):
raise CommandError("Duplicate members in switch list.")
# Log the switch
async with ctx.conn.transaction():
switch_id = await db.add_switch(ctx.conn, system_id=ctx.system.id)
for member in members:
await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id)
if len(members) == 1:
return "Switch registered. Current fronter is now {}.".format(members[0].name)
else:
return "Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members]))
@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands")
async def switch_out(ctx: MemberCommandContext, args: List[str]):
# Get current fronters
fronters, _ = await utils.get_fronter_ids(ctx.conn, system_id=ctx.system.id)
if not fronters:
raise CommandError("There's already no one in front.")
# Log it, and don't log any members
await db.add_switch(ctx.conn, system_id=ctx.system.id)
return "Switch-out registered."
@command(cmd="switch move", usage="<time>", description="Moves the most recent switch to a different point in time.", category="Switching commands")
async def switch_move(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
# Parse the time to move to
new_time = dateparser.parse(" ".join(args), languages=["en"], settings={
"TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False
})
if not new_time:
raise CommandError("{} can't be parsed as a valid time.".format(" ".join(args)))
# Make sure the time isn't in the future
if new_time > datetime.now():
raise CommandError("Can't move switch to a time in the future.")
# Make sure it all runs in a big transaction for atomicity
async with ctx.conn.transaction():
# Get the last two switches to make sure the switch to move isn't before the second-last switch
last_two_switches = await utils.get_front_history(ctx.conn, ctx.system.id, count=2)
if len(last_two_switches) == 0:
raise CommandError("There are no registered switches for this system.")
last_timestamp, last_fronters = last_two_switches[0]
if len(last_two_switches) > 1:
second_last_timestamp, _ = last_two_switches[1]
if new_time < second_last_timestamp:
time_str = humanize.naturaltime(second_last_timestamp)
raise CommandError("Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str))
# Display the confirmation message w/ humanized times
members = ", ".join([member.name for member in last_fronters]) or "nobody"
last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds")
last_relative = humanize.naturaltime(last_timestamp)
new_absolute = new_time.isoformat(sep=" ", timespec="seconds")
new_relative = humanize.naturaltime(new_time)
embed = utils.make_default_embed("This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative))
# Await and handle confirmation reactions
confirm_msg = await ctx.reply(embed=embed)
await ctx.client.add_reaction(confirm_msg, "")
await ctx.client.add_reaction(confirm_msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=confirm_msg, user=ctx.message.author, timeout=60.0)
if not reaction:
raise CommandError("Switch move timed out.")
if reaction.reaction.emoji == "":
raise CommandError("Switch move cancelled.")
# DB requires the actual switch ID which our utility method above doesn't return, do this manually
switch_id = (await db.front_history(ctx.conn, ctx.system.id, count=1))[0]["id"]
# Change the switch in the DB
await db.move_last_switch(ctx.conn, ctx.system.id, switch_id, new_time)
return "Switch moved."

View File

@@ -0,0 +1,215 @@
import logging
from typing import List
from urllib.parse import urlparse
import humanize
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands", system_required=False)
async def system_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
system = ctx.system
else:
# Look one up
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
raise CommandError("Unable to find system \"{}\".".format(args[0]))
await ctx.reply(embed=await utils.generate_system_info_card(ctx.conn, ctx.client, system))
@command(cmd="system new", usage="[name]", description="Registers a new system to this account.", category="System commands", system_required=False)
async def new_system(ctx: CommandContext, args: List[str]):
if ctx.system:
raise CommandError("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
system_name = None
if len(args) > 0:
system_name = " ".join(args)
async with ctx.conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid()
system = await db.create_system(ctx.conn, system_name=system_name, system_hid=hid)
# Link account
await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return "System registered! To begin adding members, use `pk;member new <name>`."
@command(cmd="system set", usage="<name|description|tag|avatar> [value]", description="Edits a system property. Leave [value] blank to clear.", category="System commands")
async def system_set(ctx: CommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
allowed_properties = ["name", "description", "tag", "avatar"]
db_properties = {
"name": "name",
"description": "description",
"tag": "tag",
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
raise CommandError("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)))
if len(args) >= 2:
value = " ".join(args[1:])
# Sanity checking
if prop == "tag":
if len(value) > 32:
raise CommandError("Can't have system tag longer than 32 characters.")
# Make sure there are no members which would make the combined length exceed 32
members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=ctx.system.id, length=32 - len(value) - 1)
if len(members_exceeding) > 0:
# If so, error out and warn
member_names = ", ".join([member.name
for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
raise CommandError("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names))
if prop == "avatar":
user = await utils.parse_mention(ctx.client, value)
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
raise CommandError("Invalid URL.")
else:
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_system_field(ctx.conn, system_id=ctx.system.id, field=db_prop, value=value)
response = utils.make_default_embed("{} system {}.".format("Updated" if value else "Cleared", prop))
if prop == "avatar" and value:
response.set_image(url=value)
return response
@command(cmd="system link", usage="<account>", description="Links another account to your system.", category="System commands")
async def system_link(ctx: CommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
# Find account to link
linkee = await utils.parse_mention(ctx.client, args[0])
if not linkee:
raise CommandError("Account not found.")
# Make sure account doesn't already have a system
account_system = await db.get_system_by_account(ctx.conn, linkee.id)
if account_system:
raise CommandError("Account is already linked to a system (`{}`)".format(account_system.hid))
# Send confirmation message
msg = await ctx.reply("{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
await ctx.client.add_reaction(msg, "")
await ctx.client.add_reaction(msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee, timeout=60.0)
# If account to be linked confirms...
if not reaction:
raise CommandError("Account link timed out.")
if not reaction.reaction.emoji == "":
raise CommandError("Account link cancelled.")
await db.link_account(ctx.conn, system_id=ctx.system.id, account_id=linkee.id)
return "Account linked to system."
@command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.", category="System commands")
async def system_unlink(ctx: CommandContext, args: List[str]):
# Make sure you can't unlink every account
linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=ctx.system.id)
if len(linked_accounts) == 1:
raise CommandError("This is the only account on your system, so you can't unlink it.")
await db.unlink_account(ctx.conn, system_id=ctx.system.id, account_id=ctx.message.author.id)
return "Account unlinked."
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands", system_required=False)
async def system_fronter(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
raise CommandError("Can't find system \"{}\".".format(args[0]))
fronters, timestamp = await utils.get_fronters(ctx.conn, system_id=ctx.system.id)
fronter_names = [member.name for member in fronters]
embed = utils.make_default_embed(None)
if len(fronter_names) == 0:
embed.add_field(name="Current fronter", value="(no fronter)")
elif len(fronter_names) == 1:
embed.add_field(name="Current fronter", value=fronter_names[0])
else:
embed.add_field(name="Current fronters", value=", ".join(fronter_names))
if timestamp:
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp)))
return embed
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands", system_required=False)
async def system_fronthistory(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
raise CommandError("Can't find system \"{}\".".format(args[0]))
lines = []
front_history = await utils.get_front_history(ctx.conn, ctx.system.id, count=10)
for i, (timestamp, members) in enumerate(front_history):
# Special case when no one's fronting
if len(members) == 0:
name = "(no fronter)"
else:
name = ", ".join([member.name for member in members])
# Make proper date string
time_text = timestamp.isoformat(sep=" ", timespec="seconds")
rel_text = humanize.naturaltime(timestamp)
delta_text = ""
if i > 0:
last_switch_time = front_history[i-1][0]
delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time))
lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text))
embed = utils.make_default_embed("\n".join(lines) or "(none)")
embed.title = "Past switches"
return embed
@command(cmd="system delete", description="Deletes your system from the database ***permanently***.", category="System commands")
async def system_delete(ctx: CommandContext, args: List[str]):
await ctx.reply("Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(ctx.system.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content == ctx.system.hid:
await db.remove_system(ctx.conn, system_id=ctx.system.id)
return "System deleted."
else:
return "System deletion cancelled."

156
src/pluralkit/bot/help.py Normal file
View File

@@ -0,0 +1,156 @@
help_pages = {
None: [
(None,
"""PluralKit is a bot designed for plural communities on Discord. It allows you to register systems, maintain system information, set up message proxying, log switches, and more."""),
("Getting started",
"""To get started, set up a system with `pk;system new`. Then, inspect the other help pages for further instructions."""),
("Help categories",
"""`pk;help system` - Details on system configuration.
`pk;help member` - Details on member configuration.
`pk;help proxy` - Details on message proxying.
`pk;help switch` - Details on switch logging.
`pk;help mod` - Details on moderator operations.
`pk;help import` - Details on data import from other services."""),
("Discord",
"""For feedback, bug reports, suggestions, or just chatting, join our Discord: https://discord.gg/PczBt78""")
],
"system": [
("Registering a new system",
"""To use PluralKit, you must register a system for your account. You can use the `pk;system new` command for this. You can optionally add a system name after the command."""),
("Looking up a system",
"""To look up a system's details, you can use the `pk;system` command.
For example:
`pk;system` - Shows details of your own system.
`pk;system abcde` - Shows details of the system with the ID `abcde`.
`pk;system @JohnsAccount` - Shows details of the system linked to @JohnsAccount."""),
("Editing system properties",
"""You can use the `pk;system set` command to change your system properties. The properties you can change are name, description, and tag.
For example:
`pk;system set name My System` - sets your system name to "My System".
`pk;system set description A really cool system.` - sets your system description.
`pk;system set tag [MS]` - Sets the tag (which will be displayed after member names in messages) to "[MS]".
`pk;system set avatar https://placekitten.com/400/400` - Changes your system's avatar to a linked image.
If you don't specify any value, the property will be cleared."""),
("Linking accounts",
"""If your system has multiple accounts, you can link all of them to your system, and you can use the bot from all of those accounts.
For example:
`pk;system link @MyOtherAccount` - Links @MyOtherAccount to your system.
You'll need to confirm the link from the other account."""),
("Unlinking accounts",
"""If you need to unlink an account, you can do that with the `pk;system unlink` command.""")
],
"member": [
("Adding a new member",
"""To add a new member to your system, use the `pk;member new` command. You'll need to add a member name.
For example:
`pk;member new John`"""),
("Looking up a member",
"""To look up a member's details, you can use the `pk;member` command.
For example:
`pk;member John` - Shows details of the member in your system named John.
`pk;member abcde` - Shows details of the member with the ID `abcde`.
You can use member IDs to look up members in other systems."""),
("Editing member properties",
"""You can use the `pk;member set` command to change a member's properties. The properties you can change are name, description, color, pronouns, birthdate and avatar.
For example:
`pk;member set John name Joe` - Changes John's name to Joe.
`pk;member set John description Pretty cool dude.` - Changes John's description.
`pk;member set John color #ff0000` - Changes John's color to red.
`pk;member set John pronouns he/him` - Changes John's pronouns.
`pk;member set John birthdate 1996-02-27` - Changes John's birthdate to Feb 27, 1996. (Must be YYYY-MM-DD format).
`pk;member set John birthdate 02-27` - Changes John's birthdate to February 27th, with no year.
`pk;member set John avatar https://placekitten.com/400/400` - Changes John's avatar to a linked image.
`pk;member set John avatar @JohnsAccount` - Changes John's avatar to the avatar of the mentioned account.
If you don't specify any value, the property will be cleared."""),
("Removing a member",
"""If you want to delete a member, you can use the `pk;member delete` command.
For example:
`pk;member delete John`
You will need to confirm the deletion.""")
],
"proxy": [
("Setting up member proxying",
"""To register a member for proxying, use the `pk;member proxy` command.
For example:
`pk;member proxy John [text]` - Registers John to use [square brackets] as tags.
`pk;member proxy John J:text` - Registers John to use the prefix "J:".
After setting proxy tags, you can use them in any message, and they'll be interpreted by the bot and proxied appropriately."""),
("Setting your system tag",
"""To set your system tag, use the `pk;system set tag` command.
The tag is appended to the name of all proxied messages.
For example:
`pk;system set tag [MS]` - Sets your system tag to "[MS]".
`pk;system set tag :heart:` - Sets your system tag to the heart emoji.
Note you can only use default Discord emojis, not custom server emojis."""),
("Looking up a message",
"""You can look up a message by its ID using the `pk;message` command.
For example:
`pk;message 467638937402212352` - Shows information about the message by that ID.
To get a message ID, turn on Developer Mode in your client's Appearance settings, right click, and press "Copy ID"."""),
("Deleting messages",
"""You can delete your own messages by reacting with the ❌ emoji on it. Note that this only works on messages sent from your account.""")
],
"switch": [
("Registering a switch",
"""To log a switch in your system, use the `pk;switch` command.
For example:
`pk;switch John` - Registers a switch with John as fronter.
`pk;switch John Jill` - Registers a switch John and Jill as co-fronters."""),
("Switching out",
"""You can use the `pk;switch out` command to register a switch with no one in front."""),
("Moving a switch",
"""You can move the latest switch you have registered using the `pk;switch move` command.
This is useful if you log the switch a while after it happened, and you want to properly backdate it in the history.
For example:
`pk;switch move 10 minutes ago` - Moves the latest switch to 10 minutes ago
`pk;switch move 11pm EST` - Moves the latest switch to 11pm EST
Note that you can't move the switch further back than the second-last logged switch, and you can't move a switch to a time in the future.
The default time zone for absolute times is UTC, but you can specify other time zones in the command itself, as given in the example."""),
("Viewing fronting history",
"""To view front history, you can use the `pk;system fronter` and `pk;system fronthistory` commands.
For example:
`pk;system fronter` - Shows the current fronter(s) in your own system.
`pk;system fronter abcde` - Shows the current fronter in the system with the ID `abcde`.
`pk;system fronthistory` - Shows the past 10 switches in your own system.
`pk;system fronthistory @JohnsAccount` - Shows the past 10 switches in the system linked to @JohnsAccount.""")
],
"mod": [
(None, "Note that all moderation commands require you to have administrator privileges on the server they're used on."),
("Setting up a logging channel",
"""To designate a channel for the bot to log posted messages to, use the `pk;mod log` command.
For example:
`pk;mod log #message-log` - Configures the bot to log to #message-log.""")
],
"import": [
("Importing from Tupperware",
"""If you already have a registered system on Tupperware, you can use the `pk;import tupperware` command to import it into PluralKit.
Note the command only works on a server and channel where the Tupperware bot is already present.""")
]
}

297
src/pluralkit/bot/proxy.py Normal file
View File

@@ -0,0 +1,297 @@
import ciso8601
import logging
import re
from typing import List, Optional
import aiohttp
import discord
from pluralkit import db
from pluralkit.bot import channel_logger, utils
logger = logging.getLogger("pluralkit.bot.proxy")
def extract_leading_mentions(message_text):
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
if not match:
return message_text, ""
# Return the text after the mentions, and the mentions themselves
return message_text[match.span(0)[1]:].strip(), match.group(0)
def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
# Skip members with no defined proxy tags
if not member.prefix and not member.suffix:
return None
# DB defines empty prefix/suffixes as None, replace with empty strings to prevent errors
prefix = member.prefix or ""
suffix = member.suffix or ""
# Ignore mentions at the very start of the message, and match proxy tags after those
message_text, leading_mentions = extract_leading_mentions(message_text)
logger.debug("Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions, prefix, suffix))
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
prefix_length = len(prefix)
suffix_length = len(suffix)
# If suffix_length is 0, the last bit of the slice will be "-0", and the slice will fail
if suffix_length > 0:
inner_string = message_text[prefix_length:-suffix_length]
else:
inner_string = message_text[prefix_length:]
# Add the mentions we stripped back
inner_string = leading_mentions + inner_string
return inner_string
def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
# Sort by specificity (members with both prefix and suffix go higher)
# This will make sure more "precise" proxy tags get tried first
members: List[db.ProxyMember] = sorted(members, key=lambda x: int(
bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
for member in members:
match = match_member_proxy_tags(member, message_text)
if match is not None: # Using "is not None" because an empty string is OK here too
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
return member, match
def get_message_attachment_url(message: discord.Message):
if not message.attachments:
return None
attachment = message.attachments[0]
if "proxy_url" in attachment:
return attachment["proxy_url"]
if "url" in attachment:
return attachment["url"]
# TODO: possibly move this to bot __init__ so commands can access it too
class WebhookPermissionError(Exception):
pass
class DeletionPermissionError(Exception):
pass
class Proxy:
def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
self.logger = logging.getLogger("pluralkit.bot.proxy")
self.session = aiohttp.ClientSession()
self.client = client
self.token = token
self.channel_logger = logger
async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
await db.add_webhook(conn, channel.id, id, token)
return id, token
async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
# This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
# No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
req_headers = {"Authorization": "Bot {}".format(self.token)}
# First, check if there's already a webhook belonging to the bot
async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
headers=req_headers) as resp:
if resp.status == 200:
webhooks = await resp.json()
for webhook in webhooks:
if webhook["user"]["id"] == self.client.user.id:
# This webhook belongs to us, we can use that, return it and save it
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# Then, try submitting a new one
req_data = {"name": "PluralKit Proxy Webhook"}
async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
json=req_data, headers=req_headers) as resp:
if resp.status == 200:
webhook = await resp.json()
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# Should not be reached without an exception being thrown
async def get_webhook_for_channel(self, conn, channel: discord.Channel):
async with conn.transaction():
hook_match = await db.get_webhook(conn, channel.id)
if not hook_match:
# We don't have a webhook, create/add one
return await self.create_and_add_channel_webhook(conn, channel)
else:
return hook_match
async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
attachment_url: str, has_already_retried=False):
hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
form_data = aiohttp.FormData()
form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
if text:
form_data.add_field("content", text)
if attachment_url:
attachment_resp = await self.session.get(attachment_url)
form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
filename=attachment_resp.url.name)
if member.avatar_url:
form_data.add_field("avatar_url", member.avatar_url)
async with self.session.post(
"https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
data=form_data) as resp:
if resp.status == 200:
message = await resp.json()
await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id,
text or "")
try:
await self.client.delete_message(original_message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to delete original message (server={}, channel={})".format(
original_message.server.id, original_message.channel.id))
raise DeletionPermissionError()
except discord.NotFound:
self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
message_image = None
if message["attachments"]:
first_attachment = message["attachments"][0]
if "width" in first_attachment and "height" in first_attachment:
# Only log attachments that are actually images
message_image = first_attachment["url"]
await self.channel_logger.log_message_proxied(conn,
server_id=original_message.server.id,
channel_name=original_message.channel.name,
channel_id=original_message.channel.id,
sender_name=original_message.author.name,
sender_disc=original_message.author.discriminator,
member_name=member.name,
member_hid=member.hid,
member_avatar_url=member.avatar_url,
system_name=member.system_name,
system_hid=member.system_hid,
message_text=text,
message_image=message_image,
message_timestamp=ciso8601.parse_datetime(
message["timestamp"]),
message_id=message["id"])
elif resp.status == 404 and not has_already_retried:
# Webhook doesn't exist. Delete it from the DB, create, and add a new one
self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
await db.delete_webhook(conn, original_message.channel.id)
await self.create_and_add_channel_webhook(conn, original_message.channel)
# Then try again all over, making sure to not retry again and go in a loop should it continually fail
return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
else:
raise discord.HTTPException(resp, await resp.text())
async def try_proxy_message(self, conn, message: discord.Message):
# Can't proxy in DMs, webhook creation will explode
if message.channel.is_private:
return False
# Big fat query to find every member associated with this account
# Returned member object has a few more keys (system tag, for example)
members = await db.get_members_by_account(conn, account_id=message.author.id)
match = match_proxy_tags(members, message.content)
if not match:
return False
member, text = match
attachment_url = get_message_attachment_url(message)
# Can't proxy a message with no text AND no attachment
if not text and not attachment_url:
self.logger.debug("Skipping message because of no text and no attachment")
return False
try:
async with conn.transaction():
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
except WebhookPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
except DeletionPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
return True
async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
async with conn.transaction():
# Find the message in the DB, and make sure it's sent by the user (if we need to check)
if check_user_id:
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
else:
db_message = await db.get_message(conn, message_id=message_id)
if db_message:
self.logger.debug("Deleting message {}".format(message_id))
channel = self.client.get_channel(str(db_message.channel))
# If we should also delete the actual message, do that
if delete_message:
message = await self.client.get_message(channel, message_id)
try:
await self.client.delete_message(message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
channel.server.id, channel.id))
return
# Remove it from the DB
await db.delete_message(conn, message_id)
# Then log deletion to logging channel
await self.channel_logger.log_message_deleted(conn,
server_id=channel.server.id,
channel_name=channel.name,
member_name=db_message.name,
member_hid=db_message.hid,
member_avatar_url=db_message.avatar_url,
system_name=db_message.system_name,
system_hid=db_message.system_hid,
message_text=db_message.content,
message_id=message_id,
deleted_by_moderator=deleted_by_moderator)
async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
if emoji == "":
await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
async def handle_deletion(self, conn, message_id: str):
# Don't delete the message, it's already gone at this point, just handle DB deletion and logging
await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)

219
src/pluralkit/bot/utils.py Normal file
View File

@@ -0,0 +1,219 @@
from datetime import datetime
import logging
import random
import re
from typing import List, Tuple
import string
import asyncio
import asyncpg
import discord
import humanize
from pluralkit import System, Member, db
logger = logging.getLogger("pluralkit.utils")
def escape(s):
return s.replace("`", "\\`")
def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5))
def bounds_check_member_name(new_name, system_tag):
if len(new_name) > 32:
return "Name cannot be longer than 32 characters."
if system_tag:
if len("{} {}".format(new_name, system_tag)) > 32:
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
async def parse_mention(client: discord.Client, mention: str) -> discord.User:
# First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention)
if match:
try:
return await client.get_user_info(match.group(1))
except discord.NotFound:
return None
# Then try with just ID
try:
return await client.get_user_info(str(int(mention)))
except (ValueError, discord.NotFound):
return None
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
match = re.fullmatch("<#(\\d+)>", mention)
if match:
return server.get_channel(match.group(1))
try:
return server.get_channel(str(int(mention)))
except ValueError:
return None
async def get_fronter_ids(conn, system_id) -> (List[int], datetime):
switches = await db.front_history(conn, system_id=system_id, count=1)
if not switches:
return [], None
if not switches[0]["members"]:
return [], switches[0]["timestamp"]
return switches[0]["members"], switches[0]["timestamp"]
async def get_fronters(conn, system_id) -> (List[Member], datetime):
member_ids, timestamp = await get_fronter_ids(conn, system_id)
# Collect in dict and then look up as list, to preserve return order
members = {member.id: member for member in await db.get_members(conn, member_ids)}
return [members[member_id] for member_id in member_ids], timestamp
async def get_front_history(conn, system_id, count) -> List[Tuple[datetime, List[Member]]]:
# Get history from DB
switches = await db.front_history(conn, system_id=system_id, count=count)
if not switches:
return []
# Get all unique IDs referenced
all_member_ids = {id for switch in switches for id in switch["members"]}
# And look them up in the database into a dict
all_members = {member.id: member for member in await db.get_members(conn, list(all_member_ids))}
# Collect in array and return
out = []
for switch in switches:
timestamp = switch["timestamp"]
members = [all_members[id] for id in switch["members"]]
out.append((timestamp, members))
return out
async def get_system_fuzzy(conn, client: discord.Client, key) -> System:
if isinstance(key, discord.User):
return await db.get_system_by_account(conn, account_id=key.id)
if isinstance(key, str) and len(key) == 5:
return await db.get_system_by_hid(conn, system_hid=key)
account = await parse_mention(client, key)
if account:
system = await db.get_system_by_account(conn, account_id=account.id)
if system:
return system
return None
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> Member:
# First search by hid
if system_only:
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
else:
member = await db.get_member_by_hid(conn, member_hid=key)
if member is not None:
return member
# Then search by name, if we have a system
if system_id:
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
if member is not None:
return member
def make_default_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message
return embed
def make_error_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message
return embed
async def generate_system_info_card(conn, client: discord.Client, system: System) -> discord.Embed:
card = discord.Embed()
card.colour = discord.Colour.blue()
if system.name:
card.title = system.name
if system.avatar_url:
card.set_thumbnail(url=system.avatar_url)
if system.tag:
card.add_field(name="Tag", value=system.tag)
fronters, switch_time = await get_fronters(conn, system.id)
if fronters:
names = ", ".join([member.name for member in fronters])
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time))
card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters", value=fronter_val)
account_names = []
for account_id in await db.get_linked_accounts(conn, system_id=system.id):
account = await client.get_user_info(account_id)
account_names.append("{}#{}".format(account.name, account.discriminator))
card.add_field(name="Linked accounts", value="\n".join(account_names))
if system.description:
card.add_field(name="Description",
value=system.description, inline=False)
# Get names of all members
member_texts = []
for member in await db.get_all_members(conn, system_id=system.id):
member_texts.append("{} (`{}`)".format(escape(member.name), member.hid))
if len(member_texts) > 0:
card.add_field(name="Members", value="\n".join(
member_texts), inline=False)
card.set_footer(text="System ID: {}".format(system.hid))
return card
async def generate_member_info_card(conn, member: Member) -> discord.Embed:
system = await db.get_system(conn, system_id=member.system)
card = discord.Embed()
card.colour = discord.Colour.blue()
name_and_system = member.name
if system.name:
name_and_system += " ({})".format(system.name)
card.set_author(name=name_and_system, icon_url=member.avatar_url or discord.Embed.Empty)
if member.avatar_url:
card.set_thumbnail(url=member.avatar_url)
# Get system name and hid
system = await db.get_system(conn, system_id=member.system)
if member.color:
card.colour = int(member.color, 16)
if member.birthday:
bday_val = member.birthday.strftime("%b %d, %Y")
if member.birthday.year == 1:
bday_val = member.birthday.strftime("%b %d")
card.add_field(name="Birthdate", value=bday_val)
if member.pronouns:
card.add_field(name="Pronouns", value=member.pronouns)
if member.prefix or member.suffix:
prefix = member.prefix or ""
suffix = member.suffix or ""
card.add_field(name="Proxy Tags",
value="{}text{}".format(prefix, suffix))
if member.description:
card.add_field(name="Description",
value=member.description, inline=False)
card.set_footer(text="System ID: {} | Member ID: {}".format(
system.hid, member.hid))
return card

358
src/pluralkit/db.py Normal file
View File

@@ -0,0 +1,358 @@
from collections import namedtuple
from datetime import datetime
import logging
from typing import List
import time
import asyncpg
import asyncpg.exceptions
from pluralkit import System, Member, stats
logger = logging.getLogger("pluralkit.db")
async def connect():
while True:
try:
return await asyncpg.create_pool(user="postgres", password="postgres", database="postgres", host="db")
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
pass
def db_wrap(func):
async def inner(*args, **kwargs):
before = time.perf_counter()
try:
res = await func(*args, **kwargs)
after = time.perf_counter()
logger.debug(" - DB call {} took {:.2f} ms".format(func.__name__, (after - before) * 1000))
await stats.report_db_query(func.__name__, after - before, True)
return res
except asyncpg.exceptions.PostgresError:
await stats.report_db_query(func.__name__, time.perf_counter() - before, False)
logger.exception("Error from database query {}".format(func.__name__))
return inner
@db_wrap
async def create_system(conn, system_name: str, system_hid: str) -> System:
logger.debug("Creating system (name={}, hid={})".format(
system_name, system_hid))
row = await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
return System(**row) if row else None
@db_wrap
async def remove_system(conn, system_id: int):
logger.debug("Deleting system (id={})".format(system_id))
await conn.execute("delete from systems where id = $1", system_id)
@db_wrap
async def create_member(conn, system_id: int, member_name: str, member_hid: str) -> Member:
logger.debug("Creating member (system={}, name={}, hid={})".format(
system_id, member_name, member_hid))
row = await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
return Member(**row) if row else None
@db_wrap
async def delete_member(conn, member_id: int):
logger.debug("Deleting member (id={})".format(member_id))
await conn.execute("delete from members where id = $1", member_id)
@db_wrap
async def link_account(conn, system_id: int, account_id: str):
logger.debug("Linking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
@db_wrap
async def unlink_account(conn, system_id: int, account_id: str):
logger.debug("Unlinking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
@db_wrap
async def get_linked_accounts(conn, system_id: int) -> List[int]:
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
@db_wrap
async def get_system_by_account(conn, account_id: str) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
return System(**row) if row else None
@db_wrap
async def get_system_by_hid(conn, system_hid: str) -> System:
row = await conn.fetchrow("select * from systems where hid = $1", system_hid)
return System(**row) if row else None
@db_wrap
async def get_system(conn, system_id: int) -> System:
row = await conn.fetchrow("select * from systems where id = $1", system_id)
return System(**row) if row else None
@db_wrap
async def get_member_by_name(conn, system_id: int, member_name: str) -> Member:
row = await conn.fetchrow("select * from members where system = $1 and lower(name) = lower($2)", system_id, member_name)
return Member(**row) if row else None
@db_wrap
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str) -> Member:
row = await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
return Member(**row) if row else None
@db_wrap
async def get_member_by_hid(conn, member_hid: str) -> Member:
row = await conn.fetchrow("select * from members where hid = $1", member_hid)
return Member(**row) if row else None
@db_wrap
async def get_member(conn, member_id: int) -> Member:
row = await conn.fetchrow("select * from members where id = $1", member_id)
return Member(**row) if row else None
@db_wrap
async def get_members(conn, members: list) -> List[Member]:
rows = await conn.fetch("select * from members where id = any($1)", members)
return [Member(**row) for row in rows]
@db_wrap
async def update_system_field(conn, system_id: int, field: str, value):
logger.debug("Updating system field (id={}, {}={})".format(
system_id, field, value))
await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id)
@db_wrap
async def update_member_field(conn, member_id: int, field: str, value):
logger.debug("Updating member field (id={}, {}={})".format(
member_id, field, value))
await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id)
@db_wrap
async def get_all_members(conn, system_id: int) -> List[Member]:
rows = await conn.fetch("select * from members where system = $1", system_id)
return [Member(**row) for row in rows]
@db_wrap
async def get_members_exceeding(conn, system_id: int, length: int) -> List[Member]:
rows = await conn.fetch("select * from members where system = $1 and length(name) > $2", system_id, length)
return [Member(**row) for row in rows]
@db_wrap
async def get_webhook(conn, channel_id: str) -> (str, str):
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
return (str(row["webhook"]), row["token"]) if row else None
@db_wrap
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
@db_wrap
async def delete_webhook(conn, channel_id: str):
await conn.execute("delete from webhooks where channel = $1", int(channel_id))
@db_wrap
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str, content: str):
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender, content) values ($1, $2, $3, $4, $5)", int(message_id), int(channel_id), member_id, int(sender_id), content)
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
id: int
hid: str
prefix: str
suffix: str
color: str
name: str
avatar_url: str
tag: str
system_name: str
system_hid: str
@db_wrap
async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
# Returns a "chimera" object
rows = await conn.fetch("""select
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
systems.tag, systems.name as system_name, systems.hid as system_hid
from
systems, members, accounts
where
accounts.uid = $1
and systems.id = accounts.system
and members.system = systems.id""", int(account_id))
return [ProxyMember(**row) for row in rows]
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "content", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
mid: int
channel: int
member: int
content: str
sender: int
name: str
hid: str
avatar_url: str
system_name: str
system_hid: str
@db_wrap
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url,
systems.name as system_name, systems.hid as system_hid
from
messages, members, systems
where
messages.member = members.id
and members.system = systems.id
and mid = $1
and sender = $2""", int(message_id), int(sender_id))
return MessageInfo(**row) if row else None
@db_wrap
async def get_message(conn, message_id: str) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url,
systems.name as system_name, systems.hid as system_hid
from
messages, members, systems
where
messages.member = members.id
and members.system = systems.id
and mid = $1""", int(message_id))
return MessageInfo(**row) if row else None
@db_wrap
async def delete_message(conn, message_id: str):
logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", int(message_id))
@db_wrap
async def front_history(conn, system_id: int, count: int):
return await conn.fetch("""select
switches.*,
array(
select member from switch_members
where switch_members.switch = switches.id
order by switch_members.id asc
) as members
from switches
where switches.system = $1
order by switches.timestamp desc
limit $2""", system_id, count)
@db_wrap
async def add_switch(conn, system_id: int):
logger.debug("Adding switch (system={})".format(system_id))
res = await conn.fetchrow("insert into switches (system) values ($1) returning *", system_id)
return res["id"]
@db_wrap
async def move_last_switch(conn, system_id: int, switch_id: int, new_time: datetime):
logger.debug("Moving latest switch (system={}, id={}, new_time={})".format(system_id, switch_id, new_time))
await conn.execute("update switches set timestamp = $1 where system = $2 and id = $3", new_time, system_id, switch_id)
@db_wrap
async def add_switch_member(conn, switch_id: int, member_id: int):
logger.debug("Adding switch member (switch={}, member={})".format(switch_id, member_id))
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
@db_wrap
async def get_server_info(conn, server_id: str):
return await conn.fetchrow("select * from servers where id = $1", int(server_id))
@db_wrap
async def update_server(conn, server_id: str, logging_channel_id: str):
logging_channel_id = int(logging_channel_id) if logging_channel_id else None
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
@db_wrap
async def member_count(conn) -> int:
return await conn.fetchval("select count(*) from members")
@db_wrap
async def system_count(conn) -> int:
return await conn.fetchval("select count(*) from systems")
@db_wrap
async def message_count(conn) -> int:
return await conn.fetchval("select count(*) from messages")
@db_wrap
async def account_count(conn) -> int:
return await conn.fetchval("select count(*) from accounts")
async def create_tables(conn):
await conn.execute("""create table if not exists systems (
id serial primary key,
hid char(5) unique not null,
name text,
description text,
tag text,
avatar_url text,
created timestamp not null default current_timestamp
)""")
await conn.execute("""create table if not exists members (
id serial primary key,
hid char(5) unique not null,
system serial not null references systems(id) on delete cascade,
color char(6),
avatar_url text,
name text not null,
birthday date,
pronouns text,
description text,
prefix text,
suffix text,
created timestamp not null default current_timestamp
)""")
await conn.execute("""create table if not exists accounts (
uid bigint primary key,
system serial not null references systems(id) on delete cascade
)""")
await conn.execute("""create table if not exists messages (
mid bigint primary key,
channel bigint not null,
member serial not null references members(id) on delete cascade,
content text not null,
sender bigint not null
)""")
await conn.execute("""create table if not exists switches (
id serial primary key,
system serial not null references systems(id) on delete cascade,
timestamp timestamp not null default current_timestamp
)""")
await conn.execute("""create table if not exists switch_members (
id serial primary key,
switch serial not null references switches(id) on delete cascade,
member serial not null references members(id) on delete cascade
)""")
await conn.execute("""create table if not exists webhooks (
channel bigint primary key,
webhook bigint not null,
token text not null
)""")
await conn.execute("""create table if not exists servers (
id bigint primary key,
log_channel bigint
)""")

45
src/pluralkit/stats.py Normal file
View File

@@ -0,0 +1,45 @@
from aioinflux import InfluxDBClient
client = None
async def connect():
global client
client = InfluxDBClient(host="influx", db="pluralkit")
await client.create_database(db="pluralkit")
async def report_db_query(query_name, time, success):
await client.write({
"measurement": "database_query",
"tags": {"query": query_name},
"fields": {"response_time": time, "success": int(success)}
})
async def report_command(command_name, execution_time, response_time):
await client.write({
"measurement": "command",
"tags": {"command": command_name},
"fields": {"execution_time": execution_time, "response_time": response_time}
})
async def report_webhook(time, success):
await client.write({
"measurement": "webhook",
"fields": {"response_time": time, "success": int(success)}
})
async def report_periodical_stats(conn):
from pluralkit import db
systems = await db.system_count(conn)
members = await db.member_count(conn)
messages = await db.message_count(conn)
accounts = await db.account_count(conn)
await client.write({
"measurement": "stats",
"fields": {
"systems": systems,
"members": members,
"messages": messages,
"accounts": accounts
}
})

7
src/requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
aiohttp
aioinflux
asyncpg
dateparser
discord.py
humanize
uvloop