Move various system functionality to system.py

This commit is contained in:
Ske 2018-09-09 20:38:57 +02:00
parent 49b4e4c1ef
commit a079db8be0
8 changed files with 182 additions and 102 deletions

View File

@ -4,7 +4,7 @@ import discord
import logging import logging
import re import re
import traceback import traceback
from typing import Tuple, Optional from typing import Tuple, Optional, Union
from pluralkit import db from pluralkit import db
from pluralkit.system import System from pluralkit.system import System
@ -110,7 +110,7 @@ class CommandContext:
async def reply(self, content=None, embed=None): async def reply(self, content=None, embed=None):
return await self.client.send_message(self.message.channel, content=content, embed=embed) return await self.client.send_message(self.message.channel, content=content, embed=embed)
async def confirm_react(self, user: discord.Member, message: str): async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
message = await self.reply(message) message = await self.reply(message)
await self.client.add_reaction(message, "") await self.client.add_reaction(message, "")

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
from datetime import datetime from datetime import datetime
import pluralkit.utils
from pluralkit.bot.commands import * from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@ -100,7 +101,7 @@ async def import_tupperware(ctx: CommandContext):
# Create new (nameless) system if there isn't any registered # Create new (nameless) system if there isn't any registered
system = await ctx.get_system() system = await ctx.get_system()
if system is None: if system is None:
hid = utils.generate_hid() hid = pluralkit.utils.generate_hid()
logger.debug("Creating new system (hid={})...".format(hid)) logger.debug("Creating new system (hid={})...".format(hid))
system = await db.create_system(ctx.conn, system_name=None, system_hid=hid) system = await db.create_system(ctx.conn, system_name=None, system_hid=hid)
await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id) await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
@ -140,7 +141,7 @@ async def import_tupperware(ctx: CommandContext):
existing_member = await db.get_member_by_name(ctx.conn, system_id=system.id, member_name=name) existing_member = await db.get_member_by_name(ctx.conn, system_id=system.id, member_name=name)
if not existing_member: if not existing_member:
# Or create a new member # Or create a new member
hid = utils.generate_hid() hid = pluralkit.utils.generate_hid()
logger.debug("Creating new member {} (hid={})...".format(name, hid)) logger.debug("Creating new member {} (hid={})...".format(name, hid))
existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name, existing_member = await db.create_member(ctx.conn, system_id=system.id, member_name=name,
member_hid=hid) member_hid=hid)

View File

@ -1,10 +1,9 @@
import re
from datetime import datetime from datetime import datetime
from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
from pluralkit.bot.commands import * import pluralkit.utils
from pluralkit.bot import help from pluralkit.bot import help
from pluralkit.bot.commands import *
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@ -26,7 +25,7 @@ async def new_member(ctx: CommandContext):
return CommandError(bounds_error) return CommandError(bounds_error)
# TODO: figure out what to do if this errors out on collision on generate_hid # TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid() hid = pluralkit.utils.generate_hid()
# Insert member row # Insert member row
await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid) await db.create_member(ctx.conn, system_id=system.id, member_name=name, member_hid=hid)

View File

@ -1,11 +1,12 @@
import dateparser import dateparser
import humanize import humanize
from datetime import datetime from datetime import datetime
from urllib.parse import urlparse
import pluralkit.utils import pluralkit.utils
from pluralkit.bot import help from pluralkit.bot import help
from pluralkit.bot.commands import * from pluralkit.bot.commands import *
from pluralkit.errors import ExistingSystemError, DescriptionTooLongError, TagTooLongError, TagTooLongWithMembersError, \
InvalidAvatarURLError, UnlinkingLastAccountError
logger = logging.getLogger("pluralkit.commands") logger = logging.getLogger("pluralkit.commands")
@ -20,90 +21,58 @@ async def system_info(ctx: CommandContext):
async def new_system(ctx: CommandContext): async def new_system(ctx: CommandContext):
system = await ctx.get_system() system_name = ctx.remaining() or None
if system:
try:
await System.create_system(ctx.conn, ctx.message.author.id, system_name)
except ExistingSystemError:
return CommandError( return CommandError(
"You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.") "You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
system_name = ctx.remaining() or None return CommandSuccess("System registered! To begin adding members, use `pk;member new <name>`.")
async with ctx.conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = utils.generate_hid()
system = await db.create_system(ctx.conn, system_name=system_name, system_hid=hid)
# Link account
await db.link_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return CommandSuccess("System registered! To begin adding members, use `pk;member new <name>`.")
async def system_set(ctx: CommandContext): async def system_set(ctx: CommandContext):
system = await ctx.ensure_system() system = await ctx.ensure_system()
prop = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_system)) property_name = ctx.pop_str(CommandError("You must pass a property name to set.", help=help.edit_system))
allowed_properties = ["name", "description", "tag", "avatar"] async def avatar_setter(conn, url):
db_properties = { user = await utils.parse_mention(ctx.client, url)
"name": "name", if user:
"description": "description", # Set the avatar to the mentioned user's avatar
"tag": "tag", # Discord pushes webp by default, which isn't supported by webhooks, but also hosts png alternatives
"avatar": "avatar_url" url = user.avatar_url.replace(".webp", ".png")
await system.set_avatar(conn, url)
properties = {
"name": system.set_name,
"description": system.set_description,
"tag": system.set_tag,
"avatar": avatar_setter
} }
if prop not in allowed_properties: if property_name not in properties:
return CommandError( return CommandError(
"Unknown property {}. Allowed properties are {}.".format(prop, ", ".join(allowed_properties)), "Unknown property {}. Allowed properties are {}.".format(property_name, ", ".join(allowed_properties)),
help=help.edit_system) help=help.edit_system)
if ctx.has_next(): value = ctx.remaining() or None
value = ctx.remaining()
# Sanity checking
if prop == "description":
if len(value) > 1024:
return CommandError("You can't have a description longer than 1024 characters.")
if prop == "tag": try:
if len(value) > 32: await properties[property_name](ctx.conn, value)
return CommandError("You can't have a system tag longer than 32 characters.") except DescriptionTooLongError:
return CommandError("You can't have a description longer than 1024 characters.")
except TagTooLongError:
return CommandError("You can't have a system tag longer than 32 characters.")
except TagTooLongWithMembersError as e:
return CommandError("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(", ".join(e.member_names)))
except InvalidAvatarURLError:
return CommandError("Invalid image URL.")
if re.search("<a?:\w+:\d+>", value): response = CommandSuccess("{} system {}.".format("Updated" if value else "Cleared", property_name))
return CommandError("Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.") # if prop == "avatar" and value:
# Make sure there are no members which would make the combined length exceed 32
members_exceeding = await db.get_members_exceeding(ctx.conn, system_id=system.id,
length=32 - len(value) - 1)
if len(members_exceeding) > 0:
# If so, error out and warn
member_names = ", ".join([member.name
for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(value, member_names))
return CommandError(
"The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(
member_names))
if prop == "avatar":
user = await utils.parse_mention(ctx.client, value)
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
value = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(value)
if u.scheme in ["http", "https"] and u.netloc and u.path:
value = value
else:
return CommandError("Invalid image URL.")
else:
# Clear from DB
value = None
db_prop = db_properties[prop]
await db.update_system_field(ctx.conn, system_id=system.id, field=db_prop, value=value)
response = CommandSuccess("{} system {}.".format("Updated" if value else "Cleared", prop))
#if prop == "avatar" and value:
# response.set_image(url=value) # response.set_image(url=value)
return response return response
@ -118,36 +87,25 @@ async def system_link(ctx: CommandContext):
return CommandError("Account not found.") return CommandError("Account not found.")
# Make sure account doesn't already have a system # Make sure account doesn't already have a system
account_system = await db.get_system_by_account(ctx.conn, linkee.id) account_system = await System.get_by_account(ctx.conn, linkee.id)
if account_system: if account_system:
return CommandError("The mentioned account is already linked to a system (`{}`)".format(account_system.hid)) return CommandError("The mentioned account is already linked to a system (`{}`)".format(account_system.hid))
# Send confirmation message if not await ctx.confirm_react(linkee, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention)):
msg = await ctx.reply(
"{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
await ctx.client.add_reaction(msg, "")
await ctx.client.add_reaction(msg, "")
reaction = await ctx.client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee, timeout=60.0 * 5)
# If account to be linked confirms...
if not reaction:
return CommandError("Account link timed out.")
if not reaction.reaction.emoji == "":
return CommandError("Account link cancelled.") return CommandError("Account link cancelled.")
await db.link_account(ctx.conn, system_id=system.id, account_id=linkee.id) await system.link_account(ctx.conn, linkee.id)
return CommandSuccess("Account linked to system.") return CommandSuccess("Account linked to system.")
async def system_unlink(ctx: CommandContext): async def system_unlink(ctx: CommandContext):
system = await ctx.ensure_system() system = await ctx.ensure_system()
# Make sure you can't unlink every account try:
linked_accounts = await db.get_linked_accounts(ctx.conn, system_id=system.id) await system.unlink_account(ctx.conn, ctx.message.author.id)
if len(linked_accounts) == 1: except UnlinkingLastAccountError:
return CommandError("This is the only account on your system, so you can't unlink it.") return CommandError("This is the only account on your system, so you can't unlink it.")
await db.unlink_account(ctx.conn, system_id=system.id, account_id=ctx.message.author.id)
return CommandSuccess("Account unlinked.") return CommandSuccess("Account unlinked.")
@ -208,11 +166,12 @@ async def system_fronthistory(ctx: CommandContext):
async def system_delete(ctx: CommandContext): async def system_delete(ctx: CommandContext):
system = await ctx.ensure_system() system = await ctx.ensure_system()
delete_confirm_msg = "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(system.hid) delete_confirm_msg = "Are you sure you want to delete your system? If so, reply to this message with the system's ID (`{}`).".format(
system.hid)
if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, system.hid, delete_confirm_msg): if not await ctx.confirm_text(ctx.message.author, ctx.message.channel, system.hid, delete_confirm_msg):
return CommandError("System deletion cancelled.") return CommandError("System deletion cancelled.")
await db.remove_system(ctx.conn, system_id=system.id) await system.delete(ctx.conn)
return CommandSuccess("System deleted.") return CommandSuccess("System deleted.")

View File

@ -1,7 +1,5 @@
import logging import logging
import random
import re import re
import string
import discord import discord
import humanize import humanize
@ -16,9 +14,6 @@ logger = logging.getLogger("pluralkit.utils")
def escape(s): def escape(s):
return s.replace("`", "\\`") return s.replace("`", "\\`")
def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5))
def bounds_check_member_name(new_name, system_tag): def bounds_check_member_name(new_name, system_tag):
if len(new_name) > 32: if len(new_name) > 32:
return "Name cannot be longer than 32 characters." return "Name cannot be longer than 32 characters."

36
src/pluralkit/errors.py Normal file
View File

@ -0,0 +1,36 @@
class PluralKitError(Exception):
pass
class ExistingSystemError(PluralKitError):
pass
class DescriptionTooLongError(PluralKitError):
pass
class TagTooLongError(PluralKitError):
pass
class TagTooLongWithMembersError(PluralKitError):
def __init__(self, member_names):
self.member_names = member_names
class CustomEmojiError(PluralKitError):
pass
class InvalidAvatarURLError(PluralKitError):
pass
class AccountAlreadyLinkedError(PluralKitError):
def __init__(self, existing_system):
self.existing_system = existing_system
class UnlinkingLastAccountError(PluralKitError):
pass

View File

@ -1,6 +1,10 @@
from datetime import datetime from datetime import datetime
from collections.__init__ import namedtuple from collections.__init__ import namedtuple
from typing import Optional
from pluralkit import db, errors
from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])): class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])):
@ -12,6 +16,70 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
avatar_url: str avatar_url: str
created: datetime created: datetime
@staticmethod
async def get_by_account(conn, account_id: str) -> "System":
return await db.get_system_by_account(conn, account_id)
@staticmethod
async def create_system(conn, account_id: str, system_name: Optional[str] = None) -> "System":
existing_system = await System.get_by_account(conn, account_id)
if existing_system:
raise errors.ExistingSystemError()
new_hid = generate_hid()
async with conn.transaction():
new_system = await db.create_system(conn, system_name, new_hid)
await db.link_account(conn, new_system.id, account_id)
return new_system
async def set_name(self, conn, new_name: Optional[str]):
await db.update_system_field(conn, self.id, "name", new_name)
async def set_description(self, conn, new_description: Optional[str]):
if new_description and len(new_description) > 1024:
raise errors.DescriptionTooLongError()
await db.update_system_field(conn, self.id, "description", new_description)
async def set_tag(self, conn, new_tag: Optional[str]):
if new_tag:
if len(new_tag) > 32:
raise errors.TagTooLongError()
if contains_custom_emoji(new_tag):
raise errors.CustomEmojiError()
members_exceeding = await db.get_members_exceeding(conn, system_id=self.id, length=32 - len(new_tag) - 1)
if len(members_exceeding) > 0:
raise errors.TagTooLongWithMembersError([member.name for member in members_exceeding])
await db.update_system_field(conn, self.id, "tag", new_tag)
async def set_avatar(self, conn, new_avatar_url: Optional[str]):
if new_avatar_url:
validate_avatar_url_or_raise(new_avatar_url)
await db.update_system_field(conn, self.id, "avatar", new_avatar_url)
async def link_account(self, conn, new_account_id: str):
existing_system = await System.get_by_account(conn, new_account_id)
if existing_system:
raise errors.AccountAlreadyLinkedError(existing_system)
await db.link_account(conn, self.id, new_account_id)
async def unlink_account(self, conn, account_id: str):
linked_accounts = await db.get_linked_accounts(conn, self.id)
if len(linked_accounts) == 1:
raise errors.UnlinkingLastAccountError()
await db.unlink_account(conn, self.id, account_id)
async def delete(self, conn):
await db.remove_system(conn, self.id)
def to_json(self): def to_json(self):
return { return {
"id": self.hid, "id": self.hid,
@ -19,4 +87,4 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
"description": self.description, "description": self.description,
"tag": self.tag, "tag": self.tag,
"avatar_url": self.avatar_url "avatar_url": self.avatar_url
} }

View File

@ -1,7 +1,13 @@
import re
import random
import string
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Tuple from typing import List, Tuple
from urllib.parse import urlparse
from pluralkit import db from pluralkit import db
from pluralkit.errors import InvalidAvatarURLError
from pluralkit.member import Member from pluralkit.member import Member
@ -48,3 +54,19 @@ async def get_front_history(conn, system_id, count) -> List[Tuple[datetime, List
members = [all_members[id] for id in switch["members"]] members = [all_members[id] for id in switch["members"]]
out.append((timestamp, members)) out.append((timestamp, members))
return out return out
def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5))
def contains_custom_emoji(value):
return bool(re.search("<a?:\w+:\d+>", value))
def validate_avatar_url_or_raise(url):
u = urlparse(url)
if not (u.scheme in ["http", "https"] and u.netloc and u.path):
raise InvalidAvatarURLError()
# TODO: check file type and size of image