PluralKit/src/api_main.py
2019-04-19 18:08:58 +02:00

255 lines
9.8 KiB
Python

import json
import logging
import os
from aiohttp import web, ClientSession
from pluralkit import db, utils
from pluralkit.errors import PluralKitError
from pluralkit.member import Member
from pluralkit.system import System
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logger = logging.getLogger("pluralkit.api")
def require_system(f):
async def inner(request):
if "system" not in request:
raise web.HTTPUnauthorized()
return await f(request)
return inner
@web.middleware
async def error_middleware(request, handler):
try:
return await handler(request)
except json.JSONDecodeError:
raise web.HTTPBadRequest()
except PluralKitError as e:
return web.json_response({"error": e.message}, status=400)
@web.middleware
async def db_middleware(request, handler):
async with request.app["pool"].acquire() as conn:
request["conn"] = conn
return await handler(request)
@web.middleware
async def auth_middleware(request, handler):
token = request.headers.get("X-Token") or request.query.get("token")
if token:
system = await System.get_by_token(request["conn"], token)
if system:
request["system"] = system
return await handler(request)
@web.middleware
async def cors_middleware(request, handler):
try:
resp = await handler(request)
except web.HTTPException as r:
resp = r
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Methods"] = "GET, POST, PATCH"
resp.headers["Access-Control-Allow-Headers"] = "X-Token"
return resp
class Handlers:
@require_system
async def get_system(request):
return web.json_response(request["system"].to_json())
async def get_other_system(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
return web.json_response(system.to_json())
async def get_system_members(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
members = await system.get_members(request["conn"])
return web.json_response([m.to_json() for m in members])
async def get_system_switches(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
switches = await system.get_switches(request["conn"], 9999)
cache = {}
async def hid_getter(member_id):
if not member_id in cache:
cache[member_id] = await Member.get_member_by_id(request["conn"], member_id)
return cache[member_id].hid
return web.json_response([await s.to_json(hid_getter) for s in switches])
async def get_system_fronters(request):
system_id = request.match_info.get("system")
system = await System.get_by_hid(request["conn"], system_id)
if not system:
raise web.HTTPNotFound(body="null")
members, stamp = await utils.get_fronters(request["conn"], system.id)
if not stamp:
# No switch has been registered at all
raise web.HTTPNotFound(body="null")
data = {
"timestamp": stamp.isoformat(),
"members": [member.to_json() for member in members]
}
return web.json_response(data)
@require_system
async def patch_system(request):
req = await request.json()
if "name" in req:
await request["system"].set_name(request["conn"], req["name"])
if "description" in req:
await request["system"].set_description(request["conn"], req["description"])
if "tag" in req:
await request["system"].set_tag(request["conn"], req["tag"])
if "avatar_url" in req:
await request["system"].set_avatar(request["conn"], req["name"])
if "tz" in req:
await request["system"].set_time_zone(request["conn"], req["tz"])
return web.json_response((await System.get_by_id(request["conn"], request["system"].id)).to_json())
async def get_member(request):
member_id = request.match_info.get("member")
member = await Member.get_member_by_hid(request["conn"], None, member_id)
if not member:
raise web.HTTPNotFound(body="{}")
system = await System.get_by_id(request["conn"], member.system)
member_json = member.to_json()
member_json["system"] = system.to_json()
return web.json_response(member_json)
@require_system
async def post_member(request):
req = await request.json()
member = await request["system"].create_member(request["conn"], req["name"])
return web.json_response(member.to_json())
@require_system
async def patch_member(request):
member_id = request.match_info.get("member")
member = await Member.get_member_by_hid(request["conn"], None, member_id)
if not member:
raise web.HTTPNotFound()
if member.system != request["system"].id:
raise web.HTTPUnauthorized()
req = await request.json()
if "name" in req:
await member.set_name(request["conn"], req["name"])
if "description" in req:
await member.set_description(request["conn"], req["description"])
if "avatar_url" in req:
await member.set_avatar_url(request["conn"], req["avatar_url"])
if "color" in req:
await member.set_color(request["conn"], req["color"])
if "birthday" in req:
await member.set_birthdate(request["conn"], req["birthday"])
if "pronouns" in req:
await member.set_pronouns(request["conn"], req["pronouns"])
if "prefix" in req or "suffix" in req:
await member.set_proxy_tags(request["conn"], req.get("prefix", member.prefix), req.get("suffix", member.suffix))
return web.json_response((await Member.get_member_by_id(request["conn"], member.id)).to_json())
@require_system
async def delete_member(request):
member_id = request.match_info.get("member")
member = await Member.get_member_by_hid(request["conn"], None, member_id)
if not member:
raise web.HTTPNotFound()
if member.system != request["system"].id:
raise web.HTTPUnauthorized()
await member.delete(request["conn"])
@require_system
async def post_switch(request):
req = await request.json()
if isinstance(req, str):
req = [req]
if req is None:
req = []
if not isinstance(req, list):
raise web.HTTPBadRequest()
members = [await Member.get_member_by_hid(request["conn"], request["system"].id, hid) for hid in req]
if not all(members):
raise web.HTTPNotFound(body=json.dumps({"error": "One or more members not found."}))
switch = await request["system"].add_switch(request["conn"], members)
hids = {member.id: member.hid for member in members}
async def hid_getter(mid):
return hids[mid]
return web.json_response(await switch.to_json(hid_getter))
async def discord_oauth(request):
code = await request.text()
async with ClientSession() as sess:
data = {
'client_id': os.environ["CLIENT_ID"],
'client_secret': os.environ["CLIENT_SECRET"],
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': os.environ["REDIRECT_URI"],
'scope': 'identify'
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
res = await sess.post("https://discordapp.com/api/v6/oauth2/token", data=data, headers=headers)
if res.status != 200:
raise web.HTTPBadRequest()
access_token = (await res.json())["access_token"]
res = await sess.get("https://discordapp.com/api/v6/users/@me", headers={"Authorization": "Bearer " + access_token})
user_id = int((await res.json())["id"])
system = await System.get_by_account(request["conn"], user_id)
if not system:
raise web.HTTPUnauthorized()
return web.Response(text=await system.get_token(request["conn"]))
async def run():
app = web.Application(middlewares=[cors_middleware, db_middleware, auth_middleware, error_middleware])
def cors_fallback(req):
return web.Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "x-token", "Access-Control-Allow-Methods": "GET, POST, PATCH"}, status=404 if req.method != "OPTIONS" else 200)
app.add_routes([
web.get("/s", Handlers.get_system),
web.post("/s/switches", Handlers.post_switch),
web.get("/s/{system}", Handlers.get_other_system),
web.get("/s/{system}/members", Handlers.get_system_members),
web.get("/s/{system}/switches", Handlers.get_system_switches),
web.get("/s/{system}/fronters", Handlers.get_system_fronters),
web.patch("/s", Handlers.patch_system),
web.get("/m/{member}", Handlers.get_member),
web.post("/m", Handlers.post_member),
web.patch("/m/{member}", Handlers.patch_member),
web.delete("/m/{member}", Handlers.delete_member),
web.post("/discord_oauth", Handlers.discord_oauth),
web.route("*", "/{tail:.*}", cors_fallback)
])
app["pool"] = await db.connect(
os.environ["DATABASE_URI"]
)
return app
web.run_app(run())