Add commands for API token retrieval/refreshing
This commit is contained in:
parent
a72a7c3de9
commit
6da7436aed
@ -129,6 +129,7 @@ class CommandContext:
|
|||||||
raise CommandError("Timed out - try again.")
|
raise CommandError("Timed out - try again.")
|
||||||
|
|
||||||
|
|
||||||
|
import pluralkit.bot.commands.api_commands
|
||||||
import pluralkit.bot.commands.import_commands
|
import pluralkit.bot.commands.import_commands
|
||||||
import pluralkit.bot.commands.member_commands
|
import pluralkit.bot.commands.member_commands
|
||||||
import pluralkit.bot.commands.message_commands
|
import pluralkit.bot.commands.message_commands
|
||||||
@ -179,7 +180,10 @@ async def command_dispatch(client: discord.Client, message: discord.Message, con
|
|||||||
|
|
||||||
(r"switch move", switch_commands.switch_move),
|
(r"switch move", switch_commands.switch_move),
|
||||||
(r"switch out", switch_commands.switch_out),
|
(r"switch out", switch_commands.switch_out),
|
||||||
(r"switch", switch_commands.switch_member)
|
(r"switch", switch_commands.switch_member),
|
||||||
|
|
||||||
|
(r"token (refresh|expire|update)", api_commands.refresh_token),
|
||||||
|
(r"token", api_commands.get_token)
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern, func in commands:
|
for pattern, func in commands:
|
||||||
|
31
src/pluralkit/bot/commands/api_commands.py
Normal file
31
src/pluralkit/bot/commands/api_commands.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
from discord import DMChannel
|
||||||
|
|
||||||
|
from pluralkit.bot.commands import CommandContext, CommandSuccess
|
||||||
|
|
||||||
|
logger = logging.getLogger("pluralkit.commands")
|
||||||
|
disclaimer = "Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure. If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`."
|
||||||
|
|
||||||
|
async def reply_dm(ctx: CommandContext, message: str):
|
||||||
|
await ctx.message.author.send(message)
|
||||||
|
|
||||||
|
if not isinstance(ctx.message.channel, DMChannel):
|
||||||
|
return CommandSuccess("DM'd!")
|
||||||
|
|
||||||
|
async def get_token(ctx: CommandContext):
|
||||||
|
system = await ctx.ensure_system()
|
||||||
|
|
||||||
|
if system.token:
|
||||||
|
token = system.token
|
||||||
|
else:
|
||||||
|
token = await system.refresh_token(ctx.conn)
|
||||||
|
|
||||||
|
token_message = "Here's your API token: \n**`{}`**\n{}".format(token, disclaimer)
|
||||||
|
return await reply_dm(ctx, token_message)
|
||||||
|
|
||||||
|
async def refresh_token(ctx: CommandContext):
|
||||||
|
system = await ctx.ensure_system()
|
||||||
|
|
||||||
|
token = await system.refresh_token(ctx.conn)
|
||||||
|
token_message = "Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\nHere's your new API token:\n**`{}`**\n{}".format(token, disclaimer)
|
||||||
|
return await reply_dm(ctx, token_message)
|
@ -1,7 +1,7 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import asyncpg
|
import asyncpg
|
||||||
@ -85,6 +85,11 @@ async def get_system_by_account(conn, account_id: int) -> System:
|
|||||||
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
|
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
|
||||||
return System(**row) if row else None
|
return System(**row) if row else None
|
||||||
|
|
||||||
|
@db_wrap
|
||||||
|
async def get_system_by_token(conn, token: str) -> Optional[System]:
|
||||||
|
row = await conn.fetchrow("select * from systems where token = $1", token)
|
||||||
|
return System(**row) if row else None
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_system_by_hid(conn, system_hid: str) -> System:
|
async def get_system_by_hid(conn, system_hid: str) -> System:
|
||||||
row = await conn.fetchrow("select * from systems where hid = $1", system_hid)
|
row = await conn.fetchrow("select * from systems where hid = $1", system_hid)
|
||||||
@ -323,6 +328,7 @@ async def create_tables(conn):
|
|||||||
description text,
|
description text,
|
||||||
tag text,
|
tag text,
|
||||||
avatar_url text,
|
avatar_url text,
|
||||||
|
token text,
|
||||||
created timestamp not null default (current_timestamp at time zone 'utc')
|
created timestamp not null default (current_timestamp at time zone 'utc')
|
||||||
)""")
|
)""")
|
||||||
await conn.execute("""create table if not exists members (
|
await conn.execute("""create table if not exists members (
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import random
|
||||||
|
import string
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from collections.__init__ import namedtuple
|
from collections.__init__ import namedtuple
|
||||||
@ -9,21 +11,26 @@ from pluralkit.switch import Switch
|
|||||||
from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise
|
from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise
|
||||||
|
|
||||||
|
|
||||||
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])):
|
class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "token", "created"])):
|
||||||
id: int
|
id: int
|
||||||
hid: str
|
hid: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
tag: str
|
tag: str
|
||||||
avatar_url: str
|
avatar_url: str
|
||||||
|
token: str
|
||||||
created: datetime
|
created: datetime
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_by_account(conn, account_id: str) -> "System":
|
async def get_by_account(conn, account_id: int) -> Optional["System"]:
|
||||||
return await db.get_system_by_account(conn, account_id)
|
return await db.get_system_by_account(conn, account_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_system(conn, account_id: str, system_name: Optional[str] = None) -> "System":
|
async def get_by_token(conn, token: str) -> Optional["System"]:
|
||||||
|
return await db.get_system_by_token(conn, token)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
existing_system = await System.get_by_account(conn, account_id)
|
existing_system = await System.get_by_account(conn, account_id)
|
||||||
if existing_system:
|
if existing_system:
|
||||||
@ -66,7 +73,7 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
|
|
||||||
await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url)
|
await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url)
|
||||||
|
|
||||||
async def link_account(self, conn, new_account_id: str):
|
async def link_account(self, conn, new_account_id: int):
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
existing_system = await System.get_by_account(conn, new_account_id)
|
existing_system = await System.get_by_account(conn, new_account_id)
|
||||||
|
|
||||||
@ -78,7 +85,7 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
|
|
||||||
await db.link_account(conn, self.id, new_account_id)
|
await db.link_account(conn, self.id, new_account_id)
|
||||||
|
|
||||||
async def unlink_account(self, conn, account_id: str):
|
async def unlink_account(self, conn, account_id: int):
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
linked_accounts = await db.get_linked_accounts(conn, self.id)
|
linked_accounts = await db.get_linked_accounts(conn, self.id)
|
||||||
if len(linked_accounts) == 1:
|
if len(linked_accounts) == 1:
|
||||||
@ -92,6 +99,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
|||||||
async def delete(self, conn):
|
async def delete(self, conn):
|
||||||
await db.remove_system(conn, self.id)
|
await db.remove_system(conn, self.id)
|
||||||
|
|
||||||
|
async def refresh_token(self, conn) -> str:
|
||||||
|
new_token = "".join(random.choices(string.ascii_letters + string.digits, k=64))
|
||||||
|
await db.update_system_field(conn, self.id, "token", new_token)
|
||||||
|
return new_token
|
||||||
|
|
||||||
async def create_member(self, conn, member_name: str) -> Member:
|
async def create_member(self, conn, member_name: str) -> Member:
|
||||||
# TODO: figure out what to do if this errors out on collision on generate_hid
|
# TODO: figure out what to do if this errors out on collision on generate_hid
|
||||||
new_hid = generate_hid()
|
new_hid = generate_hid()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user