From 799ea01f03f8984a85a3f9a030eeecd7c8f43d47 Mon Sep 17 00:00:00 2001 From: Ske Date: Sat, 14 Jul 2018 02:46:16 +0200 Subject: [PATCH] Refactor multifronter system --- bot/pluralkit/commands.py | 57 +++++++++++---------------------------- bot/pluralkit/utils.py | 47 ++++++++++++++++++++++++++------ 2 files changed, 55 insertions(+), 49 deletions(-) diff --git a/bot/pluralkit/commands.py b/bot/pluralkit/commands.py index b94cf1fe..89a6da7d 100644 --- a/bot/pluralkit/commands.py +++ b/bot/pluralkit/commands.py @@ -8,7 +8,7 @@ import humanize from pluralkit import db from pluralkit.bot import client, logger -from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention, bounds_check_member_name +from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention, bounds_check_member_name, get_fronters, get_fronter_ids, get_front_history @command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands") async def system_info(conn, message, args): @@ -159,31 +159,20 @@ async def system_fronter(conn, message, args): if system is None: return False, "Can't find system \"{}\".".format(args[0]) - # Get latest switch from DB - switches = await db.front_history(conn, system_id=system["id"], count=1) - if len(switches) == 0: - # Special case if empty - return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*") - - switch = switches[0] - - fronter_names = [] - if len(switch["members"]) > 0: - # Fetch member data from DB - members = await db.get_members(conn, switch["members"]) - fronter_names = [member["name"] for member in members] + fronters, timestamp = await get_fronters(conn, system_id=system["id"]) + fronter_names = [member["name"] for member in fronters] embed = make_default_embed(None) if len(fronter_names) == 0: embed.add_field(name="Current fronter", value="*nobody*") elif len(fronter_names) == 1: - embed.add_field(name="Current fronters", value=fronter_names[0]) + embed.add_field(name="Current fronter", value=fronter_names[0]) else: - embed.add_field(name="Current fronter", value=", ".join(fronter_names)) + embed.add_field(name="Current fronters", value=", ".join(fronter_names)) - since = switch["timestamp"] - embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since))) + if timestamp: + embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp))) return True, embed @command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands") @@ -199,31 +188,21 @@ async def system_fronthistory(conn, message, args): if system is None: return False, "Can't find system \"{}\".".format(args[0]) - # Get list of past switches from DB - switches = await db.front_history(conn, system_id=system["id"], count=10) - - # Get all unique IDs referenced - all_member_ids = {id for switch in switches for id in switch["members"]} - - # And look them up in the database into a dict - all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))} - lines = [] - for switch in switches: + for timestamp, members in await get_front_history(conn, system["id"], count=10): # Special case when no one's fronting - if len(switch["members"]) == 0: + if len(members) == 0: name = "*nobody*" else: - name = ", ".join([all_members[id]["name"] for id in switch["members"]]) + name = ", ".join([member["name"] for member in members]) # Make proper date string - since = switch["timestamp"] - time_text = since.isoformat(sep=" ", timespec="seconds") - rel_text = humanize.naturaltime(since) + time_text = timestamp.isoformat(sep=" ", timespec="seconds") + rel_text = humanize.naturaltime(timestamp) lines.append("**{}** ({}, {})".format(name, time_text, rel_text)) - embed = make_default_embed("\n".join(lines)) + embed = make_default_embed("\n".join(lines) or "(none)") embed.title = "Past switches" return True, embed @@ -449,11 +428,7 @@ async def switch_member(conn, message, args): member_ids = {member["id"] for member in members} - switches = await db.front_history(conn, system_id=system["id"], count=1) - fronter_ids = {} - if switches: - fronter_ids = set(switches[0]["members"]) - + fronter_ids = set((await get_fronter_ids(conn, system["id"]))[0]) if member_ids == fronter_ids: if len(members) == 1: return False, "{} is already fronting.".format(members[0]["name"]) @@ -478,8 +453,8 @@ async def switch_out(conn, message, args): return False, "No system is registered to this account." # Get current fronters - switches = await db.front_history(conn, system_id=system["id"], count=1) - if not switches or not switches[0]["members"]: + fronters, _ = await get_fronter_ids(conn, system_id=system["id"]) + if not fronters: return False, "There's already no one in front." # Log it, and don't log any members diff --git a/bot/pluralkit/utils.py b/bot/pluralkit/utils.py index 1499ae7d..5e04fe42 100644 --- a/bot/pluralkit/utils.py +++ b/bot/pluralkit/utils.py @@ -47,7 +47,40 @@ def parse_channel_mention(mention: str, server: discord.Server) -> discord.Chann except ValueError: return None +async def get_fronter_ids(conn, system_id): + switches = await db.front_history(conn, system_id=system_id, count=1) + if not switches: + return [], None + + if not switches[0]["members"]: + return [], switches[0]["timestamp"] + + return switches[0]["members"], switches[0]["timestamp"] +async def get_fronters(conn, system_id): + member_ids, timestamp = await get_fronter_ids(conn, system_id) + members = await db.get_members(conn, member_ids) + return members, timestamp + +async def get_front_history(conn, system_id, count): + # Get history from DB + switches = await db.front_history(conn, system_id=system_id, count=count) + if not switches: + return [] + + # Get all unique IDs referenced + all_member_ids = {id for switch in switches for id in switch["members"]} + + # And look them up in the database into a dict + all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))} + + # Collect in array and return + out = [] + for switch in switches: + timestamp = switch["timestamp"] + members = [all_members[id] for id in switch["members"]] + out.append((timestamp, members)) + return out async def get_system_fuzzy(conn, key) -> asyncpg.Record: if isinstance(key, discord.User): @@ -169,14 +202,12 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb if system["tag"]: card.add_field(name="Tag", value=system["tag"]) - - switches = await db.front_history(conn, system_id=system["id"], count=1) - if switches and switches[0]["members"]: - members = await db.get_members(conn, switches[0]["members"]) - names = ", ".join([member["name"] for member in members]) - - fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switches[0]["timestamp"])) - card.add_field(name="Current fronter" if len(members) == 1 else "Current fronters", value=fronter_val) + + fronters, switch_time = await get_fronters(conn, system["id"]) + if fronters: + names = ", ".join([member["name"] for member in fronters]) + fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time)) + card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters", value=fronter_val) account_names = [] for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):