diff --git a/bot/pluralkit/commands.py b/bot/pluralkit/commands.py index 3dd384fc..b94cf1fe 100644 --- a/bot/pluralkit/commands.py +++ b/bot/pluralkit/commands.py @@ -1,4 +1,5 @@ from datetime import datetime +import itertools import re from urllib.parse import urlparse @@ -145,7 +146,7 @@ async def system_unlink(conn, message, args): await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) return True, "Account unlinked." -@command(cmd="system fronter", usage="[system]", description="Gets the current fronter in the system.", category="Switching commands") +@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands") async def system_fronter(conn, message, args): if len(args) == 0: system = await db.get_system_by_account(conn, message.author.id) @@ -158,21 +159,30 @@ async def system_fronter(conn, message, args): if system is None: return False, "Can't find system \"{}\".".format(args[0]) - current_fronter = await db.current_fronter(conn, system_id=system["id"]) - if not current_fronter: + # Get latest switch from DB + switches = await db.front_history(conn, system_id=system["id"], count=1) + if len(switches) == 0: + # Special case if empty return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*") - fronter_name = "*(nobody)*" - if current_fronter["member"]: - member = await db.get_member(conn, member_id=current_fronter["member"]) - fronter_name = member["name"] - if current_fronter["member_del"]: - fronter_name = "*(deleted member)*" + switch = switches[0] - since = current_fronter["timestamp"] + fronter_names = [] + if len(switch["members"]) > 0: + # Fetch member data from DB + members = await db.get_members(conn, switch["members"]) + fronter_names = [member["name"] for member in members] embed = make_default_embed(None) - embed.add_field(name="Current fronter", value=fronter_name) + + if len(fronter_names) == 0: + embed.add_field(name="Current fronter", value="*nobody*") + elif len(fronter_names) == 1: + embed.add_field(name="Current fronters", value=fronter_names[0]) + else: + embed.add_field(name="Current fronter", value=", ".join(fronter_names)) + + since = switch["timestamp"] embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since))) return True, embed @@ -189,15 +199,29 @@ async def system_fronthistory(conn, message, args): if system is None: return False, "Can't find system \"{}\".".format(args[0]) - switches = await db.past_fronters(conn, system_id=system["id"], amount=10) + # Get list of past switches from DB + switches = await db.front_history(conn, system_id=system["id"], count=10) + # Get all unique IDs referenced + all_member_ids = {id for switch in switches for id in switch["members"]} + + # And look them up in the database into a dict + all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))} + lines = [] for switch in switches: + # Special case when no one's fronting + if len(switch["members"]) == 0: + name = "*nobody*" + else: + name = ", ".join([all_members[id]["name"] for id in switch["members"]]) + + # Make proper date string since = switch["timestamp"] time_text = since.isoformat(sep=" ", timespec="seconds") rel_text = humanize.naturaltime(since) - lines.append("**{}** ({}, at {})".format(switch["name"], time_text, rel_text)) + lines.append("**{}** ({}, {})".format(name, time_text, rel_text)) embed = make_default_embed("\n".join(lines)) embed.title = "Past switches" @@ -292,16 +316,16 @@ async def member_set(conn, message, member, args): return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)." if prop == "avatar": - user = await parse_mention(args[0]) + user = await parse_mention(value) if user: # Set the avatar to the mentioned user's avatar # Discord doesn't like webp, but also hosts png alternatives value = user.avatar_url.replace(".webp", ".png") else: # Validate URL - u = urlparse(args[0]) + u = urlparse(value) if u.scheme in ["http", "https"] and u.netloc and u.path: - value = args[0] + value = value else: return False, "Invalid URL." else: @@ -405,7 +429,7 @@ async def message_info(conn, message, args): await client.send_message(message.channel, embed=embed) return True -@command(cmd="switch", usage="", description="Registers a switch and changes the current fronter.", category="Switching commands") +@command(cmd="switch", usage=" [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands") async def switch_member(conn, message, args): if len(args) == 0: return False @@ -415,19 +439,36 @@ async def switch_member(conn, message, args): if system is None: return False, "No system is registered to this account." - # Find the member - member = await get_member_fuzzy(conn, system["id"], " ".join(args)) - if not member: - return False, "Couldn't find member \"{}\".".format(args[0]) + members = [] + for member_name in args: + # Find the member + member = await get_member_fuzzy(conn, system["id"], member_name) + if not member: + return False, "Couldn't find member \"{}\".".format(member_name) + members.append(member) + + member_ids = {member["id"] for member in members} - # Get current fronter - current_fronter = await db.current_fronter(conn, system_id=system["id"]) - if current_fronter and current_fronter["member"] == member["id"]: - return False, "Member \"{}\" is already fronting.".format(member["name"]) + switches = await db.front_history(conn, system_id=system["id"], count=1) + fronter_ids = {} + if switches: + fronter_ids = set(switches[0]["members"]) + + if member_ids == fronter_ids: + if len(members) == 1: + return False, "{} is already fronting.".format(members[0]["name"]) + return False, "Members {} are already fronting.".format(", ".join([m["name"] for m in members])) # Log the switch - await db.add_switch(conn, system_id=system["id"], member_id=member["id"]) - return True, "Switch registered. Current fronter is now {}.".format(member["name"]) + async with conn.transaction(): + switch_id = await db.add_switch(conn, system_id=system["id"]) + for member in members: + await db.add_switch_member(conn, switch_id=switch_id, member_id=member["id"]) + + if len(members) == 1: + return True, "Switch registered. Current fronter is now {}.".format(member["name"]) + else: + return True, "Switch registered. Current fronters are now {}.".format(", ".join([m["name"] for m in members])) @command(cmd="switch out", description="Registers a switch out, and leaves current fronter blank.", category="Switching commands") async def switch_out(conn, message, args): @@ -436,13 +477,13 @@ async def switch_out(conn, message, args): if system is None: return False, "No system is registered to this account." - # Get current fronter - current_fronter = await db.current_fronter(conn, system_id=system["id"]) - if not current_fronter or not current_fronter["member"]: + # Get current fronters + switches = await db.front_history(conn, system_id=system["id"], count=1) + if not switches or not switches[0]["members"]: return False, "There's already no one in front." - # Log it - await db.add_switch(conn, system_id=system["id"], member_id=None) + # Log it, and don't log any members + await db.add_switch(conn, system_id=system["id"]) return True, "Switch-out registered." @command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands") diff --git a/bot/pluralkit/db.py b/bot/pluralkit/db.py index a91ec296..dce54e0f 100644 --- a/bot/pluralkit/db.py +++ b/bot/pluralkit/db.py @@ -104,6 +104,9 @@ async def get_member_by_hid(conn, member_hid: str): async def get_member(conn, member_id: int): return await conn.fetchrow("select * from members where id = $1", member_id) +@db_wrap +async def get_members(conn, members: list): + return await conn.fetch("select * from members where id = any($1)", members) @db_wrap async def get_message(conn, message_id: str): @@ -186,27 +189,43 @@ async def delete_message(conn, message_id: str): logger.debug("Deleting message (id={})".format(message_id)) await conn.execute("delete from messages where mid = $1", int(message_id)) -@db_wrap -async def current_fronter(conn, system_id: int): - return await conn.fetchrow("""select *, members.name - from switches - left outer join members on (members.id = switches.member) -- Left outer join instead of normal join - makes name = null instead of just ignoring the row - where switches.system = $1 - order by timestamp desc""", system_id) +# @db_wrap +# async def front_history(conn, system_id: int, count: int): +# return await conn.fetch("""select +# switches.timestamp, members.name, members.id, switches.id as switch_id +# from +# ( +# select * from switches where system = $1 order by timestamp desc limit $2 +# ) as switches +# left outer join switch_members +# on switch_members.switch = switches.id +# left outer join members +# on switch_members.member = members.id +# order by switches.timestamp desc""", system_id, count) @db_wrap -async def past_fronters(conn, system_id: int, amount: int): - return await conn.fetch("""select *, members.name +async def front_history(conn, system_id: int, count: int): + return await conn.fetch("""select + switches.*, + array( + select member from switch_members + where switch_members.switch = switches.id + ) as members from switches - left outer join members on (members.id = switches.member) -- (see above) where switches.system = $1 - order by timestamp - desc limit $2""", system_id, amount) + order by switches.timestamp desc + limit $2""", system_id, count) @db_wrap -async def add_switch(conn, system_id: int, member_id: int): - logger.debug("Adding switch (system={}, member={})".format(system_id, member_id)) - return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id) +async def add_switch(conn, system_id: int): + logger.debug("Adding switch (system={})".format(system_id)) + res = await conn.fetchrow("insert into switches (system) values ($1) returning *", system_id) + return res["id"] + +@db_wrap +async def add_switch_member(conn, switch_id: int, member_id: int): + logger.debug("Adding switch member (switch={}, member={})".format(switch_id, member_id)) + await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id) @db_wrap async def get_server_info(conn, server_id: str): @@ -254,9 +273,12 @@ async def create_tables(conn): await conn.execute("""create table if not exists switches ( id serial primary key, system serial not null references systems(id) on delete cascade, - member serial references members(id) on delete restrict, - timestamp timestamp not null default current_timestamp, - member_del bool not null default false + timestamp timestamp not null default current_timestamp + )""") + await conn.execute("""create table if not exists switch_members ( + id serial primary key, + switch serial not null references switches(id) on delete cascade, + member serial not null references members(id) on delete cascade )""") await conn.execute("""create table if not exists webhooks ( channel bigint primary key, diff --git a/bot/pluralkit/utils.py b/bot/pluralkit/utils.py index cc4f4c94..1499ae7d 100644 --- a/bot/pluralkit/utils.py +++ b/bot/pluralkit/utils.py @@ -170,10 +170,13 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb if system["tag"]: card.add_field(name="Tag", value=system["tag"]) - current_fronter = await db.current_fronter(conn, system_id=system["id"]) - if current_fronter and current_fronter["member"]: - fronter_val = "{} (for {})".format(current_fronter["name"], humanize.naturaldelta(current_fronter["timestamp"])) - card.add_field(name="Current fronter", value=fronter_val) + switches = await db.front_history(conn, system_id=system["id"], count=1) + if switches and switches[0]["members"]: + members = await db.get_members(conn, switches[0]["members"]) + names = ", ".join([member["name"] for member in members]) + + fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switches[0]["timestamp"])) + card.add_field(name="Current fronter" if len(members) == 1 else "Current fronters", value=fronter_val) account_names = [] for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):