Massive refactor of pretty much everything in the bot

This commit is contained in:
Ske 2018-07-24 22:47:57 +02:00
parent 086fa84b4b
commit 8936029dc8
27 changed files with 1799 additions and 1450 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.env .env
.vscode/ .vscode/
.idea/

View File

@ -1 +0,0 @@
from . import commands, db, proxy, stats

View File

@ -1,122 +0,0 @@
import asyncio
from datetime import datetime
import logging
import json
import os
import time
import discord
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logging.getLogger("discord").setLevel(logging.INFO)
logging.getLogger("websockets").setLevel(logging.INFO)
logger = logging.getLogger("pluralkit.bot")
logger.setLevel(logging.INFO)
client = discord.Client()
@client.event
async def on_error(evt, *args, **kwargs):
logger.exception(
"Error while handling event {} with arguments {}:".format(evt, args))
@client.event
async def on_ready():
# Print status info
logger.info("Connected to Discord.")
logger.info("Account: {}#{}".format(
client.user.name, client.user.discriminator))
logger.info("User ID: {}".format(client.user.id))
@client.event
async def on_message(message):
# Ignore bot messages
if message.author.bot:
return
# Split into args. shlex sucks so we don't bother with quotes
args = message.content.split(" ")
from pluralkit import proxy, utils, stats
command_items = utils.command_map.items()
command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True)
prefix = "pk;"
for command, (func, _, _, _) in command_items:
if message.content.lower().startswith(prefix + command):
args_str = message.content[len(prefix + command):].strip()
args = args_str.split(" ")
# Splitting on empty string yields one-element array, remove that
if len(args) == 1 and not args[0]:
args = []
async with client.pool.acquire() as conn:
time_before = time.perf_counter()
await func(conn, message, args)
time_after = time.perf_counter()
# Report command time stats
execution_time = time_after - time_before
response_time = (datetime.now() - message.timestamp).total_seconds()
await stats.report_command(command, execution_time, response_time)
return
# Try doing proxy parsing
async with client.pool.acquire() as conn:
await proxy.handle_proxying(conn, message)
@client.event
async def on_socket_raw_receive(msg):
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# we parse socket data manually for the reaction add event
if isinstance(msg, str):
try:
msg_data = json.loads(msg)
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
evt_data = msg_data.get("d")
if evt_data:
user_id = evt_data["user_id"]
message_id = evt_data["message_id"]
emoji = evt_data["emoji"]["name"]
async with client.pool.acquire() as conn:
from pluralkit import proxy
await proxy.handle_reaction(conn, user_id, message_id, emoji)
except ValueError:
pass
async def periodical_stat_timer(pool):
async with pool.acquire() as conn:
while True:
from pluralkit import stats
await stats.report_periodical_stats(conn)
await asyncio.sleep(30)
async def run():
from pluralkit import db, stats
try:
logger.info("Connecting to database...")
pool = await db.connect()
logger.info("Attempting to create tables...")
async with pool.acquire() as conn:
await db.create_tables(conn)
logger.info("Connecting to InfluxDB...")
await stats.connect()
logger.info("Starting periodical stat reporting...")
asyncio.get_event_loop().create_task(periodical_stat_timer(pool))
client.pool = pool
logger.info("Connecting to Discord...")
await client.start(os.environ["TOKEN"])
finally:
logger.info("Logging out from Discord...")
await client.logout()

View File

@ -1,752 +0,0 @@
from datetime import datetime, timezone
import io
import itertools
import json
import os
import re
from urllib.parse import urlparse
import dateparser
import discord
from discord.utils import oauth_url
import humanize
from pluralkit import db
from pluralkit.bot import client, logger
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention, bounds_check_member_name, get_fronters, get_fronter_ids, get_front_history
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands")
async def system_info(conn, message, args):
if len(args) == 0:
# Use sender's system
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
else:
# Look one up
system = await get_system_fuzzy(conn, args[0])
if system is None:
return False, "Unable to find system \"{}\".".format(args[0])
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
return True
@command(cmd="system new", usage="[name]", description="Registers a new system to this account.", category="System commands")
async def new_system(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is not None:
return False, "You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`."
system_name = None
if len(args) > 0:
system_name = " ".join(args)
async with conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = generate_hid()
system = await db.create_system(conn, system_name=system_name, system_hid=hid)
# Link account
await db.link_account(conn, system_id=system["id"], account_id=message.author.id)
return True, "System registered! To begin adding members, use `pk;member new <name>`."
@command(cmd="system set", usage="<name|description|tag|avatar> [value]", description="Edits a system property. Leave [value] blank to clear.", category="System commands")
async def system_set(conn, message, args):
if len(args) == 0:
return False
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
allowed_properties = ["name", "description", "tag", "avatar"]
db_properties = {
"name": "name",
"description": "description",
"tag": "tag",
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
return False, "Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties))
if len(args) >= 2:
value = " ".join(args[1:])
# Sanity checking
if prop == "tag":
# Make sure there are no members which would make the combined length exceed 32
members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=32 - len(value))
if len(members_exceeding) > 0:
# If so, error out and warn
member_names = ", ".join([member["name"]
for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names)
if prop == "avatar":
user = await parse_mention(value)
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
return False, "Invalid URL."
else:
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_system_field(conn, system_id=system["id"], field=db_prop, value=value)
response = make_default_embed("{} system {}.".format("Updated" if value else "Cleared", prop))
if prop == "avatar" and value:
response.set_image(url=value)
return True, response
@command(cmd="system link", usage="<account>", description="Links another account to your system.", category="System commands")
async def system_link(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
if len(args) == 0:
return False
# Find account to link
linkee = await parse_mention(args[0])
if not linkee:
return False, "Account not found."
# Make sure account doesn't already have a system
account_system = await db.get_system_by_account(conn, linkee.id)
if account_system:
return False, "Account is already linked to a system (`{}`)".format(account_system["hid"])
# Send confirmation message
msg = await client.send_message(message.channel, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
await client.add_reaction(msg, "")
await client.add_reaction(msg, "")
reaction = await client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee)
# If account to be linked confirms...
if reaction.reaction.emoji == "":
async with conn.transaction():
# Execute the link
await db.link_account(conn, system_id=system["id"], account_id=linkee.id)
return True, "Account linked to system."
else:
await client.clear_reactions(msg)
return False, "Account link cancelled."
@command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.", category="System commands")
async def system_unlink(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
# Make sure you can't unlink every account
linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"])
if len(linked_accounts) == 1:
return False, "This is the only account on your system, so you can't unlink it."
async with conn.transaction():
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
return True, "Account unlinked."
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands")
async def system_fronter(conn, message, args):
if len(args) == 0:
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
else:
system = await get_system_fuzzy(conn, args[0])
if system is None:
return False, "Can't find system \"{}\".".format(args[0])
fronters, timestamp = await get_fronters(conn, system_id=system["id"])
fronter_names = [member["name"] for member in fronters]
embed = make_default_embed(None)
if len(fronter_names) == 0:
embed.add_field(name="Current fronter", value="(no fronter)")
elif len(fronter_names) == 1:
embed.add_field(name="Current fronter", value=fronter_names[0])
else:
embed.add_field(name="Current fronters", value=", ".join(fronter_names))
if timestamp:
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp)))
return True, embed
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands")
async def system_fronthistory(conn, message, args):
if len(args) == 0:
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
else:
system = await get_system_fuzzy(conn, args[0])
if system is None:
return False, "Can't find system \"{}\".".format(args[0])
lines = []
front_history = await get_front_history(conn, system["id"], count=10)
for i, (timestamp, members) in enumerate(front_history):
# Special case when no one's fronting
if len(members) == 0:
name = "(no fronter)"
else:
name = ", ".join([member["name"] for member in members])
# Make proper date string
time_text = timestamp.isoformat(sep=" ", timespec="seconds")
rel_text = humanize.naturaltime(timestamp)
delta_text = ""
if i > 0:
last_switch_time = front_history[i-1][0]
delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time))
lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text))
embed = make_default_embed("\n".join(lines) or "(none)")
embed.title = "Past switches"
return True, embed
@command(cmd="system delete", description="Deletes your system from the database ***permanently***.", category="System commands")
async def system_delete(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
await client.send_message(message.channel, "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"]))
msg = await client.wait_for_message(author=message.author, channel=message.channel, timeout=60.0)
if msg and msg.content == system["hid"]:
await db.remove_system(conn, system_id=system["id"])
return True, "System deleted."
else:
return True, "System deletion cancelled."
@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands")
async def member_info(conn, message, member, args):
await client.send_message(message.channel, embed=await generate_member_info_card(conn, member))
return True
@command(cmd="member new", usage="<name>", description="Adds a new member to your system.", category="Member commands")
async def new_member(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
if len(args) == 0:
return False
name = " ".join(args)
bounds_error = bounds_check_member_name(name, system["tag"])
if bounds_error:
return False, bounds_error
async with conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = generate_hid()
# Insert member row
await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid)
return True, "Member \"{}\" (`{}`) registered!".format(name, hid)
@member_command(cmd="member set", usage="<name|description|color|pronouns|birthdate|avatar> [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands")
async def member_set(conn, message, member, args):
if len(args) == 0:
return False
allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"]
db_properties = {
"name": "name",
"description": "description",
"color": "color",
"pronouns": "pronouns",
"birthdate": "birthday",
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
return False, "Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties))
if len(args) >= 2:
value = " ".join(args[1:])
# Sanity/validity checks and type conversions
if prop == "name":
system = await db.get_system(conn, member["system"])
bounds_error = bounds_check_member_name(value, system["tag"])
if bounds_error:
return False, bounds_error
if prop == "color":
match = re.fullmatch("#?([0-9A-Fa-f]{6})", value)
if not match:
return False, "Color must be a valid hex color (eg. #ff0000)"
value = match.group(1).lower()
if prop == "birthdate":
try:
value = datetime.strptime(value, "%Y-%m-%d").date()
except ValueError:
try:
# Try again, adding 0001 as a placeholder year
# This is considered a "null year" and will be omitted from the info card
# Useful if you want your birthday to be displayed yearless.
value = value = datetime.strptime("0001-" + value, "%Y-%m-%d").date()
except ValueError:
return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)."
if prop == "avatar":
user = await parse_mention(value)
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
return False, "Invalid URL."
else:
# Can't clear member name
if prop == "name":
return False, "Can't clear member name."
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_member_field(conn, member_id=member["id"], field=db_prop, value=value)
response = make_default_embed("{} {}'s {}.".format("Updated" if value else "Cleared", member["name"], prop))
if prop == "avatar" and value:
response.set_image(url=value)
if prop == "color" and value:
response.colour = int(value, 16)
return True, response
@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands")
async def member_proxy(conn, message, member, args):
if len(args) == 0:
prefix, suffix = None, None
else:
# Sanity checking
example = " ".join(args)
if "text" not in example:
return False, "Example proxy message must contain the string 'text'."
if example.count("text") != 1:
return False, "Example proxy message must contain the string 'text' exactly once."
# Extract prefix and suffix
prefix = example[:example.index("text")].strip()
suffix = example[example.index("text")+4:].strip()
logger.debug(
"Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
# DB stores empty strings as None, make that work
if not prefix:
prefix = None
if not suffix:
suffix = None
async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="prefix", value=prefix)
await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix)
return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands")
async def member_delete(conn, message, member, args):
await client.send_message(message.channel, "Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(member["name"], member["hid"]))
msg = await client.wait_for_message(author=message.author, channel=message.channel, timeout=60.0)
if msg and msg.content == member["hid"]:
await db.delete_member(conn, member_id=member["id"])
return True, "Member deleted."
else:
return True, "Member deletion cancelled."
@command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.", category="Message commands")
async def message_info(conn, message, args):
try:
mid = int(args[0])
except ValueError:
return False
# Find the message in the DB
message_row = await db.get_message(conn, mid)
if not message_row:
return False, "Message not found."
# Get the original sender of the message
original_sender = await client.get_user_info(str(message_row["sender"]))
# Get sender member and system
member = await db.get_member(conn, message_row["member"])
system = await db.get_system(conn, member["system"])
embed = discord.Embed()
embed.timestamp = discord.utils.snowflake_time(str(mid))
embed.colour = discord.Colour.blue()
if system["name"]:
system_value = "{} (`{}`)".format(system["name"], system["hid"])
else:
system_value = "`{}`".format(system["hid"])
embed.add_field(name="System", value=system_value)
embed.add_field(name="Member", value="{}: (`{}`)".format(
member["name"], member["hid"]))
embed.add_field(name="Sent by", value="{}#{}".format(
original_sender.name, original_sender.discriminator))
embed.add_field(name="Content", value=message_row["content"], inline=False)
embed.set_author(name=member["name"], url=member["avatar_url"])
await client.send_message(message.channel, embed=embed)
return True
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
async def switch_member(conn, message, args):
if len(args) == 0:
return False
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
members = []
for member_name in args:
# Find the member
member = await get_member_fuzzy(conn, system["id"], member_name)
if not member:
return False, "Couldn't find member \"{}\".".format(member_name)
members.append(member)
# Compare requested switch IDs and existing fronter IDs to check for existing switches
# Lists, because order matters, it makes sense to just swap fronters
member_ids = [member["id"] for member in members]
fronter_ids = (await get_fronter_ids(conn, system["id"]))[0]
if member_ids == fronter_ids:
if len(members) == 1:
return False, "{} is already fronting.".format(members[0]["name"])
return False, "Members {} are already fronting.".format(", ".join([m["name"] for m in members]))
# Log the switch
async with conn.transaction():
switch_id = await db.add_switch(conn, system_id=system["id"])
for member in members:
await db.add_switch_member(conn, switch_id=switch_id, member_id=member["id"])
if len(members) == 1:
return True, "Switch registered. Current fronter is now {}.".format(member["name"])
else:
return True, "Switch registered. Current fronters are now {}.".format(", ".join([m["name"] for m in members]))
@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands")
async def switch_out(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
# Get current fronters
fronters, _ = await get_fronter_ids(conn, system_id=system["id"])
if not fronters:
return False, "There's already no one in front."
# Log it, and don't log any members
await db.add_switch(conn, system_id=system["id"])
return True, "Switch-out registered."
@command(cmd="switch move", usage="<time>", description="Moves the most recent switch to a different point in time.", category="Switching commands")
async def switch_move(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
if len(args) == 0:
return False
# Parse the time to move to
new_time = dateparser.parse(" ".join(args), languages=["en"], settings={
"TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False
})
if not new_time:
return False, "{} can't be parsed as a valid time.".format(" ".join(args))
# Make sure the time isn't in the future
if new_time > datetime.now():
return False, "Can't move switch to a time in the future."
# Make sure it all runs in a big transaction for atomicity
async with conn.transaction():
# Get the last two switches to make sure the switch to move isn't before the second-last switch
last_two_switches = await get_front_history(conn, system["id"], count=2)
if len(last_two_switches) == 0:
return False, "There are no registered switches for this system."
last_timestamp, last_fronters = last_two_switches[0]
if len(last_two_switches) > 1:
second_last_timestamp, _ = last_two_switches[1]
if new_time < second_last_timestamp:
time_str = humanize.naturaltime(second_last_timestamp)
return False, "Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str)
# Display the confirmation message w/ humanized times
members = ", ".join([member["name"] for member in last_fronters])
last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds")
last_relative = humanize.naturaltime(last_timestamp)
new_absolute = new_time.isoformat(sep=" ", timespec="seconds")
new_relative = humanize.naturaltime(new_time)
embed = make_default_embed("This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative))
# Await and handle confirmation reactions
confirm_msg = await client.send_message(message.channel, embed=embed)
await client.add_reaction(confirm_msg, "")
await client.add_reaction(confirm_msg, "")
reaction = await client.wait_for_reaction(emoji=["", ""], message=confirm_msg, user=message.author, timeout=60.0)
if not reaction:
return False, "Switch move timed out."
if reaction.reaction.emoji == "":
return False, "Switch move cancelled."
# DB requires the actual switch ID which our utility method above doesn't return, do this manually
switch_id = (await db.front_history(conn, system["id"], count=1))[0]["id"]
# Change the switch in the DB
await db.move_last_switch(conn, system["id"], switch_id, new_time)
return True, "Switch moved."
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")
async def set_log(conn, message, args):
if not message.author.server_permissions.administrator:
return False, "You must be a server administrator to use this command."
server = message.server
if len(args) == 0:
channel_id = None
else:
channel = parse_channel_mention(args[0], server=server)
if not channel:
return False, "Channel not found."
channel_id = channel.id
await db.update_server(conn, server.id, logging_channel_id=channel_id)
return True, "Updated logging channel." if channel_id else "Cleared logging channel."
@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.")
async def show_help(conn, message, args):
embed = make_default_embed("")
embed.title = "PluralKit Help"
category = args[0] if len(args) > 0 else None
from pluralkit.help import help_pages
if category in help_pages:
for name, text in help_pages[category]:
if name:
embed.add_field(name=name, value=text)
else:
embed.description = text
else:
return False
return True, embed
@command(cmd="import tupperware", description="Import data from Tupperware.")
async def import_tupperware(conn, message, args):
tupperware_member = message.server.get_member("431544605209788416") or message.server.get_member("433916057053560832")
if not tupperware_member:
return False, "This command only works in a server where the Tupperware bot is also present."
channel_permissions = message.channel.permissions_for(tupperware_member)
if not (channel_permissions.read_messages and channel_permissions.send_messages):
return False, "This command only works in a channel where the Tupperware bot has read/send access."
await client.send_message(message.channel, embed=make_default_embed("Please reply to this message with `tul!list` (or the server equivalent)."))
# Check to make sure the Tupperware response actually belongs to the correct user
def ensure_account(tw_msg):
if not tw_msg.embeds:
return False
if not tw_msg.embeds[0]["title"]:
return False
return tw_msg.embeds[0]["title"].startswith("{}#{}".format(message.author.name, message.author.discriminator))
embeds = []
tw_msg = await client.wait_for_message(author=tupperware_member, channel=message.channel, timeout=60.0, check=ensure_account)
if not tw_msg:
return False, "Tupperware import timed out."
embeds.append(tw_msg.embeds[0])
# Handle Tupperware pagination
if tw_msg.embeds[0]["title"].endswith("(page 1)"):
while True:
# Wait for a new message (within 1 second)
tw_msg = await client.wait_for_message(author=tupperware_member, channel=message.channel, timeout=1.0, check=ensure_account)
if not tw_msg:
# If no message, then it's probably done, so we break
break
# Otherwise add this next message to the list
embeds.append(tw_msg.embeds[0])
logger.debug("Importing from Tupperware...")
# Create new (nameless) system if there isn't any registered
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
hid = generate_hid()
logger.debug("Creating new system (hid={})...".format(hid))
system = await db.create_system(conn, system_name=None, system_hid=hid)
await db.link_account(conn, system_id=system["id"], account_id=message.author.id)
for embed in embeds:
for field in embed["fields"]:
name = field["name"]
lines = field["value"].split("\n")
member_prefix = None
member_suffix = None
member_avatar = None
member_birthdate = None
member_description = None
for line in lines:
if line.startswith("Brackets:"):
brackets = line[len("Brackets: "):]
member_prefix = brackets[:brackets.index("text")].strip() or None
member_suffix = brackets[brackets.index("text")+4:].strip() or None
elif line.startswith("Avatar URL: "):
url = line[len("Avatar URL: "):]
member_avatar = url
elif line.startswith("Birthday: "):
bday_str = line[len("Birthday: "):]
bday = datetime.strptime(bday_str, "%a %b %d %Y")
if bday:
member_birthdate = bday.date()
elif line.startswith("Total messages sent: ") or line.startswith("Tag: "):
# Ignore this, just so it doesn't catch as the description
pass
else:
member_description = line
existing_member = await db.get_member_by_name(conn, system_id=system["id"], member_name=name)
if not existing_member:
hid = generate_hid()
logger.debug("Creating new member {} (hid={})...".format(name, hid))
existing_member = await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid)
logger.debug("Updating fields...")
await db.update_member_field(conn, member_id=existing_member["id"], field="prefix", value=member_prefix)
await db.update_member_field(conn, member_id=existing_member["id"], field="suffix", value=member_suffix)
await db.update_member_field(conn, member_id=existing_member["id"], field="avatar_url", value=member_avatar)
await db.update_member_field(conn, member_id=existing_member["id"], field="birthday", value=member_birthdate)
await db.update_member_field(conn, member_id=existing_member["id"], field="description", value=member_description)
return True, "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting."
@command(cmd="invite", description="Generates an invite link for this bot.")
async def invite_link(conn, message, args):
client_id = os.environ["CLIENT_ID"]
permissions = discord.Permissions()
permissions.manage_webhooks = True
permissions.send_messages = True
permissions.manage_messages = True
permissions.embed_links = True
permissions.attach_files = True
permissions.read_message_history = True
permissions.add_reactions = True
url = oauth_url(client_id, permissions)
logger.debug("Sending invite URL: {}".format(url))
return True, url
@command(cmd="export", description="Exports system data to a machine-readable format.")
async def export(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
members = await db.get_all_members(conn, system["id"])
accounts = await db.get_linked_accounts(conn, system["id"])
switches = await get_front_history(conn, system["id"], 999999)
data = {
"name": system["name"],
"id": system["hid"],
"description": system["description"],
"tag": system["tag"],
"avatar_url": system["avatar_url"],
"created": system["created"].isoformat(),
"members": [
{
"name": member["name"],
"id": member["hid"],
"color": member["color"],
"avatar_url": member["avatar_url"],
"birthday": member["birthday"].isoformat() if member["birthday"] else None,
"pronouns": member["pronouns"],
"description": member["description"],
"prefix": member["prefix"],
"suffix": member["suffix"],
"created": member["created"].isoformat()
} for member in members
],
"accounts": [str(uid) for uid in accounts],
"switches": [
{
"timestamp": timestamp.isoformat(),
"members": [member["hid"] for member in members]
} for timestamp, members in switches
]
}
f = io.BytesIO(json.dumps(data).encode("utf-8"))
await client.send_file(message.channel, f, filename="system.json")

View File

@ -1,213 +0,0 @@
import json
import os
import re
import time
import aiohttp
import discord
from pluralkit import db, stats
from pluralkit.bot import client, logger
def make_log_embed(hook_message, member, channel_name):
author_name = "#{}: {}".format(channel_name, member["name"])
if member["system_name"]:
author_name += " ({})".format(member["system_name"])
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = hook_message.clean_content
embed.timestamp = hook_message.timestamp
embed.set_author(name=author_name, icon_url=member["avatar_url"] or discord.Embed.Empty)
if len(hook_message.attachments) > 0:
embed.set_image(url=hook_message.attachments[0]["url"])
return embed
async def log_message(original_message, hook_message, member, log_channel):
# hook_message is kinda broken, and doesn't include details from server or channel
# We rely on the fact that original_message must be in the same channel, this'll break if that changes
embed = make_log_embed(hook_message, member, channel_name=original_message.channel.name)
embed.set_footer(text="System ID: {} | Member ID: {} | Sender: {}#{} | Message ID: {}".format(member["system_hid"], member["hid"], original_message.author.name, original_message.author.discriminator, hook_message.id))
message_link = "https://discordapp.com/channels/{}/{}/{}".format(original_message.server.id, original_message.channel.id, hook_message.id)
embed.author.url = message_link
try:
await client.send_message(log_channel, embed=embed)
except discord.errors.Forbidden:
# Ignore logging permission errors, perhaps make it spam a big nasty error instead
pass
async def log_delete(hook_message, member, log_channel):
embed = make_log_embed(hook_message, member, channel_name=hook_message.channel.name)
embed.set_footer(text="System ID: {} | Member ID: {} | Message ID: {}".format(member["system_hid"], member["hid"], hook_message.id))
embed.colour = discord.Colour.dark_red()
await client.send_message(log_channel, embed=embed)
async def get_log_channel(conn, server):
# Check server info for a log channel
server_info = await db.get_server_info(conn, server.id)
if server_info and server_info["log_channel"]:
channel = server.get_channel(str(server_info["log_channel"]))
return channel
async def get_webhook(conn, channel):
async with conn.transaction():
# Try to find an existing webhook
hook_row = await db.get_webhook(conn, channel_id=channel.id)
# There's none, we'll make one
if not hook_row:
async with aiohttp.ClientSession() as session:
req_data = {"name": "PluralKit Proxy Webhook"}
req_headers = {
"Authorization": "Bot {}".format(os.environ["TOKEN"])}
async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp:
data = await resp.json()
hook_id = data["id"]
token = data["token"]
# Insert new hook into DB
await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token)
return hook_id, token
return hook_row["webhook"], hook_row["token"]
async def send_hook_message(member, hook_id, hook_token, text=None, image_url=None):
async with aiohttp.ClientSession() as session:
# Set up headers
req_headers = {
"Authorization": "Bot {}".format(os.environ["TOKEN"])
}
# Set up parameters
# Use FormData because the API doesn't like JSON requests with file data
fd = aiohttp.FormData()
fd.add_field("username", "{} {}".format(member["name"], member["tag"] or "").strip())
if member["avatar_url"]:
fd.add_field("avatar_url", member["avatar_url"])
if text:
fd.add_field("content", text)
if image_url:
# Fetch the image URL and proxy it directly into the file data (async streaming!)
image_resp = await session.get(image_url)
fd.add_field("file", image_resp.content, content_type=image_resp.content_type, filename=image_resp.url.name)
# Send the actual webhook request, and wait for a response
time_before = time.perf_counter()
try:
async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
data=fd,
headers=req_headers) as resp:
if resp.status == 200:
resp_data = await resp.json()
# Make a fake message object for passing on - this is slightly broken but works for most things
msg = discord.Message(reactions=[], **resp_data)
# Report to stats
await stats.report_webhook(time.perf_counter() - time_before, True)
return msg
else:
await stats.report_webhook(time.perf_counter() - time_before, False)
# Fake a Discord exception, also because #yolo
raise discord.HTTPException(resp, await resp.text())
except aiohttp.ClientResponseError:
await stats.report_webhook(time.perf_counter() - time_before, False)
logger.exception("Error while sending webhook message")
async def proxy_message(conn, member, trigger_message, inner):
logger.debug("Proxying message '{}' for member {}".format(inner, member["hid"]))
logger.info("[{}#{}] {}".format(member["system_hid"], member["hid"], inner))
# Get the webhook details
hook_id, hook_token = await get_webhook(conn, trigger_message.channel)
# Get attachment image URL if present (only works for one...)
image_urls = [a["url"] for a in trigger_message.attachments if "url" in a]
image_url = image_urls[0] if len(image_urls) > 0 else None
# Send the hook message
hook_message = await send_hook_message(member, hook_id, hook_token, text=inner, image_url=image_url)
# Insert new message details into the DB
await db.add_message(conn, message_id=hook_message.id, channel_id=trigger_message.channel.id, member_id=member["id"], sender_id=trigger_message.author.id, content=inner)
# Log message to logging channel if necessary
log_channel = await get_log_channel(conn, trigger_message.server)
if log_channel:
await log_message(trigger_message, hook_message, member, log_channel)
# Delete the original message
await client.delete_message(trigger_message)
async def handle_proxying(conn, message):
# Can't proxy in DMs, webhook creation will explode
if message.channel.is_private:
return
# Big fat query to find every member associated with this account
# Returned member object has a few more keys (system tag, for example)
members = await db.get_members_by_account(conn, account_id=message.author.id)
# Sort by specificity (members with both prefix and suffix go higher)
members = sorted(members, key=lambda x: int(
bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True)
msg = message.content
for member in members:
# If no proxy details are configured, skip
if not member["prefix"] and not member["suffix"]:
continue
# Database stores empty strings as null, fix that here
prefix = member["prefix"] or ""
suffix = member["suffix"] or ""
# Avoid matching a prefix of "<" on a mention
if prefix == "<":
if re.match(r"^<(?:@|@!|#|@&|:\w+:|a:\w+:)\d+>", msg):
continue
# If we have a match, proxy the message
if msg.startswith(prefix) and msg.endswith(suffix):
# Extract the actual message contents sans tags
if suffix:
inner_message = msg[len(prefix):-len(suffix)].strip()
else:
# Slicing to -0 breaks, don't do that
inner_message = msg[len(prefix):].strip()
# Make sure the message isn't blank (but only if it has no attachments)
if inner_message or message.attachments:
await proxy_message(conn, member, message, inner_message)
break
async def handle_reaction(conn, user_id, message_id, emoji):
if emoji == "":
async with conn.transaction():
# Find the message in the DB, and make sure it's sent by the user who reacted
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=user_id)
if db_message:
logger.debug("Deleting message {} by reaction from {}".format(message_id, user_id))
# If so, remove it from the DB
await db.delete_message(conn, message_id)
# And look up the message and then delete it
channel = client.get_channel(str(db_message["channel"]))
message = await client.get_message(channel, message_id)
await client.delete_message(message)
# Log deletion to logging channel if necessary
log_channel = await get_log_channel(conn, message.server)
if log_channel:
# db_message contains enough member data for the things to work
await log_delete(message, db_message, log_channel)

View File

@ -1,304 +0,0 @@
import random
import re
import string
import asyncio
import asyncpg
import discord
import humanize
from pluralkit import db
from pluralkit.bot import client, logger
def escape(s):
return s.replace("`", "\\`")
def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5))
def bounds_check_member_name(new_name, system_tag):
if len(new_name) > 32:
return "Name cannot be longer than 32 characters."
if system_tag:
if len("{} {}".format(new_name, system_tag)) > 32:
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
async def parse_mention(mention: str) -> discord.User:
# First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention)
if match:
try:
return await client.get_user_info(match.group(1))
except discord.NotFound:
return None
# Then try with just ID
try:
return await client.get_user_info(str(int(mention)))
except (ValueError, discord.NotFound):
return None
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
match = re.fullmatch("<#(\\d+)>", mention)
if match:
return server.get_channel(match.group(1))
try:
return server.get_channel(str(int(mention)))
except ValueError:
return None
async def get_fronter_ids(conn, system_id):
switches = await db.front_history(conn, system_id=system_id, count=1)
if not switches:
return [], None
if not switches[0]["members"]:
return [], switches[0]["timestamp"]
return switches[0]["members"], switches[0]["timestamp"]
async def get_fronters(conn, system_id):
member_ids, timestamp = await get_fronter_ids(conn, system_id)
# Collect in dict and then look up as list, to preserve return order
members = {member["id"]: member for member in await db.get_members(conn, member_ids)}
return [members[member_id] for member_id in member_ids], timestamp
async def get_front_history(conn, system_id, count):
# Get history from DB
switches = await db.front_history(conn, system_id=system_id, count=count)
if not switches:
return []
# Get all unique IDs referenced
all_member_ids = {id for switch in switches for id in switch["members"]}
# And look them up in the database into a dict
all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))}
# Collect in array and return
out = []
for switch in switches:
timestamp = switch["timestamp"]
members = [all_members[id] for id in switch["members"]]
out.append((timestamp, members))
return out
async def get_system_fuzzy(conn, key) -> asyncpg.Record:
if isinstance(key, discord.User):
return await db.get_system_by_account(conn, account_id=key.id)
if isinstance(key, str) and len(key) == 5:
return await db.get_system_by_hid(conn, system_hid=key)
account = await parse_mention(key)
if account:
system = await db.get_system_by_account(conn, account_id=account.id)
if system:
return system
return None
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record:
# First search by hid
if system_only:
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
else:
member = await db.get_member_by_hid(conn, member_hid=key)
if member is not None:
return member
# Then search by name, if we have a system
if system_id:
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
if member is not None:
return member
def make_default_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message
return embed
def make_error_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message
return embed
command_map = {}
# Command wrapper
# Return True for success, return False for failure
# Second parameter is the message it'll send. If just False, will print usage
def command(cmd, usage=None, description=None, category=None):
def wrap(func):
async def wrapper(conn, message, args):
res = await func(conn, message, args)
if res is not None:
if not isinstance(res, tuple):
success, msg = res, None
else:
success, msg = res
if not success and not msg:
# Failure, no message, print usage
usage_str = "**Usage:** pk;{} {}".format(cmd, usage or "")
await client.send_message(message.channel, embed=make_default_embed(usage_str))
elif not success:
# Failure, print message
embed = msg if isinstance(msg, discord.Embed) else make_error_embed(msg)
# embed.set_footer(text="{:.02f} ms".format(time_ms))
await client.send_message(message.channel, embed=embed)
elif msg:
# Success, print message
embed = msg if isinstance(msg, discord.Embed) else make_default_embed(msg)
# embed.set_footer(text="{:.02f} ms".format(time_ms))
await client.send_message(message.channel, embed=embed)
# Success, don't print anything
# Put command in map
command_map[cmd] = (wrapper, usage, description, category)
return wrapper
return wrap
# Member command wrapper
# Tries to find member by first argument
# If system_only=False, allows members from other systems by hid
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
def wrap(func):
async def wrapper(conn, message, args):
# Return if no member param
if len(args) == 0:
return False
# If system_only, we need a system to check
system = await db.get_system_by_account(conn, message.author.id)
if system_only and system is None:
return False, "No system is registered to this account."
# System is allowed to be none if not system_only
system_id = system["id"] if system else None
# And find member by key
member = await get_member_fuzzy(conn, system_id=system_id, key=args[0], system_only=system_only)
if member is None:
return False, "Can't find member \"{}\".".format(args[0])
return await func(conn, message, member, args[1:])
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category)(wrapper)
return wrap
async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed:
card = discord.Embed()
card.colour = discord.Colour.blue()
if system["name"]:
card.title = system["name"]
if system["avatar_url"]:
card.set_thumbnail(url=system["avatar_url"])
if system["tag"]:
card.add_field(name="Tag", value=system["tag"])
fronters, switch_time = await get_fronters(conn, system["id"])
if fronters:
names = ", ".join([member["name"] for member in fronters])
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time))
card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters", value=fronter_val)
account_names = []
for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
account = await client.get_user_info(account_id)
account_names.append("{}#{}".format(account.name, account.discriminator))
card.add_field(name="Linked accounts", value="\n".join(account_names))
if system["description"]:
card.add_field(name="Description",
value=system["description"], inline=False)
# Get names of all members
member_texts = []
for member in await db.get_all_members(conn, system_id=system["id"]):
member_texts.append("{} (`{}`)".format(escape(member["name"]), member["hid"]))
if len(member_texts) > 0:
card.add_field(name="Members", value="\n".join(
member_texts), inline=False)
card.set_footer(text="System ID: {}".format(system["hid"]))
return card
async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed:
system = await db.get_system(conn, system_id=member["system"])
card = discord.Embed()
card.colour = discord.Colour.blue()
name_and_system = member["name"]
if system["name"]:
name_and_system += " ({})".format(system["name"])
card.set_author(name=name_and_system, icon_url=member["avatar_url"] or discord.Embed.Empty)
if member["avatar_url"]:
card.set_thumbnail(url=member["avatar_url"])
# Get system name and hid
system = await db.get_system(conn, system_id=member["system"])
if member["color"]:
card.colour = int(member["color"], 16)
if member["birthday"]:
bday_val = member["birthday"].strftime("%b %d, %Y")
if member["birthday"].year == 1:
bday_val = member["birthday"].strftime("%b %d")
card.add_field(name="Birthdate", value=bday_val)
if member["pronouns"]:
card.add_field(name="Pronouns", value=member["pronouns"])
if member["prefix"] or member["suffix"]:
prefix = member["prefix"] or ""
suffix = member["suffix"] or ""
card.add_field(name="Proxy Tags",
value="{}text{}".format(prefix, suffix))
if member["description"]:
card.add_field(name="Description",
value=member["description"], inline=False)
card.set_footer(text="System ID: {} | Member ID: {}".format(
system["hid"], member["hid"]))
return card
async def text_input(message, subject):
embed = make_default_embed("")
embed.description = "Reply in this channel with the new description you want to set for {}.".format(subject)
status_msg = await client.send_message(message.channel, embed=embed)
reply_msg = await client.wait_for_message(author=message.author, channel=message.channel)
embed.description = "Alright. When you're happy with the new description, click the ✅ reaction. To cancel, click the ❌ reaction."
await client.edit_message(status_msg, embed=embed)
await client.add_reaction(reply_msg, "")
await client.add_reaction(reply_msg, "")
reaction = await client.wait_for_reaction(emoji=["", ""], message=reply_msg, user=message.author)
if reaction.reaction.emoji == "":
await client.clear_reactions(reply_msg)
return reply_msg.content
else:
await client.clear_reactions(reply_msg)
return None

View File

@ -1,7 +1,9 @@
version: '3' version: '3'
services: services:
bot: bot:
build: bot build:
context: src/
dockerfile: bot.Dockerfile
depends_on: depends_on:
- db - db
- influx - influx

View File

@ -7,5 +7,5 @@ ADD requirements.txt /app
RUN pip install --trusted-host pypi.python.org -r requirements.txt RUN pip install --trusted-host pypi.python.org -r requirements.txt
ADD . /app ADD . /app
ENTRYPOINT ["python", "main.py"] ENTRYPOINT ["python", "bot_main.py"]

View File

@ -1,9 +1,11 @@
import asyncio import asyncio
import os
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from pluralkit import bot from pluralkit import bot
pk = bot.PluralKitBot(os.environ["TOKEN"])
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(bot.run()) loop.run_until_complete(pk.run())

26
src/pluralkit/__init__.py Normal file
View File

@ -0,0 +1,26 @@
from collections import namedtuple
from datetime import date, datetime
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])):
id: int
hid: str
name: str
description: str
tag: str
avatar_url: str
created: datetime
class Member(namedtuple("Member", ["id", "hid", "system", "color", "avatar_url", "name", "birthday", "pronouns", "description", "prefix", "suffix", "created"])):
id: int
hid: str
system: int
color: str
avatar_url: str
name: str
birthday: date
pronouns: str
description: str
prefix: str
suffix: str
created: datetime

View File

@ -0,0 +1,131 @@
import asyncio
from datetime import datetime
import logging
import json
import os
import time
import discord
from pluralkit import db, stats
from pluralkit.bot import channel_logger, commands, proxy
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logging.getLogger("pluralkit").setLevel(logging.DEBUG)
class PluralKitBot:
def __init__(self, token):
self.token = token
self.logger = logging.getLogger("pluralkit.bot")
self.client = discord.Client()
self.client.event(self.on_error)
self.client.event(self.on_ready)
self.client.event(self.on_message)
self.client.event(self.on_socket_raw_receive)
self.channel_logger = channel_logger.ChannelLogger(self.client)
self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
async def on_error(self, evt, *args, **kwargs):
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
async def on_ready(self):
self.logger.info("Connected to Discord.")
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
self.logger.info("- User ID: {}".format(self.client.user.id))
self.logger.info("- {} servers".format(len(self.client.servers)))
async def on_message(self, message):
# Ignore bot messages
if message.author.bot:
return
if await self.handle_command_dispatch(message):
return
if await self.handle_proxy_dispatch(message):
return
async def on_socket_raw_receive(self, msg):
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# we parse socket data manually for the reaction add event
if isinstance(msg, str):
try:
msg_data = json.loads(msg)
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
evt_data = msg_data.get("d")
if evt_data:
user_id = evt_data["user_id"]
message_id = evt_data["message_id"]
emoji = evt_data["emoji"]["name"]
async with self.pool.acquire() as conn:
await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
elif msg_data.get("t") == "MESSAGE_DELETE":
evt_data = msg_data.get("d")
if evt_data:
message_id = evt_data["id"]
async with self.pool.acquire() as conn:
await self.proxy.handle_deletion(conn, message_id)
except ValueError:
pass
async def handle_command_dispatch(self, message):
command_items = commands.command_list.items()
command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True)
prefix = "pk;"
for command_name, command in command_items:
if message.content.lower().startswith(prefix + command_name):
args_str = message.content[len(prefix + command_name):].strip()
args = args_str.split(" ")
# Splitting on empty string yields one-element array, remove that
if len(args) == 1 and not args[0]:
args = []
async with self.pool.acquire() as conn:
time_before = time.perf_counter()
await command.function(self.client, conn, message, args)
time_after = time.perf_counter()
# Report command time stats
execution_time = time_after - time_before
response_time = (datetime.now() - message.timestamp).total_seconds()
await stats.report_command(command_name, execution_time, response_time)
return True
async def handle_proxy_dispatch(self, message):
# Try doing proxy parsing
async with self.pool.acquire() as conn:
return await self.proxy.try_proxy_message(conn, message)
async def periodical_stat_timer(self, pool):
async with pool.acquire() as conn:
while True:
from pluralkit import stats
await stats.report_periodical_stats(conn)
await asyncio.sleep(30)
async def run(self):
try:
self.logger.info("Connecting to database...")
self.pool = await db.connect()
self.logger.info("Attempting to create tables...")
async with self.pool.acquire() as conn:
await db.create_tables(conn)
self.logger.info("Connecting to InfluxDB...")
await stats.connect()
self.logger.info("Starting periodical stat reporting...")
asyncio.get_event_loop().create_task(self.periodical_stat_timer(self.pool))
self.logger.info("Connecting to Discord...")
await self.client.start(self.token)
finally:
self.logger.info("Logging out from Discord...")
await self.client.logout()

View File

@ -0,0 +1,109 @@
import logging
from datetime import datetime
import discord
from pluralkit import db
def embed_set_author_name(embed: discord.Embed, channel_name: str, member_name: str, system_name: str, avatar_url: str):
name = "#{}: {}".format(channel_name, member_name)
if system_name:
name += " ({})".format(system_name)
embed.set_author(name=name, icon_url=avatar_url or discord.Embed.Empty)
class ChannelLogger:
def __init__(self, client: discord.Client):
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
self.client = client
async def get_log_channel(self, conn, server_id: str):
server_info = await db.get_server_info(conn, server_id)
if not server_info:
return None
log_channel = server_info["log_channel"]
if not log_channel:
return None
return self.client.get_channel(str(log_channel))
async def send_to_log_channel(self, log_channel: discord.Channel, embed: discord.Embed):
try:
await self.client.send_message(log_channel, embed=embed)
except discord.Forbidden:
# TODO: spew big error
self.logger.warning(
"Did not have permission to send message to logging channel (server={}, channel={})".format(
log_channel.server.id, log_channel.id))
async def log_message_proxied(self, conn,
server_id: str,
channel_name: str,
channel_id: str,
sender_name: str,
sender_disc: int,
member_name: str,
member_hid: str,
member_avatar_url: str,
system_name: str,
system_hid: str,
message_text: str,
message_image: str,
message_timestamp: datetime,
message_id: str):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message_text
embed.timestamp = message_timestamp
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
embed.set_footer(
text="System ID: {} | Member ID: {} | Sender: {}#{} | Message ID: {}".format(system_hid, member_hid,
sender_name, sender_disc,
message_id))
if message_image:
embed.set_thumbnail(url=message_image)
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
embed.author.url = message_link
await self.send_to_log_channel(log_channel, embed)
async def log_message_deleted(self, conn,
server_id: str,
channel_name: str,
member_name: str,
member_hid: str,
member_avatar_url: str,
system_name: str,
system_hid: str,
message_text: str,
message_id: str,
deleted_by_moderator: bool):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message_text
embed.timestamp = datetime.utcnow()
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
embed.set_footer(
text="System ID: {} | Member ID: {} | Message ID: {} | Deleted by moderator? {}".format(system_hid,
member_hid,
message_id,
"Yes" if deleted_by_moderator else "No"))
await self.send_to_log_channel(log_channel, embed)

View File

@ -0,0 +1,98 @@
import logging
from collections import namedtuple
import asyncpg
import discord
import pluralkit
from pluralkit import db
from pluralkit.bot import utils
command_list = {}
class InvalidCommandSyntax(Exception):
pass
class NoSystemRegistered(Exception):
pass
class CommandError(Exception):
def __init__(self, message):
self.message = message
class CommandContext(namedtuple("CommandContext", ["client", "conn", "message", "system"])):
client: discord.Client
conn: asyncpg.Connection
message: discord.Message
system: pluralkit.System
async def reply(self, message=None, embed=None):
return await self.client.send_message(self.message.channel, message, embed=embed)
class MemberCommandContext(namedtuple("MemberCommandContext", CommandContext._fields + ("member",)), CommandContext):
client: discord.Client
conn: asyncpg.Connection
message: discord.Message
system: pluralkit.System
member: pluralkit.Member
class CommandEntry(namedtuple("CommandEntry", ["command", "function", "usage", "description", "category"])):
pass
def command(cmd, usage=None, description=None, category=None, system_required=True):
def wrap(func):
async def wrapper(client, conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system_required and system is None:
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account"))
return
ctx = CommandContext(client=client, conn=conn, message=message, system=system)
try:
res = await func(ctx, args)
if res:
embed = res if isinstance(res, discord.Embed) else utils.make_default_embed(res)
await client.send_message(message.channel, embed=embed)
except NoSystemRegistered:
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account"))
except InvalidCommandSyntax:
usage_str = "**Usage:** pk;{} {}".format(cmd, usage or "")
await client.send_message(message.channel, embed=utils.make_default_embed(usage_str))
except CommandError as e:
embed = e.message if isinstance(e.message, discord.Embed) else utils.make_error_embed(e.message)
await client.send_message(message.channel, embed=embed)
# Put command in map
command_list[cmd] = CommandEntry(command=cmd, function=wrapper, usage=usage, description=description, category=category)
return wrapper
return wrap
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
def wrap(func):
async def wrapper(ctx: CommandContext, args):
# Return if no member param
if len(args) == 0:
raise InvalidCommandSyntax()
# System is allowed to be none if not system_only
system_id = ctx.system.id if ctx.system else None
# And find member by key
member = await utils.get_member_fuzzy(ctx.conn, system_id=system_id, key=args[0], system_only=system_only)
if member is None:
raise CommandError("Can't find member \"{}\".".format(args[0]))
ctx = MemberCommandContext(client=ctx.client, conn=ctx.conn, message=ctx.message, system=ctx.system, member=member)
return await func(ctx, args[1:])
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category, system_required=False)(wrapper)
return wrap
import pluralkit.bot.commands.import_commands
import pluralkit.bot.commands.member_commands
import pluralkit.bot.commands.message_commands
import pluralkit.bot.commands.misc_commands
import pluralkit.bot.commands.mod_commands
import pluralkit.bot.commands.switch_commands
import pluralkit.bot.commands.system_commands

View File

@ -0,0 +1,143 @@
import asyncio
import re
from datetime import datetime
import logging
from typing import List
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="import tupperware", description="Import data from Tupperware.")
async def import_tupperware(ctx: CommandContext, args: List[str]):
tupperware_member = ctx.message.server.get_member("431544605209788416") or ctx.message.server.get_member("433916057053560832")
if not tupperware_member:
raise CommandError("This command only works in a server where the Tupperware bot is also present.")
channel_permissions = ctx.message.channel.permissions_for(tupperware_member)
if not (channel_permissions.read_messages and channel_permissions.send_messages):
raise CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
await ctx.reply(embed=utils.make_default_embed("Please reply to this message with `tul!list` (or the server equivalent)."))
# Check to make sure the Tupperware response actually belongs to the correct user
def ensure_account(tw_msg):
if not tw_msg.embeds:
return False
if not tw_msg.embeds[0]["title"]:
return False
return tw_msg.embeds[0]["title"].startswith("{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
embeds = []
tw_msg: discord.Message = await ctx.client.wait_for_message(author=tupperware_member, channel=ctx.message.channel, timeout=60.0, check=ensure_account)
if not tw_msg:
raise CommandError("Tupperware import timed out.")
embeds.append(tw_msg.embeds[0])
# Handle Tupperware pagination
def match_pagination():
pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0]["title"])
if not pagination_match:
return None
return int(pagination_match.group(1)), int(pagination_match.group(2))
pagination_match = match_pagination()
if pagination_match:
status_msg = await ctx.reply("Multi-page member list found. Please manually scroll through all the pages.")
current_page = 0
total_pages = 1
pages_found = {}
# Keep trying to read the embed with new pages
last_found_time = datetime.utcnow()
while len(pages_found) < total_pages:
new_page, total_pages = match_pagination()
# Put the found page in the pages dict
pages_found[new_page] = dict(tw_msg.embeds[0])
# If this isn't the same page as last check, edit the status message
if new_page != current_page:
last_found_time = datetime.utcnow()
await ctx.client.edit_message(status_msg, "Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(len(pages_found), total_pages))
current_page = new_page
# And sleep a bit to prevent spamming the CPU
await asyncio.sleep(0.25)
# Make sure it doesn't spin here for too long, time out after 30 seconds since last new page
if (datetime.utcnow() - last_found_time).seconds > 30:
raise CommandError("Pagination scan timed out.")
# Now that we've got all the pages, put them in the embeds list
# Make sure to erase the original one we put in above too
embeds = list([embed for page, embed in sorted(pages_found.items(), key=lambda x: x[0])])
# Also edit the status message to indicate we're now importing, and it may take a while because there's probably a lot of members
await ctx.client.edit_message(status_msg, "All pages read. Now importing...")
logger.debug("Importing from Tupperware...")
# Create new (nameless) system if there isn't any registered
system = ctx.system
if system is None:
hid = utils.generate_hid()
logger.debug("Creating new system (hid={})...".format(hid))
system = await db.create_system(ctx.conn, system_name=None, system_hid=hid)
await db.link_account(ctx.conn, system_id=system["id"], account_id=ctx.message.author.id)
for embed in embeds:
for field in embed["fields"]:
name = field["name"]
lines = field["value"].split("\n")
member_prefix = None
member_suffix = None
member_avatar = None
member_birthdate = None
member_description = None
# Read the message format line by line
for line in lines:
if line.startswith("Brackets:"):
brackets = line[len("Brackets: "):]
member_prefix = brackets[:brackets.index("text")].strip() or None
member_suffix = brackets[brackets.index("text")+4:].strip() or None
elif line.startswith("Avatar URL: "):
url = line[len("Avatar URL: "):]
member_avatar = url
elif line.startswith("Birthday: "):
bday_str = line[len("Birthday: "):]
bday = datetime.strptime(bday_str, "%a %b %d %Y")
if bday:
member_birthdate = bday.date()
elif line.startswith("Total messages sent: ") or line.startswith("Tag: "):
# Ignore this, just so it doesn't catch as the description
pass
else:
member_description = line
# Read by name - TW doesn't allow name collisions so we're safe here (prevents dupes)
existing_member = await db.get_member_by_name(ctx.conn, system_id=system.id, member_name=name)
if not existing_member:
# Or create a new member
hid = utils.generate_hid()
logger.debug("Creating new member {} (hid={})...".format(name, hid))
existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid)
# Save the new stuff in the DB
logger.debug("Updating fields...")
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="prefix", value=member_prefix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="suffix", value=member_suffix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url", value=member_avatar)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday", value=member_birthdate)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", value=member_description)
return "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting."

View File

@ -0,0 +1,150 @@
import logging
import re
from datetime import datetime
from typing import List
from urllib.parse import urlparse
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands")
async def member_info(ctx: MemberCommandContext, args: List[str]):
await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, ctx.member))
@command(cmd="member new", usage="<name>", description="Adds a new member to your system.", category="Member commands")
async def new_member(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
name = " ".join(args)
bounds_error = utils.bounds_check_member_name(name, ctx.system.tag)
if bounds_error:
raise CommandError(bounds_error)
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid()
# Insert member row
await db.create_member(ctx.conn, system_id=ctx.system.id, member_name=name, member_hid=hid)
return "Member \"{}\" (`{}`) registered!".format(name, hid)
@member_command(cmd="member set", usage="<name|description|color|pronouns|birthdate|avatar> [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands")
async def member_set(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"]
db_properties = {
"name": "name",
"description": "description",
"color": "color",
"pronouns": "pronouns",
"birthdate": "birthday",
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
raise CommandError("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)))
if len(args) >= 2:
value = " ".join(args[1:])
# Sanity/validity checks and type conversions
if prop == "name":
bounds_error = utils.bounds_check_member_name(value, ctx.system.tag)
if bounds_error:
raise CommandError(bounds_error)
if prop == "color":
match = re.fullmatch("#?([0-9A-Fa-f]{6})", value)
if not match:
raise CommandError("Color must be a valid hex color (eg. #ff0000)")
value = match.group(1).lower()
if prop == "birthdate":
try:
value = datetime.strptime(value, "%Y-%m-%d").date()
except ValueError:
try:
# Try again, adding 0001 as a placeholder year
# This is considered a "null year" and will be omitted from the info card
# Useful if you want your birthday to be displayed yearless.
value = datetime.strptime("0001-" + value, "%Y-%m-%d").date()
except ValueError:
raise CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).")
if prop == "avatar":
user = await utils.parse_mention(ctx.client, value)
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
raise CommandError("Invalid URL.")
else:
# Can't clear member name
if prop == "name":
raise CommandError("Can't clear member name.")
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field=db_prop, value=value)
response = utils.make_default_embed("{} {}'s {}.".format("Updated" if value else "Cleared", ctx.member.name, prop))
if prop == "avatar" and value:
response.set_image(url=value)
if prop == "color" and value:
response.colour = int(value, 16)
return response
@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands")
async def member_proxy(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
prefix, suffix = None, None
else:
# Sanity checking
example = " ".join(args)
if "text" not in example:
raise CommandError("Example proxy message must contain the string 'text'.")
if example.count("text") != 1:
raise CommandError("Example proxy message must contain the string 'text' exactly once.")
# Extract prefix and suffix
prefix = example[:example.index("text")].strip()
suffix = example[example.index("text")+4:].strip()
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
# DB stores empty strings as None, make that work
if not prefix:
prefix = None
if not suffix:
suffix = None
async with ctx.conn.transaction():
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="prefix", value=prefix)
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="suffix", value=suffix)
return "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands")
async def member_delete(ctx: MemberCommandContext, args: List[str]):
await ctx.reply("Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(ctx.member.name, ctx.member.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content == ctx.member.hid:
await db.delete_member(ctx.conn, member_id=ctx.member.id)
return "Member deleted."
else:
return "Member deletion cancelled."

View File

@ -0,0 +1,57 @@
import logging
from typing import List
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.",
category="Message commands")
async def message_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
try:
mid = int(args[0])
except ValueError:
raise InvalidCommandSyntax()
# Find the message in the DB
message = await db.get_message(ctx.conn, str(mid))
if not message:
raise CommandError("Message not found.")
# Get the original sender of the messages
try:
original_sender = await ctx.client.get_user_info(str(message.sender))
except discord.NotFound:
# Account was since deleted - rare but we're handling it anyway
original_sender = None
embed = discord.Embed()
embed.timestamp = discord.utils.snowflake_time(str(mid))
embed.colour = discord.Colour.blue()
if message.system_name:
system_value = "{} (`{}`)".format(message.system_name, message.system_hid)
else:
system_value = "`{}`".format(message.system_hid)
embed.add_field(name="System", value=system_value)
embed.add_field(name="Member", value="{} (`{}`)".format(message.name, message.hid))
if original_sender:
sender_name = "{}#{}".format(original_sender.name, original_sender.discriminator)
else:
sender_name = "(deleted account {})".format(message.sender)
embed.add_field(name="Sent by", value=sender_name)
if message.content: # Content can be empty string if there's an attachment
embed.add_field(name="Content", value=message.content, inline=False)
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
return embed

View File

@ -0,0 +1,89 @@
import io
import json
import logging
import os
from typing import List
from discord.utils import oauth_url
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.")
async def show_help(ctx: CommandContext, args: List[str]):
embed = utils.make_default_embed("")
embed.title = "PluralKit Help"
embed.set_footer(text="By Astrid (Ske#6201, or 'qoxvy' on PK) | GitHub: https://github.com/xSke/PluralKit/")
category = args[0] if len(args) > 0 else None
from pluralkit.bot.help import help_pages
if category in help_pages:
for name, text in help_pages[category]:
if name:
embed.add_field(name=name, value=text)
else:
embed.description = text
else:
raise InvalidCommandSyntax()
return embed
@command(cmd="invite", description="Generates an invite link for this bot.")
async def invite_link(ctx: CommandContext, args: List[str]):
client_id = os.environ["CLIENT_ID"]
permissions = discord.Permissions()
permissions.manage_webhooks = True
permissions.send_messages = True
permissions.manage_messages = True
permissions.embed_links = True
permissions.attach_files = True
permissions.read_message_history = True
permissions.add_reactions = True
url = oauth_url(client_id, permissions)
logger.debug("Sending invite URL: {}".format(url))
return url
@command(cmd="export", description="Exports system data to a machine-readable format.")
async def export(ctx: CommandContext, args: List[str]):
members = await db.get_all_members(ctx.conn, ctx.system.id)
accounts = await db.get_linked_accounts(ctx.conn, ctx.system.id)
switches = await utils.get_front_history(ctx.conn, ctx.system.id, 999999)
system = ctx.system
data = {
"name": system.name,
"id": system.hid,
"description": system.description,
"tag": system.tag,
"avatar_url": system.avatar_url,
"created": system.created.isoformat(),
"members": [
{
"name": member.name,
"id": member.hid,
"color": member.color,
"avatar_url": member.avatar_url,
"birthday": member.birthday.isoformat() if member.birthday else None,
"pronouns": member.pronouns,
"description": member.description,
"prefix": member.prefix,
"suffix": member.suffix,
"created": member.created.isoformat()
} for member in members
],
"accounts": [str(uid) for uid in accounts],
"switches": [
{
"timestamp": timestamp.isoformat(),
"members": [member.hid for member in members]
} for timestamp, members in switches
]
}
f = io.BytesIO(json.dumps(data).encode("utf-8"))
await ctx.client.send_file(ctx.message.channel, f, filename="system.json")

View File

@ -0,0 +1,24 @@
import logging
from typing import List
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")
async def set_log(ctx: CommandContext, args: List[str]):
if not ctx.message.author.server_permissions.administrator:
raise CommandError("You must be a server administrator to use this command.")
server = ctx.message.server
if len(args) == 0:
channel_id = None
else:
channel = utils.parse_channel_mention(args[0], server=server)
if not channel:
raise CommandError("Channel not found.")
channel_id = channel.id
await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id)
return "Updated logging channel." if channel_id else "Cleared logging channel."

View File

@ -0,0 +1,119 @@
from datetime import datetime
import logging
from typing import List
import dateparser
import humanize
from pluralkit import Member
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
async def switch_member(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
members: List[Member] = []
for member_name in args:
# Find the member
member = await utils.get_member_fuzzy(ctx.conn, ctx.system.id, member_name)
if not member:
raise CommandError("Couldn't find member \"{}\".".format(member_name))
members.append(member)
# Compare requested switch IDs and existing fronter IDs to check for existing switches
# Lists, because order matters, it makes sense to just swap fronters
member_ids = [member.id for member in members]
fronter_ids = (await utils.get_fronter_ids(ctx.conn, ctx.system.id))[0]
if member_ids == fronter_ids:
if len(members) == 1:
raise CommandError("{} is already fronting.".format(members[0].name))
raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
# Also make sure there aren't any duplicates
if len(set(member_ids)) != len(member_ids):
raise CommandError("Duplicate members in switch list.")
# Log the switch
async with ctx.conn.transaction():
switch_id = await db.add_switch(ctx.conn, system_id=ctx.system.id)
for member in members:
await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id)
if len(members) == 1:
return "Switch registered. Current fronter is now {}.".format(members[0].name)
else:
return "Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members]))
@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands")
async def switch_out(ctx: MemberCommandContext, args: List[str]):
# Get current fronters
fronters, _ = await utils.get_fronter_ids(ctx.conn, system_id=ctx.system.id)
if not fronters:
raise CommandError("There's already no one in front.")
# Log it, and don't log any members
await db.add_switch(ctx.conn, system_id=ctx.system.id)
return "Switch-out registered."
@command(cmd="switch move", usage="<time>", description="Moves the most recent switch to a different point in time.", category="Switching commands")
async def switch_move(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
# Parse the time to move to
new_time = dateparser.parse(" ".join(args), languages=["en"], settings={
"TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False
})
if not new_time:
raise CommandError("{} can't be parsed as a valid time.".format(" ".join(args)))
# Make sure the time isn't in the future
if new_time > datetime.now():
raise CommandError("Can't move switch to a time in the future.")
# Make sure it all runs in a big transaction for atomicity
async with ctx.conn.transaction():
# Get the last two switches to make sure the switch to move isn't before the second-last switch
last_two_switches = await utils.get_front_history(ctx.conn, ctx.system.id, count=2)
if len(last_two_switches) == 0:
raise CommandError("There are no registered switches for this system.")
last_timestamp, last_fronters = last_two_switches[0]
if len(last_two_switches) > 1:
second_last_timestamp, _ = last_two_switches[1]
if new_time < second_last_timestamp:
time_str = humanize.naturaltime(second_last_timestamp)
raise CommandError("Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str))
# Display the confirmation message w/ humanized times
members = ", ".join([member.name for member in last_fronters]) or "nobody"
last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds")
last_relative = humanize.naturaltime(last_timestamp)
new_absolute = new_time.isoformat(sep=" ", timespec="seconds")
new_relative = humanize.naturaltime(new_time)
embed = utils.make_default_embed("This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative))
# Await and handle confirmation reactions
confirm_msg = await ctx.reply(embed=embed)
await ctx.client.add_reaction(confirm_msg, "")
await ctx.client.add_reaction(confirm_msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=confirm_msg, user=ctx.message.author, timeout=60.0)
if not reaction:
raise CommandError("Switch move timed out.")
if reaction.reaction.emoji == "":
raise CommandError("Switch move cancelled.")
# DB requires the actual switch ID which our utility method above doesn't return, do this manually
switch_id = (await db.front_history(ctx.conn, ctx.system.id, count=1))[0]["id"]
# Change the switch in the DB
await db.move_last_switch(ctx.conn, ctx.system.id, switch_id, new_time)
return "Switch moved."

View File

@ -0,0 +1,215 @@
import logging
from typing import List
from urllib.parse import urlparse
import humanize
from pluralkit.bot import utils
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands", system_required=False)
async def system_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
system = ctx.system
else:
# Look one up
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
raise CommandError("Unable to find system \"{}\".".format(args[0]))
await ctx.reply(embed=await utils.generate_system_info_card(ctx.conn, ctx.client, system))
@command(cmd="system new", usage="[name]", description="Registers a new system to this account.", category="System commands", system_required=False)
async def new_system(ctx: CommandContext, args: List[str]):
if ctx.system:
raise CommandError("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
system_name = None
if len(args) > 0:
system_name = " ".join(args)
async with ctx.conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid()
system = await db.create_system(ctx.conn, system_name=system_name, system_hid=hid)
# Link account
await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return "System registered! To begin adding members, use `pk;member new <name>`."
@command(cmd="system set", usage="<name|description|tag|avatar> [value]", description="Edits a system property. Leave [value] blank to clear.", category="System commands")
async def system_set(ctx: CommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
allowed_properties = ["name", "description", "tag", "avatar"]
db_properties = {
"name": "name",
"description": "description",
"tag": "tag",
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
raise CommandError("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)))
if len(args) >= 2:
value = " ".join(args[1:])
# Sanity checking
if prop == "tag":
if len(value) > 32:
raise CommandError("Can't have system tag longer than 32 characters.")
# Make sure there are no members which would make the combined length exceed 32
members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=ctx.system.id, length=32 - len(value) - 1)
if len(members_exceeding) > 0:
# If so, error out and warn
member_names = ", ".join([member.name
for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
raise CommandError("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names))
if prop == "avatar":
user = await utils.parse_mention(ctx.client, value)
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
raise CommandError("Invalid URL.")
else:
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_system_field(ctx.conn, system_id=ctx.system.id, field=db_prop, value=value)
response = utils.make_default_embed("{} system {}.".format("Updated" if value else "Cleared", prop))
if prop == "avatar" and value:
response.set_image(url=value)
return response
@command(cmd="system link", usage="<account>", description="Links another account to your system.", category="System commands")
async def system_link(ctx: CommandContext, args: List[str]):
if len(args) == 0:
raise InvalidCommandSyntax()
# Find account to link
linkee = await utils.parse_mention(ctx.client, args[0])
if not linkee:
raise CommandError("Account not found.")
# Make sure account doesn't already have a system
account_system = await db.get_system_by_account(ctx.conn, linkee.id)
if account_system:
raise CommandError("Account is already linked to a system (`{}`)".format(account_system.hid))
# Send confirmation message
msg = await ctx.reply("{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
await ctx.client.add_reaction(msg, "")
await ctx.client.add_reaction(msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee, timeout=60.0)
# If account to be linked confirms...
if not reaction:
raise CommandError("Account link timed out.")
if not reaction.reaction.emoji == "":
raise CommandError("Account link cancelled.")
await db.link_account(ctx.conn, system_id=ctx.system.id, account_id=linkee.id)
return "Account linked to system."
@command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.", category="System commands")
async def system_unlink(ctx: CommandContext, args: List[str]):
# Make sure you can't unlink every account
linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=ctx.system.id)
if len(linked_accounts) == 1:
raise CommandError("This is the only account on your system, so you can't unlink it.")
await db.unlink_account(ctx.conn, system_id=ctx.system.id, account_id=ctx.message.author.id)
return "Account unlinked."
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands", system_required=False)
async def system_fronter(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
raise CommandError("Can't find system \"{}\".".format(args[0]))
fronters, timestamp = await utils.get_fronters(ctx.conn, system_id=ctx.system.id)
fronter_names = [member.name for member in fronters]
embed = utils.make_default_embed(None)
if len(fronter_names) == 0:
embed.add_field(name="Current fronter", value="(no fronter)")
elif len(fronter_names) == 1:
embed.add_field(name="Current fronter", value=fronter_names[0])
else:
embed.add_field(name="Current fronters", value=", ".join(fronter_names))
if timestamp:
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp)))
return embed
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands", system_required=False)
async def system_fronthistory(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
raise CommandError("Can't find system \"{}\".".format(args[0]))
lines = []
front_history = await utils.get_front_history(ctx.conn, ctx.system.id, count=10)
for i, (timestamp, members) in enumerate(front_history):
# Special case when no one's fronting
if len(members) == 0:
name = "(no fronter)"
else:
name = ", ".join([member.name for member in members])
# Make proper date string
time_text = timestamp.isoformat(sep=" ", timespec="seconds")
rel_text = humanize.naturaltime(timestamp)
delta_text = ""
if i > 0:
last_switch_time = front_history[i-1][0]
delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time))
lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text))
embed = utils.make_default_embed("\n".join(lines) or "(none)")
embed.title = "Past switches"
return embed
@command(cmd="system delete", description="Deletes your system from the database ***permanently***.", category="System commands")
async def system_delete(ctx: CommandContext, args: List[str]):
await ctx.reply("Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(ctx.system.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content == ctx.system.hid:
await db.remove_system(ctx.conn, system_id=ctx.system.id)
return "System deleted."
else:
return "System deletion cancelled."

View File

@ -10,7 +10,9 @@ help_pages = {
`pk;help proxy` - Details on message proxying. `pk;help proxy` - Details on message proxying.
`pk;help switch` - Details on switch logging. `pk;help switch` - Details on switch logging.
`pk;help mod` - Details on moderator operations. `pk;help mod` - Details on moderator operations.
`pk;help import` - Details on data import from other services.""") `pk;help import` - Details on data import from other services."""),
("Discord",
"""For feedback, bug reports, suggestions, or just chatting, join our Discord: https://discord.gg/PczBt78""")
], ],
"system": [ "system": [
("Registering a new system", ("Registering a new system",

297
src/pluralkit/bot/proxy.py Normal file
View File

@ -0,0 +1,297 @@
import ciso8601
import logging
import re
from typing import List, Optional
import aiohttp
import discord
from pluralkit import db
from pluralkit.bot import channel_logger, utils
logger = logging.getLogger("pluralkit.bot.proxy")
def extract_leading_mentions(message_text):
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
if not match:
return message_text, ""
# Return the text after the mentions, and the mentions themselves
return message_text[match.span(0)[1]:].strip(), match.group(0)
def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
# Skip members with no defined proxy tags
if not member.prefix and not member.suffix:
return None
# DB defines empty prefix/suffixes as None, replace with empty strings to prevent errors
prefix = member.prefix or ""
suffix = member.suffix or ""
# Ignore mentions at the very start of the message, and match proxy tags after those
message_text, leading_mentions = extract_leading_mentions(message_text)
logger.debug("Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions, prefix, suffix))
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
prefix_length = len(prefix)
suffix_length = len(suffix)
# If suffix_length is 0, the last bit of the slice will be "-0", and the slice will fail
if suffix_length > 0:
inner_string = message_text[prefix_length:-suffix_length]
else:
inner_string = message_text[prefix_length:]
# Add the mentions we stripped back
inner_string = leading_mentions + inner_string
return inner_string
def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
# Sort by specificity (members with both prefix and suffix go higher)
# This will make sure more "precise" proxy tags get tried first
members: List[db.ProxyMember] = sorted(members, key=lambda x: int(
bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
for member in members:
match = match_member_proxy_tags(member, message_text)
if match is not None: # Using "is not None" because an empty string is OK here too
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
return member, match
def get_message_attachment_url(message: discord.Message):
if not message.attachments:
return None
attachment = message.attachments[0]
if "proxy_url" in attachment:
return attachment["proxy_url"]
if "url" in attachment:
return attachment["url"]
# TODO: possibly move this to bot __init__ so commands can access it too
class WebhookPermissionError(Exception):
pass
class DeletionPermissionError(Exception):
pass
class Proxy:
def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
self.logger = logging.getLogger("pluralkit.bot.proxy")
self.session = aiohttp.ClientSession()
self.client = client
self.token = token
self.channel_logger = logger
async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
await db.add_webhook(conn, channel.id, id, token)
return id, token
async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
# This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
# No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
req_headers = {"Authorization": "Bot {}".format(self.token)}
# First, check if there's already a webhook belonging to the bot
async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
headers=req_headers) as resp:
if resp.status == 200:
webhooks = await resp.json()
for webhook in webhooks:
if webhook["user"]["id"] == self.client.user.id:
# This webhook belongs to us, we can use that, return it and save it
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# Then, try submitting a new one
req_data = {"name": "PluralKit Proxy Webhook"}
async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
json=req_data, headers=req_headers) as resp:
if resp.status == 200:
webhook = await resp.json()
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# Should not be reached without an exception being thrown
async def get_webhook_for_channel(self, conn, channel: discord.Channel):
async with conn.transaction():
hook_match = await db.get_webhook(conn, channel.id)
if not hook_match:
# We don't have a webhook, create/add one
return await self.create_and_add_channel_webhook(conn, channel)
else:
return hook_match
async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
attachment_url: str, has_already_retried=False):
hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
form_data = aiohttp.FormData()
form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
if text:
form_data.add_field("content", text)
if attachment_url:
attachment_resp = await self.session.get(attachment_url)
form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
filename=attachment_resp.url.name)
if member.avatar_url:
form_data.add_field("avatar_url", member.avatar_url)
async with self.session.post(
"https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
data=form_data) as resp:
if resp.status == 200:
message = await resp.json()
await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id,
text or "")
try:
await self.client.delete_message(original_message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to delete original message (server={}, channel={})".format(
original_message.server.id, original_message.channel.id))
raise DeletionPermissionError()
except discord.NotFound:
self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
message_image = None
if message["attachments"]:
first_attachment = message["attachments"][0]
if "width" in first_attachment and "height" in first_attachment:
# Only log attachments that are actually images
message_image = first_attachment["url"]
await self.channel_logger.log_message_proxied(conn,
server_id=original_message.server.id,
channel_name=original_message.channel.name,
channel_id=original_message.channel.id,
sender_name=original_message.author.name,
sender_disc=original_message.author.discriminator,
member_name=member.name,
member_hid=member.hid,
member_avatar_url=member.avatar_url,
system_name=member.system_name,
system_hid=member.system_hid,
message_text=text,
message_image=message_image,
message_timestamp=ciso8601.parse_datetime(
message["timestamp"]),
message_id=message["id"])
elif resp.status == 404 and not has_already_retried:
# Webhook doesn't exist. Delete it from the DB, create, and add a new one
self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
await db.delete_webhook(conn, original_message.channel.id)
await self.create_and_add_channel_webhook(conn, original_message.channel)
# Then try again all over, making sure to not retry again and go in a loop should it continually fail
return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
else:
raise discord.HTTPException(resp, await resp.text())
async def try_proxy_message(self, conn, message: discord.Message):
# Can't proxy in DMs, webhook creation will explode
if message.channel.is_private:
return False
# Big fat query to find every member associated with this account
# Returned member object has a few more keys (system tag, for example)
members = await db.get_members_by_account(conn, account_id=message.author.id)
match = match_proxy_tags(members, message.content)
if not match:
return False
member, text = match
attachment_url = get_message_attachment_url(message)
# Can't proxy a message with no text AND no attachment
if not text and not attachment_url:
self.logger.debug("Skipping message because of no text and no attachment")
return False
try:
async with conn.transaction():
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
except WebhookPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
except DeletionPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
return True
async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
async with conn.transaction():
# Find the message in the DB, and make sure it's sent by the user (if we need to check)
if check_user_id:
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
else:
db_message = await db.get_message(conn, message_id=message_id)
if db_message:
self.logger.debug("Deleting message {}".format(message_id))
channel = self.client.get_channel(str(db_message.channel))
# If we should also delete the actual message, do that
if delete_message:
message = await self.client.get_message(channel, message_id)
try:
await self.client.delete_message(message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
channel.server.id, channel.id))
return
# Remove it from the DB
await db.delete_message(conn, message_id)
# Then log deletion to logging channel
await self.channel_logger.log_message_deleted(conn,
server_id=channel.server.id,
channel_name=channel.name,
member_name=db_message.name,
member_hid=db_message.hid,
member_avatar_url=db_message.avatar_url,
system_name=db_message.system_name,
system_hid=db_message.system_hid,
message_text=db_message.content,
message_id=message_id,
deleted_by_moderator=deleted_by_moderator)
async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
if emoji == "":
await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
async def handle_deletion(self, conn, message_id: str):
# Don't delete the message, it's already gone at this point, just handle DB deletion and logging
await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)

219
src/pluralkit/bot/utils.py Normal file
View File

@ -0,0 +1,219 @@
from datetime import datetime
import logging
import random
import re
from typing import List, Tuple
import string
import asyncio
import asyncpg
import discord
import humanize
from pluralkit import System, Member, db
logger = logging.getLogger("pluralkit.utils")
def escape(s):
return s.replace("`", "\\`")
def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5))
def bounds_check_member_name(new_name, system_tag):
if len(new_name) > 32:
return "Name cannot be longer than 32 characters."
if system_tag:
if len("{} {}".format(new_name, system_tag)) > 32:
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
async def parse_mention(client: discord.Client, mention: str) -> discord.User:
# First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention)
if match:
try:
return await client.get_user_info(match.group(1))
except discord.NotFound:
return None
# Then try with just ID
try:
return await client.get_user_info(str(int(mention)))
except (ValueError, discord.NotFound):
return None
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
match = re.fullmatch("<#(\\d+)>", mention)
if match:
return server.get_channel(match.group(1))
try:
return server.get_channel(str(int(mention)))
except ValueError:
return None
async def get_fronter_ids(conn, system_id) -> (List[int], datetime):
switches = await db.front_history(conn, system_id=system_id, count=1)
if not switches:
return [], None
if not switches[0]["members"]:
return [], switches[0]["timestamp"]
return switches[0]["members"], switches[0]["timestamp"]
async def get_fronters(conn, system_id) -> (List[Member], datetime):
member_ids, timestamp = await get_fronter_ids(conn, system_id)
# Collect in dict and then look up as list, to preserve return order
members = {member.id: member for member in await db.get_members(conn, member_ids)}
return [members[member_id] for member_id in member_ids], timestamp
async def get_front_history(conn, system_id, count) -> List[Tuple[datetime, List[Member]]]:
# Get history from DB
switches = await db.front_history(conn, system_id=system_id, count=count)
if not switches:
return []
# Get all unique IDs referenced
all_member_ids = {id for switch in switches for id in switch["members"]}
# And look them up in the database into a dict
all_members = {member.id: member for member in await db.get_members(conn, list(all_member_ids))}
# Collect in array and return
out = []
for switch in switches:
timestamp = switch["timestamp"]
members = [all_members[id] for id in switch["members"]]
out.append((timestamp, members))
return out
async def get_system_fuzzy(conn, client: discord.Client, key) -> System:
if isinstance(key, discord.User):
return await db.get_system_by_account(conn, account_id=key.id)
if isinstance(key, str) and len(key) == 5:
return await db.get_system_by_hid(conn, system_hid=key)
account = await parse_mention(client, key)
if account:
system = await db.get_system_by_account(conn, account_id=account.id)
if system:
return system
return None
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> Member:
# First search by hid
if system_only:
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
else:
member = await db.get_member_by_hid(conn, member_hid=key)
if member is not None:
return member
# Then search by name, if we have a system
if system_id:
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
if member is not None:
return member
def make_default_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message
return embed
def make_error_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message
return embed
async def generate_system_info_card(conn, client: discord.Client, system: System) -> discord.Embed:
card = discord.Embed()
card.colour = discord.Colour.blue()
if system.name:
card.title = system.name
if system.avatar_url:
card.set_thumbnail(url=system.avatar_url)
if system.tag:
card.add_field(name="Tag", value=system.tag)
fronters, switch_time = await get_fronters(conn, system.id)
if fronters:
names = ", ".join([member.name for member in fronters])
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time))
card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters", value=fronter_val)
account_names = []
for account_id in await db.get_linked_accounts(conn, system_id=system.id):
account = await client.get_user_info(account_id)
account_names.append("{}#{}".format(account.name, account.discriminator))
card.add_field(name="Linked accounts", value="\n".join(account_names))
if system.description:
card.add_field(name="Description",
value=system.description, inline=False)
# Get names of all members
member_texts = []
for member in await db.get_all_members(conn, system_id=system.id):
member_texts.append("{} (`{}`)".format(escape(member.name), member.hid))
if len(member_texts) > 0:
card.add_field(name="Members", value="\n".join(
member_texts), inline=False)
card.set_footer(text="System ID: {}".format(system.hid))
return card
async def generate_member_info_card(conn, member: Member) -> discord.Embed:
system = await db.get_system(conn, system_id=member.system)
card = discord.Embed()
card.colour = discord.Colour.blue()
name_and_system = member.name
if system.name:
name_and_system += " ({})".format(system.name)
card.set_author(name=name_and_system, icon_url=member.avatar_url or discord.Embed.Empty)
if member.avatar_url:
card.set_thumbnail(url=member.avatar_url)
# Get system name and hid
system = await db.get_system(conn, system_id=member.system)
if member.color:
card.colour = int(member.color, 16)
if member.birthday:
bday_val = member.birthday.strftime("%b %d, %Y")
if member.birthday.year == 1:
bday_val = member.birthday.strftime("%b %d")
card.add_field(name="Birthdate", value=bday_val)
if member.pronouns:
card.add_field(name="Pronouns", value=member.pronouns)
if member.prefix or member.suffix:
prefix = member.prefix or ""
suffix = member.suffix or ""
card.add_field(name="Proxy Tags",
value="{}text{}".format(prefix, suffix))
if member.description:
card.add_field(name="Description",
value=member.description, inline=False)
card.set_footer(text="System ID: {} | Member ID: {}".format(
system.hid, member.hid))
return card

View File

@ -1,11 +1,15 @@
from collections import namedtuple
from datetime import datetime
import logging
from typing import List
import time import time
import asyncpg import asyncpg
import asyncpg.exceptions import asyncpg.exceptions
from pluralkit import stats from pluralkit import System, Member, stats
from pluralkit.bot import logger
logger = logging.getLogger("pluralkit.db")
async def connect(): async def connect():
while True: while True:
try: try:
@ -13,7 +17,6 @@ async def connect():
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError): except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
pass pass
def db_wrap(func): def db_wrap(func):
async def inner(*args, **kwargs): async def inner(*args, **kwargs):
before = time.perf_counter() before = time.perf_counter()
@ -31,10 +34,11 @@ def db_wrap(func):
return inner return inner
@db_wrap @db_wrap
async def create_system(conn, system_name: str, system_hid: str): async def create_system(conn, system_name: str, system_hid: str) -> System:
logger.debug("Creating system (name={}, hid={})".format( logger.debug("Creating system (name={}, hid={})".format(
system_name, system_hid)) system_name, system_hid))
return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid) row = await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
return System(**row) if row else None
@db_wrap @db_wrap
@ -44,10 +48,11 @@ async def remove_system(conn, system_id: int):
@db_wrap @db_wrap
async def create_member(conn, system_id: int, member_name: str, member_hid: str): async def create_member(conn, system_id: int, member_name: str, member_hid: str) -> Member:
logger.debug("Creating member (system={}, name={}, hid={})".format( logger.debug("Creating member (system={}, name={}, hid={})".format(
system_id, member_name, member_hid)) system_id, member_name, member_hid))
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid) row = await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
return Member(**row) if row else None
@db_wrap @db_wrap
@ -71,52 +76,54 @@ async def unlink_account(conn, system_id: int, account_id: str):
@db_wrap @db_wrap
async def get_linked_accounts(conn, system_id: int): async def get_linked_accounts(conn, system_id: int) -> List[int]:
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)] return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
@db_wrap @db_wrap
async def get_system_by_account(conn, account_id: str): async def get_system_by_account(conn, account_id: str) -> System:
return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id)) row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
return System(**row) if row else None
@db_wrap
async def get_system_by_hid(conn, system_hid: str) -> System:
row = await conn.fetchrow("select * from systems where hid = $1", system_hid)
return System(**row) if row else None
@db_wrap @db_wrap
async def get_system_by_hid(conn, system_hid: str): async def get_system(conn, system_id: int) -> System:
return await conn.fetchrow("select * from systems where hid = $1", system_hid) row = await conn.fetchrow("select * from systems where id = $1", system_id)
return System(**row) if row else None
@db_wrap @db_wrap
async def get_system(conn, system_id: int): async def get_member_by_name(conn, system_id: int, member_name: str) -> Member:
return await conn.fetchrow("select * from systems where id = $1", system_id) row = await conn.fetchrow("select * from members where system = $1 and lower(name) = lower($2)", system_id, member_name)
return Member(**row) if row else None
@db_wrap @db_wrap
async def get_member_by_name(conn, system_id: int, member_name: str): async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str) -> Member:
return await conn.fetchrow("select * from members where system = $1 and lower(name) = lower($2)", system_id, member_name) row = await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
return Member(**row) if row else None
@db_wrap @db_wrap
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str): async def get_member_by_hid(conn, member_hid: str) -> Member:
return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid) row = await conn.fetchrow("select * from members where hid = $1", member_hid)
return Member(**row) if row else None
@db_wrap @db_wrap
async def get_member_by_hid(conn, member_hid: str): async def get_member(conn, member_id: int) -> Member:
return await conn.fetchrow("select * from members where hid = $1", member_hid) row = await conn.fetchrow("select * from members where id = $1", member_id)
return Member(**row) if row else None
@db_wrap @db_wrap
async def get_member(conn, member_id: int): async def get_members(conn, members: list) -> List[Member]:
return await conn.fetchrow("select * from members where id = $1", member_id) rows = await conn.fetch("select * from members where id = any($1)", members)
return [Member(**row) for row in rows]
@db_wrap
async def get_members(conn, members: list):
return await conn.fetch("select * from members where id = any($1)", members)
@db_wrap
async def get_message(conn, message_id: str):
return await conn.fetchrow("select * from messages where mid = $1", message_id)
@db_wrap @db_wrap
async def update_system_field(conn, system_id: int, field: str, value): async def update_system_field(conn, system_id: int, field: str, value):
@ -133,18 +140,20 @@ async def update_member_field(conn, member_id: int, field: str, value):
@db_wrap @db_wrap
async def get_all_members(conn, system_id: int): async def get_all_members(conn, system_id: int) -> List[Member]:
return await conn.fetch("select * from members where system = $1", system_id) rows = await conn.fetch("select * from members where system = $1", system_id)
return [Member(**row) for row in rows]
@db_wrap
async def get_members_exceeding(conn, system_id: int, length: int) -> List[Member]:
rows = await conn.fetch("select * from members where system = $1 and length(name) > $2", system_id, length)
return [Member(**row) for row in rows]
@db_wrap @db_wrap
async def get_members_exceeding(conn, system_id: int, length: int): async def get_webhook(conn, channel_id: str) -> (str, str):
return await conn.fetch("select * from members where system = $1 and length(name) > $2", system_id, length) row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
return (str(row["webhook"]), row["token"]) if row else None
@db_wrap
async def get_webhook(conn, channel_id: str):
return await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
@db_wrap @db_wrap
@ -153,6 +162,9 @@ async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str
channel_id, webhook_id, webhook_token)) channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token) await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
@db_wrap
async def delete_webhook(conn, channel_id: str):
await conn.execute("delete from webhooks where channel = $1", int(channel_id))
@db_wrap @db_wrap
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str, content: str): async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str, content: str):
@ -160,11 +172,22 @@ async def add_message(conn, message_id: str, channel_id: str, member_id: int, se
message_id, channel_id, member_id, sender_id)) message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender, content) values ($1, $2, $3, $4, $5)", int(message_id), int(channel_id), member_id, int(sender_id), content) await conn.execute("insert into messages (mid, channel, member, sender, content) values ($1, $2, $3, $4, $5)", int(message_id), int(channel_id), member_id, int(sender_id), content)
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
id: int
hid: str
prefix: str
suffix: str
color: str
name: str
avatar_url: str
tag: str
system_name: str
system_hid: str
@db_wrap @db_wrap
async def get_members_by_account(conn, account_id: str): async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
# Returns a "chimera" object # Returns a "chimera" object
return await conn.fetch("""select rows = await conn.fetch("""select
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url, members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
systems.tag, systems.name as system_name, systems.hid as system_hid systems.tag, systems.name as system_name, systems.hid as system_hid
from from
@ -173,11 +196,23 @@ async def get_members_by_account(conn, account_id: str):
accounts.uid = $1 accounts.uid = $1
and systems.id = accounts.system and systems.id = accounts.system
and members.system = systems.id""", int(account_id)) and members.system = systems.id""", int(account_id))
return [ProxyMember(**row) for row in rows]
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "content", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
mid: int
channel: int
member: int
content: str
sender: int
name: str
hid: str
avatar_url: str
system_name: str
system_hid: str
@db_wrap @db_wrap
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str): async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) -> MessageInfo:
return await conn.fetchrow("""select row = await conn.fetchrow("""select
messages.*, messages.*,
members.name, members.hid, members.avatar_url, members.name, members.hid, members.avatar_url,
systems.name as system_name, systems.hid as system_hid systems.name as system_name, systems.hid as system_hid
@ -186,7 +221,24 @@ async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str):
where where
messages.member = members.id messages.member = members.id
and members.system = systems.id and members.system = systems.id
and mid = $1 and sender = $2""", int(message_id), int(sender_id)) and mid = $1
and sender = $2""", int(message_id), int(sender_id))
return MessageInfo(**row) if row else None
@db_wrap
async def get_message(conn, message_id: str) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url,
systems.name as system_name, systems.hid as system_hid
from
messages, members, systems
where
messages.member = members.id
and members.system = systems.id
and mid = $1""", int(message_id))
return MessageInfo(**row) if row else None
@db_wrap @db_wrap
@ -215,7 +267,7 @@ async def add_switch(conn, system_id: int):
return res["id"] return res["id"]
@db_wrap @db_wrap
async def move_last_switch(conn, system_id: int, switch_id: int, new_time): async def move_last_switch(conn, system_id: int, switch_id: int, new_time: datetime):
logger.debug("Moving latest switch (system={}, id={}, new_time={})".format(system_id, switch_id, new_time)) logger.debug("Moving latest switch (system={}, id={}, new_time={})".format(system_id, switch_id, new_time))
await conn.execute("update switches set timestamp = $1 where system = $2 and id = $3", new_time, system_id, switch_id) await conn.execute("update switches set timestamp = $1 where system = $2 and id = $3", new_time, system_id, switch_id)
@ -235,19 +287,19 @@ async def update_server(conn, server_id: str, logging_channel_id: str):
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id) await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
@db_wrap @db_wrap
async def member_count(conn): async def member_count(conn) -> int:
return await conn.fetchval("select count(*) from members") return await conn.fetchval("select count(*) from members")
@db_wrap @db_wrap
async def system_count(conn): async def system_count(conn) -> int:
return await conn.fetchval("select count(*) from systems") return await conn.fetchval("select count(*) from systems")
@db_wrap @db_wrap
async def message_count(conn): async def message_count(conn) -> int:
return await conn.fetchval("select count(*) from messages") return await conn.fetchval("select count(*) from messages")
@db_wrap @db_wrap
async def account_count(conn): async def account_count(conn) -> int:
return await conn.fetchval("select count(*) from accounts") return await conn.fetchval("select count(*) from accounts")
async def create_tables(conn): async def create_tables(conn):
@ -283,7 +335,7 @@ async def create_tables(conn):
channel bigint not null, channel bigint not null,
member serial not null references members(id) on delete cascade, member serial not null references members(id) on delete cascade,
content text not null, content text not null,
sender bigint not null references accounts(uid) sender bigint not null
)""") )""")
await conn.execute("""create table if not exists switches ( await conn.execute("""create table if not exists switches (
id serial primary key, id serial primary key,

View File

@ -1,7 +1,5 @@
from aioinflux import InfluxDBClient from aioinflux import InfluxDBClient
from pluralkit.bot import logger
client = None client = None
async def connect(): async def connect():
global client global client

7
src/requirements.txt Normal file
View File

@ -0,0 +1,7 @@
aiohttp
aioinflux
asyncpg
dateparser
discord.py
humanize
uvloop