diff --git a/src/pluralkit/bot/commands/import_commands.py b/src/pluralkit/bot/commands/import_commands.py index cc467457..8d025278 100644 --- a/src/pluralkit/bot/commands/import_commands.py +++ b/src/pluralkit/bot/commands/import_commands.py @@ -1,165 +1,49 @@ +import aiohttp +import asyncio +import io +import json import os from datetime import datetime +from pluralkit.errors import TupperboxImportError from pluralkit.bot.commands import * - -def default_tupperware_id(): - if "TUPPERWARE_ID" in os.environ: - return int(os.environ["TUPPERWARE_ID"]) - return 431544605209788416 - - async def import_root(ctx: CommandContext): - # Only one import method rn, so why not default to Tupperware? - await import_tupperware(ctx) + # Only one import method rn, so why not default to Tupperbox? + await import_tupperbox(ctx) -async def import_tupperware(ctx: CommandContext): - # Check if there's a Tupperware bot on the server - # Main instance of TW has that ID, at least - tupperware_id = default_tupperware_id() - if ctx.has_next(): - try: - id_str = ctx.pop_str() - tupperware_id = int(id_str) - except ValueError: - raise CommandError("'{}' is not a valid ID.".format(id_str)) - - tupperware_member = ctx.message.guild.get_member(tupperware_id) - if not tupperware_member: - raise CommandError( - """This command only works in a server where the Tupperware bot is also present. - -If you're trying to import from a Tupperware instance other than the main one (which has the ID {}), pass the ID of that instance as a parameter.""".format( - default_tupperware_id())) - - # Make sure at the bot has send/read permissions here - channel_permissions = ctx.message.channel.permissions_for(tupperware_member) - if not (channel_permissions.read_messages and channel_permissions.send_messages): - # If it doesn't, throw error - raise CommandError("This command only works in a channel where the Tupperware bot has read/send access.") - - await ctx.reply("Please reply to this message with `tul!list` (or the server equivalent).") - - # Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user - def ensure_account(tw_msg): - if tw_msg.channel.id != ctx.message.channel.id: +async def import_tupperbox(ctx: CommandContext): + await ctx.reply("To import from Tupperbox, reply to this message with a `tuppers.json` file imported from Tupperbox.\n\nTo obtain such a file, type `tul!export` (or your server's equivalent).") + + def predicate(msg): + if msg.author.id != ctx.message.author.id: return False - - if tw_msg.author.id != tupperware_member.id: - return False - - if not tw_msg.embeds: - return False - - if not tw_msg.embeds[0].title: - return False - - return tw_msg.embeds[0].title.startswith( - "{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator)) - - tupperware_page_embeds = [] + if msg.attachments: + if msg.attachments[0].filename.endswith(".json"): + return True + return False try: - tw_msg: discord.Message = await ctx.client.wait_for("message", check=ensure_account, timeout=60.0 * 5) + message = await ctx.client.wait_for("message", check=predicate, timeout=60*5) except asyncio.TimeoutError: - raise CommandError("Tupperware import timed out.") + raise CommandError("Timed out. Try running `pk;import` again.") - tupperware_page_embeds.append(tw_msg.embeds[0].to_dict()) - - # Handle Tupperware pagination - def match_pagination(): - pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0].title) - if not pagination_match: - return None - return int(pagination_match.group(1)), int(pagination_match.group(2)) - - pagination_match = match_pagination() - if pagination_match: - status_msg = await ctx.reply("Multi-page member list found. Please manually scroll through all the pages.") - current_page = 0 - total_pages = 1 - - pages_found = {} - - # Keep trying to read the embed with new pages - last_found_time = datetime.utcnow() - while len(pages_found) < total_pages: - new_page, total_pages = match_pagination() - - # Put the found page in the pages dict - pages_found[new_page] = tw_msg.embeds[0].to_dict() - - # If this isn't the same page as last check, edit the status message - if new_page != current_page: - last_found_time = datetime.utcnow() - await status_msg.edit( - content="Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format( - len(pages_found), total_pages)) - current_page = new_page - - # And sleep a bit to prevent spamming the CPU - await asyncio.sleep(0.25) - - # Make sure it doesn't spin here for too long, time out after 30 seconds since last new page - if (datetime.utcnow() - last_found_time).seconds > 30: - raise CommandError("Pagination scan timed out.") - - # Now that we've got all the pages, put them in the embeds list - # Make sure to erase the original one we put in above too - tupperware_page_embeds = list([embed for page, embed in sorted(pages_found.items(), key=lambda x: x[0])]) - - # Also edit the status message to indicate we're now importing, and it may take a while because there's probably a lot of members - await status_msg.edit(content="All pages read. Now importing...") - - # Create new (nameless) system if there isn't any registered + s = io.BytesIO() + await message.attachments[0].save(s) + data = json.load(s) + system = await ctx.get_system() - if system is None: - system = await System.create_system(ctx.conn, ctx.message.author.id) - - for embed in tupperware_page_embeds: - for field in embed["fields"]: - name = field["name"] - lines = field["value"].split("\n") - - member_prefix = None - member_suffix = None - member_avatar = None - member_birthdate = None - member_description = None - - # Read the message format line by line - for line in lines: - if line.startswith("Brackets:"): - brackets = line[len("Brackets: "):] - member_prefix = brackets[:brackets.index("text")].strip() or None - member_suffix = brackets[brackets.index("text") + 4:].strip() or None - elif line.startswith("Avatar URL: "): - url = line[len("Avatar URL: "):] - member_avatar = url - elif line.startswith("Birthday: "): - bday_str = line[len("Birthday: "):] - bday = datetime.strptime(bday_str, "%a %b %d %Y") - if bday: - member_birthdate = bday.date() - elif line.startswith("Total messages sent: ") or line.startswith("Tag: "): - # Ignore this, just so it doesn't catch as the description - pass - else: - member_description = line - - # Read by name - TW doesn't allow name collisions so we're safe here (prevents dupes) - existing_member = await Member.get_member_by_name(ctx.conn, system.id, name) - if not existing_member: - # Or create a new member - existing_member = await system.create_member(ctx.conn, name) - - # Save the new stuff in the DB - await existing_member.set_proxy_tags(ctx.conn, member_prefix, member_suffix) - await existing_member.set_avatar(ctx.conn, member_avatar) - await existing_member.set_birthdate(ctx.conn, member_birthdate) - await existing_member.set_description(ctx.conn, member_description) - - await ctx.reply_ok( - "System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.") + if not system: + system = await System.create_system(ctx.conn, account_id=ctx.author.id) + + result = await system.import_from_tupperbox(ctx.conn, data) + tag_note = "" + if len(result.tags) > 1: + tag_note = "\n\nPluralKit's tags work on a per-system basis. Since your Tupperbox members have more than one unique tag, PluralKit has not imported the tags. Set your system tag manually with `pk;system tag `." + + await ctx.reply_ok("Updated {} member{}, created {} member{}. Type `pk;system` to check!{}".format( + len(result.updated), "s" if len(result.updated) != 1 else "", + len(result.created), "s" if len(result.created) != 1 else "", + tag_note + )) \ No newline at end of file diff --git a/src/pluralkit/errors.py b/src/pluralkit/errors.py index b65f0fe8..ae084adb 100644 --- a/src/pluralkit/errors.py +++ b/src/pluralkit/errors.py @@ -98,3 +98,7 @@ class DuplicateSwitchMembersError(PluralKitError): class InvalidTimeZoneError(PluralKitError): def __init__(self, tz_name: str): super().__init__("Invalid time zone designation \"{}\".\n\nFor a list of valid time zone designations, see the `TZ database name` column here: .".format(tz_name)) + +class TupperboxImportError(PluralKitError): + def __init__(self): + super().__init__("Invalid Tupperbox file.") \ No newline at end of file diff --git a/src/pluralkit/system.py b/src/pluralkit/system.py index 7f1777c7..1ca7a015 100644 --- a/src/pluralkit/system.py +++ b/src/pluralkit/system.py @@ -26,6 +26,8 @@ def canonicalize_tz_name(name: str) -> Optional[str]: if name in name_map: return name_map[name] +class TupperboxImportResult(namedtuple("TupperboxImportResult", ["updated", "created", "tags"])): + pass class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "token", "created", "ui_tz"])): id: int @@ -254,6 +256,66 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a await db.update_system_field(conn, self.id, "ui_tz", tz.zone) return tz + async def import_from_tupperbox(self, conn, data: dict): + """ + Imports from a Tupperbox JSON data file. + :raises: TupperboxImportError + """ + if not "tuppers" in data: + raise errors.TupperboxImportError() + if not isinstance(data["tuppers"], list): + raise errors.TupperboxImportError() + + all_tags = set() + created_members = set() + updated_members = set() + for tupper in data["tuppers"]: + # Sanity check tupper fields + for field in ["name", "avatar_url", "brackets", "birthday", "description", "tag"]: + if field not in tupper: + raise errors.TupperboxImportError() + + # Find member by name, create if not exists + member_name = str(tupper["name"]) + member = await Member.get_member_by_name(conn, self.id, member_name) + if not member: + # And keep track of created members + created_members.add(member_name) + member = await self.create_member(conn, member_name) + else: + # Keep track of updated members + updated_members.add(member_name) + + # Set avatar + await member.set_avatar(conn, str(tupper["avatar_url"])) + + # Set proxy tags + if not (isinstance(tupper["brackets"], list) and len(tupper["brackets"]) >= 2): + raise errors.TupperboxImportError() + await member.set_proxy_tags(conn, str(tupper["brackets"][0]), str(tupper["brackets"][1])) + + # Set birthdate (input is in ISO-8601, first 10 characters is the date) + if tupper["birthday"]: + try: + await member.set_birthdate(conn, str(tupper["birthday"][:10])) + except errors.InvalidDateStringError: + pass + + # Set description + await member.set_description(conn, tupper["description"]) + + # Keep track of tag + all_tags.add(tupper["tag"]) + + # Since Tupperbox does tags on a per-member basis, we only apply a system tag if + # every member has the same tag (surprisingly common) + # If not, we just do nothing. (This will be reported in the caller function through the returned result) + if len(all_tags) == 1: + tag = list(all_tags)[0] + await self.set_tag(ctx.conn, tag) + + return TupperboxImportResult(updated=updated_members, created=created_members, tags=all_tags) + def to_json(self): return { "id": self.hid,