diff --git a/src/api_main.py b/src/api_main.py index 87d96f2b..7fe9484c 100644 --- a/src/api_main.py +++ b/src/api_main.py @@ -1,19 +1,40 @@ -import os - +import json import logging +import os from aiohttp import web from pluralkit import db, utils +from pluralkit.errors import PluralKitError +from pluralkit.member import Member +from pluralkit.system import System logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") logger = logging.getLogger("pluralkit.api") def db_handler(f): - async def inner(request): + async def inner(request, *args, **kwargs): async with request.app["pool"].acquire() as conn: - return await f(request, conn) + return await f(request, conn=conn, *args, **kwargs) + + return inner + + +def system_auth(f): + async def inner(request: web.Request, conn, *args, **kwargs): + token = request.headers.get("X-Token") + if not token: + token = request.query.get("token") + + if not token: + raise web.HTTPUnauthorized() + + system = await System.get_by_token(conn, token) + if not system: + raise web.HTTPUnauthorized() + + return await f(request, conn=conn, system=system, *args, **kwargs) return inner @@ -58,6 +79,7 @@ async def get_switches(request: web.Request, conn): return web.json_response(data) + @db_handler async def get_message(request: web.Request, conn): message = await db.get_message(conn, request.match_info["id"]) @@ -66,6 +88,7 @@ async def get_message(request: web.Request, conn): return web.json_response(message.to_json()) + @db_handler async def get_switch(request: web.Request, conn): system = await db.get_system_by_hid(conn, request.match_info["id"]) @@ -84,6 +107,7 @@ async def get_switch(request: web.Request, conn): } return web.json_response(data) + @db_handler async def get_switch_name(request: web.Request, conn): system = await db.get_system_by_hid(conn, request.match_info["id"]) @@ -94,6 +118,7 @@ async def get_switch_name(request: web.Request, conn): members, stamp = await utils.get_fronters(conn, system.id) return web.Response(text=members[0].name if members else "(nobody)") + @db_handler async def get_switch_color(request: web.Request, conn): system = await db.get_system_by_hid(conn, request.match_info["id"]) @@ -104,6 +129,34 @@ async def get_switch_color(request: web.Request, conn): members, stamp = await utils.get_fronters(conn, system.id) return web.Response(text=members[0].color if members else "#ffffff") + +@db_handler +@system_auth +async def put_switch(request: web.Request, system: System, conn): + try: + req = await request.json() + except json.JSONDecodeError: + raise web.HTTPBadRequest(body="Invalid JSON") + + if isinstance(req, str): + req = [req] + elif not isinstance(req, list): + raise web.HTTPBadRequest(body="Body must be JSON string or list") + + members = [] + for member_name in req: + if not isinstance(member_name, str): + raise web.HTTPBadRequest(body="List value must be string") + + member = await Member.get_member_fuzzy(conn, system.id, member_name) + if not member: + raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name)) + members.append(member) + + switch = await system.add_switch(conn, members) + return web.json_response(await switch.to_json(conn)) + + @db_handler async def get_stats(request: web.Request, conn): system_count = await db.system_count(conn) @@ -116,11 +169,21 @@ async def get_stats(request: web.Request, conn): "messages": message_count }) -app = web.Application() + +@web.middleware +async def render_pk_errors(request, handler): + try: + return await handler(request) + except PluralKitError as e: + raise web.HTTPBadRequest(body=e.message) + + +app = web.Application(middlewares=[render_pk_errors]) app.add_routes([ web.get("/systems/{id}", get_system), web.get("/systems/{id}/switches", get_switches), web.get("/systems/{id}/switch", get_switch), + web.put("/systems/{id}/switch", put_switch), web.get("/systems/{id}/switch/name", get_switch_name), web.get("/systems/{id}/switch/color", get_switch_color), web.get("/members/{id}", get_member), diff --git a/src/pluralkit/bot/commands/member_commands.py b/src/pluralkit/bot/commands/member_commands.py index b9406ceb..53f997c8 100644 --- a/src/pluralkit/bot/commands/member_commands.py +++ b/src/pluralkit/bot/commands/member_commands.py @@ -12,11 +12,10 @@ async def member_root(ctx: CommandContext): elif ctx.match("set"): await member_set(ctx) # TODO "pk;member list" - - if not ctx.has_next(): + elif not ctx.has_next(): raise CommandError("Must pass a subcommand. For a list of subcommands, type `pk;member help`.") - - await specific_member_root(ctx) + else: + await specific_member_root(ctx) async def specific_member_root(ctx: CommandContext): diff --git a/src/pluralkit/bot/commands/switch_commands.py b/src/pluralkit/bot/commands/switch_commands.py index f80755d1..ebf15ec3 100644 --- a/src/pluralkit/bot/commands/switch_commands.py +++ b/src/pluralkit/bot/commands/switch_commands.py @@ -33,19 +33,6 @@ async def switch_member(ctx: CommandContext): while ctx.has_next(): members.append(await ctx.pop_member()) - # Compare requested switch IDs and existing fronter IDs to check for existing switches - # Lists, because order matters, it makes sense to just swap fronters - member_ids = [member.id for member in members] - fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0] - if member_ids == fronter_ids: - if len(members) == 1: - raise CommandError("{} is already fronting.".format(members[0].name)) - raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members]))) - - # Also make sure there aren't any duplicates - if len(set(member_ids)) != len(member_ids): - raise CommandError("Duplicate members in member list.") - # Log the switch await system.add_switch(ctx.conn, members) diff --git a/src/pluralkit/errors.py b/src/pluralkit/errors.py index b9506d75..cb421e91 100644 --- a/src/pluralkit/errors.py +++ b/src/pluralkit/errors.py @@ -12,7 +12,8 @@ class PluralKitError(Exception): class ExistingSystemError(PluralKitError): def __init__(self): - super().__init__("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.") + super().__init__( + "You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.") class DescriptionTooLongError(PluralKitError): @@ -27,13 +28,16 @@ class TagTooLongError(PluralKitError): class TagTooLongWithMembersError(PluralKitError): def __init__(self, member_names): - super().__init__("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(", ".join(member_names))) + super().__init__( + "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format( + ", ".join(member_names))) self.member_names = member_names class CustomEmojiError(PluralKitError): def __init__(self): - super().__init__("Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.") + super().__init__( + "Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.") class InvalidAvatarURLError(PluralKitError): @@ -60,7 +64,8 @@ class UnlinkingLastAccountError(PluralKitError): class MemberNameTooLongError(PluralKitError): def __init__(self, tag_present: bool): if tag_present: - super().__init__("The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.") + super().__init__( + "The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.") else: super().__init__("The maximum length of a member name is 32 characters.") @@ -69,6 +74,22 @@ class InvalidColorError(PluralKitError): def __init__(self): super().__init__("Color must be a valid hex color. (eg. #ff0000)") + class InvalidDateStringError(PluralKitError): def __init__(self): - super().__init__("Invalid date string. Date must be in ISO-8601 format (YYYY-MM-DD, eg. 1999-07-25).") \ No newline at end of file + super().__init__("Invalid date string. Date must be in ISO-8601 format (YYYY-MM-DD, eg. 1999-07-25).") + + +class MembersAlreadyFrontingError(PluralKitError): + def __init__(self, members: "List[Member]"): + if len(members) == 0: + super().__init__("There are already no members fronting.") + elif len(members) == 1: + super().__init__("Member {} is already fronting.".format(members[0].name)) + else: + super().__init__("Members {} are already fronting.".format(", ".join([member.name for member in members]))) + + +class DuplicateSwitchMembersError(PluralKitError): + def __init__(self): + super().__init__("Duplicate members in member list.") diff --git a/src/pluralkit/member.py b/src/pluralkit/member.py index 28a617ac..0f1a8341 100644 --- a/src/pluralkit/member.py +++ b/src/pluralkit/member.py @@ -54,6 +54,16 @@ class Member(namedtuple("Member", return member + @staticmethod + async def get_member_fuzzy(conn, system_id: int, name: str) -> "Optional[Member]": + by_hid = await Member.get_member_by_hid(conn, system_id, name) + if by_hid: + return by_hid + + by_name = await Member.get_member_by_name(conn, system_id, name) + return by_name + + async def set_name(self, conn, new_name: str): """ Set the name of a member. diff --git a/src/pluralkit/switch.py b/src/pluralkit/switch.py index 784d1f9e..d2c86757 100644 --- a/src/pluralkit/switch.py +++ b/src/pluralkit/switch.py @@ -1,6 +1,5 @@ from collections import namedtuple from datetime import datetime - from typing import List from pluralkit import db @@ -17,4 +16,10 @@ class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])): return await db.get_members(conn, self.members) async def delete(self, conn): - await db.delete_switch(conn, self.id) \ No newline at end of file + await db.delete_switch(conn, self.id) + + async def to_json(self, conn): + return { + "timestamp": self.timestamp.isoformat(), + "members": [member.hid for member in await self.fetch_members(conn)] + } diff --git a/src/pluralkit/system.py b/src/pluralkit/system.py index c121b973..7b03bcec 100644 --- a/src/pluralkit/system.py +++ b/src/pluralkit/system.py @@ -1,9 +1,8 @@ import random import re import string -from datetime import datetime - from collections.__init__ import namedtuple +from datetime import datetime from typing import Optional, List, Tuple from pluralkit import db, errors @@ -22,6 +21,10 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a token: str created: datetime + @staticmethod + async def get_by_id(conn, system_id: int) -> Optional["System"]: + return await db.get_system(conn, system_id) + @staticmethod async def get_by_account(conn, account_id: int) -> Optional["System"]: return await db.get_system_by_account(conn, account_id) @@ -128,7 +131,29 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a else: return None - async def add_switch(self, conn, members: List[Member]): + async def add_switch(self, conn, members: List[Member]) -> Switch: + """ + Logs a new switch for a system. + + :raises: MembersAlreadyFrontingError, DuplicateSwitchMembersError + """ + new_ids = [member.id for member in members] + + last_switch = await self.get_latest_switch(conn) + + # If we have a switch logged before, make sure this isn't a dupe switch + if last_switch: + last_switch_members = await last_switch.fetch_members(conn) + last_ids = [member.id for member in last_switch_members] + + # We don't compare by set() here because swapping multiple is a valid operation + if last_ids == new_ids: + raise errors.MembersAlreadyFrontingError(members) + + # Check for dupes + if len(set(new_ids)) != len(new_ids): + raise errors.DuplicateSwitchMembersError() + async with conn.transaction(): switch_id = await db.add_switch(conn, self.id) @@ -136,6 +161,8 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a for member in members: await db.add_switch_member(conn, switch_id, member.id) + return await self.get_latest_switch(conn) + def get_member_name_limit(self) -> int: """Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag.""" if self.tag: