Add API endpoint for logging new switches (+ refactor)
This commit is contained in:
parent
121f8ab8c3
commit
8ccee1d6fa
@ -1,19 +1,40 @@
|
|||||||
import os
|
import json
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from pluralkit import db, utils
|
from pluralkit import db, utils
|
||||||
|
from pluralkit.errors import PluralKitError
|
||||||
|
from pluralkit.member import Member
|
||||||
|
from pluralkit.system import System
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||||
logger = logging.getLogger("pluralkit.api")
|
logger = logging.getLogger("pluralkit.api")
|
||||||
|
|
||||||
|
|
||||||
def db_handler(f):
|
def db_handler(f):
|
||||||
async def inner(request):
|
async def inner(request, *args, **kwargs):
|
||||||
async with request.app["pool"].acquire() as conn:
|
async with request.app["pool"].acquire() as conn:
|
||||||
return await f(request, conn)
|
return await f(request, conn=conn, *args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def system_auth(f):
|
||||||
|
async def inner(request: web.Request, conn, *args, **kwargs):
|
||||||
|
token = request.headers.get("X-Token")
|
||||||
|
if not token:
|
||||||
|
token = request.query.get("token")
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
raise web.HTTPUnauthorized()
|
||||||
|
|
||||||
|
system = await System.get_by_token(conn, token)
|
||||||
|
if not system:
|
||||||
|
raise web.HTTPUnauthorized()
|
||||||
|
|
||||||
|
return await f(request, conn=conn, system=system, *args, **kwargs)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
@ -58,6 +79,7 @@ async def get_switches(request: web.Request, conn):
|
|||||||
|
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
@db_handler
|
||||||
async def get_message(request: web.Request, conn):
|
async def get_message(request: web.Request, conn):
|
||||||
message = await db.get_message(conn, request.match_info["id"])
|
message = await db.get_message(conn, request.match_info["id"])
|
||||||
@ -66,6 +88,7 @@ async def get_message(request: web.Request, conn):
|
|||||||
|
|
||||||
return web.json_response(message.to_json())
|
return web.json_response(message.to_json())
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
@db_handler
|
||||||
async def get_switch(request: web.Request, conn):
|
async def get_switch(request: web.Request, conn):
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
||||||
@ -84,6 +107,7 @@ async def get_switch(request: web.Request, conn):
|
|||||||
}
|
}
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
@db_handler
|
||||||
async def get_switch_name(request: web.Request, conn):
|
async def get_switch_name(request: web.Request, conn):
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
||||||
@ -94,6 +118,7 @@ async def get_switch_name(request: web.Request, conn):
|
|||||||
members, stamp = await utils.get_fronters(conn, system.id)
|
members, stamp = await utils.get_fronters(conn, system.id)
|
||||||
return web.Response(text=members[0].name if members else "(nobody)")
|
return web.Response(text=members[0].name if members else "(nobody)")
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
@db_handler
|
||||||
async def get_switch_color(request: web.Request, conn):
|
async def get_switch_color(request: web.Request, conn):
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
||||||
@ -104,6 +129,34 @@ async def get_switch_color(request: web.Request, conn):
|
|||||||
members, stamp = await utils.get_fronters(conn, system.id)
|
members, stamp = await utils.get_fronters(conn, system.id)
|
||||||
return web.Response(text=members[0].color if members else "#ffffff")
|
return web.Response(text=members[0].color if members else "#ffffff")
|
||||||
|
|
||||||
|
|
||||||
|
@db_handler
|
||||||
|
@system_auth
|
||||||
|
async def put_switch(request: web.Request, system: System, conn):
|
||||||
|
try:
|
||||||
|
req = await request.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise web.HTTPBadRequest(body="Invalid JSON")
|
||||||
|
|
||||||
|
if isinstance(req, str):
|
||||||
|
req = [req]
|
||||||
|
elif not isinstance(req, list):
|
||||||
|
raise web.HTTPBadRequest(body="Body must be JSON string or list")
|
||||||
|
|
||||||
|
members = []
|
||||||
|
for member_name in req:
|
||||||
|
if not isinstance(member_name, str):
|
||||||
|
raise web.HTTPBadRequest(body="List value must be string")
|
||||||
|
|
||||||
|
member = await Member.get_member_fuzzy(conn, system.id, member_name)
|
||||||
|
if not member:
|
||||||
|
raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name))
|
||||||
|
members.append(member)
|
||||||
|
|
||||||
|
switch = await system.add_switch(conn, members)
|
||||||
|
return web.json_response(await switch.to_json(conn))
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
@db_handler
|
||||||
async def get_stats(request: web.Request, conn):
|
async def get_stats(request: web.Request, conn):
|
||||||
system_count = await db.system_count(conn)
|
system_count = await db.system_count(conn)
|
||||||
@ -116,11 +169,21 @@ async def get_stats(request: web.Request, conn):
|
|||||||
"messages": message_count
|
"messages": message_count
|
||||||
})
|
})
|
||||||
|
|
||||||
app = web.Application()
|
|
||||||
|
@web.middleware
|
||||||
|
async def render_pk_errors(request, handler):
|
||||||
|
try:
|
||||||
|
return await handler(request)
|
||||||
|
except PluralKitError as e:
|
||||||
|
raise web.HTTPBadRequest(body=e.message)
|
||||||
|
|
||||||
|
|
||||||
|
app = web.Application(middlewares=[render_pk_errors])
|
||||||
app.add_routes([
|
app.add_routes([
|
||||||
web.get("/systems/{id}", get_system),
|
web.get("/systems/{id}", get_system),
|
||||||
web.get("/systems/{id}/switches", get_switches),
|
web.get("/systems/{id}/switches", get_switches),
|
||||||
web.get("/systems/{id}/switch", get_switch),
|
web.get("/systems/{id}/switch", get_switch),
|
||||||
|
web.put("/systems/{id}/switch", put_switch),
|
||||||
web.get("/systems/{id}/switch/name", get_switch_name),
|
web.get("/systems/{id}/switch/name", get_switch_name),
|
||||||
web.get("/systems/{id}/switch/color", get_switch_color),
|
web.get("/systems/{id}/switch/color", get_switch_color),
|
||||||
web.get("/members/{id}", get_member),
|
web.get("/members/{id}", get_member),
|
||||||
|
@ -12,11 +12,10 @@ async def member_root(ctx: CommandContext):
|
|||||||
elif ctx.match("set"):
|
elif ctx.match("set"):
|
||||||
await member_set(ctx)
|
await member_set(ctx)
|
||||||
# TODO "pk;member list"
|
# TODO "pk;member list"
|
||||||
|
elif not ctx.has_next():
|
||||||
if not ctx.has_next():
|
|
||||||
raise CommandError("Must pass a subcommand. For a list of subcommands, type `pk;member help`.")
|
raise CommandError("Must pass a subcommand. For a list of subcommands, type `pk;member help`.")
|
||||||
|
else:
|
||||||
await specific_member_root(ctx)
|
await specific_member_root(ctx)
|
||||||
|
|
||||||
|
|
||||||
async def specific_member_root(ctx: CommandContext):
|
async def specific_member_root(ctx: CommandContext):
|
||||||
|
@ -33,19 +33,6 @@ async def switch_member(ctx: CommandContext):
|
|||||||
while ctx.has_next():
|
while ctx.has_next():
|
||||||
members.append(await ctx.pop_member())
|
members.append(await ctx.pop_member())
|
||||||
|
|
||||||
# Compare requested switch IDs and existing fronter IDs to check for existing switches
|
|
||||||
# Lists, because order matters, it makes sense to just swap fronters
|
|
||||||
member_ids = [member.id for member in members]
|
|
||||||
fronter_ids = (await pluralkit.utils.get_fronter_ids(ctx.conn, system.id))[0]
|
|
||||||
if member_ids == fronter_ids:
|
|
||||||
if len(members) == 1:
|
|
||||||
raise CommandError("{} is already fronting.".format(members[0].name))
|
|
||||||
raise CommandError("Members {} are already fronting.".format(", ".join([m.name for m in members])))
|
|
||||||
|
|
||||||
# Also make sure there aren't any duplicates
|
|
||||||
if len(set(member_ids)) != len(member_ids):
|
|
||||||
raise CommandError("Duplicate members in member list.")
|
|
||||||
|
|
||||||
# Log the switch
|
# Log the switch
|
||||||
await system.add_switch(ctx.conn, members)
|
await system.add_switch(ctx.conn, members)
|
||||||
|
|
||||||
|
@ -12,7 +12,8 @@ class PluralKitError(Exception):
|
|||||||
|
|
||||||
class ExistingSystemError(PluralKitError):
|
class ExistingSystemError(PluralKitError):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
|
super().__init__(
|
||||||
|
"You already have a system registered. To delete your system, use `pk;system delete`, or to unlink your system from this account, use `pk;system unlink`.")
|
||||||
|
|
||||||
|
|
||||||
class DescriptionTooLongError(PluralKitError):
|
class DescriptionTooLongError(PluralKitError):
|
||||||
@ -27,13 +28,16 @@ class TagTooLongError(PluralKitError):
|
|||||||
|
|
||||||
class TagTooLongWithMembersError(PluralKitError):
|
class TagTooLongWithMembersError(PluralKitError):
|
||||||
def __init__(self, member_names):
|
def __init__(self, member_names):
|
||||||
super().__init__("The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(", ".join(member_names)))
|
super().__init__(
|
||||||
|
"The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(
|
||||||
|
", ".join(member_names)))
|
||||||
self.member_names = member_names
|
self.member_names = member_names
|
||||||
|
|
||||||
|
|
||||||
class CustomEmojiError(PluralKitError):
|
class CustomEmojiError(PluralKitError):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.")
|
super().__init__(
|
||||||
|
"Due to a Discord limitation, custom emojis aren't supported. Please use a standard emoji instead.")
|
||||||
|
|
||||||
|
|
||||||
class InvalidAvatarURLError(PluralKitError):
|
class InvalidAvatarURLError(PluralKitError):
|
||||||
@ -60,7 +64,8 @@ class UnlinkingLastAccountError(PluralKitError):
|
|||||||
class MemberNameTooLongError(PluralKitError):
|
class MemberNameTooLongError(PluralKitError):
|
||||||
def __init__(self, tag_present: bool):
|
def __init__(self, tag_present: bool):
|
||||||
if tag_present:
|
if tag_present:
|
||||||
super().__init__("The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.")
|
super().__init__(
|
||||||
|
"The maximum length of a name plus the system tag is 32 characters. Please reduce the length of the tag, or choose a shorter member name.")
|
||||||
else:
|
else:
|
||||||
super().__init__("The maximum length of a member name is 32 characters.")
|
super().__init__("The maximum length of a member name is 32 characters.")
|
||||||
|
|
||||||
@ -69,6 +74,22 @@ class InvalidColorError(PluralKitError):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("Color must be a valid hex color. (eg. #ff0000)")
|
super().__init__("Color must be a valid hex color. (eg. #ff0000)")
|
||||||
|
|
||||||
|
|
||||||
class InvalidDateStringError(PluralKitError):
|
class InvalidDateStringError(PluralKitError):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("Invalid date string. Date must be in ISO-8601 format (YYYY-MM-DD, eg. 1999-07-25).")
|
super().__init__("Invalid date string. Date must be in ISO-8601 format (YYYY-MM-DD, eg. 1999-07-25).")
|
||||||
|
|
||||||
|
|
||||||
|
class MembersAlreadyFrontingError(PluralKitError):
|
||||||
|
def __init__(self, members: "List[Member]"):
|
||||||
|
if len(members) == 0:
|
||||||
|
super().__init__("There are already no members fronting.")
|
||||||
|
elif len(members) == 1:
|
||||||
|
super().__init__("Member {} is already fronting.".format(members[0].name))
|
||||||
|
else:
|
||||||
|
super().__init__("Members {} are already fronting.".format(", ".join([member.name for member in members])))
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateSwitchMembersError(PluralKitError):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("Duplicate members in member list.")
|
||||||
|
@ -54,6 +54,16 @@ class Member(namedtuple("Member",
|
|||||||
|
|
||||||
return member
|
return member
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_member_fuzzy(conn, system_id: int, name: str) -> "Optional[Member]":
|
||||||
|
by_hid = await Member.get_member_by_hid(conn, system_id, name)
|
||||||
|
if by_hid:
|
||||||
|
return by_hid
|
||||||
|
|
||||||
|
by_name = await Member.get_member_by_name(conn, system_id, name)
|
||||||
|
return by_name
|
||||||
|
|
||||||
|
|
||||||
async def set_name(self, conn, new_name: str):
|
async def set_name(self, conn, new_name: str):
|
||||||
"""
|
"""
|
||||||
Set the name of a member.
|
Set the name of a member.
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
@ -17,4 +16,10 @@ class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])):
|
|||||||
return await db.get_members(conn, self.members)
|
return await db.get_members(conn, self.members)
|
||||||
|
|
||||||
async def delete(self, conn):
|
async def delete(self, conn):
|
||||||
await db.delete_switch(conn, self.id)
|
await db.delete_switch(conn, self.id)
|
||||||
|
|
||||||
|
async def to_json(self, conn):
|
||||||
|
return {
|
||||||
|
"timestamp": self.timestamp.isoformat(),
|
||||||
|
"members": [member.hid for member in await self.fetch_members(conn)]
|
||||||
|
}
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from collections.__init__ import namedtuple
|
from collections.__init__ import namedtuple
|
||||||
|
from datetime import datetime
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from pluralkit import db, errors
|
from pluralkit import db, errors
|
||||||
@ -22,6 +21,10 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
token: str
|
token: str
|
||||||
created: datetime
|
created: datetime
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_by_id(conn, system_id: int) -> Optional["System"]:
|
||||||
|
return await db.get_system(conn, system_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_by_account(conn, account_id: int) -> Optional["System"]:
|
async def get_by_account(conn, account_id: int) -> Optional["System"]:
|
||||||
return await db.get_system_by_account(conn, account_id)
|
return await db.get_system_by_account(conn, account_id)
|
||||||
@ -128,7 +131,29 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def add_switch(self, conn, members: List[Member]):
|
async def add_switch(self, conn, members: List[Member]) -> Switch:
|
||||||
|
"""
|
||||||
|
Logs a new switch for a system.
|
||||||
|
|
||||||
|
:raises: MembersAlreadyFrontingError, DuplicateSwitchMembersError
|
||||||
|
"""
|
||||||
|
new_ids = [member.id for member in members]
|
||||||
|
|
||||||
|
last_switch = await self.get_latest_switch(conn)
|
||||||
|
|
||||||
|
# If we have a switch logged before, make sure this isn't a dupe switch
|
||||||
|
if last_switch:
|
||||||
|
last_switch_members = await last_switch.fetch_members(conn)
|
||||||
|
last_ids = [member.id for member in last_switch_members]
|
||||||
|
|
||||||
|
# We don't compare by set() here because swapping multiple is a valid operation
|
||||||
|
if last_ids == new_ids:
|
||||||
|
raise errors.MembersAlreadyFrontingError(members)
|
||||||
|
|
||||||
|
# Check for dupes
|
||||||
|
if len(set(new_ids)) != len(new_ids):
|
||||||
|
raise errors.DuplicateSwitchMembersError()
|
||||||
|
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
switch_id = await db.add_switch(conn, self.id)
|
switch_id = await db.add_switch(conn, self.id)
|
||||||
|
|
||||||
@ -136,6 +161,8 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
for member in members:
|
for member in members:
|
||||||
await db.add_switch_member(conn, switch_id, member.id)
|
await db.add_switch_member(conn, switch_id, member.id)
|
||||||
|
|
||||||
|
return await self.get_latest_switch(conn)
|
||||||
|
|
||||||
def get_member_name_limit(self) -> int:
|
def get_member_name_limit(self) -> int:
|
||||||
"""Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
|
"""Returns the maximum length a member's name or nickname is allowed to be in order for the member to be proxied. Depends on the system tag."""
|
||||||
if self.tag:
|
if self.tag:
|
||||||
|
Loading…
Reference in New Issue
Block a user