Refactor command handling
This commit is contained in:
		| @@ -39,26 +39,23 @@ async def on_message(message): | ||||
|     args = message.content.split(" ") | ||||
|  | ||||
|     from pluralkit import proxy, utils | ||||
|      | ||||
|     command_items = utils.command_map.items() | ||||
|     command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True) | ||||
|  | ||||
|     cmd = None | ||||
|     # Look up commands with subcommands | ||||
|     if len(args) >= 2: | ||||
|         lookup = utils.command_map.get((args[0], args[1]), None) | ||||
|         if lookup: | ||||
|             # Curry with arg slice | ||||
|             cmd = lambda c, m, a: lookup[0](conn, message, args[2:]) | ||||
|     # Look up root commands | ||||
|     if not cmd and len(args) >= 1: | ||||
|         lookup = utils.command_map.get((args[0], None), None) | ||||
|         if lookup: | ||||
|             # Curry with arg slice | ||||
|             cmd = lambda c, m, a: lookup[0](conn, message, args[1:]) | ||||
|     prefix = "pk;" | ||||
|     for command, (func, _, _) in command_items: | ||||
|         if message.content.startswith(prefix + command): | ||||
|             args_str = message.content[len(prefix + command):].strip() | ||||
|             args = args_str.split(" ") | ||||
|              | ||||
|             # Splitting on empty string yields one-element array, remove that | ||||
|             if len(args) == 1 and not args[0]: | ||||
|                 args = [] | ||||
|  | ||||
|     # Found anything? run it | ||||
|     if cmd: | ||||
|         async with client.pool.acquire() as conn: | ||||
|             await cmd(conn, message, args) | ||||
|             return | ||||
|             async with client.pool.acquire() as conn: | ||||
|                 await func(conn, message, args) | ||||
|                 return | ||||
|  | ||||
|     # Try doing proxy parsing | ||||
|     async with client.pool.acquire() as conn: | ||||
|   | ||||
| @@ -9,18 +9,7 @@ from pluralkit import db | ||||
| from pluralkit.bot import client, logger | ||||
| from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention | ||||
|  | ||||
| @command(cmd="pk;system", subcommand=None, description="Shows information about your system.") | ||||
| async def this_system_info(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
|     if system is None: | ||||
|         return False, "No system is registered to this account." | ||||
|  | ||||
|     await client.send_message(message.channel, embed=await generate_system_info_card(conn, system)) | ||||
|     return True | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.", basic=True) | ||||
| @command(cmd="system new", usage="[name]", description="Registers a new system to this account.") | ||||
| async def new_system(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -42,7 +31,7 @@ async def new_system(conn, message, args): | ||||
|         return True, "System registered! To begin adding members, use `pk;member new <name>`." | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.", basic=True) | ||||
| @command(cmd="system", usage="[system]", description="Shows information about a system.") | ||||
| async def system_info(conn, message, args): | ||||
|     if len(args) == 0: | ||||
|         # Use sender's system | ||||
| @@ -61,7 +50,7 @@ async def system_info(conn, message, args): | ||||
|     return True | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.") | ||||
| @command(cmd="system name", usage="[name]", description="Renames your system. Leave blank to clear.") | ||||
| async def system_name(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -78,7 +67,7 @@ async def system_name(conn, message, args): | ||||
|         return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared." | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.") | ||||
| @command(cmd="system description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.") | ||||
| async def system_description(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -99,7 +88,7 @@ async def system_description(conn, message, args): | ||||
|         return True, "Description set." if new_description else "Description cleared." | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.") | ||||
| @command(cmd="system tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.") | ||||
| async def system_tag(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -128,24 +117,24 @@ async def system_tag(conn, message, args): | ||||
|     return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared." | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.") | ||||
| async def system_remove(conn, message, args): | ||||
| @command(cmd="system delete", description="Deletes your system from the database ***permanently***.") | ||||
| async def system_delete(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
|     if system is None: | ||||
|         return False, "No system is registered to this account." | ||||
|  | ||||
|     await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"])) | ||||
|     await client.send_message(message.channel, "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"])) | ||||
|  | ||||
|     msg = await client.wait_for_message(author=message.author, channel=message.channel) | ||||
|     if msg.content == system["hid"]: | ||||
|         await db.remove_system(conn, system_id=system["id"]) | ||||
|         return True, "System removed." | ||||
|         return True, "System deleted." | ||||
|     else: | ||||
|         return True, "System removal cancelled." | ||||
|         return True, "System deletion cancelled." | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="link", usage="<account>", description="Links another account to your system.") | ||||
| @command(cmd="system link", usage="<account>", description="Links another account to your system.") | ||||
| async def system_link(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -182,7 +171,7 @@ async def system_link(conn, message, args): | ||||
|         return False, "Account link cancelled." | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.") | ||||
| @command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.") | ||||
| async def system_unlink(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -198,7 +187,7 @@ async def system_unlink(conn, message, args): | ||||
|         await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) | ||||
|         return True, "Account unlinked." | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="fronter", usage="[system]", description="Gets the current fronter in the system.") | ||||
| @command(cmd="system fronter", usage="[system]", description="Gets the current fronter in the system.") | ||||
| async def system_fronter(conn, message, args): | ||||
|     if len(args) == 0: | ||||
|         system = await db.get_system_by_account(conn, message.author.id) | ||||
| @@ -229,7 +218,7 @@ async def system_fronter(conn, message, args): | ||||
|     embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since))) | ||||
|     return True, embed | ||||
|  | ||||
| @command(cmd="pk;system", subcommand="fronthistory", usage="[system]", description="Shows the past 10 switches in the system.") | ||||
| @command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.") | ||||
| async def system_fronthistory(conn, message, args): | ||||
|     if len(args) == 0: | ||||
|         system = await db.get_system_by_account(conn, message.author.id) | ||||
| @@ -256,7 +245,7 @@ async def system_fronthistory(conn, message, args): | ||||
|     embed.title = "Past switches" | ||||
|     return True, embed | ||||
|  | ||||
| @command(cmd="pk;member", subcommand="new", usage="<name>", description="Adds a new member to your system.", basic=True) | ||||
| @command(cmd="member new", usage="<name>", description="Adds a new member to your system.") | ||||
| async def new_member(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -276,13 +265,13 @@ async def new_member(conn, message, args): | ||||
|         return True, "Member \"{}\" (`{}`) registered!".format(name, hid) | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False, basic=True) | ||||
| @member_command(cmd="member info", description="Shows information about a system member.", system_only=False) | ||||
| async def member_info(conn, message, member, args): | ||||
|     await client.send_message(message.channel, embed=await generate_member_info_card(conn, member)) | ||||
|     return True | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.") | ||||
| @member_command(cmd="member color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.") | ||||
| async def member_color(conn, message, member, args): | ||||
|     if len(args) == 0: | ||||
|         color = None | ||||
| @@ -298,7 +287,7 @@ async def member_color(conn, message, member, args): | ||||
|         return True, "Color updated to #{}.".format(color) if color else "Color cleared." | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.") | ||||
| @member_command(cmd="member pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.") | ||||
| async def member_pronouns(conn, message, member, args): | ||||
|     if len(args) == 0: | ||||
|         pronouns = None | ||||
| @@ -310,7 +299,7 @@ async def member_pronouns(conn, message, member, args): | ||||
|         return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared." | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.") | ||||
| @member_command(cmd="member birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.") | ||||
| async def member_birthday(conn, message, member, args): | ||||
|     if len(args) == 0: | ||||
|         new_date = None | ||||
| @@ -326,7 +315,7 @@ async def member_birthday(conn, message, member, args): | ||||
|         return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared." | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.") | ||||
| @member_command(cmd="member description", description="Updates a member's description. Add \"clear\" to clear.") | ||||
| async def member_description(conn, message, member, args): | ||||
|     if len(args) > 0 and args[0] == "clear": | ||||
|         new_description = None | ||||
| @@ -341,7 +330,7 @@ async def member_description(conn, message, member, args): | ||||
|         return True, "Description set." if new_description else "Description cleared." | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.") | ||||
| @member_command(cmd="member remove", description="Removes a member from your system.") | ||||
| async def member_remove(conn, message, member, args): | ||||
|     await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"])) | ||||
|  | ||||
| @@ -353,7 +342,7 @@ async def member_remove(conn, message, member, args): | ||||
|         return True, "Member removal cancelled." | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.", basic=True) | ||||
| @member_command(cmd="member avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.") | ||||
| async def member_avatar(conn, message, member, args): | ||||
|     if len(args) == 0: | ||||
|         avatar_url = None | ||||
| @@ -381,7 +370,7 @@ async def member_avatar(conn, message, member, args): | ||||
|             return True, make_default_embed("Avatar set.").set_image(url=avatar_url) | ||||
|  | ||||
|  | ||||
| @member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", basic=True) | ||||
| @member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).") | ||||
| async def member_proxy(conn, message, member, args): | ||||
|     if len(args) == 0: | ||||
|         prefix, suffix = None, None | ||||
| @@ -412,7 +401,7 @@ async def member_proxy(conn, message, member, args): | ||||
|         return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared." | ||||
|  | ||||
|  | ||||
| @command(cmd="pk;message", subcommand=None, usage="<id>", description="Shows information about a proxied message. Requires the message ID.") | ||||
| @command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.") | ||||
| async def message_info(conn, message, args): | ||||
|     try: | ||||
|         mid = int(args[0]) | ||||
| @@ -455,7 +444,7 @@ async def message_info(conn, message, args): | ||||
|     await client.send_message(message.channel, embed=embed) | ||||
|     return True | ||||
|  | ||||
| @command(cmd="pk;switch", subcommand=None, usage="<name|id>", description="Registers a switch and changes the current fronter.", basic=True) | ||||
| @command(cmd="switch", usage="<name|id>", description="Registers a switch and changes the current fronter.") | ||||
| async def switch_member(conn, message, args): | ||||
|     if len(args) == 0: | ||||
|         return False | ||||
| @@ -479,7 +468,7 @@ async def switch_member(conn, message, args): | ||||
|     await db.add_switch(conn, system_id=system["id"], member_id=member["id"]) | ||||
|     return True, "Switch registered. Current fronter is now {}.".format(member["name"]) | ||||
|  | ||||
| @command(cmd="pk;switch", subcommand="out", description="Registers a switch out, and leaves current fronter blank.") | ||||
| @command(cmd="switch out", description="Registers a switch out, and leaves current fronter blank.") | ||||
| async def switch_out(conn, message, args): | ||||
|     system = await db.get_system_by_account(conn, message.author.id) | ||||
|  | ||||
| @@ -495,7 +484,7 @@ async def switch_out(conn, message, args): | ||||
|     await db.add_switch(conn, system_id=system["id"], member_id=None) | ||||
|     return True, "Switch-out registered." | ||||
|  | ||||
| @command(cmd="pk;mod", subcommand="log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.") | ||||
| @command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.") | ||||
| async def set_log(conn, message, args): | ||||
|     if not message.author.server_permissions.administrator: | ||||
|         return False, "You must be a server administrator to use this command." | ||||
| @@ -525,7 +514,7 @@ def make_help(cmds): | ||||
|                 cmd, subcmd or "", usage or ""), value=description, inline=False) | ||||
|     return embed | ||||
|  | ||||
| @command(cmd="pk;help", subcommand=None, usage="[category]", description="Shows this help message.") | ||||
| @command(cmd="help", usage="[category]", description="Shows this help message.") | ||||
| async def show_help(conn, message, args): | ||||
|     embed = make_default_embed(None) | ||||
|     embed.title = "PluralKit Help" | ||||
|   | ||||
| @@ -91,7 +91,7 @@ command_map = {} | ||||
| # Second parameter is the message it'll send. If just False, will print usage | ||||
|  | ||||
|  | ||||
| def command(cmd, subcommand, usage=None, description=None, basic=False): | ||||
| def command(cmd, usage=None, description=None): | ||||
|     def wrap(func): | ||||
|         async def wrapper(conn, message, args): | ||||
|             res = await func(conn, message, args) | ||||
| @@ -104,7 +104,7 @@ def command(cmd, subcommand, usage=None, description=None, basic=False): | ||||
|  | ||||
|                 if not success and not msg: | ||||
|                     # Failure, no message, print usage | ||||
|                     usage_str = "**Usage:** {} {} {}".format(cmd, subcommand or "", usage or "") | ||||
|                     usage_str = "**Usage:** {} {}".format(cmd, usage or "") | ||||
|                     await client.send_message(message.channel, embed=make_default_embed(usage_str)) | ||||
|                 elif not success: | ||||
|                     # Failure, print message | ||||
| @@ -119,7 +119,7 @@ def command(cmd, subcommand, usage=None, description=None, basic=False): | ||||
|                 # Success, don't print anything | ||||
|  | ||||
|         # Put command in map | ||||
|         command_map[(cmd, subcommand)] = (wrapper, usage, description, basic) | ||||
|         command_map[cmd] = (wrapper, usage, description) | ||||
|         return wrapper | ||||
|     return wrap | ||||
|  | ||||
| @@ -128,7 +128,7 @@ def command(cmd, subcommand, usage=None, description=None, basic=False): | ||||
| # If system_only=False, allows members from other systems by hid | ||||
|  | ||||
|  | ||||
| def member_command(cmd, subcommand, usage=None, description=None, system_only=True, basic=False): | ||||
| def member_command(cmd, usage=None, description=None, system_only=True): | ||||
|     def wrap(func): | ||||
|         async def wrapper(conn, message, args): | ||||
|             # Return if no member param | ||||
| @@ -149,7 +149,7 @@ def member_command(cmd, subcommand, usage=None, description=None, system_only=Tr | ||||
|                 return False, "Can't find member \"{}\".".format(args[0]) | ||||
|  | ||||
|             return await func(conn, message, member, args[1:]) | ||||
|         return command(cmd=cmd, subcommand=subcommand, usage="<name|id> {}".format(usage or ""), description=description, basic=basic)(wrapper) | ||||
|         return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description)(wrapper) | ||||
|     return wrap | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user