Format code

This commit is contained in:
Ske 2018-07-12 00:49:02 +02:00
parent b81b768b04
commit 902c14ef65
6 changed files with 157 additions and 65 deletions

View File

@ -1 +1 @@
from . import commands, db, proxy from . import commands, db, proxy

View File

@ -12,23 +12,28 @@ logger.setLevel(logging.DEBUG)
client = discord.Client() client = discord.Client()
@client.event @client.event
async def on_error(evt, *args, **kwargs): async def on_error(evt, *args, **kwargs):
logger.exception("Error while handling event {} with arguments {}:".format(evt, args)) logger.exception(
"Error while handling event {} with arguments {}:".format(evt, args))
@client.event @client.event
async def on_ready(): async def on_ready():
# Print status info # Print status info
logger.info("Connected to Discord.") logger.info("Connected to Discord.")
logger.info("Account: {}#{}".format(client.user.name, client.user.discriminator)) logger.info("Account: {}#{}".format(
client.user.name, client.user.discriminator))
logger.info("User ID: {}".format(client.user.id)) logger.info("User ID: {}".format(client.user.id))
@client.event @client.event
async def on_message(message): async def on_message(message):
# Ignore bot messages # Ignore bot messages
if message.author.bot: if message.author.bot:
return return
# Split into args. shlex sucks so we don't bother with quotes # Split into args. shlex sucks so we don't bother with quotes
args = message.content.split(" ") args = message.content.split(" ")
@ -54,6 +59,7 @@ async def on_message(message):
async with client.pool.acquire() as conn: async with client.pool.acquire() as conn:
await proxy.handle_proxying(conn, message) await proxy.handle_proxying(conn, message)
@client.event @client.event
async def on_reaction_add(reaction, user): async def on_reaction_add(reaction, user):
from pluralkit import proxy from pluralkit import proxy
@ -62,6 +68,7 @@ async def on_reaction_add(reaction, user):
async with client.pool.acquire() as conn: async with client.pool.acquire() as conn:
await proxy.handle_reaction(conn, reaction, user) await proxy.handle_reaction(conn, reaction, user)
async def run(): async def run():
from pluralkit import db from pluralkit import db
try: try:
@ -79,4 +86,4 @@ async def run():
await client.start(os.environ["TOKEN"]) await client.start(os.environ["TOKEN"])
finally: finally:
logger.info("Logging out from Discord...") logger.info("Logging out from Discord...")
await client.logout() await client.logout()

View File

@ -8,6 +8,7 @@ from pluralkit import db
from pluralkit.bot import client, logger from pluralkit.bot import client, logger
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map
@command(cmd="pk;system", subcommand=None, description="Shows information about your system.") @command(cmd="pk;system", subcommand=None, description="Shows information about your system.")
async def this_system_info(conn, message, args): async def this_system_info(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
@ -18,6 +19,7 @@ async def this_system_info(conn, message, args):
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
return True return True
@command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.") @command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.")
async def new_system(conn, message, args): async def new_system(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
@ -39,12 +41,13 @@ async def new_system(conn, message, args):
await db.link_account(conn, system_id=system["id"], account_id=message.author.id) await db.link_account(conn, system_id=system["id"], account_id=message.author.id)
return True, "System registered! To begin adding members, use `pk;member new <name>`." return True, "System registered! To begin adding members, use `pk;member new <name>`."
@command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.") @command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.")
async def system_info(conn, message, args): async def system_info(conn, message, args):
if len(args) == 0: if len(args) == 0:
# Use sender's system # Use sender's system
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
else: else:
@ -53,14 +56,15 @@ async def system_info(conn, message, args):
if system is None: if system is None:
return False, "Unable to find system \"{}\".".format(args[0]) return False, "Unable to find system \"{}\".".format(args[0])
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
return True return True
@command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.") @command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.")
async def system_name(conn, message, args): async def system_name(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
@ -73,10 +77,11 @@ async def system_name(conn, message, args):
await db.update_system_field(conn, system_id=system["id"], field="name", value=new_name) await db.update_system_field(conn, system_id=system["id"], field="name", value=new_name)
return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared." return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared."
@command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.") @command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.")
async def system_description(conn, message, args): async def system_description(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
@ -93,10 +98,11 @@ async def system_description(conn, message, args):
await db.update_system_field(conn, system_id=system["id"], field="description", value=new_description) await db.update_system_field(conn, system_id=system["id"], field="description", value=new_description)
return True, "Description set." if new_description else "Description cleared." return True, "Description set." if new_description else "Description cleared."
@command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.") @command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.")
async def system_tag(conn, message, args): async def system_tag(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
@ -110,24 +116,27 @@ async def system_tag(conn, message, args):
members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=max_length - len(tag)) members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=max_length - len(tag))
if len(members_exceeding) > 0: if len(members_exceeding) > 0:
# If so, error out and warn # If so, error out and warn
member_names = ", ".join([member["name"] for member in members_exceeding]) member_names = ", ".join([member["name"]
logger.debug("Members exceeding combined length with tag '{}': {}".format(tag, member_names)) for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(
tag, member_names))
return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names) return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names)
async with conn.transaction(): async with conn.transaction():
await db.update_system_field(conn, system_id=system["id"], field="tag", value=tag) await db.update_system_field(conn, system_id=system["id"], field="tag", value=tag)
return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared." return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared."
@command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.") @command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.")
async def system_remove(conn, message, args): async def system_remove(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"])) await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"]))
msg = await client.wait_for_message(author=message.author, channel=message.channel) msg = await client.wait_for_message(author=message.author, channel=message.channel)
if msg.content == system["hid"]: if msg.content == system["hid"]:
await db.remove_system(conn, system_id=system["id"]) await db.remove_system(conn, system_id=system["id"])
@ -139,7 +148,7 @@ async def system_remove(conn, message, args):
@command(cmd="pk;system", subcommand="link", usage="<account>", description="Links another account to your system.") @command(cmd="pk;system", subcommand="link", usage="<account>", description="Links another account to your system.")
async def system_link(conn, message, args): async def system_link(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
@ -176,7 +185,7 @@ async def system_link(conn, message, args):
@command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.") @command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.")
async def system_unlink(conn, message, args): async def system_unlink(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
@ -184,15 +193,16 @@ async def system_unlink(conn, message, args):
linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"]) linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"])
if len(linked_accounts) == 1: if len(linked_accounts) == 1:
return False, "This is the only account on your system, so you can't unlink it." return False, "This is the only account on your system, so you can't unlink it."
async with conn.transaction(): async with conn.transaction():
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
return True, "Account unlinked." return True, "Account unlinked."
@command(cmd="pk;member", subcommand="new", usage="<name>", description="Adds a new member to your system.") @command(cmd="pk;member", subcommand="new", usage="<name>", description="Adds a new member to your system.")
async def new_member(conn, message, args): async def new_member(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
@ -208,11 +218,13 @@ async def new_member(conn, message, args):
await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid) await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid)
return True, "Member \"{}\" (`{}`) registered!".format(name, hid) return True, "Member \"{}\" (`{}`) registered!".format(name, hid)
@member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False) @member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False)
async def member_info(conn, message, member, args): async def member_info(conn, message, member, args):
await client.send_message(message.channel, embed=await generate_member_info_card(conn, member)) await client.send_message(message.channel, embed=await generate_member_info_card(conn, member))
return True return True
@member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.") @member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.")
async def member_color(conn, message, member, args): async def member_color(conn, message, member, args):
if len(args) == 0: if len(args) == 0:
@ -227,7 +239,8 @@ async def member_color(conn, message, member, args):
async with conn.transaction(): async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="color", value=color) await db.update_member_field(conn, member_id=member["id"], field="color", value=color)
return True, "Color updated to #{}.".format(color) if color else "Color cleared." return True, "Color updated to #{}.".format(color) if color else "Color cleared."
@member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.") @member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.")
async def member_pronouns(conn, message, member, args): async def member_pronouns(conn, message, member, args):
if len(args) == 0: if len(args) == 0:
@ -239,12 +252,13 @@ async def member_pronouns(conn, message, member, args):
await db.update_member_field(conn, member_id=member["id"], field="pronouns", value=pronouns) await db.update_member_field(conn, member_id=member["id"], field="pronouns", value=pronouns)
return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared." return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared."
@member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.") @member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.")
async def member_birthday(conn, message, member, args): async def member_birthday(conn, message, member, args):
if len(args) == 0: if len(args) == 0:
new_date = None new_date = None
else: else:
# Parse date # Parse date
try: try:
new_date = datetime.strptime(args[0], "%Y-%m-%d").date() new_date = datetime.strptime(args[0], "%Y-%m-%d").date()
except ValueError: except ValueError:
@ -254,6 +268,7 @@ async def member_birthday(conn, message, member, args):
await db.update_member_field(conn, member_id=member["id"], field="birthday", value=new_date) await db.update_member_field(conn, member_id=member["id"], field="birthday", value=new_date)
return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared." return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared."
@member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.") @member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.")
async def member_description(conn, message, member, args): async def member_description(conn, message, member, args):
if len(args) > 0 and args[0] == "clear": if len(args) > 0 and args[0] == "clear":
@ -267,11 +282,12 @@ async def member_description(conn, message, member, args):
async with conn.transaction(): async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="description", value=new_description) await db.update_member_field(conn, member_id=member["id"], field="description", value=new_description)
return True, "Description set." if new_description else "Description cleared." return True, "Description set." if new_description else "Description cleared."
@member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.") @member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.")
async def member_remove(conn, message, member, args): async def member_remove(conn, message, member, args):
await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"])) await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"]))
msg = await client.wait_for_message(author=message.author, channel=message.channel) msg = await client.wait_for_message(author=message.author, channel=message.channel)
if msg.content == member["name"]: if msg.content == member["name"]:
await db.delete_member(conn, member_id=member["id"]) await db.delete_member(conn, member_id=member["id"])
@ -279,6 +295,7 @@ async def member_remove(conn, message, member, args):
else: else:
return True, "Member removal cancelled." return True, "Member removal cancelled."
@member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.") @member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.")
async def member_avatar(conn, message, member, args): async def member_avatar(conn, message, member, args):
if len(args) == 0: if len(args) == 0:
@ -296,11 +313,11 @@ async def member_avatar(conn, message, member, args):
avatar_url = args[0] avatar_url = args[0]
else: else:
return False, "Invalid URL." return False, "Invalid URL."
async with conn.transaction(): async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="avatar_url", value=avatar_url) await db.update_member_field(conn, member_id=member["id"], field="avatar_url", value=avatar_url)
return True, "Avatar set." if avatar_url else "Avatar cleared." return True, "Avatar set." if avatar_url else "Avatar cleared."
@member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).") @member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).")
async def member_proxy(conn, message, member, args): async def member_proxy(conn, message, member, args):
@ -318,7 +335,8 @@ async def member_proxy(conn, message, member, args):
# Extract prefix and suffix # Extract prefix and suffix
prefix = example[:example.index("text")].strip() prefix = example[:example.index("text")].strip()
suffix = example[example.index("text")+4:].strip() suffix = example[example.index("text")+4:].strip()
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix)) logger.debug(
"Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
# DB stores empty strings as None, make that work # DB stores empty strings as None, make that work
if not prefix: if not prefix:
@ -331,6 +349,7 @@ async def member_proxy(conn, message, member, args):
await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix) await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix)
return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared." return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
@command(cmd="pk;message", subcommand=None, usage="<id>", description="Shows information about a proxied message. Requires the message ID.") @command(cmd="pk;message", subcommand=None, usage="<id>", description="Shows information about a proxied message. Requires the message ID.")
async def message_info(conn, message, args): async def message_info(conn, message, args):
try: try:
@ -357,14 +376,16 @@ async def message_info(conn, message, args):
embed = discord.Embed() embed = discord.Embed()
embed.timestamp = message.timestamp embed.timestamp = message.timestamp
embed.colour = discord.Colour.blue() embed.colour = discord.Colour.blue()
if system["name"]: if system["name"]:
system_value = "`{}`: {}".format(system["hid"], system["name"]) system_value = "`{}`: {}".format(system["hid"], system["name"])
else: else:
system_value = "`{}`".format(system["hid"]) system_value = "`{}`".format(system["hid"])
embed.add_field(name="System", value=system_value) embed.add_field(name="System", value=system_value)
embed.add_field(name="Member", value="`{}`: {}".format(member["hid"], member["name"])) embed.add_field(name="Member", value="`{}`: {}".format(
embed.add_field(name="Sent by", value="{}#{}".format(original_sender.name, original_sender.discriminator)) member["hid"], member["name"]))
embed.add_field(name="Sent by", value="{}#{}".format(
original_sender.name, original_sender.discriminator))
embed.add_field(name="Content", value=message.clean_content, inline=False) embed.add_field(name="Content", value=message.clean_content, inline=False)
embed.set_author(name=member["name"], url=member["avatar_url"]) embed.set_author(name=member["name"], url=member["avatar_url"])
@ -372,12 +393,14 @@ async def message_info(conn, message, args):
await client.send_message(message.channel, embed=embed) await client.send_message(message.channel, embed=embed)
return True return True
@command(cmd="pk;help", subcommand=None, usage="[system|member|message]", description="Shows this help message.") @command(cmd="pk;help", subcommand=None, usage="[system|member|message]", description="Shows this help message.")
async def show_help(conn, message, args): async def show_help(conn, message, args):
embed = discord.Embed() embed = discord.Embed()
embed.colour = discord.Colour.blue() embed.colour = discord.Colour.blue()
embed.title = "PluralKit Help" embed.title = "PluralKit Help"
embed.set_footer(text="<> denotes mandatory arguments, [] denotes optional arguments") embed.set_footer(
text="<> denotes mandatory arguments, [] denotes optional arguments")
if len(args) > 0 and ("pk;" + args[0]) in command_map: if len(args) > 0 and ("pk;" + args[0]) in command_map:
cmds = ["", ("pk;" + args[0], command_map["pk;" + args[0]])] cmds = ["", ("pk;" + args[0], command_map["pk;" + args[0]])]
@ -386,7 +409,8 @@ async def show_help(conn, message, args):
for cmd, subcommands in cmds: for cmd, subcommands in cmds:
for subcmd, (_, usage, description) in subcommands.items(): for subcmd, (_, usage, description) in subcommands.items():
embed.add_field(name="{} {} {}".format(cmd, subcmd or "", usage or ""), value=description, inline=False) embed.add_field(name="{} {} {}".format(
cmd, subcmd or "", usage or ""), value=description, inline=False)
await client.send_message(message.channel, embed=embed) await client.send_message(message.channel, embed=embed)
return True return True

View File

@ -5,6 +5,7 @@ import asyncpg.exceptions
from pluralkit.bot import logger from pluralkit.bot import logger
async def connect(): async def connect():
while True: while True:
try: try:
@ -12,32 +13,40 @@ async def connect():
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError): except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
pass pass
def db_wrap(func): def db_wrap(func):
async def inner(*args, **kwargs): async def inner(*args, **kwargs):
before = time.perf_counter() before = time.perf_counter()
res = await func(*args, **kwargs) res = await func(*args, **kwargs)
after = time.perf_counter() after = time.perf_counter()
logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000)) logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000))
return res return res
return inner return inner
webhook_cache = {} webhook_cache = {}
@db_wrap @db_wrap
async def create_system(conn, system_name: str, system_hid: str): async def create_system(conn, system_name: str, system_hid: str):
logger.debug("Creating system (name={}, hid={})".format(system_name, system_hid)) logger.debug("Creating system (name={}, hid={})".format(
system_name, system_hid))
return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid) return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
@db_wrap @db_wrap
async def remove_system(conn, system_id: int): async def remove_system(conn, system_id: int):
logger.debug("Deleting system (id={})".format(system_id)) logger.debug("Deleting system (id={})".format(system_id))
await conn.execute("delete from systems where id = $1", system_id) await conn.execute("delete from systems where id = $1", system_id)
@db_wrap @db_wrap
async def create_member(conn, system_id: int, member_name: str, member_hid: str): async def create_member(conn, system_id: int, member_name: str, member_hid: str):
logger.debug("Creating member (system={}, name={}, hid={})".format(system_id, member_name, member_hid)) logger.debug("Creating member (system={}, name={}, hid={})".format(
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid) system_id, member_name, member_hid))
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
@db_wrap @db_wrap
async def delete_member(conn, member_id: int): async def delete_member(conn, member_id: int):
@ -45,70 +54,90 @@ async def delete_member(conn, member_id: int):
await conn.execute("update switches set member = null, member_del = true where member = $1", member_id) await conn.execute("update switches set member = null, member_del = true where member = $1", member_id)
await conn.execute("delete from members where id = $1", member_id) await conn.execute("delete from members where id = $1", member_id)
@db_wrap @db_wrap
async def link_account(conn, system_id: int, account_id: str): async def link_account(conn, system_id: int, account_id: str):
logger.debug("Linking account (account_id={}, system_id={})".format(account_id, system_id)) logger.debug("Linking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id) await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
@db_wrap @db_wrap
async def unlink_account(conn, system_id: int, account_id: str): async def unlink_account(conn, system_id: int, account_id: str):
logger.debug("Unlinking account (account_id={}, system_id={})".format(account_id, system_id)) logger.debug("Unlinking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id) await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
@db_wrap @db_wrap
async def get_linked_accounts(conn, system_id: int): async def get_linked_accounts(conn, system_id: int):
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)] return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
@db_wrap @db_wrap
async def get_system_by_account(conn, account_id: str): async def get_system_by_account(conn, account_id: str):
return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id)) return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
@db_wrap @db_wrap
async def get_system_by_hid(conn, system_hid: str): async def get_system_by_hid(conn, system_hid: str):
return await conn.fetchrow("select * from systems where hid = $1", system_hid) return await conn.fetchrow("select * from systems where hid = $1", system_hid)
@db_wrap @db_wrap
async def get_system(conn, system_id: int): async def get_system(conn, system_id: int):
return await conn.fetchrow("select * from systems where id = $1", system_id) return await conn.fetchrow("select * from systems where id = $1", system_id)
@db_wrap @db_wrap
async def get_member_by_name(conn, system_id: int, member_name: str): async def get_member_by_name(conn, system_id: int, member_name: str):
return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name) return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name)
@db_wrap @db_wrap
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str): async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str):
return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid) return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
@db_wrap @db_wrap
async def get_member_by_hid(conn, member_hid: str): async def get_member_by_hid(conn, member_hid: str):
return await conn.fetchrow("select * from members where hid = $1", member_hid) return await conn.fetchrow("select * from members where hid = $1", member_hid)
@db_wrap @db_wrap
async def get_member(conn, member_id: int): async def get_member(conn, member_id: int):
return await conn.fetchrow("select * from members where id = $1", member_id) return await conn.fetchrow("select * from members where id = $1", member_id)
@db_wrap @db_wrap
async def get_message(conn, message_id: str): async def get_message(conn, message_id: str):
return await conn.fetchrow("select * from messages where mid = $1", message_id) return await conn.fetchrow("select * from messages where mid = $1", message_id)
@db_wrap @db_wrap
async def update_system_field(conn, system_id: int, field: str, value): async def update_system_field(conn, system_id: int, field: str, value):
logger.debug("Updating system field (id={}, {}={})".format(system_id, field, value)) logger.debug("Updating system field (id={}, {}={})".format(
system_id, field, value))
await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id) await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id)
@db_wrap @db_wrap
async def update_member_field(conn, member_id: int, field: str, value): async def update_member_field(conn, member_id: int, field: str, value):
logger.debug("Updating member field (id={}, {}={})".format(member_id, field, value)) logger.debug("Updating member field (id={}, {}={})".format(
member_id, field, value))
await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id) await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id)
@db_wrap @db_wrap
async def get_all_members(conn, system_id: int): async def get_all_members(conn, system_id: int):
return await conn.fetch("select * from members where system = $1", system_id) return await conn.fetch("select * from members where system = $1", system_id)
@db_wrap @db_wrap
async def get_members_exceeding(conn, system_id: int, length: int): async def get_members_exceeding(conn, system_id: int, length: int):
return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length) return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length)
@db_wrap @db_wrap
async def get_webhook(conn, channel_id: str): async def get_webhook(conn, channel_id: str):
if channel_id in webhook_cache: if channel_id in webhook_cache:
@ -117,30 +146,38 @@ async def get_webhook(conn, channel_id: str):
webhook_cache[channel_id] = res webhook_cache[channel_id] = res
return res return res
@db_wrap @db_wrap
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str): async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(channel_id, webhook_id, webhook_token)) logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token) await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
@db_wrap @db_wrap
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str): async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str):
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(message_id, channel_id, member_id, sender_id)) logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id)) await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id))
@db_wrap @db_wrap
async def get_members_by_account(conn, account_id: str): async def get_members_by_account(conn, account_id: str):
# Returns a "chimera" object # Returns a "chimera" object
return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id))
@db_wrap @db_wrap
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str): async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str):
await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id)) await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id))
@db_wrap @db_wrap
async def delete_message(conn, message_id: str): async def delete_message(conn, message_id: str):
logger.debug("Deleting message (id={})".format(message_id)) logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", int(message_id)) await conn.execute("delete from messages where mid = $1", int(message_id))
async def create_tables(conn): async def create_tables(conn):
await conn.execute("""create table if not exists systems ( await conn.execute("""create table if not exists systems (
id serial primary key, id serial primary key,
@ -190,4 +227,4 @@ async def create_tables(conn):
id bigint primary key, id bigint primary key,
cmd_chans bigint[], cmd_chans bigint[],
proxy_chans bigint[] proxy_chans bigint[]
)""") )""")

View File

@ -6,6 +6,7 @@ import aiohttp
from pluralkit import db from pluralkit import db
from pluralkit.bot import client, logger from pluralkit.bot import client, logger
async def get_webhook(conn, channel): async def get_webhook(conn, channel):
async with conn.transaction(): async with conn.transaction():
# Try to find an existing webhook # Try to find an existing webhook
@ -14,7 +15,8 @@ async def get_webhook(conn, channel):
if not hook_row: if not hook_row:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
req_data = {"name": "PluralKit Proxy Webhook"} req_data = {"name": "PluralKit Proxy Webhook"}
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} req_headers = {
"Authorization": "Bot {}".format(os.environ["TOKEN"])}
async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp: async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp:
data = await resp.json() data = await resp.json()
@ -24,17 +26,19 @@ async def get_webhook(conn, channel):
# Insert new hook into DB # Insert new hook into DB
await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token) await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token)
return hook_id, token return hook_id, token
return hook_row["webhook"], hook_row["token"] return hook_row["webhook"], hook_row["token"]
async def proxy_message(conn, member, message, inner): async def proxy_message(conn, member, message, inner):
logger.debug("Proxying message '{}' for member {}".format(inner, member["hid"])) logger.debug("Proxying message '{}' for member {}".format(
inner, member["hid"]))
# Delete the original message # Delete the original message
await client.delete_message(message) await client.delete_message(message)
# Get the webhook details # Get the webhook details
hook_id, hook_token = await get_webhook(conn, message.channel) hook_id, hook_token = await get_webhook(conn, message.channel)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
req_data = { req_data = {
"username": "{} {}".format(member["name"], member["tag"] or "").strip(), "username": "{} {}".format(member["name"], member["tag"] or "").strip(),
"avatar_url": member["avatar_url"], "avatar_url": member["avatar_url"],
@ -49,13 +53,15 @@ async def proxy_message(conn, member, message, inner):
# Insert new message details into the DB # Insert new message details into the DB
await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id) await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id)
async def handle_proxying(conn, message): async def handle_proxying(conn, message):
# Big fat query to find every member associated with this account # Big fat query to find every member associated with this account
# Returned member object has a few more keys (system tag, for example) # Returned member object has a few more keys (system tag, for example)
members = await db.get_members_by_account(conn, account_id=message.author.id) members = await db.get_members_by_account(conn, account_id=message.author.id)
# Sort by specificity (members with both prefix and suffix go higher) # Sort by specificity (members with both prefix and suffix go higher)
members = sorted(members, key=lambda x: int(bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True) members = sorted(members, key=lambda x: int(
bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True)
msg = message.content msg = message.content
for member in members: for member in members:
@ -71,14 +77,14 @@ async def handle_proxying(conn, message):
if msg.startswith(prefix) and msg.endswith(suffix): if msg.startswith(prefix) and msg.endswith(suffix):
# Extract the actual message contents sans tags # Extract the actual message contents sans tags
if suffix: if suffix:
inner_message = message.content[len(prefix):-len(suffix)].strip() inner_message = message.content[len(
prefix):-len(suffix)].strip()
else: else:
# Slicing to -0 breaks, don't do that # Slicing to -0 breaks, don't do that
inner_message = message.content[len(prefix):].strip() inner_message = message.content[len(prefix):].strip()
await proxy_message(conn, member, message, inner_message) await proxy_message(conn, member, message, inner_message)
break break
async def handle_reaction(conn, reaction, user): async def handle_reaction(conn, reaction, user):
@ -90,4 +96,4 @@ async def handle_reaction(conn, reaction, user):
if message: if message:
# If so, delete the message and remove it from the DB # If so, delete the message and remove it from the DB
await db.delete_message(conn, message["mid"]) await db.delete_message(conn, message["mid"])
await client.delete_message(reaction.message) await client.delete_message(reaction.message)

View File

@ -9,9 +9,11 @@ import discord
from pluralkit import db from pluralkit import db
from pluralkit.bot import client, logger from pluralkit.bot import client, logger
def generate_hid() -> str: def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5)) return "".join(random.choices(string.ascii_lowercase, k=5))
async def parse_mention(mention: str) -> discord.User: async def parse_mention(mention: str) -> discord.User:
# First try matching mention format # First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention) match = re.fullmatch("<@!?(\\d+)>", mention)
@ -27,10 +29,11 @@ async def parse_mention(mention: str) -> discord.User:
except (ValueError, discord.NotFound): except (ValueError, discord.NotFound):
return None return None
async def get_system_fuzzy(conn, key) -> asyncpg.Record: async def get_system_fuzzy(conn, key) -> asyncpg.Record:
if isinstance(key, discord.User): if isinstance(key, discord.User):
return await db.get_system_by_account(conn, account_id=key.id) return await db.get_system_by_account(conn, account_id=key.id)
if isinstance(key, str) and len(key) == 5: if isinstance(key, str) and len(key) == 5:
return await db.get_system_by_hid(conn, system_hid=key) return await db.get_system_by_hid(conn, system_hid=key)
@ -39,7 +42,8 @@ async def get_system_fuzzy(conn, key) -> asyncpg.Record:
if system: if system:
return system return system
return None return None
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record: async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record:
# First search by hid # First search by hid
if system_only: if system_only:
@ -60,6 +64,8 @@ command_map = {}
# Command wrapper # Command wrapper
# Return True for success, return False for failure # Return True for success, return False for failure
# Second parameter is the message it'll send. If just False, will print usage # Second parameter is the message it'll send. If just False, will print usage
def command(cmd, subcommand, usage=None, description=None): def command(cmd, subcommand, usage=None, description=None):
def wrap(func): def wrap(func):
async def wrapper(conn, message, args): async def wrapper(conn, message, args):
@ -70,12 +76,13 @@ def command(cmd, subcommand, usage=None, description=None):
success, msg = res, None success, msg = res, None
else: else:
success, msg = res success, msg = res
if not success and not msg: if not success and not msg:
# Failure, no message, print usage # Failure, no message, print usage
usage_embed = discord.Embed() usage_embed = discord.Embed()
usage_embed.colour = discord.Colour.blue() usage_embed.colour = discord.Colour.blue()
usage_embed.add_field(name="Usage", value=usage, inline=False) usage_embed.add_field(
name="Usage", value=usage, inline=False)
await client.send_message(message.channel, embed=usage_embed) await client.send_message(message.channel, embed=usage_embed)
elif not success: elif not success:
@ -103,6 +110,8 @@ def command(cmd, subcommand, usage=None, description=None):
# Member command wrapper # Member command wrapper
# Tries to find member by first argument # Tries to find member by first argument
# If system_only=False, allows members from other systems by hid # If system_only=False, allows members from other systems by hid
def member_command(cmd, subcommand, usage=None, description=None, system_only=True): def member_command(cmd, subcommand, usage=None, description=None, system_only=True):
def wrap(func): def wrap(func):
async def wrapper(conn, message, args): async def wrapper(conn, message, args):
@ -122,11 +131,12 @@ def member_command(cmd, subcommand, usage=None, description=None, system_only=Tr
if member is None: if member is None:
return False, "Can't find member \"{}\".".format(args[0]) return False, "Can't find member \"{}\".".format(args[0])
return await func(conn, message, member, args[1:]) return await func(conn, message, member, args[1:])
return command(cmd=cmd, subcommand=subcommand, usage=usage, description=description)(wrapper) return command(cmd=cmd, subcommand=subcommand, usage=usage, description=description)(wrapper)
return wrap return wrap
async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed: async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed:
card = discord.Embed() card = discord.Embed()
@ -134,7 +144,8 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb
card.title = system["name"] card.title = system["name"]
if system["description"]: if system["description"]:
card.add_field(name="Description", value=system["description"], inline=False) card.add_field(name="Description",
value=system["description"], inline=False)
if system["tag"]: if system["tag"]:
card.add_field(name="Tag", value=system["tag"]) card.add_field(name="Tag", value=system["tag"])
@ -158,11 +169,13 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb
member_texts.append("`{}`: {}".format(member["hid"], member["name"])) member_texts.append("`{}`: {}".format(member["hid"], member["name"]))
if len(member_texts) > 0: if len(member_texts) > 0:
card.add_field(name="Members", value="\n".join(member_texts), inline=False) card.add_field(name="Members", value="\n".join(
member_texts), inline=False)
card.set_footer(text="System ID: {}".format(system["hid"])) card.set_footer(text="System ID: {}".format(system["hid"]))
return card return card
async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed: async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed:
card = discord.Embed() card = discord.Embed()
card.set_author(name=member["name"], icon_url=member["avatar_url"]) card.set_author(name=member["name"], icon_url=member["avatar_url"])
@ -171,18 +184,21 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb
card.colour = int(member["color"], 16) card.colour = int(member["color"], 16)
if member["birthday"]: if member["birthday"]:
card.add_field(name="Birthdate", value=member["birthday"].strftime("%b %d, %Y")) card.add_field(name="Birthdate",
value=member["birthday"].strftime("%b %d, %Y"))
if member["pronouns"]: if member["pronouns"]:
card.add_field(name="Pronouns", value=member["pronouns"]) card.add_field(name="Pronouns", value=member["pronouns"])
if member["prefix"] or member["suffix"]: if member["prefix"] or member["suffix"]:
prefix = member["prefix"] or "" prefix = member["prefix"] or ""
suffix = member["suffix"] or "" suffix = member["suffix"] or ""
card.add_field(name="Proxy Tags", value="{}text{}".format(prefix, suffix)) card.add_field(name="Proxy Tags",
value="{}text{}".format(prefix, suffix))
if member["description"]: if member["description"]:
card.add_field(name="Description", value=member["description"], inline=False) card.add_field(name="Description",
value=member["description"], inline=False)
# Get system name and hid # Get system name and hid
system = await db.get_system(conn, system_id=member["system"]) system = await db.get_system(conn, system_id=member["system"])
@ -192,9 +208,11 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb
system_value = "`{}`".format(system["hid"]) system_value = "`{}`".format(system["hid"])
card.add_field(name="System", value=system_value, inline=False) card.add_field(name="System", value=system_value, inline=False)
card.set_footer(text="System ID: {} | Member ID: {}".format(system["hid"], member["hid"])) card.set_footer(text="System ID: {} | Member ID: {}".format(
system["hid"], member["hid"]))
return card return card
async def text_input(message, subject): async def text_input(message, subject):
await client.send_message(message.channel, "Reply in this channel with the new description you want to set for {}.".format(subject)) await client.send_message(message.channel, "Reply in this channel with the new description you want to set for {}.".format(subject))
msg = await client.wait_for_message(author=message.author, channel=message.channel) msg = await client.wait_for_message(author=message.author, channel=message.channel)
@ -208,4 +226,4 @@ async def text_input(message, subject):
return msg.content return msg.content
else: else:
await client.clear_reactions(msg) await client.clear_reactions(msg)
return None return None