Added multi-fronter support

This commit is contained in:
Ske 2018-07-14 02:28:15 +02:00
parent 1542a8dd40
commit 8599ee3fd0
3 changed files with 120 additions and 54 deletions

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
import itertools
import re import re
from urllib.parse import urlparse from urllib.parse import urlparse
@ -145,7 +146,7 @@ async def system_unlink(conn, message, args):
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id) await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
return True, "Account unlinked." return True, "Account unlinked."
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter in the system.", category="Switching commands") @command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands")
async def system_fronter(conn, message, args): async def system_fronter(conn, message, args):
if len(args) == 0: if len(args) == 0:
system = await db.get_system_by_account(conn, message.author.id) system = await db.get_system_by_account(conn, message.author.id)
@ -158,21 +159,30 @@ async def system_fronter(conn, message, args):
if system is None: if system is None:
return False, "Can't find system \"{}\".".format(args[0]) return False, "Can't find system \"{}\".".format(args[0])
current_fronter = await db.current_fronter(conn, system_id=system["id"]) # Get latest switch from DB
if not current_fronter: switches = await db.front_history(conn, system_id=system["id"], count=1)
if len(switches) == 0:
# Special case if empty
return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*") return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*")
fronter_name = "*(nobody)*" switch = switches[0]
if current_fronter["member"]:
member = await db.get_member(conn, member_id=current_fronter["member"])
fronter_name = member["name"]
if current_fronter["member_del"]:
fronter_name = "*(deleted member)*"
since = current_fronter["timestamp"] fronter_names = []
if len(switch["members"]) > 0:
# Fetch member data from DB
members = await db.get_members(conn, switch["members"])
fronter_names = [member["name"] for member in members]
embed = make_default_embed(None) embed = make_default_embed(None)
embed.add_field(name="Current fronter", value=fronter_name)
if len(fronter_names) == 0:
embed.add_field(name="Current fronter", value="*nobody*")
elif len(fronter_names) == 1:
embed.add_field(name="Current fronters", value=fronter_names[0])
else:
embed.add_field(name="Current fronter", value=", ".join(fronter_names))
since = switch["timestamp"]
embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since))) embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since)))
return True, embed return True, embed
@ -189,15 +199,29 @@ async def system_fronthistory(conn, message, args):
if system is None: if system is None:
return False, "Can't find system \"{}\".".format(args[0]) return False, "Can't find system \"{}\".".format(args[0])
switches = await db.past_fronters(conn, system_id=system["id"], amount=10) # Get list of past switches from DB
switches = await db.front_history(conn, system_id=system["id"], count=10)
# Get all unique IDs referenced
all_member_ids = {id for switch in switches for id in switch["members"]}
# And look them up in the database into a dict
all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))}
lines = [] lines = []
for switch in switches: for switch in switches:
# Special case when no one's fronting
if len(switch["members"]) == 0:
name = "*nobody*"
else:
name = ", ".join([all_members[id]["name"] for id in switch["members"]])
# Make proper date string
since = switch["timestamp"] since = switch["timestamp"]
time_text = since.isoformat(sep=" ", timespec="seconds") time_text = since.isoformat(sep=" ", timespec="seconds")
rel_text = humanize.naturaltime(since) rel_text = humanize.naturaltime(since)
lines.append("**{}** ({}, at {})".format(switch["name"], time_text, rel_text)) lines.append("**{}** ({}, {})".format(name, time_text, rel_text))
embed = make_default_embed("\n".join(lines)) embed = make_default_embed("\n".join(lines))
embed.title = "Past switches" embed.title = "Past switches"
@ -292,16 +316,16 @@ async def member_set(conn, message, member, args):
return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)." return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)."
if prop == "avatar": if prop == "avatar":
user = await parse_mention(args[0]) user = await parse_mention(value)
if user: if user:
# Set the avatar to the mentioned user's avatar # Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives # Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png") value = user.avatar_url.replace(".webp", ".png")
else: else:
# Validate URL # Validate URL
u = urlparse(args[0]) u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path: if u.scheme in ["http", "https"] and u.netloc and u.path:
value = args[0] value = value
else: else:
return False, "Invalid URL." return False, "Invalid URL."
else: else:
@ -405,7 +429,7 @@ async def message_info(conn, message, args):
await client.send_message(message.channel, embed=embed) await client.send_message(message.channel, embed=embed)
return True return True
@command(cmd="switch", usage="<name|id>", description="Registers a switch and changes the current fronter.", category="Switching commands") @command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
async def switch_member(conn, message, args): async def switch_member(conn, message, args):
if len(args) == 0: if len(args) == 0:
return False return False
@ -415,19 +439,36 @@ async def switch_member(conn, message, args):
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
# Find the member members = []
member = await get_member_fuzzy(conn, system["id"], " ".join(args)) for member_name in args:
if not member: # Find the member
return False, "Couldn't find member \"{}\".".format(args[0]) member = await get_member_fuzzy(conn, system["id"], member_name)
if not member:
return False, "Couldn't find member \"{}\".".format(member_name)
members.append(member)
# Get current fronter member_ids = {member["id"] for member in members}
current_fronter = await db.current_fronter(conn, system_id=system["id"])
if current_fronter and current_fronter["member"] == member["id"]: switches = await db.front_history(conn, system_id=system["id"], count=1)
return False, "Member \"{}\" is already fronting.".format(member["name"]) fronter_ids = {}
if switches:
fronter_ids = set(switches[0]["members"])
if member_ids == fronter_ids:
if len(members) == 1:
return False, "{} is already fronting.".format(members[0]["name"])
return False, "Members {} are already fronting.".format(", ".join([m["name"] for m in members]))
# Log the switch # Log the switch
await db.add_switch(conn, system_id=system["id"], member_id=member["id"]) async with conn.transaction():
return True, "Switch registered. Current fronter is now {}.".format(member["name"]) switch_id = await db.add_switch(conn, system_id=system["id"])
for member in members:
await db.add_switch_member(conn, switch_id=switch_id, member_id=member["id"])
if len(members) == 1:
return True, "Switch registered. Current fronter is now {}.".format(member["name"])
else:
return True, "Switch registered. Current fronters are now {}.".format(", ".join([m["name"] for m in members]))
@command(cmd="switch out", description="Registers a switch out, and leaves current fronter blank.", category="Switching commands") @command(cmd="switch out", description="Registers a switch out, and leaves current fronter blank.", category="Switching commands")
async def switch_out(conn, message, args): async def switch_out(conn, message, args):
@ -436,13 +477,13 @@ async def switch_out(conn, message, args):
if system is None: if system is None:
return False, "No system is registered to this account." return False, "No system is registered to this account."
# Get current fronter # Get current fronters
current_fronter = await db.current_fronter(conn, system_id=system["id"]) switches = await db.front_history(conn, system_id=system["id"], count=1)
if not current_fronter or not current_fronter["member"]: if not switches or not switches[0]["members"]:
return False, "There's already no one in front." return False, "There's already no one in front."
# Log it # Log it, and don't log any members
await db.add_switch(conn, system_id=system["id"], member_id=None) await db.add_switch(conn, system_id=system["id"])
return True, "Switch-out registered." return True, "Switch-out registered."
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands") @command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")

View File

@ -104,6 +104,9 @@ async def get_member_by_hid(conn, member_hid: str):
async def get_member(conn, member_id: int): async def get_member(conn, member_id: int):
return await conn.fetchrow("select * from members where id = $1", member_id) return await conn.fetchrow("select * from members where id = $1", member_id)
@db_wrap
async def get_members(conn, members: list):
return await conn.fetch("select * from members where id = any($1)", members)
@db_wrap @db_wrap
async def get_message(conn, message_id: str): async def get_message(conn, message_id: str):
@ -186,27 +189,43 @@ async def delete_message(conn, message_id: str):
logger.debug("Deleting message (id={})".format(message_id)) logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", int(message_id)) await conn.execute("delete from messages where mid = $1", int(message_id))
@db_wrap # @db_wrap
async def current_fronter(conn, system_id: int): # async def front_history(conn, system_id: int, count: int):
return await conn.fetchrow("""select *, members.name # return await conn.fetch("""select
from switches # switches.timestamp, members.name, members.id, switches.id as switch_id
left outer join members on (members.id = switches.member) -- Left outer join instead of normal join - makes name = null instead of just ignoring the row # from
where switches.system = $1 # (
order by timestamp desc""", system_id) # select * from switches where system = $1 order by timestamp desc limit $2
# ) as switches
# left outer join switch_members
# on switch_members.switch = switches.id
# left outer join members
# on switch_members.member = members.id
# order by switches.timestamp desc""", system_id, count)
@db_wrap @db_wrap
async def past_fronters(conn, system_id: int, amount: int): async def front_history(conn, system_id: int, count: int):
return await conn.fetch("""select *, members.name return await conn.fetch("""select
switches.*,
array(
select member from switch_members
where switch_members.switch = switches.id
) as members
from switches from switches
left outer join members on (members.id = switches.member) -- (see above)
where switches.system = $1 where switches.system = $1
order by timestamp order by switches.timestamp desc
desc limit $2""", system_id, amount) limit $2""", system_id, count)
@db_wrap @db_wrap
async def add_switch(conn, system_id: int, member_id: int): async def add_switch(conn, system_id: int):
logger.debug("Adding switch (system={}, member={})".format(system_id, member_id)) logger.debug("Adding switch (system={})".format(system_id))
return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id) res = await conn.fetchrow("insert into switches (system) values ($1) returning *", system_id)
return res["id"]
@db_wrap
async def add_switch_member(conn, switch_id: int, member_id: int):
logger.debug("Adding switch member (switch={}, member={})".format(switch_id, member_id))
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
@db_wrap @db_wrap
async def get_server_info(conn, server_id: str): async def get_server_info(conn, server_id: str):
@ -254,9 +273,12 @@ async def create_tables(conn):
await conn.execute("""create table if not exists switches ( await conn.execute("""create table if not exists switches (
id serial primary key, id serial primary key,
system serial not null references systems(id) on delete cascade, system serial not null references systems(id) on delete cascade,
member serial references members(id) on delete restrict, timestamp timestamp not null default current_timestamp
timestamp timestamp not null default current_timestamp, )""")
member_del bool not null default false await conn.execute("""create table if not exists switch_members (
id serial primary key,
switch serial not null references switches(id) on delete cascade,
member serial not null references members(id) on delete cascade
)""") )""")
await conn.execute("""create table if not exists webhooks ( await conn.execute("""create table if not exists webhooks (
channel bigint primary key, channel bigint primary key,

View File

@ -170,10 +170,13 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb
if system["tag"]: if system["tag"]:
card.add_field(name="Tag", value=system["tag"]) card.add_field(name="Tag", value=system["tag"])
current_fronter = await db.current_fronter(conn, system_id=system["id"]) switches = await db.front_history(conn, system_id=system["id"], count=1)
if current_fronter and current_fronter["member"]: if switches and switches[0]["members"]:
fronter_val = "{} (for {})".format(current_fronter["name"], humanize.naturaldelta(current_fronter["timestamp"])) members = await db.get_members(conn, switches[0]["members"])
card.add_field(name="Current fronter", value=fronter_val) names = ", ".join([member["name"] for member in members])
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switches[0]["timestamp"]))
card.add_field(name="Current fronter" if len(members) == 1 else "Current fronters", value=fronter_val)
account_names = [] account_names = []
for account_id in await db.get_linked_accounts(conn, system_id=system["id"]): for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):