Massive refactor of pretty much everything in the bot
This commit is contained in:
parent
086fa84b4b
commit
8936029dc8
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
.env
|
||||
.vscode/
|
||||
.idea/
|
@ -1 +0,0 @@
|
||||
from . import commands, db, proxy, stats
|
@ -1,122 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import discord
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||
logging.getLogger("discord").setLevel(logging.INFO)
|
||||
logging.getLogger("websockets").setLevel(logging.INFO)
|
||||
|
||||
logger = logging.getLogger("pluralkit.bot")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
client = discord.Client()
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_error(evt, *args, **kwargs):
|
||||
logger.exception(
|
||||
"Error while handling event {} with arguments {}:".format(evt, args))
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
# Print status info
|
||||
logger.info("Connected to Discord.")
|
||||
logger.info("Account: {}#{}".format(
|
||||
client.user.name, client.user.discriminator))
|
||||
logger.info("User ID: {}".format(client.user.id))
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
# Ignore bot messages
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Split into args. shlex sucks so we don't bother with quotes
|
||||
args = message.content.split(" ")
|
||||
|
||||
from pluralkit import proxy, utils, stats
|
||||
|
||||
command_items = utils.command_map.items()
|
||||
command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True)
|
||||
|
||||
prefix = "pk;"
|
||||
for command, (func, _, _, _) in command_items:
|
||||
if message.content.lower().startswith(prefix + command):
|
||||
args_str = message.content[len(prefix + command):].strip()
|
||||
args = args_str.split(" ")
|
||||
|
||||
# Splitting on empty string yields one-element array, remove that
|
||||
if len(args) == 1 and not args[0]:
|
||||
args = []
|
||||
|
||||
async with client.pool.acquire() as conn:
|
||||
time_before = time.perf_counter()
|
||||
await func(conn, message, args)
|
||||
time_after = time.perf_counter()
|
||||
|
||||
# Report command time stats
|
||||
execution_time = time_after - time_before
|
||||
response_time = (datetime.now() - message.timestamp).total_seconds()
|
||||
await stats.report_command(command, execution_time, response_time)
|
||||
return
|
||||
|
||||
# Try doing proxy parsing
|
||||
async with client.pool.acquire() as conn:
|
||||
await proxy.handle_proxying(conn, message)
|
||||
|
||||
@client.event
|
||||
async def on_socket_raw_receive(msg):
|
||||
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
|
||||
# we parse socket data manually for the reaction add event
|
||||
if isinstance(msg, str):
|
||||
try:
|
||||
msg_data = json.loads(msg)
|
||||
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
|
||||
evt_data = msg_data.get("d")
|
||||
if evt_data:
|
||||
user_id = evt_data["user_id"]
|
||||
message_id = evt_data["message_id"]
|
||||
emoji = evt_data["emoji"]["name"]
|
||||
|
||||
async with client.pool.acquire() as conn:
|
||||
from pluralkit import proxy
|
||||
await proxy.handle_reaction(conn, user_id, message_id, emoji)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def periodical_stat_timer(pool):
|
||||
async with pool.acquire() as conn:
|
||||
while True:
|
||||
from pluralkit import stats
|
||||
await stats.report_periodical_stats(conn)
|
||||
await asyncio.sleep(30)
|
||||
|
||||
async def run():
|
||||
from pluralkit import db, stats
|
||||
try:
|
||||
logger.info("Connecting to database...")
|
||||
pool = await db.connect()
|
||||
|
||||
logger.info("Attempting to create tables...")
|
||||
async with pool.acquire() as conn:
|
||||
await db.create_tables(conn)
|
||||
|
||||
logger.info("Connecting to InfluxDB...")
|
||||
await stats.connect()
|
||||
|
||||
logger.info("Starting periodical stat reporting...")
|
||||
asyncio.get_event_loop().create_task(periodical_stat_timer(pool))
|
||||
|
||||
client.pool = pool
|
||||
logger.info("Connecting to Discord...")
|
||||
await client.start(os.environ["TOKEN"])
|
||||
finally:
|
||||
logger.info("Logging out from Discord...")
|
||||
await client.logout()
|
@ -1,752 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dateparser
|
||||
import discord
|
||||
from discord.utils import oauth_url
|
||||
import humanize
|
||||
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import client, logger
|
||||
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention, bounds_check_member_name, get_fronters, get_fronter_ids, get_front_history
|
||||
|
||||
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands")
|
||||
async def system_info(conn, message, args):
|
||||
if len(args) == 0:
|
||||
# Use sender's system
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
else:
|
||||
# Look one up
|
||||
system = await get_system_fuzzy(conn, args[0])
|
||||
|
||||
if system is None:
|
||||
return False, "Unable to find system \"{}\".".format(args[0])
|
||||
|
||||
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
|
||||
return True
|
||||
|
||||
|
||||
@command(cmd="system new", usage="[name]", description="Registers a new system to this account.", category="System commands")
|
||||
async def new_system(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is not None:
|
||||
return False, "You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`."
|
||||
|
||||
system_name = None
|
||||
if len(args) > 0:
|
||||
system_name = " ".join(args)
|
||||
|
||||
async with conn.transaction():
|
||||
# TODO: figure out what to do if this errors out on collision on generate_hid
|
||||
hid = generate_hid()
|
||||
|
||||
system = await db.create_system(conn, system_name=system_name, system_hid=hid)
|
||||
|
||||
# Link account
|
||||
await db.link_account(conn, system_id=system["id"], account_id=message.author.id)
|
||||
return True, "System registered! To begin adding members, use `pk;member new <name>`."
|
||||
|
||||
@command(cmd="system set", usage="<name|description|tag|avatar> [value]", description="Edits a system property. Leave [value] blank to clear.", category="System commands")
|
||||
async def system_set(conn, message, args):
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
allowed_properties = ["name", "description", "tag", "avatar"]
|
||||
db_properties = {
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"tag": "tag",
|
||||
"avatar": "avatar_url"
|
||||
}
|
||||
|
||||
prop = args[0]
|
||||
if prop not in allowed_properties:
|
||||
return False, "Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties))
|
||||
|
||||
if len(args) >= 2:
|
||||
value = " ".join(args[1:])
|
||||
|
||||
# Sanity checking
|
||||
if prop == "tag":
|
||||
# Make sure there are no members which would make the combined length exceed 32
|
||||
members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=32 - len(value))
|
||||
if len(members_exceeding) > 0:
|
||||
# If so, error out and warn
|
||||
member_names = ", ".join([member["name"]
|
||||
for member in members_exceeding])
|
||||
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
|
||||
return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names)
|
||||
|
||||
if prop == "avatar":
|
||||
user = await parse_mention(value)
|
||||
if user:
|
||||
# Set the avatar to the mentioned user's avatar
|
||||
# Discord doesn't like webp, but also hosts png alternatives
|
||||
value = user.avatar_url.replace(".webp", ".png")
|
||||
else:
|
||||
# Validate URL
|
||||
u = urlparse(value)
|
||||
if u.scheme in ["http", "https"] and u.netloc and u.path:
|
||||
value = value
|
||||
else:
|
||||
return False, "Invalid URL."
|
||||
else:
|
||||
# Clear from DB
|
||||
value = None
|
||||
|
||||
db_prop = db_properties[prop]
|
||||
await db.update_system_field(conn, system_id=system["id"], field=db_prop, value=value)
|
||||
|
||||
response = make_default_embed("{} system {}.".format("Updated" if value else "Cleared", prop))
|
||||
if prop == "avatar" and value:
|
||||
response.set_image(url=value)
|
||||
return True, response
|
||||
|
||||
@command(cmd="system link", usage="<account>", description="Links another account to your system.", category="System commands")
|
||||
async def system_link(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
# Find account to link
|
||||
linkee = await parse_mention(args[0])
|
||||
if not linkee:
|
||||
return False, "Account not found."
|
||||
|
||||
# Make sure account doesn't already have a system
|
||||
account_system = await db.get_system_by_account(conn, linkee.id)
|
||||
if account_system:
|
||||
return False, "Account is already linked to a system (`{}`)".format(account_system["hid"])
|
||||
|
||||
# Send confirmation message
|
||||
msg = await client.send_message(message.channel, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
|
||||
await client.add_reaction(msg, "✅")
|
||||
await client.add_reaction(msg, "❌")
|
||||
|
||||
reaction = await client.wait_for_reaction(emoji=["✅", "❌"], message=msg, user=linkee)
|
||||
# If account to be linked confirms...
|
||||
if reaction.reaction.emoji == "✅":
|
||||
async with conn.transaction():
|
||||
# Execute the link
|
||||
await db.link_account(conn, system_id=system["id"], account_id=linkee.id)
|
||||
return True, "Account linked to system."
|
||||
else:
|
||||
await client.clear_reactions(msg)
|
||||
return False, "Account link cancelled."
|
||||
|
||||
|
||||
@command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.", category="System commands")
|
||||
async def system_unlink(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
# Make sure you can't unlink every account
|
||||
linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"])
|
||||
if len(linked_accounts) == 1:
|
||||
return False, "This is the only account on your system, so you can't unlink it."
|
||||
|
||||
async with conn.transaction():
|
||||
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
|
||||
return True, "Account unlinked."
|
||||
|
||||
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands")
|
||||
async def system_fronter(conn, message, args):
|
||||
if len(args) == 0:
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
else:
|
||||
system = await get_system_fuzzy(conn, args[0])
|
||||
|
||||
if system is None:
|
||||
return False, "Can't find system \"{}\".".format(args[0])
|
||||
|
||||
fronters, timestamp = await get_fronters(conn, system_id=system["id"])
|
||||
fronter_names = [member["name"] for member in fronters]
|
||||
|
||||
embed = make_default_embed(None)
|
||||
|
||||
if len(fronter_names) == 0:
|
||||
embed.add_field(name="Current fronter", value="(no fronter)")
|
||||
elif len(fronter_names) == 1:
|
||||
embed.add_field(name="Current fronter", value=fronter_names[0])
|
||||
else:
|
||||
embed.add_field(name="Current fronters", value=", ".join(fronter_names))
|
||||
|
||||
if timestamp:
|
||||
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp)))
|
||||
return True, embed
|
||||
|
||||
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands")
|
||||
async def system_fronthistory(conn, message, args):
|
||||
if len(args) == 0:
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
else:
|
||||
system = await get_system_fuzzy(conn, args[0])
|
||||
|
||||
if system is None:
|
||||
return False, "Can't find system \"{}\".".format(args[0])
|
||||
|
||||
lines = []
|
||||
front_history = await get_front_history(conn, system["id"], count=10)
|
||||
for i, (timestamp, members) in enumerate(front_history):
|
||||
# Special case when no one's fronting
|
||||
if len(members) == 0:
|
||||
name = "(no fronter)"
|
||||
else:
|
||||
name = ", ".join([member["name"] for member in members])
|
||||
|
||||
# Make proper date string
|
||||
time_text = timestamp.isoformat(sep=" ", timespec="seconds")
|
||||
rel_text = humanize.naturaltime(timestamp)
|
||||
|
||||
delta_text = ""
|
||||
if i > 0:
|
||||
last_switch_time = front_history[i-1][0]
|
||||
delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time))
|
||||
lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text))
|
||||
|
||||
embed = make_default_embed("\n".join(lines) or "(none)")
|
||||
embed.title = "Past switches"
|
||||
return True, embed
|
||||
|
||||
|
||||
@command(cmd="system delete", description="Deletes your system from the database ***permanently***.", category="System commands")
|
||||
async def system_delete(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
await client.send_message(message.channel, "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"]))
|
||||
|
||||
msg = await client.wait_for_message(author=message.author, channel=message.channel, timeout=60.0)
|
||||
if msg and msg.content == system["hid"]:
|
||||
await db.remove_system(conn, system_id=system["id"])
|
||||
return True, "System deleted."
|
||||
else:
|
||||
return True, "System deletion cancelled."
|
||||
|
||||
@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands")
|
||||
async def member_info(conn, message, member, args):
|
||||
await client.send_message(message.channel, embed=await generate_member_info_card(conn, member))
|
||||
return True
|
||||
|
||||
@command(cmd="member new", usage="<name>", description="Adds a new member to your system.", category="Member commands")
|
||||
async def new_member(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
name = " ".join(args)
|
||||
bounds_error = bounds_check_member_name(name, system["tag"])
|
||||
if bounds_error:
|
||||
return False, bounds_error
|
||||
|
||||
async with conn.transaction():
|
||||
# TODO: figure out what to do if this errors out on collision on generate_hid
|
||||
hid = generate_hid()
|
||||
|
||||
# Insert member row
|
||||
await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid)
|
||||
return True, "Member \"{}\" (`{}`) registered!".format(name, hid)
|
||||
|
||||
|
||||
@member_command(cmd="member set", usage="<name|description|color|pronouns|birthdate|avatar> [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands")
|
||||
async def member_set(conn, message, member, args):
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"]
|
||||
db_properties = {
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"color": "color",
|
||||
"pronouns": "pronouns",
|
||||
"birthdate": "birthday",
|
||||
"avatar": "avatar_url"
|
||||
}
|
||||
|
||||
prop = args[0]
|
||||
if prop not in allowed_properties:
|
||||
return False, "Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties))
|
||||
|
||||
if len(args) >= 2:
|
||||
value = " ".join(args[1:])
|
||||
|
||||
# Sanity/validity checks and type conversions
|
||||
if prop == "name":
|
||||
system = await db.get_system(conn, member["system"])
|
||||
bounds_error = bounds_check_member_name(value, system["tag"])
|
||||
if bounds_error:
|
||||
return False, bounds_error
|
||||
|
||||
if prop == "color":
|
||||
match = re.fullmatch("#?([0-9A-Fa-f]{6})", value)
|
||||
if not match:
|
||||
return False, "Color must be a valid hex color (eg. #ff0000)"
|
||||
|
||||
value = match.group(1).lower()
|
||||
|
||||
if prop == "birthdate":
|
||||
try:
|
||||
value = datetime.strptime(value, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
try:
|
||||
# Try again, adding 0001 as a placeholder year
|
||||
# This is considered a "null year" and will be omitted from the info card
|
||||
# Useful if you want your birthday to be displayed yearless.
|
||||
value = value = datetime.strptime("0001-" + value, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)."
|
||||
|
||||
if prop == "avatar":
|
||||
user = await parse_mention(value)
|
||||
if user:
|
||||
# Set the avatar to the mentioned user's avatar
|
||||
# Discord doesn't like webp, but also hosts png alternatives
|
||||
value = user.avatar_url.replace(".webp", ".png")
|
||||
else:
|
||||
# Validate URL
|
||||
u = urlparse(value)
|
||||
if u.scheme in ["http", "https"] and u.netloc and u.path:
|
||||
value = value
|
||||
else:
|
||||
return False, "Invalid URL."
|
||||
else:
|
||||
# Can't clear member name
|
||||
if prop == "name":
|
||||
return False, "Can't clear member name."
|
||||
|
||||
# Clear from DB
|
||||
value = None
|
||||
|
||||
db_prop = db_properties[prop]
|
||||
await db.update_member_field(conn, member_id=member["id"], field=db_prop, value=value)
|
||||
|
||||
response = make_default_embed("{} {}'s {}.".format("Updated" if value else "Cleared", member["name"], prop))
|
||||
if prop == "avatar" and value:
|
||||
response.set_image(url=value)
|
||||
if prop == "color" and value:
|
||||
response.colour = int(value, 16)
|
||||
return True, response
|
||||
|
||||
@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands")
|
||||
async def member_proxy(conn, message, member, args):
|
||||
if len(args) == 0:
|
||||
prefix, suffix = None, None
|
||||
else:
|
||||
# Sanity checking
|
||||
example = " ".join(args)
|
||||
if "text" not in example:
|
||||
return False, "Example proxy message must contain the string 'text'."
|
||||
|
||||
if example.count("text") != 1:
|
||||
return False, "Example proxy message must contain the string 'text' exactly once."
|
||||
|
||||
# Extract prefix and suffix
|
||||
prefix = example[:example.index("text")].strip()
|
||||
suffix = example[example.index("text")+4:].strip()
|
||||
logger.debug(
|
||||
"Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
|
||||
|
||||
# DB stores empty strings as None, make that work
|
||||
if not prefix:
|
||||
prefix = None
|
||||
if not suffix:
|
||||
suffix = None
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_member_field(conn, member_id=member["id"], field="prefix", value=prefix)
|
||||
await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix)
|
||||
return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
|
||||
|
||||
@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands")
|
||||
async def member_delete(conn, message, member, args):
|
||||
await client.send_message(message.channel, "Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(member["name"], member["hid"]))
|
||||
|
||||
msg = await client.wait_for_message(author=message.author, channel=message.channel, timeout=60.0)
|
||||
if msg and msg.content == member["hid"]:
|
||||
await db.delete_member(conn, member_id=member["id"])
|
||||
return True, "Member deleted."
|
||||
else:
|
||||
return True, "Member deletion cancelled."
|
||||
|
||||
@command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.", category="Message commands")
|
||||
async def message_info(conn, message, args):
|
||||
try:
|
||||
mid = int(args[0])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Find the message in the DB
|
||||
message_row = await db.get_message(conn, mid)
|
||||
if not message_row:
|
||||
return False, "Message not found."
|
||||
|
||||
# Get the original sender of the message
|
||||
original_sender = await client.get_user_info(str(message_row["sender"]))
|
||||
|
||||
# Get sender member and system
|
||||
member = await db.get_member(conn, message_row["member"])
|
||||
system = await db.get_system(conn, member["system"])
|
||||
|
||||
embed = discord.Embed()
|
||||
embed.timestamp = discord.utils.snowflake_time(str(mid))
|
||||
embed.colour = discord.Colour.blue()
|
||||
|
||||
if system["name"]:
|
||||
system_value = "{} (`{}`)".format(system["name"], system["hid"])
|
||||
else:
|
||||
system_value = "`{}`".format(system["hid"])
|
||||
embed.add_field(name="System", value=system_value)
|
||||
embed.add_field(name="Member", value="{}: (`{}`)".format(
|
||||
member["name"], member["hid"]))
|
||||
embed.add_field(name="Sent by", value="{}#{}".format(
|
||||
original_sender.name, original_sender.discriminator))
|
||||
embed.add_field(name="Content", value=message_row["content"], inline=False)
|
||||
|
||||
embed.set_author(name=member["name"], url=member["avatar_url"])
|
||||
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
return True
|
||||
|
||||
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
|
||||
async def switch_member(conn, message, args):
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
members = []
|
||||
for member_name in args:
|
||||
# Find the member
|
||||
member = await get_member_fuzzy(conn, system["id"], member_name)
|
||||
if not member:
|
||||
return False, "Couldn't find member \"{}\".".format(member_name)
|
||||
members.append(member)
|
||||
|
||||
# Compare requested switch IDs and existing fronter IDs to check for existing switches
|
||||
# Lists, because order matters, it makes sense to just swap fronters
|
||||
member_ids = [member["id"] for member in members]
|
||||
fronter_ids = (await get_fronter_ids(conn, system["id"]))[0]
|
||||
if member_ids == fronter_ids:
|
||||
if len(members) == 1:
|
||||
return False, "{} is already fronting.".format(members[0]["name"])
|
||||
return False, "Members {} are already fronting.".format(", ".join([m["name"] for m in members]))
|
||||
|
||||
# Log the switch
|
||||
async with conn.transaction():
|
||||
switch_id = await db.add_switch(conn, system_id=system["id"])
|
||||
for member in members:
|
||||
await db.add_switch_member(conn, switch_id=switch_id, member_id=member["id"])
|
||||
|
||||
if len(members) == 1:
|
||||
return True, "Switch registered. Current fronter is now {}.".format(member["name"])
|
||||
else:
|
||||
return True, "Switch registered. Current fronters are now {}.".format(", ".join([m["name"] for m in members]))
|
||||
|
||||
@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands")
|
||||
async def switch_out(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
# Get current fronters
|
||||
fronters, _ = await get_fronter_ids(conn, system_id=system["id"])
|
||||
if not fronters:
|
||||
return False, "There's already no one in front."
|
||||
|
||||
# Log it, and don't log any members
|
||||
await db.add_switch(conn, system_id=system["id"])
|
||||
return True, "Switch-out registered."
|
||||
|
||||
@command(cmd="switch move", usage="<time>", description="Moves the most recent switch to a different point in time.", category="Switching commands")
|
||||
async def switch_move(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
# Parse the time to move to
|
||||
new_time = dateparser.parse(" ".join(args), languages=["en"], settings={
|
||||
"TO_TIMEZONE": "UTC",
|
||||
"RETURN_AS_TIMEZONE_AWARE": False
|
||||
})
|
||||
if not new_time:
|
||||
return False, "{} can't be parsed as a valid time.".format(" ".join(args))
|
||||
|
||||
# Make sure the time isn't in the future
|
||||
if new_time > datetime.now():
|
||||
return False, "Can't move switch to a time in the future."
|
||||
|
||||
# Make sure it all runs in a big transaction for atomicity
|
||||
async with conn.transaction():
|
||||
# Get the last two switches to make sure the switch to move isn't before the second-last switch
|
||||
last_two_switches = await get_front_history(conn, system["id"], count=2)
|
||||
if len(last_two_switches) == 0:
|
||||
return False, "There are no registered switches for this system."
|
||||
|
||||
last_timestamp, last_fronters = last_two_switches[0]
|
||||
if len(last_two_switches) > 1:
|
||||
second_last_timestamp, _ = last_two_switches[1]
|
||||
|
||||
if new_time < second_last_timestamp:
|
||||
time_str = humanize.naturaltime(second_last_timestamp)
|
||||
return False, "Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str)
|
||||
|
||||
# Display the confirmation message w/ humanized times
|
||||
members = ", ".join([member["name"] for member in last_fronters])
|
||||
last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds")
|
||||
last_relative = humanize.naturaltime(last_timestamp)
|
||||
new_absolute = new_time.isoformat(sep=" ", timespec="seconds")
|
||||
new_relative = humanize.naturaltime(new_time)
|
||||
embed = make_default_embed("This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative))
|
||||
|
||||
# Await and handle confirmation reactions
|
||||
confirm_msg = await client.send_message(message.channel, embed=embed)
|
||||
await client.add_reaction(confirm_msg, "✅")
|
||||
await client.add_reaction(confirm_msg, "❌")
|
||||
|
||||
reaction = await client.wait_for_reaction(emoji=["✅", "❌"], message=confirm_msg, user=message.author, timeout=60.0)
|
||||
if not reaction:
|
||||
return False, "Switch move timed out."
|
||||
|
||||
if reaction.reaction.emoji == "❌":
|
||||
return False, "Switch move cancelled."
|
||||
|
||||
# DB requires the actual switch ID which our utility method above doesn't return, do this manually
|
||||
switch_id = (await db.front_history(conn, system["id"], count=1))[0]["id"]
|
||||
|
||||
# Change the switch in the DB
|
||||
await db.move_last_switch(conn, system["id"], switch_id, new_time)
|
||||
return True, "Switch moved."
|
||||
|
||||
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")
|
||||
async def set_log(conn, message, args):
|
||||
if not message.author.server_permissions.administrator:
|
||||
return False, "You must be a server administrator to use this command."
|
||||
|
||||
server = message.server
|
||||
if len(args) == 0:
|
||||
channel_id = None
|
||||
else:
|
||||
channel = parse_channel_mention(args[0], server=server)
|
||||
if not channel:
|
||||
return False, "Channel not found."
|
||||
channel_id = channel.id
|
||||
|
||||
await db.update_server(conn, server.id, logging_channel_id=channel_id)
|
||||
return True, "Updated logging channel." if channel_id else "Cleared logging channel."
|
||||
|
||||
@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.")
|
||||
async def show_help(conn, message, args):
|
||||
embed = make_default_embed("")
|
||||
embed.title = "PluralKit Help"
|
||||
|
||||
category = args[0] if len(args) > 0 else None
|
||||
|
||||
from pluralkit.help import help_pages
|
||||
if category in help_pages:
|
||||
for name, text in help_pages[category]:
|
||||
if name:
|
||||
embed.add_field(name=name, value=text)
|
||||
else:
|
||||
embed.description = text
|
||||
else:
|
||||
return False
|
||||
|
||||
return True, embed
|
||||
|
||||
@command(cmd="import tupperware", description="Import data from Tupperware.")
|
||||
async def import_tupperware(conn, message, args):
|
||||
tupperware_member = message.server.get_member("431544605209788416") or message.server.get_member("433916057053560832")
|
||||
|
||||
if not tupperware_member:
|
||||
return False, "This command only works in a server where the Tupperware bot is also present."
|
||||
|
||||
channel_permissions = message.channel.permissions_for(tupperware_member)
|
||||
if not (channel_permissions.read_messages and channel_permissions.send_messages):
|
||||
return False, "This command only works in a channel where the Tupperware bot has read/send access."
|
||||
|
||||
await client.send_message(message.channel, embed=make_default_embed("Please reply to this message with `tul!list` (or the server equivalent)."))
|
||||
|
||||
|
||||
# Check to make sure the Tupperware response actually belongs to the correct user
|
||||
def ensure_account(tw_msg):
|
||||
if not tw_msg.embeds:
|
||||
return False
|
||||
|
||||
if not tw_msg.embeds[0]["title"]:
|
||||
return False
|
||||
|
||||
return tw_msg.embeds[0]["title"].startswith("{}#{}".format(message.author.name, message.author.discriminator))
|
||||
|
||||
embeds = []
|
||||
|
||||
tw_msg = await client.wait_for_message(author=tupperware_member, channel=message.channel, timeout=60.0, check=ensure_account)
|
||||
if not tw_msg:
|
||||
return False, "Tupperware import timed out."
|
||||
embeds.append(tw_msg.embeds[0])
|
||||
|
||||
# Handle Tupperware pagination
|
||||
if tw_msg.embeds[0]["title"].endswith("(page 1)"):
|
||||
while True:
|
||||
# Wait for a new message (within 1 second)
|
||||
tw_msg = await client.wait_for_message(author=tupperware_member, channel=message.channel, timeout=1.0, check=ensure_account)
|
||||
if not tw_msg:
|
||||
# If no message, then it's probably done, so we break
|
||||
break
|
||||
# Otherwise add this next message to the list
|
||||
embeds.append(tw_msg.embeds[0])
|
||||
|
||||
logger.debug("Importing from Tupperware...")
|
||||
|
||||
# Create new (nameless) system if there isn't any registered
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
if system is None:
|
||||
hid = generate_hid()
|
||||
logger.debug("Creating new system (hid={})...".format(hid))
|
||||
system = await db.create_system(conn, system_name=None, system_hid=hid)
|
||||
await db.link_account(conn, system_id=system["id"], account_id=message.author.id)
|
||||
|
||||
for embed in embeds:
|
||||
for field in embed["fields"]:
|
||||
name = field["name"]
|
||||
lines = field["value"].split("\n")
|
||||
|
||||
member_prefix = None
|
||||
member_suffix = None
|
||||
member_avatar = None
|
||||
member_birthdate = None
|
||||
member_description = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("Brackets:"):
|
||||
brackets = line[len("Brackets: "):]
|
||||
member_prefix = brackets[:brackets.index("text")].strip() or None
|
||||
member_suffix = brackets[brackets.index("text")+4:].strip() or None
|
||||
elif line.startswith("Avatar URL: "):
|
||||
url = line[len("Avatar URL: "):]
|
||||
member_avatar = url
|
||||
elif line.startswith("Birthday: "):
|
||||
bday_str = line[len("Birthday: "):]
|
||||
bday = datetime.strptime(bday_str, "%a %b %d %Y")
|
||||
if bday:
|
||||
member_birthdate = bday.date()
|
||||
elif line.startswith("Total messages sent: ") or line.startswith("Tag: "):
|
||||
# Ignore this, just so it doesn't catch as the description
|
||||
pass
|
||||
else:
|
||||
member_description = line
|
||||
|
||||
existing_member = await db.get_member_by_name(conn, system_id=system["id"], member_name=name)
|
||||
if not existing_member:
|
||||
hid = generate_hid()
|
||||
logger.debug("Creating new member {} (hid={})...".format(name, hid))
|
||||
existing_member = await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid)
|
||||
|
||||
logger.debug("Updating fields...")
|
||||
await db.update_member_field(conn, member_id=existing_member["id"], field="prefix", value=member_prefix)
|
||||
await db.update_member_field(conn, member_id=existing_member["id"], field="suffix", value=member_suffix)
|
||||
await db.update_member_field(conn, member_id=existing_member["id"], field="avatar_url", value=member_avatar)
|
||||
await db.update_member_field(conn, member_id=existing_member["id"], field="birthday", value=member_birthdate)
|
||||
await db.update_member_field(conn, member_id=existing_member["id"], field="description", value=member_description)
|
||||
|
||||
return True, "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting."
|
||||
|
||||
@command(cmd="invite", description="Generates an invite link for this bot.")
|
||||
async def invite_link(conn, message, args):
|
||||
client_id = os.environ["CLIENT_ID"]
|
||||
|
||||
permissions = discord.Permissions()
|
||||
permissions.manage_webhooks = True
|
||||
permissions.send_messages = True
|
||||
permissions.manage_messages = True
|
||||
permissions.embed_links = True
|
||||
permissions.attach_files = True
|
||||
permissions.read_message_history = True
|
||||
permissions.add_reactions = True
|
||||
|
||||
url = oauth_url(client_id, permissions)
|
||||
logger.debug("Sending invite URL: {}".format(url))
|
||||
return True, url
|
||||
|
||||
@command(cmd="export", description="Exports system data to a machine-readable format.")
|
||||
async def export(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
members = await db.get_all_members(conn, system["id"])
|
||||
accounts = await db.get_linked_accounts(conn, system["id"])
|
||||
switches = await get_front_history(conn, system["id"], 999999)
|
||||
|
||||
data = {
|
||||
"name": system["name"],
|
||||
"id": system["hid"],
|
||||
"description": system["description"],
|
||||
"tag": system["tag"],
|
||||
"avatar_url": system["avatar_url"],
|
||||
"created": system["created"].isoformat(),
|
||||
"members": [
|
||||
{
|
||||
"name": member["name"],
|
||||
"id": member["hid"],
|
||||
"color": member["color"],
|
||||
"avatar_url": member["avatar_url"],
|
||||
"birthday": member["birthday"].isoformat() if member["birthday"] else None,
|
||||
"pronouns": member["pronouns"],
|
||||
"description": member["description"],
|
||||
"prefix": member["prefix"],
|
||||
"suffix": member["suffix"],
|
||||
"created": member["created"].isoformat()
|
||||
} for member in members
|
||||
],
|
||||
"accounts": [str(uid) for uid in accounts],
|
||||
"switches": [
|
||||
{
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"members": [member["hid"] for member in members]
|
||||
} for timestamp, members in switches
|
||||
]
|
||||
}
|
||||
|
||||
f = io.BytesIO(json.dumps(data).encode("utf-8"))
|
||||
await client.send_file(message.channel, f, filename="system.json")
|
@ -1,213 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
|
||||
from pluralkit import db, stats
|
||||
from pluralkit.bot import client, logger
|
||||
|
||||
def make_log_embed(hook_message, member, channel_name):
|
||||
author_name = "#{}: {}".format(channel_name, member["name"])
|
||||
if member["system_name"]:
|
||||
author_name += " ({})".format(member["system_name"])
|
||||
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.blue()
|
||||
embed.description = hook_message.clean_content
|
||||
embed.timestamp = hook_message.timestamp
|
||||
embed.set_author(name=author_name, icon_url=member["avatar_url"] or discord.Embed.Empty)
|
||||
|
||||
if len(hook_message.attachments) > 0:
|
||||
embed.set_image(url=hook_message.attachments[0]["url"])
|
||||
return embed
|
||||
|
||||
async def log_message(original_message, hook_message, member, log_channel):
|
||||
# hook_message is kinda broken, and doesn't include details from server or channel
|
||||
# We rely on the fact that original_message must be in the same channel, this'll break if that changes
|
||||
embed = make_log_embed(hook_message, member, channel_name=original_message.channel.name)
|
||||
embed.set_footer(text="System ID: {} | Member ID: {} | Sender: {}#{} | Message ID: {}".format(member["system_hid"], member["hid"], original_message.author.name, original_message.author.discriminator, hook_message.id))
|
||||
|
||||
message_link = "https://discordapp.com/channels/{}/{}/{}".format(original_message.server.id, original_message.channel.id, hook_message.id)
|
||||
embed.author.url = message_link
|
||||
|
||||
try:
|
||||
await client.send_message(log_channel, embed=embed)
|
||||
except discord.errors.Forbidden:
|
||||
# Ignore logging permission errors, perhaps make it spam a big nasty error instead
|
||||
pass
|
||||
|
||||
async def log_delete(hook_message, member, log_channel):
|
||||
embed = make_log_embed(hook_message, member, channel_name=hook_message.channel.name)
|
||||
embed.set_footer(text="System ID: {} | Member ID: {} | Message ID: {}".format(member["system_hid"], member["hid"], hook_message.id))
|
||||
embed.colour = discord.Colour.dark_red()
|
||||
|
||||
await client.send_message(log_channel, embed=embed)
|
||||
|
||||
async def get_log_channel(conn, server):
|
||||
# Check server info for a log channel
|
||||
server_info = await db.get_server_info(conn, server.id)
|
||||
if server_info and server_info["log_channel"]:
|
||||
channel = server.get_channel(str(server_info["log_channel"]))
|
||||
return channel
|
||||
|
||||
async def get_webhook(conn, channel):
|
||||
async with conn.transaction():
|
||||
# Try to find an existing webhook
|
||||
hook_row = await db.get_webhook(conn, channel_id=channel.id)
|
||||
# There's none, we'll make one
|
||||
if not hook_row:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
req_data = {"name": "PluralKit Proxy Webhook"}
|
||||
req_headers = {
|
||||
"Authorization": "Bot {}".format(os.environ["TOKEN"])}
|
||||
|
||||
async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp:
|
||||
data = await resp.json()
|
||||
hook_id = data["id"]
|
||||
token = data["token"]
|
||||
|
||||
# Insert new hook into DB
|
||||
await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token)
|
||||
return hook_id, token
|
||||
|
||||
return hook_row["webhook"], hook_row["token"]
|
||||
|
||||
async def send_hook_message(member, hook_id, hook_token, text=None, image_url=None):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Set up headers
|
||||
req_headers = {
|
||||
"Authorization": "Bot {}".format(os.environ["TOKEN"])
|
||||
}
|
||||
|
||||
# Set up parameters
|
||||
# Use FormData because the API doesn't like JSON requests with file data
|
||||
fd = aiohttp.FormData()
|
||||
fd.add_field("username", "{} {}".format(member["name"], member["tag"] or "").strip())
|
||||
if member["avatar_url"]:
|
||||
fd.add_field("avatar_url", member["avatar_url"])
|
||||
|
||||
if text:
|
||||
fd.add_field("content", text)
|
||||
|
||||
if image_url:
|
||||
# Fetch the image URL and proxy it directly into the file data (async streaming!)
|
||||
image_resp = await session.get(image_url)
|
||||
fd.add_field("file", image_resp.content, content_type=image_resp.content_type, filename=image_resp.url.name)
|
||||
|
||||
# Send the actual webhook request, and wait for a response
|
||||
time_before = time.perf_counter()
|
||||
try:
|
||||
async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
|
||||
data=fd,
|
||||
headers=req_headers) as resp:
|
||||
if resp.status == 200:
|
||||
resp_data = await resp.json()
|
||||
|
||||
# Make a fake message object for passing on - this is slightly broken but works for most things
|
||||
msg = discord.Message(reactions=[], **resp_data)
|
||||
|
||||
# Report to stats
|
||||
await stats.report_webhook(time.perf_counter() - time_before, True)
|
||||
return msg
|
||||
else:
|
||||
await stats.report_webhook(time.perf_counter() - time_before, False)
|
||||
|
||||
# Fake a Discord exception, also because #yolo
|
||||
raise discord.HTTPException(resp, await resp.text())
|
||||
except aiohttp.ClientResponseError:
|
||||
await stats.report_webhook(time.perf_counter() - time_before, False)
|
||||
logger.exception("Error while sending webhook message")
|
||||
|
||||
|
||||
async def proxy_message(conn, member, trigger_message, inner):
|
||||
logger.debug("Proxying message '{}' for member {}".format(inner, member["hid"]))
|
||||
logger.info("[{}#{}] {}".format(member["system_hid"], member["hid"], inner))
|
||||
|
||||
# Get the webhook details
|
||||
hook_id, hook_token = await get_webhook(conn, trigger_message.channel)
|
||||
|
||||
# Get attachment image URL if present (only works for one...)
|
||||
image_urls = [a["url"] for a in trigger_message.attachments if "url" in a]
|
||||
image_url = image_urls[0] if len(image_urls) > 0 else None
|
||||
|
||||
# Send the hook message
|
||||
hook_message = await send_hook_message(member, hook_id, hook_token, text=inner, image_url=image_url)
|
||||
|
||||
# Insert new message details into the DB
|
||||
await db.add_message(conn, message_id=hook_message.id, channel_id=trigger_message.channel.id, member_id=member["id"], sender_id=trigger_message.author.id, content=inner)
|
||||
|
||||
# Log message to logging channel if necessary
|
||||
log_channel = await get_log_channel(conn, trigger_message.server)
|
||||
if log_channel:
|
||||
await log_message(trigger_message, hook_message, member, log_channel)
|
||||
|
||||
# Delete the original message
|
||||
await client.delete_message(trigger_message)
|
||||
|
||||
async def handle_proxying(conn, message):
|
||||
# Can't proxy in DMs, webhook creation will explode
|
||||
if message.channel.is_private:
|
||||
return
|
||||
|
||||
# Big fat query to find every member associated with this account
|
||||
# Returned member object has a few more keys (system tag, for example)
|
||||
members = await db.get_members_by_account(conn, account_id=message.author.id)
|
||||
|
||||
# Sort by specificity (members with both prefix and suffix go higher)
|
||||
members = sorted(members, key=lambda x: int(
|
||||
bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True)
|
||||
|
||||
msg = message.content
|
||||
for member in members:
|
||||
# If no proxy details are configured, skip
|
||||
if not member["prefix"] and not member["suffix"]:
|
||||
continue
|
||||
|
||||
# Database stores empty strings as null, fix that here
|
||||
prefix = member["prefix"] or ""
|
||||
suffix = member["suffix"] or ""
|
||||
|
||||
# Avoid matching a prefix of "<" on a mention
|
||||
if prefix == "<":
|
||||
if re.match(r"^<(?:@|@!|#|@&|:\w+:|a:\w+:)\d+>", msg):
|
||||
continue
|
||||
|
||||
# If we have a match, proxy the message
|
||||
if msg.startswith(prefix) and msg.endswith(suffix):
|
||||
# Extract the actual message contents sans tags
|
||||
if suffix:
|
||||
inner_message = msg[len(prefix):-len(suffix)].strip()
|
||||
else:
|
||||
# Slicing to -0 breaks, don't do that
|
||||
inner_message = msg[len(prefix):].strip()
|
||||
|
||||
# Make sure the message isn't blank (but only if it has no attachments)
|
||||
if inner_message or message.attachments:
|
||||
await proxy_message(conn, member, message, inner_message)
|
||||
break
|
||||
|
||||
|
||||
async def handle_reaction(conn, user_id, message_id, emoji):
|
||||
if emoji == "❌":
|
||||
async with conn.transaction():
|
||||
# Find the message in the DB, and make sure it's sent by the user who reacted
|
||||
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=user_id)
|
||||
if db_message:
|
||||
logger.debug("Deleting message {} by reaction from {}".format(message_id, user_id))
|
||||
|
||||
# If so, remove it from the DB
|
||||
await db.delete_message(conn, message_id)
|
||||
|
||||
# And look up the message and then delete it
|
||||
channel = client.get_channel(str(db_message["channel"]))
|
||||
message = await client.get_message(channel, message_id)
|
||||
await client.delete_message(message)
|
||||
|
||||
# Log deletion to logging channel if necessary
|
||||
log_channel = await get_log_channel(conn, message.server)
|
||||
if log_channel:
|
||||
# db_message contains enough member data for the things to work
|
||||
await log_delete(message, db_message, log_channel)
|
@ -1,304 +0,0 @@
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
||||
import asyncio
|
||||
import asyncpg
|
||||
import discord
|
||||
import humanize
|
||||
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import client, logger
|
||||
|
||||
def escape(s):
|
||||
return s.replace("`", "\\`")
|
||||
|
||||
def generate_hid() -> str:
|
||||
return "".join(random.choices(string.ascii_lowercase, k=5))
|
||||
|
||||
def bounds_check_member_name(new_name, system_tag):
|
||||
if len(new_name) > 32:
|
||||
return "Name cannot be longer than 32 characters."
|
||||
|
||||
if system_tag:
|
||||
if len("{} {}".format(new_name, system_tag)) > 32:
|
||||
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
|
||||
|
||||
async def parse_mention(mention: str) -> discord.User:
|
||||
# First try matching mention format
|
||||
match = re.fullmatch("<@!?(\\d+)>", mention)
|
||||
if match:
|
||||
try:
|
||||
return await client.get_user_info(match.group(1))
|
||||
except discord.NotFound:
|
||||
return None
|
||||
|
||||
# Then try with just ID
|
||||
try:
|
||||
return await client.get_user_info(str(int(mention)))
|
||||
except (ValueError, discord.NotFound):
|
||||
return None
|
||||
|
||||
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
|
||||
match = re.fullmatch("<#(\\d+)>", mention)
|
||||
if match:
|
||||
return server.get_channel(match.group(1))
|
||||
|
||||
try:
|
||||
return server.get_channel(str(int(mention)))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def get_fronter_ids(conn, system_id):
|
||||
switches = await db.front_history(conn, system_id=system_id, count=1)
|
||||
if not switches:
|
||||
return [], None
|
||||
|
||||
if not switches[0]["members"]:
|
||||
return [], switches[0]["timestamp"]
|
||||
|
||||
return switches[0]["members"], switches[0]["timestamp"]
|
||||
|
||||
async def get_fronters(conn, system_id):
|
||||
member_ids, timestamp = await get_fronter_ids(conn, system_id)
|
||||
|
||||
# Collect in dict and then look up as list, to preserve return order
|
||||
members = {member["id"]: member for member in await db.get_members(conn, member_ids)}
|
||||
return [members[member_id] for member_id in member_ids], timestamp
|
||||
|
||||
async def get_front_history(conn, system_id, count):
|
||||
# Get history from DB
|
||||
switches = await db.front_history(conn, system_id=system_id, count=count)
|
||||
if not switches:
|
||||
return []
|
||||
|
||||
# Get all unique IDs referenced
|
||||
all_member_ids = {id for switch in switches for id in switch["members"]}
|
||||
|
||||
# And look them up in the database into a dict
|
||||
all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))}
|
||||
|
||||
# Collect in array and return
|
||||
out = []
|
||||
for switch in switches:
|
||||
timestamp = switch["timestamp"]
|
||||
members = [all_members[id] for id in switch["members"]]
|
||||
out.append((timestamp, members))
|
||||
return out
|
||||
|
||||
async def get_system_fuzzy(conn, key) -> asyncpg.Record:
|
||||
if isinstance(key, discord.User):
|
||||
return await db.get_system_by_account(conn, account_id=key.id)
|
||||
|
||||
if isinstance(key, str) and len(key) == 5:
|
||||
return await db.get_system_by_hid(conn, system_hid=key)
|
||||
|
||||
account = await parse_mention(key)
|
||||
if account:
|
||||
system = await db.get_system_by_account(conn, account_id=account.id)
|
||||
if system:
|
||||
return system
|
||||
return None
|
||||
|
||||
|
||||
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record:
|
||||
# First search by hid
|
||||
if system_only:
|
||||
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
|
||||
else:
|
||||
member = await db.get_member_by_hid(conn, member_hid=key)
|
||||
if member is not None:
|
||||
return member
|
||||
|
||||
# Then search by name, if we have a system
|
||||
if system_id:
|
||||
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
|
||||
if member is not None:
|
||||
return member
|
||||
|
||||
def make_default_embed(message):
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.blue()
|
||||
embed.description = message
|
||||
return embed
|
||||
|
||||
def make_error_embed(message):
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.dark_red()
|
||||
embed.description = message
|
||||
return embed
|
||||
|
||||
command_map = {}
|
||||
|
||||
# Command wrapper
|
||||
# Return True for success, return False for failure
|
||||
# Second parameter is the message it'll send. If just False, will print usage
|
||||
|
||||
|
||||
def command(cmd, usage=None, description=None, category=None):
|
||||
def wrap(func):
|
||||
async def wrapper(conn, message, args):
|
||||
res = await func(conn, message, args)
|
||||
|
||||
if res is not None:
|
||||
if not isinstance(res, tuple):
|
||||
success, msg = res, None
|
||||
else:
|
||||
success, msg = res
|
||||
|
||||
if not success and not msg:
|
||||
# Failure, no message, print usage
|
||||
usage_str = "**Usage:** pk;{} {}".format(cmd, usage or "")
|
||||
await client.send_message(message.channel, embed=make_default_embed(usage_str))
|
||||
elif not success:
|
||||
# Failure, print message
|
||||
embed = msg if isinstance(msg, discord.Embed) else make_error_embed(msg)
|
||||
# embed.set_footer(text="{:.02f} ms".format(time_ms))
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
elif msg:
|
||||
# Success, print message
|
||||
embed = msg if isinstance(msg, discord.Embed) else make_default_embed(msg)
|
||||
# embed.set_footer(text="{:.02f} ms".format(time_ms))
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
# Success, don't print anything
|
||||
|
||||
# Put command in map
|
||||
command_map[cmd] = (wrapper, usage, description, category)
|
||||
return wrapper
|
||||
return wrap
|
||||
|
||||
# Member command wrapper
|
||||
# Tries to find member by first argument
|
||||
# If system_only=False, allows members from other systems by hid
|
||||
|
||||
|
||||
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
|
||||
def wrap(func):
|
||||
async def wrapper(conn, message, args):
|
||||
# Return if no member param
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
# If system_only, we need a system to check
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
if system_only and system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
# System is allowed to be none if not system_only
|
||||
system_id = system["id"] if system else None
|
||||
# And find member by key
|
||||
member = await get_member_fuzzy(conn, system_id=system_id, key=args[0], system_only=system_only)
|
||||
|
||||
if member is None:
|
||||
return False, "Can't find member \"{}\".".format(args[0])
|
||||
|
||||
return await func(conn, message, member, args[1:])
|
||||
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category)(wrapper)
|
||||
return wrap
|
||||
|
||||
|
||||
async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed:
|
||||
card = discord.Embed()
|
||||
card.colour = discord.Colour.blue()
|
||||
|
||||
if system["name"]:
|
||||
card.title = system["name"]
|
||||
|
||||
if system["avatar_url"]:
|
||||
card.set_thumbnail(url=system["avatar_url"])
|
||||
|
||||
if system["tag"]:
|
||||
card.add_field(name="Tag", value=system["tag"])
|
||||
|
||||
fronters, switch_time = await get_fronters(conn, system["id"])
|
||||
if fronters:
|
||||
names = ", ".join([member["name"] for member in fronters])
|
||||
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time))
|
||||
card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters", value=fronter_val)
|
||||
|
||||
account_names = []
|
||||
for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
|
||||
account = await client.get_user_info(account_id)
|
||||
account_names.append("{}#{}".format(account.name, account.discriminator))
|
||||
card.add_field(name="Linked accounts", value="\n".join(account_names))
|
||||
|
||||
if system["description"]:
|
||||
card.add_field(name="Description",
|
||||
value=system["description"], inline=False)
|
||||
|
||||
# Get names of all members
|
||||
member_texts = []
|
||||
for member in await db.get_all_members(conn, system_id=system["id"]):
|
||||
member_texts.append("{} (`{}`)".format(escape(member["name"]), member["hid"]))
|
||||
|
||||
if len(member_texts) > 0:
|
||||
card.add_field(name="Members", value="\n".join(
|
||||
member_texts), inline=False)
|
||||
|
||||
card.set_footer(text="System ID: {}".format(system["hid"]))
|
||||
return card
|
||||
|
||||
|
||||
async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed:
|
||||
system = await db.get_system(conn, system_id=member["system"])
|
||||
|
||||
card = discord.Embed()
|
||||
card.colour = discord.Colour.blue()
|
||||
|
||||
name_and_system = member["name"]
|
||||
if system["name"]:
|
||||
name_and_system += " ({})".format(system["name"])
|
||||
|
||||
card.set_author(name=name_and_system, icon_url=member["avatar_url"] or discord.Embed.Empty)
|
||||
if member["avatar_url"]:
|
||||
card.set_thumbnail(url=member["avatar_url"])
|
||||
|
||||
# Get system name and hid
|
||||
system = await db.get_system(conn, system_id=member["system"])
|
||||
|
||||
if member["color"]:
|
||||
card.colour = int(member["color"], 16)
|
||||
|
||||
if member["birthday"]:
|
||||
bday_val = member["birthday"].strftime("%b %d, %Y")
|
||||
if member["birthday"].year == 1:
|
||||
bday_val = member["birthday"].strftime("%b %d")
|
||||
card.add_field(name="Birthdate", value=bday_val)
|
||||
|
||||
if member["pronouns"]:
|
||||
card.add_field(name="Pronouns", value=member["pronouns"])
|
||||
|
||||
if member["prefix"] or member["suffix"]:
|
||||
prefix = member["prefix"] or ""
|
||||
suffix = member["suffix"] or ""
|
||||
card.add_field(name="Proxy Tags",
|
||||
value="{}text{}".format(prefix, suffix))
|
||||
|
||||
if member["description"]:
|
||||
card.add_field(name="Description",
|
||||
value=member["description"], inline=False)
|
||||
|
||||
card.set_footer(text="System ID: {} | Member ID: {}".format(
|
||||
system["hid"], member["hid"]))
|
||||
return card
|
||||
|
||||
|
||||
async def text_input(message, subject):
|
||||
embed = make_default_embed("")
|
||||
embed.description = "Reply in this channel with the new description you want to set for {}.".format(subject)
|
||||
|
||||
status_msg = await client.send_message(message.channel, embed=embed)
|
||||
reply_msg = await client.wait_for_message(author=message.author, channel=message.channel)
|
||||
|
||||
embed.description = "Alright. When you're happy with the new description, click the ✅ reaction. To cancel, click the ❌ reaction."
|
||||
await client.edit_message(status_msg, embed=embed)
|
||||
await client.add_reaction(reply_msg, "✅")
|
||||
await client.add_reaction(reply_msg, "❌")
|
||||
|
||||
reaction = await client.wait_for_reaction(emoji=["✅", "❌"], message=reply_msg, user=message.author)
|
||||
if reaction.reaction.emoji == "✅":
|
||||
await client.clear_reactions(reply_msg)
|
||||
return reply_msg.content
|
||||
else:
|
||||
await client.clear_reactions(reply_msg)
|
||||
return None
|
@ -1,7 +1,9 @@
|
||||
version: '3'
|
||||
services:
|
||||
bot:
|
||||
build: bot
|
||||
build:
|
||||
context: src/
|
||||
dockerfile: bot.Dockerfile
|
||||
depends_on:
|
||||
- db
|
||||
- influx
|
||||
|
@ -7,5 +7,5 @@ ADD requirements.txt /app
|
||||
RUN pip install --trusted-host pypi.python.org -r requirements.txt
|
||||
|
||||
ADD . /app
|
||||
ENTRYPOINT ["python", "main.py"]
|
||||
ENTRYPOINT ["python", "bot_main.py"]
|
||||
|
@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uvloop
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
from pluralkit import bot
|
||||
|
||||
pk = bot.PluralKitBot(os.environ["TOKEN"])
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(bot.run())
|
||||
loop.run_until_complete(pk.run())
|
26
src/pluralkit/__init__.py
Normal file
26
src/pluralkit/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
from collections import namedtuple
|
||||
from datetime import date, datetime
|
||||
|
||||
|
||||
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])):
|
||||
id: int
|
||||
hid: str
|
||||
name: str
|
||||
description: str
|
||||
tag: str
|
||||
avatar_url: str
|
||||
created: datetime
|
||||
|
||||
class Member(namedtuple("Member", ["id", "hid", "system", "color", "avatar_url", "name", "birthday", "pronouns", "description", "prefix", "suffix", "created"])):
|
||||
id: int
|
||||
hid: str
|
||||
system: int
|
||||
color: str
|
||||
avatar_url: str
|
||||
name: str
|
||||
birthday: date
|
||||
pronouns: str
|
||||
description: str
|
||||
prefix: str
|
||||
suffix: str
|
||||
created: datetime
|
131
src/pluralkit/bot/__init__.py
Normal file
131
src/pluralkit/bot/__init__.py
Normal file
@ -0,0 +1,131 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import discord
|
||||
|
||||
from pluralkit import db, stats
|
||||
from pluralkit.bot import channel_logger, commands, proxy
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||
logging.getLogger("pluralkit").setLevel(logging.DEBUG)
|
||||
|
||||
class PluralKitBot:
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
self.logger = logging.getLogger("pluralkit.bot")
|
||||
|
||||
self.client = discord.Client()
|
||||
self.client.event(self.on_error)
|
||||
self.client.event(self.on_ready)
|
||||
self.client.event(self.on_message)
|
||||
self.client.event(self.on_socket_raw_receive)
|
||||
|
||||
self.channel_logger = channel_logger.ChannelLogger(self.client)
|
||||
self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
|
||||
|
||||
async def on_error(self, evt, *args, **kwargs):
|
||||
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
|
||||
|
||||
async def on_ready(self):
|
||||
self.logger.info("Connected to Discord.")
|
||||
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
|
||||
self.logger.info("- User ID: {}".format(self.client.user.id))
|
||||
self.logger.info("- {} servers".format(len(self.client.servers)))
|
||||
|
||||
async def on_message(self, message):
|
||||
# Ignore bot messages
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
if await self.handle_command_dispatch(message):
|
||||
return
|
||||
|
||||
if await self.handle_proxy_dispatch(message):
|
||||
return
|
||||
|
||||
async def on_socket_raw_receive(self, msg):
|
||||
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
|
||||
# we parse socket data manually for the reaction add event
|
||||
if isinstance(msg, str):
|
||||
try:
|
||||
msg_data = json.loads(msg)
|
||||
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
|
||||
evt_data = msg_data.get("d")
|
||||
if evt_data:
|
||||
user_id = evt_data["user_id"]
|
||||
message_id = evt_data["message_id"]
|
||||
emoji = evt_data["emoji"]["name"]
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
|
||||
elif msg_data.get("t") == "MESSAGE_DELETE":
|
||||
evt_data = msg_data.get("d")
|
||||
if evt_data:
|
||||
message_id = evt_data["id"]
|
||||
async with self.pool.acquire() as conn:
|
||||
await self.proxy.handle_deletion(conn, message_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def handle_command_dispatch(self, message):
|
||||
command_items = commands.command_list.items()
|
||||
command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True)
|
||||
|
||||
prefix = "pk;"
|
||||
for command_name, command in command_items:
|
||||
if message.content.lower().startswith(prefix + command_name):
|
||||
args_str = message.content[len(prefix + command_name):].strip()
|
||||
args = args_str.split(" ")
|
||||
|
||||
# Splitting on empty string yields one-element array, remove that
|
||||
if len(args) == 1 and not args[0]:
|
||||
args = []
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
time_before = time.perf_counter()
|
||||
await command.function(self.client, conn, message, args)
|
||||
time_after = time.perf_counter()
|
||||
|
||||
# Report command time stats
|
||||
execution_time = time_after - time_before
|
||||
response_time = (datetime.now() - message.timestamp).total_seconds()
|
||||
await stats.report_command(command_name, execution_time, response_time)
|
||||
|
||||
return True
|
||||
|
||||
async def handle_proxy_dispatch(self, message):
|
||||
# Try doing proxy parsing
|
||||
async with self.pool.acquire() as conn:
|
||||
return await self.proxy.try_proxy_message(conn, message)
|
||||
|
||||
async def periodical_stat_timer(self, pool):
|
||||
async with pool.acquire() as conn:
|
||||
while True:
|
||||
from pluralkit import stats
|
||||
await stats.report_periodical_stats(conn)
|
||||
await asyncio.sleep(30)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
self.logger.info("Connecting to database...")
|
||||
self.pool = await db.connect()
|
||||
|
||||
self.logger.info("Attempting to create tables...")
|
||||
async with self.pool.acquire() as conn:
|
||||
await db.create_tables(conn)
|
||||
|
||||
self.logger.info("Connecting to InfluxDB...")
|
||||
await stats.connect()
|
||||
|
||||
self.logger.info("Starting periodical stat reporting...")
|
||||
asyncio.get_event_loop().create_task(self.periodical_stat_timer(self.pool))
|
||||
|
||||
self.logger.info("Connecting to Discord...")
|
||||
await self.client.start(self.token)
|
||||
finally:
|
||||
self.logger.info("Logging out from Discord...")
|
||||
await self.client.logout()
|
109
src/pluralkit/bot/channel_logger.py
Normal file
109
src/pluralkit/bot/channel_logger.py
Normal file
@ -0,0 +1,109 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import discord
|
||||
|
||||
from pluralkit import db
|
||||
|
||||
|
||||
def embed_set_author_name(embed: discord.Embed, channel_name: str, member_name: str, system_name: str, avatar_url: str):
|
||||
name = "#{}: {}".format(channel_name, member_name)
|
||||
if system_name:
|
||||
name += " ({})".format(system_name)
|
||||
|
||||
embed.set_author(name=name, icon_url=avatar_url or discord.Embed.Empty)
|
||||
|
||||
|
||||
class ChannelLogger:
|
||||
def __init__(self, client: discord.Client):
|
||||
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
|
||||
self.client = client
|
||||
|
||||
async def get_log_channel(self, conn, server_id: str):
|
||||
server_info = await db.get_server_info(conn, server_id)
|
||||
|
||||
if not server_info:
|
||||
return None
|
||||
|
||||
log_channel = server_info["log_channel"]
|
||||
|
||||
if not log_channel:
|
||||
return None
|
||||
|
||||
return self.client.get_channel(str(log_channel))
|
||||
|
||||
async def send_to_log_channel(self, log_channel: discord.Channel, embed: discord.Embed):
|
||||
try:
|
||||
await self.client.send_message(log_channel, embed=embed)
|
||||
except discord.Forbidden:
|
||||
# TODO: spew big error
|
||||
self.logger.warning(
|
||||
"Did not have permission to send message to logging channel (server={}, channel={})".format(
|
||||
log_channel.server.id, log_channel.id))
|
||||
|
||||
async def log_message_proxied(self, conn,
|
||||
server_id: str,
|
||||
channel_name: str,
|
||||
channel_id: str,
|
||||
sender_name: str,
|
||||
sender_disc: int,
|
||||
member_name: str,
|
||||
member_hid: str,
|
||||
member_avatar_url: str,
|
||||
system_name: str,
|
||||
system_hid: str,
|
||||
message_text: str,
|
||||
message_image: str,
|
||||
message_timestamp: datetime,
|
||||
message_id: str):
|
||||
log_channel = await self.get_log_channel(conn, server_id)
|
||||
if not log_channel:
|
||||
return
|
||||
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.blue()
|
||||
embed.description = message_text
|
||||
embed.timestamp = message_timestamp
|
||||
|
||||
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
|
||||
embed.set_footer(
|
||||
text="System ID: {} | Member ID: {} | Sender: {}#{} | Message ID: {}".format(system_hid, member_hid,
|
||||
sender_name, sender_disc,
|
||||
message_id))
|
||||
|
||||
if message_image:
|
||||
embed.set_thumbnail(url=message_image)
|
||||
|
||||
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
|
||||
embed.author.url = message_link
|
||||
|
||||
await self.send_to_log_channel(log_channel, embed)
|
||||
|
||||
async def log_message_deleted(self, conn,
|
||||
server_id: str,
|
||||
channel_name: str,
|
||||
member_name: str,
|
||||
member_hid: str,
|
||||
member_avatar_url: str,
|
||||
system_name: str,
|
||||
system_hid: str,
|
||||
message_text: str,
|
||||
message_id: str,
|
||||
deleted_by_moderator: bool):
|
||||
log_channel = await self.get_log_channel(conn, server_id)
|
||||
if not log_channel:
|
||||
return
|
||||
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.dark_red()
|
||||
embed.description = message_text
|
||||
embed.timestamp = datetime.utcnow()
|
||||
|
||||
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
|
||||
embed.set_footer(
|
||||
text="System ID: {} | Member ID: {} | Message ID: {} | Deleted by moderator? {}".format(system_hid,
|
||||
member_hid,
|
||||
message_id,
|
||||
"Yes" if deleted_by_moderator else "No"))
|
||||
|
||||
await self.send_to_log_channel(log_channel, embed)
|
98
src/pluralkit/bot/commands/__init__.py
Normal file
98
src/pluralkit/bot/commands/__init__.py
Normal file
@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import asyncpg
|
||||
import discord
|
||||
|
||||
import pluralkit
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import utils
|
||||
|
||||
command_list = {}
|
||||
|
||||
class InvalidCommandSyntax(Exception):
|
||||
pass
|
||||
|
||||
class NoSystemRegistered(Exception):
|
||||
pass
|
||||
|
||||
class CommandError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
class CommandContext(namedtuple("CommandContext", ["client", "conn", "message", "system"])):
|
||||
client: discord.Client
|
||||
conn: asyncpg.Connection
|
||||
message: discord.Message
|
||||
system: pluralkit.System
|
||||
|
||||
async def reply(self, message=None, embed=None):
|
||||
return await self.client.send_message(self.message.channel, message, embed=embed)
|
||||
|
||||
class MemberCommandContext(namedtuple("MemberCommandContext", CommandContext._fields + ("member",)), CommandContext):
|
||||
client: discord.Client
|
||||
conn: asyncpg.Connection
|
||||
message: discord.Message
|
||||
system: pluralkit.System
|
||||
member: pluralkit.Member
|
||||
|
||||
class CommandEntry(namedtuple("CommandEntry", ["command", "function", "usage", "description", "category"])):
|
||||
pass
|
||||
|
||||
def command(cmd, usage=None, description=None, category=None, system_required=True):
|
||||
def wrap(func):
|
||||
async def wrapper(client, conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system_required and system is None:
|
||||
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account"))
|
||||
return
|
||||
|
||||
ctx = CommandContext(client=client, conn=conn, message=message, system=system)
|
||||
try:
|
||||
res = await func(ctx, args)
|
||||
|
||||
if res:
|
||||
embed = res if isinstance(res, discord.Embed) else utils.make_default_embed(res)
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
except NoSystemRegistered:
|
||||
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account"))
|
||||
except InvalidCommandSyntax:
|
||||
usage_str = "**Usage:** pk;{} {}".format(cmd, usage or "")
|
||||
await client.send_message(message.channel, embed=utils.make_default_embed(usage_str))
|
||||
except CommandError as e:
|
||||
embed = e.message if isinstance(e.message, discord.Embed) else utils.make_error_embed(e.message)
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
|
||||
# Put command in map
|
||||
command_list[cmd] = CommandEntry(command=cmd, function=wrapper, usage=usage, description=description, category=category)
|
||||
return wrapper
|
||||
return wrap
|
||||
|
||||
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
|
||||
def wrap(func):
|
||||
async def wrapper(ctx: CommandContext, args):
|
||||
# Return if no member param
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
# System is allowed to be none if not system_only
|
||||
system_id = ctx.system.id if ctx.system else None
|
||||
# And find member by key
|
||||
member = await utils.get_member_fuzzy(ctx.conn, system_id=system_id, key=args[0], system_only=system_only)
|
||||
|
||||
if member is None:
|
||||
raise CommandError("Can't find member \"{}\".".format(args[0]))
|
||||
|
||||
ctx = MemberCommandContext(client=ctx.client, conn=ctx.conn, message=ctx.message, system=ctx.system, member=member)
|
||||
return await func(ctx, args[1:])
|
||||
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category, system_required=False)(wrapper)
|
||||
return wrap
|
||||
|
||||
import pluralkit.bot.commands.import_commands
|
||||
import pluralkit.bot.commands.member_commands
|
||||
import pluralkit.bot.commands.message_commands
|
||||
import pluralkit.bot.commands.misc_commands
|
||||
import pluralkit.bot.commands.mod_commands
|
||||
import pluralkit.bot.commands.switch_commands
|
||||
import pluralkit.bot.commands.system_commands
|
143
src/pluralkit/bot/commands/import_commands.py
Normal file
143
src/pluralkit/bot/commands/import_commands.py
Normal file
@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from pluralkit.bot import utils
|
||||
from pluralkit.bot.commands import *
|
||||
|
||||
logger = logging.getLogger("pluralkit.commands")
|
||||
|
||||
@command(cmd="import tupperware", description="Import data from Tupperware.")
|
||||
async def import_tupperware(ctx: CommandContext, args: List[str]):
|
||||
tupperware_member = ctx.message.server.get_member("431544605209788416") or ctx.message.server.get_member("433916057053560832")
|
||||
|
||||
if not tupperware_member:
|
||||
raise CommandError("This command only works in a server where the Tupperware bot is also present.")
|
||||
|
||||
channel_permissions = ctx.message.channel.permissions_for(tupperware_member)
|
||||
if not (channel_permissions.read_messages and channel_permissions.send_messages):
|
||||
raise CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
|
||||
|
||||
await ctx.reply(embed=utils.make_default_embed("Please reply to this message with `tul!list` (or the server equivalent)."))
|
||||
|
||||
|
||||
# Check to make sure the Tupperware response actually belongs to the correct user
|
||||
def ensure_account(tw_msg):
|
||||
if not tw_msg.embeds:
|
||||
return False
|
||||
|
||||
if not tw_msg.embeds[0]["title"]:
|
||||
return False
|
||||
|
||||
return tw_msg.embeds[0]["title"].startswith("{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
|
||||
|
||||
embeds = []
|
||||
|
||||
tw_msg: discord.Message = await ctx.client.wait_for_message(author=tupperware_member, channel=ctx.message.channel, timeout=60.0, check=ensure_account)
|
||||
if not tw_msg:
|
||||
raise CommandError("Tupperware import timed out.")
|
||||
embeds.append(tw_msg.embeds[0])
|
||||
|
||||
# Handle Tupperware pagination
|
||||
def match_pagination():
|
||||
pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0]["title"])
|
||||
if not pagination_match:
|
||||
return None
|
||||
return int(pagination_match.group(1)), int(pagination_match.group(2))
|
||||
|
||||
pagination_match = match_pagination()
|
||||
if pagination_match:
|
||||
status_msg = await ctx.reply("Multi-page member list found. Please manually scroll through all the pages.")
|
||||
current_page = 0
|
||||
total_pages = 1
|
||||
|
||||
pages_found = {}
|
||||
|
||||
# Keep trying to read the embed with new pages
|
||||
last_found_time = datetime.utcnow()
|
||||
while len(pages_found) < total_pages:
|
||||
new_page, total_pages = match_pagination()
|
||||
|
||||
# Put the found page in the pages dict
|
||||
pages_found[new_page] = dict(tw_msg.embeds[0])
|
||||
|
||||
# If this isn't the same page as last check, edit the status message
|
||||
if new_page != current_page:
|
||||
last_found_time = datetime.utcnow()
|
||||
await ctx.client.edit_message(status_msg, "Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(len(pages_found), total_pages))
|
||||
current_page = new_page
|
||||
|
||||
# And sleep a bit to prevent spamming the CPU
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
# Make sure it doesn't spin here for too long, time out after 30 seconds since last new page
|
||||
if (datetime.utcnow() - last_found_time).seconds > 30:
|
||||
raise CommandError("Pagination scan timed out.")
|
||||
|
||||
# Now that we've got all the pages, put them in the embeds list
|
||||
# Make sure to erase the original one we put in above too
|
||||
embeds = list([embed for page, embed in sorted(pages_found.items(), key=lambda x: x[0])])
|
||||
|
||||
# Also edit the status message to indicate we're now importing, and it may take a while because there's probably a lot of members
|
||||
await ctx.client.edit_message(status_msg, "All pages read. Now importing...")
|
||||
|
||||
logger.debug("Importing from Tupperware...")
|
||||
|
||||
# Create new (nameless) system if there isn't any registered
|
||||
system = ctx.system
|
||||
if system is None:
|
||||
hid = utils.generate_hid()
|
||||
logger.debug("Creating new system (hid={})...".format(hid))
|
||||
system = await db.create_system(ctx.conn, system_name=None, system_hid=hid)
|
||||
await db.link_account(ctx.conn, system_id=system["id"], account_id=ctx.message.author.id)
|
||||
|
||||
for embed in embeds:
|
||||
for field in embed["fields"]:
|
||||
name = field["name"]
|
||||
lines = field["value"].split("\n")
|
||||
|
||||
member_prefix = None
|
||||
member_suffix = None
|
||||
member_avatar = None
|
||||
member_birthdate = None
|
||||
member_description = None
|
||||
|
||||
# Read the message format line by line
|
||||
for line in lines:
|
||||
if line.startswith("Brackets:"):
|
||||
brackets = line[len("Brackets: "):]
|
||||
member_prefix = brackets[:brackets.index("text")].strip() or None
|
||||
member_suffix = brackets[brackets.index("text")+4:].strip() or None
|
||||
elif line.startswith("Avatar URL: "):
|
||||
url = line[len("Avatar URL: "):]
|
||||
member_avatar = url
|
||||
elif line.startswith("Birthday: "):
|
||||
bday_str = line[len("Birthday: "):]
|
||||
bday = datetime.strptime(bday_str, "%a %b %d %Y")
|
||||
if bday:
|
||||
member_birthdate = bday.date()
|
||||
elif line.startswith("Total messages sent: ") or line.startswith("Tag: "):
|
||||
# Ignore this, just so it doesn't catch as the description
|
||||
pass
|
||||
else:
|
||||
member_description = line
|
||||
|
||||
# Read by name - TW doesn't allow name collisions so we're safe here (prevents dupes)
|
||||
existing_member = await db.get_member_by_name(ctx.conn, system_id=system.id, member_name=name)
|
||||
if not existing_member:
|
||||
# Or create a new member
|
||||
hid = utils.generate_hid()
|
||||
logger.debug("Creating new member {} (hid={})...".format(name, hid))
|
||||
existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid)
|
||||
|
||||
# Save the new stuff in the DB
|
||||
logger.debug("Updating fields...")
|
||||
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="prefix", value=member_prefix)
|
||||
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="suffix", value=member_suffix)
|
||||
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url", value=member_avatar)
|
||||
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday", value=member_birthdate)
|
||||
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", value=member_description)
|
||||
|
||||
return "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting."
|
150
src/pluralkit/bot/commands/member_commands.py
Normal file
150
src/pluralkit/bot/commands/member_commands.py
Normal file
@ -0,0 +1,150 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pluralkit.bot import utils
|
||||
from pluralkit.bot.commands import *
|
||||
|
||||
logger = logging.getLogger("pluralkit.commands")
|
||||
|
||||
@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands")
|
||||
async def member_info(ctx: MemberCommandContext, args: List[str]):
|
||||
await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, ctx.member))
|
||||
|
||||
@command(cmd="member new", usage="<name>", description="Adds a new member to your system.", category="Member commands")
|
||||
async def new_member(ctx: MemberCommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
name = " ".join(args)
|
||||
bounds_error = utils.bounds_check_member_name(name, ctx.system.tag)
|
||||
if bounds_error:
|
||||
raise CommandError(bounds_error)
|
||||
|
||||
# TODO: figure out what to do if this errors out on collision on generate_hid
|
||||
hid = utils.generate_hid()
|
||||
|
||||
# Insert member row
|
||||
await db.create_member(ctx.conn, system_id=ctx.system.id, member_name=name, member_hid=hid)
|
||||
return "Member \"{}\" (`{}`) registered!".format(name, hid)
|
||||
|
||||
|
||||
@member_command(cmd="member set", usage="<name|description|color|pronouns|birthdate|avatar> [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands")
|
||||
async def member_set(ctx: MemberCommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"]
|
||||
db_properties = {
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"color": "color",
|
||||
"pronouns": "pronouns",
|
||||
"birthdate": "birthday",
|
||||
"avatar": "avatar_url"
|
||||
}
|
||||
|
||||
prop = args[0]
|
||||
if prop not in allowed_properties:
|
||||
raise CommandError("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)))
|
||||
|
||||
if len(args) >= 2:
|
||||
value = " ".join(args[1:])
|
||||
|
||||
# Sanity/validity checks and type conversions
|
||||
if prop == "name":
|
||||
bounds_error = utils.bounds_check_member_name(value, ctx.system.tag)
|
||||
if bounds_error:
|
||||
raise CommandError(bounds_error)
|
||||
|
||||
if prop == "color":
|
||||
match = re.fullmatch("#?([0-9A-Fa-f]{6})", value)
|
||||
if not match:
|
||||
raise CommandError("Color must be a valid hex color (eg. #ff0000)")
|
||||
|
||||
value = match.group(1).lower()
|
||||
|
||||
if prop == "birthdate":
|
||||
try:
|
||||
value = datetime.strptime(value, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
try:
|
||||
# Try again, adding 0001 as a placeholder year
|
||||
# This is considered a "null year" and will be omitted from the info card
|
||||
# Useful if you want your birthday to be displayed yearless.
|
||||
value = datetime.strptime("0001-" + value, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
raise CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).")
|
||||
|
||||
if prop == "avatar":
|
||||
user = await utils.parse_mention(ctx.client, value)
|
||||
if user:
|
||||
# Set the avatar to the mentioned user's avatar
|
||||
# Discord doesn't like webp, but also hosts png alternatives
|
||||
value = user.avatar_url.replace(".webp", ".png")
|
||||
else:
|
||||
# Validate URL
|
||||
u = urlparse(value)
|
||||
if u.scheme in ["http", "https"] and u.netloc and u.path:
|
||||
value = value
|
||||
else:
|
||||
raise CommandError("Invalid URL.")
|
||||
else:
|
||||
# Can't clear member name
|
||||
if prop == "name":
|
||||
raise CommandError("Can't clear member name.")
|
||||
|
||||
# Clear from DB
|
||||
value = None
|
||||
|
||||
db_prop = db_properties[prop]
|
||||
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field=db_prop, value=value)
|
||||
|
||||
response = utils.make_default_embed("{} {}'s {}.".format("Updated" if value else "Cleared", ctx.member.name, prop))
|
||||
if prop == "avatar" and value:
|
||||
response.set_image(url=value)
|
||||
if prop == "color" and value:
|
||||
response.colour = int(value, 16)
|
||||
return response
|
||||
|
||||
@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands")
|
||||
async def member_proxy(ctx: MemberCommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
prefix, suffix = None, None
|
||||
else:
|
||||
# Sanity checking
|
||||
example = " ".join(args)
|
||||
if "text" not in example:
|
||||
raise CommandError("Example proxy message must contain the string 'text'.")
|
||||
|
||||
if example.count("text") != 1:
|
||||
raise CommandError("Example proxy message must contain the string 'text' exactly once.")
|
||||
|
||||
# Extract prefix and suffix
|
||||
prefix = example[:example.index("text")].strip()
|
||||
suffix = example[example.index("text")+4:].strip()
|
||||
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
|
||||
|
||||
# DB stores empty strings as None, make that work
|
||||
if not prefix:
|
||||
prefix = None
|
||||
if not suffix:
|
||||
suffix = None
|
||||
|
||||
async with ctx.conn.transaction():
|
||||
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="prefix", value=prefix)
|
||||
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="suffix", value=suffix)
|
||||
return "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
|
||||
|
||||
@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands")
|
||||
async def member_delete(ctx: MemberCommandContext, args: List[str]):
|
||||
await ctx.reply("Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(ctx.member.name, ctx.member.hid))
|
||||
|
||||
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
|
||||
if msg and msg.content == ctx.member.hid:
|
||||
await db.delete_member(ctx.conn, member_id=ctx.member.id)
|
||||
return "Member deleted."
|
||||
else:
|
||||
return "Member deletion cancelled."
|
57
src/pluralkit/bot/commands/message_commands.py
Normal file
57
src/pluralkit/bot/commands/message_commands.py
Normal file
@ -0,0 +1,57 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from pluralkit.bot import utils
|
||||
from pluralkit.bot.commands import *
|
||||
|
||||
logger = logging.getLogger("pluralkit.commands")
|
||||
|
||||
|
||||
@command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.",
|
||||
category="Message commands")
|
||||
async def message_info(ctx: CommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
try:
|
||||
mid = int(args[0])
|
||||
except ValueError:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
# Find the message in the DB
|
||||
message = await db.get_message(ctx.conn, str(mid))
|
||||
if not message:
|
||||
raise CommandError("Message not found.")
|
||||
|
||||
# Get the original sender of the messages
|
||||
try:
|
||||
original_sender = await ctx.client.get_user_info(str(message.sender))
|
||||
except discord.NotFound:
|
||||
# Account was since deleted - rare but we're handling it anyway
|
||||
original_sender = None
|
||||
|
||||
embed = discord.Embed()
|
||||
embed.timestamp = discord.utils.snowflake_time(str(mid))
|
||||
embed.colour = discord.Colour.blue()
|
||||
|
||||
if message.system_name:
|
||||
system_value = "{} (`{}`)".format(message.system_name, message.system_hid)
|
||||
else:
|
||||
system_value = "`{}`".format(message.system_hid)
|
||||
embed.add_field(name="System", value=system_value)
|
||||
|
||||
embed.add_field(name="Member", value="{} (`{}`)".format(message.name, message.hid))
|
||||
|
||||
if original_sender:
|
||||
sender_name = "{}#{}".format(original_sender.name, original_sender.discriminator)
|
||||
else:
|
||||
sender_name = "(deleted account {})".format(message.sender)
|
||||
|
||||
embed.add_field(name="Sent by", value=sender_name)
|
||||
|
||||
if message.content: # Content can be empty string if there's an attachment
|
||||
embed.add_field(name="Content", value=message.content, inline=False)
|
||||
|
||||
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
|
||||
|
||||
return embed
|
89
src/pluralkit/bot/commands/misc_commands.py
Normal file
89
src/pluralkit/bot/commands/misc_commands.py
Normal file
@ -0,0 +1,89 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from discord.utils import oauth_url
|
||||
|
||||
from pluralkit.bot import utils
|
||||
from pluralkit.bot.commands import *
|
||||
|
||||
logger = logging.getLogger("pluralkit.commands")
|
||||
|
||||
@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.")
|
||||
async def show_help(ctx: CommandContext, args: List[str]):
|
||||
embed = utils.make_default_embed("")
|
||||
embed.title = "PluralKit Help"
|
||||
embed.set_footer(text="By Astrid (Ske#6201, or 'qoxvy' on PK) | GitHub: https://github.com/xSke/PluralKit/")
|
||||
|
||||
category = args[0] if len(args) > 0 else None
|
||||
|
||||
from pluralkit.bot.help import help_pages
|
||||
if category in help_pages:
|
||||
for name, text in help_pages[category]:
|
||||
if name:
|
||||
embed.add_field(name=name, value=text)
|
||||
else:
|
||||
embed.description = text
|
||||
else:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
return embed
|
||||
|
||||
@command(cmd="invite", description="Generates an invite link for this bot.")
|
||||
async def invite_link(ctx: CommandContext, args: List[str]):
|
||||
client_id = os.environ["CLIENT_ID"]
|
||||
|
||||
permissions = discord.Permissions()
|
||||
permissions.manage_webhooks = True
|
||||
permissions.send_messages = True
|
||||
permissions.manage_messages = True
|
||||
permissions.embed_links = True
|
||||
permissions.attach_files = True
|
||||
permissions.read_message_history = True
|
||||
permissions.add_reactions = True
|
||||
|
||||
url = oauth_url(client_id, permissions)
|
||||
logger.debug("Sending invite URL: {}".format(url))
|
||||
return url
|
||||
|
||||
@command(cmd="export", description="Exports system data to a machine-readable format.")
|
||||
async def export(ctx: CommandContext, args: List[str]):
|
||||
members = await db.get_all_members(ctx.conn, ctx.system.id)
|
||||
accounts = await db.get_linked_accounts(ctx.conn, ctx.system.id)
|
||||
switches = await utils.get_front_history(ctx.conn, ctx.system.id, 999999)
|
||||
|
||||
system = ctx.system
|
||||
data = {
|
||||
"name": system.name,
|
||||
"id": system.hid,
|
||||
"description": system.description,
|
||||
"tag": system.tag,
|
||||
"avatar_url": system.avatar_url,
|
||||
"created": system.created.isoformat(),
|
||||
"members": [
|
||||
{
|
||||
"name": member.name,
|
||||
"id": member.hid,
|
||||
"color": member.color,
|
||||
"avatar_url": member.avatar_url,
|
||||
"birthday": member.birthday.isoformat() if member.birthday else None,
|
||||
"pronouns": member.pronouns,
|
||||
"description": member.description,
|
||||
"prefix": member.prefix,
|
||||
"suffix": member.suffix,
|
||||
"created": member.created.isoformat()
|
||||
} for member in members
|
||||
],
|
||||
"accounts": [str(uid) for uid in accounts],
|
||||
"switches": [
|
||||
{
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"members": [member.hid for member in members]
|
||||
} for timestamp, members in switches
|
||||
]
|
||||
}
|
||||
|
||||
f = io.BytesIO(json.dumps(data).encode("utf-8"))
|
||||
await ctx.client.send_file(ctx.message.channel, f, filename="system.json")
|
24
src/pluralkit/bot/commands/mod_commands.py
Normal file
24
src/pluralkit/bot/commands/mod_commands.py
Normal file
@ -0,0 +1,24 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from pluralkit.bot import utils
|
||||
from pluralkit.bot.commands import *
|
||||
|
||||
logger = logging.getLogger("pluralkit.commands")
|
||||
|
||||
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")
|
||||
async def set_log(ctx: CommandContext, args: List[str]):
|
||||
if not ctx.message.author.server_permissions.administrator:
|
||||
raise CommandError("You must be a server administrator to use this command.")
|
||||
|
||||
server = ctx.message.server
|
||||
if len(args) == 0:
|
||||
channel_id = None
|
||||
else:
|
||||
channel = utils.parse_channel_mention(args[0], server=server)
|
||||
if not channel:
|
||||
raise CommandError("Channel not found.")
|
||||
channel_id = channel.id
|
||||
|
||||
await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id)
|
||||
return "Updated logging channel." if channel_id else "Cleared logging channel."
|
119
src/pluralkit/bot/commands/switch_commands.py
Normal file
119
src/pluralkit/bot/commands/switch_commands.py
Normal file
@ -0,0 +1,119 @@
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import dateparser
|
||||
import humanize
|
||||
|
||||
from pluralkit import Member
|
||||
from pluralkit.bot import utils
|
||||
from pluralkit.bot.commands import *
|
||||
|
||||
logger = logging.getLogger("pluralkit.commands")
|
||||
|
||||
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
|
||||
async def switch_member(ctx: MemberCommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
members: List[Member] = []
|
||||
for member_name in args:
|
||||
# Find the member
|
||||
member = await utils.get_member_fuzzy(ctx.conn, ctx.system.id, member_name)
|
||||
if not member:
|
||||
raise CommandError("Couldn't find member \"{}\".".format(member_name))
|
||||
members.append(member)
|
||||
|
||||
# Compare requested switch IDs and existing fronter IDs to check for existing switches
|
||||
# Lists, because order matters, it makes sense to just swap fronters
|
||||
member_ids = [member.id for member in members]
|
||||
fronter_ids = (await utils.get_fronter_ids(ctx.conn, ctx.system.id))[0]
|
||||
if member_ids == fronter_ids:
|
||||
if len(members) == 1:
|
||||
raise CommandError("{} is already fronting.".format(members[0].name))
|
||||
raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
|
||||
|
||||
# Also make sure there aren't any duplicates
|
||||
if len(set(member_ids)) != len(member_ids):
|
||||
raise CommandError("Duplicate members in switch list.")
|
||||
|
||||
# Log the switch
|
||||
async with ctx.conn.transaction():
|
||||
switch_id = await db.add_switch(ctx.conn, system_id=ctx.system.id)
|
||||
for member in members:
|
||||
await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id)
|
||||
|
||||
if len(members) == 1:
|
||||
return "Switch registered. Current fronter is now {}.".format(members[0].name)
|
||||
else:
|
||||
return "Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members]))
|
||||
|
||||
@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands")
|
||||
async def switch_out(ctx: MemberCommandContext, args: List[str]):
|
||||
# Get current fronters
|
||||
fronters, _ = await utils.get_fronter_ids(ctx.conn, system_id=ctx.system.id)
|
||||
if not fronters:
|
||||
raise CommandError("There's already no one in front.")
|
||||
|
||||
# Log it, and don't log any members
|
||||
await db.add_switch(ctx.conn, system_id=ctx.system.id)
|
||||
return "Switch-out registered."
|
||||
|
||||
@command(cmd="switch move", usage="<time>", description="Moves the most recent switch to a different point in time.", category="Switching commands")
|
||||
async def switch_move(ctx: MemberCommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
# Parse the time to move to
|
||||
new_time = dateparser.parse(" ".join(args), languages=["en"], settings={
|
||||
"TO_TIMEZONE": "UTC",
|
||||
"RETURN_AS_TIMEZONE_AWARE": False
|
||||
})
|
||||
if not new_time:
|
||||
raise CommandError("{} can't be parsed as a valid time.".format(" ".join(args)))
|
||||
|
||||
# Make sure the time isn't in the future
|
||||
if new_time > datetime.now():
|
||||
raise CommandError("Can't move switch to a time in the future.")
|
||||
|
||||
# Make sure it all runs in a big transaction for atomicity
|
||||
async with ctx.conn.transaction():
|
||||
# Get the last two switches to make sure the switch to move isn't before the second-last switch
|
||||
last_two_switches = await utils.get_front_history(ctx.conn, ctx.system.id, count=2)
|
||||
if len(last_two_switches) == 0:
|
||||
raise CommandError("There are no registered switches for this system.")
|
||||
|
||||
last_timestamp, last_fronters = last_two_switches[0]
|
||||
if len(last_two_switches) > 1:
|
||||
second_last_timestamp, _ = last_two_switches[1]
|
||||
|
||||
if new_time < second_last_timestamp:
|
||||
time_str = humanize.naturaltime(second_last_timestamp)
|
||||
raise CommandError("Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str))
|
||||
|
||||
# Display the confirmation message w/ humanized times
|
||||
members = ", ".join([member.name for member in last_fronters]) or "nobody"
|
||||
last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds")
|
||||
last_relative = humanize.naturaltime(last_timestamp)
|
||||
new_absolute = new_time.isoformat(sep=" ", timespec="seconds")
|
||||
new_relative = humanize.naturaltime(new_time)
|
||||
embed = utils.make_default_embed("This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative))
|
||||
|
||||
# Await and handle confirmation reactions
|
||||
confirm_msg = await ctx.reply(embed=embed)
|
||||
await ctx.client.add_reaction(confirm_msg, "✅")
|
||||
await ctx.client.add_reaction(confirm_msg, "❌")
|
||||
|
||||
reaction = await ctx.client.wait_for_reaction(emoji=["✅", "❌"], message=confirm_msg, user=ctx.message.author, timeout=60.0)
|
||||
if not reaction:
|
||||
raise CommandError("Switch move timed out.")
|
||||
|
||||
if reaction.reaction.emoji == "❌":
|
||||
raise CommandError("Switch move cancelled.")
|
||||
|
||||
# DB requires the actual switch ID which our utility method above doesn't return, do this manually
|
||||
switch_id = (await db.front_history(ctx.conn, ctx.system.id, count=1))[0]["id"]
|
||||
|
||||
# Change the switch in the DB
|
||||
await db.move_last_switch(ctx.conn, ctx.system.id, switch_id, new_time)
|
||||
return "Switch moved."
|
215
src/pluralkit/bot/commands/system_commands.py
Normal file
215
src/pluralkit/bot/commands/system_commands.py
Normal file
@ -0,0 +1,215 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import humanize
|
||||
|
||||
from pluralkit.bot import utils
|
||||
from pluralkit.bot.commands import *
|
||||
|
||||
logger = logging.getLogger("pluralkit.commands")
|
||||
|
||||
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands", system_required=False)
|
||||
async def system_info(ctx: CommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
if not ctx.system:
|
||||
raise NoSystemRegistered()
|
||||
system = ctx.system
|
||||
else:
|
||||
# Look one up
|
||||
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
|
||||
|
||||
if system is None:
|
||||
raise CommandError("Unable to find system \"{}\".".format(args[0]))
|
||||
|
||||
await ctx.reply(embed=await utils.generate_system_info_card(ctx.conn, ctx.client, system))
|
||||
|
||||
@command(cmd="system new", usage="[name]", description="Registers a new system to this account.", category="System commands", system_required=False)
|
||||
async def new_system(ctx: CommandContext, args: List[str]):
|
||||
if ctx.system:
|
||||
raise CommandError("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
|
||||
|
||||
system_name = None
|
||||
if len(args) > 0:
|
||||
system_name = " ".join(args)
|
||||
|
||||
async with ctx.conn.transaction():
|
||||
# TODO: figure out what to do if this errors out on collision on generate_hid
|
||||
hid = utils.generate_hid()
|
||||
|
||||
system = await db.create_system(ctx.conn, system_name=system_name, system_hid=hid)
|
||||
|
||||
# Link account
|
||||
await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
|
||||
return "System registered! To begin adding members, use `pk;member new <name>`."
|
||||
|
||||
@command(cmd="system set", usage="<name|description|tag|avatar> [value]", description="Edits a system property. Leave [value] blank to clear.", category="System commands")
|
||||
async def system_set(ctx: CommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
allowed_properties = ["name", "description", "tag", "avatar"]
|
||||
db_properties = {
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"tag": "tag",
|
||||
"avatar": "avatar_url"
|
||||
}
|
||||
|
||||
prop = args[0]
|
||||
if prop not in allowed_properties:
|
||||
raise CommandError("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)))
|
||||
|
||||
if len(args) >= 2:
|
||||
value = " ".join(args[1:])
|
||||
# Sanity checking
|
||||
if prop == "tag":
|
||||
if len(value) > 32:
|
||||
raise CommandError("Can't have system tag longer than 32 characters.")
|
||||
|
||||
# Make sure there are no members which would make the combined length exceed 32
|
||||
members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=ctx.system.id, length=32 - len(value) - 1)
|
||||
if len(members_exceeding) > 0:
|
||||
# If so, error out and warn
|
||||
member_names = ", ".join([member.name
|
||||
for member in members_exceeding])
|
||||
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
|
||||
raise CommandError("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names))
|
||||
|
||||
if prop == "avatar":
|
||||
user = await utils.parse_mention(ctx.client, value)
|
||||
if user:
|
||||
# Set the avatar to the mentioned user's avatar
|
||||
# Discord doesn't like webp, but also hosts png alternatives
|
||||
value = user.avatar_url.replace(".webp", ".png")
|
||||
else:
|
||||
# Validate URL
|
||||
u = urlparse(value)
|
||||
if u.scheme in ["http", "https"] and u.netloc and u.path:
|
||||
value = value
|
||||
else:
|
||||
raise CommandError("Invalid URL.")
|
||||
else:
|
||||
# Clear from DB
|
||||
value = None
|
||||
|
||||
db_prop = db_properties[prop]
|
||||
await db.update_system_field(ctx.conn, system_id=ctx.system.id, field=db_prop, value=value)
|
||||
|
||||
response = utils.make_default_embed("{} system {}.".format("Updated" if value else "Cleared", prop))
|
||||
if prop == "avatar" and value:
|
||||
response.set_image(url=value)
|
||||
return response
|
||||
|
||||
@command(cmd="system link", usage="<account>", description="Links another account to your system.", category="System commands")
|
||||
async def system_link(ctx: CommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
raise InvalidCommandSyntax()
|
||||
|
||||
# Find account to link
|
||||
linkee = await utils.parse_mention(ctx.client, args[0])
|
||||
if not linkee:
|
||||
raise CommandError("Account not found.")
|
||||
|
||||
# Make sure account doesn't already have a system
|
||||
account_system = await db.get_system_by_account(ctx.conn, linkee.id)
|
||||
if account_system:
|
||||
raise CommandError("Account is already linked to a system (`{}`)".format(account_system.hid))
|
||||
|
||||
# Send confirmation message
|
||||
msg = await ctx.reply("{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
|
||||
await ctx.client.add_reaction(msg, "✅")
|
||||
await ctx.client.add_reaction(msg, "❌")
|
||||
|
||||
reaction = await ctx.client.wait_for_reaction(emoji=["✅", "❌"], message=msg, user=linkee, timeout=60.0)
|
||||
# If account to be linked confirms...
|
||||
if not reaction:
|
||||
raise CommandError("Account link timed out.")
|
||||
if not reaction.reaction.emoji == "✅":
|
||||
raise CommandError("Account link cancelled.")
|
||||
|
||||
await db.link_account(ctx.conn, system_id=ctx.system.id, account_id=linkee.id)
|
||||
return "Account linked to system."
|
||||
|
||||
@command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.", category="System commands")
|
||||
async def system_unlink(ctx: CommandContext, args: List[str]):
|
||||
# Make sure you can't unlink every account
|
||||
linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=ctx.system.id)
|
||||
if len(linked_accounts) == 1:
|
||||
raise CommandError("This is the only account on your system, so you can't unlink it.")
|
||||
|
||||
await db.unlink_account(ctx.conn, system_id=ctx.system.id, account_id=ctx.message.author.id)
|
||||
return "Account unlinked."
|
||||
|
||||
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands", system_required=False)
|
||||
async def system_fronter(ctx: CommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
if not ctx.system:
|
||||
raise NoSystemRegistered()
|
||||
else:
|
||||
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
|
||||
|
||||
if system is None:
|
||||
raise CommandError("Can't find system \"{}\".".format(args[0]))
|
||||
|
||||
fronters, timestamp = await utils.get_fronters(ctx.conn, system_id=ctx.system.id)
|
||||
fronter_names = [member.name for member in fronters]
|
||||
|
||||
embed = utils.make_default_embed(None)
|
||||
|
||||
if len(fronter_names) == 0:
|
||||
embed.add_field(name="Current fronter", value="(no fronter)")
|
||||
elif len(fronter_names) == 1:
|
||||
embed.add_field(name="Current fronter", value=fronter_names[0])
|
||||
else:
|
||||
embed.add_field(name="Current fronters", value=", ".join(fronter_names))
|
||||
|
||||
if timestamp:
|
||||
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp)))
|
||||
return embed
|
||||
|
||||
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands", system_required=False)
|
||||
async def system_fronthistory(ctx: CommandContext, args: List[str]):
|
||||
if len(args) == 0:
|
||||
if not ctx.system:
|
||||
raise NoSystemRegistered()
|
||||
else:
|
||||
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
|
||||
|
||||
if system is None:
|
||||
raise CommandError("Can't find system \"{}\".".format(args[0]))
|
||||
|
||||
lines = []
|
||||
front_history = await utils.get_front_history(ctx.conn, ctx.system.id, count=10)
|
||||
for i, (timestamp, members) in enumerate(front_history):
|
||||
# Special case when no one's fronting
|
||||
if len(members) == 0:
|
||||
name = "(no fronter)"
|
||||
else:
|
||||
name = ", ".join([member.name for member in members])
|
||||
|
||||
# Make proper date string
|
||||
time_text = timestamp.isoformat(sep=" ", timespec="seconds")
|
||||
rel_text = humanize.naturaltime(timestamp)
|
||||
|
||||
delta_text = ""
|
||||
if i > 0:
|
||||
last_switch_time = front_history[i-1][0]
|
||||
delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time))
|
||||
lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text))
|
||||
|
||||
embed = utils.make_default_embed("\n".join(lines) or "(none)")
|
||||
embed.title = "Past switches"
|
||||
return embed
|
||||
|
||||
|
||||
@command(cmd="system delete", description="Deletes your system from the database ***permanently***.", category="System commands")
|
||||
async def system_delete(ctx: CommandContext, args: List[str]):
|
||||
await ctx.reply("Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(ctx.system.hid))
|
||||
|
||||
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
|
||||
if msg and msg.content == ctx.system.hid:
|
||||
await db.remove_system(ctx.conn, system_id=ctx.system.id)
|
||||
return "System deleted."
|
||||
else:
|
||||
return "System deletion cancelled."
|
@ -10,7 +10,9 @@ help_pages = {
|
||||
`pk;help proxy` - Details on message proxying.
|
||||
`pk;help switch` - Details on switch logging.
|
||||
`pk;help mod` - Details on moderator operations.
|
||||
`pk;help import` - Details on data import from other services.""")
|
||||
`pk;help import` - Details on data import from other services."""),
|
||||
("Discord",
|
||||
"""For feedback, bug reports, suggestions, or just chatting, join our Discord: https://discord.gg/PczBt78""")
|
||||
],
|
||||
"system": [
|
||||
("Registering a new system",
|
297
src/pluralkit/bot/proxy.py
Normal file
297
src/pluralkit/bot/proxy.py
Normal file
@ -0,0 +1,297 @@
|
||||
import ciso8601
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import channel_logger, utils
|
||||
|
||||
logger = logging.getLogger("pluralkit.bot.proxy")
|
||||
|
||||
def extract_leading_mentions(message_text):
|
||||
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
|
||||
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
|
||||
if not match:
|
||||
return message_text, ""
|
||||
|
||||
# Return the text after the mentions, and the mentions themselves
|
||||
return message_text[match.span(0)[1]:].strip(), match.group(0)
|
||||
|
||||
|
||||
def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
|
||||
# Skip members with no defined proxy tags
|
||||
if not member.prefix and not member.suffix:
|
||||
return None
|
||||
|
||||
# DB defines empty prefix/suffixes as None, replace with empty strings to prevent errors
|
||||
prefix = member.prefix or ""
|
||||
suffix = member.suffix or ""
|
||||
|
||||
# Ignore mentions at the very start of the message, and match proxy tags after those
|
||||
message_text, leading_mentions = extract_leading_mentions(message_text)
|
||||
|
||||
logger.debug("Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions, prefix, suffix))
|
||||
|
||||
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
|
||||
prefix_length = len(prefix)
|
||||
suffix_length = len(suffix)
|
||||
|
||||
# If suffix_length is 0, the last bit of the slice will be "-0", and the slice will fail
|
||||
if suffix_length > 0:
|
||||
inner_string = message_text[prefix_length:-suffix_length]
|
||||
else:
|
||||
inner_string = message_text[prefix_length:]
|
||||
|
||||
# Add the mentions we stripped back
|
||||
inner_string = leading_mentions + inner_string
|
||||
return inner_string
|
||||
|
||||
|
||||
def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
|
||||
# Sort by specificity (members with both prefix and suffix go higher)
|
||||
# This will make sure more "precise" proxy tags get tried first
|
||||
members: List[db.ProxyMember] = sorted(members, key=lambda x: int(
|
||||
bool(x.prefix)) + int(bool(x.suffix)), reverse=True)
|
||||
|
||||
for member in members:
|
||||
match = match_member_proxy_tags(member, message_text)
|
||||
if match is not None: # Using "is not None" because an empty string is OK here too
|
||||
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
|
||||
return member, match
|
||||
|
||||
|
||||
def get_message_attachment_url(message: discord.Message):
|
||||
if not message.attachments:
|
||||
return None
|
||||
|
||||
attachment = message.attachments[0]
|
||||
if "proxy_url" in attachment:
|
||||
return attachment["proxy_url"]
|
||||
|
||||
if "url" in attachment:
|
||||
return attachment["url"]
|
||||
|
||||
|
||||
# TODO: possibly move this to bot __init__ so commands can access it too
|
||||
class WebhookPermissionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeletionPermissionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Proxy:
|
||||
def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
|
||||
self.logger = logging.getLogger("pluralkit.bot.proxy")
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.client = client
|
||||
self.token = token
|
||||
self.channel_logger = logger
|
||||
|
||||
async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
|
||||
await db.add_webhook(conn, channel.id, id, token)
|
||||
return id, token
|
||||
|
||||
async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
|
||||
# This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
|
||||
# No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
|
||||
req_headers = {"Authorization": "Bot {}".format(self.token)}
|
||||
|
||||
# First, check if there's already a webhook belonging to the bot
|
||||
async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
|
||||
headers=req_headers) as resp:
|
||||
if resp.status == 200:
|
||||
webhooks = await resp.json()
|
||||
for webhook in webhooks:
|
||||
if webhook["user"]["id"] == self.client.user.id:
|
||||
# This webhook belongs to us, we can use that, return it and save it
|
||||
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
|
||||
elif resp.status == 403:
|
||||
self.logger.warning(
|
||||
"Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
|
||||
channel.id))
|
||||
raise WebhookPermissionError()
|
||||
else:
|
||||
raise discord.HTTPException(resp, await resp.text())
|
||||
|
||||
# Then, try submitting a new one
|
||||
req_data = {"name": "PluralKit Proxy Webhook"}
|
||||
async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
|
||||
json=req_data, headers=req_headers) as resp:
|
||||
if resp.status == 200:
|
||||
webhook = await resp.json()
|
||||
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
|
||||
elif resp.status == 403:
|
||||
self.logger.warning(
|
||||
"Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
|
||||
channel.id))
|
||||
raise WebhookPermissionError()
|
||||
else:
|
||||
raise discord.HTTPException(resp, await resp.text())
|
||||
|
||||
# Should not be reached without an exception being thrown
|
||||
|
||||
async def get_webhook_for_channel(self, conn, channel: discord.Channel):
|
||||
async with conn.transaction():
|
||||
hook_match = await db.get_webhook(conn, channel.id)
|
||||
if not hook_match:
|
||||
# We don't have a webhook, create/add one
|
||||
return await self.create_and_add_channel_webhook(conn, channel)
|
||||
else:
|
||||
return hook_match
|
||||
|
||||
async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
|
||||
attachment_url: str, has_already_retried=False):
|
||||
hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
|
||||
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
|
||||
|
||||
if text:
|
||||
form_data.add_field("content", text)
|
||||
|
||||
if attachment_url:
|
||||
attachment_resp = await self.session.get(attachment_url)
|
||||
form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
|
||||
filename=attachment_resp.url.name)
|
||||
|
||||
if member.avatar_url:
|
||||
form_data.add_field("avatar_url", member.avatar_url)
|
||||
|
||||
async with self.session.post(
|
||||
"https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
|
||||
data=form_data) as resp:
|
||||
if resp.status == 200:
|
||||
message = await resp.json()
|
||||
|
||||
await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id,
|
||||
text or "")
|
||||
|
||||
try:
|
||||
await self.client.delete_message(original_message)
|
||||
except discord.Forbidden:
|
||||
self.logger.warning(
|
||||
"Did not have permission to delete original message (server={}, channel={})".format(
|
||||
original_message.server.id, original_message.channel.id))
|
||||
raise DeletionPermissionError()
|
||||
except discord.NotFound:
|
||||
self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
|
||||
|
||||
message_image = None
|
||||
if message["attachments"]:
|
||||
first_attachment = message["attachments"][0]
|
||||
if "width" in first_attachment and "height" in first_attachment:
|
||||
# Only log attachments that are actually images
|
||||
message_image = first_attachment["url"]
|
||||
|
||||
await self.channel_logger.log_message_proxied(conn,
|
||||
server_id=original_message.server.id,
|
||||
channel_name=original_message.channel.name,
|
||||
channel_id=original_message.channel.id,
|
||||
sender_name=original_message.author.name,
|
||||
sender_disc=original_message.author.discriminator,
|
||||
member_name=member.name,
|
||||
member_hid=member.hid,
|
||||
member_avatar_url=member.avatar_url,
|
||||
system_name=member.system_name,
|
||||
system_hid=member.system_hid,
|
||||
message_text=text,
|
||||
message_image=message_image,
|
||||
message_timestamp=ciso8601.parse_datetime(
|
||||
message["timestamp"]),
|
||||
message_id=message["id"])
|
||||
elif resp.status == 404 and not has_already_retried:
|
||||
# Webhook doesn't exist. Delete it from the DB, create, and add a new one
|
||||
self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
|
||||
await db.delete_webhook(conn, original_message.channel.id)
|
||||
await self.create_and_add_channel_webhook(conn, original_message.channel)
|
||||
|
||||
# Then try again all over, making sure to not retry again and go in a loop should it continually fail
|
||||
return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
|
||||
else:
|
||||
raise discord.HTTPException(resp, await resp.text())
|
||||
|
||||
async def try_proxy_message(self, conn, message: discord.Message):
|
||||
# Can't proxy in DMs, webhook creation will explode
|
||||
if message.channel.is_private:
|
||||
return False
|
||||
|
||||
# Big fat query to find every member associated with this account
|
||||
# Returned member object has a few more keys (system tag, for example)
|
||||
members = await db.get_members_by_account(conn, account_id=message.author.id)
|
||||
|
||||
match = match_proxy_tags(members, message.content)
|
||||
if not match:
|
||||
return False
|
||||
|
||||
member, text = match
|
||||
attachment_url = get_message_attachment_url(message)
|
||||
|
||||
# Can't proxy a message with no text AND no attachment
|
||||
if not text and not attachment_url:
|
||||
self.logger.debug("Skipping message because of no text and no attachment")
|
||||
return False
|
||||
|
||||
try:
|
||||
async with conn.transaction():
|
||||
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
|
||||
except WebhookPermissionError:
|
||||
embed = utils.make_error_embed("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
|
||||
await self.client.send_message(message.channel, embed=embed)
|
||||
except DeletionPermissionError:
|
||||
embed = utils.make_error_embed("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
|
||||
await self.client.send_message(message.channel, embed=embed)
|
||||
|
||||
return True
|
||||
|
||||
async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
|
||||
async with conn.transaction():
|
||||
# Find the message in the DB, and make sure it's sent by the user (if we need to check)
|
||||
if check_user_id:
|
||||
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
|
||||
else:
|
||||
db_message = await db.get_message(conn, message_id=message_id)
|
||||
|
||||
if db_message:
|
||||
self.logger.debug("Deleting message {}".format(message_id))
|
||||
channel = self.client.get_channel(str(db_message.channel))
|
||||
|
||||
# If we should also delete the actual message, do that
|
||||
if delete_message:
|
||||
message = await self.client.get_message(channel, message_id)
|
||||
|
||||
try:
|
||||
await self.client.delete_message(message)
|
||||
except discord.Forbidden:
|
||||
self.logger.warning(
|
||||
"Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
|
||||
channel.server.id, channel.id))
|
||||
return
|
||||
|
||||
# Remove it from the DB
|
||||
await db.delete_message(conn, message_id)
|
||||
|
||||
# Then log deletion to logging channel
|
||||
await self.channel_logger.log_message_deleted(conn,
|
||||
server_id=channel.server.id,
|
||||
channel_name=channel.name,
|
||||
member_name=db_message.name,
|
||||
member_hid=db_message.hid,
|
||||
member_avatar_url=db_message.avatar_url,
|
||||
system_name=db_message.system_name,
|
||||
system_hid=db_message.system_hid,
|
||||
message_text=db_message.content,
|
||||
message_id=message_id,
|
||||
deleted_by_moderator=deleted_by_moderator)
|
||||
|
||||
async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
|
||||
if emoji == "❌":
|
||||
await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
|
||||
|
||||
async def handle_deletion(self, conn, message_id: str):
|
||||
# Don't delete the message, it's already gone at this point, just handle DB deletion and logging
|
||||
await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)
|
219
src/pluralkit/bot/utils.py
Normal file
219
src/pluralkit/bot/utils.py
Normal file
@ -0,0 +1,219 @@
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
import string
|
||||
|
||||
import asyncio
|
||||
import asyncpg
|
||||
import discord
|
||||
import humanize
|
||||
|
||||
from pluralkit import System, Member, db
|
||||
|
||||
logger = logging.getLogger("pluralkit.utils")
|
||||
|
||||
def escape(s):
|
||||
return s.replace("`", "\\`")
|
||||
|
||||
def generate_hid() -> str:
|
||||
return "".join(random.choices(string.ascii_lowercase, k=5))
|
||||
|
||||
def bounds_check_member_name(new_name, system_tag):
|
||||
if len(new_name) > 32:
|
||||
return "Name cannot be longer than 32 characters."
|
||||
|
||||
if system_tag:
|
||||
if len("{} {}".format(new_name, system_tag)) > 32:
|
||||
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
|
||||
|
||||
async def parse_mention(client: discord.Client, mention: str) -> discord.User:
|
||||
# First try matching mention format
|
||||
match = re.fullmatch("<@!?(\\d+)>", mention)
|
||||
if match:
|
||||
try:
|
||||
return await client.get_user_info(match.group(1))
|
||||
except discord.NotFound:
|
||||
return None
|
||||
|
||||
# Then try with just ID
|
||||
try:
|
||||
return await client.get_user_info(str(int(mention)))
|
||||
except (ValueError, discord.NotFound):
|
||||
return None
|
||||
|
||||
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
|
||||
match = re.fullmatch("<#(\\d+)>", mention)
|
||||
if match:
|
||||
return server.get_channel(match.group(1))
|
||||
|
||||
try:
|
||||
return server.get_channel(str(int(mention)))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def get_fronter_ids(conn, system_id) -> (List[int], datetime):
|
||||
switches = await db.front_history(conn, system_id=system_id, count=1)
|
||||
if not switches:
|
||||
return [], None
|
||||
|
||||
if not switches[0]["members"]:
|
||||
return [], switches[0]["timestamp"]
|
||||
|
||||
return switches[0]["members"], switches[0]["timestamp"]
|
||||
|
||||
async def get_fronters(conn, system_id) -> (List[Member], datetime):
|
||||
member_ids, timestamp = await get_fronter_ids(conn, system_id)
|
||||
|
||||
# Collect in dict and then look up as list, to preserve return order
|
||||
members = {member.id: member for member in await db.get_members(conn, member_ids)}
|
||||
return [members[member_id] for member_id in member_ids], timestamp
|
||||
|
||||
async def get_front_history(conn, system_id, count) -> List[Tuple[datetime, List[Member]]]:
|
||||
# Get history from DB
|
||||
switches = await db.front_history(conn, system_id=system_id, count=count)
|
||||
if not switches:
|
||||
return []
|
||||
|
||||
# Get all unique IDs referenced
|
||||
all_member_ids = {id for switch in switches for id in switch["members"]}
|
||||
|
||||
# And look them up in the database into a dict
|
||||
all_members = {member.id: member for member in await db.get_members(conn, list(all_member_ids))}
|
||||
|
||||
# Collect in array and return
|
||||
out = []
|
||||
for switch in switches:
|
||||
timestamp = switch["timestamp"]
|
||||
members = [all_members[id] for id in switch["members"]]
|
||||
out.append((timestamp, members))
|
||||
return out
|
||||
|
||||
async def get_system_fuzzy(conn, client: discord.Client, key) -> System:
|
||||
if isinstance(key, discord.User):
|
||||
return await db.get_system_by_account(conn, account_id=key.id)
|
||||
|
||||
if isinstance(key, str) and len(key) == 5:
|
||||
return await db.get_system_by_hid(conn, system_hid=key)
|
||||
|
||||
account = await parse_mention(client, key)
|
||||
if account:
|
||||
system = await db.get_system_by_account(conn, account_id=account.id)
|
||||
if system:
|
||||
return system
|
||||
return None
|
||||
|
||||
|
||||
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> Member:
|
||||
# First search by hid
|
||||
if system_only:
|
||||
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
|
||||
else:
|
||||
member = await db.get_member_by_hid(conn, member_hid=key)
|
||||
if member is not None:
|
||||
return member
|
||||
|
||||
# Then search by name, if we have a system
|
||||
if system_id:
|
||||
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
|
||||
if member is not None:
|
||||
return member
|
||||
|
||||
def make_default_embed(message):
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.blue()
|
||||
embed.description = message
|
||||
return embed
|
||||
|
||||
def make_error_embed(message):
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.dark_red()
|
||||
embed.description = message
|
||||
return embed
|
||||
|
||||
|
||||
async def generate_system_info_card(conn, client: discord.Client, system: System) -> discord.Embed:
|
||||
card = discord.Embed()
|
||||
card.colour = discord.Colour.blue()
|
||||
|
||||
if system.name:
|
||||
card.title = system.name
|
||||
|
||||
if system.avatar_url:
|
||||
card.set_thumbnail(url=system.avatar_url)
|
||||
|
||||
if system.tag:
|
||||
card.add_field(name="Tag", value=system.tag)
|
||||
|
||||
fronters, switch_time = await get_fronters(conn, system.id)
|
||||
if fronters:
|
||||
names = ", ".join([member.name for member in fronters])
|
||||
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time))
|
||||
card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters", value=fronter_val)
|
||||
|
||||
account_names = []
|
||||
for account_id in await db.get_linked_accounts(conn, system_id=system.id):
|
||||
account = await client.get_user_info(account_id)
|
||||
account_names.append("{}#{}".format(account.name, account.discriminator))
|
||||
card.add_field(name="Linked accounts", value="\n".join(account_names))
|
||||
|
||||
if system.description:
|
||||
card.add_field(name="Description",
|
||||
value=system.description, inline=False)
|
||||
|
||||
# Get names of all members
|
||||
member_texts = []
|
||||
for member in await db.get_all_members(conn, system_id=system.id):
|
||||
member_texts.append("{} (`{}`)".format(escape(member.name), member.hid))
|
||||
|
||||
if len(member_texts) > 0:
|
||||
card.add_field(name="Members", value="\n".join(
|
||||
member_texts), inline=False)
|
||||
|
||||
card.set_footer(text="System ID: {}".format(system.hid))
|
||||
return card
|
||||
|
||||
|
||||
async def generate_member_info_card(conn, member: Member) -> discord.Embed:
|
||||
system = await db.get_system(conn, system_id=member.system)
|
||||
|
||||
card = discord.Embed()
|
||||
card.colour = discord.Colour.blue()
|
||||
|
||||
name_and_system = member.name
|
||||
if system.name:
|
||||
name_and_system += " ({})".format(system.name)
|
||||
|
||||
card.set_author(name=name_and_system, icon_url=member.avatar_url or discord.Embed.Empty)
|
||||
if member.avatar_url:
|
||||
card.set_thumbnail(url=member.avatar_url)
|
||||
|
||||
# Get system name and hid
|
||||
system = await db.get_system(conn, system_id=member.system)
|
||||
|
||||
if member.color:
|
||||
card.colour = int(member.color, 16)
|
||||
|
||||
if member.birthday:
|
||||
bday_val = member.birthday.strftime("%b %d, %Y")
|
||||
if member.birthday.year == 1:
|
||||
bday_val = member.birthday.strftime("%b %d")
|
||||
card.add_field(name="Birthdate", value=bday_val)
|
||||
|
||||
if member.pronouns:
|
||||
card.add_field(name="Pronouns", value=member.pronouns)
|
||||
|
||||
if member.prefix or member.suffix:
|
||||
prefix = member.prefix or ""
|
||||
suffix = member.suffix or ""
|
||||
card.add_field(name="Proxy Tags",
|
||||
value="{}text{}".format(prefix, suffix))
|
||||
|
||||
if member.description:
|
||||
card.add_field(name="Description",
|
||||
value=member.description, inline=False)
|
||||
|
||||
card.set_footer(text="System ID: {} | Member ID: {}".format(
|
||||
system.hid, member.hid))
|
||||
return card
|
@ -1,11 +1,15 @@
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
import asyncpg
|
||||
import asyncpg.exceptions
|
||||
|
||||
from pluralkit import stats
|
||||
from pluralkit.bot import logger
|
||||
from pluralkit import System, Member, stats
|
||||
|
||||
logger = logging.getLogger("pluralkit.db")
|
||||
async def connect():
|
||||
while True:
|
||||
try:
|
||||
@ -13,7 +17,6 @@ async def connect():
|
||||
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
|
||||
pass
|
||||
|
||||
|
||||
def db_wrap(func):
|
||||
async def inner(*args, **kwargs):
|
||||
before = time.perf_counter()
|
||||
@ -31,10 +34,11 @@ def db_wrap(func):
|
||||
return inner
|
||||
|
||||
@db_wrap
|
||||
async def create_system(conn, system_name: str, system_hid: str):
|
||||
async def create_system(conn, system_name: str, system_hid: str) -> System:
|
||||
logger.debug("Creating system (name={}, hid={})".format(
|
||||
system_name, system_hid))
|
||||
return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
|
||||
row = await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
|
||||
return System(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
@ -44,10 +48,11 @@ async def remove_system(conn, system_id: int):
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def create_member(conn, system_id: int, member_name: str, member_hid: str):
|
||||
async def create_member(conn, system_id: int, member_name: str, member_hid: str) -> Member:
|
||||
logger.debug("Creating member (system={}, name={}, hid={})".format(
|
||||
system_id, member_name, member_hid))
|
||||
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
|
||||
row = await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
|
||||
return Member(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
@ -71,52 +76,54 @@ async def unlink_account(conn, system_id: int, account_id: str):
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_linked_accounts(conn, system_id: int):
|
||||
async def get_linked_accounts(conn, system_id: int) -> List[int]:
|
||||
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_system_by_account(conn, account_id: str):
|
||||
return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
|
||||
async def get_system_by_account(conn, account_id: str) -> System:
|
||||
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
|
||||
return System(**row) if row else None
|
||||
|
||||
@db_wrap
|
||||
async def get_system_by_hid(conn, system_hid: str) -> System:
|
||||
row = await conn.fetchrow("select * from systems where hid = $1", system_hid)
|
||||
return System(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_system_by_hid(conn, system_hid: str):
|
||||
return await conn.fetchrow("select * from systems where hid = $1", system_hid)
|
||||
async def get_system(conn, system_id: int) -> System:
|
||||
row = await conn.fetchrow("select * from systems where id = $1", system_id)
|
||||
return System(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_system(conn, system_id: int):
|
||||
return await conn.fetchrow("select * from systems where id = $1", system_id)
|
||||
async def get_member_by_name(conn, system_id: int, member_name: str) -> Member:
|
||||
row = await conn.fetchrow("select * from members where system = $1 and lower(name) = lower($2)", system_id, member_name)
|
||||
return Member(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_name(conn, system_id: int, member_name: str):
|
||||
return await conn.fetchrow("select * from members where system = $1 and lower(name) = lower($2)", system_id, member_name)
|
||||
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str) -> Member:
|
||||
row = await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
|
||||
return Member(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str):
|
||||
return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
|
||||
async def get_member_by_hid(conn, member_hid: str) -> Member:
|
||||
row = await conn.fetchrow("select * from members where hid = $1", member_hid)
|
||||
return Member(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_hid(conn, member_hid: str):
|
||||
return await conn.fetchrow("select * from members where hid = $1", member_hid)
|
||||
|
||||
async def get_member(conn, member_id: int) -> Member:
|
||||
row = await conn.fetchrow("select * from members where id = $1", member_id)
|
||||
return Member(**row) if row else None
|
||||
|
||||
@db_wrap
|
||||
async def get_member(conn, member_id: int):
|
||||
return await conn.fetchrow("select * from members where id = $1", member_id)
|
||||
|
||||
@db_wrap
|
||||
async def get_members(conn, members: list):
|
||||
return await conn.fetch("select * from members where id = any($1)", members)
|
||||
|
||||
@db_wrap
|
||||
async def get_message(conn, message_id: str):
|
||||
return await conn.fetchrow("select * from messages where mid = $1", message_id)
|
||||
|
||||
async def get_members(conn, members: list) -> List[Member]:
|
||||
rows = await conn.fetch("select * from members where id = any($1)", members)
|
||||
return [Member(**row) for row in rows]
|
||||
|
||||
@db_wrap
|
||||
async def update_system_field(conn, system_id: int, field: str, value):
|
||||
@ -133,18 +140,20 @@ async def update_member_field(conn, member_id: int, field: str, value):
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_all_members(conn, system_id: int):
|
||||
return await conn.fetch("select * from members where system = $1", system_id)
|
||||
async def get_all_members(conn, system_id: int) -> List[Member]:
|
||||
rows = await conn.fetch("select * from members where system = $1", system_id)
|
||||
return [Member(**row) for row in rows]
|
||||
|
||||
@db_wrap
|
||||
async def get_members_exceeding(conn, system_id: int, length: int) -> List[Member]:
|
||||
rows = await conn.fetch("select * from members where system = $1 and length(name) > $2", system_id, length)
|
||||
return [Member(**row) for row in rows]
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_members_exceeding(conn, system_id: int, length: int):
|
||||
return await conn.fetch("select * from members where system = $1 and length(name) > $2", system_id, length)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_webhook(conn, channel_id: str):
|
||||
return await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
|
||||
async def get_webhook(conn, channel_id: str) -> (str, str):
|
||||
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
|
||||
return (str(row["webhook"]), row["token"]) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
@ -153,6 +162,9 @@ async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str
|
||||
channel_id, webhook_id, webhook_token))
|
||||
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
|
||||
|
||||
@db_wrap
|
||||
async def delete_webhook(conn, channel_id: str):
|
||||
await conn.execute("delete from webhooks where channel = $1", int(channel_id))
|
||||
|
||||
@db_wrap
|
||||
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str, content: str):
|
||||
@ -160,11 +172,22 @@ async def add_message(conn, message_id: str, channel_id: str, member_id: int, se
|
||||
message_id, channel_id, member_id, sender_id))
|
||||
await conn.execute("insert into messages (mid, channel, member, sender, content) values ($1, $2, $3, $4, $5)", int(message_id), int(channel_id), member_id, int(sender_id), content)
|
||||
|
||||
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
|
||||
id: int
|
||||
hid: str
|
||||
prefix: str
|
||||
suffix: str
|
||||
color: str
|
||||
name: str
|
||||
avatar_url: str
|
||||
tag: str
|
||||
system_name: str
|
||||
system_hid: str
|
||||
|
||||
@db_wrap
|
||||
async def get_members_by_account(conn, account_id: str):
|
||||
async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
|
||||
# Returns a "chimera" object
|
||||
return await conn.fetch("""select
|
||||
rows = await conn.fetch("""select
|
||||
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
|
||||
systems.tag, systems.name as system_name, systems.hid as system_hid
|
||||
from
|
||||
@ -173,11 +196,23 @@ async def get_members_by_account(conn, account_id: str):
|
||||
accounts.uid = $1
|
||||
and systems.id = accounts.system
|
||||
and members.system = systems.id""", int(account_id))
|
||||
return [ProxyMember(**row) for row in rows]
|
||||
|
||||
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "content", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
|
||||
mid: int
|
||||
channel: int
|
||||
member: int
|
||||
content: str
|
||||
sender: int
|
||||
name: str
|
||||
hid: str
|
||||
avatar_url: str
|
||||
system_name: str
|
||||
system_hid: str
|
||||
|
||||
@db_wrap
|
||||
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str):
|
||||
return await conn.fetchrow("""select
|
||||
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) -> MessageInfo:
|
||||
row = await conn.fetchrow("""select
|
||||
messages.*,
|
||||
members.name, members.hid, members.avatar_url,
|
||||
systems.name as system_name, systems.hid as system_hid
|
||||
@ -186,7 +221,24 @@ async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str):
|
||||
where
|
||||
messages.member = members.id
|
||||
and members.system = systems.id
|
||||
and mid = $1 and sender = $2""", int(message_id), int(sender_id))
|
||||
and mid = $1
|
||||
and sender = $2""", int(message_id), int(sender_id))
|
||||
return MessageInfo(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_message(conn, message_id: str) -> MessageInfo:
|
||||
row = await conn.fetchrow("""select
|
||||
messages.*,
|
||||
members.name, members.hid, members.avatar_url,
|
||||
systems.name as system_name, systems.hid as system_hid
|
||||
from
|
||||
messages, members, systems
|
||||
where
|
||||
messages.member = members.id
|
||||
and members.system = systems.id
|
||||
and mid = $1""", int(message_id))
|
||||
return MessageInfo(**row) if row else None
|
||||
|
||||
|
||||
@db_wrap
|
||||
@ -215,7 +267,7 @@ async def add_switch(conn, system_id: int):
|
||||
return res["id"]
|
||||
|
||||
@db_wrap
|
||||
async def move_last_switch(conn, system_id: int, switch_id: int, new_time):
|
||||
async def move_last_switch(conn, system_id: int, switch_id: int, new_time: datetime):
|
||||
logger.debug("Moving latest switch (system={}, id={}, new_time={})".format(system_id, switch_id, new_time))
|
||||
await conn.execute("update switches set timestamp = $1 where system = $2 and id = $3", new_time, system_id, switch_id)
|
||||
|
||||
@ -235,19 +287,19 @@ async def update_server(conn, server_id: str, logging_channel_id: str):
|
||||
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
|
||||
|
||||
@db_wrap
|
||||
async def member_count(conn):
|
||||
async def member_count(conn) -> int:
|
||||
return await conn.fetchval("select count(*) from members")
|
||||
|
||||
@db_wrap
|
||||
async def system_count(conn):
|
||||
async def system_count(conn) -> int:
|
||||
return await conn.fetchval("select count(*) from systems")
|
||||
|
||||
@db_wrap
|
||||
async def message_count(conn):
|
||||
async def message_count(conn) -> int:
|
||||
return await conn.fetchval("select count(*) from messages")
|
||||
|
||||
@db_wrap
|
||||
async def account_count(conn):
|
||||
async def account_count(conn) -> int:
|
||||
return await conn.fetchval("select count(*) from accounts")
|
||||
|
||||
async def create_tables(conn):
|
||||
@ -283,7 +335,7 @@ async def create_tables(conn):
|
||||
channel bigint not null,
|
||||
member serial not null references members(id) on delete cascade,
|
||||
content text not null,
|
||||
sender bigint not null references accounts(uid)
|
||||
sender bigint not null
|
||||
)""")
|
||||
await conn.execute("""create table if not exists switches (
|
||||
id serial primary key,
|
@ -1,7 +1,5 @@
|
||||
from aioinflux import InfluxDBClient
|
||||
|
||||
from pluralkit.bot import logger
|
||||
|
||||
client = None
|
||||
async def connect():
|
||||
global client
|
7
src/requirements.txt
Normal file
7
src/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
aiohttp
|
||||
aioinflux
|
||||
asyncpg
|
||||
dateparser
|
||||
discord.py
|
||||
humanize
|
||||
uvloop
|
Loading…
Reference in New Issue
Block a user