Added multi-fronter support
This commit is contained in:
		| @@ -1,4 +1,5 @@ | ||||
| from datetime import datetime | ||||
| import itertools | ||||
| import re | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| @@ -145,7 +146,7 @@ async def system_unlink(conn, message, args): | ||||
|         await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) | ||||
|         return True, "Account unlinked." | ||||
|  | ||||
| @command(cmd="system fronter", usage="[system]", description="Gets the current fronter in the system.", category="Switching commands") | ||||
| @command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands") | ||||
| async def system_fronter(conn, message, args): | ||||
|     if len(args) == 0: | ||||
|         system = await db.get_system_by_account(conn, message.author.id) | ||||
| @@ -158,21 +159,30 @@ async def system_fronter(conn, message, args): | ||||
|         if system is None: | ||||
|             return False, "Can't find system \"{}\".".format(args[0]) | ||||
|      | ||||
|     current_fronter = await db.current_fronter(conn, system_id=system["id"]) | ||||
|     if not current_fronter: | ||||
|     # Get latest switch from DB | ||||
|     switches = await db.front_history(conn, system_id=system["id"], count=1) | ||||
|     if len(switches) == 0: | ||||
|         # Special case if empty | ||||
|         return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*") | ||||
|  | ||||
|     fronter_name = "*(nobody)*" | ||||
|     if current_fronter["member"]: | ||||
|         member = await db.get_member(conn, member_id=current_fronter["member"]) | ||||
|         fronter_name = member["name"] | ||||
|     if current_fronter["member_del"]: | ||||
|         fronter_name = "*(deleted member)*" | ||||
|     switch = switches[0] | ||||
|  | ||||
|     since = current_fronter["timestamp"] | ||||
|     fronter_names = [] | ||||
|     if len(switch["members"]) > 0: | ||||
|         # Fetch member data from DB | ||||
|         members = await db.get_members(conn, switch["members"]) | ||||
|         fronter_names = [member["name"] for member in members] | ||||
|  | ||||
|     embed = make_default_embed(None) | ||||
|     embed.add_field(name="Current fronter", value=fronter_name) | ||||
|  | ||||
|     if len(fronter_names) == 0: | ||||
|         embed.add_field(name="Current fronter", value="*nobody*") | ||||
|     elif len(fronter_names) == 1: | ||||
|         embed.add_field(name="Current fronters", value=fronter_names[0]) | ||||
|     else: | ||||
|         embed.add_field(name="Current fronter", value=", ".join(fronter_names)) | ||||
|  | ||||
|     since = switch["timestamp"] | ||||
|     embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since))) | ||||
|     return True, embed | ||||
|  | ||||
| @@ -189,15 +199,29 @@ async def system_fronthistory(conn, message, args): | ||||
|         if system is None: | ||||
|             return False, "Can't find system \"{}\".".format(args[0]) | ||||
|      | ||||
|     switches = await db.past_fronters(conn, system_id=system["id"], amount=10) | ||||
|     # Get list of past switches from DB | ||||
|     switches = await db.front_history(conn, system_id=system["id"], count=10) | ||||
|      | ||||
|     # Get all unique IDs referenced | ||||
|     all_member_ids = {id for switch in switches for id in switch["members"]} | ||||
|      | ||||
|     # And look them up in the database into a dict | ||||
|     all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))} | ||||
|  | ||||
|     lines = [] | ||||
|     for switch in switches: | ||||
|         # Special case when no one's fronting | ||||
|         if len(switch["members"]) == 0: | ||||
|             name = "*nobody*" | ||||
|         else: | ||||
|             name = ", ".join([all_members[id]["name"] for id in switch["members"]]) | ||||
|  | ||||
|         # Make proper date string | ||||
|         since = switch["timestamp"] | ||||
|         time_text = since.isoformat(sep=" ", timespec="seconds") | ||||
|         rel_text = humanize.naturaltime(since) | ||||
|  | ||||
|         lines.append("**{}** ({}, at {})".format(switch["name"], time_text, rel_text)) | ||||
|         lines.append("**{}** ({}, {})".format(name, time_text, rel_text)) | ||||
|  | ||||
|     embed = make_default_embed("\n".join(lines)) | ||||
|     embed.title = "Past switches" | ||||
| @@ -292,16 +316,16 @@ async def member_set(conn, message, member, args): | ||||
|                 return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)." | ||||
|  | ||||
|         if prop == "avatar": | ||||
|             user = await parse_mention(args[0]) | ||||
|             user = await parse_mention(value) | ||||
|             if user: | ||||
|                 # Set the avatar to the mentioned user's avatar | ||||
|                 # Discord doesn't like webp, but also hosts png alternatives | ||||
|                 value = user.avatar_url.replace(".webp", ".png") | ||||
|             else: | ||||
|                 # Validate URL | ||||
|                 u = urlparse(args[0]) | ||||
|                 u = urlparse(value) | ||||
|                 if u.scheme in ["http", "https"] and u.netloc and u.path: | ||||
|                     value = args[0] | ||||
|                     value = value | ||||
|                 else: | ||||
|                     return False, "Invalid URL." | ||||
|     else: | ||||
| @@ -405,7 +429,7 @@ async def message_info(conn, message, args): | ||||
|     await client.send_message(message.channel, embed=embed) | ||||
|     return True | ||||
|  | ||||
| @command(cmd="switch", usage="<name|id>", description="Registers a switch and changes the current fronter.", category="Switching commands") | ||||
| @command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands") | ||||
| async def switch_member(conn, message, args): | ||||
|     if len(args) == 0: | ||||
|         return False | ||||
| @@ -415,19 +439,36 @@ async def switch_member(conn, message, args): | ||||
|     if system is None: | ||||
|         return False, "No system is registered to this account." | ||||
|  | ||||
|     # Find the member | ||||
|     member = await get_member_fuzzy(conn, system["id"], " ".join(args)) | ||||
|     if not member: | ||||
|         return False, "Couldn't find member \"{}\".".format(args[0]) | ||||
|     members = [] | ||||
|     for member_name in args: | ||||
|         # Find the member | ||||
|         member = await get_member_fuzzy(conn, system["id"], member_name) | ||||
|         if not member: | ||||
|             return False, "Couldn't find member \"{}\".".format(member_name) | ||||
|         members.append(member) | ||||
|      | ||||
|     member_ids = {member["id"] for member in members} | ||||
|  | ||||
|     # Get current fronter | ||||
|     current_fronter = await db.current_fronter(conn, system_id=system["id"]) | ||||
|     if current_fronter and current_fronter["member"] == member["id"]: | ||||
|         return False, "Member \"{}\" is already fronting.".format(member["name"]) | ||||
|     switches = await db.front_history(conn, system_id=system["id"], count=1) | ||||
|     fronter_ids = {} | ||||
|     if switches: | ||||
|         fronter_ids = set(switches[0]["members"]) | ||||
|  | ||||
|     if member_ids == fronter_ids: | ||||
|         if len(members) == 1: | ||||
|             return False, "{} is already fronting.".format(members[0]["name"]) | ||||
|         return False, "Members {} are already fronting.".format(", ".join([m["name"] for m in members])) | ||||
|      | ||||
|     # Log the switch | ||||
|     await db.add_switch(conn, system_id=system["id"], member_id=member["id"]) | ||||
|     return True, "Switch registered. Current fronter is now {}.".format(member["name"]) | ||||
|     async with conn.transaction(): | ||||
|         switch_id = await db.add_switch(conn, system_id=system["id"]) | ||||
|         for member in members: | ||||
|             await db.add_switch_member(conn, switch_id=switch_id, member_id=member["id"]) | ||||
|  | ||||
|     if len(members) == 1: | ||||
|         return True, "Switch registered. Current fronter is now {}.".format(member["name"]) | ||||
|     else: | ||||
|         return True, "Switch registered. Current fronters are now {}.".format(", ".join([m["name"] for m in members])) | ||||
|  | ||||
| @command(cmd="switch out", description="Registers a switch out, and leaves current fronter blank.", category="Switching commands") | ||||
| async def switch_out(conn, message, args): | ||||
| @@ -436,13 +477,13 @@ async def switch_out(conn, message, args): | ||||
|     if system is None: | ||||
|         return False, "No system is registered to this account." | ||||
|  | ||||
|     # Get current fronter | ||||
|     current_fronter = await db.current_fronter(conn, system_id=system["id"]) | ||||
|     if not current_fronter or not current_fronter["member"]: | ||||
|     # Get current fronters | ||||
|     switches = await db.front_history(conn, system_id=system["id"], count=1) | ||||
|     if not switches or not switches[0]["members"]: | ||||
|         return False, "There's already no one in front." | ||||
|  | ||||
|     # Log it | ||||
|     await db.add_switch(conn, system_id=system["id"], member_id=None) | ||||
|     # Log it, and don't log any members | ||||
|     await db.add_switch(conn, system_id=system["id"]) | ||||
|     return True, "Switch-out registered." | ||||
|  | ||||
| @command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands") | ||||
|   | ||||
| @@ -104,6 +104,9 @@ async def get_member_by_hid(conn, member_hid: str): | ||||
| async def get_member(conn, member_id: int): | ||||
|     return await conn.fetchrow("select * from members where id = $1", member_id) | ||||
|  | ||||
| @db_wrap | ||||
| async def get_members(conn, members: list): | ||||
|     return await conn.fetch("select * from members where id = any($1)", members) | ||||
|  | ||||
| @db_wrap | ||||
| async def get_message(conn, message_id: str): | ||||
| @@ -186,27 +189,43 @@ async def delete_message(conn, message_id: str): | ||||
|     logger.debug("Deleting message (id={})".format(message_id)) | ||||
|     await conn.execute("delete from messages where mid = $1", int(message_id)) | ||||
|  | ||||
| @db_wrap | ||||
| async def current_fronter(conn, system_id: int): | ||||
|     return await conn.fetchrow("""select *, members.name | ||||
|     from switches | ||||
|     left outer join members on (members.id = switches.member) -- Left outer join instead of normal join - makes name = null instead of just ignoring the row | ||||
|     where switches.system = $1 | ||||
|     order by timestamp desc""", system_id) | ||||
| # @db_wrap | ||||
| # async def front_history(conn, system_id: int, count: int): | ||||
| #     return await conn.fetch("""select | ||||
| #         switches.timestamp, members.name, members.id, switches.id as switch_id | ||||
| #     from | ||||
| #         ( | ||||
| #             select * from switches where system = $1 order by timestamp desc limit $2 | ||||
| #         ) as switches | ||||
| #     left outer join switch_members | ||||
| #         on switch_members.switch = switches.id | ||||
| #     left outer join members | ||||
| #         on switch_members.member = members.id | ||||
| #     order by switches.timestamp desc""", system_id, count) | ||||
|  | ||||
| @db_wrap | ||||
| async def past_fronters(conn, system_id: int, amount: int): | ||||
|     return await conn.fetch("""select *, members.name | ||||
| async def front_history(conn, system_id: int, count: int): | ||||
|     return await conn.fetch("""select | ||||
|         switches.*, | ||||
|         array( | ||||
|             select member from switch_members | ||||
|             where switch_members.switch = switches.id | ||||
|         ) as members | ||||
|     from switches | ||||
|     left outer join members on (members.id = switches.member) -- (see above) | ||||
|     where switches.system = $1 | ||||
|     order by timestamp | ||||
|     desc limit $2""", system_id, amount) | ||||
|     order by switches.timestamp desc | ||||
|     limit $2""", system_id, count) | ||||
|  | ||||
| @db_wrap | ||||
| async def add_switch(conn, system_id: int, member_id: int): | ||||
|     logger.debug("Adding switch (system={}, member={})".format(system_id, member_id)) | ||||
|     return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id) | ||||
| async def add_switch(conn, system_id: int): | ||||
|     logger.debug("Adding switch (system={})".format(system_id)) | ||||
|     res = await conn.fetchrow("insert into switches (system) values ($1) returning *", system_id) | ||||
|     return res["id"] | ||||
|  | ||||
| @db_wrap | ||||
| async def add_switch_member(conn, switch_id: int, member_id: int): | ||||
|     logger.debug("Adding switch member (switch={}, member={})".format(switch_id, member_id)) | ||||
|     await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id) | ||||
|  | ||||
| @db_wrap | ||||
| async def get_server_info(conn, server_id: str): | ||||
| @@ -254,9 +273,12 @@ async def create_tables(conn): | ||||
|     await conn.execute("""create table if not exists switches ( | ||||
|         id          serial primary key, | ||||
|         system      serial not null references systems(id) on delete cascade, | ||||
|         member      serial references members(id) on delete restrict, | ||||
|         timestamp   timestamp not null default current_timestamp, | ||||
|         member_del  bool not null default false | ||||
|         timestamp   timestamp not null default current_timestamp | ||||
|     )""") | ||||
|     await conn.execute("""create table if not exists switch_members ( | ||||
|         id          serial primary key, | ||||
|         switch      serial not null references switches(id) on delete cascade, | ||||
|         member      serial not null references members(id) on delete cascade | ||||
|     )""") | ||||
|     await conn.execute("""create table if not exists webhooks ( | ||||
|         channel     bigint primary key, | ||||
|   | ||||
| @@ -170,10 +170,13 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb | ||||
|     if system["tag"]: | ||||
|         card.add_field(name="Tag", value=system["tag"]) | ||||
|  | ||||
|     current_fronter = await db.current_fronter(conn, system_id=system["id"]) | ||||
|     if current_fronter and current_fronter["member"]: | ||||
|         fronter_val = "{} (for {})".format(current_fronter["name"], humanize.naturaldelta(current_fronter["timestamp"])) | ||||
|         card.add_field(name="Current fronter", value=fronter_val) | ||||
|     switches = await db.front_history(conn, system_id=system["id"], count=1) | ||||
|     if switches and switches[0]["members"]: | ||||
|         members = await db.get_members(conn, switches[0]["members"]) | ||||
|         names = ", ".join([member["name"] for member in members]) | ||||
|          | ||||
|         fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switches[0]["timestamp"])) | ||||
|         card.add_field(name="Current fronter" if len(members) == 1 else "Current fronters", value=fronter_val) | ||||
|  | ||||
|     account_names = [] | ||||
|     for account_id in await db.get_linked_accounts(conn, system_id=system["id"]): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user