Add OAuth2 token flow to API

This commit is contained in:
Ske 2019-03-11 21:53:08 +01:00
parent 47187138b6
commit aeac3c4b10
3 changed files with 38 additions and 2 deletions

View File

@ -24,6 +24,9 @@ services:
- "2939:8080"
environment:
- "DATABASE_URI=postgres://postgres:postgres@db:5432/postgres"
- "CLIENT_ID"
- "CLIENT_SECRET"
- "REDIRECT_URI"
db:
image: postgres:alpine
volumes:

View File

@ -2,7 +2,7 @@ import json
import logging
import os
from aiohttp import web
from aiohttp import web, ClientSession
from pluralkit import db, utils
from pluralkit.errors import PluralKitError
@ -166,6 +166,33 @@ class Handlers:
return web.json_response(await switch.to_json(hid_getter))
async def discord_oauth(request):
code = await request.text()
async with ClientSession() as sess:
data = {
'client_id': os.environ["CLIENT_ID"],
'client_secret': os.environ["CLIENT_SECRET"],
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': os.environ["REDIRECT_URI"],
'scope': 'identify'
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
res = await sess.post("https://discordapp.com/api/v6/oauth2/token", data=data, headers=headers)
if res.status != 200:
raise web.HTTPBadRequest()
access_token = (await res.json())["access_token"]
res = await sess.get("https://discordapp.com/api/v6/users/@me", headers={"Authorization": "Bearer " + access_token})
user_id = int((await res.json())["id"])
system = await System.get_by_account(request["conn"], user_id)
if not system:
raise web.HTTPUnauthorized()
return web.Response(text=await system.get_token(request["conn"]))
async def run():
app = web.Application(middlewares=[db_middleware, auth_middleware, error_middleware])
@ -179,7 +206,8 @@ async def run():
web.get("/m/{member}", Handlers.get_member),
web.post("/m", Handlers.post_member),
web.patch("/m/{member}", Handlers.patch_member),
web.delete("/m/{member}", Handlers.delete_member)
web.delete("/m/{member}", Handlers.delete_member),
web.post("/discord_oauth", Handlers.discord_oauth)
])
app["pool"] = await db.connect(
os.environ["DATABASE_URI"]

View File

@ -116,6 +116,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
await db.update_system_field(conn, self.id, "token", new_token)
return new_token
async def get_token(self, conn) -> str:
if self.token:
return self.token
return await self.refresh_token(conn)
async def create_member(self, conn, member_name: str) -> Member:
# TODO: figure out what to do if this errors out on collision on generate_hid
new_hid = generate_hid()