Major command handling refactor

This commit is contained in:
Ske 2018-09-07 17:34:38 +02:00
parent 0869f94cdf
commit f067485e88
15 changed files with 463 additions and 355 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@
.vscode/ .vscode/
.idea/ .idea/
venv/ venv/
*.pyc

View File

@ -75,7 +75,11 @@ class PluralKitBot:
pass pass
async def handle_command_dispatch(self, message): async def handle_command_dispatch(self, message):
command_items = commands.command_list.items() async with self.pool.acquire() as conn:
result = await commands.command_dispatch(self.client, message, conn)
return result
"""command_items = commands.command_list.items()
command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True) command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True)
prefix = "pk;" prefix = "pk;"
@ -98,7 +102,7 @@ class PluralKitBot:
response_time = (datetime.now() - message.timestamp).total_seconds() response_time = (datetime.now() - message.timestamp).total_seconds()
await self.stats.report_command(command_name, execution_time, response_time) await self.stats.report_command(command_name, execution_time, response_time)
return True return True"""
async def handle_proxy_dispatch(self, message): async def handle_proxy_dispatch(self, message):
# Try doing proxy parsing # Try doing proxy parsing

View File

@ -1,84 +1,111 @@
import logging
from collections import namedtuple
import asyncpg
import discord import discord
import logging
import re
from typing import Tuple, Optional
import pluralkit from pluralkit import db, System, Member
from pluralkit import db from pluralkit.bot import embeds, utils
from pluralkit.bot import utils, embeds
logger = logging.getLogger("pluralkit.bot.commands") logger = logging.getLogger("pluralkit.bot.commands")
command_list = {}
class NoSystemRegistered(Exception): def next_arg(arg_string: str) -> Tuple[str, Optional[str]]:
pass if arg_string.startswith("\""):
end_quote = arg_string.find("\"", start=1)
if end_quote > 0:
return arg_string[1:end_quote], arg_string[end_quote + 1:].strip()
else:
return arg_string[1:], None
class CommandContext(namedtuple("CommandContext", ["client", "conn", "message", "system"])): next_space = arg_string.find(" ")
client: discord.Client if next_space >= 0:
conn: asyncpg.Connection return arg_string[:next_space].strip(), arg_string[next_space:].strip()
message: discord.Message else:
system: pluralkit.System return arg_string.strip(), None
async def reply(self, message=None, embed=None):
return await self.client.send_message(self.message.channel, message, embed=embed)
class MemberCommandContext(namedtuple("MemberCommandContext", CommandContext._fields + ("member",)), CommandContext): class CommandResponse:
client: discord.Client def to_embed(self):
conn: asyncpg.Connection pass
message: discord.Message
system: pluralkit.System
member: pluralkit.Member
class CommandEntry(namedtuple("CommandEntry", ["command", "function", "usage", "description", "category"])):
pass
def command(cmd, usage=None, description=None, category=None, system_required=True): class CommandSuccess(CommandResponse):
def wrap(func): def __init__(self, text):
async def wrapper(client, conn, message, args): self.text = text
system = await db.get_system_by_account(conn, message.author.id)
if system_required and system is None: def to_embed(self):
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one.")) return embeds.success("\u2705 " + self.text)
return
ctx = CommandContext(client=client, conn=conn, message=message, system=system)
try:
res = await func(ctx, args)
if res: class CommandError(Exception, CommandResponse):
embed = res if isinstance(res, discord.Embed) else utils.make_default_embed(res) def __init__(self, embed: str, help: Tuple[str, str] = None):
await client.send_message(message.channel, embed=embed) self.text = embed
except NoSystemRegistered: self.help = help
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one."))
except Exception:
logger.exception("Exception while handling command {} (args={}, system={})".format(cmd, args, system.hid if system else "(none)"))
# Put command in map def to_embed(self):
command_list[cmd] = CommandEntry(command=cmd, function=wrapper, usage=usage, description=description, category=category) return embeds.error("\u274c " + self.text, self.help)
return wrapper
return wrap
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
def wrap(func):
async def wrapper(ctx: CommandContext, args):
# Return if no member param
if len(args) == 0:
return embeds.error("You must pass a member name or ID.")
# System is allowed to be none if not system_only class CommandContext:
system_id = ctx.system.id if ctx.system else None def __init__(self, client: discord.Client, message: discord.Message, conn, args: str):
# And find member by key self.client = client
member = await utils.get_member_fuzzy(ctx.conn, system_id=system_id, key=args[0], system_only=system_only) self.message = message
self.conn = conn
self.args = args
if member is None: async def get_system(self) -> Optional[System]:
return embeds.error("Can't find member \"{}\".".format(args[0])) return await db.get_system_by_account(self.conn, self.message.author.id)
async def ensure_system(self) -> System:
system = await self.get_system()
if not system:
raise CommandError(
embeds.error("No system registered to this account. Use `pk;system new` to register one."))
return system
def has_next(self) -> bool:
return bool(self.args)
def pop_str(self, error: CommandError = None) -> str:
if not self.args:
if error:
raise error
return None
popped, self.args = next_arg(self.args)
return popped
async def pop_system(self, error: CommandError = None) -> System:
name = self.pop_str(error)
system = await utils.get_system_fuzzy(self.conn, self.client, name)
if not system:
raise CommandError("Unable to find system '{}'.".format(name))
return system
async def pop_member(self, error: CommandError = None, system_only: bool = True) -> Member:
name = self.pop_str(error)
if system_only:
system = await self.ensure_system()
else:
system = await self.get_system()
member = await utils.get_member_fuzzy(self.conn, system.id if system else None, name, system_only)
if not member:
raise CommandError("Unable to find member '{}'{}.".format(name, " in your system" if system_only else ""))
return member
def remaining(self):
return self.args
async def reply(self, content=None, embed=None):
return await self.client.send_message(self.message.channel, content=content, embed=embed)
ctx = MemberCommandContext(client=ctx.client, conn=ctx.conn, message=ctx.message, system=ctx.system, member=member)
return await func(ctx, args[1:])
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category, system_required=False)(wrapper)
return wrap
import pluralkit.bot.commands.import_commands import pluralkit.bot.commands.import_commands
import pluralkit.bot.commands.member_commands import pluralkit.bot.commands.member_commands
@ -87,3 +114,69 @@ import pluralkit.bot.commands.misc_commands
import pluralkit.bot.commands.mod_commands import pluralkit.bot.commands.mod_commands
import pluralkit.bot.commands.switch_commands import pluralkit.bot.commands.switch_commands
import pluralkit.bot.commands.system_commands import pluralkit.bot.commands.system_commands
async def run_command(ctx: CommandContext, func):
try:
result = await func(ctx)
if isinstance(result, CommandResponse):
await ctx.reply(embed=result.to_embed())
except CommandError as e:
await ctx.reply(embed=e.to_embed())
except Exception:
logger.exception("Exception while dispatching command")
async def command_dispatch(client: discord.Client, message: discord.Message, conn) -> bool:
prefix = "^pk(;|!)"
commands = [
(r"system (new|register|create|init)", system_commands.new_system),
(r"system set", system_commands.system_set),
(r"system link", system_commands.system_link),
(r"system unlink", system_commands.system_unlink),
(r"system fronter", system_commands.system_fronter),
(r"system fronthistory", system_commands.system_fronthistory),
(r"system (delete|remove|destroy|erase)", system_commands.system_delete),
(r"system frontpercent(age)?", system_commands.system_frontpercent),
(r"system", system_commands.system_info),
(r"import tupperware", import_commands.import_tupperware),
(r"member (new|create|add|register)", member_commands.new_member),
(r"member set", member_commands.member_set),
(r"member proxy", member_commands.member_proxy),
(r"member (delete|remove|destroy|erase)", member_commands.member_delete),
(r"member", member_commands.member_info),
(r"message", message_commands.message_info),
(r"mod log", mod_commands.set_log),
(r"invite", misc_commands.invite_link),
(r"export", misc_commands.export),
(r"help", misc_commands.show_help),
(r"switch move", switch_commands.switch_move),
(r"switch out", switch_commands.switch_out),
(r"switch", switch_commands.switch_member)
]
for pattern, func in commands:
regex = re.compile(prefix + pattern, re.IGNORECASE)
cmd = message.content
match = regex.match(cmd)
if match:
remaining_string = cmd[match.span()[1]:].strip()
ctx = CommandContext(
client=client,
message=message,
conn=conn,
args=remaining_string
)
await run_command(ctx, func)
return True
return False

View File

@ -1,20 +1,19 @@
import asyncio import asyncio
import re
from datetime import datetime from datetime import datetime
from typing import List
from pluralkit.bot.commands import * from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@command(cmd="import tupperware", description="Import data from Tupperware.", system_required=False)
async def import_tupperware(ctx: CommandContext, args: List[str]): async def import_tupperware(ctx: CommandContext):
tupperware_ids = ["431544605209788416", "433916057053560832"] # Main bot instance and Multi-Pals-specific fork tupperware_ids = ["431544605209788416", "433916057053560832"] # Main bot instance and Multi-Pals-specific fork
tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if ctx.message.server.get_member(bot_id)] tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if
ctx.message.server.get_member(bot_id)]
# Check if there's any Tupperware bot on the server # Check if there's any Tupperware bot on the server
if not tupperware_members: if not tupperware_members:
return embeds.error("This command only works in a server where the Tupperware bot is also present.") return CommandError("This command only works in a server where the Tupperware bot is also present.")
# Make sure at least one of the bts have send/read permissions here # Make sure at least one of the bts have send/read permissions here
for bot_member in tupperware_members: for bot_member in tupperware_members:
@ -24,9 +23,10 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
break break
else: else:
# If no bots have permission (ie. loop doesn't break), throw error # If no bots have permission (ie. loop doesn't break), throw error
return embeds.error("This command only works in a channel where the Tupperware bot has read/send access.") return CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
await ctx.reply(embed=utils.make_default_embed("Please reply to this message with `tul!list` (or the server equivalent).")) await ctx.reply(
embed=embeds.status("Please reply to this message with `tul!list` (or the server equivalent)."))
# Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user # Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user
def ensure_account(tw_msg): def ensure_account(tw_msg):
@ -39,13 +39,15 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
if not tw_msg.embeds[0]["title"]: if not tw_msg.embeds[0]["title"]:
return False return False
return tw_msg.embeds[0]["title"].startswith("{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator)) return tw_msg.embeds[0]["title"].startswith(
"{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
tupperware_page_embeds = [] tupperware_page_embeds = []
tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0, check=ensure_account) tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0,
check=ensure_account)
if not tw_msg: if not tw_msg:
return embeds.error("Tupperware import timed out.") return CommandError("Tupperware import timed out.")
tupperware_page_embeds.append(tw_msg.embeds[0]) tupperware_page_embeds.append(tw_msg.embeds[0])
# Handle Tupperware pagination # Handle Tupperware pagination
@ -74,7 +76,9 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
# If this isn't the same page as last check, edit the status message # If this isn't the same page as last check, edit the status message
if new_page != current_page: if new_page != current_page:
last_found_time = datetime.utcnow() last_found_time = datetime.utcnow()
await ctx.client.edit_message(status_msg, "Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(len(pages_found), total_pages)) await ctx.client.edit_message(status_msg,
"Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(
len(pages_found), total_pages))
current_page = new_page current_page = new_page
# And sleep a bit to prevent spamming the CPU # And sleep a bit to prevent spamming the CPU
@ -82,7 +86,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
# Make sure it doesn't spin here for too long, time out after 30 seconds since last new page # Make sure it doesn't spin here for too long, time out after 30 seconds since last new page
if (datetime.utcnow() - last_found_time).seconds > 30: if (datetime.utcnow() - last_found_time).seconds > 30:
return embeds.error("Pagination scan timed out.") return CommandError("Pagination scan timed out.")
# Now that we've got all the pages, put them in the embeds list # Now that we've got all the pages, put them in the embeds list
# Make sure to erase the original one we put in above too # Make sure to erase the original one we put in above too
@ -94,7 +98,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
logger.debug("Importing from Tupperware...") logger.debug("Importing from Tupperware...")
# Create new (nameless) system if there isn't any registered # Create new (nameless) system if there isn't any registered
system = ctx.system system = ctx.get_system()
if system is None: if system is None:
hid = utils.generate_hid() hid = utils.generate_hid()
logger.debug("Creating new system (hid={})...".format(hid)) logger.debug("Creating new system (hid={})...".format(hid))
@ -117,7 +121,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
if line.startswith("Brackets:"): if line.startswith("Brackets:"):
brackets = line[len("Brackets: "):] brackets = line[len("Brackets: "):]
member_prefix = brackets[:brackets.index("text")].strip() or None member_prefix = brackets[:brackets.index("text")].strip() or None
member_suffix = brackets[brackets.index("text")+4:].strip() or None member_suffix = brackets[brackets.index("text") + 4:].strip() or None
elif line.startswith("Avatar URL: "): elif line.startswith("Avatar URL: "):
url = line[len("Avatar URL: "):] url = line[len("Avatar URL: "):]
member_avatar = url member_avatar = url
@ -138,14 +142,19 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
# Or create a new member # Or create a new member
hid = utils.generate_hid() hid = utils.generate_hid()
logger.debug("Creating new member {} (hid={})...".format(name, hid)) logger.debug("Creating new member {} (hid={})...".format(name, hid))
existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid) existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name,
member_hid=hid)
# Save the new stuff in the DB # Save the new stuff in the DB
logger.debug("Updating fields...") logger.debug("Updating fields...")
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="prefix", value=member_prefix) await db.update_member_field(ctx.conn, member_id=existing_member.id, field="prefix", value=member_prefix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="suffix", value=member_suffix) await db.update_member_field(ctx.conn, member_id=existing_member.id, field="suffix", value=member_suffix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url", value=member_avatar) await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url",
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday", value=member_birthdate) value=member_avatar)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", value=member_description) await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday",
value=member_birthdate)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description",
value=member_description)
return embeds.success("System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.") return CommandSuccess(
"System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.")

View File

@ -8,32 +8,36 @@ from pluralkit.bot import help
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands")
async def member_info(ctx: MemberCommandContext, args: List[str]):
await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, ctx.member))
@command(cmd="member new", usage="<name>", description="Adds a new member to your system.", category="Member commands") async def member_info(ctx: CommandContext):
async def new_member(ctx: MemberCommandContext, args: List[str]): member = await ctx.pop_member(
if len(args) == 0: error=CommandError("You must pass a member name or ID.", help=help.lookup_member), system_only=False)
return embeds.error("You must pass a member name or ID.", help=help.add_member) await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, member))
name = " ".join(args)
bounds_error = utils.bounds_check_member_name(name, ctx.system.tag) async def new_member(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
return CommandError("You must pass a name for the new member.", help=help.add_member)
name = ctx.remaining()
bounds_error = utils.bounds_check_member_name(name, system.tag)
if bounds_error: if bounds_error:
return embeds.error(bounds_error) return CommandError(bounds_error)
# TODO: figure out what to do if this errors out on collision on generate_hid # TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid() hid = utils.generate_hid()
# Insert member row # Insert member row
await db.create_member(ctx.conn, system_id=ctx.system.id, member_name=name, member_hid=hid) await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid)
return embeds.success("Member \"{}\" (`{}`) registered!".format(name, hid)) return CommandSuccess(
"Member \"{}\" (`{}`) registered! To register their proxy tags, use `pk;member proxy`.".format(name, hid))
@member_command(cmd="member set", usage="<name|description|color|pronouns|birthdate|avatar> [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands") async def member_set(ctx: CommandContext):
async def member_set(ctx: MemberCommandContext, args: List[str]): system = await ctx.ensure_system()
if len(args) == 0: member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member))
return embeds.error("You must pass a property name to set.", help=help.edit_member) prop = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_member))
allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"] allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"]
db_properties = { db_properties = {
@ -45,23 +49,24 @@ async def member_set(ctx: MemberCommandContext, args: List[str]):
"avatar": "avatar_url" "avatar": "avatar_url"
} }
prop = args[0]
if prop not in allowed_properties: if prop not in allowed_properties:
return embeds.error("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), help=help.edit_member) return CommandError(
"Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)),
help=help.edit_member)
if len(args) >= 2: if ctx.has_next():
value = " ".join(args[1:]) value = " ".join(ctx.remaining())
# Sanity/validity checks and type conversions # Sanity/validity checks and type conversions
if prop == "name": if prop == "name":
bounds_error = utils.bounds_check_member_name(value, ctx.system.tag) bounds_error = utils.bounds_check_member_name(value, system.tag)
if bounds_error: if bounds_error:
return embeds.error(bounds_error) return CommandError(bounds_error)
if prop == "color": if prop == "color":
match = re.fullmatch("#?([0-9A-Fa-f]{6})", value) match = re.fullmatch("#?([0-9A-Fa-f]{6})", value)
if not match: if not match:
return embeds.error("Color must be a valid hex color (eg. #ff0000)") return CommandError("Color must be a valid hex color (eg. #ff0000)")
value = match.group(1).lower() value = match.group(1).lower()
@ -75,7 +80,7 @@ async def member_set(ctx: MemberCommandContext, args: List[str]):
# Useful if you want your birthday to be displayed yearless. # Useful if you want your birthday to be displayed yearless.
value = datetime.strptime("0001-" + value, "%Y-%m-%d").date() value = datetime.strptime("0001-" + value, "%Y-%m-%d").date()
except ValueError: except ValueError:
return embeds.error("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).") return CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).")
if prop == "avatar": if prop == "avatar":
user = await utils.parse_mention(ctx.client, value) user = await utils.parse_mention(ctx.client, value)
@ -89,41 +94,45 @@ async def member_set(ctx: MemberCommandContext, args: List[str]):
if u.scheme in ["http", "https"] and u.netloc and u.path: if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value value = value
else: else:
return embeds.error("Invalid image URL.") return CommandError("Invalid image URL.")
else: else:
# Can't clear member name # Can't clear member name
if prop == "name": if prop == "name":
return embeds.error("You can't clear the member name.") return CommandError("You can't clear the member name.")
# Clear from DB # Clear from DB
value = None value = None
db_prop = db_properties[prop] db_prop = db_properties[prop]
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field=db_prop, value=value) await db.update_member_field(ctx.conn, member_id=member.id, field=db_prop, value=value)
response = embeds.success("{} {}'s {}.".format("Updated" if value else "Cleared", ctx.member.name, prop)) response = CommandSuccess("{} {}'s {}.".format("Updated" if value else "Cleared", member.name, prop))
if prop == "avatar" and value: if prop == "avatar" and value:
response.set_image(url=value) response.set_image(url=value)
if prop == "color" and value: if prop == "color" and value:
response.colour = int(value, 16) response.colour = int(value, 16)
return response return response
@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands")
async def member_proxy(ctx: MemberCommandContext, args: List[str]): async def member_proxy(ctx: CommandContext):
if len(args) == 0: await ctx.ensure_system()
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.member_proxy))
if not ctx.has_next():
prefix, suffix = None, None prefix, suffix = None, None
else: else:
# Sanity checking # Sanity checking
example = " ".join(args) example = ctx.remaining()
if "text" not in example: if "text" not in example:
return embeds.error("Example proxy message must contain the string 'text'.", help=help.member_proxy) return CommandError("Example proxy message must contain the string 'text'.", help=help.member_proxy)
if example.count("text") != 1: if example.count("text") != 1:
return embeds.error("Example proxy message must contain the string 'text' exactly once.", help=help.member_proxy) return CommandError("Example proxy message must contain the string 'text' exactly once.",
help=help.member_proxy)
# Extract prefix and suffix # Extract prefix and suffix
prefix = example[:example.index("text")].strip() prefix = example[:example.index("text")].strip()
suffix = example[example.index("text")+4:].strip() suffix = example[example.index("text") + 4:].strip()
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix)) logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
# DB stores empty strings as None, make that work # DB stores empty strings as None, make that work
@ -133,17 +142,22 @@ async def member_proxy(ctx: MemberCommandContext, args: List[str]):
suffix = None suffix = None
async with ctx.conn.transaction(): async with ctx.conn.transaction():
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="prefix", value=prefix) await db.update_member_field(ctx.conn, member_id=member.id, field="prefix", value=prefix)
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="suffix", value=suffix) await db.update_member_field(ctx.conn, member_id=member.id, field="suffix", value=suffix)
return embeds.success("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.") return CommandSuccess("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.")
@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands")
async def member_delete(ctx: MemberCommandContext, args: List[str]): async def member_delete(ctx: CommandContext):
await ctx.reply("Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(ctx.member.name, ctx.member.hid)) await ctx.ensure_system()
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member))
await ctx.reply(
"Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(
member.name, member.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0) msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content.lower() == ctx.member.hid.lower(): if msg and msg.content.lower() == member.hid.lower():
await db.delete_member(ctx.conn, member_id=ctx.member.id) await db.delete_member(ctx.conn, member_id=member.id)
return embeds.success("Member deleted.") return CommandSuccess("Member deleted.")
else: else:
return embeds.error("Member deletion cancelled.") return CommandError("Member deletion cancelled.")

View File

@ -1,27 +1,21 @@
import logging from pluralkit.bot import help
from typing import List
from pluralkit.bot import utils, embeds, help
from pluralkit.bot.commands import * from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.", async def message_info(ctx: CommandContext):
category="Message commands", system_required=False) mid_str = ctx.pop_str(CommandError("You must pass a message ID.", help=help.message_lookup))
async def message_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass a message ID.", help=help.message_lookup)
try: try:
mid = int(args[0]) mid = int(mid_str)
except ValueError: except ValueError:
return embeds.error("You must pass a valid number as a message ID.", help=help.message_lookup) return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
# Find the message in the DB # Find the message in the DB
message = await db.get_message(ctx.conn, str(mid)) message = await db.get_message(ctx.conn, str(mid))
if not message: if not message:
raise embeds.error("Message with ID '{}' not found.".format(args[0])) raise CommandError("Message with ID '{}' not found.".format(mid))
# Get the original sender of the messages # Get the original sender of the messages
try: try:
@ -49,9 +43,9 @@ async def message_info(ctx: CommandContext, args: List[str]):
embed.add_field(name="Sent by", value=sender_name) embed.add_field(name="Sent by", value=sender_name)
if message.content: # Content can be empty string if there's an attachment if message.content: # Content can be empty string if there's an attachment
embed.add_field(name="Content", value=message.content, inline=False) embed.add_field(name="Content", value=message.content, inline=False)
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty) embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
return embed await ctx.reply(embed=embed)

View File

@ -12,13 +12,13 @@ from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.", system_required=False)
async def show_help(ctx: CommandContext, args: List[str]):
embed = utils.make_default_embed("")
embed.title = "PluralKit Help"
embed.set_footer(text="By Astrid (Ske#6201, or 'qoxvy' on PK) | GitHub: https://github.com/xSke/PluralKit/")
category = args[0] if len(args) > 0 else None async def show_help(ctx: CommandContext):
embed = embeds.status("")
embed.title = "PluralKit Help"
embed.set_footer(text="By Astrid (Ske#6201; pk;member qoxvy) | GitHub: https://github.com/xSke/PluralKit/")
category = ctx.pop_str() if ctx.has_next() else None
from pluralkit.bot.help import help_pages from pluralkit.bot.help import help_pages
if category in help_pages: if category in help_pages:
@ -28,12 +28,12 @@ async def show_help(ctx: CommandContext, args: List[str]):
else: else:
embed.description = text embed.description = text
else: else:
return embeds.error("Unknown help page '{}'.".format(category)) return CommandError("Unknown help page '{}'.".format(category))
return embed await ctx.reply(embed=embed)
@command(cmd="invite", description="Generates an invite link for this bot.", system_required=False)
async def invite_link(ctx: CommandContext, args: List[str]): async def invite_link(ctx: CommandContext):
client_id = os.environ["CLIENT_ID"] client_id = os.environ["CLIENT_ID"]
permissions = discord.Permissions() permissions = discord.Permissions()
@ -47,15 +47,16 @@ async def invite_link(ctx: CommandContext, args: List[str]):
url = oauth_url(client_id, permissions) url = oauth_url(client_id, permissions)
logger.debug("Sending invite URL: {}".format(url)) logger.debug("Sending invite URL: {}".format(url))
return embeds.success("Use this link to add PluralKit to your server: {}".format(url)) return CommandSuccess("Use this link to add PluralKit to your server: {}".format(url))
@command(cmd="export", description="Exports system data to a machine-readable format.")
async def export(ctx: CommandContext, args: List[str]):
members = await db.get_all_members(ctx.conn, ctx.system.id)
accounts = await db.get_linked_accounts(ctx.conn, ctx.system.id)
switches = await pluralkit.utils.get_front_history(ctx.conn, ctx.system.id, 999999)
system = ctx.system async def export(ctx: CommandContext):
system = await ctx.ensure_system()
members = await db.get_all_members(ctx.conn, system.id)
accounts = await db.get_linked_accounts(ctx.conn, system.id)
switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 999999)
data = { data = {
"name": system.name, "name": system.name,
"id": system.hid, "id": system.hid,

View File

@ -1,24 +1,20 @@
import logging
from typing import List
from pluralkit.bot import utils, embeds
from pluralkit.bot.commands import * from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands", system_required=False)
async def set_log(ctx: CommandContext, args: List[str]): async def set_log(ctx: CommandContext):
if not ctx.message.author.server_permissions.administrator: if not ctx.message.author.server_permissions.administrator:
return embeds.error("You must be a server administrator to use this command.") return CommandError("You must be a server administrator to use this command.")
server = ctx.message.server server = ctx.message.server
if len(args) == 0: if not ctx.has_next():
channel_id = None channel_id = None
else: else:
channel = utils.parse_channel_mention(args[0], server=server) channel = utils.parse_channel_mention(ctx.pop_str(), server=server)
if not channel: if not channel:
return embeds.error("Channel not found.") return CommandError("Channel not found.")
channel_id = channel.id channel_id = channel.id
await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id) await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id)
return embeds.success("Updated logging channel." if channel_id else "Cleared logging channel.") return CommandSuccess("Updated logging channel." if channel_id else "Cleared logging channel.")

View File

@ -1,120 +1,130 @@
from datetime import datetime
import logging
from typing import List
import dateparser import dateparser
import humanize import humanize
from datetime import datetime, timezone
from typing import List
import pluralkit.utils import pluralkit.utils
from pluralkit import Member from pluralkit.bot import help
from pluralkit.bot import utils, embeds, help
from pluralkit.bot.commands import * from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
async def switch_member(ctx: MemberCommandContext, args: List[str]): async def switch_member(ctx: CommandContext):
if len(args) == 0: system = await ctx.ensure_system()
return embeds.error("You must pass at least one member name or ID to register a switch to.", help=help.switch_register)
if not ctx.has_next():
return CommandError("You must pass at least one member name or ID to register a switch to.",
help=help.switch_register)
members: List[Member] = [] members: List[Member] = []
for member_name in args: for member_name in ctx.remaining().split(" "):
# Find the member # Find the member
member = await utils.get_member_fuzzy(ctx.conn, ctx.system.id, member_name) member = await utils.get_member_fuzzy(ctx.conn, system.id, member_name)
if not member: if not member:
return embeds.error("Couldn't find member \"{}\".".format(member_name)) return CommandError("Couldn't find member \"{}\".".format(member_name))
members.append(member) members.append(member)
# Compare requested switch IDs and existing fronter IDs to check for existing switches # Compare requested switch IDs and existing fronter IDs to check for existing switches
# Lists, because order matters, it makes sense to just swap fronters # Lists, because order matters, it makes sense to just swap fronters
member_ids = [member.id for member in members] member_ids = [member.id for member in members]
fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, ctx.system.id))[0] fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0]
if member_ids == fronter_ids: if member_ids == fronter_ids:
if len(members) == 1: if len(members) == 1:
return embeds.error("{} is already fronting.".format(members[0].name)) return CommandError("{} is already fronting.".format(members[0].name))
return embeds.error("Members {} are already fronting.".format(", ".join([m.name for m in members]))) return CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
# Also make sure there aren't any duplicates # Also make sure there aren't any duplicates
if len(set(member_ids)) != len(member_ids): if len(set(member_ids)) != len(member_ids):
return embeds.error("Duplicate members in member list.") return CommandError("Duplicate members in member list.")
# Log the switch # Log the switch
async with ctx.conn.transaction(): async with ctx.conn.transaction():
switch_id = await db.add_switch(ctx.conn, system_id=ctx.system.id) switch_id = await db.add_switch(ctx.conn, system_id=system.id)
for member in members: for member in members:
await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id) await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id)
if len(members) == 1: if len(members) == 1:
return embeds.success("Switch registered. Current fronter is now {}.".format(members[0].name)) return CommandSuccess("Switch registered. Current fronter is now {}.".format(members[0].name))
else: else:
return embeds.success("Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members]))) return CommandSuccess(
"Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members])))
async def switch_out(ctx: CommandContext):
system = await ctx.ensure_system()
@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands")
async def switch_out(ctx: MemberCommandContext, args: List[str]):
# Get current fronters # Get current fronters
fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=ctx.system.id) fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=system.id)
if not fronters: if not fronters:
return embeds.error("There's already no one in front.") return CommandError("There's already no one in front.")
# Log it, and don't log any members # Log it, and don't log any members
await db.add_switch(ctx.conn, system_id=ctx.system.id) await db.add_switch(ctx.conn, system_id=system.id)
return embeds.success("Switch-out registered.") return CommandSuccess("Switch-out registered.")
@command(cmd="switch move", usage="<time>", description="Moves the most recent switch to a different point in time.", category="Switching commands")
async def switch_move(ctx: MemberCommandContext, args: List[str]): async def switch_move(ctx: CommandContext):
if len(args) == 0: system = await ctx.ensure_system()
return embeds.error("You must pass a time to move the switch to.", help=help.switch_move) if not ctx.has_next():
return CommandError("You must pass a time to move the switch to.", help=help.switch_move)
# Parse the time to move to # Parse the time to move to
new_time = dateparser.parse(" ".join(args), languages=["en"], settings={ new_time = dateparser.parse(ctx.remaining(), languages=["en"], settings={
"TO_TIMEZONE": "UTC", "TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False "RETURN_AS_TIMEZONE_AWARE": False
}) })
if not new_time: if not new_time:
return embeds.error("{} can't be parsed as a valid time.".format(" ".join(args))) return CommandError("'{}' can't be parsed as a valid time.".format(ctx.remaining()), help=help.switch_move)
# Make sure the time isn't in the future # Make sure the time isn't in the future
if new_time > datetime.now(): if new_time > datetime.utcnow():
return embeds.error("Can't move switch to a time in the future.") return CommandError("Can't move switch to a time in the future.", help=help.switch_move)
# Make sure it all runs in a big transaction for atomicity # Make sure it all runs in a big transaction for atomicity
async with ctx.conn.transaction(): async with ctx.conn.transaction():
# Get the last two switches to make sure the switch to move isn't before the second-last switch # Get the last two switches to make sure the switch to move isn't before the second-last switch
last_two_switches = await pluralkit.utils.get_front_history(ctx.conn, ctx.system.id, count=2) last_two_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=2)
if len(last_two_switches) == 0: if len(last_two_switches) == 0:
return embeds.error("There are no registered switches for this system.") return CommandError("There are no registered switches for this system.")
last_timestamp, last_fronters = last_two_switches[0] last_timestamp, last_fronters = last_two_switches[0]
if len(last_two_switches) > 1: if len(last_two_switches) > 1:
second_last_timestamp, _ = last_two_switches[1] second_last_timestamp, _ = last_two_switches[1]
if new_time < second_last_timestamp: if new_time < second_last_timestamp:
time_str = humanize.naturaltime(second_last_timestamp) time_str = humanize.naturaltime(pluralkit.utils.fix_time(second_last_timestamp))
return embeds.error("Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str)) return CommandError(
"Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str))
# Display the confirmation message w/ humanized times # Display the confirmation message w/ humanized times
members = ", ".join([member.name for member in last_fronters]) or "nobody" members = ", ".join([member.name for member in last_fronters]) or "nobody"
last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds") last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds")
last_relative = humanize.naturaltime(last_timestamp) last_relative = humanize.naturaltime(pluralkit.utils.fix_time(last_timestamp))
new_absolute = new_time.isoformat(sep=" ", timespec="seconds") new_absolute = new_time.isoformat(sep=" ", timespec="seconds")
new_relative = humanize.naturaltime(new_time) new_relative = humanize.naturaltime(pluralkit.utils.fix_time(new_time))
embed = utils.make_default_embed("This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative)) embed = embeds.status(
"This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute,
last_relative,
new_absolute,
new_relative))
# Await and handle confirmation reactions # Await and handle confirmation reactions
confirm_msg = await ctx.reply(embed=embed) confirm_msg = await ctx.reply(embed=embed)
await ctx.client.add_reaction(confirm_msg, "") await ctx.client.add_reaction(confirm_msg, "")
await ctx.client.add_reaction(confirm_msg, "") await ctx.client.add_reaction(confirm_msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=confirm_msg, user=ctx.message.author, timeout=60.0) reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=confirm_msg, user=ctx.message.author,
timeout=60.0)
if not reaction: if not reaction:
return embeds.error("Switch move timed out.") return CommandError("Switch move timed out.")
if reaction.reaction.emoji == "": if reaction.reaction.emoji == "":
return embeds.error("Switch move cancelled.") return CommandError("Switch move cancelled.")
# DB requires the actual switch ID which our utility method above doesn't return, do this manually # DB requires the actual switch ID which our utility method above doesn't return, do this manually
switch_id = (await db.front_history(ctx.conn, ctx.system.id, count=1))[0]["id"] switch_id = (await db.front_history(ctx.conn, system.id, count=1))[0]["id"]
# Change the switch in the DB # Change the switch in the DB
await db.move_last_switch(ctx.conn, ctx.system.id, switch_id, new_time) await db.move_last_switch(ctx.conn, system.id, switch_id, new_time)
return embeds.success("Switch moved.") return CommandSuccess("Switch moved.")

View File

@ -1,39 +1,31 @@
from datetime import datetime
from typing import List
from urllib.parse import urlparse
import dateparser import dateparser
import humanize import humanize
from datetime import datetime
from urllib.parse import urlparse
import pluralkit.utils import pluralkit.utils
from pluralkit.bot import embeds, help from pluralkit.bot import help
from pluralkit.bot.commands import * from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands", system_required=False)
async def system_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
system = ctx.system
else:
# Look one up
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None: async def system_info(ctx: CommandContext):
return embeds.error("Unable to find system \"{}\".".format(args[0])) if ctx.has_next():
system = await ctx.pop_system()
else:
system = await ctx.ensure_system()
await ctx.reply(embed=await utils.generate_system_info_card(ctx.conn, ctx.client, system)) await ctx.reply(embed=await utils.generate_system_info_card(ctx.conn, ctx.client, system))
@command(cmd="system new", usage="[name]", description="Registers a new system to this account.", category="System commands", system_required=False)
async def new_system(ctx: CommandContext, args: List[str]):
if ctx.system:
return embeds.error("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
system_name = None async def new_system(ctx: CommandContext):
if len(args) > 0: system = await ctx.get_system()
system_name = " ".join(args) if system:
return CommandError(
"You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
system_name = ctx.remaining() or None
async with ctx.conn.transaction(): async with ctx.conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid # TODO: figure out what to do if this errors out on collision on generate_hid
@ -43,12 +35,13 @@ async def new_system(ctx: CommandContext, args: List[str]):
# Link account # Link account
await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id) await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return embeds.success("System registered! To begin adding members, use `pk;member new <name>`.") return CommandSuccess("System registered! To begin adding members, use `pk;member new <name>`.")
@command(cmd="system set", usage="<name|description|tag|avatar> [value]", description="Edits a system property. Leave [value] blank to clear.", category="System commands")
async def system_set(ctx: CommandContext, args: List[str]): async def system_set(ctx: CommandContext):
if len(args) == 0: system = await ctx.ensure_system()
return embeds.error("You must pass a property name to set.", help=help.edit_system)
prop = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_system))
allowed_properties = ["name", "description", "tag", "avatar"] allowed_properties = ["name", "description", "tag", "avatar"]
db_properties = { db_properties = {
@ -58,25 +51,29 @@ async def system_set(ctx: CommandContext, args: List[str]):
"avatar": "avatar_url" "avatar": "avatar_url"
} }
prop = args[0]
if prop not in allowed_properties: if prop not in allowed_properties:
raise embeds.error("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), help=help.edit_system) return CommandError(
"Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)),
help=help.edit_system)
if len(args) >= 2: if ctx.has_next():
value = " ".join(args[1:]) value = ctx.remaining()
# Sanity checking # Sanity checking
if prop == "tag": if prop == "tag":
if len(value) > 32: if len(value) > 32:
raise embeds.error("You can't have a system tag longer than 32 characters.") return CommandError("You can't have a system tag longer than 32 characters.")
# Make sure there are no members which would make the combined length exceed 32 # Make sure there are no members which would make the combined length exceed 32
members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=ctx.system.id, length=32 - len(value) - 1) members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=system.id,
length=32 - len(value) - 1)
if len(members_exceeding) > 0: if len(members_exceeding) > 0:
# If so, error out and warn # If so, error out and warn
member_names = ", ".join([member.name member_names = ", ".join([member.name
for member in members_exceeding]) for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names)) logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
raise embeds.error("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names)) return CommandError(
"The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(
member_names))
if prop == "avatar": if prop == "avatar":
user = await utils.parse_mention(ctx.client, value) user = await utils.parse_mention(ctx.client, value)
@ -90,75 +87,73 @@ async def system_set(ctx: CommandContext, args: List[str]):
if u.scheme in ["http", "https"] and u.netloc and u.path: if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value value = value
else: else:
raise embeds.error("Invalid image URL.") return CommandError("Invalid image URL.")
else: else:
# Clear from DB # Clear from DB
value = None value = None
db_prop = db_properties[prop] db_prop = db_properties[prop]
await db.update_system_field(ctx.conn, system_id=ctx.system.id, field=db_prop, value=value) await db.update_system_field(ctx.conn, system_id=system.id, field=db_prop, value=value)
response = embeds.success("{} system {}.".format("Updated" if value else "Cleared", prop)) response = CommandSuccess("{} system {}.".format("Updated" if value else "Cleared", prop))
if prop == "avatar" and value: if prop == "avatar" and value:
response.set_image(url=value) response.set_image(url=value)
return response return response
@command(cmd="system link", usage="<account>", description="Links another account to your system.", category="System commands")
async def system_link(ctx: CommandContext, args: List[str]): async def system_link(ctx: CommandContext):
if len(args) == 0: system = await ctx.ensure_system()
return embeds.error("You must pass an account to link this system to.", help=help.link_account) account_name = ctx.pop_str(CommandError("You must pass an account to link this system to.", help=help.link_account))
# Find account to link # Find account to link
linkee = await utils.parse_mention(ctx.client, args[0]) linkee = await utils.parse_mention(ctx.client, account_name)
if not linkee: if not linkee:
return embeds.error("Account not found.") return CommandError("Account not found.")
# Make sure account doesn't already have a system # Make sure account doesn't already have a system
account_system = await db.get_system_by_account(ctx.conn, linkee.id) account_system = await db.get_system_by_account(ctx.conn, linkee.id)
if account_system: if account_system:
return embeds.error("The mentioned account is already linked to a system (`{}`)".format(account_system.hid)) return CommandError("The mentioned account is already linked to a system (`{}`)".format(account_system.hid))
# Send confirmation message # Send confirmation message
msg = await ctx.reply("{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention)) msg = await ctx.reply(
"{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
await ctx.client.add_reaction(msg, "") await ctx.client.add_reaction(msg, "")
await ctx.client.add_reaction(msg, "") await ctx.client.add_reaction(msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee, timeout=60.0) reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee, timeout=60.0)
# If account to be linked confirms... # If account to be linked confirms...
if not reaction: if not reaction:
return embeds.error("Account link timed out.") return CommandError("Account link timed out.")
if not reaction.reaction.emoji == "": if not reaction.reaction.emoji == "":
return embeds.error("Account link cancelled.") return CommandError("Account link cancelled.")
await db.link_account(ctx.conn, system_id=ctx.system.id, account_id=linkee.id) await db.link_account(ctx.conn, system_id=system.id, account_id=linkee.id)
return embeds.success("Account linked to system.") return CommandSuccess("Account linked to system.")
async def system_unlink(ctx: CommandContext):
system = await ctx.ensure_system()
@command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.", category="System commands")
async def system_unlink(ctx: CommandContext, args: List[str]):
# Make sure you can't unlink every account # Make sure you can't unlink every account
linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=ctx.system.id) linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=system.id)
if len(linked_accounts) == 1: if len(linked_accounts) == 1:
return embeds.error("This is the only account on your system, so you can't unlink it.") return CommandError("This is the only account on your system, so you can't unlink it.")
await db.unlink_account(ctx.conn, system_id=ctx.system.id, account_id=ctx.message.author.id) await db.unlink_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return embeds.success("Account unlinked.") return CommandSuccess("Account unlinked.")
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands", system_required=False)
async def system_fronter(ctx: CommandContext, args: List[str]): async def system_fronter(ctx: CommandContext):
if len(args) == 0: if ctx.has_next():
if not ctx.system: system = await ctx.pop_system()
raise NoSystemRegistered()
system = ctx.system
else: else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0]) system = await ctx.ensure_system()
if system is None:
return embeds.error("Can't find system \"{}\".".format(args[0]))
fronters, timestamp = await pluralkit.utils.get_fronters(ctx.conn, system_id=system.id) fronters, timestamp = await pluralkit.utils.get_fronters(ctx.conn, system_id=system.id)
fronter_names = [member.name for member in fronters] fronter_names = [member.name for member in fronters]
embed = utils.make_default_embed(None) embed = embeds.status("")
if len(fronter_names) == 0: if len(fronter_names) == 0:
embed.add_field(name="Current fronter", value="(no fronter)") embed.add_field(name="Current fronter", value="(no fronter)")
@ -168,20 +163,16 @@ async def system_fronter(ctx: CommandContext, args: List[str]):
embed.add_field(name="Current fronters", value=", ".join(fronter_names)) embed.add_field(name="Current fronters", value=", ".join(fronter_names))
if timestamp: if timestamp:
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp))) embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"),
return embed humanize.naturaltime(pluralkit.utils.fix_time(timestamp))))
await ctx.reply(embed=embed)
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands", system_required=False)
async def system_fronthistory(ctx: CommandContext, args: List[str]): async def system_fronthistory(ctx: CommandContext):
if len(args) == 0: if ctx.has_next():
if not ctx.system: system = await ctx.pop_system()
raise NoSystemRegistered()
system = ctx.system
else: else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0]) system = await ctx.ensure_system()
if system is None:
raise embeds.error("Can't find system \"{}\".".format(args[0]))
lines = [] lines = []
front_history = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=10) front_history = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=10)
@ -194,37 +185,39 @@ async def system_fronthistory(ctx: CommandContext, args: List[str]):
# Make proper date string # Make proper date string
time_text = timestamp.isoformat(sep=" ", timespec="seconds") time_text = timestamp.isoformat(sep=" ", timespec="seconds")
rel_text = humanize.naturaltime(timestamp) rel_text = humanize.naturaltime(pluralkit.utils.fix_time(timestamp))
delta_text = "" delta_text = ""
if i > 0: if i > 0:
last_switch_time = front_history[i-1][0] last_switch_time = front_history[i - 1][0]
delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time)) delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time))
lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text)) lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text))
embed = utils.make_default_embed("\n".join(lines) or "(none)") embed = embeds.status("\n".join(lines) or "(none)")
embed.title = "Past switches" embed.title = "Past switches"
return embed await ctx.reply(embed=embed)
@command(cmd="system delete", description="Deletes your system from the database ***permanently***.", category="System commands") async def system_delete(ctx: CommandContext):
async def system_delete(ctx: CommandContext, args: List[str]): system = await ctx.ensure_system()
await ctx.reply("Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(ctx.system.hid))
await ctx.reply(
"Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(
system.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0) msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content.lower() == ctx.system.hid.lower(): if msg and msg.content.lower() == system.hid.lower():
await db.remove_system(ctx.conn, system_id=ctx.system.id) await db.remove_system(ctx.conn, system_id=system.id)
return embeds.success("System deleted.") return CommandSuccess("System deleted.")
else: else:
return embeds.error("System deletion cancelled.") return CommandError("System deletion cancelled.")
@command(cmd="system frontpercent", usage="[time]", async def system_frontpercent(ctx: CommandContext):
description="Shows the fronting percentage of every member, averaged over the given time", system = await ctx.ensure_system()
category="System commands")
async def system_frontpercent(ctx: CommandContext, args: List[str]):
# Parse the time limit (will go this far back) # Parse the time limit (will go this far back)
before = dateparser.parse(" ".join(args), languages=["en"], settings={ before = dateparser.parse(ctx.remaining(), languages=["en"], settings={
"TO_TIMEZONE": "UTC", "TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False "RETURN_AS_TIMEZONE_AWARE": False
}) })
@ -234,9 +227,9 @@ async def system_frontpercent(ctx: CommandContext, args: List[str]):
before = None before = None
# Fetch list of switches # Fetch list of switches
all_switches = await pluralkit.utils.get_front_history(ctx.conn, ctx.system.id, 99999) all_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 99999)
if not all_switches: if not all_switches:
return embeds.error("No switches registered to this system.") return CommandError("No switches registered to this system.")
# Cull the switches *ending* before the limit, if given # Cull the switches *ending* before the limit, if given
# We'll need to find the first switch starting before the limit, then cut off every switch *before* that # We'll need to find the first switch starting before the limit, then cut off every switch *before* that
@ -264,11 +257,11 @@ async def system_frontpercent(ctx: CommandContext, args: List[str]):
# Calculate length of the switch # Calculate length of the switch
switch_length = end_time - start_time switch_length = end_time - start_time
def add_switch(member_id, length): def add_switch(id, length):
if member_id not in member_times: if id not in member_times:
member_times[member_id] = length member_times[id] = length
else: else:
member_times[member_id] += length member_times[id] += length
for member in members: for member in members:
# Add the switch length to the currently registered time for that member # Add the switch length to the currently registered time for that member
@ -297,4 +290,4 @@ async def system_frontpercent(ctx: CommandContext, args: List[str]):
value="{}% ({})".format(percent, humanize.naturaldelta(front_time))) value="{}% ({})".format(percent, humanize.naturaldelta(front_time)))
embed.set_footer(text="Since {}".format(span_start.isoformat(sep=" ", timespec="seconds"))) embed.set_footer(text="Since {}".format(span_start.isoformat(sep=" ", timespec="seconds")))
return embed await ctx.reply(embed=embed)

View File

@ -8,7 +8,7 @@ import aiohttp
import discord import discord
from pluralkit import db from pluralkit import db
from pluralkit.bot import channel_logger, utils from pluralkit.bot import channel_logger, utils, embeds
from pluralkit.stats import StatCollector from pluralkit.stats import StatCollector
logger = logging.getLogger("pluralkit.bot.proxy") logger = logging.getLogger("pluralkit.bot.proxy")
@ -253,10 +253,10 @@ class Proxy:
async with conn.transaction(): async with conn.transaction():
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url) await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
except WebhookPermissionError: except WebhookPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.") embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed) await self.client.send_message(message.channel, embed=embed)
except DeletionPermissionError: except DeletionPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.") embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed) await self.client.send_message(message.channel, embed=embed)
return True return True

View File

@ -81,19 +81,6 @@ async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) ->
if member is not None: if member is not None:
return member return member
def make_default_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message
return embed
def make_error_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message
return embed
async def generate_system_info_card(conn, client: discord.Client, system: System) -> discord.Embed: async def generate_system_info_card(conn, client: discord.Client, system: System) -> discord.Embed:
card = discord.Embed() card = discord.Embed()
card.colour = discord.Colour.blue() card.colour = discord.Colour.blue()

View File

@ -323,7 +323,7 @@ async def create_tables(conn):
description text, description text,
tag text, tag text,
avatar_url text, avatar_url text,
created timestamp not null default current_timestamp created timestamp not null default (current_timestamp at time zone 'utc')
)""") )""")
await conn.execute("""create table if not exists members ( await conn.execute("""create table if not exists members (
id serial primary key, id serial primary key,
@ -337,7 +337,7 @@ async def create_tables(conn):
description text, description text,
prefix text, prefix text,
suffix text, suffix text,
created timestamp not null default current_timestamp created timestamp not null default (current_timestamp at time zone 'utc')
)""") )""")
await conn.execute("""create table if not exists accounts ( await conn.execute("""create table if not exists accounts (
uid bigint primary key, uid bigint primary key,
@ -353,7 +353,7 @@ async def create_tables(conn):
await conn.execute("""create table if not exists switches ( await conn.execute("""create table if not exists switches (
id serial primary key, id serial primary key,
system serial not null references systems(id) on delete cascade, system serial not null references systems(id) on delete cascade,
timestamp timestamp not null default current_timestamp timestamp timestamp not null default (current_timestamp at time zone 'utc')
)""") )""")
await conn.execute("""create table if not exists switch_members ( await conn.execute("""create table if not exists switch_members (
id serial primary key, id serial primary key,

View File

@ -1,9 +1,14 @@
from datetime import datetime from datetime import datetime, timezone
from typing import List, Tuple from typing import List, Tuple
from pluralkit import db, Member from pluralkit import db, Member
def fix_time(time: datetime):
# Assume we're receiving a naive datetime set to UTC, returns naive time zone set to local
return time.replace(tzinfo=timezone.utc).astimezone().replace(tzinfo=None)
async def get_fronter_ids(conn, system_id) -> (List[int], datetime): async def get_fronter_ids(conn, system_id) -> (List[int], datetime):
switches = await db.front_history(conn, system_id=system_id, count=1) switches = await db.front_history(conn, system_id=system_id, count=1)
if not switches: if not switches: