diff --git a/src/pluralkit/bot/commands/__init__.py b/src/pluralkit/bot/commands/__init__.py index e0b10cae..e439b590 100644 --- a/src/pluralkit/bot/commands/__init__.py +++ b/src/pluralkit/bot/commands/__init__.py @@ -27,26 +27,13 @@ def next_arg(arg_string: str) -> Tuple[str, Optional[str]]: return arg_string.strip(), None -class CommandResponse: - def to_embed(self): - pass - - -class CommandSuccess(CommandResponse): - def __init__(self, text): - self.text = text - - def to_embed(self): - return embeds.success("\u2705 " + self.text) - - -class CommandError(Exception, CommandResponse): +class CommandError(Exception): def __init__(self, text: str, help: Tuple[str, str] = None): self.text = text self.help = help - def to_embed(self): - return embeds.error("\u274c " + self.text, self.help) + def format(self): + return "\u274c " + self.text, embeds.error("", self.help) if self.help else None class CommandContext: @@ -108,13 +95,19 @@ class CommandContext: async def reply(self, content=None, embed=None): return await self.message.channel.send(content=content, embed=embed) + async def reply_ok(self, content=None, embed=None): + return await self.reply(content="\u2705 {}".format(content or ""), embed=embed) + async def confirm_react(self, user: Union[discord.Member, discord.User], message: str): message = await self.reply(message) await message.add_reaction("\u2705") # Checkmark await message.add_reaction("\u274c") # Red X try: - reaction, _ = await self.client.wait_for("reaction_add", check=lambda r, u: u.id == user.id and r.emoji in ["\u2705", "\u274c"], timeout=60.0*5) + reaction, _ = await self.client.wait_for("reaction_add", + check=lambda r, u: u.id == user.id and r.emoji in ["\u2705", + "\u274c"], + timeout=60.0 * 5) return reaction.emoji == "\u2705" except asyncio.TimeoutError: raise CommandError("Timed out - try again.") @@ -123,7 +116,9 @@ class CommandContext: await self.reply(message) try: - message = await self.client.wait_for("message", check=lambda m: m.channel.id == channel.id and m.author.id == user.id, timeout=60.0*5) + message = await self.client.wait_for("message", + check=lambda m: m.channel.id == channel.id and m.author.id == user.id, + timeout=60.0 * 5) return message.content.lower() == confirm_text.lower() except asyncio.TimeoutError: raise CommandError("Timed out - try again.") @@ -142,10 +137,9 @@ import pluralkit.bot.commands.system_commands async def run_command(ctx: CommandContext, func): try: result = await func(ctx) - if isinstance(result, CommandResponse): - await ctx.reply(embed=result.to_embed()) except CommandError as e: - await ctx.reply(embed=e.to_embed()) + content, embed = e.format() + await ctx.reply(content=content, embed=embed) async def command_dispatch(client: discord.Client, message: discord.Message, conn) -> bool: diff --git a/src/pluralkit/bot/commands/api_commands.py b/src/pluralkit/bot/commands/api_commands.py index 5670acd3..a3597635 100644 --- a/src/pluralkit/bot/commands/api_commands.py +++ b/src/pluralkit/bot/commands/api_commands.py @@ -1,7 +1,7 @@ import logging from discord import DMChannel -from pluralkit.bot.commands import CommandContext, CommandSuccess +from pluralkit.bot.commands import CommandContext logger = logging.getLogger("pluralkit.commands") disclaimer = "Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure. If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`." @@ -10,7 +10,7 @@ async def reply_dm(ctx: CommandContext, message: str): await ctx.message.author.send(message) if not isinstance(ctx.message.channel, DMChannel): - return CommandSuccess("DM'd!") + await ctx.reply_ok("DM'd!") async def get_token(ctx: CommandContext): system = await ctx.ensure_system() diff --git a/src/pluralkit/bot/commands/import_commands.py b/src/pluralkit/bot/commands/import_commands.py index 6c6fa2d5..c714701b 100644 --- a/src/pluralkit/bot/commands/import_commands.py +++ b/src/pluralkit/bot/commands/import_commands.py @@ -12,13 +12,13 @@ async def import_tupperware(ctx: CommandContext): # Check if there's a Tupperware bot on the server if not tupperware_member: - return CommandError("This command only works in a server where the Tupperware bot is also present.") + raise CommandError("This command only works in a server where the Tupperware bot is also present.") # Make sure at the bot has send/read permissions here channel_permissions = ctx.message.channel.permissions_for(tupperware_member) if not (channel_permissions.read_messages and channel_permissions.send_messages): # If it doesn't, throw error - return CommandError("This command only works in a channel where the Tupperware bot has read/send access.") + raise CommandError("This command only works in a channel where the Tupperware bot has read/send access.") await ctx.reply( embed=embeds.status("Please reply to this message with `tul!list` (or the server equivalent).")) @@ -44,7 +44,7 @@ async def import_tupperware(ctx: CommandContext): tw_msg: discord.Message = await ctx.client.wait_for("message", check=ensure_account, timeout=60.0 * 5) if not tw_msg: - return CommandError("Tupperware import timed out.") + raise CommandError("Tupperware import timed out.") tupperware_page_embeds.append(tw_msg.embeds[0].to_dict()) # Handle Tupperware pagination @@ -82,7 +82,7 @@ async def import_tupperware(ctx: CommandContext): # Make sure it doesn't spin here for too long, time out after 30 seconds since last new page if (datetime.utcnow() - last_found_time).seconds > 30: - return CommandError("Pagination scan timed out.") + raise CommandError("Pagination scan timed out.") # Now that we've got all the pages, put them in the embeds list # Make sure to erase the original one we put in above too @@ -152,5 +152,5 @@ async def import_tupperware(ctx: CommandContext): await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", value=member_description) - return CommandSuccess( + await ctx.reply_ok( "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.") diff --git a/src/pluralkit/bot/commands/member_commands.py b/src/pluralkit/bot/commands/member_commands.py index 091b6cda..4cf0fa09 100644 --- a/src/pluralkit/bot/commands/member_commands.py +++ b/src/pluralkit/bot/commands/member_commands.py @@ -17,16 +17,16 @@ async def member_info(ctx: CommandContext): async def new_member(ctx: CommandContext): system = await ctx.ensure_system() if not ctx.has_next(): - return CommandError("You must pass a name for the new member.", help=help.add_member) + raise CommandError("You must pass a name for the new member.", help=help.add_member) new_name = ctx.remaining() try: member = await system.create_member(ctx.conn, new_name) except PluralKitError as e: - return CommandError(e.message) + raise CommandError(e.message) - return CommandSuccess( + await ctx.reply_ok( "Member \"{}\" (`{}`) registered! To register their proxy tags, use `pk;member proxy`.".format(new_name, member.hid)) @@ -78,7 +78,7 @@ async def member_set(ctx: CommandContext): } if property_name not in properties: - return CommandError( + raise CommandError( "Unknown property {}. Allowed properties are {}.".format(property_name, ", ".join(properties.keys())), help=help.edit_system) @@ -87,14 +87,13 @@ async def member_set(ctx: CommandContext): try: await properties[property_name](ctx.conn, value) except PluralKitError as e: - return CommandError(e.message) + raise CommandError(e.message) - response = CommandSuccess("{} member {}.".format("Updated" if value else "Cleared", property_name)) # if prop == "avatar" and value: # response.set_image(url=value) # if prop == "color" and value: # response.colour = int(value, 16) - return response + await ctx.reply_ok("{} member {}.".format("Updated" if value else "Cleared", property_name)) async def member_proxy(ctx: CommandContext): @@ -107,10 +106,10 @@ async def member_proxy(ctx: CommandContext): # Sanity checking example = ctx.remaining() if "text" not in example: - return CommandError("Example proxy message must contain the string 'text'.", help=help.member_proxy) + raise CommandError("Example proxy message must contain the string 'text'.", help=help.member_proxy) if example.count("text") != 1: - return CommandError("Example proxy message must contain the string 'text' exactly once.", + raise CommandError("Example proxy message must contain the string 'text' exactly once.", help=help.member_proxy) # Extract prefix and suffix @@ -126,7 +125,7 @@ async def member_proxy(ctx: CommandContext): async with ctx.conn.transaction(): await member.set_proxy_tags(ctx.conn, prefix, suffix) - return CommandSuccess("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.") + await ctx.reply_ok("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.") async def member_delete(ctx: CommandContext): @@ -135,7 +134,7 @@ async def member_delete(ctx: CommandContext): delete_confirm_msg = "Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(member.name, member.hid) if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, member.hid, delete_confirm_msg): - return CommandError("Member deletion cancelled.") + raise CommandError("Member deletion cancelled.") await member.delete(ctx.conn) - return CommandSuccess("Member deleted.") + await ctx.reply_ok("Member deleted.") diff --git a/src/pluralkit/bot/commands/message_commands.py b/src/pluralkit/bot/commands/message_commands.py index 0fa1aafd..7f7c20f3 100644 --- a/src/pluralkit/bot/commands/message_commands.py +++ b/src/pluralkit/bot/commands/message_commands.py @@ -21,7 +21,7 @@ async def message_info(ctx: CommandContext): try: mid = int(mid_str) except ValueError: - return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup) + raise CommandError("You must pass a valid number as a message ID.", help=help.message_lookup) # Find the message in the DB message = await db.get_message(ctx.conn, mid) @@ -59,4 +59,4 @@ async def message_info(ctx: CommandContext): embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty) - await ctx.reply(embed=embed) + await ctx.reply_ok(embed=embed) diff --git a/src/pluralkit/bot/commands/misc_commands.py b/src/pluralkit/bot/commands/misc_commands.py index 803532c4..bdf6c898 100644 --- a/src/pluralkit/bot/commands/misc_commands.py +++ b/src/pluralkit/bot/commands/misc_commands.py @@ -28,7 +28,7 @@ async def show_help(ctx: CommandContext): else: embed.description = text else: - return CommandError("Unknown help page '{}'.".format(category)) + raise CommandError("Unknown help page '{}'.".format(category)) await ctx.reply(embed=embed) @@ -47,7 +47,7 @@ async def invite_link(ctx: CommandContext): url = oauth_url(client_id, permissions) logger.debug("Sending invite URL: {}".format(url)) - return CommandSuccess("Use this link to add PluralKit to your server: {}".format(url)) + await ctx.reply_ok("Use this link to add PluralKit to your server: {}".format(url)) async def export(ctx: CommandContext): diff --git a/src/pluralkit/bot/commands/mod_commands.py b/src/pluralkit/bot/commands/mod_commands.py index 8d532825..1d9ba30a 100644 --- a/src/pluralkit/bot/commands/mod_commands.py +++ b/src/pluralkit/bot/commands/mod_commands.py @@ -5,19 +5,19 @@ logger = logging.getLogger("pluralkit.commands") async def set_log(ctx: CommandContext): if not ctx.message.author.guild_permissions.administrator: - return CommandError("You must be a server administrator to use this command.") + raise CommandError("You must be a server administrator to use this command.") server = ctx.message.guild if not server: - return CommandError("This command can not be run in a DM.") + raise CommandError("This command can not be run in a DM.") if not ctx.has_next(): channel_id = None else: channel = utils.parse_channel_mention(ctx.pop_str(), server=server) if not channel: - return CommandError("Channel not found.") + raise CommandError("Channel not found.") channel_id = channel.id await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id) - return CommandSuccess("Updated logging channel." if channel_id else "Cleared logging channel.") + await ctx.reply_ok("Updated logging channel." if channel_id else "Cleared logging channel.") diff --git a/src/pluralkit/bot/commands/switch_commands.py b/src/pluralkit/bot/commands/switch_commands.py index 75496013..85175777 100644 --- a/src/pluralkit/bot/commands/switch_commands.py +++ b/src/pluralkit/bot/commands/switch_commands.py @@ -15,7 +15,7 @@ async def switch_member(ctx: CommandContext): system = await ctx.ensure_system() if not ctx.has_next(): - return CommandError("You must pass at least one member name or ID to register a switch to.", + raise CommandError("You must pass at least one member name or ID to register a switch to.", help=help.switch_register) members: List[Member] = [] @@ -23,7 +23,7 @@ async def switch_member(ctx: CommandContext): # Find the member member = await utils.get_member_fuzzy(ctx.conn, system.id, member_name) if not member: - return CommandError("Couldn't find member \"{}\".".format(member_name)) + raise CommandError("Couldn't find member \"{}\".".format(member_name)) members.append(member) # Compare requested switch IDs and existing fronter IDs to check for existing switches @@ -32,12 +32,12 @@ async def switch_member(ctx: CommandContext): fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0] if member_ids == fronter_ids: if len(members) == 1: - return CommandError("{} is already fronting.".format(members[0].name)) - return CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members]))) + raise CommandError("{} is already fronting.".format(members[0].name)) + raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members]))) # Also make sure there aren't any duplicates if len(set(member_ids)) != len(member_ids): - return CommandError("Duplicate members in member list.") + raise CommandError("Duplicate members in member list.") # Log the switch async with ctx.conn.transaction(): @@ -46,9 +46,9 @@ async def switch_member(ctx: CommandContext): await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id) if len(members) == 1: - return CommandSuccess("Switch registered. Current fronter is now {}.".format(members[0].name)) + await ctx.reply_ok("Switch registered. Current fronter is now {}.".format(members[0].name)) else: - return CommandSuccess( + await ctx.reply_ok( "Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members]))) @@ -58,17 +58,17 @@ async def switch_out(ctx: CommandContext): # Get current fronters fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=system.id) if not fronters: - return CommandError("There's already no one in front.") + raise CommandError("There's already no one in front.") # Log it, and don't log any members await db.add_switch(ctx.conn, system_id=system.id) - return CommandSuccess("Switch-out registered.") + await ctx.reply_ok("Switch-out registered.") async def switch_move(ctx: CommandContext): system = await ctx.ensure_system() if not ctx.has_next(): - return CommandError("You must pass a time to move the switch to.", help=help.switch_move) + raise CommandError("You must pass a time to move the switch to.", help=help.switch_move) # Parse the time to move to new_time = dateparser.parse(ctx.remaining(), languages=["en"], settings={ @@ -76,18 +76,18 @@ async def switch_move(ctx: CommandContext): "RETURN_AS_TIMEZONE_AWARE": False }) if not new_time: - return CommandError("'{}' can't be parsed as a valid time.".format(ctx.remaining()), help=help.switch_move) + raise CommandError("'{}' can't be parsed as a valid time.".format(ctx.remaining()), help=help.switch_move) # Make sure the time isn't in the future if new_time > datetime.utcnow(): - return CommandError("Can't move switch to a time in the future.", help=help.switch_move) + raise CommandError("Can't move switch to a time in the future.", help=help.switch_move) # Make sure it all runs in a big transaction for atomicity async with ctx.conn.transaction(): # Get the last two switches to make sure the switch to move isn't before the second-last switch last_two_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=2) if len(last_two_switches) == 0: - return CommandError("There are no registered switches for this system.") + raise CommandError("There are no registered switches for this system.") last_timestamp, last_fronters = last_two_switches[0] if len(last_two_switches) > 1: @@ -95,7 +95,7 @@ async def switch_move(ctx: CommandContext): if new_time < second_last_timestamp: time_str = humanize.naturaltime(pluralkit.utils.fix_time(second_last_timestamp)) - return CommandError( + raise CommandError( "Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str)) # Display the confirmation message w/ humanized times @@ -108,14 +108,14 @@ async def switch_move(ctx: CommandContext): # Confirm with user switch_confirm_message = "This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative) if not await ctx.confirm_react(ctx.message.author, switch_confirm_message): - return CommandError("Switch move cancelled.") + raise CommandError("Switch move cancelled.") # DB requires the actual switch ID which our utility method above doesn't return, do this manually switch_id = (await db.front_history(ctx.conn, system.id, count=1))[0]["id"] # Change the switch in the DB await db.move_last_switch(ctx.conn, system.id, switch_id, new_time) - return CommandSuccess("Switch moved.") + await ctx.reply_ok("Switch moved.") diff --git a/src/pluralkit/bot/commands/system_commands.py b/src/pluralkit/bot/commands/system_commands.py index 76be807a..04e7f6d6 100644 --- a/src/pluralkit/bot/commands/system_commands.py +++ b/src/pluralkit/bot/commands/system_commands.py @@ -26,9 +26,9 @@ async def new_system(ctx: CommandContext): try: await System.create_system(ctx.conn, ctx.message.author.id, system_name) except ExistingSystemError as e: - return CommandError(e.message) + raise CommandError(e.message) - return CommandSuccess("System registered! To begin adding members, use `pk;member new `.") + await ctx.reply_ok("System registered! To begin adding members, use `pk;member new `.") async def system_set(ctx: CommandContext): @@ -54,7 +54,7 @@ async def system_set(ctx: CommandContext): } if property_name not in properties: - return CommandError( + raise CommandError( "Unknown property {}. Allowed properties are {}.".format(property_name, ", ".join(properties.keys())), help=help.edit_system) @@ -63,12 +63,11 @@ async def system_set(ctx: CommandContext): try: await properties[property_name](ctx.conn, value) except PluralKitError as e: - return CommandError(e.message) + raise CommandError(e.message) - response = CommandSuccess("{} system {}.".format("Updated" if value else "Cleared", property_name)) + await ctx.reply_ok("{} system {}.".format("Updated" if value else "Cleared", property_name)) # if prop == "avatar" and value: # response.set_image(url=value) - return response async def system_link(ctx: CommandContext): @@ -78,18 +77,18 @@ async def system_link(ctx: CommandContext): # Find account to link linkee = await utils.parse_mention(ctx.client, account_name) if not linkee: - return CommandError("Account not found.") + raise CommandError("Account not found.") # Make sure account doesn't already have a system account_system = await System.get_by_account(ctx.conn, linkee.id) if account_system: - return CommandError(AccountAlreadyLinkedError(account_system).message) + raise CommandError(AccountAlreadyLinkedError(account_system).message) if not await ctx.confirm_react(linkee, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention)): - return CommandError("Account link cancelled.") + raise CommandError("Account link cancelled.") await system.link_account(ctx.conn, linkee.id) - return CommandSuccess("Account linked to system.") + await ctx.reply_ok("Account linked to system.") async def system_unlink(ctx: CommandContext): @@ -98,9 +97,9 @@ async def system_unlink(ctx: CommandContext): try: await system.unlink_account(ctx.conn, ctx.message.author.id) except UnlinkingLastAccountError as e: - return CommandError(e.message) + raise CommandError(e.message) - return CommandSuccess("Account unlinked.") + await ctx.reply_ok("Account unlinked.") async def system_fronter(ctx: CommandContext): @@ -149,10 +148,10 @@ async def system_delete(ctx: CommandContext): delete_confirm_msg = "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format( system.hid) if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, system.hid, delete_confirm_msg): - return CommandError("System deletion cancelled.") + raise CommandError("System deletion cancelled.") await system.delete(ctx.conn) - return CommandSuccess("System deleted.") + await ctx.reply_ok("System deleted.") async def system_frontpercent(ctx: CommandContext): @@ -166,7 +165,7 @@ async def system_frontpercent(ctx: CommandContext): }) if not before: - return CommandError("Could not parse '{}' as a valid time.".format(ctx.remaining())) + raise CommandError("Could not parse '{}' as a valid time.".format(ctx.remaining())) # If time is in the future, just kinda discard if before and before > datetime.utcnow(): @@ -177,7 +176,7 @@ async def system_frontpercent(ctx: CommandContext): # Fetch list of switches all_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 99999) if not all_switches: - return CommandError("No switches registered to this system.") + raise CommandError("No switches registered to this system.") # Cull the switches *ending* before the limit, if given # We'll need to find the first switch starting before the limit, then cut off every switch *before* that diff --git a/src/pluralkit/bot/embeds.py b/src/pluralkit/bot/embeds.py index 4a5e1aa1..2b3d0811 100644 --- a/src/pluralkit/bot/embeds.py +++ b/src/pluralkit/bot/embeds.py @@ -95,7 +95,7 @@ async def system_card(conn, client: discord.Client, system: System) -> discord.E pages = [""] for member in member_texts: last_page = pages[-1] - new_page = last_page + "\n" + member + new_page = last_page + "\n" + member if last_page else member if len(new_page) >= 1024: pages.append(member) diff --git a/src/pluralkit/errors.py b/src/pluralkit/errors.py index 0eed7d16..952d799d 100644 --- a/src/pluralkit/errors.py +++ b/src/pluralkit/errors.py @@ -64,6 +64,7 @@ class MemberNameTooLongError(PluralKitError): else: super().__init__("The maximum length of a member name is 32 characters.") + class InvalidColorError(PluralKitError): def __init__(self): super().__init__("Color must be a valid hex color. (eg. #ff0000)") \ No newline at end of file