Major command handling refactor

This commit is contained in:
Ske 2018-09-07 17:34:38 +02:00
parent 0869f94cdf
commit f067485e88
15 changed files with 463 additions and 355 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@
.vscode/
.idea/
venv/
*.pyc

View File

@ -75,7 +75,11 @@ class PluralKitBot:
pass
async def handle_command_dispatch(self, message):
command_items = commands.command_list.items()
async with self.pool.acquire() as conn:
result = await commands.command_dispatch(self.client, message, conn)
return result
"""command_items = commands.command_list.items()
command_items = sorted(command_items, key=lambda x: len(x[0]), reverse=True)
prefix = "pk;"
@ -98,7 +102,7 @@ class PluralKitBot:
response_time = (datetime.now() - message.timestamp).total_seconds()
await self.stats.report_command(command_name, execution_time, response_time)
return True
return True"""
async def handle_proxy_dispatch(self, message):
# Try doing proxy parsing

View File

@ -1,84 +1,111 @@
import logging
from collections import namedtuple
import asyncpg
import discord
import logging
import re
from typing import Tuple, Optional
import pluralkit
from pluralkit import db
from pluralkit.bot import utils, embeds
from pluralkit import db, System, Member
from pluralkit.bot import embeds, utils
logger = logging.getLogger("pluralkit.bot.commands")
command_list = {}
class NoSystemRegistered(Exception):
def next_arg(arg_string: str) -> Tuple[str, Optional[str]]:
if arg_string.startswith("\""):
end_quote = arg_string.find("\"", start=1)
if end_quote > 0:
return arg_string[1:end_quote], arg_string[end_quote + 1:].strip()
else:
return arg_string[1:], None
next_space = arg_string.find(" ")
if next_space >= 0:
return arg_string[:next_space].strip(), arg_string[next_space:].strip()
else:
return arg_string.strip(), None
class CommandResponse:
def to_embed(self):
pass
class CommandContext(namedtuple("CommandContext", ["client", "conn", "message", "system"])):
client: discord.Client
conn: asyncpg.Connection
message: discord.Message
system: pluralkit.System
async def reply(self, message=None, embed=None):
return await self.client.send_message(self.message.channel, message, embed=embed)
class CommandSuccess(CommandResponse):
def __init__(self, text):
self.text = text
class MemberCommandContext(namedtuple("MemberCommandContext", CommandContext._fields + ("member",)), CommandContext):
client: discord.Client
conn: asyncpg.Connection
message: discord.Message
system: pluralkit.System
member: pluralkit.Member
def to_embed(self):
return embeds.success("\u2705 " + self.text)
class CommandEntry(namedtuple("CommandEntry", ["command", "function", "usage", "description", "category"])):
pass
def command(cmd, usage=None, description=None, category=None, system_required=True):
def wrap(func):
async def wrapper(client, conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
class CommandError(Exception, CommandResponse):
def __init__(self, embed: str, help: Tuple[str, str] = None):
self.text = embed
self.help = help
if system_required and system is None:
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one."))
return
def to_embed(self):
return embeds.error("\u274c " + self.text, self.help)
ctx = CommandContext(client=client, conn=conn, message=message, system=system)
try:
res = await func(ctx, args)
if res:
embed = res if isinstance(res, discord.Embed) else utils.make_default_embed(res)
await client.send_message(message.channel, embed=embed)
except NoSystemRegistered:
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one."))
except Exception:
logger.exception("Exception while handling command {} (args={}, system={})".format(cmd, args, system.hid if system else "(none)"))
class CommandContext:
def __init__(self, client: discord.Client, message: discord.Message, conn, args: str):
self.client = client
self.message = message
self.conn = conn
self.args = args
# Put command in map
command_list[cmd] = CommandEntry(command=cmd, function=wrapper, usage=usage, description=description, category=category)
return wrapper
return wrap
async def get_system(self) -> Optional[System]:
return await db.get_system_by_account(self.conn, self.message.author.id)
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
def wrap(func):
async def wrapper(ctx: CommandContext, args):
# Return if no member param
if len(args) == 0:
return embeds.error("You must pass a member name or ID.")
async def ensure_system(self) -> System:
system = await self.get_system()
# System is allowed to be none if not system_only
system_id = ctx.system.id if ctx.system else None
# And find member by key
member = await utils.get_member_fuzzy(ctx.conn, system_id=system_id, key=args[0], system_only=system_only)
if not system:
raise CommandError(
embeds.error("No system registered to this account. Use `pk;system new` to register one."))
if member is None:
return embeds.error("Can't find member \"{}\".".format(args[0]))
return system
def has_next(self) -> bool:
return bool(self.args)
def pop_str(self, error: CommandError = None) -> str:
if not self.args:
if error:
raise error
return None
popped, self.args = next_arg(self.args)
return popped
async def pop_system(self, error: CommandError = None) -> System:
name = self.pop_str(error)
system = await utils.get_system_fuzzy(self.conn, self.client, name)
if not system:
raise CommandError("Unable to find system '{}'.".format(name))
return system
async def pop_member(self, error: CommandError = None, system_only: bool = True) -> Member:
name = self.pop_str(error)
if system_only:
system = await self.ensure_system()
else:
system = await self.get_system()
member = await utils.get_member_fuzzy(self.conn, system.id if system else None, name, system_only)
if not member:
raise CommandError("Unable to find member '{}'{}.".format(name, " in your system" if system_only else ""))
return member
def remaining(self):
return self.args
async def reply(self, content=None, embed=None):
return await self.client.send_message(self.message.channel, content=content, embed=embed)
ctx = MemberCommandContext(client=ctx.client, conn=ctx.conn, message=ctx.message, system=ctx.system, member=member)
return await func(ctx, args[1:])
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category, system_required=False)(wrapper)
return wrap
import pluralkit.bot.commands.import_commands
import pluralkit.bot.commands.member_commands
@ -87,3 +114,69 @@ import pluralkit.bot.commands.misc_commands
import pluralkit.bot.commands.mod_commands
import pluralkit.bot.commands.switch_commands
import pluralkit.bot.commands.system_commands
async def run_command(ctx: CommandContext, func):
try:
result = await func(ctx)
if isinstance(result, CommandResponse):
await ctx.reply(embed=result.to_embed())
except CommandError as e:
await ctx.reply(embed=e.to_embed())
except Exception:
logger.exception("Exception while dispatching command")
async def command_dispatch(client: discord.Client, message: discord.Message, conn) -> bool:
prefix = "^pk(;|!)"
commands = [
(r"system (new|register|create|init)", system_commands.new_system),
(r"system set", system_commands.system_set),
(r"system link", system_commands.system_link),
(r"system unlink", system_commands.system_unlink),
(r"system fronter", system_commands.system_fronter),
(r"system fronthistory", system_commands.system_fronthistory),
(r"system (delete|remove|destroy|erase)", system_commands.system_delete),
(r"system frontpercent(age)?", system_commands.system_frontpercent),
(r"system", system_commands.system_info),
(r"import tupperware", import_commands.import_tupperware),
(r"member (new|create|add|register)", member_commands.new_member),
(r"member set", member_commands.member_set),
(r"member proxy", member_commands.member_proxy),
(r"member (delete|remove|destroy|erase)", member_commands.member_delete),
(r"member", member_commands.member_info),
(r"message", message_commands.message_info),
(r"mod log", mod_commands.set_log),
(r"invite", misc_commands.invite_link),
(r"export", misc_commands.export),
(r"help", misc_commands.show_help),
(r"switch move", switch_commands.switch_move),
(r"switch out", switch_commands.switch_out),
(r"switch", switch_commands.switch_member)
]
for pattern, func in commands:
regex = re.compile(prefix + pattern, re.IGNORECASE)
cmd = message.content
match = regex.match(cmd)
if match:
remaining_string = cmd[match.span()[1]:].strip()
ctx = CommandContext(
client=client,
message=message,
conn=conn,
args=remaining_string
)
await run_command(ctx, func)
return True
return False

View File

@ -1,20 +1,19 @@
import asyncio
import re
from datetime import datetime
from typing import List
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="import tupperware", description="Import data from Tupperware.", system_required=False)
async def import_tupperware(ctx: CommandContext, args: List[str]):
async def import_tupperware(ctx: CommandContext):
tupperware_ids = ["431544605209788416", "433916057053560832"] # Main bot instance and Multi-Pals-specific fork
tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if ctx.message.server.get_member(bot_id)]
tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if
ctx.message.server.get_member(bot_id)]
# Check if there's any Tupperware bot on the server
if not tupperware_members:
return embeds.error("This command only works in a server where the Tupperware bot is also present.")
return CommandError("This command only works in a server where the Tupperware bot is also present.")
# Make sure at least one of the bts have send/read permissions here
for bot_member in tupperware_members:
@ -24,9 +23,10 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
break
else:
# If no bots have permission (ie. loop doesn't break), throw error
return embeds.error("This command only works in a channel where the Tupperware bot has read/send access.")
return CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
await ctx.reply(embed=utils.make_default_embed("Please reply to this message with `tul!list` (or the server equivalent)."))
await ctx.reply(
embed=embeds.status("Please reply to this message with `tul!list` (or the server equivalent)."))
# Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user
def ensure_account(tw_msg):
@ -39,13 +39,15 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
if not tw_msg.embeds[0]["title"]:
return False
return tw_msg.embeds[0]["title"].startswith("{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
return tw_msg.embeds[0]["title"].startswith(
"{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
tupperware_page_embeds = []
tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0, check=ensure_account)
tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0,
check=ensure_account)
if not tw_msg:
return embeds.error("Tupperware import timed out.")
return CommandError("Tupperware import timed out.")
tupperware_page_embeds.append(tw_msg.embeds[0])
# Handle Tupperware pagination
@ -74,7 +76,9 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
# If this isn't the same page as last check, edit the status message
if new_page != current_page:
last_found_time = datetime.utcnow()
await ctx.client.edit_message(status_msg, "Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(len(pages_found), total_pages))
await ctx.client.edit_message(status_msg,
"Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(
len(pages_found), total_pages))
current_page = new_page
# And sleep a bit to prevent spamming the CPU
@ -82,7 +86,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
# Make sure it doesn't spin here for too long, time out after 30 seconds since last new page
if (datetime.utcnow() - last_found_time).seconds > 30:
return embeds.error("Pagination scan timed out.")
return CommandError("Pagination scan timed out.")
# Now that we've got all the pages, put them in the embeds list
# Make sure to erase the original one we put in above too
@ -94,7 +98,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
logger.debug("Importing from Tupperware...")
# Create new (nameless) system if there isn't any registered
system = ctx.system
system = ctx.get_system()
if system is None:
hid = utils.generate_hid()
logger.debug("Creating new system (hid={})...".format(hid))
@ -117,7 +121,7 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
if line.startswith("Brackets:"):
brackets = line[len("Brackets: "):]
member_prefix = brackets[:brackets.index("text")].strip() or None
member_suffix = brackets[brackets.index("text")+4:].strip() or None
member_suffix = brackets[brackets.index("text") + 4:].strip() or None
elif line.startswith("Avatar URL: "):
url = line[len("Avatar URL: "):]
member_avatar = url
@ -138,14 +142,19 @@ async def import_tupperware(ctx: CommandContext, args: List[str]):
# Or create a new member
hid = utils.generate_hid()
logger.debug("Creating new member {} (hid={})...".format(name, hid))
existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid)
existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name,
member_hid=hid)
# Save the new stuff in the DB
logger.debug("Updating fields...")
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="prefix", value=member_prefix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="suffix", value=member_suffix)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url", value=member_avatar)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday", value=member_birthdate)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description", value=member_description)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="avatar_url",
value=member_avatar)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="birthday",
value=member_birthdate)
await db.update_member_field(ctx.conn, member_id=existing_member.id, field="description",
value=member_description)
return embeds.success("System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.")
return CommandSuccess(
"System information imported. Try using `pk;system` now.\nYou should probably remove your members from Tupperware to avoid double-posting.")

View File

@ -8,32 +8,36 @@ from pluralkit.bot import help
logger = logging.getLogger("pluralkit.commands")
@member_command(cmd="member", description="Shows information about a system member.", system_only=False, category="Member commands")
async def member_info(ctx: MemberCommandContext, args: List[str]):
await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, ctx.member))
@command(cmd="member new", usage="<name>", description="Adds a new member to your system.", category="Member commands")
async def new_member(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass a member name or ID.", help=help.add_member)
async def member_info(ctx: CommandContext):
member = await ctx.pop_member(
error=CommandError("You must pass a member name or ID.", help=help.lookup_member), system_only=False)
await ctx.reply(embed=await utils.generate_member_info_card(ctx.conn, member))
name = " ".join(args)
bounds_error = utils.bounds_check_member_name(name, ctx.system.tag)
async def new_member(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
return CommandError("You must pass a name for the new member.", help=help.add_member)
name = ctx.remaining()
bounds_error = utils.bounds_check_member_name(name, system.tag)
if bounds_error:
return embeds.error(bounds_error)
return CommandError(bounds_error)
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid()
# Insert member row
await db.create_member(ctx.conn, system_id=ctx.system.id, member_name=name, member_hid=hid)
return embeds.success("Member \"{}\" (`{}`) registered!".format(name, hid))
await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid)
return CommandSuccess(
"Member \"{}\" (`{}`) registered! To register their proxy tags, use `pk;member proxy`.".format(name, hid))
@member_command(cmd="member set", usage="<name|description|color|pronouns|birthdate|avatar> [value]", description="Edits a member property. Leave [value] blank to clear.", category="Member commands")
async def member_set(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass a property name to set.", help=help.edit_member)
async def member_set(ctx: CommandContext):
system = await ctx.ensure_system()
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member))
prop = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_member))
allowed_properties = ["name", "description", "color", "pronouns", "birthdate", "avatar"]
db_properties = {
@ -45,23 +49,24 @@ async def member_set(ctx: MemberCommandContext, args: List[str]):
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
return embeds.error("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), help=help.edit_member)
return CommandError(
"Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)),
help=help.edit_member)
if len(args) >= 2:
value = " ".join(args[1:])
if ctx.has_next():
value = " ".join(ctx.remaining())
# Sanity/validity checks and type conversions
if prop == "name":
bounds_error = utils.bounds_check_member_name(value, ctx.system.tag)
bounds_error = utils.bounds_check_member_name(value, system.tag)
if bounds_error:
return embeds.error(bounds_error)
return CommandError(bounds_error)
if prop == "color":
match = re.fullmatch("#?([0-9A-Fa-f]{6})", value)
if not match:
return embeds.error("Color must be a valid hex color (eg. #ff0000)")
return CommandError("Color must be a valid hex color (eg. #ff0000)")
value = match.group(1).lower()
@ -75,7 +80,7 @@ async def member_set(ctx: MemberCommandContext, args: List[str]):
# Useful if you want your birthday to be displayed yearless.
value = datetime.strptime("0001-" + value, "%Y-%m-%d").date()
except ValueError:
return embeds.error("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).")
return CommandError("Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25).")
if prop == "avatar":
user = await utils.parse_mention(ctx.client, value)
@ -89,41 +94,45 @@ async def member_set(ctx: MemberCommandContext, args: List[str]):
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
return embeds.error("Invalid image URL.")
return CommandError("Invalid image URL.")
else:
# Can't clear member name
if prop == "name":
return embeds.error("You can't clear the member name.")
return CommandError("You can't clear the member name.")
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field=db_prop, value=value)
await db.update_member_field(ctx.conn, member_id=member.id, field=db_prop, value=value)
response = embeds.success("{} {}'s {}.".format("Updated" if value else "Cleared", ctx.member.name, prop))
response = CommandSuccess("{} {}'s {}.".format("Updated" if value else "Cleared", member.name, prop))
if prop == "avatar" and value:
response.set_image(url=value)
if prop == "color" and value:
response.colour = int(value, 16)
return response
@member_command(cmd="member proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).", category="Member commands")
async def member_proxy(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
async def member_proxy(ctx: CommandContext):
await ctx.ensure_system()
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.member_proxy))
if not ctx.has_next():
prefix, suffix = None, None
else:
# Sanity checking
example = " ".join(args)
example = ctx.remaining()
if "text" not in example:
return embeds.error("Example proxy message must contain the string 'text'.", help=help.member_proxy)
return CommandError("Example proxy message must contain the string 'text'.", help=help.member_proxy)
if example.count("text") != 1:
return embeds.error("Example proxy message must contain the string 'text' exactly once.", help=help.member_proxy)
return CommandError("Example proxy message must contain the string 'text' exactly once.",
help=help.member_proxy)
# Extract prefix and suffix
prefix = example[:example.index("text")].strip()
suffix = example[example.index("text")+4:].strip()
suffix = example[example.index("text") + 4:].strip()
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
# DB stores empty strings as None, make that work
@ -133,17 +142,22 @@ async def member_proxy(ctx: MemberCommandContext, args: List[str]):
suffix = None
async with ctx.conn.transaction():
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="prefix", value=prefix)
await db.update_member_field(ctx.conn, member_id=ctx.member.id, field="suffix", value=suffix)
return embeds.success("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.")
await db.update_member_field(ctx.conn, member_id=member.id, field="prefix", value=prefix)
await db.update_member_field(ctx.conn, member_id=member.id, field="suffix", value=suffix)
return CommandSuccess("Proxy settings updated." if prefix or suffix else "Proxy settings cleared.")
@member_command("member delete", description="Deletes a member from your system ***permanently***.", category="Member commands")
async def member_delete(ctx: MemberCommandContext, args: List[str]):
await ctx.reply("Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(ctx.member.name, ctx.member.hid))
async def member_delete(ctx: CommandContext):
await ctx.ensure_system()
member = await ctx.pop_member(CommandError("You must pass a member name.", help=help.edit_member))
await ctx.reply(
"Are you sure you want to delete {}? If so, reply to this message with the member's ID (`{}`).".format(
member.name, member.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content.lower() == ctx.member.hid.lower():
await db.delete_member(ctx.conn, member_id=ctx.member.id)
return embeds.success("Member deleted.")
if msg and msg.content.lower() == member.hid.lower():
await db.delete_member(ctx.conn, member_id=member.id)
return CommandSuccess("Member deleted.")
else:
return embeds.error("Member deletion cancelled.")
return CommandError("Member deletion cancelled.")

View File

@ -1,27 +1,21 @@
import logging
from typing import List
from pluralkit.bot import utils, embeds, help
from pluralkit.bot import help
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="message", usage="<id>", description="Shows information about a proxied message. Requires the message ID.",
category="Message commands", system_required=False)
async def message_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass a message ID.", help=help.message_lookup)
async def message_info(ctx: CommandContext):
mid_str = ctx.pop_str(CommandError("You must pass a message ID.", help=help.message_lookup))
try:
mid = int(args[0])
mid = int(mid_str)
except ValueError:
return embeds.error("You must pass a valid number as a message ID.", help=help.message_lookup)
return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
# Find the message in the DB
message = await db.get_message(ctx.conn, str(mid))
if not message:
raise embeds.error("Message with ID '{}' not found.".format(args[0]))
raise CommandError("Message with ID '{}' not found.".format(mid))
# Get the original sender of the messages
try:
@ -54,4 +48,4 @@ async def message_info(ctx: CommandContext, args: List[str]):
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
return embed
await ctx.reply(embed=embed)

View File

@ -12,13 +12,13 @@ from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="help", usage="[system|member|proxy|switch|mod]", description="Shows help messages.", system_required=False)
async def show_help(ctx: CommandContext, args: List[str]):
embed = utils.make_default_embed("")
embed.title = "PluralKit Help"
embed.set_footer(text="By Astrid (Ske#6201, or 'qoxvy' on PK) | GitHub: https://github.com/xSke/PluralKit/")
category = args[0] if len(args) > 0 else None
async def show_help(ctx: CommandContext):
embed = embeds.status("")
embed.title = "PluralKit Help"
embed.set_footer(text="By Astrid (Ske#6201; pk;member qoxvy) | GitHub: https://github.com/xSke/PluralKit/")
category = ctx.pop_str() if ctx.has_next() else None
from pluralkit.bot.help import help_pages
if category in help_pages:
@ -28,12 +28,12 @@ async def show_help(ctx: CommandContext, args: List[str]):
else:
embed.description = text
else:
return embeds.error("Unknown help page '{}'.".format(category))
return CommandError("Unknown help page '{}'.".format(category))
return embed
await ctx.reply(embed=embed)
@command(cmd="invite", description="Generates an invite link for this bot.", system_required=False)
async def invite_link(ctx: CommandContext, args: List[str]):
async def invite_link(ctx: CommandContext):
client_id = os.environ["CLIENT_ID"]
permissions = discord.Permissions()
@ -47,15 +47,16 @@ async def invite_link(ctx: CommandContext, args: List[str]):
url = oauth_url(client_id, permissions)
logger.debug("Sending invite URL: {}".format(url))
return embeds.success("Use this link to add PluralKit to your server: {}".format(url))
return CommandSuccess("Use this link to add PluralKit to your server: {}".format(url))
@command(cmd="export", description="Exports system data to a machine-readable format.")
async def export(ctx: CommandContext, args: List[str]):
members = await db.get_all_members(ctx.conn, ctx.system.id)
accounts = await db.get_linked_accounts(ctx.conn, ctx.system.id)
switches = await pluralkit.utils.get_front_history(ctx.conn, ctx.system.id, 999999)
system = ctx.system
async def export(ctx: CommandContext):
system = await ctx.ensure_system()
members = await db.get_all_members(ctx.conn, system.id)
accounts = await db.get_linked_accounts(ctx.conn, system.id)
switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 999999)
data = {
"name": system.name,
"id": system.hid,

View File

@ -1,24 +1,20 @@
import logging
from typing import List
from pluralkit.bot import utils, embeds
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="mod log", usage="[channel]", description="Sets the bot to log events to a specified channel. Leave blank to disable.", category="Moderation commands", system_required=False)
async def set_log(ctx: CommandContext, args: List[str]):
async def set_log(ctx: CommandContext):
if not ctx.message.author.server_permissions.administrator:
return embeds.error("You must be a server administrator to use this command.")
return CommandError("You must be a server administrator to use this command.")
server = ctx.message.server
if len(args) == 0:
if not ctx.has_next():
channel_id = None
else:
channel = utils.parse_channel_mention(args[0], server=server)
channel = utils.parse_channel_mention(ctx.pop_str(), server=server)
if not channel:
return embeds.error("Channel not found.")
return CommandError("Channel not found.")
channel_id = channel.id
await db.update_server(ctx.conn, server.id, logging_channel_id=channel_id)
return embeds.success("Updated logging channel." if channel_id else "Cleared logging channel.")
return CommandSuccess("Updated logging channel." if channel_id else "Cleared logging channel.")

View File

@ -1,120 +1,130 @@
from datetime import datetime
import logging
from typing import List
import dateparser
import humanize
from datetime import datetime, timezone
from typing import List
import pluralkit.utils
from pluralkit import Member
from pluralkit.bot import utils, embeds, help
from pluralkit.bot import help
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="switch", usage="<name|id> [name|id]...", description="Registers a switch and changes the current fronter.", category="Switching commands")
async def switch_member(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass at least one member name or ID to register a switch to.", help=help.switch_register)
async def switch_member(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
return CommandError("You must pass at least one member name or ID to register a switch to.",
help=help.switch_register)
members: List[Member] = []
for member_name in args:
for member_name in ctx.remaining().split(" "):
# Find the member
member = await utils.get_member_fuzzy(ctx.conn, ctx.system.id, member_name)
member = await utils.get_member_fuzzy(ctx.conn, system.id, member_name)
if not member:
return embeds.error("Couldn't find member \"{}\".".format(member_name))
return CommandError("Couldn't find member \"{}\".".format(member_name))
members.append(member)
# Compare requested switch IDs and existing fronter IDs to check for existing switches
# Lists, because order matters, it makes sense to just swap fronters
member_ids = [member.id for member in members]
fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, ctx.system.id))[0]
fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0]
if member_ids == fronter_ids:
if len(members) == 1:
return embeds.error("{} is already fronting.".format(members[0].name))
return embeds.error("Members {} are already fronting.".format(", ".join([m.name for m in members])))
return CommandError("{} is already fronting.".format(members[0].name))
return CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
# Also make sure there aren't any duplicates
if len(set(member_ids)) != len(member_ids):
return embeds.error("Duplicate members in member list.")
return CommandError("Duplicate members in member list.")
# Log the switch
async with ctx.conn.transaction():
switch_id = await db.add_switch(ctx.conn, system_id=ctx.system.id)
switch_id = await db.add_switch(ctx.conn, system_id=system.id)
for member in members:
await db.add_switch_member(ctx.conn, switch_id=switch_id, member_id=member.id)
if len(members) == 1:
return embeds.success("Switch registered. Current fronter is now {}.".format(members[0].name))
return CommandSuccess("Switch registered. Current fronter is now {}.".format(members[0].name))
else:
return embeds.success("Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members])))
return CommandSuccess(
"Switch registered. Current fronters are now {}.".format(", ".join([m.name for m in members])))
async def switch_out(ctx: CommandContext):
system = await ctx.ensure_system()
@command(cmd="switch out", description="Registers a switch with no one in front.", category="Switching commands")
async def switch_out(ctx: MemberCommandContext, args: List[str]):
# Get current fronters
fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=ctx.system.id)
fronters, _ = await pluralkit.utils.get_fronter_ids(ctx.conn, system_id=system.id)
if not fronters:
return embeds.error("There's already no one in front.")
return CommandError("There's already no one in front.")
# Log it, and don't log any members
await db.add_switch(ctx.conn, system_id=ctx.system.id)
return embeds.success("Switch-out registered.")
await db.add_switch(ctx.conn, system_id=system.id)
return CommandSuccess("Switch-out registered.")
@command(cmd="switch move", usage="<time>", description="Moves the most recent switch to a different point in time.", category="Switching commands")
async def switch_move(ctx: MemberCommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass a time to move the switch to.", help=help.switch_move)
async def switch_move(ctx: CommandContext):
system = await ctx.ensure_system()
if not ctx.has_next():
return CommandError("You must pass a time to move the switch to.", help=help.switch_move)
# Parse the time to move to
new_time = dateparser.parse(" ".join(args), languages=["en"], settings={
new_time = dateparser.parse(ctx.remaining(), languages=["en"], settings={
"TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False
})
if not new_time:
return embeds.error("{} can't be parsed as a valid time.".format(" ".join(args)))
return CommandError("'{}' can't be parsed as a valid time.".format(ctx.remaining()), help=help.switch_move)
# Make sure the time isn't in the future
if new_time > datetime.now():
return embeds.error("Can't move switch to a time in the future.")
if new_time > datetime.utcnow():
return CommandError("Can't move switch to a time in the future.", help=help.switch_move)
# Make sure it all runs in a big transaction for atomicity
async with ctx.conn.transaction():
# Get the last two switches to make sure the switch to move isn't before the second-last switch
last_two_switches = await pluralkit.utils.get_front_history(ctx.conn, ctx.system.id, count=2)
last_two_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=2)
if len(last_two_switches) == 0:
return embeds.error("There are no registered switches for this system.")
return CommandError("There are no registered switches for this system.")
last_timestamp, last_fronters = last_two_switches[0]
if len(last_two_switches) > 1:
second_last_timestamp, _ = last_two_switches[1]
if new_time < second_last_timestamp:
time_str = humanize.naturaltime(second_last_timestamp)
return embeds.error("Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str))
time_str = humanize.naturaltime(pluralkit.utils.fix_time(second_last_timestamp))
return CommandError(
"Can't move switch to before last switch time ({}), as it would cause conflicts.".format(time_str))
# Display the confirmation message w/ humanized times
members = ", ".join([member.name for member in last_fronters]) or "nobody"
last_absolute = last_timestamp.isoformat(sep=" ", timespec="seconds")
last_relative = humanize.naturaltime(last_timestamp)
last_relative = humanize.naturaltime(pluralkit.utils.fix_time(last_timestamp))
new_absolute = new_time.isoformat(sep=" ", timespec="seconds")
new_relative = humanize.naturaltime(new_time)
embed = utils.make_default_embed("This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute, last_relative, new_absolute, new_relative))
new_relative = humanize.naturaltime(pluralkit.utils.fix_time(new_time))
embed = embeds.status(
"This will move the latest switch ({}) from {} ({}) to {} ({}). Is this OK?".format(members, last_absolute,
last_relative,
new_absolute,
new_relative))
# Await and handle confirmation reactions
confirm_msg = await ctx.reply(embed=embed)
await ctx.client.add_reaction(confirm_msg, "")
await ctx.client.add_reaction(confirm_msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=confirm_msg, user=ctx.message.author, timeout=60.0)
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=confirm_msg, user=ctx.message.author,
timeout=60.0)
if not reaction:
return embeds.error("Switch move timed out.")
return CommandError("Switch move timed out.")
if reaction.reaction.emoji == "":
return embeds.error("Switch move cancelled.")
return CommandError("Switch move cancelled.")
# DB requires the actual switch ID which our utility method above doesn't return, do this manually
switch_id = (await db.front_history(ctx.conn, ctx.system.id, count=1))[0]["id"]
switch_id = (await db.front_history(ctx.conn, system.id, count=1))[0]["id"]
# Change the switch in the DB
await db.move_last_switch(ctx.conn, ctx.system.id, switch_id, new_time)
return embeds.success("Switch moved.")
await db.move_last_switch(ctx.conn, system.id, switch_id, new_time)
return CommandSuccess("Switch moved.")

View File

@ -1,39 +1,31 @@
from datetime import datetime
from typing import List
from urllib.parse import urlparse
import dateparser
import humanize
from datetime import datetime
from urllib.parse import urlparse
import pluralkit.utils
from pluralkit.bot import embeds, help
from pluralkit.bot import help
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands")
@command(cmd="system", usage="[system]", description="Shows information about a system.", category="System commands", system_required=False)
async def system_info(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
system = ctx.system
else:
# Look one up
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
return embeds.error("Unable to find system \"{}\".".format(args[0]))
async def system_info(ctx: CommandContext):
if ctx.has_next():
system = await ctx.pop_system()
else:
system = await ctx.ensure_system()
await ctx.reply(embed=await utils.generate_system_info_card(ctx.conn, ctx.client, system))
@command(cmd="system new", usage="[name]", description="Registers a new system to this account.", category="System commands", system_required=False)
async def new_system(ctx: CommandContext, args: List[str]):
if ctx.system:
return embeds.error("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
system_name = None
if len(args) > 0:
system_name = " ".join(args)
async def new_system(ctx: CommandContext):
system = await ctx.get_system()
if system:
return CommandError(
"You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
system_name = ctx.remaining() or None
async with ctx.conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
@ -43,12 +35,13 @@ async def new_system(ctx: CommandContext, args: List[str]):
# Link account
await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return embeds.success("System registered! To begin adding members, use `pk;member new <name>`.")
return CommandSuccess("System registered! To begin adding members, use `pk;member new <name>`.")
@command(cmd="system set", usage="<name|description|tag|avatar> [value]", description="Edits a system property. Leave [value] blank to clear.", category="System commands")
async def system_set(ctx: CommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass a property name to set.", help=help.edit_system)
async def system_set(ctx: CommandContext):
system = await ctx.ensure_system()
prop = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_system))
allowed_properties = ["name", "description", "tag", "avatar"]
db_properties = {
@ -58,25 +51,29 @@ async def system_set(ctx: CommandContext, args: List[str]):
"avatar": "avatar_url"
}
prop = args[0]
if prop not in allowed_properties:
raise embeds.error("Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), help=help.edit_system)
return CommandError(
"Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)),
help=help.edit_system)
if len(args) >= 2:
value = " ".join(args[1:])
if ctx.has_next():
value = ctx.remaining()
# Sanity checking
if prop == "tag":
if len(value) > 32:
raise embeds.error("You can't have a system tag longer than 32 characters.")
return CommandError("You can't have a system tag longer than 32 characters.")
# Make sure there are no members which would make the combined length exceed 32
members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=ctx.system.id, length=32 - len(value) - 1)
members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=system.id,
length=32 - len(value) - 1)
if len(members_exceeding) > 0:
# If so, error out and warn
member_names = ", ".join([member.name
for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
raise embeds.error("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names))
return CommandError(
"The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(
member_names))
if prop == "avatar":
user = await utils.parse_mention(ctx.client, value)
@ -90,75 +87,73 @@ async def system_set(ctx: CommandContext, args: List[str]):
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
raise embeds.error("Invalid image URL.")
return CommandError("Invalid image URL.")
else:
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_system_field(ctx.conn, system_id=ctx.system.id, field=db_prop, value=value)
await db.update_system_field(ctx.conn, system_id=system.id, field=db_prop, value=value)
response = embeds.success("{} system {}.".format("Updated" if value else "Cleared", prop))
response = CommandSuccess("{} system {}.".format("Updated" if value else "Cleared", prop))
if prop == "avatar" and value:
response.set_image(url=value)
return response
@command(cmd="system link", usage="<account>", description="Links another account to your system.", category="System commands")
async def system_link(ctx: CommandContext, args: List[str]):
if len(args) == 0:
return embeds.error("You must pass an account to link this system to.", help=help.link_account)
async def system_link(ctx: CommandContext):
system = await ctx.ensure_system()
account_name = ctx.pop_str(CommandError("You must pass an account to link this system to.", help=help.link_account))
# Find account to link
linkee = await utils.parse_mention(ctx.client, args[0])
linkee = await utils.parse_mention(ctx.client, account_name)
if not linkee:
return embeds.error("Account not found.")
return CommandError("Account not found.")
# Make sure account doesn't already have a system
account_system = await db.get_system_by_account(ctx.conn, linkee.id)
if account_system:
return embeds.error("The mentioned account is already linked to a system (`{}`)".format(account_system.hid))
return CommandError("The mentioned account is already linked to a system (`{}`)".format(account_system.hid))
# Send confirmation message
msg = await ctx.reply("{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
msg = await ctx.reply(
"{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
await ctx.client.add_reaction(msg, "")
await ctx.client.add_reaction(msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee, timeout=60.0)
# If account to be linked confirms...
if not reaction:
return embeds.error("Account link timed out.")
return CommandError("Account link timed out.")
if not reaction.reaction.emoji == "":
return embeds.error("Account link cancelled.")
return CommandError("Account link cancelled.")
await db.link_account(ctx.conn, system_id=ctx.system.id, account_id=linkee.id)
return embeds.success("Account linked to system.")
await db.link_account(ctx.conn, system_id=system.id, account_id=linkee.id)
return CommandSuccess("Account linked to system.")
async def system_unlink(ctx: CommandContext):
system = await ctx.ensure_system()
@command(cmd="system unlink", description="Unlinks your system from this account. There must be at least one other account linked.", category="System commands")
async def system_unlink(ctx: CommandContext, args: List[str]):
# Make sure you can't unlink every account
linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=ctx.system.id)
linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=system.id)
if len(linked_accounts) == 1:
return embeds.error("This is the only account on your system, so you can't unlink it.")
return CommandError("This is the only account on your system, so you can't unlink it.")
await db.unlink_account(ctx.conn, system_id=ctx.system.id, account_id=ctx.message.author.id)
return embeds.success("Account unlinked.")
await db.unlink_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return CommandSuccess("Account unlinked.")
@command(cmd="system fronter", usage="[system]", description="Gets the current fronter(s) in the system.", category="Switching commands", system_required=False)
async def system_fronter(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
system = ctx.system
async def system_fronter(ctx: CommandContext):
if ctx.has_next():
system = await ctx.pop_system()
else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
return embeds.error("Can't find system \"{}\".".format(args[0]))
system = await ctx.ensure_system()
fronters, timestamp = await pluralkit.utils.get_fronters(ctx.conn, system_id=system.id)
fronter_names = [member.name for member in fronters]
embed = utils.make_default_embed(None)
embed = embeds.status("")
if len(fronter_names) == 0:
embed.add_field(name="Current fronter", value="(no fronter)")
@ -168,20 +163,16 @@ async def system_fronter(ctx: CommandContext, args: List[str]):
embed.add_field(name="Current fronters", value=", ".join(fronter_names))
if timestamp:
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"), humanize.naturaltime(timestamp)))
return embed
embed.add_field(name="Since", value="{} ({})".format(timestamp.isoformat(sep=" ", timespec="seconds"),
humanize.naturaltime(pluralkit.utils.fix_time(timestamp))))
await ctx.reply(embed=embed)
@command(cmd="system fronthistory", usage="[system]", description="Shows the past 10 switches in the system.", category="Switching commands", system_required=False)
async def system_fronthistory(ctx: CommandContext, args: List[str]):
if len(args) == 0:
if not ctx.system:
raise NoSystemRegistered()
system = ctx.system
async def system_fronthistory(ctx: CommandContext):
if ctx.has_next():
system = await ctx.pop_system()
else:
system = await utils.get_system_fuzzy(ctx.conn, ctx.client, args[0])
if system is None:
raise embeds.error("Can't find system \"{}\".".format(args[0]))
system = await ctx.ensure_system()
lines = []
front_history = await pluralkit.utils.get_front_history(ctx.conn, system.id, count=10)
@ -194,37 +185,39 @@ async def system_fronthistory(ctx: CommandContext, args: List[str]):
# Make proper date string
time_text = timestamp.isoformat(sep=" ", timespec="seconds")
rel_text = humanize.naturaltime(timestamp)
rel_text = humanize.naturaltime(pluralkit.utils.fix_time(timestamp))
delta_text = ""
if i > 0:
last_switch_time = front_history[i-1][0]
last_switch_time = front_history[i - 1][0]
delta_text = ", for {}".format(humanize.naturaldelta(timestamp - last_switch_time))
lines.append("**{}** ({}, {}{})".format(name, time_text, rel_text, delta_text))
embed = utils.make_default_embed("\n".join(lines) or "(none)")
embed = embeds.status("\n".join(lines) or "(none)")
embed.title = "Past switches"
return embed
await ctx.reply(embed=embed)
@command(cmd="system delete", description="Deletes your system from the database ***permanently***.", category="System commands")
async def system_delete(ctx: CommandContext, args: List[str]):
await ctx.reply("Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(ctx.system.hid))
async def system_delete(ctx: CommandContext):
system = await ctx.ensure_system()
await ctx.reply(
"Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(
system.hid))
msg = await ctx.client.wait_for_message(author=ctx.message.author, channel=ctx.message.channel, timeout=60.0)
if msg and msg.content.lower() == ctx.system.hid.lower():
await db.remove_system(ctx.conn, system_id=ctx.system.id)
return embeds.success("System deleted.")
if msg and msg.content.lower() == system.hid.lower():
await db.remove_system(ctx.conn, system_id=system.id)
return CommandSuccess("System deleted.")
else:
return embeds.error("System deletion cancelled.")
return CommandError("System deletion cancelled.")
@command(cmd="system frontpercent", usage="[time]",
description="Shows the fronting percentage of every member, averaged over the given time",
category="System commands")
async def system_frontpercent(ctx: CommandContext, args: List[str]):
async def system_frontpercent(ctx: CommandContext):
system = await ctx.ensure_system()
# Parse the time limit (will go this far back)
before = dateparser.parse(" ".join(args), languages=["en"], settings={
before = dateparser.parse(ctx.remaining(), languages=["en"], settings={
"TO_TIMEZONE": "UTC",
"RETURN_AS_TIMEZONE_AWARE": False
})
@ -234,9 +227,9 @@ async def system_frontpercent(ctx: CommandContext, args: List[str]):
before = None
# Fetch list of switches
all_switches = await pluralkit.utils.get_front_history(ctx.conn, ctx.system.id, 99999)
all_switches = await pluralkit.utils.get_front_history(ctx.conn, system.id, 99999)
if not all_switches:
return embeds.error("No switches registered to this system.")
return CommandError("No switches registered to this system.")
# Cull the switches *ending* before the limit, if given
# We'll need to find the first switch starting before the limit, then cut off every switch *before* that
@ -264,11 +257,11 @@ async def system_frontpercent(ctx: CommandContext, args: List[str]):
# Calculate length of the switch
switch_length = end_time - start_time
def add_switch(member_id, length):
if member_id not in member_times:
member_times[member_id] = length
def add_switch(id, length):
if id not in member_times:
member_times[id] = length
else:
member_times[member_id] += length
member_times[id] += length
for member in members:
# Add the switch length to the currently registered time for that member
@ -297,4 +290,4 @@ async def system_frontpercent(ctx: CommandContext, args: List[str]):
value="{}% ({})".format(percent, humanize.naturaldelta(front_time)))
embed.set_footer(text="Since {}".format(span_start.isoformat(sep=" ", timespec="seconds")))
return embed
await ctx.reply(embed=embed)

View File

@ -8,7 +8,7 @@ import aiohttp
import discord
from pluralkit import db
from pluralkit.bot import channel_logger, utils
from pluralkit.bot import channel_logger, utils, embeds
from pluralkit.stats import StatCollector
logger = logging.getLogger("pluralkit.bot.proxy")
@ -253,10 +253,10 @@ class Proxy:
async with conn.transaction():
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
except WebhookPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
except DeletionPermissionError:
embed = utils.make_error_embed("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
return True

View File

@ -81,19 +81,6 @@ async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) ->
if member is not None:
return member
def make_default_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message
return embed
def make_error_embed(message):
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message
return embed
async def generate_system_info_card(conn, client: discord.Client, system: System) -> discord.Embed:
card = discord.Embed()
card.colour = discord.Colour.blue()

View File

@ -323,7 +323,7 @@ async def create_tables(conn):
description text,
tag text,
avatar_url text,
created timestamp not null default current_timestamp
created timestamp not null default (current_timestamp at time zone 'utc')
)""")
await conn.execute("""create table if not exists members (
id serial primary key,
@ -337,7 +337,7 @@ async def create_tables(conn):
description text,
prefix text,
suffix text,
created timestamp not null default current_timestamp
created timestamp not null default (current_timestamp at time zone 'utc')
)""")
await conn.execute("""create table if not exists accounts (
uid bigint primary key,
@ -353,7 +353,7 @@ async def create_tables(conn):
await conn.execute("""create table if not exists switches (
id serial primary key,
system serial not null references systems(id) on delete cascade,
timestamp timestamp not null default current_timestamp
timestamp timestamp not null default (current_timestamp at time zone 'utc')
)""")
await conn.execute("""create table if not exists switch_members (
id serial primary key,

View File

@ -1,9 +1,14 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Tuple
from pluralkit import db, Member
def fix_time(time: datetime):
# Assume we're receiving a naive datetime set to UTC, returns naive time zone set to local
return time.replace(tzinfo=timezone.utc).astimezone().replace(tzinfo=None)
async def get_fronter_ids(conn, system_id) -> (List[int], datetime):
switches = await db.front_history(conn, system_id=system_id, count=1)
if not switches: