From cae394b4e858ea81de4cdeceb8bec5649a459363 Mon Sep 17 00:00:00 2001 From: Ske Date: Sun, 16 Sep 2018 13:46:22 +0200 Subject: [PATCH] Refactor member actions into members.py --- src/pluralkit/bot/commands/member_commands.py | 131 +++++++----------- src/pluralkit/bot/commands/system_commands.py | 13 +- src/pluralkit/bot/embeds.py | 1 + src/pluralkit/errors.py | 12 ++ src/pluralkit/member.py | 100 ++++++++++++- src/pluralkit/system.py | 61 +++++--- src/pluralkit/utils.py | 5 +- 7 files changed, 214 insertions(+), 109 deletions(-) diff --git a/src/pluralkit/bot/commands/member_commands.py b/src/pluralkit/bot/commands/member_commands.py index 7ad7a9e2..b8577300 100644 --- a/src/pluralkit/bot/commands/member_commands.py +++ b/src/pluralkit/bot/commands/member_commands.py @@ -1,9 +1,8 @@ from datetime import datetime -from urllib.parse import urlparse -import pluralkit.utils from pluralkit.bot import help from pluralkit.bot.commands import * +from pluralkit.errors import PluralKitError logger = logging.getLogger("pluralkit.commands") @@ -19,103 +18,80 @@ async def new_member(ctx: CommandContext): if not ctx.has_next(): return CommandError("You must pass a name for the new member.", help=help.add_member) - name = ctx.remaining() - bounds_error = utils.bounds_check_member_name(name, system.tag) - if bounds_error: - return CommandError(bounds_error) + new_name = ctx.remaining() - # TODO: figure out what to do if this errors out on collision on generate_hid - hid = pluralkit.utils.generate_hid() + try: + member = await system.create_member(ctx.conn, new_name) + except PluralKitError as e: + return CommandError(e.message) - # Insert member row - await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid) return CommandSuccess( - "Member \"{}\" (`{}`) registered! To register their proxy tags, use `pk;member proxy`.".format(name, hid)) + "Member \"{}\" (`{}`) registered! To register their proxy tags, use `pk;member proxy`.".format(new_name, member.hid)) async def member_set(ctx: CommandContext): system = await ctx.ensure_system() member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member)) - prop = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_member)) - allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"] - db_properties = { - "name": "name", - "description": "description", - "color": "color", - "pronouns": "pronouns", - "birthdate": "birthday", - "avatar": "avatar_url" - } + property_name = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_member)) - if prop not in allowed_properties: - return CommandError( - "Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), - help=help.edit_member) + async def name_setter(conn, new_name): + if not new_name: + raise CommandError("You can't clear the member name.") + await member.set_name(conn, system, new_name) - if ctx.has_next(): - value = ctx.remaining() + async def avatar_setter(conn, url): + if url: + user = await utils.parse_mention(ctx.client, url) + if user: + # Set the avatar to the mentioned user's avatar + # Discord pushes webp by default, which isn't supported by webhooks, but also hosts png alternatives + url = user.avatar_url.replace(".webp", ".png") - # Sanity/validity checks and type conversions - if prop == "name": - if re.search("", value): - return CommandError("Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.") + await member.set_avatar(conn, url) - bounds_error = utils.bounds_check_member_name(value, system.tag) - if bounds_error: - return CommandError(bounds_error) - - if prop == "description": - if len(value) > 1024: - return CommandError("You can't have a description longer than 1024 characters.") - - if prop == "color": - match = re.fullmatch("#?([0-9A-Fa-f]{6})", value) - if not match: - return CommandError("Color must be a valid hex color (eg. #ff0000)") - - value = match.group(1).lower() - - if prop == "birthdate": + async def birthdate_setter(conn, date_str): + if date_str: try: - value = datetime.strptime(value, "%Y-%m-%d").date() + date = datetime.strptime(date_str, "%Y-%m-%d").date() except ValueError: try: # Try again, adding 0001 as a placeholder year # This is considered a "null year" and will be omitted from the info card # Useful if you want your birthday to be displayed yearless. - value = datetime.strptime("0001-" + value, "%Y-%m-%d").date() + date = datetime.strptime("0001-" + date_str, "%Y-%m-%d").date() except ValueError: - return CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).") + raise CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).") + else: + date = None - if prop == "avatar": - user = await utils.parse_mention(ctx.client, value) - if user: - # Set the avatar to the mentioned user's avatar - # Discord doesn't like webp, but also hosts png alternatives - value = user.avatar_url.replace(".webp", ".png") - else: - # Validate URL - u = urlparse(value) - if u.scheme in ["http", "https"] and u.netloc and u.path: - value = value - else: - return CommandError("Invalid image URL.") - else: - # Can't clear member name - if prop == "name": - return CommandError("You can't clear the member name.") + await member.set_birthdate(conn, date) - # Clear from DB - value = None + properties = { + "name": name_setter, + "description": member.set_description, + "avatar": avatar_setter, + "color": member.set_color, + "pronouns": member.set_pronouns, + "birthdate": birthdate_setter, + } - db_prop = db_properties[prop] - await db.update_member_field(ctx.conn, member_id=member.id, field=db_prop, value=value) + if property_name not in properties: + return CommandError( + "Unknown property {}. Allowed properties are {}.".format(property_name, ", ".join(properties.keys())), + help=help.edit_system) - response = CommandSuccess("{} {}'s {}.".format("Updated" if value else "Cleared", member.name, prop)) - #if prop == "avatar" and value: + value = ctx.remaining() or None + + try: + await properties[property_name](ctx.conn, value) + except PluralKitError as e: + return CommandError(e.message) + + response = CommandSuccess("{} member {}.".format("Updated" if value else "Cleared", property_name)) + # if prop == "avatar" and value: # response.set_image(url=value) - #if prop == "color" and value: + # if prop == "color" and value: # response.colour = int(value, 16) return response @@ -148,18 +124,17 @@ async def member_proxy(ctx: CommandContext): suffix = None async with ctx.conn.transaction(): - await db.update_member_field(ctx.conn, member_id=member.id, field="prefix", value=prefix) - await db.update_member_field(ctx.conn, member_id=member.id, field="suffix", value=suffix) + await member.set_proxy_tags(ctx.conn, prefix, suffix) return CommandSuccess("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.") async def member_delete(ctx: CommandContext): await ctx.ensure_system() - member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member)) + member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.remove_member)) delete_confirm_msg = "Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(member.name, member.hid) if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, member.hid, delete_confirm_msg): return CommandError("Member deletion cancelled.") - await db.delete_member(ctx.conn, member_id=member.id) + await member.delete(ctx.conn) return CommandSuccess("Member deleted.") diff --git a/src/pluralkit/bot/commands/system_commands.py b/src/pluralkit/bot/commands/system_commands.py index bcd60437..3cdd8438 100644 --- a/src/pluralkit/bot/commands/system_commands.py +++ b/src/pluralkit/bot/commands/system_commands.py @@ -36,13 +36,14 @@ async def system_set(ctx: CommandContext): property_name = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_system)) async def avatar_setter(conn, url): - user = await utils.parse_mention(ctx.client, url) - if user: - # Set the avatar to the mentioned user's avatar - # Discord pushes webp by default, which isn't supported by webhooks, but also hosts png alternatives - url = user.avatar_url.replace(".webp", ".png") + if url: + user = await utils.parse_mention(ctx.client, url) + if user: + # Set the avatar to the mentioned user's avatar + # Discord pushes webp by default, which isn't supported by webhooks, but also hosts png alternatives + url = user.avatar_url.replace(".webp", ".png") - await system.set_avatar(conn, url) + await system.set_avatar(conn, url) properties = { "name": system.set_name, diff --git a/src/pluralkit/bot/embeds.py b/src/pluralkit/bot/embeds.py index 6aee9797..9c48df8e 100644 --- a/src/pluralkit/bot/embeds.py +++ b/src/pluralkit/bot/embeds.py @@ -28,6 +28,7 @@ def status(text: str) -> discord.Embed: embed.colour = discord.Colour.blue() return embed + def exception_log(message_content, author_name, author_discriminator, server_id, channel_id) -> discord.Embed: embed = discord.Embed() embed.colour = discord.Colour.dark_red() diff --git a/src/pluralkit/errors.py b/src/pluralkit/errors.py index f63cff7f..0eed7d16 100644 --- a/src/pluralkit/errors.py +++ b/src/pluralkit/errors.py @@ -55,3 +55,15 @@ class AccountAlreadyLinkedError(PluralKitError): class UnlinkingLastAccountError(PluralKitError): def __init__(self): super().__init__("This is the only account on your system, so you can't unlink it.") + + +class MemberNameTooLongError(PluralKitError): + def __init__(self, tag_present: bool): + if tag_present: + super().__init__("The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.") + else: + super().__init__("The maximum length of a member name is 32 characters.") + +class InvalidColorError(PluralKitError): + def __init__(self): + super().__init__("Color must be a valid hex color. (eg. #ff0000)") \ No newline at end of file diff --git a/src/pluralkit/member.py b/src/pluralkit/member.py index 626623b2..677aca99 100644 --- a/src/pluralkit/member.py +++ b/src/pluralkit/member.py @@ -1,9 +1,17 @@ +import re from datetime import date, datetime from collections.__init__ import namedtuple +from typing import Optional + +from pluralkit import db, errors +from pluralkit.utils import validate_avatar_url_or_raise, contains_custom_emoji -class Member(namedtuple("Member", ["id", "hid", "system", "color", "avatar_url", "name", "birthday", "pronouns", "description", "prefix", "suffix", "created"])): +class Member(namedtuple("Member", + ["id", "hid", "system", "color", "avatar_url", "name", "birthday", "pronouns", "description", + "prefix", "suffix", "created"])): + """An immutable representation of a system member fetched from the database.""" id: int hid: str system: int @@ -28,4 +36,92 @@ class Member(namedtuple("Member", ["id", "hid", "system", "color", "avatar_url", "description": self.description, "prefix": self.prefix, "suffix": self.suffix - } \ No newline at end of file + } + + @staticmethod + async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]": + """Fetch a member by the given name in the given system from the database.""" + member = await db.get_member_by_name(conn, system_id, member_name) + return member + + @staticmethod + async def get_member_by_hid(conn, system_id: Optional[int], member_hid: str) -> "Optional[Member]": + """Fetch a member by the given hid from the database. If @`system_id` is present, will only return members from that system.""" + if system_id: + member = await db.get_member_by_hid_in_system(conn, system_id, member_hid) + else: + member = await db.get_member_by_hid(conn, member_hid) + + return member + + async def set_name(self, conn, system: "System", new_name: str): + """ + Set the name of a member. Requires the system to be passed in order to bounds check with the system tag. + :raises: MemberNameTooLongError, CustomEmojiError + """ + if contains_custom_emoji(new_name): + raise errors.CustomEmojiError() + + if len(new_name) > system.get_member_name_limit(): + raise errors.MemberNameTooLongError(tag_present=bool(system.tag)) + + await db.update_member_field(conn, self.id, "name", new_name) + + async def set_description(self, conn, new_description: Optional[str]): + """ + Set or clear the description of a member. + :raises: DescriptionTooLongError + """ + if new_description and len(new_description) > 1024: + raise errors.DescriptionTooLongError() + + await db.update_member_field(conn, self.id, "description", new_description) + + async def set_avatar(self, conn, new_avatar_url: Optional[str]): + """ + Set or clear the avatar of a member. + :raises: InvalidAvatarURLError + """ + if new_avatar_url: + validate_avatar_url_or_raise(new_avatar_url) + + await db.update_member_field(conn, self.id, "avatar_url", new_avatar_url) + + async def set_color(self, conn, new_color: Optional[str]): + """ + Set or clear the associated color of a member. + :raises: InvalidColorError + """ + cleaned_color = None + if new_color: + match = re.fullmatch("#?([0-9A-Fa-f]{6})", new_color) + if not match: + return errors.InvalidColorError() + + cleaned_color = match.group(1).lower() + + await db.update_member_field(conn, self.id, "color", cleaned_color) + + async def set_birthdate(self, conn, new_date: date): + """Set or clear the birthdate of a member. To hide the birth year, pass a year of 0001.""" + await db.update_member_field(conn, self.id, "birthday", new_date) + + async def set_pronouns(self, conn, new_pronouns: str): + """Set or clear the associated pronouns with a member.""" + await db.update_member_field(conn, self.id, "pronouns", new_pronouns) + + async def set_proxy_tags(self, conn, prefix: Optional[str], suffix: Optional[str]): + """ + Set the proxy tags for a member. Having no prefix *and* no suffix will disable proxying. + """ + # Make sure empty strings or other falsey values are actually None + prefix = prefix or None + suffix = suffix or None + + async with conn.transaction(): + await db.update_member_field(conn, member_id=self.id, field="prefix", value=prefix) + await db.update_member_field(conn, member_id=self.id, field="suffix", value=suffix) + + async def delete(self, conn): + """Delete this member from the database.""" + await db.delete_member(conn, self.id) \ No newline at end of file diff --git a/src/pluralkit/system.py b/src/pluralkit/system.py index 60bf140b..e1d5d231 100644 --- a/src/pluralkit/system.py +++ b/src/pluralkit/system.py @@ -4,6 +4,7 @@ from collections.__init__ import namedtuple from typing import Optional from pluralkit import db, errors +from pluralkit.member import Member from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise @@ -22,17 +23,18 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a @staticmethod async def create_system(conn, account_id: str, system_name: Optional[str] = None) -> "System": - existing_system = await System.get_by_account(conn, account_id) - if existing_system: - raise errors.ExistingSystemError() - - new_hid = generate_hid() - async with conn.transaction(): - new_system = await db.create_system(conn, system_name, new_hid) - await db.link_account(conn, new_system.id, account_id) + existing_system = await System.get_by_account(conn, account_id) + if existing_system: + raise errors.ExistingSystemError() - return new_system + new_hid = generate_hid() + + async with conn.transaction(): + new_system = await db.create_system(conn, system_name, new_hid) + await db.link_account(conn, new_system.id, account_id) + + return new_system async def set_name(self, conn, new_name: Optional[str]): await db.update_system_field(conn, self.id, "name", new_name) @@ -61,29 +63,48 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a if new_avatar_url: validate_avatar_url_or_raise(new_avatar_url) - await db.update_system_field(conn, self.id, "avatar", new_avatar_url) + await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url) async def link_account(self, conn, new_account_id: str): - existing_system = await System.get_by_account(conn, new_account_id) + async with conn.transaction(): + existing_system = await System.get_by_account(conn, new_account_id) - if existing_system: - if existing_system.id == self.id: - raise errors.AccountInOwnSystemError() + if existing_system: + if existing_system.id == self.id: + raise errors.AccountInOwnSystemError() - raise errors.AccountAlreadyLinkedError(existing_system) + raise errors.AccountAlreadyLinkedError(existing_system) - await db.link_account(conn, self.id, new_account_id) + await db.link_account(conn, self.id, new_account_id) async def unlink_account(self, conn, account_id: str): - linked_accounts = await db.get_linked_accounts(conn, self.id) - if len(linked_accounts) == 1: - raise errors.UnlinkingLastAccountError() + async with conn.transaction(): + linked_accounts = await db.get_linked_accounts(conn, self.id) + if len(linked_accounts) == 1: + raise errors.UnlinkingLastAccountError() - await db.unlink_account(conn, self.id, account_id) + await db.unlink_account(conn, self.id, account_id) async def delete(self, conn): await db.remove_system(conn, self.id) + async def create_member(self, conn, member_name: str) -> Member: + # TODO: figure out what to do if this errors out on collision on generate_hid + new_hid = generate_hid() + + if len(member_name) > self.get_member_name_limit(): + raise errors.MemberNameTooLongError(tag_present=bool(self.tag)) + + member = await db.create_member(conn, self.id, member_name, new_hid) + return member + + def get_member_name_limit(self) -> int: + """Returns the maximum length a member's name or nickname is allowed to be. Depends on the system tag.""" + if self.tag: + return 32 - len(self.tag) - 1 + else: + return 32 + def to_json(self): return { "id": self.hid, diff --git a/src/pluralkit/utils.py b/src/pluralkit/utils.py index 4e33ca86..04dbce4b 100644 --- a/src/pluralkit/utils.py +++ b/src/pluralkit/utils.py @@ -8,7 +8,6 @@ from urllib.parse import urlparse from pluralkit import db from pluralkit.errors import InvalidAvatarURLError -from pluralkit.member import Member def fix_time(time: datetime): @@ -27,7 +26,7 @@ async def get_fronter_ids(conn, system_id) -> (List[int], datetime): return switches[0]["members"], switches[0]["timestamp"] -async def get_fronters(conn, system_id) -> (List[Member], datetime): +async def get_fronters(conn, system_id) -> (List["Member"], datetime): member_ids, timestamp = await get_fronter_ids(conn, system_id) # Collect in dict and then look up as list, to preserve return order @@ -35,7 +34,7 @@ async def get_fronters(conn, system_id) -> (List[Member], datetime): return [members[member_id] for member_id in member_ids], timestamp -async def get_front_history(conn, system_id, count) -> List[Tuple[datetime, List[Member]]]: +async def get_front_history(conn, system_id, count) -> List[Tuple[datetime, List["pluMember"]]]: # Get history from DB switches = await db.front_history(conn, system_id=system_id, count=count) if not switches: