Add API endpoint for logging new switches (+ refactor)

This commit is contained in:
Ske 2018-12-10 22:00:34 +01:00
parent 121f8ab8c3
commit 8ccee1d6fa
7 changed files with 144 additions and 32 deletions

View File

@ -1,19 +1,40 @@
import os import json
import logging import logging
import os
from aiohttp import web from aiohttp import web
from pluralkit import db, utils from pluralkit import db, utils
from pluralkit.errors import PluralKitError
from pluralkit.member import Member
from pluralkit.system import System
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logger = logging.getLogger("pluralkit.api") logger = logging.getLogger("pluralkit.api")
def db_handler(f): def db_handler(f):
async def inner(request): async def inner(request, *args, **kwargs):
async with request.app["pool"].acquire() as conn: async with request.app["pool"].acquire() as conn:
return await f(request, conn) return await f(request, conn=conn, *args, **kwargs)
return inner
def system_auth(f):
async def inner(request: web.Request, conn, *args, **kwargs):
token = request.headers.get("X-Token")
if not token:
token = request.query.get("token")
if not token:
raise web.HTTPUnauthorized()
system = await System.get_by_token(conn, token)
if not system:
raise web.HTTPUnauthorized()
return await f(request, conn=conn, system=system, *args, **kwargs)
return inner return inner
@ -58,6 +79,7 @@ async def get_switches(request: web.Request, conn):
return web.json_response(data) return web.json_response(data)
@db_handler @db_handler
async def get_message(request: web.Request, conn): async def get_message(request: web.Request, conn):
message = await db.get_message(conn, request.match_info["id"]) message = await db.get_message(conn, request.match_info["id"])
@ -66,6 +88,7 @@ async def get_message(request: web.Request, conn):
return web.json_response(message.to_json()) return web.json_response(message.to_json())
@db_handler @db_handler
async def get_switch(request: web.Request, conn): async def get_switch(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"]) system = await db.get_system_by_hid(conn, request.match_info["id"])
@ -84,6 +107,7 @@ async def get_switch(request: web.Request, conn):
} }
return web.json_response(data) return web.json_response(data)
@db_handler @db_handler
async def get_switch_name(request: web.Request, conn): async def get_switch_name(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"]) system = await db.get_system_by_hid(conn, request.match_info["id"])
@ -94,6 +118,7 @@ async def get_switch_name(request: web.Request, conn):
members, stamp = await utils.get_fronters(conn, system.id) members, stamp = await utils.get_fronters(conn, system.id)
return web.Response(text=members[0].name if members else "(nobody)") return web.Response(text=members[0].name if members else "(nobody)")
@db_handler @db_handler
async def get_switch_color(request: web.Request, conn): async def get_switch_color(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"]) system = await db.get_system_by_hid(conn, request.match_info["id"])
@ -104,6 +129,34 @@ async def get_switch_color(request: web.Request, conn):
members, stamp = await utils.get_fronters(conn, system.id) members, stamp = await utils.get_fronters(conn, system.id)
return web.Response(text=members[0].color if members else "#ffffff") return web.Response(text=members[0].color if members else "#ffffff")
@db_handler
@system_auth
async def put_switch(request: web.Request, system: System, conn):
try:
req = await request.json()
except json.JSONDecodeError:
raise web.HTTPBadRequest(body="Invalid JSON")
if isinstance(req, str):
req = [req]
elif not isinstance(req, list):
raise web.HTTPBadRequest(body="Body must be JSON string or list")
members = []
for member_name in req:
if not isinstance(member_name, str):
raise web.HTTPBadRequest(body="List value must be string")
member = await Member.get_member_fuzzy(conn, system.id, member_name)
if not member:
raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name))
members.append(member)
switch = await system.add_switch(conn, members)
return web.json_response(await switch.to_json(conn))
@db_handler @db_handler
async def get_stats(request: web.Request, conn): async def get_stats(request: web.Request, conn):
system_count = await db.system_count(conn) system_count = await db.system_count(conn)
@ -116,11 +169,21 @@ async def get_stats(request: web.Request, conn):
"messages": message_count "messages": message_count
}) })
app = web.Application()
@web.middleware
async def render_pk_errors(request, handler):
try:
return await handler(request)
except PluralKitError as e:
raise web.HTTPBadRequest(body=e.message)
app = web.Application(middlewares=[render_pk_errors])
app.add_routes([ app.add_routes([
web.get("/systems/{id}", get_system), web.get("/systems/{id}", get_system),
web.get("/systems/{id}/switches", get_switches), web.get("/systems/{id}/switches", get_switches),
web.get("/systems/{id}/switch", get_switch), web.get("/systems/{id}/switch", get_switch),
web.put("/systems/{id}/switch", put_switch),
web.get("/systems/{id}/switch/name", get_switch_name), web.get("/systems/{id}/switch/name", get_switch_name),
web.get("/systems/{id}/switch/color", get_switch_color), web.get("/systems/{id}/switch/color", get_switch_color),
web.get("/members/{id}", get_member), web.get("/members/{id}", get_member),

View File

@ -12,11 +12,10 @@ async def member_root(ctx: CommandContext):
elif ctx.match("set"): elif ctx.match("set"):
await member_set(ctx) await member_set(ctx)
# TODO "pk;member list" # TODO "pk;member list"
elif not ctx.has_next():
if not ctx.has_next():
raise CommandError("Must pass a subcommand. For a list of subcommands, type `pk;member help`.") raise CommandError("Must pass a subcommand. For a list of subcommands, type `pk;member help`.")
else:
await specific_member_root(ctx) await specific_member_root(ctx)
async def specific_member_root(ctx: CommandContext): async def specific_member_root(ctx: CommandContext):

View File

@ -33,19 +33,6 @@ async def switch_member(ctx: CommandContext):
while ctx.has_next(): while ctx.has_next():
members.append(await ctx.pop_member()) members.append(await ctx.pop_member())
# Compare requested switch IDs and existing fronter IDs to check for existing switches
# Lists, because order matters, it makes sense to just swap fronters
member_ids = [member.id for member in members]
fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0]
if member_ids == fronter_ids:
if len(members) == 1:
raise CommandError("{} is already fronting.".format(members[0].name))
raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
# Also make sure there aren't any duplicates
if len(set(member_ids)) != len(member_ids):
raise CommandError("Duplicate members in member list.")
# Log the switch # Log the switch
await system.add_switch(ctx.conn, members) await system.add_switch(ctx.conn, members)

View File

@ -12,7 +12,8 @@ class PluralKitError(Exception):
class ExistingSystemError(PluralKitError): class ExistingSystemError(PluralKitError):
def __init__(self): def __init__(self):
super().__init__("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.") super().__init__(
"You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
class DescriptionTooLongError(PluralKitError): class DescriptionTooLongError(PluralKitError):
@ -27,13 +28,16 @@ class TagTooLongError(PluralKitError):
class TagTooLongWithMembersError(PluralKitError): class TagTooLongWithMembersError(PluralKitError):
def __init__(self, member_names): def __init__(self, member_names):
super().__init__("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(", ".join(member_names))) super().__init__(
"The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(
", ".join(member_names)))
self.member_names = member_names self.member_names = member_names
class CustomEmojiError(PluralKitError): class CustomEmojiError(PluralKitError):
def __init__(self): def __init__(self):
super().__init__("Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.") super().__init__(
"Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.")
class InvalidAvatarURLError(PluralKitError): class InvalidAvatarURLError(PluralKitError):
@ -60,7 +64,8 @@ class UnlinkingLastAccountError(PluralKitError):
class MemberNameTooLongError(PluralKitError): class MemberNameTooLongError(PluralKitError):
def __init__(self, tag_present: bool): def __init__(self, tag_present: bool):
if tag_present: if tag_present:
super().__init__("The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.") super().__init__(
"The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.")
else: else:
super().__init__("The maximum length of a member name is 32 characters.") super().__init__("The maximum length of a member name is 32 characters.")
@ -69,6 +74,22 @@ class InvalidColorError(PluralKitError):
def __init__(self): def __init__(self):
super().__init__("Color must be a valid hex color. (eg. #ff0000)") super().__init__("Color must be a valid hex color. (eg. #ff0000)")
class InvalidDateStringError(PluralKitError): class InvalidDateStringError(PluralKitError):
def __init__(self): def __init__(self):
super().__init__("Invalid date string. Date must be in ISO-8601 format (YYYY-MM-DD, eg. 1999-07-25).") super().__init__("Invalid date string. Date must be in ISO-8601 format (YYYY-MM-DD, eg. 1999-07-25).")
class MembersAlreadyFrontingError(PluralKitError):
def __init__(self, members: "List[Member]"):
if len(members) == 0:
super().__init__("There are already no members fronting.")
elif len(members) == 1:
super().__init__("Member {} is already fronting.".format(members[0].name))
else:
super().__init__("Members {} are already fronting.".format(", ".join([member.name for member in members])))
class DuplicateSwitchMembersError(PluralKitError):
def __init__(self):
super().__init__("Duplicate members in member list.")

View File

@ -54,6 +54,16 @@ class Member(namedtuple("Member",
return member return member
@staticmethod
async def get_member_fuzzy(conn, system_id: int, name: str) -> "Optional[Member]":
by_hid = await Member.get_member_by_hid(conn, system_id, name)
if by_hid:
return by_hid
by_name = await Member.get_member_by_name(conn, system_id, name)
return by_name
async def set_name(self, conn, new_name: str): async def set_name(self, conn, new_name: str):
""" """
Set the name of a member. Set the name of a member.

View File

@ -1,6 +1,5 @@
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from pluralkit import db from pluralkit import db
@ -17,4 +16,10 @@ class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])):
return await db.get_members(conn, self.members) return await db.get_members(conn, self.members)
async def delete(self, conn): async def delete(self, conn):
await db.delete_switch(conn, self.id) await db.delete_switch(conn, self.id)
async def to_json(self, conn):
return {
"timestamp": self.timestamp.isoformat(),
"members": [member.hid for member in await self.fetch_members(conn)]
}

View File

@ -1,9 +1,8 @@
import random import random
import re import re
import string import string
from datetime import datetime
from collections.__init__ import namedtuple from collections.__init__ import namedtuple
from datetime import datetime
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from pluralkit import db, errors from pluralkit import db, errors
@ -22,6 +21,10 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
token: str token: str
created: datetime created: datetime
@staticmethod
async def get_by_id(conn, system_id: int) -> Optional["System"]:
return await db.get_system(conn, system_id)
@staticmethod @staticmethod
async def get_by_account(conn, account_id: int) -> Optional["System"]: async def get_by_account(conn, account_id: int) -> Optional["System"]:
return await db.get_system_by_account(conn, account_id) return await db.get_system_by_account(conn, account_id)
@ -128,7 +131,29 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
else: else:
return None return None
async def add_switch(self, conn, members: List[Member]): async def add_switch(self, conn, members: List[Member]) -> Switch:
"""
Logs a new switch for a system.
:raises: MembersAlreadyFrontingError, DuplicateSwitchMembersError
"""
new_ids = [member.id for member in members]
last_switch = await self.get_latest_switch(conn)
# If we have a switch logged before, make sure this isn't a dupe switch
if last_switch:
last_switch_members = await last_switch.fetch_members(conn)
last_ids = [member.id for member in last_switch_members]
# We don't compare by set() here because swapping multiple is a valid operation
if last_ids == new_ids:
raise errors.MembersAlreadyFrontingError(members)
# Check for dupes
if len(set(new_ids)) != len(new_ids):
raise errors.DuplicateSwitchMembersError()
async with conn.transaction(): async with conn.transaction():
switch_id = await db.add_switch(conn, self.id) switch_id = await db.add_switch(conn, self.id)
@ -136,6 +161,8 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
for member in members: for member in members:
await db.add_switch_member(conn, switch_id, member.id) await db.add_switch_member(conn, switch_id, member.id)
return await self.get_latest_switch(conn)
def get_member_name_limit(self) -> int: def get_member_name_limit(self) -> int:
"""Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag.""" """Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
if self.tag: if self.tag: