Initial commit

This commit is contained in:
Ske 2018-07-12 00:47:44 +02:00
commit b81b768b04
11 changed files with 1022 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
.env
.vscode/

11
bot/Dockerfile Normal file
View File

@ -0,0 +1,11 @@
FROM python:3.6-alpine
RUN apk --no-cache add build-base
WORKDIR /app
ADD requirements.txt /app
RUN pip install --trusted-host pypi.python.org -r requirements.txt
ADD . /app
ENTRYPOINT ["python", "main.py"]

6
bot/main.py Normal file
View File

@ -0,0 +1,6 @@
import asyncio
from pluralkit import bot
loop = asyncio.get_event_loop()
loop.run_until_complete(bot.run())

View File

@ -0,0 +1 @@
from . import commands, db, proxy

82
bot/pluralkit/bot.py Normal file
View File

@ -0,0 +1,82 @@
import logging
import os
import discord
logging.basicConfig(level=logging.DEBUG)
logging.getLogger("discord").setLevel(logging.INFO)
logging.getLogger("websockets").setLevel(logging.INFO)
logger = logging.getLogger("pluralkit.bot")
logger.setLevel(logging.DEBUG)
client = discord.Client()
@client.event
async def on_error(evt, *args, **kwargs):
logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
@client.event
async def on_ready():
# Print status info
logger.info("Connected to Discord.")
logger.info("Account: {}#{}".format(client.user.name, client.user.discriminator))
logger.info("User ID: {}".format(client.user.id))
@client.event
async def on_message(message):
# Ignore bot messages
if message.author.bot:
return
# Split into args. shlex sucks so we don't bother with quotes
args = message.content.split(" ")
from pluralkit import proxy, utils
# Find and execute command in map
if len(args) > 0 and args[0] in utils.command_map:
subcommand_map = utils.command_map[args[0]]
if len(args) >= 2 and args[1] in subcommand_map:
async with client.pool.acquire() as conn:
await subcommand_map[args[1]][0](conn, message, args[2:])
elif None in subcommand_map:
async with client.pool.acquire() as conn:
await subcommand_map[None][0](conn, message, args[1:])
elif len(args) >= 2:
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = "Subcommand \"{}\" not found.".format(args[1])
await client.send_message(message.channel, embed=embed)
else:
# Try doing proxy parsing
async with client.pool.acquire() as conn:
await proxy.handle_proxying(conn, message)
@client.event
async def on_reaction_add(reaction, user):
from pluralkit import proxy
# Pass reactions to proxy system
async with client.pool.acquire() as conn:
await proxy.handle_reaction(conn, reaction, user)
async def run():
from pluralkit import db
try:
logger.info("Connecting to database...")
pool = await db.connect()
logger.info("Attempting to create tables...")
async with pool.acquire() as conn:
await db.create_tables(conn)
logger.info("Connecting to InfluxDB...")
client.pool = pool
logger.info("Connecting to Discord...")
await client.start(os.environ["TOKEN"])
finally:
logger.info("Logging out from Discord...")
await client.logout()

392
bot/pluralkit/commands.py Normal file
View File

@ -0,0 +1,392 @@
from datetime import datetime
import re
from urllib.parse import urlparse
import discord
from pluralkit import db
from pluralkit.bot import client, logger
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map
@command(cmd="pk;system", subcommand=None, description="Shows information about your system.")
async def this_system_info(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
return True
@command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.")
async def new_system(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is not None:
return False, "You already have a system registered. To remove your system, use `pk;system remove`, or to unlink your system from this account, use `pk;system unlink`."
system_name = None
if len(args) > 2:
system_name = " ".join(args[2:])
async with conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = generate_hid()
system = await db.create_system(conn, system_name=system_name, system_hid=hid)
# Link account
await db.link_account(conn, system_id=system["id"], account_id=message.author.id)
return True, "System registered! To begin adding members, use `pk;member new <name>`."
@command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.")
async def system_info(conn, message, args):
if len(args) == 0:
# Use sender's system
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
else:
# Look one up
system = await get_system_fuzzy(conn, args[0])
if system is None:
return False, "Unable to find system \"{}\".".format(args[0])
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
return True
@command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.")
async def system_name(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
if len(args) == 0:
new_name = None
else:
new_name = " ".join(args)
async with conn.transaction():
await db.update_system_field(conn, system_id=system["id"], field="name", value=new_name)
return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared."
@command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.")
async def system_description(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
# If "clear" in args, clear
if len(args) > 0 and args[0] == "clear":
new_description = None
else:
new_description = await text_input(message, "your system")
if not new_description:
return True, "Description update cancelled."
async with conn.transaction():
await db.update_system_field(conn, system_id=system["id"], field="description", value=new_description)
return True, "Description set." if new_description else "Description cleared."
@command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.")
async def system_tag(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
if len(args) == 0:
tag = None
else:
tag = " ".join(args)
max_length = 32
# Make sure there are no members which would make the combined length exceed 32
members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=max_length - len(tag))
if len(members_exceeding) > 0:
# If so, error out and warn
member_names = ", ".join([member["name"] for member in members_exceeding])
logger.debug("Members exceeding combined length with tag '{}': {}".format(tag, member_names))
return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names)
async with conn.transaction():
await db.update_system_field(conn, system_id=system["id"], field="tag", value=tag)
return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared."
@command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.")
async def system_remove(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"]))
msg = await client.wait_for_message(author=message.author, channel=message.channel)
if msg.content == system["hid"]:
await db.remove_system(conn, system_id=system["id"])
return True, "System removed."
else:
return True, "System removal cancelled."
@command(cmd="pk;system", subcommand="link", usage="<account>", description="Links another account to your system.")
async def system_link(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
if len(args) == 0:
return False
# Find account to link
linkee = await parse_mention(args[0])
if not linkee:
return False, "Account not found."
# Make sure account doesn't already have a system
account_system = await db.get_system_by_account(conn, linkee.id)
if account_system:
return False, "Account is already linked to a system (`{}`)".format(account_system["hid"])
# Send confirmation message
msg = await client.send_message(message.channel, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
await client.add_reaction(msg, "")
await client.add_reaction(msg, "")
reaction = await client.wait_for_reaction(emoji=["", ""], message=msg, user=linkee)
# If account to be linked confirms...
if reaction.reaction.emoji == "":
async with conn.transaction():
# Execute the link
await db.link_account(conn, system_id=system["id"], account_id=linkee.id)
return True, "Account linked to system."
else:
await client.clear_reactions(msg)
return False, "Account link cancelled."
@command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.")
async def system_unlink(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
# Make sure you can't unlink every account
linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"])
if len(linked_accounts) == 1:
return False, "This is the only account on your system, so you can't unlink it."
async with conn.transaction():
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
return True, "Account unlinked."
@command(cmd="pk;member", subcommand="new", usage="<name>", description="Adds a new member to your system.")
async def new_member(conn, message, args):
system = await db.get_system_by_account(conn, message.author.id)
if system is None:
return False, "No system is registered to this account."
if len(args) == 0:
return False
name = " ".join(args)
async with conn.transaction():
# TODO: figure out what to do if this errors out on collision on generate_hid
hid = generate_hid()
# Insert member row
await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid)
return True, "Member \"{}\" (`{}`) registered!".format(name, hid)
@member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False)
async def member_info(conn, message, member, args):
await client.send_message(message.channel, embed=await generate_member_info_card(conn, member))
return True
@member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.")
async def member_color(conn, message, member, args):
if len(args) == 0:
color = None
else:
match = re.fullmatch("#?([0-9a-f]{6})", args[0])
if not match:
return False, "Color must be a valid hex color (eg. #ff0000)"
color = match.group(1)
async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="color", value=color)
return True, "Color updated to #{}.".format(color) if color else "Color cleared."
@member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.")
async def member_pronouns(conn, message, member, args):
if len(args) == 0:
pronouns = None
else:
pronouns = " ".join(args)
async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="pronouns", value=pronouns)
return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared."
@member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.")
async def member_birthday(conn, message, member, args):
if len(args) == 0:
new_date = None
else:
# Parse date
try:
new_date = datetime.strptime(args[0], "%Y-%m-%d").date()
except ValueError:
return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)."
async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="birthday", value=new_date)
return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared."
@member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.")
async def member_description(conn, message, member, args):
if len(args) > 0 and args[0] == "clear":
new_description = None
else:
new_description = await text_input(message, member["name"])
if not new_description:
return True, "Description update cancelled."
async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="description", value=new_description)
return True, "Description set." if new_description else "Description cleared."
@member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.")
async def member_remove(conn, message, member, args):
await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"]))
msg = await client.wait_for_message(author=message.author, channel=message.channel)
if msg.content == member["name"]:
await db.delete_member(conn, member_id=member["id"])
return True, "Member removed."
else:
return True, "Member removal cancelled."
@member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.")
async def member_avatar(conn, message, member, args):
if len(args) == 0:
avatar_url = None
else:
user = await parse_mention(args[0])
if user:
# Set the avatar to the mentioned user's avatar
# Discord doesn't like webp, but also hosts png alternatives
avatar_url = user.avatar_url.replace(".webp", ".png")
else:
# Validate URL
u = urlparse(args[0])
if u.scheme in ["http", "https"] and u.netloc and u.path:
avatar_url = args[0]
else:
return False, "Invalid URL."
async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="avatar_url", value=avatar_url)
return True, "Avatar set." if avatar_url else "Avatar cleared."
@member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).")
async def member_proxy(conn, message, member, args):
if len(args) == 0:
prefix, suffix = None, None
else:
# Sanity checking
example = " ".join(args)
if "text" not in example:
return False, "Example proxy message must contain the string 'text'."
if example.count("text") != 1:
return False, "Example proxy message must contain the string 'text' exactly once."
# Extract prefix and suffix
prefix = example[:example.index("text")].strip()
suffix = example[example.index("text")+4:].strip()
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
# DB stores empty strings as None, make that work
if not prefix:
prefix = None
if not suffix:
suffix = None
async with conn.transaction():
await db.update_member_field(conn, member_id=member["id"], field="prefix", value=prefix)
await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix)
return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
@command(cmd="pk;message", subcommand=None, usage="<id>", description="Shows information about a proxied message. Requires the message ID.")
async def message_info(conn, message, args):
try:
mid = int(args[0])
except ValueError:
return False
# Find the message in the DB
message_row = await db.get_message(conn, mid)
if not message_row:
return False, "Message not found."
# Find the actual message object
channel = client.get_channel(str(message_row["channel"]))
message = await client.get_message(channel, str(message_row["mid"]))
# Get the original sender of the message
original_sender = await client.get_user_info(str(message_row["sender"]))
# Get sender member and system
member = await db.get_member(conn, message_row["member"])
system = await db.get_system(conn, member["system"])
embed = discord.Embed()
embed.timestamp = message.timestamp
embed.colour = discord.Colour.blue()
if system["name"]:
system_value = "`{}`: {}".format(system["hid"], system["name"])
else:
system_value = "`{}`".format(system["hid"])
embed.add_field(name="System", value=system_value)
embed.add_field(name="Member", value="`{}`: {}".format(member["hid"], member["name"]))
embed.add_field(name="Sent by", value="{}#{}".format(original_sender.name, original_sender.discriminator))
embed.add_field(name="Content", value=message.clean_content, inline=False)
embed.set_author(name=member["name"], url=member["avatar_url"])
await client.send_message(message.channel, embed=embed)
return True
@command(cmd="pk;help", subcommand=None, usage="[system|member|message]", description="Shows this help message.")
async def show_help(conn, message, args):
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.title = "PluralKit Help"
embed.set_footer(text="<> denotes mandatory arguments, [] denotes optional arguments")
if len(args) > 0 and ("pk;" + args[0]) in command_map:
cmds = ["", ("pk;" + args[0], command_map["pk;" + args[0]])]
else:
cmds = command_map.items()
for cmd, subcommands in cmds:
for subcmd, (_, usage, description) in subcommands.items():
embed.add_field(name="{} {} {}".format(cmd, subcmd or "", usage or ""), value=description, inline=False)
await client.send_message(message.channel, embed=embed)
return True

193
bot/pluralkit/db.py Normal file
View File

@ -0,0 +1,193 @@
import time
import asyncpg
import asyncpg.exceptions
from pluralkit.bot import logger
async def connect():
while True:
try:
return await asyncpg.create_pool(user="postgres", password="postgres", database="postgres", host="db")
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
pass
def db_wrap(func):
async def inner(*args, **kwargs):
before = time.perf_counter()
res = await func(*args, **kwargs)
after = time.perf_counter()
logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000))
return res
return inner
webhook_cache = {}
@db_wrap
async def create_system(conn, system_name: str, system_hid: str):
logger.debug("Creating system (name={}, hid={})".format(system_name, system_hid))
return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
@db_wrap
async def remove_system(conn, system_id: int):
logger.debug("Deleting system (id={})".format(system_id))
await conn.execute("delete from systems where id = $1", system_id)
@db_wrap
async def create_member(conn, system_id: int, member_name: str, member_hid: str):
logger.debug("Creating member (system={}, name={}, hid={})".format(system_id, member_name, member_hid))
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
@db_wrap
async def delete_member(conn, member_id: int):
logger.debug("Deleting member (id={})".format(member_id))
await conn.execute("update switches set member = null, member_del = true where member = $1", member_id)
await conn.execute("delete from members where id = $1", member_id)
@db_wrap
async def link_account(conn, system_id: int, account_id: str):
logger.debug("Linking account (account_id={}, system_id={})".format(account_id, system_id))
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
@db_wrap
async def unlink_account(conn, system_id: int, account_id: str):
logger.debug("Unlinking account (account_id={}, system_id={})".format(account_id, system_id))
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
@db_wrap
async def get_linked_accounts(conn, system_id: int):
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
@db_wrap
async def get_system_by_account(conn, account_id: str):
return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
@db_wrap
async def get_system_by_hid(conn, system_hid: str):
return await conn.fetchrow("select * from systems where hid = $1", system_hid)
@db_wrap
async def get_system(conn, system_id: int):
return await conn.fetchrow("select * from systems where id = $1", system_id)
@db_wrap
async def get_member_by_name(conn, system_id: int, member_name: str):
return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name)
@db_wrap
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str):
return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
@db_wrap
async def get_member_by_hid(conn, member_hid: str):
return await conn.fetchrow("select * from members where hid = $1", member_hid)
@db_wrap
async def get_member(conn, member_id: int):
return await conn.fetchrow("select * from members where id = $1", member_id)
@db_wrap
async def get_message(conn, message_id: str):
return await conn.fetchrow("select * from messages where mid = $1", message_id)
@db_wrap
async def update_system_field(conn, system_id: int, field: str, value):
logger.debug("Updating system field (id={}, {}={})".format(system_id, field, value))
await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id)
@db_wrap
async def update_member_field(conn, member_id: int, field: str, value):
logger.debug("Updating member field (id={}, {}={})".format(member_id, field, value))
await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id)
@db_wrap
async def get_all_members(conn, system_id: int):
return await conn.fetch("select * from members where system = $1", system_id)
@db_wrap
async def get_members_exceeding(conn, system_id: int, length: int):
return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length)
@db_wrap
async def get_webhook(conn, channel_id: str):
if channel_id in webhook_cache:
return webhook_cache[channel_id]
res = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
webhook_cache[channel_id] = res
return res
@db_wrap
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
@db_wrap
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str):
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id))
@db_wrap
async def get_members_by_account(conn, account_id: str):
# Returns a "chimera" object
return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id))
@db_wrap
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str):
await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id))
@db_wrap
async def delete_message(conn, message_id: str):
logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", int(message_id))
async def create_tables(conn):
await conn.execute("""create table if not exists systems (
id serial primary key,
hid char(5) unique not null,
name text,
description text,
tag text,
created timestamp not null default current_timestamp
)""")
await conn.execute("""create table if not exists members (
id serial primary key,
hid char(5) unique not null,
system serial not null references systems(id) on delete cascade,
color char(6),
avatar_url text,
name text not null,
birthday date,
pronouns text,
description text,
prefix text,
suffix text,
created timestamp not null default current_timestamp
)""")
await conn.execute("""create table if not exists accounts (
uid bigint primary key,
system serial not null references systems(id) on delete cascade
)""")
await conn.execute("""create table if not exists messages (
mid bigint primary key,
channel bigint not null,
member serial not null references members(id) on delete cascade,
sender bigint not null references accounts(uid)
)""")
await conn.execute("""create table if not exists switches (
id serial primary key,
system serial not null references systems(id) on delete cascade,
member serial references members(id) on delete restrict,
timestamp timestamp not null default current_timestamp,
member_del bool default false
)""")
await conn.execute("""create table if not exists webhooks (
channel bigint primary key,
webhook bigint not null,
token text not null
)""")
await conn.execute("""create table if not exists servers (
id bigint primary key,
cmd_chans bigint[],
proxy_chans bigint[]
)""")

93
bot/pluralkit/proxy.py Normal file
View File

@ -0,0 +1,93 @@
import os
import time
import aiohttp
from pluralkit import db
from pluralkit.bot import client, logger
async def get_webhook(conn, channel):
async with conn.transaction():
# Try to find an existing webhook
hook_row = await db.get_webhook(conn, channel_id=channel.id)
# There's none, we'll make one
if not hook_row:
async with aiohttp.ClientSession() as session:
req_data = {"name": "PluralKit Proxy Webhook"}
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])}
async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp:
data = await resp.json()
hook_id = data["id"]
token = data["token"]
# Insert new hook into DB
await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token)
return hook_id, token
return hook_row["webhook"], hook_row["token"]
async def proxy_message(conn, member, message, inner):
logger.debug("Proxying message '{}' for member {}".format(inner, member["hid"]))
# Delete the original message
await client.delete_message(message)
# Get the webhook details
hook_id, hook_token = await get_webhook(conn, message.channel)
async with aiohttp.ClientSession() as session:
req_data = {
"username": "{} {}".format(member["name"], member["tag"] or "").strip(),
"avatar_url": member["avatar_url"],
"content": inner
}
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])}
# And send the message
async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp:
resp_data = await resp.json()
logger.debug("Discord webhook response: {}".format(resp_data))
# Insert new message details into the DB
await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id)
async def handle_proxying(conn, message):
# Big fat query to find every member associated with this account
# Returned member object has a few more keys (system tag, for example)
members = await db.get_members_by_account(conn, account_id=message.author.id)
# Sort by specificity (members with both prefix and suffix go higher)
members = sorted(members, key=lambda x: int(bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True)
msg = message.content
for member in members:
# If no proxy details are configured, skip
if not member["prefix"] and not member["suffix"]:
continue
# Database stores empty strings as null, fix that here
prefix = member["prefix"] or ""
suffix = member["suffix"] or ""
# If we have a match, proxy the message
if msg.startswith(prefix) and msg.endswith(suffix):
# Extract the actual message contents sans tags
if suffix:
inner_message = message.content[len(prefix):-len(suffix)].strip()
else:
# Slicing to -0 breaks, don't do that
inner_message = message.content[len(prefix):].strip()
await proxy_message(conn, member, message, inner_message)
break
async def handle_reaction(conn, reaction, user):
if reaction.emoji == "":
async with conn.transaction():
# Find the message in the DB, and make sure it's sent by the user who reacted
message = await db.get_message_by_sender_and_id(conn, message_id=reaction.message.id, sender_id=user.id)
if message:
# If so, delete the message and remove it from the DB
await db.delete_message(conn, message["mid"])
await client.delete_message(reaction.message)

211
bot/pluralkit/utils.py Normal file
View File

@ -0,0 +1,211 @@
import random
import re
import string
import asyncio
import asyncpg
import discord
from pluralkit import db
from pluralkit.bot import client, logger
def generate_hid() -> str:
return "".join(random.choices(string.ascii_lowercase, k=5))
async def parse_mention(mention: str) -> discord.User:
# First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention)
if match:
try:
return await client.get_user_info(match.group(1))
except discord.NotFound:
return None
# Then try with just ID
try:
return await client.get_user_info(str(int(mention)))
except (ValueError, discord.NotFound):
return None
async def get_system_fuzzy(conn, key) -> asyncpg.Record:
if isinstance(key, discord.User):
return await db.get_system_by_account(conn, account_id=key.id)
if isinstance(key, str) and len(key) == 5:
return await db.get_system_by_hid(conn, system_hid=key)
system = parse_mention(key)
if system:
return system
return None
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record:
# First search by hid
if system_only:
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
else:
member = await db.get_member_by_hid(conn, member_hid=key)
if member is not None:
return member
# Then search by name, if we have a system
if system_id:
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
if member is not None:
return member
command_map = {}
# Command wrapper
# Return True for success, return False for failure
# Second parameter is the message it'll send. If just False, will print usage
def command(cmd, subcommand, usage=None, description=None):
def wrap(func):
async def wrapper(conn, message, args):
res = await func(conn, message, args)
if res is not None:
if not isinstance(res, tuple):
success, msg = res, None
else:
success, msg = res
if not success and not msg:
# Failure, no message, print usage
usage_embed = discord.Embed()
usage_embed.colour = discord.Colour.blue()
usage_embed.add_field(name="Usage", value=usage, inline=False)
await client.send_message(message.channel, embed=usage_embed)
elif not success:
# Failure, print message
error_embed = discord.Embed()
error_embed.colour = discord.Colour.dark_red()
error_embed.description = msg
await client.send_message(message.channel, embed=error_embed)
elif msg:
# Success, print message
success_embed = discord.Embed()
success_embed.colour = discord.Colour.blue()
success_embed.description = msg
await client.send_message(message.channel, embed=success_embed)
# Success, don't print anything
if cmd not in command_map:
command_map[cmd] = {}
if subcommand not in command_map[cmd]:
command_map[cmd][subcommand] = {}
command_map[cmd][subcommand] = (wrapper, usage, description)
return wrapper
return wrap
# Member command wrapper
# Tries to find member by first argument
# If system_only=False, allows members from other systems by hid
def member_command(cmd, subcommand, usage=None, description=None, system_only=True):
def wrap(func):
async def wrapper(conn, message, args):
# Return if no member param
if len(args) == 0:
return False
# If system_only, we need a system to check
system = await db.get_system_by_account(conn, message.author.id)
if system_only and system is None:
return False, "No system is registered to this account."
# System is allowed to be none if not system_only
system_id = system["id"] if system else None
# And find member by key
member = await get_member_fuzzy(conn, system_id=system_id, key=args[0], system_only=system_only)
if member is None:
return False, "Can't find member \"{}\".".format(args[0])
return await func(conn, message, member, args[1:])
return command(cmd=cmd, subcommand=subcommand, usage=usage, description=description)(wrapper)
return wrap
async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed:
card = discord.Embed()
if system["name"]:
card.title = system["name"]
if system["description"]:
card.add_field(name="Description", value=system["description"], inline=False)
if system["tag"]:
card.add_field(name="Tag", value=system["tag"])
# Get names of all linked accounts
async def get_name(account_id):
account = await client.get_user_info(account_id)
return "{}#{}".format(account.name, account.discriminator)
account_name_futures = []
for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
account_name_futures.append(get_name(account_id))
# Run in parallel
account_names = await asyncio.gather(*account_name_futures)
card.add_field(name="Linked accounts", value=", ".join(account_names))
# Get names of all members
member_texts = []
for member in await db.get_all_members(conn, system_id=system["id"]):
member_texts.append("`{}`: {}".format(member["hid"], member["name"]))
if len(member_texts) > 0:
card.add_field(name="Members", value="\n".join(member_texts), inline=False)
card.set_footer(text="System ID: {}".format(system["hid"]))
return card
async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed:
card = discord.Embed()
card.set_author(name=member["name"], icon_url=member["avatar_url"])
if member["color"]:
card.colour = int(member["color"], 16)
if member["birthday"]:
card.add_field(name="Birthdate", value=member["birthday"].strftime("%b %d, %Y"))
if member["pronouns"]:
card.add_field(name="Pronouns", value=member["pronouns"])
if member["prefix"] or member["suffix"]:
prefix = member["prefix"] or ""
suffix = member["suffix"] or ""
card.add_field(name="Proxy Tags", value="{}text{}".format(prefix, suffix))
if member["description"]:
card.add_field(name="Description", value=member["description"], inline=False)
# Get system name and hid
system = await db.get_system(conn, system_id=member["system"])
if system["name"]:
system_value = "`{}`: {}".format(system["hid"], system["name"])
else:
system_value = "`{}`".format(system["hid"])
card.add_field(name="System", value=system_value, inline=False)
card.set_footer(text="System ID: {} | Member ID: {}".format(system["hid"], member["hid"]))
return card
async def text_input(message, subject):
await client.send_message(message.channel, "Reply in this channel with the new description you want to set for {}.".format(subject))
msg = await client.wait_for_message(author=message.author, channel=message.channel)
await client.send_message(message.channel, "Alright. When you're happy with the new description, click the ✅ reaction. To cancel, click the ❌ reaction.")
await client.add_reaction(msg, "")
await client.add_reaction(msg, "")
reaction = await client.wait_for_reaction(emoji=["", ""], message=msg, user=message.author)
if reaction.reaction.emoji == "":
return msg.content
else:
await client.clear_reactions(msg)
return None

4
bot/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
discord.py
asyncpg
aiohttp
aioinflux

27
docker-compose.yml Normal file
View File

@ -0,0 +1,27 @@
version: '3'
services:
bot:
build: bot
depends_on:
- db
environment:
- TOKEN
db:
image: postgres:alpine
volumes:
- db_data:/var/lib/postgres
restart: always
influx:
image: influxdb:alpine
volumes:
- /var/lib/influxdb
environment:
- INFLUXDB_GRAPHITE_ENABLED=true
restart: always
grafana:
image: grafana/grafana
ports:
- "3000:3000"
restart: always
volumes:
db_data: