Add commands for API token retrieval/refreshing

This commit is contained in:
Ske 2018-11-13 12:34:19 +01:00
parent a72a7c3de9
commit 6da7436aed
4 changed files with 60 additions and 7 deletions

View File

@ -129,6 +129,7 @@ class CommandContext:
raise CommandError("Timed out - try again.") raise CommandError("Timed out - try again.")
import pluralkit.bot.commands.api_commands
import pluralkit.bot.commands.import_commands import pluralkit.bot.commands.import_commands
import pluralkit.bot.commands.member_commands import pluralkit.bot.commands.member_commands
import pluralkit.bot.commands.message_commands import pluralkit.bot.commands.message_commands
@ -179,7 +180,10 @@ async def command_dispatch(client: discord.Client, message: discord.Message, con
(r"switch move", switch_commands.switch_move), (r"switch move", switch_commands.switch_move),
(r"switch out", switch_commands.switch_out), (r"switch out", switch_commands.switch_out),
(r"switch", switch_commands.switch_member) (r"switch", switch_commands.switch_member),
(r"token (refresh|expire|update)", api_commands.refresh_token),
(r"token", api_commands.get_token)
] ]
for pattern, func in commands: for pattern, func in commands:

View File

@ -0,0 +1,31 @@
import logging
from discord import DMChannel
from pluralkit.bot.commands import CommandContext, CommandSuccess
logger = logging.getLogger("pluralkit.commands")
disclaimer = "Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure. If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`."
async def reply_dm(ctx: CommandContext, message: str):
await ctx.message.author.send(message)
if not isinstance(ctx.message.channel, DMChannel):
return CommandSuccess("DM'd!")
async def get_token(ctx: CommandContext):
system = await ctx.ensure_system()
if system.token:
token = system.token
else:
token = await system.refresh_token(ctx.conn)
token_message = "Here's your API token: \n**`{}`**\n{}".format(token, disclaimer)
return await reply_dm(ctx, token_message)
async def refresh_token(ctx: CommandContext):
system = await ctx.ensure_system()
token = await system.refresh_token(ctx.conn)
token_message = "Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\nHere's your new API token:\n**`{}`**\n{}".format(token, disclaimer)
return await reply_dm(ctx, token_message)

View File

@ -1,7 +1,7 @@
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime
import logging import logging
from typing import List from typing import List, Optional
import time import time
import asyncpg import asyncpg
@ -85,6 +85,11 @@ async def get_system_by_account(conn, account_id: int) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id) row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
return System(**row) if row else None return System(**row) if row else None
@db_wrap
async def get_system_by_token(conn, token: str) -> Optional[System]:
row = await conn.fetchrow("select * from systems where token = $1", token)
return System(**row) if row else None
@db_wrap @db_wrap
async def get_system_by_hid(conn, system_hid: str) -> System: async def get_system_by_hid(conn, system_hid: str) -> System:
row = await conn.fetchrow("select * from systems where hid = $1", system_hid) row = await conn.fetchrow("select * from systems where hid = $1", system_hid)
@ -323,6 +328,7 @@ async def create_tables(conn):
description text, description text,
tag text, tag text,
avatar_url text, avatar_url text,
token text,
created timestamp not null default (current_timestamp at time zone 'utc') created timestamp not null default (current_timestamp at time zone 'utc')
)""") )""")
await conn.execute("""create table if not exists members ( await conn.execute("""create table if not exists members (

View File

@ -1,3 +1,5 @@
import random
import string
from datetime import datetime from datetime import datetime
from collections.__init__ import namedtuple from collections.__init__ import namedtuple
@ -9,21 +11,26 @@ from pluralkit.switch import Switch
from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])): class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "token", "created"])):
id: int id: int
hid: str hid: str
name: str name: str
description: str description: str
tag: str tag: str
avatar_url: str avatar_url: str
token: str
created: datetime created: datetime
@staticmethod @staticmethod
async def get_by_account(conn, account_id: str) -> "System": async def get_by_account(conn, account_id: int) -> Optional["System"]:
return await db.get_system_by_account(conn, account_id) return await db.get_system_by_account(conn, account_id)
@staticmethod @staticmethod
async def create_system(conn, account_id: str, system_name: Optional[str] = None) -> "System": async def get_by_token(conn, token: str) -> Optional["System"]:
return await db.get_system_by_token(conn, token)
@staticmethod
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
async with conn.transaction(): async with conn.transaction():
existing_system = await System.get_by_account(conn, account_id) existing_system = await System.get_by_account(conn, account_id)
if existing_system: if existing_system:
@ -66,7 +73,7 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url) await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url)
async def link_account(self, conn, new_account_id: str): async def link_account(self, conn, new_account_id: int):
async with conn.transaction(): async with conn.transaction():
existing_system = await System.get_by_account(conn, new_account_id) existing_system = await System.get_by_account(conn, new_account_id)
@ -78,7 +85,7 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
await db.link_account(conn, self.id, new_account_id) await db.link_account(conn, self.id, new_account_id)
async def unlink_account(self, conn, account_id: str): async def unlink_account(self, conn, account_id: int):
async with conn.transaction(): async with conn.transaction():
linked_accounts = await db.get_linked_accounts(conn, self.id) linked_accounts = await db.get_linked_accounts(conn, self.id)
if len(linked_accounts) == 1: if len(linked_accounts) == 1:
@ -92,6 +99,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
async def delete(self, conn): async def delete(self, conn):
await db.remove_system(conn, self.id) await db.remove_system(conn, self.id)
async def refresh_token(self, conn) -> str:
new_token = "".join(random.choices(string.ascii_letters + string.digits, k=64))
await db.update_system_field(conn, self.id, "token", new_token)
return new_token
async def create_member(self, conn, member_name: str) -> Member: async def create_member(self, conn, member_name: str) -> Member:
# TODO: figure out what to do if this errors out on collision on generate_hid # TODO: figure out what to do if this errors out on collision on generate_hid
new_hid = generate_hid() new_hid = generate_hid()