Rework API
This commit is contained in:
parent
ac911b170d
commit
47187138b6
326
src/api_main.py
326
src/api_main.py
@ -12,187 +12,175 @@ from pluralkit.system import System
|
|||||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||||
logger = logging.getLogger("pluralkit.api")
|
logger = logging.getLogger("pluralkit.api")
|
||||||
|
|
||||||
|
def require_system(f):
|
||||||
def db_handler(f):
|
async def inner(request):
|
||||||
async def inner(request, *args, **kwargs):
|
if "system" not in request:
|
||||||
async with request.app["pool"].acquire() as conn:
|
|
||||||
return await f(request, conn=conn, *args, **kwargs)
|
|
||||||
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
def system_auth(f):
|
|
||||||
async def inner(request: web.Request, conn, *args, **kwargs):
|
|
||||||
token = request.headers.get("X-Token")
|
|
||||||
if not token:
|
|
||||||
token = request.query.get("token")
|
|
||||||
|
|
||||||
if not token:
|
|
||||||
raise web.HTTPUnauthorized()
|
raise web.HTTPUnauthorized()
|
||||||
|
return await f(request)
|
||||||
system = await System.get_by_token(conn, token)
|
|
||||||
if not system:
|
|
||||||
raise web.HTTPUnauthorized()
|
|
||||||
|
|
||||||
return await f(request, conn=conn, system=system, *args, **kwargs)
|
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_system(request: web.Request, conn):
|
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
|
||||||
|
|
||||||
if not system:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
members = await db.get_all_members(conn, system.id)
|
|
||||||
|
|
||||||
system_json = system.to_json()
|
|
||||||
system_json["members"] = [member.to_json() for member in members]
|
|
||||||
return web.json_response(system_json)
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_member(request: web.Request, conn):
|
|
||||||
member = await db.get_member_by_hid(conn, request.match_info["id"])
|
|
||||||
|
|
||||||
if not member:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
return web.json_response(member.to_json())
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_switches(request: web.Request, conn):
|
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
|
||||||
|
|
||||||
if not system:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
switches = await utils.get_front_history(conn, system.id, 99999)
|
|
||||||
|
|
||||||
data = [{
|
|
||||||
"timestamp": stamp.isoformat(),
|
|
||||||
"members": [member.hid for member in members]
|
|
||||||
} for stamp, members in switches]
|
|
||||||
|
|
||||||
return web.json_response(data)
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_message(request: web.Request, conn):
|
|
||||||
message = await db.get_message(conn, request.match_info["id"])
|
|
||||||
if not message:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
return web.json_response(message.to_json())
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_switch(request: web.Request, conn):
|
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
|
||||||
|
|
||||||
if not system:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
members, stamp = await utils.get_fronters(conn, system.id)
|
|
||||||
if not stamp:
|
|
||||||
# No switch has been registered at all
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"timestamp": stamp.isoformat(),
|
|
||||||
"members": [member.to_json() for member in members]
|
|
||||||
}
|
|
||||||
return web.json_response(data)
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_switch_name(request: web.Request, conn):
|
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
|
||||||
|
|
||||||
if not system:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
members, stamp = await utils.get_fronters(conn, system.id)
|
|
||||||
return web.Response(text=members[0].name if members else "(nobody)")
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_switch_color(request: web.Request, conn):
|
|
||||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
|
||||||
|
|
||||||
if not system:
|
|
||||||
raise web.HTTPNotFound()
|
|
||||||
|
|
||||||
members, stamp = await utils.get_fronters(conn, system.id)
|
|
||||||
return web.Response(text=members[0].color if members else "#ffffff")
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
@system_auth
|
|
||||||
async def put_switch(request: web.Request, system: System, conn):
|
|
||||||
try:
|
|
||||||
req = await request.json()
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise web.HTTPBadRequest(body="Invalid JSON")
|
|
||||||
|
|
||||||
if isinstance(req, str):
|
|
||||||
req = [req]
|
|
||||||
elif not isinstance(req, list):
|
|
||||||
raise web.HTTPBadRequest(body="Body must be JSON string or list")
|
|
||||||
|
|
||||||
members = []
|
|
||||||
for member_name in req:
|
|
||||||
if not isinstance(member_name, str):
|
|
||||||
raise web.HTTPBadRequest(body="List value must be string")
|
|
||||||
|
|
||||||
member = await Member.get_member_fuzzy(conn, system.id, member_name)
|
|
||||||
if not member:
|
|
||||||
raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name))
|
|
||||||
members.append(member)
|
|
||||||
|
|
||||||
switch = await system.add_switch(conn, members)
|
|
||||||
return web.json_response(await switch.to_json(conn))
|
|
||||||
|
|
||||||
|
|
||||||
@db_handler
|
|
||||||
async def get_stats(request: web.Request, conn):
|
|
||||||
system_count = await db.system_count(conn)
|
|
||||||
member_count = await db.member_count(conn)
|
|
||||||
message_count = await db.message_count(conn)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
"systems": system_count,
|
|
||||||
"members": member_count,
|
|
||||||
"messages": message_count
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def render_pk_errors(request, handler):
|
async def error_middleware(request, handler):
|
||||||
try:
|
try:
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise web.HTTPBadRequest()
|
||||||
except PluralKitError as e:
|
except PluralKitError as e:
|
||||||
raise web.HTTPBadRequest(body=e.message)
|
return web.json_response({"error": e.message}, status=400)
|
||||||
|
|
||||||
|
@web.middleware
|
||||||
|
async def db_middleware(request, handler):
|
||||||
|
async with request.app["pool"].acquire() as conn:
|
||||||
|
request["conn"] = conn
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
app = web.Application(middlewares=[render_pk_errors])
|
@web.middleware
|
||||||
app.add_routes([
|
async def auth_middleware(request, handler):
|
||||||
web.get("/systems/{id}", get_system),
|
token = request.headers.get("X-Token") or request.query.get("token")
|
||||||
web.get("/systems/{id}/switches", get_switches),
|
if token:
|
||||||
web.get("/systems/{id}/switch", get_switch),
|
system = await System.get_by_token(request["conn"], token)
|
||||||
web.put("/systems/{id}/switch", put_switch),
|
if system:
|
||||||
web.get("/systems/{id}/switch/name", get_switch_name),
|
request["system"] = system
|
||||||
web.get("/systems/{id}/switch/color", get_switch_color),
|
return await handler(request)
|
||||||
web.get("/members/{id}", get_member),
|
|
||||||
web.get("/messages/{id}", get_message),
|
|
||||||
web.get("/stats", get_stats)
|
|
||||||
])
|
|
||||||
|
|
||||||
|
class Handlers:
|
||||||
|
@require_system
|
||||||
|
async def get_system(request):
|
||||||
|
return web.json_response(request["system"].to_json())
|
||||||
|
|
||||||
|
async def get_other_system(request):
|
||||||
|
system_id = request.match_info.get("system")
|
||||||
|
system = await System.get_by_hid(request["conn"], system_id)
|
||||||
|
if not system:
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
return web.json_response(system.to_json())
|
||||||
|
|
||||||
|
async def get_system_members(request):
|
||||||
|
system_id = request.match_info.get("system")
|
||||||
|
system = await System.get_by_hid(request["conn"], system_id)
|
||||||
|
if not system:
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
|
||||||
|
members = await system.get_members(request["conn"])
|
||||||
|
return web.json_response([m.to_json() for m in members])
|
||||||
|
|
||||||
|
async def get_system_switches(request):
|
||||||
|
system_id = request.match_info.get("system")
|
||||||
|
system = await System.get_by_hid(request["conn"], system_id)
|
||||||
|
if not system:
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
|
||||||
|
switches = await system.get_switches(request["conn"], 9999)
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
async def hid_getter(member_id):
|
||||||
|
if not member_id in cache:
|
||||||
|
cache[member_id] = await Member.get_member_by_id(request["conn"], member_id)
|
||||||
|
return cache[member_id].hid
|
||||||
|
|
||||||
|
return web.json_response([await s.to_json(hid_getter) for s in switches])
|
||||||
|
|
||||||
|
@require_system
|
||||||
|
async def patch_system(request):
|
||||||
|
req = await request.json()
|
||||||
|
if "name" in req:
|
||||||
|
await request["system"].set_name(request["conn"], req["name"])
|
||||||
|
if "description" in req:
|
||||||
|
await request["system"].set_description(request["conn"], req["description"])
|
||||||
|
if "tag" in req:
|
||||||
|
await request["system"].set_tag(request["conn"], req["tag"])
|
||||||
|
if "avatar_url" in req:
|
||||||
|
await request["system"].set_avatar(request["conn"], req["name"])
|
||||||
|
if "tz" in req:
|
||||||
|
await request["system"].set_time_zone(request["conn"], req["tz"])
|
||||||
|
return web.json_response((await System.get_by_id(request["conn"], request["system"].id)).to_json())
|
||||||
|
|
||||||
|
async def get_member(request):
|
||||||
|
member_id = request.match_info.get("member")
|
||||||
|
member = await Member.get_member_by_hid(request["conn"], None, member_id)
|
||||||
|
if not member:
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
return web.json_response(member.to_json())
|
||||||
|
|
||||||
|
@require_system
|
||||||
|
async def post_member(request):
|
||||||
|
member = await request["system"].create_member(request["conn"])
|
||||||
|
return web.json_response(member.to_json())
|
||||||
|
|
||||||
|
@require_system
|
||||||
|
async def patch_member(request):
|
||||||
|
member_id = request.match_info.get("member")
|
||||||
|
member = await Member.get_member_by_hid(request["conn"], None, member_id)
|
||||||
|
if not member:
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
if member.system != request["system"].id:
|
||||||
|
raise web.HTTPUnauthorized()
|
||||||
|
|
||||||
|
req = await request.json()
|
||||||
|
if "name" in req:
|
||||||
|
await member.set_name(request["conn"], req["name"])
|
||||||
|
if "description" in req:
|
||||||
|
await member.set_description(request["conn"], req["description"])
|
||||||
|
if "avatar_url" in req:
|
||||||
|
await member.set_avatar_url(request["conn"], req["avatar_url"])
|
||||||
|
if "color" in req:
|
||||||
|
await member.set_color(request["conn"], req["color"])
|
||||||
|
if "birthday" in req:
|
||||||
|
await member.set_birthdate(request["conn"], req["birthday"])
|
||||||
|
if "pronouns" in req:
|
||||||
|
await member.set_pronouns(request["conn"], req["pronouns"])
|
||||||
|
if "prefix" in req or "suffix" in req:
|
||||||
|
await member.set_proxy_tags(request["conn"], req.get("prefix", member.prefix), req.get("suffix", member.suffix))
|
||||||
|
return web.json_response((await Member.get_member_by_id(request["conn"], member.id)).to_json())
|
||||||
|
|
||||||
|
@require_system
|
||||||
|
async def delete_member(request):
|
||||||
|
member_id = request.match_info.get("member")
|
||||||
|
member = await Member.get_member_by_hid(request["conn"], None, member_id)
|
||||||
|
if not member:
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
if member.system != request["system"].id:
|
||||||
|
raise web.HTTPUnauthorized()
|
||||||
|
|
||||||
|
await member.delete(request["conn"])
|
||||||
|
|
||||||
|
@require_system
|
||||||
|
async def post_switch(request):
|
||||||
|
req = await request.json()
|
||||||
|
if isinstance(req, str):
|
||||||
|
req = [req]
|
||||||
|
if req is None:
|
||||||
|
req = []
|
||||||
|
if not isinstance(req, list):
|
||||||
|
raise web.HTTPBadRequest()
|
||||||
|
|
||||||
|
members = [await Member.get_member_by_hid(request["conn"], request["system"].id, hid) for hid in req]
|
||||||
|
if not all(members):
|
||||||
|
raise web.HTTPNotFound(body=json.dumps({"error": "One or more members not found."}))
|
||||||
|
|
||||||
|
switch = await request["system"].add_switch(request["conn"], members)
|
||||||
|
|
||||||
|
hids = {member.id: member.hid for member in members}
|
||||||
|
async def hid_getter(mid):
|
||||||
|
return hids[mid]
|
||||||
|
|
||||||
|
return web.json_response(await switch.to_json(hid_getter))
|
||||||
|
|
||||||
async def run():
|
async def run():
|
||||||
|
app = web.Application(middlewares=[db_middleware, auth_middleware, error_middleware])
|
||||||
|
|
||||||
|
app.add_routes([
|
||||||
|
web.get("/s", Handlers.get_system),
|
||||||
|
web.post("/s/switches", Handlers.post_switch),
|
||||||
|
web.get("/s/{system}", Handlers.get_other_system),
|
||||||
|
web.get("/s/{system}/members", Handlers.get_system_members),
|
||||||
|
web.get("/s/{system}/switches", Handlers.get_system_switches),
|
||||||
|
web.patch("/s", Handlers.patch_system),
|
||||||
|
web.get("/m/{member}", Handlers.get_member),
|
||||||
|
web.post("/m", Handlers.post_member),
|
||||||
|
web.patch("/m/{member}", Handlers.patch_member),
|
||||||
|
web.delete("/m/{member}", Handlers.delete_member)
|
||||||
|
])
|
||||||
app["pool"] = await db.connect(
|
app["pool"] = await db.connect(
|
||||||
os.environ["DATABASE_URI"]
|
os.environ["DATABASE_URI"]
|
||||||
)
|
)
|
||||||
|
@ -14,13 +14,19 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds
|
|||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||||
|
|
||||||
class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])):
|
class Config:
|
||||||
required_fields = ["database_uri", "token"]
|
required_fields = ["database_uri", "token"]
|
||||||
|
fields = ["database_uri", "token", "log_channel"]
|
||||||
|
|
||||||
database_uri: str
|
database_uri: str
|
||||||
token: str
|
token: str
|
||||||
log_channel: str
|
log_channel: str
|
||||||
|
|
||||||
|
def __init__(self, database_uri: str, token: str, log_channel: str = None):
|
||||||
|
self.database_uri = database_uri
|
||||||
|
self.token = token
|
||||||
|
self.log_channel = log_channel
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_file_and_env(filename: str) -> "Config":
|
def from_file_and_env(filename: str) -> "Config":
|
||||||
try:
|
try:
|
||||||
@ -36,7 +42,7 @@ class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Override with environment variables
|
# Override with environment variables
|
||||||
for f in Config._fields:
|
for f in Config.fields:
|
||||||
if f.upper() in os.environ:
|
if f.upper() in os.environ:
|
||||||
config[f] = os.environ[f.upper()]
|
config[f] = os.environ[f.upper()]
|
||||||
|
|
||||||
|
@ -38,6 +38,11 @@ class Member(namedtuple("Member",
|
|||||||
"suffix": self.suffix
|
"suffix": self.suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_member_by_id(conn, member_id: int) -> Optional["Member"]:
|
||||||
|
"""Fetch a member with the given internal member ID from the database."""
|
||||||
|
return await db.get_member(conn, member_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]":
|
async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]":
|
||||||
"""Fetch a member by the given name in the given system from the database."""
|
"""Fetch a member by the given name in the given system from the database."""
|
||||||
|
@ -21,8 +21,8 @@ class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])):
|
|||||||
async def move(self, conn, new_timestamp):
|
async def move(self, conn, new_timestamp):
|
||||||
await db.move_switch(conn, self.system, self.id, new_timestamp)
|
await db.move_switch(conn, self.system, self.id, new_timestamp)
|
||||||
|
|
||||||
async def to_json(self, conn):
|
async def to_json(self, hid_getter):
|
||||||
return {
|
return {
|
||||||
"timestamp": self.timestamp.isoformat(),
|
"timestamp": self.timestamp.isoformat(),
|
||||||
"members": [member.hid for member in await self.fetch_members(conn)]
|
"members": [await hid_getter(m) for m in self.members]
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,10 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_by_token(conn, token: str) -> Optional["System"]:
|
async def get_by_token(conn, token: str) -> Optional["System"]:
|
||||||
return await db.get_system_by_token(conn, token)
|
return await db.get_system_by_token(conn, token)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_by_hid(conn, hid: str) -> Optional["System"]:
|
||||||
|
return await db.get_system_by_hid(conn, hid)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
|
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
|
||||||
@ -234,7 +238,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
:returns: The `pytz.tzinfo` instance of the newly set time zone.
|
:returns: The `pytz.tzinfo` instance of the newly set time zone.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tz = pytz.timezone(tz_name or "UTC")
|
try:
|
||||||
|
tz = pytz.timezone(tz_name or "UTC")
|
||||||
|
except pytz.UnknownTimeZoneError:
|
||||||
|
raise errors.InvalidTimeZoneError(tz_name)
|
||||||
|
|
||||||
await db.update_system_field(conn, self.id, "ui_tz", tz.zone)
|
await db.update_system_field(conn, self.id, "ui_tz", tz.zone)
|
||||||
return tz
|
return tz
|
||||||
|
|
||||||
@ -304,5 +312,6 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
"name": self.name,
|
"name": self.name,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
"tag": self.tag,
|
"tag": self.tag,
|
||||||
"avatar_url": self.avatar_url
|
"avatar_url": self.avatar_url,
|
||||||
|
"tz": self.ui_tz
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user