Refactor error handling slightly and don't use embeds for basic status/error messages. Closes #28.

This commit is contained in:
Ske 2018-11-15 21:05:13 +01:00
parent 869f686bd5
commit 8e504fa879
11 changed files with 74 additions and 81 deletions

View File

@ -27,26 +27,13 @@ def next_arg(arg_string: str) -> Tuple[str, Optional[str]]:
return arg_string.strip(), None
class CommandResponse:
def to_embed(self):
pass
class CommandSuccess(CommandResponse):
def __init__(self, text):
self.text = text
def to_embed(self):
return embeds.success("\u2705 " + self.text)
class CommandError(Exception, CommandResponse):
class CommandError(Exception):
def __init__(self, text: str, help: Tuple[str, str] = None):
self.text = text
self.help = help
def to_embed(self):
return embeds.error("\u274c " + self.text, self.help)
def format(self):
return "\u274c " + self.text, embeds.error("", self.help) if self.help else None
class CommandContext:
@ -108,13 +95,19 @@ class CommandContext:
async def reply(self, content=None, embed=None):
return await self.message.channel.send(content=content, embed=embed)
async def reply_ok(self, content=None, embed=None):
return await self.reply(content="\u2705 {}".format(content or ""), embed=embed)
async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
message = await self.reply(message)
await message.add_reaction("\u2705") # Checkmark
await message.add_reaction("\u274c") # Red X
try:
reaction, _ = await self.client.wait_for("reaction_add", check=lambda r, u: u.id == user.id and r.emoji in ["\u2705", "\u274c"], timeout=60.0*5)
reaction, _ = await self.client.wait_for("reaction_add",
check=lambda r, u: u.id == user.id and r.emoji in ["\u2705",
"\u274c"],
timeout=60.0 * 5)
return reaction.emoji == "\u2705"
except asyncio.TimeoutError:
raise CommandError("Timed out - try again.")
@ -123,7 +116,9 @@ class CommandContext:
await self.reply(message)
try:
message = await self.client.wait_for("message", check=lambda m: m.channel.id == channel.id and m.author.id == user.id, timeout=60.0*5)
message = await self.client.wait_for("message",
check=lambda m: m.channel.id == channel.id and m.author.id == user.id,
timeout=60.0 * 5)
return message.content.lower() == confirm_text.lower()
except asyncio.TimeoutError:
raise CommandError("Timed out - try again.")
@ -142,10 +137,9 @@ import pluralkit.bot.commands.system_commands
async def run_command(ctx: CommandContext, func):
try:
result = await func(ctx)
if isinstance(result, CommandResponse):
await ctx.reply(embed=result.to_embed())
except CommandError as e:
await ctx.reply(embed=e.to_embed())
content, embed = e.format()
await ctx.reply(content=content, embed=embed)
async def command_dispatch(client: discord.Client, message: discord.Message, conn) -> bool:

View File

@ -1,7 +1,7 @@
import logging
from discord import DMChannel
from pluralkit.bot.commands import CommandContext, CommandSuccess
from pluralkit.bot.commands import CommandContext
logger = logging.getLogger("pluralkit.commands")
disclaimer = "Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure. If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`."
@ -10,7 +10,7 @@ async def reply_dm(ctx: CommandContext, message: str):
await ctx.message.author.send(message)
if not isinstance(ctx.message.channel, DMChannel):
return CommandSuccess("DM'd!")
await ctx.reply_ok("DM'd!")
async def get_token(ctx: CommandContext):
system = await ctx.ensure_system()

View File

@ -12,13 +12,13 @@ async def import_tupperware(ctx: CommandContext):
# Check if there's a Tupperware bot on the server
if not tupperware_member:
return CommandError("This command only works in a server where the Tupperware bot is also present.")
raise CommandError("This command only works in a server where the Tupperware bot is also present.")
# Make sure at the bot has send/read permissions here
channel_permissions = ctx.message.channel.permissions_for(tupperware_member)
if not (channel_permissions.read_messages and channel_permissions.send_messages):
# If it doesn't, throw error
return CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
raise CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
await ctx.reply(
embed=embeds.status("Please reply to this message with `tul!list` (or the server equivalent)."))
@ -44,7 +44,7 @@ async def import_tupperware(ctx: CommandContext):
tw_msg: discord.Message = await ctx.client.wait_for("message", check=ensure_account, timeout=60.0 * 5)
if not tw_msg:
return CommandError("Tupperware import timed out.")
raise CommandError("Tupperware import timed out.")
tupperware_page_embeds.append(tw_msg.embeds[0].to_dict())
# Handle Tupperware pagination
@ -82,7 +82,7 @@ async def import_tupperware(ctx: CommandContext):
# Make sure it doesn't spin here for too long, time out after 30 seconds since last new page
if (datetime.utcnow() - last_found_time).seconds > 30:
return CommandError("Pagination scan timed out.")
raise CommandError("Pagination scan timed out.")
# Now that we've got all the pages, put them in the embeds list
# Make sure to erase the original one we put in above too
@ -152,5 +152,5 @@ async def import_tupperware(ctx: CommandContext):
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description",
value=member_description)
return CommandSuccess(
await ctx.reply_ok(
"System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.")

View File

@ -17,16 +17,16 @@ async def member_info(ctx: CommandContext):
async def new_member(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
return CommandError("You must pass a name for the new member.", help=help.add_member)
raise CommandError("You must pass a name for the new member.", help=help.add_member)
new_name = ctx.remaining()
try:
member = await system.create_member(ctx.conn, new_name)
except PluralKitError as e:
return CommandError(e.message)
raise CommandError(e.message)
return CommandSuccess(
await ctx.reply_ok(
"Member \"{}\" (`{}`) registered! To register their proxy tags, use `pk;member proxy`.".format(new_name, member.hid))
@ -78,7 +78,7 @@ async def member_set(ctx: CommandContext):
}
if property_name not in properties:
return CommandError(
raise CommandError(
"Unknown property {}. Allowed properties are {}.".format(property_name, ", ".join(properties.keys())),
help=help.edit_system)
@ -87,14 +87,13 @@ async def member_set(ctx: CommandContext):
try:
await properties[property_name](ctx.conn, value)
except PluralKitError as e:
return CommandError(e.message)
raise CommandError(e.message)
response = CommandSuccess("{} member {}.".format("Updated" if value else "Cleared", property_name))
# if prop == "avatar" and value:
# response.set_image(url=value)
# if prop == "color" and value:
# response.colour = int(value, 16)
return response
await ctx.reply_ok("{} member {}.".format("Updated" if value else "Cleared", property_name))
async def member_proxy(ctx: CommandContext):
@ -107,10 +106,10 @@ async def member_proxy(ctx: CommandContext):
# Sanity checking
example = ctx.remaining()
if "text" not in example:
return CommandError("Example proxy message must contain the string 'text'.", help=help.member_proxy)
raise CommandError("Example proxy message must contain the string 'text'.", help=help.member_proxy)
if example.count("text") != 1:
return CommandError("Example proxy message must contain the string 'text' exactly once.",
raise CommandError("Example proxy message must contain the string 'text' exactly once.",
help=help.member_proxy)
# Extract prefix and suffix
@ -126,7 +125,7 @@ async def member_proxy(ctx: CommandContext):
async with ctx.conn.transaction():
await member.set_proxy_tags(ctx.conn, prefix, suffix)
return CommandSuccess("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.")
await ctx.reply_ok("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.")
async def member_delete(ctx: CommandContext):
@ -135,7 +134,7 @@ async def member_delete(ctx: CommandContext):
delete_confirm_msg = "Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(member.name, member.hid)
if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, member.hid, delete_confirm_msg):
return CommandError("Member deletion cancelled.")
raise CommandError("Member deletion cancelled.")
await member.delete(ctx.conn)
return CommandSuccess("Member deleted.")
await ctx.reply_ok("Member deleted.")

View File

@ -21,7 +21,7 @@ async def message_info(ctx: CommandContext):
try:
mid = int(mid_str)
except ValueError:
return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
raise CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
# Find the message in the DB
message = await db.get_message(ctx.conn, mid)
@ -59,4 +59,4 @@ async def message_info(ctx: CommandContext):
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
await ctx.reply(embed=embed)
await ctx.reply_ok(embed=embed)

View File

@ -28,7 +28,7 @@ async def show_help(ctx: CommandContext):
else:
embed.description = text
else:
return CommandError("Unknown help page '{}'.".format(category))
raise CommandError("Unknown help page '{}'.".format(category))
await ctx.reply(embed=embed)
@ -47,7 +47,7 @@ async def invite_link(ctx: CommandContext):
url = oauth_url(client_id, permissions)
logger.debug("Sending invite URL: {}".format(url))
return CommandSuccess("Use this link to add PluralKit to your server: {}".format(url))
await ctx.reply_ok("Use this link to add PluralKit to your server: {}".format(url))
async def export(ctx: CommandContext):

View File

@ -5,19 +5,19 @@ logger = logging.getLogger("pluralkit.commands")
async def set_log(ctx: CommandContext):
if not ctx.message.author.guild_permissions.administrator:
return CommandError("You must be a server administrator to use this command.")
raise CommandError("You must be a server administrator to use this command.")
server = ctx.message.guild
if not server:
return CommandError("This command can not be run in a DM.")
raise CommandError("This command can not be run in a DM.")
if not ctx.has_next():
channel_id = None
else:
channel = utils.parse_channel_mention(ctx.pop_str(), server=server)
if not channel:
return CommandError("Channel not found.")
raise CommandError("Channel not found.")
channel_id = channel.id
await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id)
return CommandSuccess("Updated logging channel." if channel_id else "Cleared logging channel.")
await ctx.reply_ok("Updated logging channel." if channel_id else "Cleared logging channel.")

View File

@ -15,7 +15,7 @@ async def switch_member(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
return CommandError("You must pass at least one member name or ID to register a switch to.",
raise CommandError("You must pass at least one member name or ID to register a switch to.",
help=help.switch_register)
members: List[Member] = []
@ -23,7 +23,7 @@ async def switch_member(ctx: CommandContext):
# Find the member
member = await utils.get_member_fuzzy(ctx.conn, system.id, member_name)
if not member:
return CommandError("Couldn't find member \"{}\".".format(member_name))
raise CommandError("Couldn't find member \"{}\".".format(member_name))
members.append(member)
# Compare requested switch IDs and existing fronter IDs to check for existing switches
@ -32,12 +32,12 @@ async def switch_member(ctx: CommandContext):
fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0]
if member_ids == fronter_ids:
if len(members) == 1:
return CommandError("{} is already fronting.".format(members[0].name))
return CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
raise CommandError("{} is already fronting.".format(members[0].name))
raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
# Also make sure there aren't any duplicates
if len(set(member_ids)) != len(member_ids):
return CommandError("Duplicate members in member list.")
raise CommandError("Duplicate members in member list.")
# Log the switch
async with ctx.conn.transaction():
@ -46,9 +46,9 @@ async def switch_member(ctx: CommandContext):
await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id)
if len(members) == 1:
return CommandSuccess("Switch registered. Current fronter is now {}.".format(members[0].name))
await ctx.reply_ok("Switch registered. Current fronter is now {}.".format(members[0].name))
else:
return CommandSuccess(
await ctx.reply_ok(
"Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members])))
@ -58,17 +58,17 @@ async def switch_out(ctx: CommandContext):
# Get current fronters
fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=system.id)
if not fronters:
return CommandError("There's already no one in front.")
raise CommandError("There's already no one in front.")
# Log it, and don't log any members
await db.add_switch(ctx.conn, system_id=system.id)
return CommandSuccess("Switch-out registered.")
await ctx.reply_ok("Switch-out registered.")
async def switch_move(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
return CommandError("You must pass a time to move the switch to.", help=help.switch_move)
raise CommandError("You must pass a time to move the switch to.", help=help.switch_move)
# Parse the time to move to
new_time = dateparser.parse(ctx.remaining(), languages=["en"], settings={
@ -76,18 +76,18 @@ async def switch_move(ctx: CommandContext):
"RETURN_AS_TIMEZONE_AWARE": False
})
if not new_time:
return CommandError("'{}' can't be parsed as a valid time.".format(ctx.remaining()), help=help.switch_move)
raise CommandError("'{}' can't be parsed as a valid time.".format(ctx.remaining()), help=help.switch_move)
# Make sure the time isn't in the future
if new_time > datetime.utcnow():
return CommandError("Can't move switch to a time in the future.", help=help.switch_move)
raise CommandError("Can't move switch to a time in the future.", help=help.switch_move)
# Make sure it all runs in a big transaction for atomicity
async with ctx.conn.transaction():
# Get the last two switches to make sure the switch to move isn't before the second-last switch
last_two_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=2)
if len(last_two_switches) == 0:
return CommandError("There are no registered switches for this system.")
raise CommandError("There are no registered switches for this system.")
last_timestamp, last_fronters = last_two_switches[0]
if len(last_two_switches) > 1:
@ -95,7 +95,7 @@ async def switch_move(ctx: CommandContext):
if new_time < second_last_timestamp:
time_str = humanize.naturaltime(pluralkit.utils.fix_time(second_last_timestamp))
return CommandError(
raise CommandError(
"Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str))
# Display the confirmation message w/ humanized times
@ -108,14 +108,14 @@ async def switch_move(ctx: CommandContext):
# Confirm with user
switch_confirm_message = "This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative)
if not await ctx.confirm_react(ctx.message.author, switch_confirm_message):
return CommandError("Switch move cancelled.")
raise CommandError("Switch move cancelled.")
# DB requires the actual switch ID which our utility method above doesn't return, do this manually
switch_id = (await db.front_history(ctx.conn, system.id, count=1))[0]["id"]
# Change the switch in the DB
await db.move_last_switch(ctx.conn, system.id, switch_id, new_time)
return CommandSuccess("Switch moved.")
await ctx.reply_ok("Switch moved.")

View File

@ -26,9 +26,9 @@ async def new_system(ctx: CommandContext):
try:
await System.create_system(ctx.conn, ctx.message.author.id, system_name)
except ExistingSystemError as e:
return CommandError(e.message)
raise CommandError(e.message)
return CommandSuccess("System registered! To begin adding members, use `pk;member new <name>`.")
await ctx.reply_ok("System registered! To begin adding members, use `pk;member new <name>`.")
async def system_set(ctx: CommandContext):
@ -54,7 +54,7 @@ async def system_set(ctx: CommandContext):
}
if property_name not in properties:
return CommandError(
raise CommandError(
"Unknown property {}. Allowed properties are {}.".format(property_name, ", ".join(properties.keys())),
help=help.edit_system)
@ -63,12 +63,11 @@ async def system_set(ctx: CommandContext):
try:
await properties[property_name](ctx.conn, value)
except PluralKitError as e:
return CommandError(e.message)
raise CommandError(e.message)
response = CommandSuccess("{} system {}.".format("Updated" if value else "Cleared", property_name))
await ctx.reply_ok("{} system {}.".format("Updated" if value else "Cleared", property_name))
# if prop == "avatar" and value:
# response.set_image(url=value)
return response
async def system_link(ctx: CommandContext):
@ -78,18 +77,18 @@ async def system_link(ctx: CommandContext):
# Find account to link
linkee = await utils.parse_mention(ctx.client, account_name)
if not linkee:
return CommandError("Account not found.")
raise CommandError("Account not found.")
# Make sure account doesn't already have a system
account_system = await System.get_by_account(ctx.conn, linkee.id)
if account_system:
return CommandError(AccountAlreadyLinkedError(account_system).message)
raise CommandError(AccountAlreadyLinkedError(account_system).message)
if not await ctx.confirm_react(linkee, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention)):
return CommandError("Account link cancelled.")
raise CommandError("Account link cancelled.")
await system.link_account(ctx.conn, linkee.id)
return CommandSuccess("Account linked to system.")
await ctx.reply_ok("Account linked to system.")
async def system_unlink(ctx: CommandContext):
@ -98,9 +97,9 @@ async def system_unlink(ctx: CommandContext):
try:
await system.unlink_account(ctx.conn, ctx.message.author.id)
except UnlinkingLastAccountError as e:
return CommandError(e.message)
raise CommandError(e.message)
return CommandSuccess("Account unlinked.")
await ctx.reply_ok("Account unlinked.")
async def system_fronter(ctx: CommandContext):
@ -149,10 +148,10 @@ async def system_delete(ctx: CommandContext):
delete_confirm_msg = "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(
system.hid)
if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, system.hid, delete_confirm_msg):
return CommandError("System deletion cancelled.")
raise CommandError("System deletion cancelled.")
await system.delete(ctx.conn)
return CommandSuccess("System deleted.")
await ctx.reply_ok("System deleted.")
async def system_frontpercent(ctx: CommandContext):
@ -166,7 +165,7 @@ async def system_frontpercent(ctx: CommandContext):
})
if not before:
return CommandError("Could not parse '{}' as a valid time.".format(ctx.remaining()))
raise CommandError("Could not parse '{}' as a valid time.".format(ctx.remaining()))
# If time is in the future, just kinda discard
if before and before > datetime.utcnow():
@ -177,7 +176,7 @@ async def system_frontpercent(ctx: CommandContext):
# Fetch list of switches
all_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 99999)
if not all_switches:
return CommandError("No switches registered to this system.")
raise CommandError("No switches registered to this system.")
# Cull the switches *ending* before the limit, if given
# We'll need to find the first switch starting before the limit, then cut off every switch *before* that

View File

@ -95,7 +95,7 @@ async def system_card(conn, client: discord.Client, system: System) -> discord.E
pages = [""]
for member in member_texts:
last_page = pages[-1]
new_page = last_page + "\n" + member
new_page = last_page + "\n" + member if last_page else member
if len(new_page) >= 1024:
pages.append(member)

View File

@ -64,6 +64,7 @@ class MemberNameTooLongError(PluralKitError):
else:
super().__init__("The maximum length of a member name is 32 characters.")
class InvalidColorError(PluralKitError):
def __init__(self):
super().__init__("Color must be a valid hex color. (eg. #ff0000)")