Rework API
This commit is contained in:
		
							
								
								
									
										326
									
								
								src/api_main.py
									
									
									
									
									
								
							
							
						
						
									
										326
									
								
								src/api_main.py
									
									
									
									
									
								
							| @@ -12,187 +12,175 @@ from pluralkit.system import System | |||||||
| logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") | ||||||
| logger = logging.getLogger("pluralkit.api") | logger = logging.getLogger("pluralkit.api") | ||||||
|  |  | ||||||
|  | def require_system(f): | ||||||
| def db_handler(f): |     async def inner(request): | ||||||
|     async def inner(request, *args, **kwargs): |         if "system" not in request: | ||||||
|         async with request.app["pool"].acquire() as conn: |  | ||||||
|             return await f(request, conn=conn, *args, **kwargs) |  | ||||||
|  |  | ||||||
|     return inner |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def system_auth(f): |  | ||||||
|     async def inner(request: web.Request, conn, *args, **kwargs): |  | ||||||
|         token = request.headers.get("X-Token") |  | ||||||
|         if not token: |  | ||||||
|             token = request.query.get("token") |  | ||||||
|  |  | ||||||
|         if not token: |  | ||||||
|             raise web.HTTPUnauthorized() |             raise web.HTTPUnauthorized() | ||||||
|  |         return await f(request) | ||||||
|         system = await System.get_by_token(conn, token) |  | ||||||
|         if not system: |  | ||||||
|             raise web.HTTPUnauthorized() |  | ||||||
|  |  | ||||||
|         return await f(request, conn=conn, system=system, *args, **kwargs) |  | ||||||
|  |  | ||||||
|     return inner |     return inner | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_system(request: web.Request, conn): |  | ||||||
|     system = await db.get_system_by_hid(conn, request.match_info["id"]) |  | ||||||
|  |  | ||||||
|     if not system: |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     members = await db.get_all_members(conn, system.id) |  | ||||||
|  |  | ||||||
|     system_json = system.to_json() |  | ||||||
|     system_json["members"] = [member.to_json() for member in members] |  | ||||||
|     return web.json_response(system_json) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_member(request: web.Request, conn): |  | ||||||
|     member = await db.get_member_by_hid(conn, request.match_info["id"]) |  | ||||||
|  |  | ||||||
|     if not member: |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     return web.json_response(member.to_json()) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_switches(request: web.Request, conn): |  | ||||||
|     system = await db.get_system_by_hid(conn, request.match_info["id"]) |  | ||||||
|  |  | ||||||
|     if not system: |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     switches = await utils.get_front_history(conn, system.id, 99999) |  | ||||||
|  |  | ||||||
|     data = [{ |  | ||||||
|         "timestamp": stamp.isoformat(), |  | ||||||
|         "members": [member.hid for member in members] |  | ||||||
|     } for stamp, members in switches] |  | ||||||
|  |  | ||||||
|     return web.json_response(data) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_message(request: web.Request, conn): |  | ||||||
|     message = await db.get_message(conn, request.match_info["id"]) |  | ||||||
|     if not message: |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     return web.json_response(message.to_json()) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_switch(request: web.Request, conn): |  | ||||||
|     system = await db.get_system_by_hid(conn, request.match_info["id"]) |  | ||||||
|  |  | ||||||
|     if not system: |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     members, stamp = await utils.get_fronters(conn, system.id) |  | ||||||
|     if not stamp: |  | ||||||
|         # No switch has been registered at all |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     data = { |  | ||||||
|         "timestamp": stamp.isoformat(), |  | ||||||
|         "members": [member.to_json() for member in members] |  | ||||||
|     } |  | ||||||
|     return web.json_response(data) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_switch_name(request: web.Request, conn): |  | ||||||
|     system = await db.get_system_by_hid(conn, request.match_info["id"]) |  | ||||||
|  |  | ||||||
|     if not system: |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     members, stamp = await utils.get_fronters(conn, system.id) |  | ||||||
|     return web.Response(text=members[0].name if members else "(nobody)") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_switch_color(request: web.Request, conn): |  | ||||||
|     system = await db.get_system_by_hid(conn, request.match_info["id"]) |  | ||||||
|  |  | ||||||
|     if not system: |  | ||||||
|         raise web.HTTPNotFound() |  | ||||||
|  |  | ||||||
|     members, stamp = await utils.get_fronters(conn, system.id) |  | ||||||
|     return web.Response(text=members[0].color if members else "#ffffff") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| @system_auth |  | ||||||
| async def put_switch(request: web.Request, system: System, conn): |  | ||||||
|     try: |  | ||||||
|         req = await request.json() |  | ||||||
|     except json.JSONDecodeError: |  | ||||||
|         raise web.HTTPBadRequest(body="Invalid JSON") |  | ||||||
|  |  | ||||||
|     if isinstance(req, str): |  | ||||||
|         req = [req] |  | ||||||
|     elif not isinstance(req, list): |  | ||||||
|         raise web.HTTPBadRequest(body="Body must be JSON string or list") |  | ||||||
|  |  | ||||||
|     members = [] |  | ||||||
|     for member_name in req: |  | ||||||
|         if not isinstance(member_name, str): |  | ||||||
|             raise web.HTTPBadRequest(body="List value must be string") |  | ||||||
|  |  | ||||||
|         member = await Member.get_member_fuzzy(conn, system.id, member_name) |  | ||||||
|         if not member: |  | ||||||
|             raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name)) |  | ||||||
|         members.append(member) |  | ||||||
|  |  | ||||||
|     switch = await system.add_switch(conn, members) |  | ||||||
|     return web.json_response(await switch.to_json(conn)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @db_handler |  | ||||||
| async def get_stats(request: web.Request, conn): |  | ||||||
|     system_count = await db.system_count(conn) |  | ||||||
|     member_count = await db.member_count(conn) |  | ||||||
|     message_count = await db.message_count(conn) |  | ||||||
|  |  | ||||||
|     return web.json_response({ |  | ||||||
|         "systems": system_count, |  | ||||||
|         "members": member_count, |  | ||||||
|         "messages": message_count |  | ||||||
|     }) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @web.middleware | @web.middleware | ||||||
| async def render_pk_errors(request, handler): | async def error_middleware(request, handler): | ||||||
|     try: |     try: | ||||||
|         return await handler(request) |         return await handler(request) | ||||||
|  |     except json.JSONDecodeError: | ||||||
|  |         raise web.HTTPBadRequest() | ||||||
|     except PluralKitError as e: |     except PluralKitError as e: | ||||||
|         raise web.HTTPBadRequest(body=e.message) |         return web.json_response({"error": e.message}, status=400) | ||||||
|  |  | ||||||
|  | @web.middleware | ||||||
|  | async def db_middleware(request, handler): | ||||||
|  |     async with request.app["pool"].acquire() as conn: | ||||||
|  |         request["conn"] = conn | ||||||
|  |         return await handler(request) | ||||||
|  |  | ||||||
| app = web.Application(middlewares=[render_pk_errors]) | @web.middleware | ||||||
| app.add_routes([ | async def auth_middleware(request, handler): | ||||||
|     web.get("/systems/{id}", get_system), |     token = request.headers.get("X-Token") or request.query.get("token") | ||||||
|     web.get("/systems/{id}/switches", get_switches), |     if token: | ||||||
|     web.get("/systems/{id}/switch", get_switch), |         system = await System.get_by_token(request["conn"], token) | ||||||
|     web.put("/systems/{id}/switch", put_switch), |         if system: | ||||||
|     web.get("/systems/{id}/switch/name", get_switch_name), |             request["system"] = system | ||||||
|     web.get("/systems/{id}/switch/color", get_switch_color), |     return await handler(request) | ||||||
|     web.get("/members/{id}", get_member), |  | ||||||
|     web.get("/messages/{id}", get_message), |  | ||||||
|     web.get("/stats", get_stats) |  | ||||||
| ]) |  | ||||||
|  |  | ||||||
|  | class Handlers: | ||||||
|  |     @require_system | ||||||
|  |     async def get_system(request): | ||||||
|  |         return web.json_response(request["system"].to_json()) | ||||||
|  |      | ||||||
|  |     async def get_other_system(request): | ||||||
|  |         system_id = request.match_info.get("system") | ||||||
|  |         system = await System.get_by_hid(request["conn"], system_id) | ||||||
|  |         if not system: | ||||||
|  |             raise web.HTTPNotFound() | ||||||
|  |         return web.json_response(system.to_json()) | ||||||
|  |  | ||||||
|  |     async def get_system_members(request): | ||||||
|  |         system_id = request.match_info.get("system") | ||||||
|  |         system = await System.get_by_hid(request["conn"], system_id) | ||||||
|  |         if not system: | ||||||
|  |             raise web.HTTPNotFound() | ||||||
|  |  | ||||||
|  |         members = await system.get_members(request["conn"]) | ||||||
|  |         return web.json_response([m.to_json() for m in members]) | ||||||
|  |  | ||||||
|  |     async def get_system_switches(request): | ||||||
|  |         system_id = request.match_info.get("system") | ||||||
|  |         system = await System.get_by_hid(request["conn"], system_id) | ||||||
|  |         if not system: | ||||||
|  |             raise web.HTTPNotFound() | ||||||
|  |  | ||||||
|  |         switches = await system.get_switches(request["conn"], 9999) | ||||||
|  |  | ||||||
|  |         cache = {} | ||||||
|  |         async def hid_getter(member_id): | ||||||
|  |             if not member_id in cache: | ||||||
|  |                 cache[member_id] = await Member.get_member_by_id(request["conn"], member_id) | ||||||
|  |             return cache[member_id].hid | ||||||
|  |  | ||||||
|  |         return web.json_response([await s.to_json(hid_getter) for s in switches]) | ||||||
|  |  | ||||||
|  |     @require_system | ||||||
|  |     async def patch_system(request): | ||||||
|  |         req = await request.json() | ||||||
|  |         if "name" in req: | ||||||
|  |             await request["system"].set_name(request["conn"], req["name"]) | ||||||
|  |         if "description" in req: | ||||||
|  |             await request["system"].set_description(request["conn"], req["description"]) | ||||||
|  |         if "tag" in req: | ||||||
|  |             await request["system"].set_tag(request["conn"], req["tag"]) | ||||||
|  |         if "avatar_url" in req: | ||||||
|  |             await request["system"].set_avatar(request["conn"], req["name"]) | ||||||
|  |         if "tz" in req: | ||||||
|  |             await request["system"].set_time_zone(request["conn"], req["tz"]) | ||||||
|  |         return web.json_response((await System.get_by_id(request["conn"], request["system"].id)).to_json()) | ||||||
|  |  | ||||||
|  |     async def get_member(request): | ||||||
|  |         member_id = request.match_info.get("member") | ||||||
|  |         member = await Member.get_member_by_hid(request["conn"], None, member_id) | ||||||
|  |         if not member: | ||||||
|  |             raise web.HTTPNotFound() | ||||||
|  |         return web.json_response(member.to_json()) | ||||||
|  |  | ||||||
|  |     @require_system | ||||||
|  |     async def post_member(request): | ||||||
|  |         member = await request["system"].create_member(request["conn"]) | ||||||
|  |         return web.json_response(member.to_json()) | ||||||
|  |  | ||||||
|  |     @require_system | ||||||
|  |     async def patch_member(request): | ||||||
|  |         member_id = request.match_info.get("member") | ||||||
|  |         member = await Member.get_member_by_hid(request["conn"], None, member_id) | ||||||
|  |         if not member: | ||||||
|  |             raise web.HTTPNotFound() | ||||||
|  |         if member.system != request["system"].id: | ||||||
|  |             raise web.HTTPUnauthorized() | ||||||
|  |          | ||||||
|  |         req = await request.json() | ||||||
|  |         if "name" in req: | ||||||
|  |             await member.set_name(request["conn"], req["name"]) | ||||||
|  |         if "description" in req: | ||||||
|  |             await member.set_description(request["conn"], req["description"]) | ||||||
|  |         if "avatar_url" in req: | ||||||
|  |             await member.set_avatar_url(request["conn"], req["avatar_url"]) | ||||||
|  |         if "color" in req: | ||||||
|  |             await member.set_color(request["conn"], req["color"]) | ||||||
|  |         if "birthday" in req: | ||||||
|  |             await member.set_birthdate(request["conn"], req["birthday"]) | ||||||
|  |         if "pronouns" in req: | ||||||
|  |             await member.set_pronouns(request["conn"], req["pronouns"]) | ||||||
|  |         if "prefix" in req or "suffix" in req: | ||||||
|  |             await member.set_proxy_tags(request["conn"], req.get("prefix", member.prefix), req.get("suffix", member.suffix)) | ||||||
|  |         return web.json_response((await Member.get_member_by_id(request["conn"], member.id)).to_json()) | ||||||
|  |      | ||||||
|  |     @require_system | ||||||
|  |     async def delete_member(request): | ||||||
|  |         member_id = request.match_info.get("member") | ||||||
|  |         member = await Member.get_member_by_hid(request["conn"], None, member_id) | ||||||
|  |         if not member: | ||||||
|  |             raise web.HTTPNotFound() | ||||||
|  |         if member.system != request["system"].id: | ||||||
|  |             raise web.HTTPUnauthorized() | ||||||
|  |  | ||||||
|  |         await member.delete(request["conn"]) | ||||||
|  |  | ||||||
|  |     @require_system | ||||||
|  |     async def post_switch(request): | ||||||
|  |         req = await request.json() | ||||||
|  |         if isinstance(req, str): | ||||||
|  |             req = [req] | ||||||
|  |         if req is None: | ||||||
|  |             req = [] | ||||||
|  |         if not isinstance(req, list): | ||||||
|  |             raise web.HTTPBadRequest() | ||||||
|  |  | ||||||
|  |         members = [await Member.get_member_by_hid(request["conn"], request["system"].id, hid) for hid in req] | ||||||
|  |         if not all(members): | ||||||
|  |             raise web.HTTPNotFound(body=json.dumps({"error": "One or more members not found."})) | ||||||
|  |  | ||||||
|  |         switch = await request["system"].add_switch(request["conn"], members) | ||||||
|  |  | ||||||
|  |         hids = {member.id: member.hid for member in members} | ||||||
|  |         async def hid_getter(mid): | ||||||
|  |             return hids[mid] | ||||||
|  |  | ||||||
|  |         return web.json_response(await switch.to_json(hid_getter)) | ||||||
|  |  | ||||||
| async def run(): | async def run(): | ||||||
|  |     app = web.Application(middlewares=[db_middleware, auth_middleware, error_middleware]) | ||||||
|  |  | ||||||
|  |     app.add_routes([ | ||||||
|  |         web.get("/s", Handlers.get_system), | ||||||
|  |         web.post("/s/switches", Handlers.post_switch), | ||||||
|  |         web.get("/s/{system}", Handlers.get_other_system), | ||||||
|  |         web.get("/s/{system}/members", Handlers.get_system_members), | ||||||
|  |         web.get("/s/{system}/switches", Handlers.get_system_switches), | ||||||
|  |         web.patch("/s", Handlers.patch_system), | ||||||
|  |         web.get("/m/{member}", Handlers.get_member), | ||||||
|  |         web.post("/m", Handlers.post_member), | ||||||
|  |         web.patch("/m/{member}", Handlers.patch_member), | ||||||
|  |         web.delete("/m/{member}", Handlers.delete_member) | ||||||
|  |     ]) | ||||||
|     app["pool"] = await db.connect( |     app["pool"] = await db.connect( | ||||||
|         os.environ["DATABASE_URI"] |         os.environ["DATABASE_URI"] | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -14,13 +14,19 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds | |||||||
|  |  | ||||||
| logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") | ||||||
|  |  | ||||||
| class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])): | class Config: | ||||||
|     required_fields = ["database_uri", "token"] |     required_fields = ["database_uri", "token"] | ||||||
|  |     fields = ["database_uri", "token", "log_channel"] | ||||||
|  |  | ||||||
|     database_uri: str |     database_uri: str | ||||||
|     token: str |     token: str | ||||||
|     log_channel: str |     log_channel: str | ||||||
|  |  | ||||||
|  |     def __init__(self, database_uri: str, token: str, log_channel: str = None): | ||||||
|  |         self.database_uri = database_uri | ||||||
|  |         self.token = token | ||||||
|  |         self.log_channel = log_channel | ||||||
|  |          | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_file_and_env(filename: str) -> "Config": |     def from_file_and_env(filename: str) -> "Config": | ||||||
|         try: |         try: | ||||||
| @@ -36,7 +42,7 @@ class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])): | |||||||
|                 raise e |                 raise e | ||||||
|  |  | ||||||
|         # Override with environment variables |         # Override with environment variables | ||||||
|         for f in Config._fields: |         for f in Config.fields: | ||||||
|             if f.upper() in os.environ: |             if f.upper() in os.environ: | ||||||
|                 config[f] = os.environ[f.upper()] |                 config[f] = os.environ[f.upper()] | ||||||
|  |  | ||||||
|   | |||||||
| @@ -38,6 +38,11 @@ class Member(namedtuple("Member", | |||||||
|             "suffix": self.suffix |             "suffix": self.suffix | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     async def get_member_by_id(conn, member_id: int) -> Optional["Member"]: | ||||||
|  |         """Fetch a member with the given internal member ID from the database.""" | ||||||
|  |         return await db.get_member(conn, member_id) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]": |     async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]": | ||||||
|         """Fetch a member by the given name in the given system from the database.""" |         """Fetch a member by the given name in the given system from the database.""" | ||||||
|   | |||||||
| @@ -21,8 +21,8 @@ class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])): | |||||||
|     async def move(self, conn, new_timestamp): |     async def move(self, conn, new_timestamp): | ||||||
|         await db.move_switch(conn, self.system, self.id, new_timestamp) |         await db.move_switch(conn, self.system, self.id, new_timestamp) | ||||||
|  |  | ||||||
|     async def to_json(self, conn): |     async def to_json(self, hid_getter): | ||||||
|         return { |         return { | ||||||
|             "timestamp": self.timestamp.isoformat(), |             "timestamp": self.timestamp.isoformat(), | ||||||
|             "members": [member.hid for member in await self.fetch_members(conn)] |             "members": [await hid_getter(m) for m in self.members] | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -38,6 +38,10 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     async def get_by_token(conn, token: str) -> Optional["System"]: |     async def get_by_token(conn, token: str) -> Optional["System"]: | ||||||
|         return await db.get_system_by_token(conn, token) |         return await db.get_system_by_token(conn, token) | ||||||
|  |      | ||||||
|  |     @staticmethod | ||||||
|  |     async def get_by_hid(conn, hid: str) -> Optional["System"]: | ||||||
|  |         return await db.get_system_by_hid(conn, hid) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System": |     async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System": | ||||||
| @@ -234,7 +238,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a | |||||||
|         :returns: The `pytz.tzinfo` instance of the newly set time zone. |         :returns: The `pytz.tzinfo` instance of the newly set time zone. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         tz = pytz.timezone(tz_name or "UTC") |         try: | ||||||
|  |             tz = pytz.timezone(tz_name or "UTC") | ||||||
|  |         except pytz.UnknownTimeZoneError: | ||||||
|  |             raise errors.InvalidTimeZoneError(tz_name) | ||||||
|  |  | ||||||
|         await db.update_system_field(conn, self.id, "ui_tz", tz.zone) |         await db.update_system_field(conn, self.id, "ui_tz", tz.zone) | ||||||
|         return tz |         return tz | ||||||
|  |  | ||||||
| @@ -304,5 +312,6 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a | |||||||
|             "name": self.name, |             "name": self.name, | ||||||
|             "description": self.description, |             "description": self.description, | ||||||
|             "tag": self.tag, |             "tag": self.tag, | ||||||
|             "avatar_url": self.avatar_url |             "avatar_url": self.avatar_url, | ||||||
|  |             "tz": self.ui_tz | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user