Refactor multifronter system
This commit is contained in:
		@@ -8,7 +8,7 @@ import humanize
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from pluralkit import db
 | 
					from pluralkit import db
 | 
				
			||||||
from pluralkit.bot import client, logger
 | 
					from pluralkit.bot import client, logger
 | 
				
			||||||
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention, bounds_check_member_name
 | 
					from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map, make_default_embed, parse_channel_mention, bounds_check_member_name, get_fronters, get_fronter_ids, get_front_history
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands")
 | 
					@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands")
 | 
				
			||||||
async def system_info(conn, message, args):
 | 
					async def system_info(conn, message, args):
 | 
				
			||||||
@@ -159,31 +159,20 @@ async def system_fronter(conn, message, args):
 | 
				
			|||||||
        if system is None:
 | 
					        if system is None:
 | 
				
			||||||
            return False, "Can't find system \"{}\".".format(args[0])
 | 
					            return False, "Can't find system \"{}\".".format(args[0])
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Get latest switch from DB
 | 
					    fronters, timestamp = await get_fronters(conn, system_id=system["id"])
 | 
				
			||||||
    switches = await db.front_history(conn, system_id=system["id"], count=1)
 | 
					    fronter_names = [member["name"] for member in fronters]
 | 
				
			||||||
    if len(switches) == 0:
 | 
					 | 
				
			||||||
        # Special case if empty
 | 
					 | 
				
			||||||
        return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    switch = switches[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    fronter_names = []
 | 
					 | 
				
			||||||
    if len(switch["members"]) > 0:
 | 
					 | 
				
			||||||
        # Fetch member data from DB
 | 
					 | 
				
			||||||
        members = await db.get_members(conn, switch["members"])
 | 
					 | 
				
			||||||
        fronter_names = [member["name"] for member in members]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    embed = make_default_embed(None)
 | 
					    embed = make_default_embed(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if len(fronter_names) == 0:
 | 
					    if len(fronter_names) == 0:
 | 
				
			||||||
        embed.add_field(name="Current fronter", value="*nobody*")
 | 
					        embed.add_field(name="Current fronter", value="*nobody*")
 | 
				
			||||||
    elif len(fronter_names) == 1:
 | 
					    elif len(fronter_names) == 1:
 | 
				
			||||||
        embed.add_field(name="Current fronters", value=fronter_names[0])
 | 
					        embed.add_field(name="Current fronter", value=fronter_names[0])
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        embed.add_field(name="Current fronter", value=", ".join(fronter_names))
 | 
					        embed.add_field(name="Current fronters", value=", ".join(fronter_names))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    since = switch["timestamp"]
 | 
					    if timestamp:
 | 
				
			||||||
    embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since)))
 | 
					        embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp)))
 | 
				
			||||||
    return True, embed
 | 
					    return True, embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands")
 | 
					@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands")
 | 
				
			||||||
@@ -199,31 +188,21 @@ async def system_fronthistory(conn, message, args):
 | 
				
			|||||||
        if system is None:
 | 
					        if system is None:
 | 
				
			||||||
            return False, "Can't find system \"{}\".".format(args[0])
 | 
					            return False, "Can't find system \"{}\".".format(args[0])
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Get list of past switches from DB
 | 
					 | 
				
			||||||
    switches = await db.front_history(conn, system_id=system["id"], count=10)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Get all unique IDs referenced
 | 
					 | 
				
			||||||
    all_member_ids = {id for switch in switches for id in switch["members"]}
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # And look them up in the database into a dict
 | 
					 | 
				
			||||||
    all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    lines = []
 | 
					    lines = []
 | 
				
			||||||
    for switch in switches:
 | 
					    for timestamp, members in await get_front_history(conn, system["id"], count=10):
 | 
				
			||||||
        # Special case when no one's fronting
 | 
					        # Special case when no one's fronting
 | 
				
			||||||
        if len(switch["members"]) == 0:
 | 
					        if len(members) == 0:
 | 
				
			||||||
            name = "*nobody*"
 | 
					            name = "*nobody*"
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            name = ", ".join([all_members[id]["name"] for id in switch["members"]])
 | 
					            name = ", ".join([member["name"] for member in members])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Make proper date string
 | 
					        # Make proper date string
 | 
				
			||||||
        since = switch["timestamp"]
 | 
					        time_text = timestamp.isoformat(sep=" ", timespec="seconds")
 | 
				
			||||||
        time_text = since.isoformat(sep=" ", timespec="seconds")
 | 
					        rel_text = humanize.naturaltime(timestamp)
 | 
				
			||||||
        rel_text = humanize.naturaltime(since)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        lines.append("**{}** ({}, {})".format(name, time_text, rel_text))
 | 
					        lines.append("**{}** ({}, {})".format(name, time_text, rel_text))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    embed = make_default_embed("\n".join(lines))
 | 
					    embed = make_default_embed("\n".join(lines) or "(none)")
 | 
				
			||||||
    embed.title = "Past switches"
 | 
					    embed.title = "Past switches"
 | 
				
			||||||
    return True, embed
 | 
					    return True, embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -449,11 +428,7 @@ async def switch_member(conn, message, args):
 | 
				
			|||||||
    
 | 
					    
 | 
				
			||||||
    member_ids = {member["id"] for member in members}
 | 
					    member_ids = {member["id"] for member in members}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    switches = await db.front_history(conn, system_id=system["id"], count=1)
 | 
					    fronter_ids = set((await get_fronter_ids(conn, system["id"]))[0])
 | 
				
			||||||
    fronter_ids = {}
 | 
					 | 
				
			||||||
    if switches:
 | 
					 | 
				
			||||||
        fronter_ids = set(switches[0]["members"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if member_ids == fronter_ids:
 | 
					    if member_ids == fronter_ids:
 | 
				
			||||||
        if len(members) == 1:
 | 
					        if len(members) == 1:
 | 
				
			||||||
            return False, "{} is already fronting.".format(members[0]["name"])
 | 
					            return False, "{} is already fronting.".format(members[0]["name"])
 | 
				
			||||||
@@ -478,8 +453,8 @@ async def switch_out(conn, message, args):
 | 
				
			|||||||
        return False, "No system is registered to this account."
 | 
					        return False, "No system is registered to this account."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Get current fronters
 | 
					    # Get current fronters
 | 
				
			||||||
    switches = await db.front_history(conn, system_id=system["id"], count=1)
 | 
					    fronters, _ = await get_fronter_ids(conn, system_id=system["id"])
 | 
				
			||||||
    if not switches or not switches[0]["members"]:
 | 
					    if not fronters:
 | 
				
			||||||
        return False, "There's already no one in front."
 | 
					        return False, "There's already no one in front."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Log it, and don't log any members
 | 
					    # Log it, and don't log any members
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -47,7 +47,40 @@ def parse_channel_mention(mention: str, server: discord.Server) -> discord.Chann
 | 
				
			|||||||
    except ValueError:
 | 
					    except ValueError:
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_fronter_ids(conn, system_id):
 | 
				
			||||||
 | 
					    switches = await db.front_history(conn, system_id=system_id, count=1)
 | 
				
			||||||
 | 
					    if not switches:
 | 
				
			||||||
 | 
					        return [], None
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    if not switches[0]["members"]:
 | 
				
			||||||
 | 
					        return [], switches[0]["timestamp"]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return switches[0]["members"], switches[0]["timestamp"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_fronters(conn, system_id):
 | 
				
			||||||
 | 
					    member_ids, timestamp = await get_fronter_ids(conn, system_id)
 | 
				
			||||||
 | 
					    members = await db.get_members(conn, member_ids)
 | 
				
			||||||
 | 
					    return members, timestamp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_front_history(conn, system_id, count):
 | 
				
			||||||
 | 
					    # Get history from DB
 | 
				
			||||||
 | 
					    switches = await db.front_history(conn, system_id=system_id, count=count)
 | 
				
			||||||
 | 
					    if not switches:
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Get all unique IDs referenced
 | 
				
			||||||
 | 
					    all_member_ids = {id for switch in switches for id in switch["members"]}
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # And look them up in the database into a dict
 | 
				
			||||||
 | 
					    all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Collect in array and return
 | 
				
			||||||
 | 
					    out = []
 | 
				
			||||||
 | 
					    for switch in switches:
 | 
				
			||||||
 | 
					        timestamp = switch["timestamp"]
 | 
				
			||||||
 | 
					        members = [all_members[id] for id in switch["members"]]
 | 
				
			||||||
 | 
					        out.append((timestamp, members))
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_system_fuzzy(conn, key) -> asyncpg.Record:
 | 
					async def get_system_fuzzy(conn, key) -> asyncpg.Record:
 | 
				
			||||||
    if isinstance(key, discord.User):
 | 
					    if isinstance(key, discord.User):
 | 
				
			||||||
@@ -170,13 +203,11 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb
 | 
				
			|||||||
    if system["tag"]:
 | 
					    if system["tag"]:
 | 
				
			||||||
        card.add_field(name="Tag", value=system["tag"])
 | 
					        card.add_field(name="Tag", value=system["tag"])
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    switches = await db.front_history(conn, system_id=system["id"], count=1)
 | 
					    fronters, switch_time = await get_fronters(conn, system["id"])
 | 
				
			||||||
    if switches and switches[0]["members"]:
 | 
					    if fronters:
 | 
				
			||||||
        members = await db.get_members(conn, switches[0]["members"])
 | 
					        names = ", ".join([member["name"] for member in fronters])
 | 
				
			||||||
        names = ", ".join([member["name"] for member in members])
 | 
					        fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switch_time))
 | 
				
			||||||
        
 | 
					        card.add_field(name="Current fronter" if len(fronters) == 1 else "Current fronters", value=fronter_val)
 | 
				
			||||||
        fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switches[0]["timestamp"]))
 | 
					 | 
				
			||||||
        card.add_field(name="Current fronter" if len(members) == 1 else "Current fronters", value=fronter_val)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    account_names = []
 | 
					    account_names = []
 | 
				
			||||||
    for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
 | 
					    for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user