Add OAuth2 token flow to API

This commit is contained in:
Ske 2019-03-11 21:53:08 +01:00
parent 47187138b6
commit aeac3c4b10
3 changed files with 38 additions and 2 deletions

View File

@ -24,6 +24,9 @@ services:
- "2939:8080" - "2939:8080"
environment: environment:
- "DATABASE_URI=postgres://postgres:postgres@db:5432/postgres" - "DATABASE_URI=postgres://postgres:postgres@db:5432/postgres"
- "CLIENT_ID"
- "CLIENT_SECRET"
- "REDIRECT_URI"
db: db:
image: postgres:alpine image: postgres:alpine
volumes: volumes:

View File

@ -2,7 +2,7 @@ import json
import logging import logging
import os import os
from aiohttp import web from aiohttp import web, ClientSession
from pluralkit import db, utils from pluralkit import db, utils
from pluralkit.errors import PluralKitError from pluralkit.errors import PluralKitError
@ -166,6 +166,33 @@ class Handlers:
return web.json_response(await switch.to_json(hid_getter)) return web.json_response(await switch.to_json(hid_getter))
async def discord_oauth(request):
code = await request.text()
async with ClientSession() as sess:
data = {
'client_id': os.environ["CLIENT_ID"],
'client_secret': os.environ["CLIENT_SECRET"],
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': os.environ["REDIRECT_URI"],
'scope': 'identify'
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
res = await sess.post("https://discordapp.com/api/v6/oauth2/token", data=data, headers=headers)
if res.status != 200:
raise web.HTTPBadRequest()
access_token = (await res.json())["access_token"]
res = await sess.get("https://discordapp.com/api/v6/users/@me", headers={"Authorization": "Bearer " + access_token})
user_id = int((await res.json())["id"])
system = await System.get_by_account(request["conn"], user_id)
if not system:
raise web.HTTPUnauthorized()
return web.Response(text=await system.get_token(request["conn"]))
async def run(): async def run():
app = web.Application(middlewares=[db_middleware, auth_middleware, error_middleware]) app = web.Application(middlewares=[db_middleware, auth_middleware, error_middleware])
@ -179,7 +206,8 @@ async def run():
web.get("/m/{member}", Handlers.get_member), web.get("/m/{member}", Handlers.get_member),
web.post("/m", Handlers.post_member), web.post("/m", Handlers.post_member),
web.patch("/m/{member}", Handlers.patch_member), web.patch("/m/{member}", Handlers.patch_member),
web.delete("/m/{member}", Handlers.delete_member) web.delete("/m/{member}", Handlers.delete_member),
web.post("/discord_oauth", Handlers.discord_oauth)
]) ])
app["pool"] = await db.connect( app["pool"] = await db.connect(
os.environ["DATABASE_URI"] os.environ["DATABASE_URI"]

View File

@ -116,6 +116,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
await db.update_system_field(conn, self.id, "token", new_token) await db.update_system_field(conn, self.id, "token", new_token)
return new_token return new_token
async def get_token(self, conn) -> str:
if self.token:
return self.token
return await self.refresh_token(conn)
async def create_member(self, conn, member_name: str) -> Member: async def create_member(self, conn, member_name: str) -> Member:
# TODO: figure out what to do if this errors out on collision on generate_hid # TODO: figure out what to do if this errors out on collision on generate_hid
new_hid = generate_hid() new_hid = generate_hid()