diff --git a/src/api_main.py b/src/api_main.py index e1cf1f1f..39ebc68d 100644 --- a/src/api_main.py +++ b/src/api_main.py @@ -12,187 +12,175 @@ from pluralkit.system import System logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") logger = logging.getLogger("pluralkit.api") - -def db_handler(f): - async def inner(request, *args, **kwargs): - async with request.app["pool"].acquire() as conn: - return await f(request, conn=conn, *args, **kwargs) - - return inner - - -def system_auth(f): - async def inner(request: web.Request, conn, *args, **kwargs): - token = request.headers.get("X-Token") - if not token: - token = request.query.get("token") - - if not token: +def require_system(f): + async def inner(request): + if "system" not in request: raise web.HTTPUnauthorized() - - system = await System.get_by_token(conn, token) - if not system: - raise web.HTTPUnauthorized() - - return await f(request, conn=conn, system=system, *args, **kwargs) - + return await f(request) return inner - -@db_handler -async def get_system(request: web.Request, conn): - system = await db.get_system_by_hid(conn, request.match_info["id"]) - - if not system: - raise web.HTTPNotFound() - - members = await db.get_all_members(conn, system.id) - - system_json = system.to_json() - system_json["members"] = [member.to_json() for member in members] - return web.json_response(system_json) - - -@db_handler -async def get_member(request: web.Request, conn): - member = await db.get_member_by_hid(conn, request.match_info["id"]) - - if not member: - raise web.HTTPNotFound() - - return web.json_response(member.to_json()) - - -@db_handler -async def get_switches(request: web.Request, conn): - system = await db.get_system_by_hid(conn, request.match_info["id"]) - - if not system: - raise web.HTTPNotFound() - - switches = await utils.get_front_history(conn, system.id, 99999) - - data = [{ - "timestamp": stamp.isoformat(), - "members": [member.hid for member in members] - } for stamp, members in switches] - - return web.json_response(data) - - -@db_handler -async def get_message(request: web.Request, conn): - message = await db.get_message(conn, request.match_info["id"]) - if not message: - raise web.HTTPNotFound() - - return web.json_response(message.to_json()) - - -@db_handler -async def get_switch(request: web.Request, conn): - system = await db.get_system_by_hid(conn, request.match_info["id"]) - - if not system: - raise web.HTTPNotFound() - - members, stamp = await utils.get_fronters(conn, system.id) - if not stamp: - # No switch has been registered at all - raise web.HTTPNotFound() - - data = { - "timestamp": stamp.isoformat(), - "members": [member.to_json() for member in members] - } - return web.json_response(data) - - -@db_handler -async def get_switch_name(request: web.Request, conn): - system = await db.get_system_by_hid(conn, request.match_info["id"]) - - if not system: - raise web.HTTPNotFound() - - members, stamp = await utils.get_fronters(conn, system.id) - return web.Response(text=members[0].name if members else "(nobody)") - - -@db_handler -async def get_switch_color(request: web.Request, conn): - system = await db.get_system_by_hid(conn, request.match_info["id"]) - - if not system: - raise web.HTTPNotFound() - - members, stamp = await utils.get_fronters(conn, system.id) - return web.Response(text=members[0].color if members else "#ffffff") - - -@db_handler -@system_auth -async def put_switch(request: web.Request, system: System, conn): - try: - req = await request.json() - except json.JSONDecodeError: - raise web.HTTPBadRequest(body="Invalid JSON") - - if isinstance(req, str): - req = [req] - elif not isinstance(req, list): - raise web.HTTPBadRequest(body="Body must be JSON string or list") - - members = [] - for member_name in req: - if not isinstance(member_name, str): - raise web.HTTPBadRequest(body="List value must be string") - - member = await Member.get_member_fuzzy(conn, system.id, member_name) - if not member: - raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name)) - members.append(member) - - switch = await system.add_switch(conn, members) - return web.json_response(await switch.to_json(conn)) - - -@db_handler -async def get_stats(request: web.Request, conn): - system_count = await db.system_count(conn) - member_count = await db.member_count(conn) - message_count = await db.message_count(conn) - - return web.json_response({ - "systems": system_count, - "members": member_count, - "messages": message_count - }) - - @web.middleware -async def render_pk_errors(request, handler): +async def error_middleware(request, handler): try: return await handler(request) + except json.JSONDecodeError: + raise web.HTTPBadRequest() except PluralKitError as e: - raise web.HTTPBadRequest(body=e.message) + return web.json_response({"error": e.message}, status=400) +@web.middleware +async def db_middleware(request, handler): + async with request.app["pool"].acquire() as conn: + request["conn"] = conn + return await handler(request) -app = web.Application(middlewares=[render_pk_errors]) -app.add_routes([ - web.get("/systems/{id}", get_system), - web.get("/systems/{id}/switches", get_switches), - web.get("/systems/{id}/switch", get_switch), - web.put("/systems/{id}/switch", put_switch), - web.get("/systems/{id}/switch/name", get_switch_name), - web.get("/systems/{id}/switch/color", get_switch_color), - web.get("/members/{id}", get_member), - web.get("/messages/{id}", get_message), - web.get("/stats", get_stats) -]) +@web.middleware +async def auth_middleware(request, handler): + token = request.headers.get("X-Token") or request.query.get("token") + if token: + system = await System.get_by_token(request["conn"], token) + if system: + request["system"] = system + return await handler(request) +class Handlers: + @require_system + async def get_system(request): + return web.json_response(request["system"].to_json()) + + async def get_other_system(request): + system_id = request.match_info.get("system") + system = await System.get_by_hid(request["conn"], system_id) + if not system: + raise web.HTTPNotFound() + return web.json_response(system.to_json()) + + async def get_system_members(request): + system_id = request.match_info.get("system") + system = await System.get_by_hid(request["conn"], system_id) + if not system: + raise web.HTTPNotFound() + + members = await system.get_members(request["conn"]) + return web.json_response([m.to_json() for m in members]) + + async def get_system_switches(request): + system_id = request.match_info.get("system") + system = await System.get_by_hid(request["conn"], system_id) + if not system: + raise web.HTTPNotFound() + + switches = await system.get_switches(request["conn"], 9999) + + cache = {} + async def hid_getter(member_id): + if not member_id in cache: + cache[member_id] = await Member.get_member_by_id(request["conn"], member_id) + return cache[member_id].hid + + return web.json_response([await s.to_json(hid_getter) for s in switches]) + + @require_system + async def patch_system(request): + req = await request.json() + if "name" in req: + await request["system"].set_name(request["conn"], req["name"]) + if "description" in req: + await request["system"].set_description(request["conn"], req["description"]) + if "tag" in req: + await request["system"].set_tag(request["conn"], req["tag"]) + if "avatar_url" in req: + await request["system"].set_avatar(request["conn"], req["name"]) + if "tz" in req: + await request["system"].set_time_zone(request["conn"], req["tz"]) + return web.json_response((await System.get_by_id(request["conn"], request["system"].id)).to_json()) + + async def get_member(request): + member_id = request.match_info.get("member") + member = await Member.get_member_by_hid(request["conn"], None, member_id) + if not member: + raise web.HTTPNotFound() + return web.json_response(member.to_json()) + + @require_system + async def post_member(request): + member = await request["system"].create_member(request["conn"]) + return web.json_response(member.to_json()) + + @require_system + async def patch_member(request): + member_id = request.match_info.get("member") + member = await Member.get_member_by_hid(request["conn"], None, member_id) + if not member: + raise web.HTTPNotFound() + if member.system != request["system"].id: + raise web.HTTPUnauthorized() + + req = await request.json() + if "name" in req: + await member.set_name(request["conn"], req["name"]) + if "description" in req: + await member.set_description(request["conn"], req["description"]) + if "avatar_url" in req: + await member.set_avatar_url(request["conn"], req["avatar_url"]) + if "color" in req: + await member.set_color(request["conn"], req["color"]) + if "birthday" in req: + await member.set_birthdate(request["conn"], req["birthday"]) + if "pronouns" in req: + await member.set_pronouns(request["conn"], req["pronouns"]) + if "prefix" in req or "suffix" in req: + await member.set_proxy_tags(request["conn"], req.get("prefix", member.prefix), req.get("suffix", member.suffix)) + return web.json_response((await Member.get_member_by_id(request["conn"], member.id)).to_json()) + + @require_system + async def delete_member(request): + member_id = request.match_info.get("member") + member = await Member.get_member_by_hid(request["conn"], None, member_id) + if not member: + raise web.HTTPNotFound() + if member.system != request["system"].id: + raise web.HTTPUnauthorized() + + await member.delete(request["conn"]) + + @require_system + async def post_switch(request): + req = await request.json() + if isinstance(req, str): + req = [req] + if req is None: + req = [] + if not isinstance(req, list): + raise web.HTTPBadRequest() + + members = [await Member.get_member_by_hid(request["conn"], request["system"].id, hid) for hid in req] + if not all(members): + raise web.HTTPNotFound(body=json.dumps({"error": "One or more members not found."})) + + switch = await request["system"].add_switch(request["conn"], members) + + hids = {member.id: member.hid for member in members} + async def hid_getter(mid): + return hids[mid] + + return web.json_response(await switch.to_json(hid_getter)) async def run(): + app = web.Application(middlewares=[db_middleware, auth_middleware, error_middleware]) + + app.add_routes([ + web.get("/s", Handlers.get_system), + web.post("/s/switches", Handlers.post_switch), + web.get("/s/{system}", Handlers.get_other_system), + web.get("/s/{system}/members", Handlers.get_system_members), + web.get("/s/{system}/switches", Handlers.get_system_switches), + web.patch("/s", Handlers.patch_system), + web.get("/m/{member}", Handlers.get_member), + web.post("/m", Handlers.post_member), + web.patch("/m/{member}", Handlers.patch_member), + web.delete("/m/{member}", Handlers.delete_member) + ]) app["pool"] = await db.connect( os.environ["DATABASE_URI"] ) diff --git a/src/pluralkit/bot/__init__.py b/src/pluralkit/bot/__init__.py index 357d1a9d..434ebe4f 100644 --- a/src/pluralkit/bot/__init__.py +++ b/src/pluralkit/bot/__init__.py @@ -14,13 +14,19 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") -class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])): +class Config: required_fields = ["database_uri", "token"] + fields = ["database_uri", "token", "log_channel"] database_uri: str token: str log_channel: str + def __init__(self, database_uri: str, token: str, log_channel: str = None): + self.database_uri = database_uri + self.token = token + self.log_channel = log_channel + @staticmethod def from_file_and_env(filename: str) -> "Config": try: @@ -36,7 +42,7 @@ class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])): raise e # Override with environment variables - for f in Config._fields: + for f in Config.fields: if f.upper() in os.environ: config[f] = os.environ[f.upper()] diff --git a/src/pluralkit/member.py b/src/pluralkit/member.py index 442c4d63..e058b39f 100644 --- a/src/pluralkit/member.py +++ b/src/pluralkit/member.py @@ -38,6 +38,11 @@ class Member(namedtuple("Member", "suffix": self.suffix } + @staticmethod + async def get_member_by_id(conn, member_id: int) -> Optional["Member"]: + """Fetch a member with the given internal member ID from the database.""" + return await db.get_member(conn, member_id) + @staticmethod async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]": """Fetch a member by the given name in the given system from the database.""" diff --git a/src/pluralkit/switch.py b/src/pluralkit/switch.py index 86d6b539..7536fa63 100644 --- a/src/pluralkit/switch.py +++ b/src/pluralkit/switch.py @@ -21,8 +21,8 @@ class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])): async def move(self, conn, new_timestamp): await db.move_switch(conn, self.system, self.id, new_timestamp) - async def to_json(self, conn): + async def to_json(self, hid_getter): return { "timestamp": self.timestamp.isoformat(), - "members": [member.hid for member in await self.fetch_members(conn)] + "members": [await hid_getter(m) for m in self.members] } diff --git a/src/pluralkit/system.py b/src/pluralkit/system.py index bfb70dd5..a0ea39b8 100644 --- a/src/pluralkit/system.py +++ b/src/pluralkit/system.py @@ -38,6 +38,10 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a @staticmethod async def get_by_token(conn, token: str) -> Optional["System"]: return await db.get_system_by_token(conn, token) + + @staticmethod + async def get_by_hid(conn, hid: str) -> Optional["System"]: + return await db.get_system_by_hid(conn, hid) @staticmethod async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System": @@ -234,7 +238,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a :returns: The `pytz.tzinfo` instance of the newly set time zone. """ - tz = pytz.timezone(tz_name or "UTC") + try: + tz = pytz.timezone(tz_name or "UTC") + except pytz.UnknownTimeZoneError: + raise errors.InvalidTimeZoneError(tz_name) + await db.update_system_field(conn, self.id, "ui_tz", tz.zone) return tz @@ -304,5 +312,6 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a "name": self.name, "description": self.description, "tag": self.tag, - "avatar_url": self.avatar_url + "avatar_url": self.avatar_url, + "tz": self.ui_tz }