Added multi-fronter support
This commit is contained in:
parent
1542a8dd40
commit
8599ee3fd0
@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import itertools
|
||||||
import re
|
import re
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@ -145,7 +146,7 @@ async def system_unlink(conn, message, args):
|
|||||||
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
|
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
|
||||||
return True, "Account unlinked."
|
return True, "Account unlinked."
|
||||||
|
|
||||||
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter in the system.", category="Switching commands")
|
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands")
|
||||||
async def system_fronter(conn, message, args):
|
async def system_fronter(conn, message, args):
|
||||||
if len(args) == 0:
|
if len(args) == 0:
|
||||||
system = await db.get_system_by_account(conn, message.author.id)
|
system = await db.get_system_by_account(conn, message.author.id)
|
||||||
@ -158,21 +159,30 @@ async def system_fronter(conn, message, args):
|
|||||||
if system is None:
|
if system is None:
|
||||||
return False, "Can't find system \"{}\".".format(args[0])
|
return False, "Can't find system \"{}\".".format(args[0])
|
||||||
|
|
||||||
current_fronter = await db.current_fronter(conn, system_id=system["id"])
|
# Get latest switch from DB
|
||||||
if not current_fronter:
|
switches = await db.front_history(conn, system_id=system["id"], count=1)
|
||||||
|
if len(switches) == 0:
|
||||||
|
# Special case if empty
|
||||||
return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*")
|
return True, make_default_embed(None).add_field(name="Current fronter", value="*(nobody)*")
|
||||||
|
|
||||||
fronter_name = "*(nobody)*"
|
switch = switches[0]
|
||||||
if current_fronter["member"]:
|
|
||||||
member = await db.get_member(conn, member_id=current_fronter["member"])
|
|
||||||
fronter_name = member["name"]
|
|
||||||
if current_fronter["member_del"]:
|
|
||||||
fronter_name = "*(deleted member)*"
|
|
||||||
|
|
||||||
since = current_fronter["timestamp"]
|
fronter_names = []
|
||||||
|
if len(switch["members"]) > 0:
|
||||||
|
# Fetch member data from DB
|
||||||
|
members = await db.get_members(conn, switch["members"])
|
||||||
|
fronter_names = [member["name"] for member in members]
|
||||||
|
|
||||||
embed = make_default_embed(None)
|
embed = make_default_embed(None)
|
||||||
embed.add_field(name="Current fronter", value=fronter_name)
|
|
||||||
|
if len(fronter_names) == 0:
|
||||||
|
embed.add_field(name="Current fronter", value="*nobody*")
|
||||||
|
elif len(fronter_names) == 1:
|
||||||
|
embed.add_field(name="Current fronters", value=fronter_names[0])
|
||||||
|
else:
|
||||||
|
embed.add_field(name="Current fronter", value=", ".join(fronter_names))
|
||||||
|
|
||||||
|
since = switch["timestamp"]
|
||||||
embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since)))
|
embed.add_field(name="Since", value="{} ({})".format(since.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(since)))
|
||||||
return True, embed
|
return True, embed
|
||||||
|
|
||||||
@ -189,15 +199,29 @@ async def system_fronthistory(conn, message, args):
|
|||||||
if system is None:
|
if system is None:
|
||||||
return False, "Can't find system \"{}\".".format(args[0])
|
return False, "Can't find system \"{}\".".format(args[0])
|
||||||
|
|
||||||
switches = await db.past_fronters(conn, system_id=system["id"], amount=10)
|
# Get list of past switches from DB
|
||||||
|
switches = await db.front_history(conn, system_id=system["id"], count=10)
|
||||||
|
|
||||||
|
# Get all unique IDs referenced
|
||||||
|
all_member_ids = {id for switch in switches for id in switch["members"]}
|
||||||
|
|
||||||
|
# And look them up in the database into a dict
|
||||||
|
all_members = {member["id"]: member for member in await db.get_members(conn, list(all_member_ids))}
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
for switch in switches:
|
for switch in switches:
|
||||||
|
# Special case when no one's fronting
|
||||||
|
if len(switch["members"]) == 0:
|
||||||
|
name = "*nobody*"
|
||||||
|
else:
|
||||||
|
name = ", ".join([all_members[id]["name"] for id in switch["members"]])
|
||||||
|
|
||||||
|
# Make proper date string
|
||||||
since = switch["timestamp"]
|
since = switch["timestamp"]
|
||||||
time_text = since.isoformat(sep=" ", timespec="seconds")
|
time_text = since.isoformat(sep=" ", timespec="seconds")
|
||||||
rel_text = humanize.naturaltime(since)
|
rel_text = humanize.naturaltime(since)
|
||||||
|
|
||||||
lines.append("**{}** ({}, at {})".format(switch["name"], time_text, rel_text))
|
lines.append("**{}** ({}, {})".format(name, time_text, rel_text))
|
||||||
|
|
||||||
embed = make_default_embed("\n".join(lines))
|
embed = make_default_embed("\n".join(lines))
|
||||||
embed.title = "Past switches"
|
embed.title = "Past switches"
|
||||||
@ -292,16 +316,16 @@ async def member_set(conn, message, member, args):
|
|||||||
return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)."
|
return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)."
|
||||||
|
|
||||||
if prop == "avatar":
|
if prop == "avatar":
|
||||||
user = await parse_mention(args[0])
|
user = await parse_mention(value)
|
||||||
if user:
|
if user:
|
||||||
# Set the avatar to the mentioned user's avatar
|
# Set the avatar to the mentioned user's avatar
|
||||||
# Discord doesn't like webp, but also hosts png alternatives
|
# Discord doesn't like webp, but also hosts png alternatives
|
||||||
value = user.avatar_url.replace(".webp", ".png")
|
value = user.avatar_url.replace(".webp", ".png")
|
||||||
else:
|
else:
|
||||||
# Validate URL
|
# Validate URL
|
||||||
u = urlparse(args[0])
|
u = urlparse(value)
|
||||||
if u.scheme in ["http", "https"] and u.netloc and u.path:
|
if u.scheme in ["http", "https"] and u.netloc and u.path:
|
||||||
value = args[0]
|
value = value
|
||||||
else:
|
else:
|
||||||
return False, "Invalid URL."
|
return False, "Invalid URL."
|
||||||
else:
|
else:
|
||||||
@ -405,7 +429,7 @@ async def message_info(conn, message, args):
|
|||||||
await client.send_message(message.channel, embed=embed)
|
await client.send_message(message.channel, embed=embed)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@command(cmd="switch", usage="<name|id>", description="Registers a switch and changes the current fronter.", category="Switching commands")
|
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
|
||||||
async def switch_member(conn, message, args):
|
async def switch_member(conn, message, args):
|
||||||
if len(args) == 0:
|
if len(args) == 0:
|
||||||
return False
|
return False
|
||||||
@ -415,19 +439,36 @@ async def switch_member(conn, message, args):
|
|||||||
if system is None:
|
if system is None:
|
||||||
return False, "No system is registered to this account."
|
return False, "No system is registered to this account."
|
||||||
|
|
||||||
# Find the member
|
members = []
|
||||||
member = await get_member_fuzzy(conn, system["id"], " ".join(args))
|
for member_name in args:
|
||||||
if not member:
|
# Find the member
|
||||||
return False, "Couldn't find member \"{}\".".format(args[0])
|
member = await get_member_fuzzy(conn, system["id"], member_name)
|
||||||
|
if not member:
|
||||||
|
return False, "Couldn't find member \"{}\".".format(member_name)
|
||||||
|
members.append(member)
|
||||||
|
|
||||||
|
member_ids = {member["id"] for member in members}
|
||||||
|
|
||||||
# Get current fronter
|
switches = await db.front_history(conn, system_id=system["id"], count=1)
|
||||||
current_fronter = await db.current_fronter(conn, system_id=system["id"])
|
fronter_ids = {}
|
||||||
if current_fronter and current_fronter["member"] == member["id"]:
|
if switches:
|
||||||
return False, "Member \"{}\" is already fronting.".format(member["name"])
|
fronter_ids = set(switches[0]["members"])
|
||||||
|
|
||||||
|
if member_ids == fronter_ids:
|
||||||
|
if len(members) == 1:
|
||||||
|
return False, "{} is already fronting.".format(members[0]["name"])
|
||||||
|
return False, "Members {} are already fronting.".format(", ".join([m["name"] for m in members]))
|
||||||
|
|
||||||
# Log the switch
|
# Log the switch
|
||||||
await db.add_switch(conn, system_id=system["id"], member_id=member["id"])
|
async with conn.transaction():
|
||||||
return True, "Switch registered. Current fronter is now {}.".format(member["name"])
|
switch_id = await db.add_switch(conn, system_id=system["id"])
|
||||||
|
for member in members:
|
||||||
|
await db.add_switch_member(conn, switch_id=switch_id, member_id=member["id"])
|
||||||
|
|
||||||
|
if len(members) == 1:
|
||||||
|
return True, "Switch registered. Current fronter is now {}.".format(member["name"])
|
||||||
|
else:
|
||||||
|
return True, "Switch registered. Current fronters are now {}.".format(", ".join([m["name"] for m in members]))
|
||||||
|
|
||||||
@command(cmd="switch out", description="Registers a switch out, and leaves current fronter blank.", category="Switching commands")
|
@command(cmd="switch out", description="Registers a switch out, and leaves current fronter blank.", category="Switching commands")
|
||||||
async def switch_out(conn, message, args):
|
async def switch_out(conn, message, args):
|
||||||
@ -436,13 +477,13 @@ async def switch_out(conn, message, args):
|
|||||||
if system is None:
|
if system is None:
|
||||||
return False, "No system is registered to this account."
|
return False, "No system is registered to this account."
|
||||||
|
|
||||||
# Get current fronter
|
# Get current fronters
|
||||||
current_fronter = await db.current_fronter(conn, system_id=system["id"])
|
switches = await db.front_history(conn, system_id=system["id"], count=1)
|
||||||
if not current_fronter or not current_fronter["member"]:
|
if not switches or not switches[0]["members"]:
|
||||||
return False, "There's already no one in front."
|
return False, "There's already no one in front."
|
||||||
|
|
||||||
# Log it
|
# Log it, and don't log any members
|
||||||
await db.add_switch(conn, system_id=system["id"], member_id=None)
|
await db.add_switch(conn, system_id=system["id"])
|
||||||
return True, "Switch-out registered."
|
return True, "Switch-out registered."
|
||||||
|
|
||||||
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")
|
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands")
|
||||||
|
@ -104,6 +104,9 @@ async def get_member_by_hid(conn, member_hid: str):
|
|||||||
async def get_member(conn, member_id: int):
|
async def get_member(conn, member_id: int):
|
||||||
return await conn.fetchrow("select * from members where id = $1", member_id)
|
return await conn.fetchrow("select * from members where id = $1", member_id)
|
||||||
|
|
||||||
|
@db_wrap
|
||||||
|
async def get_members(conn, members: list):
|
||||||
|
return await conn.fetch("select * from members where id = any($1)", members)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_message(conn, message_id: str):
|
async def get_message(conn, message_id: str):
|
||||||
@ -186,27 +189,43 @@ async def delete_message(conn, message_id: str):
|
|||||||
logger.debug("Deleting message (id={})".format(message_id))
|
logger.debug("Deleting message (id={})".format(message_id))
|
||||||
await conn.execute("delete from messages where mid = $1", int(message_id))
|
await conn.execute("delete from messages where mid = $1", int(message_id))
|
||||||
|
|
||||||
@db_wrap
|
# @db_wrap
|
||||||
async def current_fronter(conn, system_id: int):
|
# async def front_history(conn, system_id: int, count: int):
|
||||||
return await conn.fetchrow("""select *, members.name
|
# return await conn.fetch("""select
|
||||||
from switches
|
# switches.timestamp, members.name, members.id, switches.id as switch_id
|
||||||
left outer join members on (members.id = switches.member) -- Left outer join instead of normal join - makes name = null instead of just ignoring the row
|
# from
|
||||||
where switches.system = $1
|
# (
|
||||||
order by timestamp desc""", system_id)
|
# select * from switches where system = $1 order by timestamp desc limit $2
|
||||||
|
# ) as switches
|
||||||
|
# left outer join switch_members
|
||||||
|
# on switch_members.switch = switches.id
|
||||||
|
# left outer join members
|
||||||
|
# on switch_members.member = members.id
|
||||||
|
# order by switches.timestamp desc""", system_id, count)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def past_fronters(conn, system_id: int, amount: int):
|
async def front_history(conn, system_id: int, count: int):
|
||||||
return await conn.fetch("""select *, members.name
|
return await conn.fetch("""select
|
||||||
|
switches.*,
|
||||||
|
array(
|
||||||
|
select member from switch_members
|
||||||
|
where switch_members.switch = switches.id
|
||||||
|
) as members
|
||||||
from switches
|
from switches
|
||||||
left outer join members on (members.id = switches.member) -- (see above)
|
|
||||||
where switches.system = $1
|
where switches.system = $1
|
||||||
order by timestamp
|
order by switches.timestamp desc
|
||||||
desc limit $2""", system_id, amount)
|
limit $2""", system_id, count)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def add_switch(conn, system_id: int, member_id: int):
|
async def add_switch(conn, system_id: int):
|
||||||
logger.debug("Adding switch (system={}, member={})".format(system_id, member_id))
|
logger.debug("Adding switch (system={})".format(system_id))
|
||||||
return await conn.execute("insert into switches (system, member) values ($1, $2)", system_id, member_id)
|
res = await conn.fetchrow("insert into switches (system) values ($1) returning *", system_id)
|
||||||
|
return res["id"]
|
||||||
|
|
||||||
|
@db_wrap
|
||||||
|
async def add_switch_member(conn, switch_id: int, member_id: int):
|
||||||
|
logger.debug("Adding switch member (switch={}, member={})".format(switch_id, member_id))
|
||||||
|
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_server_info(conn, server_id: str):
|
async def get_server_info(conn, server_id: str):
|
||||||
@ -254,9 +273,12 @@ async def create_tables(conn):
|
|||||||
await conn.execute("""create table if not exists switches (
|
await conn.execute("""create table if not exists switches (
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
system serial not null references systems(id) on delete cascade,
|
system serial not null references systems(id) on delete cascade,
|
||||||
member serial references members(id) on delete restrict,
|
timestamp timestamp not null default current_timestamp
|
||||||
timestamp timestamp not null default current_timestamp,
|
)""")
|
||||||
member_del bool not null default false
|
await conn.execute("""create table if not exists switch_members (
|
||||||
|
id serial primary key,
|
||||||
|
switch serial not null references switches(id) on delete cascade,
|
||||||
|
member serial not null references members(id) on delete cascade
|
||||||
)""")
|
)""")
|
||||||
await conn.execute("""create table if not exists webhooks (
|
await conn.execute("""create table if not exists webhooks (
|
||||||
channel bigint primary key,
|
channel bigint primary key,
|
||||||
|
@ -170,10 +170,13 @@ async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Emb
|
|||||||
if system["tag"]:
|
if system["tag"]:
|
||||||
card.add_field(name="Tag", value=system["tag"])
|
card.add_field(name="Tag", value=system["tag"])
|
||||||
|
|
||||||
current_fronter = await db.current_fronter(conn, system_id=system["id"])
|
switches = await db.front_history(conn, system_id=system["id"], count=1)
|
||||||
if current_fronter and current_fronter["member"]:
|
if switches and switches[0]["members"]:
|
||||||
fronter_val = "{} (for {})".format(current_fronter["name"], humanize.naturaldelta(current_fronter["timestamp"]))
|
members = await db.get_members(conn, switches[0]["members"])
|
||||||
card.add_field(name="Current fronter", value=fronter_val)
|
names = ", ".join([member["name"] for member in members])
|
||||||
|
|
||||||
|
fronter_val = "{} (for {})".format(names, humanize.naturaldelta(switches[0]["timestamp"]))
|
||||||
|
card.add_field(name="Current fronter" if len(members) == 1 else "Current fronters", value=fronter_val)
|
||||||
|
|
||||||
account_names = []
|
account_names = []
|
||||||
for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
|
for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
|
||||||
|
Loading…
Reference in New Issue
Block a user