diff --git a/.gitignore b/.gitignore index 6f642d60..efee9a79 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ .env .vscode/ .idea/ -venv/ \ No newline at end of file +venv/ + +*.pyc \ No newline at end of file diff --git a/src/pluralkit/bot/__init__.py b/src/pluralkit/bot/__init__.py index e797f178..9fa931a4 100644 --- a/src/pluralkit/bot/__init__.py +++ b/src/pluralkit/bot/__init__.py @@ -75,7 +75,11 @@ class PluralKitBot: pass async def handle_command_dispatch(self, message): - command_items = commands.command_list.items() + async with self.pool.acquire() as conn: + result = await commands.command_dispatch(self.client, message, conn) + return result + + """command_items = commands.command_list.items() command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True) prefix = "pk;" @@ -98,7 +102,7 @@ class PluralKitBot: response_time = (datetime.now() - message.timestamp).total_seconds() await self.stats.report_command(command_name, execution_time, response_time) - return True + return True""" async def handle_proxy_dispatch(self, message): # Try doing proxy parsing diff --git a/src/pluralkit/bot/commands/__init__.py b/src/pluralkit/bot/commands/__init__.py index 0fb141ad..6d7302eb 100644 --- a/src/pluralkit/bot/commands/__init__.py +++ b/src/pluralkit/bot/commands/__init__.py @@ -1,84 +1,111 @@ -import logging -from collections import namedtuple - -import asyncpg import discord +import logging +import re +from typing import Tuple, Optional -import pluralkit -from pluralkit import db -from pluralkit.bot import utils, embeds +from pluralkit import db, System, Member +from pluralkit.bot import embeds, utils logger = logging.getLogger("pluralkit.bot.commands") -command_list = {} -class NoSystemRegistered(Exception): - pass +def next_arg(arg_string: str) -> Tuple[str, Optional[str]]: + if arg_string.startswith("\""): + end_quote = arg_string.find("\"", start=1) + if end_quote > 0: + return arg_string[1:end_quote], arg_string[end_quote + 1:].strip() + else: + return arg_string[1:], None -class CommandContext(namedtuple("CommandContext", ["client", "conn", "message", "system"])): - client: discord.Client - conn: asyncpg.Connection - message: discord.Message - system: pluralkit.System + next_space = arg_string.find(" ") + if next_space >= 0: + return arg_string[:next_space].strip(), arg_string[next_space:].strip() + else: + return arg_string.strip(), None - async def reply(self, message=None, embed=None): - return await self.client.send_message(self.message.channel, message, embed=embed) -class MemberCommandContext(namedtuple("MemberCommandContext", CommandContext._fields + ("member",)), CommandContext): - client: discord.Client - conn: asyncpg.Connection - message: discord.Message - system: pluralkit.System - member: pluralkit.Member +class CommandResponse: + def to_embed(self): + pass -class CommandEntry(namedtuple("CommandEntry", ["command", "function", "usage", "description", "category"])): - pass -def command(cmd, usage=None, description=None, category=None, system_required=True): - def wrap(func): - async def wrapper(client, conn, message, args): - system = await db.get_system_by_account(conn, message.author.id) +class CommandSuccess(CommandResponse): + def __init__(self, text): + self.text = text - if system_required and system is None: - await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one.")) - return - - ctx = CommandContext(client=client, conn=conn, message=message, system=system) - try: - res = await func(ctx, args) + def to_embed(self): + return embeds.success("\u2705 " + self.text) - if res: - embed = res if isinstance(res, discord.Embed) else utils.make_default_embed(res) - await client.send_message(message.channel, embed=embed) - except NoSystemRegistered: - await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one.")) - except Exception: - logger.exception("Exception while handling command {} (args={}, system={})".format(cmd, args, system.hid if system else "(none)")) - # Put command in map - command_list[cmd] = CommandEntry(command=cmd, function=wrapper, usage=usage, description=description, category=category) - return wrapper - return wrap +class CommandError(Exception, CommandResponse): + def __init__(self, embed: str, help: Tuple[str, str] = None): + self.text = embed + self.help = help -def member_command(cmd, usage=None, description=None, category=None, system_only=True): - def wrap(func): - async def wrapper(ctx: CommandContext, args): - # Return if no member param - if len(args) == 0: - return embeds.error("You must pass a member name or ID.") + def to_embed(self): + return embeds.error("\u274c " + self.text, self.help) - # System is allowed to be none if not system_only - system_id = ctx.system.id if ctx.system else None - # And find member by key - member = await utils.get_member_fuzzy(ctx.conn, system_id=system_id, key=args[0], system_only=system_only) - if member is None: - return embeds.error("Can't find member \"{}\".".format(args[0])) +class CommandContext: + def __init__(self, client: discord.Client, message: discord.Message, conn, args: str): + self.client = client + self.message = message + self.conn = conn + self.args = args + + async def get_system(self) -> Optional[System]: + return await db.get_system_by_account(self.conn, self.message.author.id) + + async def ensure_system(self) -> System: + system = await self.get_system() + + if not system: + raise CommandError( + embeds.error("No system registered to this account. Use `pk;system new` to register one.")) + + return system + + def has_next(self) -> bool: + return bool(self.args) + + def pop_str(self, error: CommandError = None) -> str: + if not self.args: + if error: + raise error + return None + + popped, self.args = next_arg(self.args) + return popped + + async def pop_system(self, error: CommandError = None) -> System: + name = self.pop_str(error) + system = await utils.get_system_fuzzy(self.conn, self.client, name) + + if not system: + raise CommandError("Unable to find system '{}'.".format(name)) + + return system + + async def pop_member(self, error: CommandError = None, system_only: bool = True) -> Member: + name = self.pop_str(error) + + if system_only: + system = await self.ensure_system() + else: + system = await self.get_system() + + member = await utils.get_member_fuzzy(self.conn, system.id if system else None, name, system_only) + if not member: + raise CommandError("Unable to find member '{}'{}.".format(name, " in your system" if system_only else "")) + + return member + + def remaining(self): + return self.args + + async def reply(self, content=None, embed=None): + return await self.client.send_message(self.message.channel, content=content, embed=embed) - ctx = MemberCommandContext(client=ctx.client, conn=ctx.conn, message=ctx.message, system=ctx.system, member=member) - return await func(ctx, args[1:]) - return command(cmd=cmd, usage=" {}".format(usage or ""), description=description, category=category, system_required=False)(wrapper) - return wrap import pluralkit.bot.commands.import_commands import pluralkit.bot.commands.member_commands @@ -87,3 +114,69 @@ import pluralkit.bot.commands.misc_commands import pluralkit.bot.commands.mod_commands import pluralkit.bot.commands.switch_commands import pluralkit.bot.commands.system_commands + + +async def run_command(ctx: CommandContext, func): + try: + result = await func(ctx) + if isinstance(result, CommandResponse): + await ctx.reply(embed=result.to_embed()) + except CommandError as e: + await ctx.reply(embed=e.to_embed()) + except Exception: + logger.exception("Exception while dispatching command") + + +async def command_dispatch(client: discord.Client, message: discord.Message, conn) -> bool: + prefix = "^pk(;|!)" + commands = [ + (r"system (new|register|create|init)", system_commands.new_system), + (r"system set", system_commands.system_set), + (r"system link", system_commands.system_link), + (r"system unlink", system_commands.system_unlink), + (r"system fronter", system_commands.system_fronter), + (r"system fronthistory", system_commands.system_fronthistory), + (r"system (delete|remove|destroy|erase)", system_commands.system_delete), + (r"system frontpercent(age)?", system_commands.system_frontpercent), + (r"system", system_commands.system_info), + + (r"import tupperware", import_commands.import_tupperware), + + (r"member (new|create|add|register)", member_commands.new_member), + (r"member set", member_commands.member_set), + (r"member proxy", member_commands.member_proxy), + (r"member (delete|remove|destroy|erase)", member_commands.member_delete), + (r"member", member_commands.member_info), + + (r"message", message_commands.message_info), + + (r"mod log", mod_commands.set_log), + + (r"invite", misc_commands.invite_link), + (r"export", misc_commands.export), + + (r"help", misc_commands.show_help), + + (r"switch move", switch_commands.switch_move), + (r"switch out", switch_commands.switch_out), + (r"switch", switch_commands.switch_member) + ] + + for pattern, func in commands: + regex = re.compile(prefix + pattern, re.IGNORECASE) + + cmd = message.content + match = regex.match(cmd) + if match: + remaining_string = cmd[match.span()[1]:].strip() + + ctx = CommandContext( + client=client, + message=message, + conn=conn, + args=remaining_string + ) + + await run_command(ctx, func) + return True + return False diff --git a/src/pluralkit/bot/commands/import_commands.py b/src/pluralkit/bot/commands/import_commands.py index aa18a0d7..8211b50c 100644 --- a/src/pluralkit/bot/commands/import_commands.py +++ b/src/pluralkit/bot/commands/import_commands.py @@ -1,20 +1,19 @@ import asyncio -import re from datetime import datetime -from typing import List from pluralkit.bot.commands import * logger = logging.getLogger("pluralkit.commands") -@command(cmd="import tupperware", description="Import data from Tupperware.", system_required=False) -async def import_tupperware(ctx: CommandContext, args: List[str]): + +async def import_tupperware(ctx: CommandContext): tupperware_ids = ["431544605209788416", "433916057053560832"] # Main bot instance and Multi-Pals-specific fork - tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if ctx.message.server.get_member(bot_id)] + tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if + ctx.message.server.get_member(bot_id)] # Check if there's any Tupperware bot on the server if not tupperware_members: - return embeds.error("This command only works in a server where the Tupperware bot is also present.") + return CommandError("This command only works in a server where the Tupperware bot is also present.") # Make sure at least one of the bts have send/read permissions here for bot_member in tupperware_members: @@ -24,10 +23,11 @@ async def import_tupperware(ctx: CommandContext, args: List[str]): break else: # If no bots have permission (ie. loop doesn't break), throw error - return embeds.error("This command only works in a channel where the Tupperware bot has read/send access.") + return CommandError("This command only works in a channel where the Tupperware bot has read/send access.") + + await ctx.reply( + embed=embeds.status("Please reply to this message with `tul!list` (or the server equivalent).")) - await ctx.reply(embed=utils.make_default_embed("Please reply to this message with `tul!list` (or the server equivalent).")) - # Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user def ensure_account(tw_msg): if tw_msg.author not in tupperware_members: @@ -38,14 +38,16 @@ async def import_tupperware(ctx: CommandContext, args: List[str]): if not tw_msg.embeds[0]["title"]: return False - - return tw_msg.embeds[0]["title"].startswith("{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator)) + + return tw_msg.embeds[0]["title"].startswith( + "{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator)) tupperware_page_embeds = [] - - tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0, check=ensure_account) + + tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0, + check=ensure_account) if not tw_msg: - return embeds.error("Tupperware import timed out.") + return CommandError("Tupperware import timed out.") tupperware_page_embeds.append(tw_msg.embeds[0]) # Handle Tupperware pagination @@ -74,7 +76,9 @@ async def import_tupperware(ctx: CommandContext, args: List[str]): # If this isn't the same page as last check, edit the status message if new_page != current_page: last_found_time = datetime.utcnow() - await ctx.client.edit_message(status_msg, "Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(len(pages_found), total_pages)) + await ctx.client.edit_message(status_msg, + "Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format( + len(pages_found), total_pages)) current_page = new_page # And sleep a bit to prevent spamming the CPU @@ -82,7 +86,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]): # Make sure it doesn't spin here for too long, time out after 30 seconds since last new page if (datetime.utcnow() - last_found_time).seconds > 30: - return embeds.error("Pagination scan timed out.") + return CommandError("Pagination scan timed out.") # Now that we've got all the pages, put them in the embeds list # Make sure to erase the original one we put in above too @@ -94,7 +98,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]): logger.debug("Importing from Tupperware...") # Create new (nameless) system if there isn't any registered - system = ctx.system + system = ctx.get_system() if system is None: hid = utils.generate_hid() logger.debug("Creating new system (hid={})...".format(hid)) @@ -117,7 +121,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]): if line.startswith("Brackets:"): brackets = line[len("Brackets: "):] member_prefix = brackets[:brackets.index("text")].strip() or None - member_suffix = brackets[brackets.index("text")+4:].strip() or None + member_suffix = brackets[brackets.index("text") + 4:].strip() or None elif line.startswith("Avatar URL: "): url = line[len("Avatar URL: "):] member_avatar = url @@ -138,14 +142,19 @@ async def import_tupperware(ctx: CommandContext, args: List[str]): # Or create a new member hid = utils.generate_hid() logger.debug("Creating new member {} (hid={})...".format(name, hid)) - existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid) + existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, + member_hid=hid) # Save the new stuff in the DB logger.debug("Updating fields...") await db.update_member_field(ctx.conn, member_id=existing_member.id, field="prefix", value=member_prefix) await db.update_member_field(ctx.conn, member_id=existing_member.id, field="suffix", value=member_suffix) - await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url", value=member_avatar) - await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday", value=member_birthdate) - await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", value=member_description) - - return embeds.success("System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.") + await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url", + value=member_avatar) + await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday", + value=member_birthdate) + await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", + value=member_description) + + return CommandSuccess( + "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.") diff --git a/src/pluralkit/bot/commands/member_commands.py b/src/pluralkit/bot/commands/member_commands.py index 867e7906..7bd622cd 100644 --- a/src/pluralkit/bot/commands/member_commands.py +++ b/src/pluralkit/bot/commands/member_commands.py @@ -8,32 +8,36 @@ from pluralkit.bot import help logger = logging.getLogger("pluralkit.commands") -@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands") -async def member_info(ctx: MemberCommandContext, args: List[str]): - await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, ctx.member)) -@command(cmd="member new", usage="", description="Adds a new member to your system.", category="Member commands") -async def new_member(ctx: MemberCommandContext, args: List[str]): - if len(args) == 0: - return embeds.error("You must pass a member name or ID.", help=help.add_member) +async def member_info(ctx: CommandContext): + member = await ctx.pop_member( + error=CommandError("You must pass a member name or ID.", help=help.lookup_member), system_only=False) + await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, member)) - name = " ".join(args) - bounds_error = utils.bounds_check_member_name(name, ctx.system.tag) + +async def new_member(ctx: CommandContext): + system = await ctx.ensure_system() + if not ctx.has_next(): + return CommandError("You must pass a name for the new member.", help=help.add_member) + + name = ctx.remaining() + bounds_error = utils.bounds_check_member_name(name, system.tag) if bounds_error: - return embeds.error(bounds_error) + return CommandError(bounds_error) # TODO: figure out what to do if this errors out on collision on generate_hid hid = utils.generate_hid() # Insert member row - await db.create_member(ctx.conn, system_id=ctx.system.id, member_name=name, member_hid=hid) - return embeds.success("Member \"{}\" (`{}`) registered!".format(name, hid)) + await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid) + return CommandSuccess( + "Member \"{}\" (`{}`) registered! To register their proxy tags, use `pk;member proxy`.".format(name, hid)) -@member_command(cmd="member set", usage=" [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands") -async def member_set(ctx: MemberCommandContext, args: List[str]): - if len(args) == 0: - return embeds.error("You must pass a property name to set.", help=help.edit_member) +async def member_set(ctx: CommandContext): + system = await ctx.ensure_system() + member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member)) + prop = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_member)) allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"] db_properties = { @@ -45,26 +49,27 @@ async def member_set(ctx: MemberCommandContext, args: List[str]): "avatar": "avatar_url" } - prop = args[0] if prop not in allowed_properties: - return embeds.error("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), help=help.edit_member) + return CommandError( + "Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), + help=help.edit_member) - if len(args) >= 2: - value = " ".join(args[1:]) + if ctx.has_next(): + value = " ".join(ctx.remaining()) # Sanity/validity checks and type conversions if prop == "name": - bounds_error = utils.bounds_check_member_name(value, ctx.system.tag) + bounds_error = utils.bounds_check_member_name(value, system.tag) if bounds_error: - return embeds.error(bounds_error) + return CommandError(bounds_error) if prop == "color": match = re.fullmatch("#?([0-9A-Fa-f]{6})", value) if not match: - return embeds.error("Color must be a valid hex color (eg. #ff0000)") + return CommandError("Color must be a valid hex color (eg. #ff0000)") value = match.group(1).lower() - + if prop == "birthdate": try: value = datetime.strptime(value, "%Y-%m-%d").date() @@ -75,7 +80,7 @@ async def member_set(ctx: MemberCommandContext, args: List[str]): # Useful if you want your birthday to be displayed yearless. value = datetime.strptime("0001-" + value, "%Y-%m-%d").date() except ValueError: - return embeds.error("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).") + return CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).") if prop == "avatar": user = await utils.parse_mention(ctx.client, value) @@ -89,41 +94,45 @@ async def member_set(ctx: MemberCommandContext, args: List[str]): if u.scheme in ["http", "https"] and u.netloc and u.path: value = value else: - return embeds.error("Invalid image URL.") + return CommandError("Invalid image URL.") else: # Can't clear member name if prop == "name": - return embeds.error("You can't clear the member name.") + return CommandError("You can't clear the member name.") # Clear from DB value = None db_prop = db_properties[prop] - await db.update_member_field(ctx.conn, member_id=ctx.member.id, field=db_prop, value=value) - - response = embeds.success("{} {}'s {}.".format("Updated" if value else "Cleared", ctx.member.name, prop)) + await db.update_member_field(ctx.conn, member_id=member.id, field=db_prop, value=value) + + response = CommandSuccess("{} {}'s {}.".format("Updated" if value else "Cleared", member.name, prop)) if prop == "avatar" and value: response.set_image(url=value) if prop == "color" and value: response.colour = int(value, 16) return response -@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands") -async def member_proxy(ctx: MemberCommandContext, args: List[str]): - if len(args) == 0: + +async def member_proxy(ctx: CommandContext): + await ctx.ensure_system() + member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.member_proxy)) + + if not ctx.has_next(): prefix, suffix = None, None else: # Sanity checking - example = " ".join(args) + example = ctx.remaining() if "text" not in example: - return embeds.error("Example proxy message must contain the string 'text'.", help=help.member_proxy) + return CommandError("Example proxy message must contain the string 'text'.", help=help.member_proxy) if example.count("text") != 1: - return embeds.error("Example proxy message must contain the string 'text' exactly once.", help=help.member_proxy) + return CommandError("Example proxy message must contain the string 'text' exactly once.", + help=help.member_proxy) # Extract prefix and suffix prefix = example[:example.index("text")].strip() - suffix = example[example.index("text")+4:].strip() + suffix = example[example.index("text") + 4:].strip() logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix)) # DB stores empty strings as None, make that work @@ -133,17 +142,22 @@ async def member_proxy(ctx: MemberCommandContext, args: List[str]): suffix = None async with ctx.conn.transaction(): - await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="prefix", value=prefix) - await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="suffix", value=suffix) - return embeds.success("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.") + await db.update_member_field(ctx.conn, member_id=member.id, field="prefix", value=prefix) + await db.update_member_field(ctx.conn, member_id=member.id, field="suffix", value=suffix) + return CommandSuccess("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.") -@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands") -async def member_delete(ctx: MemberCommandContext, args: List[str]): - await ctx.reply("Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(ctx.member.name, ctx.member.hid)) + +async def member_delete(ctx: CommandContext): + await ctx.ensure_system() + member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member)) + + await ctx.reply( + "Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format( + member.name, member.hid)) msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0) - if msg and msg.content.lower() == ctx.member.hid.lower(): - await db.delete_member(ctx.conn, member_id=ctx.member.id) - return embeds.success("Member deleted.") + if msg and msg.content.lower() == member.hid.lower(): + await db.delete_member(ctx.conn, member_id=member.id) + return CommandSuccess("Member deleted.") else: - return embeds.error("Member deletion cancelled.") \ No newline at end of file + return CommandError("Member deletion cancelled.") diff --git a/src/pluralkit/bot/commands/message_commands.py b/src/pluralkit/bot/commands/message_commands.py index 9570cc07..5421ca5f 100644 --- a/src/pluralkit/bot/commands/message_commands.py +++ b/src/pluralkit/bot/commands/message_commands.py @@ -1,27 +1,21 @@ -import logging -from typing import List - -from pluralkit.bot import utils, embeds, help +from pluralkit.bot import help from pluralkit.bot.commands import * logger = logging.getLogger("pluralkit.commands") -@command(cmd="message", usage="", description="Shows information about a proxied message. Requires the message ID.", - category="Message commands", system_required=False) -async def message_info(ctx: CommandContext, args: List[str]): - if len(args) == 0: - return embeds.error("You must pass a message ID.", help=help.message_lookup) +async def message_info(ctx: CommandContext): + mid_str = ctx.pop_str(CommandError("You must pass a message ID.", help=help.message_lookup)) try: - mid = int(args[0]) + mid = int(mid_str) except ValueError: - return embeds.error("You must pass a valid number as a message ID.", help=help.message_lookup) + return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup) # Find the message in the DB message = await db.get_message(ctx.conn, str(mid)) if not message: - raise embeds.error("Message with ID '{}' not found.".format(args[0])) + raise CommandError("Message with ID '{}' not found.".format(mid)) # Get the original sender of the messages try: @@ -49,9 +43,9 @@ async def message_info(ctx: CommandContext, args: List[str]): embed.add_field(name="Sent by", value=sender_name) - if message.content: # Content can be empty string if there's an attachment + if message.content: # Content can be empty string if there's an attachment embed.add_field(name="Content", value=message.content, inline=False) embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty) - return embed + await ctx.reply(embed=embed) diff --git a/src/pluralkit/bot/commands/misc_commands.py b/src/pluralkit/bot/commands/misc_commands.py index 847b87ab..fde72a0a 100644 --- a/src/pluralkit/bot/commands/misc_commands.py +++ b/src/pluralkit/bot/commands/misc_commands.py @@ -12,13 +12,13 @@ from pluralkit.bot.commands import * logger = logging.getLogger("pluralkit.commands") -@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.", system_required=False) -async def show_help(ctx: CommandContext, args: List[str]): - embed = utils.make_default_embed("") - embed.title = "PluralKit Help" - embed.set_footer(text="By Astrid (Ske#6201, or 'qoxvy' on PK) | GitHub: https://github.com/xSke/PluralKit/") - category = args[0] if len(args) > 0 else None +async def show_help(ctx: CommandContext): + embed = embeds.status("") + embed.title = "PluralKit Help" + embed.set_footer(text="By Astrid (Ske#6201; pk;member qoxvy) | GitHub: https://github.com/xSke/PluralKit/") + + category = ctx.pop_str() if ctx.has_next() else None from pluralkit.bot.help import help_pages if category in help_pages: @@ -28,12 +28,12 @@ async def show_help(ctx: CommandContext, args: List[str]): else: embed.description = text else: - return embeds.error("Unknown help page '{}'.".format(category)) + return CommandError("Unknown help page '{}'.".format(category)) - return embed + await ctx.reply(embed=embed) -@command(cmd="invite", description="Generates an invite link for this bot.", system_required=False) -async def invite_link(ctx: CommandContext, args: List[str]): + +async def invite_link(ctx: CommandContext): client_id = os.environ["CLIENT_ID"] permissions = discord.Permissions() @@ -47,15 +47,16 @@ async def invite_link(ctx: CommandContext, args: List[str]): url = oauth_url(client_id, permissions) logger.debug("Sending invite URL: {}".format(url)) - return embeds.success("Use this link to add PluralKit to your server: {}".format(url)) + return CommandSuccess("Use this link to add PluralKit to your server: {}".format(url)) -@command(cmd="export", description="Exports system data to a machine-readable format.") -async def export(ctx: CommandContext, args: List[str]): - members = await db.get_all_members(ctx.conn, ctx.system.id) - accounts = await db.get_linked_accounts(ctx.conn, ctx.system.id) - switches = await pluralkit.utils.get_front_history(ctx.conn, ctx.system.id, 999999) - system = ctx.system +async def export(ctx: CommandContext): + system = await ctx.ensure_system() + + members = await db.get_all_members(ctx.conn, system.id) + accounts = await db.get_linked_accounts(ctx.conn, system.id) + switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 999999) + data = { "name": system.name, "id": system.hid, @@ -87,4 +88,4 @@ async def export(ctx: CommandContext, args: List[str]): } f = io.BytesIO(json.dumps(data).encode("utf-8")) - await ctx.client.send_file(ctx.message.channel, f, filename="system.json") \ No newline at end of file + await ctx.client.send_file(ctx.message.channel, f, filename="system.json") diff --git a/src/pluralkit/bot/commands/mod_commands.py b/src/pluralkit/bot/commands/mod_commands.py index a1c928cd..7804113e 100644 --- a/src/pluralkit/bot/commands/mod_commands.py +++ b/src/pluralkit/bot/commands/mod_commands.py @@ -1,24 +1,20 @@ -import logging -from typing import List - -from pluralkit.bot import utils, embeds from pluralkit.bot.commands import * logger = logging.getLogger("pluralkit.commands") -@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands", system_required=False) -async def set_log(ctx: CommandContext, args: List[str]): + +async def set_log(ctx: CommandContext): if not ctx.message.author.server_permissions.administrator: - return embeds.error("You must be a server administrator to use this command.") - + return CommandError("You must be a server administrator to use this command.") + server = ctx.message.server - if len(args) == 0: + if not ctx.has_next(): channel_id = None else: - channel = utils.parse_channel_mention(args[0], server=server) + channel = utils.parse_channel_mention(ctx.pop_str(), server=server) if not channel: - return embeds.error("Channel not found.") + return CommandError("Channel not found.") channel_id = channel.id await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id) - return embeds.success("Updated logging channel." if channel_id else "Cleared logging channel.") + return CommandSuccess("Updated logging channel." if channel_id else "Cleared logging channel.") diff --git a/src/pluralkit/bot/commands/switch_commands.py b/src/pluralkit/bot/commands/switch_commands.py index 1b77af14..c58ecc9e 100644 --- a/src/pluralkit/bot/commands/switch_commands.py +++ b/src/pluralkit/bot/commands/switch_commands.py @@ -1,120 +1,130 @@ -from datetime import datetime -import logging -from typing import List - import dateparser import humanize +from datetime import datetime, timezone +from typing import List import pluralkit.utils -from pluralkit import Member -from pluralkit.bot import utils, embeds, help +from pluralkit.bot import help from pluralkit.bot.commands import * logger = logging.getLogger("pluralkit.commands") -@command(cmd="switch", usage=" [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands") -async def switch_member(ctx: MemberCommandContext, args: List[str]): - if len(args) == 0: - return embeds.error("You must pass at least one member name or ID to register a switch to.", help=help.switch_register) + +async def switch_member(ctx: CommandContext): + system = await ctx.ensure_system() + + if not ctx.has_next(): + return CommandError("You must pass at least one member name or ID to register a switch to.", + help=help.switch_register) members: List[Member] = [] - for member_name in args: + for member_name in ctx.remaining().split(" "): # Find the member - member = await utils.get_member_fuzzy(ctx.conn, ctx.system.id, member_name) + member = await utils.get_member_fuzzy(ctx.conn, system.id, member_name) if not member: - return embeds.error("Couldn't find member \"{}\".".format(member_name)) + return CommandError("Couldn't find member \"{}\".".format(member_name)) members.append(member) # Compare requested switch IDs and existing fronter IDs to check for existing switches # Lists, because order matters, it makes sense to just swap fronters member_ids = [member.id for member in members] - fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, ctx.system.id))[0] + fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0] if member_ids == fronter_ids: if len(members) == 1: - return embeds.error("{} is already fronting.".format(members[0].name)) - return embeds.error("Members {} are already fronting.".format(", ".join([m.name for m in members]))) + return CommandError("{} is already fronting.".format(members[0].name)) + return CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members]))) # Also make sure there aren't any duplicates if len(set(member_ids)) != len(member_ids): - return embeds.error("Duplicate members in member list.") + return CommandError("Duplicate members in member list.") # Log the switch async with ctx.conn.transaction(): - switch_id = await db.add_switch(ctx.conn, system_id=ctx.system.id) + switch_id = await db.add_switch(ctx.conn, system_id=system.id) for member in members: await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id) if len(members) == 1: - return embeds.success("Switch registered. Current fronter is now {}.".format(members[0].name)) + return CommandSuccess("Switch registered. Current fronter is now {}.".format(members[0].name)) else: - return embeds.success("Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members]))) + return CommandSuccess( + "Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members]))) + + +async def switch_out(ctx: CommandContext): + system = await ctx.ensure_system() -@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands") -async def switch_out(ctx: MemberCommandContext, args: List[str]): # Get current fronters - fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=ctx.system.id) + fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=system.id) if not fronters: - return embeds.error("There's already no one in front.") + return CommandError("There's already no one in front.") # Log it, and don't log any members - await db.add_switch(ctx.conn, system_id=ctx.system.id) - return embeds.success("Switch-out registered.") + await db.add_switch(ctx.conn, system_id=system.id) + return CommandSuccess("Switch-out registered.") -@command(cmd="switch move", usage="