Format code
This commit is contained in:
		| @@ -1 +1 @@ | |||||||
| from . import commands, db, proxy | from . import commands, db, proxy | ||||||
|   | |||||||
| @@ -12,23 +12,28 @@ logger.setLevel(logging.DEBUG) | |||||||
|  |  | ||||||
| client = discord.Client() | client = discord.Client() | ||||||
|  |  | ||||||
|  |  | ||||||
| @client.event | @client.event | ||||||
| async def on_error(evt, *args, **kwargs): | async def on_error(evt, *args, **kwargs): | ||||||
|     logger.exception("Error while handling event {} with arguments {}:".format(evt, args)) |     logger.exception( | ||||||
|  |         "Error while handling event {} with arguments {}:".format(evt, args)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @client.event | @client.event | ||||||
| async def on_ready(): | async def on_ready(): | ||||||
|     # Print status info |     # Print status info | ||||||
|     logger.info("Connected to Discord.") |     logger.info("Connected to Discord.") | ||||||
|     logger.info("Account: {}#{}".format(client.user.name, client.user.discriminator)) |     logger.info("Account: {}#{}".format( | ||||||
|  |         client.user.name, client.user.discriminator)) | ||||||
|     logger.info("User ID: {}".format(client.user.id)) |     logger.info("User ID: {}".format(client.user.id)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @client.event | @client.event | ||||||
| async def on_message(message): | async def on_message(message): | ||||||
|     # Ignore bot messages |     # Ignore bot messages | ||||||
|     if message.author.bot: |     if message.author.bot: | ||||||
|         return |         return | ||||||
|      |  | ||||||
|     # Split into args. shlex sucks so we don't bother with quotes |     # Split into args. shlex sucks so we don't bother with quotes | ||||||
|     args = message.content.split(" ") |     args = message.content.split(" ") | ||||||
|  |  | ||||||
| @@ -54,6 +59,7 @@ async def on_message(message): | |||||||
|         async with client.pool.acquire() as conn: |         async with client.pool.acquire() as conn: | ||||||
|             await proxy.handle_proxying(conn, message) |             await proxy.handle_proxying(conn, message) | ||||||
|  |  | ||||||
|  |  | ||||||
| @client.event | @client.event | ||||||
| async def on_reaction_add(reaction, user): | async def on_reaction_add(reaction, user): | ||||||
|     from pluralkit import proxy |     from pluralkit import proxy | ||||||
| @@ -62,6 +68,7 @@ async def on_reaction_add(reaction, user): | |||||||
|     async with client.pool.acquire() as conn: |     async with client.pool.acquire() as conn: | ||||||
|         await proxy.handle_reaction(conn, reaction, user) |         await proxy.handle_reaction(conn, reaction, user) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def run(): | async def run(): | ||||||
|     from pluralkit import db |     from pluralkit import db | ||||||
|     try: |     try: | ||||||
| @@ -79,4 +86,4 @@ async def run(): | |||||||
|         await client.start(os.environ["TOKEN"]) |         await client.start(os.environ["TOKEN"]) | ||||||
|     finally: |     finally: | ||||||
|         logger.info("Logging out from Discord...") |         logger.info("Logging out from Discord...") | ||||||
|         await client.logout() |         await client.logout() | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ from pluralkit import db | |||||||
| from pluralkit.bot import client, logger | from pluralkit.bot import client, logger | ||||||
| from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map | from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand=None, description="Shows information about your system.") | @command(cmd="pk;system", subcommand=None, description="Shows information about your system.") | ||||||
| async def this_system_info(conn, message, args): | async def this_system_info(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
| @@ -18,6 +19,7 @@ async def this_system_info(conn, message, args): | |||||||
|     await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) |     await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.") | @command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.") | ||||||
| async def new_system(conn, message, args): | async def new_system(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
| @@ -39,12 +41,13 @@ async def new_system(conn, message, args): | |||||||
|         await db.link_account(conn, system_id=system["id"], account_id=message.author.id) |         await db.link_account(conn, system_id=system["id"], account_id=message.author.id) | ||||||
|         return True, "System registered! To begin adding members, use `pk;member new <name>`." |         return True, "System registered! To begin adding members, use `pk;member new <name>`." | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.") | @command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.") | ||||||
| async def system_info(conn, message, args): | async def system_info(conn, message, args): | ||||||
|     if len(args) == 0: |     if len(args) == 0: | ||||||
|         # Use sender's system |         # Use sender's system | ||||||
|         system = await db.get_system_by_account(conn, message.author.id) |         system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|         if system is None: |         if system is None: | ||||||
|             return False, "No system is registered to this account." |             return False, "No system is registered to this account." | ||||||
|     else: |     else: | ||||||
| @@ -53,14 +56,15 @@ async def system_info(conn, message, args): | |||||||
|  |  | ||||||
|         if system is None: |         if system is None: | ||||||
|             return False, "Unable to find system \"{}\".".format(args[0]) |             return False, "Unable to find system \"{}\".".format(args[0]) | ||||||
|      |  | ||||||
|     await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) |     await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.") | @command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.") | ||||||
| async def system_name(conn, message, args): | async def system_name(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|     if system is None: |     if system is None: | ||||||
|         return False, "No system is registered to this account." |         return False, "No system is registered to this account." | ||||||
|  |  | ||||||
| @@ -73,10 +77,11 @@ async def system_name(conn, message, args): | |||||||
|         await db.update_system_field(conn, system_id=system["id"], field="name", value=new_name) |         await db.update_system_field(conn, system_id=system["id"], field="name", value=new_name) | ||||||
|         return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared." |         return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared." | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.") | @command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.") | ||||||
| async def system_description(conn, message, args): | async def system_description(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|     if system is None: |     if system is None: | ||||||
|         return False, "No system is registered to this account." |         return False, "No system is registered to this account." | ||||||
|  |  | ||||||
| @@ -93,10 +98,11 @@ async def system_description(conn, message, args): | |||||||
|         await db.update_system_field(conn, system_id=system["id"], field="description", value=new_description) |         await db.update_system_field(conn, system_id=system["id"], field="description", value=new_description) | ||||||
|         return True, "Description set." if new_description else "Description cleared." |         return True, "Description set." if new_description else "Description cleared." | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.") | @command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.") | ||||||
| async def system_tag(conn, message, args): | async def system_tag(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|     if system is None: |     if system is None: | ||||||
|         return False, "No system is registered to this account." |         return False, "No system is registered to this account." | ||||||
|  |  | ||||||
| @@ -110,24 +116,27 @@ async def system_tag(conn, message, args): | |||||||
|         members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=max_length - len(tag)) |         members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=max_length - len(tag)) | ||||||
|         if len(members_exceeding) > 0: |         if len(members_exceeding) > 0: | ||||||
|             # If so, error out and warn |             # If so, error out and warn | ||||||
|             member_names = ", ".join([member["name"] for member in members_exceeding]) |             member_names = ", ".join([member["name"] | ||||||
|             logger.debug("Members exceeding combined length with tag '{}': {}".format(tag, member_names)) |                                       for member in members_exceeding]) | ||||||
|  |             logger.debug("Members exceeding combined length with tag '{}': {}".format( | ||||||
|  |                 tag, member_names)) | ||||||
|             return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names) |             return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names) | ||||||
|  |  | ||||||
|     async with conn.transaction(): |     async with conn.transaction(): | ||||||
|         await db.update_system_field(conn, system_id=system["id"], field="tag", value=tag) |         await db.update_system_field(conn, system_id=system["id"], field="tag", value=tag) | ||||||
|      |  | ||||||
|     return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared." |     return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared." | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.") | @command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.") | ||||||
| async def system_remove(conn, message, args): | async def system_remove(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|     if system is None: |     if system is None: | ||||||
|         return False, "No system is registered to this account." |         return False, "No system is registered to this account." | ||||||
|  |  | ||||||
|     await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"])) |     await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"])) | ||||||
|      |  | ||||||
|     msg = await client.wait_for_message(author=message.author, channel=message.channel) |     msg = await client.wait_for_message(author=message.author, channel=message.channel) | ||||||
|     if msg.content == system["hid"]: |     if msg.content == system["hid"]: | ||||||
|         await db.remove_system(conn, system_id=system["id"]) |         await db.remove_system(conn, system_id=system["id"]) | ||||||
| @@ -139,7 +148,7 @@ async def system_remove(conn, message, args): | |||||||
| @command(cmd="pk;system", subcommand="link", usage="<account>", description="Links another account to your system.") | @command(cmd="pk;system", subcommand="link", usage="<account>", description="Links another account to your system.") | ||||||
| async def system_link(conn, message, args): | async def system_link(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|     if system is None: |     if system is None: | ||||||
|         return False, "No system is registered to this account." |         return False, "No system is registered to this account." | ||||||
|  |  | ||||||
| @@ -176,7 +185,7 @@ async def system_link(conn, message, args): | |||||||
| @command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.") | @command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.") | ||||||
| async def system_unlink(conn, message, args): | async def system_unlink(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|     if system is None: |     if system is None: | ||||||
|         return False, "No system is registered to this account." |         return False, "No system is registered to this account." | ||||||
|  |  | ||||||
| @@ -184,15 +193,16 @@ async def system_unlink(conn, message, args): | |||||||
|     linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"]) |     linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"]) | ||||||
|     if len(linked_accounts) == 1: |     if len(linked_accounts) == 1: | ||||||
|         return False, "This is the only account on your system, so you can't unlink it." |         return False, "This is the only account on your system, so you can't unlink it." | ||||||
|      |  | ||||||
|     async with conn.transaction(): |     async with conn.transaction(): | ||||||
|         await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) |         await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) | ||||||
|         return True, "Account unlinked." |         return True, "Account unlinked." | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;member", subcommand="new", usage="<name>", description="Adds a new member to your system.") | @command(cmd="pk;member", subcommand="new", usage="<name>", description="Adds a new member to your system.") | ||||||
| async def new_member(conn, message, args): | async def new_member(conn, message, args): | ||||||
|     system = await db.get_system_by_account(conn, message.author.id) |     system = await db.get_system_by_account(conn, message.author.id) | ||||||
|      |  | ||||||
|     if system is None: |     if system is None: | ||||||
|         return False, "No system is registered to this account." |         return False, "No system is registered to this account." | ||||||
|  |  | ||||||
| @@ -208,11 +218,13 @@ async def new_member(conn, message, args): | |||||||
|         await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid) |         await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid) | ||||||
|         return True, "Member \"{}\" (`{}`) registered!".format(name, hid) |         return True, "Member \"{}\" (`{}`) registered!".format(name, hid) | ||||||
|  |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False) | @member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False) | ||||||
| async def member_info(conn, message, member, args): | async def member_info(conn, message, member, args): | ||||||
|     await client.send_message(message.channel, embed=await generate_member_info_card(conn, member)) |     await client.send_message(message.channel, embed=await generate_member_info_card(conn, member)) | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.") | @member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.") | ||||||
| async def member_color(conn, message, member, args): | async def member_color(conn, message, member, args): | ||||||
|     if len(args) == 0: |     if len(args) == 0: | ||||||
| @@ -227,7 +239,8 @@ async def member_color(conn, message, member, args): | |||||||
|     async with conn.transaction(): |     async with conn.transaction(): | ||||||
|         await db.update_member_field(conn, member_id=member["id"], field="color", value=color) |         await db.update_member_field(conn, member_id=member["id"], field="color", value=color) | ||||||
|         return True, "Color updated to #{}.".format(color) if color else "Color cleared." |         return True, "Color updated to #{}.".format(color) if color else "Color cleared." | ||||||
|          |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.") | @member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.") | ||||||
| async def member_pronouns(conn, message, member, args): | async def member_pronouns(conn, message, member, args): | ||||||
|     if len(args) == 0: |     if len(args) == 0: | ||||||
| @@ -239,12 +252,13 @@ async def member_pronouns(conn, message, member, args): | |||||||
|         await db.update_member_field(conn, member_id=member["id"], field="pronouns", value=pronouns) |         await db.update_member_field(conn, member_id=member["id"], field="pronouns", value=pronouns) | ||||||
|         return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared." |         return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared." | ||||||
|  |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.") | @member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.") | ||||||
| async def member_birthday(conn, message, member, args): | async def member_birthday(conn, message, member, args): | ||||||
|     if len(args) == 0: |     if len(args) == 0: | ||||||
|         new_date = None |         new_date = None | ||||||
|     else: |     else: | ||||||
|         # Parse date  |         # Parse date | ||||||
|         try: |         try: | ||||||
|             new_date = datetime.strptime(args[0], "%Y-%m-%d").date() |             new_date = datetime.strptime(args[0], "%Y-%m-%d").date() | ||||||
|         except ValueError: |         except ValueError: | ||||||
| @@ -254,6 +268,7 @@ async def member_birthday(conn, message, member, args): | |||||||
|         await db.update_member_field(conn, member_id=member["id"], field="birthday", value=new_date) |         await db.update_member_field(conn, member_id=member["id"], field="birthday", value=new_date) | ||||||
|         return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared." |         return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared." | ||||||
|  |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.") | @member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.") | ||||||
| async def member_description(conn, message, member, args): | async def member_description(conn, message, member, args): | ||||||
|     if len(args) > 0 and args[0] == "clear": |     if len(args) > 0 and args[0] == "clear": | ||||||
| @@ -267,11 +282,12 @@ async def member_description(conn, message, member, args): | |||||||
|     async with conn.transaction(): |     async with conn.transaction(): | ||||||
|         await db.update_member_field(conn, member_id=member["id"], field="description", value=new_description) |         await db.update_member_field(conn, member_id=member["id"], field="description", value=new_description) | ||||||
|         return True, "Description set." if new_description else "Description cleared." |         return True, "Description set." if new_description else "Description cleared." | ||||||
|          |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.") | @member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.") | ||||||
| async def member_remove(conn, message, member, args): | async def member_remove(conn, message, member, args): | ||||||
|     await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"])) |     await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"])) | ||||||
|      |  | ||||||
|     msg = await client.wait_for_message(author=message.author, channel=message.channel) |     msg = await client.wait_for_message(author=message.author, channel=message.channel) | ||||||
|     if msg.content == member["name"]: |     if msg.content == member["name"]: | ||||||
|         await db.delete_member(conn, member_id=member["id"]) |         await db.delete_member(conn, member_id=member["id"]) | ||||||
| @@ -279,6 +295,7 @@ async def member_remove(conn, message, member, args): | |||||||
|     else: |     else: | ||||||
|         return True, "Member removal cancelled." |         return True, "Member removal cancelled." | ||||||
|  |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.") | @member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.") | ||||||
| async def member_avatar(conn, message, member, args): | async def member_avatar(conn, message, member, args): | ||||||
|     if len(args) == 0: |     if len(args) == 0: | ||||||
| @@ -296,11 +313,11 @@ async def member_avatar(conn, message, member, args): | |||||||
|                 avatar_url = args[0] |                 avatar_url = args[0] | ||||||
|             else: |             else: | ||||||
|                 return False, "Invalid URL." |                 return False, "Invalid URL." | ||||||
|      |  | ||||||
|     async with conn.transaction(): |     async with conn.transaction(): | ||||||
|         await db.update_member_field(conn, member_id=member["id"], field="avatar_url", value=avatar_url) |         await db.update_member_field(conn, member_id=member["id"], field="avatar_url", value=avatar_url) | ||||||
|         return True, "Avatar set." if avatar_url else "Avatar cleared." |         return True, "Avatar set." if avatar_url else "Avatar cleared." | ||||||
|          |  | ||||||
|  |  | ||||||
| @member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).") | @member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).") | ||||||
| async def member_proxy(conn, message, member, args): | async def member_proxy(conn, message, member, args): | ||||||
| @@ -318,7 +335,8 @@ async def member_proxy(conn, message, member, args): | |||||||
|         # Extract prefix and suffix |         # Extract prefix and suffix | ||||||
|         prefix = example[:example.index("text")].strip() |         prefix = example[:example.index("text")].strip() | ||||||
|         suffix = example[example.index("text")+4:].strip() |         suffix = example[example.index("text")+4:].strip() | ||||||
|         logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix)) |         logger.debug( | ||||||
|  |             "Matched prefix '{}' and suffix '{}'".format(prefix, suffix)) | ||||||
|  |  | ||||||
|         # DB stores empty strings as None, make that work |         # DB stores empty strings as None, make that work | ||||||
|         if not prefix: |         if not prefix: | ||||||
| @@ -331,6 +349,7 @@ async def member_proxy(conn, message, member, args): | |||||||
|         await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix) |         await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix) | ||||||
|         return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared." |         return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared." | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;message", subcommand=None, usage="<id>", description="Shows information about a proxied message. Requires the message ID.") | @command(cmd="pk;message", subcommand=None, usage="<id>", description="Shows information about a proxied message. Requires the message ID.") | ||||||
| async def message_info(conn, message, args): | async def message_info(conn, message, args): | ||||||
|     try: |     try: | ||||||
| @@ -357,14 +376,16 @@ async def message_info(conn, message, args): | |||||||
|     embed = discord.Embed() |     embed = discord.Embed() | ||||||
|     embed.timestamp = message.timestamp |     embed.timestamp = message.timestamp | ||||||
|     embed.colour = discord.Colour.blue() |     embed.colour = discord.Colour.blue() | ||||||
|      |  | ||||||
|     if system["name"]: |     if system["name"]: | ||||||
|         system_value = "`{}`: {}".format(system["hid"], system["name"]) |         system_value = "`{}`: {}".format(system["hid"], system["name"]) | ||||||
|     else: |     else: | ||||||
|         system_value = "`{}`".format(system["hid"]) |         system_value = "`{}`".format(system["hid"]) | ||||||
|     embed.add_field(name="System", value=system_value) |     embed.add_field(name="System", value=system_value) | ||||||
|     embed.add_field(name="Member", value="`{}`: {}".format(member["hid"], member["name"])) |     embed.add_field(name="Member", value="`{}`: {}".format( | ||||||
|     embed.add_field(name="Sent by", value="{}#{}".format(original_sender.name, original_sender.discriminator)) |         member["hid"], member["name"])) | ||||||
|  |     embed.add_field(name="Sent by", value="{}#{}".format( | ||||||
|  |         original_sender.name, original_sender.discriminator)) | ||||||
|     embed.add_field(name="Content", value=message.clean_content, inline=False) |     embed.add_field(name="Content", value=message.clean_content, inline=False) | ||||||
|  |  | ||||||
|     embed.set_author(name=member["name"], url=member["avatar_url"]) |     embed.set_author(name=member["name"], url=member["avatar_url"]) | ||||||
| @@ -372,12 +393,14 @@ async def message_info(conn, message, args): | |||||||
|     await client.send_message(message.channel, embed=embed) |     await client.send_message(message.channel, embed=embed) | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
| @command(cmd="pk;help", subcommand=None, usage="[system|member|message]", description="Shows this help message.") | @command(cmd="pk;help", subcommand=None, usage="[system|member|message]", description="Shows this help message.") | ||||||
| async def show_help(conn, message, args): | async def show_help(conn, message, args): | ||||||
|     embed = discord.Embed() |     embed = discord.Embed() | ||||||
|     embed.colour = discord.Colour.blue() |     embed.colour = discord.Colour.blue() | ||||||
|     embed.title = "PluralKit Help" |     embed.title = "PluralKit Help" | ||||||
|     embed.set_footer(text="<> denotes mandatory arguments, [] denotes optional arguments") |     embed.set_footer( | ||||||
|  |         text="<> denotes mandatory arguments, [] denotes optional arguments") | ||||||
|  |  | ||||||
|     if len(args) > 0 and ("pk;" + args[0]) in command_map: |     if len(args) > 0 and ("pk;" + args[0]) in command_map: | ||||||
|         cmds = ["", ("pk;" + args[0], command_map["pk;" + args[0]])] |         cmds = ["", ("pk;" + args[0], command_map["pk;" + args[0]])] | ||||||
| @@ -386,7 +409,8 @@ async def show_help(conn, message, args): | |||||||
|  |  | ||||||
|     for cmd, subcommands in cmds: |     for cmd, subcommands in cmds: | ||||||
|         for subcmd, (_, usage, description) in subcommands.items(): |         for subcmd, (_, usage, description) in subcommands.items(): | ||||||
|             embed.add_field(name="{} {} {}".format(cmd, subcmd or "", usage or ""), value=description, inline=False) |             embed.add_field(name="{} {} {}".format( | ||||||
|      |                 cmd, subcmd or "", usage or ""), value=description, inline=False) | ||||||
|  |  | ||||||
|     await client.send_message(message.channel, embed=embed) |     await client.send_message(message.channel, embed=embed) | ||||||
|     return True |     return True | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import asyncpg.exceptions | |||||||
|  |  | ||||||
| from pluralkit.bot import logger | from pluralkit.bot import logger | ||||||
|  |  | ||||||
|  |  | ||||||
| async def connect(): | async def connect(): | ||||||
|     while True: |     while True: | ||||||
|         try: |         try: | ||||||
| @@ -12,32 +13,40 @@ async def connect(): | |||||||
|         except (ConnectionError, asyncpg.exceptions.CannotConnectNowError): |         except (ConnectionError, asyncpg.exceptions.CannotConnectNowError): | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|  |  | ||||||
| def db_wrap(func): | def db_wrap(func): | ||||||
|     async def inner(*args, **kwargs): |     async def inner(*args, **kwargs): | ||||||
|         before = time.perf_counter() |         before = time.perf_counter() | ||||||
|         res = await func(*args, **kwargs) |         res = await func(*args, **kwargs) | ||||||
|         after = time.perf_counter() |         after = time.perf_counter() | ||||||
|          |  | ||||||
|         logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000)) |         logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000)) | ||||||
|         return res |         return res | ||||||
|     return inner |     return inner | ||||||
|  |  | ||||||
|  |  | ||||||
| webhook_cache = {} | webhook_cache = {} | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def create_system(conn, system_name: str, system_hid: str): | async def create_system(conn, system_name: str, system_hid: str): | ||||||
|     logger.debug("Creating system (name={}, hid={})".format(system_name, system_hid)) |     logger.debug("Creating system (name={}, hid={})".format( | ||||||
|  |         system_name, system_hid)) | ||||||
|     return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid) |     return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def remove_system(conn, system_id: int): | async def remove_system(conn, system_id: int): | ||||||
|     logger.debug("Deleting system (id={})".format(system_id)) |     logger.debug("Deleting system (id={})".format(system_id)) | ||||||
|     await conn.execute("delete from systems where id = $1", system_id) |     await conn.execute("delete from systems where id = $1", system_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def create_member(conn, system_id: int, member_name: str, member_hid: str): | async def create_member(conn, system_id: int, member_name: str, member_hid: str): | ||||||
|     logger.debug("Creating member (system={}, name={}, hid={})".format(system_id, member_name, member_hid)) |     logger.debug("Creating member (system={}, name={}, hid={})".format( | ||||||
|     return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)  |         system_id, member_name, member_hid)) | ||||||
|  |     return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def delete_member(conn, member_id: int): | async def delete_member(conn, member_id: int): | ||||||
| @@ -45,70 +54,90 @@ async def delete_member(conn, member_id: int): | |||||||
|     await conn.execute("update switches set member = null, member_del = true where member = $1", member_id) |     await conn.execute("update switches set member = null, member_del = true where member = $1", member_id) | ||||||
|     await conn.execute("delete from members where id = $1", member_id) |     await conn.execute("delete from members where id = $1", member_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def link_account(conn, system_id: int, account_id: str): | async def link_account(conn, system_id: int, account_id: str): | ||||||
|     logger.debug("Linking account (account_id={}, system_id={})".format(account_id, system_id)) |     logger.debug("Linking account (account_id={}, system_id={})".format( | ||||||
|  |         account_id, system_id)) | ||||||
|     await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id) |     await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def unlink_account(conn, system_id: int, account_id: str): | async def unlink_account(conn, system_id: int, account_id: str): | ||||||
|     logger.debug("Unlinking account (account_id={}, system_id={})".format(account_id, system_id)) |     logger.debug("Unlinking account (account_id={}, system_id={})".format( | ||||||
|  |         account_id, system_id)) | ||||||
|     await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id) |     await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_linked_accounts(conn, system_id: int): | async def get_linked_accounts(conn, system_id: int): | ||||||
|     return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)] |     return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)] | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_system_by_account(conn, account_id: str): | async def get_system_by_account(conn, account_id: str): | ||||||
|     return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id)) |     return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_system_by_hid(conn, system_hid: str): | async def get_system_by_hid(conn, system_hid: str): | ||||||
|     return await conn.fetchrow("select * from systems where hid = $1", system_hid) |     return await conn.fetchrow("select * from systems where hid = $1", system_hid) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_system(conn, system_id: int): | async def get_system(conn, system_id: int): | ||||||
|     return await conn.fetchrow("select * from systems where id = $1", system_id) |     return await conn.fetchrow("select * from systems where id = $1", system_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_member_by_name(conn, system_id: int, member_name: str): | async def get_member_by_name(conn, system_id: int, member_name: str): | ||||||
|     return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name) |     return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str): | async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str): | ||||||
|     return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid) |     return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_member_by_hid(conn, member_hid: str): | async def get_member_by_hid(conn, member_hid: str): | ||||||
|     return await conn.fetchrow("select * from members where hid = $1", member_hid) |     return await conn.fetchrow("select * from members where hid = $1", member_hid) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_member(conn, member_id: int): | async def get_member(conn, member_id: int): | ||||||
|     return await conn.fetchrow("select * from members where id = $1", member_id) |     return await conn.fetchrow("select * from members where id = $1", member_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_message(conn, message_id: str): | async def get_message(conn, message_id: str): | ||||||
|     return await conn.fetchrow("select * from messages where mid = $1", message_id) |     return await conn.fetchrow("select * from messages where mid = $1", message_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def update_system_field(conn, system_id: int, field: str, value): | async def update_system_field(conn, system_id: int, field: str, value): | ||||||
|     logger.debug("Updating system field (id={}, {}={})".format(system_id, field, value)) |     logger.debug("Updating system field (id={}, {}={})".format( | ||||||
|  |         system_id, field, value)) | ||||||
|     await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id) |     await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def update_member_field(conn, member_id: int, field: str, value): | async def update_member_field(conn, member_id: int, field: str, value): | ||||||
|     logger.debug("Updating member field (id={}, {}={})".format(member_id, field, value)) |     logger.debug("Updating member field (id={}, {}={})".format( | ||||||
|  |         member_id, field, value)) | ||||||
|     await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id) |     await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_all_members(conn, system_id: int): | async def get_all_members(conn, system_id: int): | ||||||
|     return await conn.fetch("select * from members where system = $1", system_id) |     return await conn.fetch("select * from members where system = $1", system_id) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_members_exceeding(conn, system_id: int, length: int): | async def get_members_exceeding(conn, system_id: int, length: int): | ||||||
|     return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length) |     return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_webhook(conn, channel_id: str): | async def get_webhook(conn, channel_id: str): | ||||||
|     if channel_id in webhook_cache: |     if channel_id in webhook_cache: | ||||||
| @@ -117,30 +146,38 @@ async def get_webhook(conn, channel_id: str): | |||||||
|     webhook_cache[channel_id] = res |     webhook_cache[channel_id] = res | ||||||
|     return res |     return res | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str): | async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str): | ||||||
|     logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(channel_id, webhook_id, webhook_token)) |     logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format( | ||||||
|  |         channel_id, webhook_id, webhook_token)) | ||||||
|     await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token) |     await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str): | async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str): | ||||||
|     logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(message_id, channel_id, member_id, sender_id)) |     logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format( | ||||||
|  |         message_id, channel_id, member_id, sender_id)) | ||||||
|     await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id)) |     await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_members_by_account(conn, account_id: str): | async def get_members_by_account(conn, account_id: str): | ||||||
|     # Returns a "chimera" object |     # Returns a "chimera" object | ||||||
|     return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) |     return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str): | async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str): | ||||||
|     await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id)) |     await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_wrap | @db_wrap | ||||||
| async def delete_message(conn, message_id: str): | async def delete_message(conn, message_id: str): | ||||||
|     logger.debug("Deleting message (id={})".format(message_id)) |     logger.debug("Deleting message (id={})".format(message_id)) | ||||||
|     await conn.execute("delete from messages where mid = $1", int(message_id)) |     await conn.execute("delete from messages where mid = $1", int(message_id)) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def create_tables(conn): | async def create_tables(conn): | ||||||
|     await conn.execute("""create table if not exists systems ( |     await conn.execute("""create table if not exists systems ( | ||||||
|         id          serial primary key, |         id          serial primary key, | ||||||
| @@ -190,4 +227,4 @@ async def create_tables(conn): | |||||||
|         id          bigint primary key, |         id          bigint primary key, | ||||||
|         cmd_chans   bigint[], |         cmd_chans   bigint[], | ||||||
|         proxy_chans bigint[] |         proxy_chans bigint[] | ||||||
|     )""") |     )""") | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import aiohttp | |||||||
| from pluralkit import db | from pluralkit import db | ||||||
| from pluralkit.bot import client, logger | from pluralkit.bot import client, logger | ||||||
|  |  | ||||||
|  |  | ||||||
| async def get_webhook(conn, channel): | async def get_webhook(conn, channel): | ||||||
|     async with conn.transaction(): |     async with conn.transaction(): | ||||||
|         # Try to find an existing webhook |         # Try to find an existing webhook | ||||||
| @@ -14,7 +15,8 @@ async def get_webhook(conn, channel): | |||||||
|         if not hook_row: |         if not hook_row: | ||||||
|             async with aiohttp.ClientSession() as session: |             async with aiohttp.ClientSession() as session: | ||||||
|                 req_data = {"name": "PluralKit Proxy Webhook"} |                 req_data = {"name": "PluralKit Proxy Webhook"} | ||||||
|                 req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])} |                 req_headers = { | ||||||
|  |                     "Authorization": "Bot {}".format(os.environ["TOKEN"])} | ||||||
|  |  | ||||||
|                 async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp: |                 async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp: | ||||||
|                     data = await resp.json() |                     data = await resp.json() | ||||||
| @@ -24,17 +26,19 @@ async def get_webhook(conn, channel): | |||||||
|                     # Insert new hook into DB |                     # Insert new hook into DB | ||||||
|                     await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token) |                     await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token) | ||||||
|                     return hook_id, token |                     return hook_id, token | ||||||
|                      |  | ||||||
|         return hook_row["webhook"], hook_row["token"] |         return hook_row["webhook"], hook_row["token"] | ||||||
|  |  | ||||||
|  |  | ||||||
| async def proxy_message(conn, member, message, inner): | async def proxy_message(conn, member, message, inner): | ||||||
|     logger.debug("Proxying message '{}' for member {}".format(inner, member["hid"])) |     logger.debug("Proxying message '{}' for member {}".format( | ||||||
|  |         inner, member["hid"])) | ||||||
|     # Delete the original message |     # Delete the original message | ||||||
|     await client.delete_message(message) |     await client.delete_message(message) | ||||||
|  |  | ||||||
|     # Get the webhook details |     # Get the webhook details | ||||||
|     hook_id, hook_token = await get_webhook(conn, message.channel) |     hook_id, hook_token = await get_webhook(conn, message.channel) | ||||||
|     async with aiohttp.ClientSession() as session:         |     async with aiohttp.ClientSession() as session: | ||||||
|         req_data = { |         req_data = { | ||||||
|             "username": "{} {}".format(member["name"], member["tag"] or "").strip(), |             "username": "{} {}".format(member["name"], member["tag"] or "").strip(), | ||||||
|             "avatar_url": member["avatar_url"], |             "avatar_url": member["avatar_url"], | ||||||
| @@ -49,13 +53,15 @@ async def proxy_message(conn, member, message, inner): | |||||||
|             # Insert new message details into the DB |             # Insert new message details into the DB | ||||||
|             await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id) |             await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def handle_proxying(conn, message): | async def handle_proxying(conn, message): | ||||||
|     # Big fat query to find every member associated with this account |     # Big fat query to find every member associated with this account | ||||||
|     # Returned member object has a few more keys (system tag, for example) |     # Returned member object has a few more keys (system tag, for example) | ||||||
|     members = await db.get_members_by_account(conn, account_id=message.author.id) |     members = await db.get_members_by_account(conn, account_id=message.author.id) | ||||||
|  |  | ||||||
|     # Sort by specificity (members with both prefix and suffix go higher) |     # Sort by specificity (members with both prefix and suffix go higher) | ||||||
|     members = sorted(members, key=lambda x: int(bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True) |     members = sorted(members, key=lambda x: int( | ||||||
|  |         bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True) | ||||||
|  |  | ||||||
|     msg = message.content |     msg = message.content | ||||||
|     for member in members: |     for member in members: | ||||||
| @@ -71,14 +77,14 @@ async def handle_proxying(conn, message): | |||||||
|         if msg.startswith(prefix) and msg.endswith(suffix): |         if msg.startswith(prefix) and msg.endswith(suffix): | ||||||
|             # Extract the actual message contents sans tags |             # Extract the actual message contents sans tags | ||||||
|             if suffix: |             if suffix: | ||||||
|                 inner_message = message.content[len(prefix):-len(suffix)].strip() |                 inner_message = message.content[len( | ||||||
|  |                     prefix):-len(suffix)].strip() | ||||||
|             else: |             else: | ||||||
|                 # Slicing to -0 breaks, don't do that |                 # Slicing to -0 breaks, don't do that | ||||||
|                 inner_message = message.content[len(prefix):].strip() |                 inner_message = message.content[len(prefix):].strip() | ||||||
|  |  | ||||||
|             await proxy_message(conn, member, message, inner_message) |             await proxy_message(conn, member, message, inner_message) | ||||||
|             break |             break | ||||||
|      |  | ||||||
|  |  | ||||||
|  |  | ||||||
| async def handle_reaction(conn, reaction, user): | async def handle_reaction(conn, reaction, user): | ||||||
| @@ -90,4 +96,4 @@ async def handle_reaction(conn, reaction, user): | |||||||
|             if message: |             if message: | ||||||
|                 # If so, delete the message and remove it from the DB |                 # If so, delete the message and remove it from the DB | ||||||
|                 await db.delete_message(conn, message["mid"]) |                 await db.delete_message(conn, message["mid"]) | ||||||
|                 await client.delete_message(reaction.message) |                 await client.delete_message(reaction.message) | ||||||
|   | |||||||
| @@ -9,9 +9,11 @@ import discord | |||||||
| from pluralkit import db | from pluralkit import db | ||||||
| from pluralkit.bot import client, logger | from pluralkit.bot import client, logger | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_hid() -> str: | def generate_hid() -> str: | ||||||
|     return "".join(random.choices(string.ascii_lowercase, k=5)) |     return "".join(random.choices(string.ascii_lowercase, k=5)) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def parse_mention(mention: str) -> discord.User: | async def parse_mention(mention: str) -> discord.User: | ||||||
|     # First try matching mention format |     # First try matching mention format | ||||||
|     match = re.fullmatch("<@!?(\\d+)>", mention) |     match = re.fullmatch("<@!?(\\d+)>", mention) | ||||||
| @@ -27,10 +29,11 @@ async def parse_mention(mention: str) -> discord.User: | |||||||
|     except (ValueError, discord.NotFound): |     except (ValueError, discord.NotFound): | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| async def get_system_fuzzy(conn, key) -> asyncpg.Record: | async def get_system_fuzzy(conn, key) -> asyncpg.Record: | ||||||
|     if isinstance(key, discord.User): |     if isinstance(key, discord.User): | ||||||
|         return await db.get_system_by_account(conn, account_id=key.id) |         return await db.get_system_by_account(conn, account_id=key.id) | ||||||
|      |  | ||||||
|     if isinstance(key, str) and len(key) == 5: |     if isinstance(key, str) and len(key) == 5: | ||||||
|         return await db.get_system_by_hid(conn, system_hid=key) |         return await db.get_system_by_hid(conn, system_hid=key) | ||||||
|  |  | ||||||
| @@ -39,7 +42,8 @@ async def get_system_fuzzy(conn, key) -> asyncpg.Record: | |||||||
|     if system: |     if system: | ||||||
|         return system |         return system | ||||||
|     return None |     return None | ||||||
|      |  | ||||||
|  |  | ||||||
| async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record: | async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record: | ||||||
|     # First search by hid |     # First search by hid | ||||||
|     if system_only: |     if system_only: | ||||||
| @@ -60,6 +64,8 @@ command_map = {} | |||||||
| # Command wrapper | # Command wrapper | ||||||
| # Return True for success, return False for failure | # Return True for success, return False for failure | ||||||
| # Second parameter is the message it'll send. If just False, will print usage | # Second parameter is the message it'll send. If just False, will print usage | ||||||
|  |  | ||||||
|  |  | ||||||
| def command(cmd, subcommand, usage=None, description=None): | def command(cmd, subcommand, usage=None, description=None): | ||||||
|     def wrap(func): |     def wrap(func): | ||||||
|         async def wrapper(conn, message, args): |         async def wrapper(conn, message, args): | ||||||
| @@ -70,12 +76,13 @@ def command(cmd, subcommand, usage=None, description=None): | |||||||
|                     success, msg = res, None |                     success, msg = res, None | ||||||
|                 else: |                 else: | ||||||
|                     success, msg = res |                     success, msg = res | ||||||
|                      |  | ||||||
|                 if not success and not msg: |                 if not success and not msg: | ||||||
|                     # Failure, no message, print usage |                     # Failure, no message, print usage | ||||||
|                     usage_embed = discord.Embed() |                     usage_embed = discord.Embed() | ||||||
|                     usage_embed.colour = discord.Colour.blue() |                     usage_embed.colour = discord.Colour.blue() | ||||||
|                     usage_embed.add_field(name="Usage", value=usage, inline=False) |                     usage_embed.add_field( | ||||||
|  |                         name="Usage", value=usage, inline=False) | ||||||
|  |  | ||||||
|                     await client.send_message(message.channel, embed=usage_embed) |                     await client.send_message(message.channel, embed=usage_embed) | ||||||
|                 elif not success: |                 elif not success: | ||||||
| @@ -103,6 +110,8 @@ def command(cmd, subcommand, usage=None, description=None): | |||||||
| # Member command wrapper | # Member command wrapper | ||||||
| # Tries to find member by first argument | # Tries to find member by first argument | ||||||
| # If system_only=False, allows members from other systems by hid | # If system_only=False, allows members from other systems by hid | ||||||
|  |  | ||||||
|  |  | ||||||
| def member_command(cmd, subcommand, usage=None, description=None, system_only=True): | def member_command(cmd, subcommand, usage=None, description=None, system_only=True): | ||||||
|     def wrap(func): |     def wrap(func): | ||||||
|         async def wrapper(conn, message, args): |         async def wrapper(conn, message, args): | ||||||
| @@ -122,11 +131,12 @@ def member_command(cmd, subcommand, usage=None, description=None, system_only=Tr | |||||||
|  |  | ||||||
|             if member is None: |             if member is None: | ||||||
|                 return False, "Can't find member \"{}\".".format(args[0]) |                 return False, "Can't find member \"{}\".".format(args[0]) | ||||||
|              |  | ||||||
|             return await func(conn, message, member, args[1:]) |             return await func(conn, message, member, args[1:]) | ||||||
|         return command(cmd=cmd, subcommand=subcommand, usage=usage, description=description)(wrapper) |         return command(cmd=cmd, subcommand=subcommand, usage=usage, description=description)(wrapper) | ||||||
|     return wrap |     return wrap | ||||||
|  |  | ||||||
|  |  | ||||||
| async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed: | async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed: | ||||||
|     card = discord.Embed() |     card = discord.Embed() | ||||||
|  |  | ||||||
| @@ -134,7 +144,8 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb | |||||||
|         card.title = system["name"] |         card.title = system["name"] | ||||||
|  |  | ||||||
|     if system["description"]: |     if system["description"]: | ||||||
|         card.add_field(name="Description", value=system["description"], inline=False) |         card.add_field(name="Description", | ||||||
|  |                        value=system["description"], inline=False) | ||||||
|  |  | ||||||
|     if system["tag"]: |     if system["tag"]: | ||||||
|         card.add_field(name="Tag", value=system["tag"]) |         card.add_field(name="Tag", value=system["tag"]) | ||||||
| @@ -158,11 +169,13 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb | |||||||
|         member_texts.append("`{}`: {}".format(member["hid"], member["name"])) |         member_texts.append("`{}`: {}".format(member["hid"], member["name"])) | ||||||
|  |  | ||||||
|     if len(member_texts) > 0: |     if len(member_texts) > 0: | ||||||
|         card.add_field(name="Members", value="\n".join(member_texts), inline=False) |         card.add_field(name="Members", value="\n".join( | ||||||
|  |             member_texts), inline=False) | ||||||
|  |  | ||||||
|     card.set_footer(text="System ID: {}".format(system["hid"])) |     card.set_footer(text="System ID: {}".format(system["hid"])) | ||||||
|     return card |     return card | ||||||
|  |  | ||||||
|  |  | ||||||
| async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed: | async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed: | ||||||
|     card = discord.Embed() |     card = discord.Embed() | ||||||
|     card.set_author(name=member["name"], icon_url=member["avatar_url"]) |     card.set_author(name=member["name"], icon_url=member["avatar_url"]) | ||||||
| @@ -171,18 +184,21 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb | |||||||
|         card.colour = int(member["color"], 16) |         card.colour = int(member["color"], 16) | ||||||
|  |  | ||||||
|     if member["birthday"]: |     if member["birthday"]: | ||||||
|         card.add_field(name="Birthdate", value=member["birthday"].strftime("%b %d, %Y")) |         card.add_field(name="Birthdate", | ||||||
|      |                        value=member["birthday"].strftime("%b %d, %Y")) | ||||||
|  |  | ||||||
|     if member["pronouns"]: |     if member["pronouns"]: | ||||||
|         card.add_field(name="Pronouns", value=member["pronouns"]) |         card.add_field(name="Pronouns", value=member["pronouns"]) | ||||||
|  |  | ||||||
|     if member["prefix"] or member["suffix"]: |     if member["prefix"] or member["suffix"]: | ||||||
|         prefix = member["prefix"] or "" |         prefix = member["prefix"] or "" | ||||||
|         suffix = member["suffix"] or "" |         suffix = member["suffix"] or "" | ||||||
|         card.add_field(name="Proxy Tags", value="{}text{}".format(prefix, suffix)) |         card.add_field(name="Proxy Tags", | ||||||
|  |                        value="{}text{}".format(prefix, suffix)) | ||||||
|  |  | ||||||
|     if member["description"]: |     if member["description"]: | ||||||
|         card.add_field(name="Description", value=member["description"], inline=False) |         card.add_field(name="Description", | ||||||
|  |                        value=member["description"], inline=False) | ||||||
|  |  | ||||||
|     # Get system name and hid |     # Get system name and hid | ||||||
|     system = await db.get_system(conn, system_id=member["system"]) |     system = await db.get_system(conn, system_id=member["system"]) | ||||||
| @@ -192,9 +208,11 @@ async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Emb | |||||||
|         system_value = "`{}`".format(system["hid"]) |         system_value = "`{}`".format(system["hid"]) | ||||||
|     card.add_field(name="System", value=system_value, inline=False) |     card.add_field(name="System", value=system_value, inline=False) | ||||||
|  |  | ||||||
|     card.set_footer(text="System ID: {} | Member ID: {}".format(system["hid"], member["hid"])) |     card.set_footer(text="System ID: {} | Member ID: {}".format( | ||||||
|  |         system["hid"], member["hid"])) | ||||||
|     return card |     return card | ||||||
|  |  | ||||||
|  |  | ||||||
| async def text_input(message, subject): | async def text_input(message, subject): | ||||||
|     await client.send_message(message.channel, "Reply in this channel with the new description you want to set for {}.".format(subject)) |     await client.send_message(message.channel, "Reply in this channel with the new description you want to set for {}.".format(subject)) | ||||||
|     msg = await client.wait_for_message(author=message.author, channel=message.channel) |     msg = await client.wait_for_message(author=message.author, channel=message.channel) | ||||||
| @@ -208,4 +226,4 @@ async def text_input(message, subject): | |||||||
|         return msg.content |         return msg.content | ||||||
|     else: |     else: | ||||||
|         await client.clear_reactions(msg) |         await client.clear_reactions(msg) | ||||||
|         return None |         return None | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user