255 lines
9.8 KiB
Python
255 lines
9.8 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
|
|
from aiohttp import web, ClientSession
|
|
|
|
from pluralkit import db, utils
|
|
from pluralkit.errors import PluralKitError
|
|
from pluralkit.member import Member
|
|
from pluralkit.system import System
|
|
|
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
|
logger = logging.getLogger("pluralkit.api")
|
|
|
|
def require_system(f):
|
|
async def inner(request):
|
|
if "system" not in request:
|
|
raise web.HTTPUnauthorized()
|
|
return await f(request)
|
|
return inner
|
|
|
|
@web.middleware
|
|
async def error_middleware(request, handler):
|
|
try:
|
|
return await handler(request)
|
|
except json.JSONDecodeError:
|
|
raise web.HTTPBadRequest()
|
|
except PluralKitError as e:
|
|
return web.json_response({"error": e.message}, status=400)
|
|
|
|
@web.middleware
|
|
async def db_middleware(request, handler):
|
|
async with request.app["pool"].acquire() as conn:
|
|
request["conn"] = conn
|
|
return await handler(request)
|
|
|
|
@web.middleware
|
|
async def auth_middleware(request, handler):
|
|
token = request.headers.get("X-Token") or request.query.get("token")
|
|
if token:
|
|
system = await System.get_by_token(request["conn"], token)
|
|
if system:
|
|
request["system"] = system
|
|
return await handler(request)
|
|
|
|
@web.middleware
|
|
async def cors_middleware(request, handler):
|
|
try:
|
|
resp = await handler(request)
|
|
except web.HTTPException as r:
|
|
resp = r
|
|
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
resp.headers["Access-Control-Allow-Methods"] = "GET, POST, PATCH"
|
|
resp.headers["Access-Control-Allow-Headers"] = "X-Token"
|
|
return resp
|
|
|
|
class Handlers:
|
|
@require_system
|
|
async def get_system(request):
|
|
return web.json_response(request["system"].to_json())
|
|
|
|
async def get_other_system(request):
|
|
system_id = request.match_info.get("system")
|
|
system = await System.get_by_hid(request["conn"], system_id)
|
|
if not system:
|
|
raise web.HTTPNotFound(body="null")
|
|
return web.json_response(system.to_json())
|
|
|
|
async def get_system_members(request):
|
|
system_id = request.match_info.get("system")
|
|
system = await System.get_by_hid(request["conn"], system_id)
|
|
if not system:
|
|
raise web.HTTPNotFound(body="null")
|
|
|
|
members = await system.get_members(request["conn"])
|
|
return web.json_response([m.to_json() for m in members])
|
|
|
|
async def get_system_switches(request):
|
|
system_id = request.match_info.get("system")
|
|
system = await System.get_by_hid(request["conn"], system_id)
|
|
if not system:
|
|
raise web.HTTPNotFound(body="null")
|
|
|
|
switches = await system.get_switches(request["conn"], 9999)
|
|
|
|
cache = {}
|
|
async def hid_getter(member_id):
|
|
if not member_id in cache:
|
|
cache[member_id] = await Member.get_member_by_id(request["conn"], member_id)
|
|
return cache[member_id].hid
|
|
|
|
return web.json_response([await s.to_json(hid_getter) for s in switches])
|
|
|
|
async def get_system_fronters(request):
|
|
system_id = request.match_info.get("system")
|
|
system = await System.get_by_hid(request["conn"], system_id)
|
|
|
|
if not system:
|
|
raise web.HTTPNotFound(body="null")
|
|
|
|
members, stamp = await utils.get_fronters(request["conn"], system.id)
|
|
if not stamp:
|
|
# No switch has been registered at all
|
|
raise web.HTTPNotFound(body="null")
|
|
|
|
data = {
|
|
"timestamp": stamp.isoformat(),
|
|
"members": [member.to_json() for member in members]
|
|
}
|
|
return web.json_response(data)
|
|
|
|
@require_system
|
|
async def patch_system(request):
|
|
req = await request.json()
|
|
if "name" in req:
|
|
await request["system"].set_name(request["conn"], req["name"])
|
|
if "description" in req:
|
|
await request["system"].set_description(request["conn"], req["description"])
|
|
if "tag" in req:
|
|
await request["system"].set_tag(request["conn"], req["tag"])
|
|
if "avatar_url" in req:
|
|
await request["system"].set_avatar(request["conn"], req["name"])
|
|
if "tz" in req:
|
|
await request["system"].set_time_zone(request["conn"], req["tz"])
|
|
return web.json_response((await System.get_by_id(request["conn"], request["system"].id)).to_json())
|
|
|
|
async def get_member(request):
|
|
member_id = request.match_info.get("member")
|
|
member = await Member.get_member_by_hid(request["conn"], None, member_id)
|
|
if not member:
|
|
raise web.HTTPNotFound(body="{}")
|
|
system = await System.get_by_id(request["conn"], member.system)
|
|
member_json = member.to_json()
|
|
member_json["system"] = system.to_json()
|
|
return web.json_response(member_json)
|
|
|
|
@require_system
|
|
async def post_member(request):
|
|
req = await request.json()
|
|
member = await request["system"].create_member(request["conn"], req["name"])
|
|
return web.json_response(member.to_json())
|
|
|
|
@require_system
|
|
async def patch_member(request):
|
|
member_id = request.match_info.get("member")
|
|
member = await Member.get_member_by_hid(request["conn"], None, member_id)
|
|
if not member:
|
|
raise web.HTTPNotFound()
|
|
if member.system != request["system"].id:
|
|
raise web.HTTPUnauthorized()
|
|
|
|
req = await request.json()
|
|
if "name" in req:
|
|
await member.set_name(request["conn"], req["name"])
|
|
if "description" in req:
|
|
await member.set_description(request["conn"], req["description"])
|
|
if "avatar_url" in req:
|
|
await member.set_avatar_url(request["conn"], req["avatar_url"])
|
|
if "color" in req:
|
|
await member.set_color(request["conn"], req["color"])
|
|
if "birthday" in req:
|
|
await member.set_birthdate(request["conn"], req["birthday"])
|
|
if "pronouns" in req:
|
|
await member.set_pronouns(request["conn"], req["pronouns"])
|
|
if "prefix" in req or "suffix" in req:
|
|
await member.set_proxy_tags(request["conn"], req.get("prefix", member.prefix), req.get("suffix", member.suffix))
|
|
return web.json_response((await Member.get_member_by_id(request["conn"], member.id)).to_json())
|
|
|
|
@require_system
|
|
async def delete_member(request):
|
|
member_id = request.match_info.get("member")
|
|
member = await Member.get_member_by_hid(request["conn"], None, member_id)
|
|
if not member:
|
|
raise web.HTTPNotFound()
|
|
if member.system != request["system"].id:
|
|
raise web.HTTPUnauthorized()
|
|
|
|
await member.delete(request["conn"])
|
|
|
|
@require_system
|
|
async def post_switch(request):
|
|
req = await request.json()
|
|
if isinstance(req, str):
|
|
req = [req]
|
|
if req is None:
|
|
req = []
|
|
if not isinstance(req, list):
|
|
raise web.HTTPBadRequest()
|
|
|
|
members = [await Member.get_member_by_hid(request["conn"], request["system"].id, hid) for hid in req]
|
|
if not all(members):
|
|
raise web.HTTPNotFound(body=json.dumps({"error": "One or more members not found."}))
|
|
|
|
switch = await request["system"].add_switch(request["conn"], members)
|
|
|
|
hids = {member.id: member.hid for member in members}
|
|
async def hid_getter(mid):
|
|
return hids[mid]
|
|
|
|
return web.json_response(await switch.to_json(hid_getter))
|
|
|
|
async def discord_oauth(request):
|
|
code = await request.text()
|
|
async with ClientSession() as sess:
|
|
data = {
|
|
'client_id': os.environ["CLIENT_ID"],
|
|
'client_secret': os.environ["CLIENT_SECRET"],
|
|
'grant_type': 'authorization_code',
|
|
'code': code,
|
|
'redirect_uri': os.environ["REDIRECT_URI"],
|
|
'scope': 'identify'
|
|
}
|
|
headers = {
|
|
'Content-Type': 'application/x-www-form-urlencoded'
|
|
}
|
|
res = await sess.post("https://discordapp.com/api/v6/oauth2/token", data=data, headers=headers)
|
|
if res.status != 200:
|
|
raise web.HTTPBadRequest()
|
|
|
|
access_token = (await res.json())["access_token"]
|
|
res = await sess.get("https://discordapp.com/api/v6/users/@me", headers={"Authorization": "Bearer " + access_token})
|
|
user_id = int((await res.json())["id"])
|
|
|
|
system = await System.get_by_account(request["conn"], user_id)
|
|
if not system:
|
|
raise web.HTTPUnauthorized()
|
|
return web.Response(text=await system.get_token(request["conn"]))
|
|
|
|
async def run():
|
|
app = web.Application(middlewares=[cors_middleware, db_middleware, auth_middleware, error_middleware])
|
|
def cors_fallback(req):
|
|
return web.Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "x-token", "Access-Control-Allow-Methods": "GET, POST, PATCH"}, status=404 if req.method != "OPTIONS" else 200)
|
|
app.add_routes([
|
|
web.get("/s", Handlers.get_system),
|
|
web.post("/s/switches", Handlers.post_switch),
|
|
web.get("/s/{system}", Handlers.get_other_system),
|
|
web.get("/s/{system}/members", Handlers.get_system_members),
|
|
web.get("/s/{system}/switches", Handlers.get_system_switches),
|
|
web.get("/s/{system}/fronters", Handlers.get_system_fronters),
|
|
web.patch("/s", Handlers.patch_system),
|
|
web.get("/m/{member}", Handlers.get_member),
|
|
web.post("/m", Handlers.post_member),
|
|
web.patch("/m/{member}", Handlers.patch_member),
|
|
web.delete("/m/{member}", Handlers.delete_member),
|
|
web.post("/discord_oauth", Handlers.discord_oauth),
|
|
web.route("*", "/{tail:.*}", cors_fallback)
|
|
])
|
|
app["pool"] = await db.connect(
|
|
os.environ["DATABASE_URI"]
|
|
)
|
|
return app
|
|
|
|
|
|
web.run_app(run())
|