Initial commit
This commit is contained in:
commit
b81b768b04
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
.env
|
||||
.vscode/
|
11
bot/Dockerfile
Normal file
11
bot/Dockerfile
Normal file
@ -0,0 +1,11 @@
|
||||
FROM python:3.6-alpine
|
||||
|
||||
RUN apk --no-cache add build-base
|
||||
|
||||
WORKDIR /app
|
||||
ADD requirements.txt /app
|
||||
RUN pip install --trusted-host pypi.python.org -r requirements.txt
|
||||
|
||||
ADD . /app
|
||||
ENTRYPOINT ["python", "main.py"]
|
||||
|
6
bot/main.py
Normal file
6
bot/main.py
Normal file
@ -0,0 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from pluralkit import bot
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(bot.run())
|
1
bot/pluralkit/__init__.py
Normal file
1
bot/pluralkit/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import commands, db, proxy
|
82
bot/pluralkit/bot.py
Normal file
82
bot/pluralkit/bot.py
Normal file
@ -0,0 +1,82 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import discord
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.getLogger("discord").setLevel(logging.INFO)
|
||||
logging.getLogger("websockets").setLevel(logging.INFO)
|
||||
|
||||
logger = logging.getLogger("pluralkit.bot")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
client = discord.Client()
|
||||
|
||||
@client.event
|
||||
async def on_error(evt, *args, **kwargs):
|
||||
logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
# Print status info
|
||||
logger.info("Connected to Discord.")
|
||||
logger.info("Account: {}#{}".format(client.user.name, client.user.discriminator))
|
||||
logger.info("User ID: {}".format(client.user.id))
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
# Ignore bot messages
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Split into args. shlex sucks so we don't bother with quotes
|
||||
args = message.content.split(" ")
|
||||
|
||||
from pluralkit import proxy, utils
|
||||
|
||||
# Find and execute command in map
|
||||
if len(args) > 0 and args[0] in utils.command_map:
|
||||
subcommand_map = utils.command_map[args[0]]
|
||||
|
||||
if len(args) >= 2 and args[1] in subcommand_map:
|
||||
async with client.pool.acquire() as conn:
|
||||
await subcommand_map[args[1]][0](conn, message, args[2:])
|
||||
elif None in subcommand_map:
|
||||
async with client.pool.acquire() as conn:
|
||||
await subcommand_map[None][0](conn, message, args[1:])
|
||||
elif len(args) >= 2:
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.dark_red()
|
||||
embed.description = "Subcommand \"{}\" not found.".format(args[1])
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
else:
|
||||
# Try doing proxy parsing
|
||||
async with client.pool.acquire() as conn:
|
||||
await proxy.handle_proxying(conn, message)
|
||||
|
||||
@client.event
|
||||
async def on_reaction_add(reaction, user):
|
||||
from pluralkit import proxy
|
||||
|
||||
# Pass reactions to proxy system
|
||||
async with client.pool.acquire() as conn:
|
||||
await proxy.handle_reaction(conn, reaction, user)
|
||||
|
||||
async def run():
|
||||
from pluralkit import db
|
||||
try:
|
||||
logger.info("Connecting to database...")
|
||||
pool = await db.connect()
|
||||
|
||||
logger.info("Attempting to create tables...")
|
||||
async with pool.acquire() as conn:
|
||||
await db.create_tables(conn)
|
||||
|
||||
logger.info("Connecting to InfluxDB...")
|
||||
|
||||
client.pool = pool
|
||||
logger.info("Connecting to Discord...")
|
||||
await client.start(os.environ["TOKEN"])
|
||||
finally:
|
||||
logger.info("Logging out from Discord...")
|
||||
await client.logout()
|
392
bot/pluralkit/commands.py
Normal file
392
bot/pluralkit/commands.py
Normal file
@ -0,0 +1,392 @@
|
||||
from datetime import datetime
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import discord
|
||||
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import client, logger
|
||||
from pluralkit.utils import command, generate_hid, generate_member_info_card, generate_system_info_card, member_command, parse_mention, text_input, get_system_fuzzy, get_member_fuzzy, command_map
|
||||
|
||||
@command(cmd="pk;system", subcommand=None, description="Shows information about your system.")
|
||||
async def this_system_info(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
|
||||
return True
|
||||
|
||||
@command(cmd="pk;system", subcommand="new", usage="[name]", description="Registers a new system to this account.")
|
||||
async def new_system(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is not None:
|
||||
return False, "You already have a system registered. To remove your system, use `pk;system remove`, or to unlink your system from this account, use `pk;system unlink`."
|
||||
|
||||
system_name = None
|
||||
if len(args) > 2:
|
||||
system_name = " ".join(args[2:])
|
||||
|
||||
async with conn.transaction():
|
||||
# TODO: figure out what to do if this errors out on collision on generate_hid
|
||||
hid = generate_hid()
|
||||
|
||||
system = await db.create_system(conn, system_name=system_name, system_hid=hid)
|
||||
|
||||
# Link account
|
||||
await db.link_account(conn, system_id=system["id"], account_id=message.author.id)
|
||||
return True, "System registered! To begin adding members, use `pk;member new <name>`."
|
||||
|
||||
@command(cmd="pk;system", subcommand="info", usage="[system]", description="Shows information about a system.")
|
||||
async def system_info(conn, message, args):
|
||||
if len(args) == 0:
|
||||
# Use sender's system
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
else:
|
||||
# Look one up
|
||||
system = await get_system_fuzzy(conn, args[0])
|
||||
|
||||
if system is None:
|
||||
return False, "Unable to find system \"{}\".".format(args[0])
|
||||
|
||||
await client.send_message(message.channel, embed=await generate_system_info_card(conn, system))
|
||||
return True
|
||||
|
||||
@command(cmd="pk;system", subcommand="name", usage="[name]", description="Renames your system. Leave blank to clear.")
|
||||
async def system_name(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
if len(args) == 0:
|
||||
new_name = None
|
||||
else:
|
||||
new_name = " ".join(args)
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_system_field(conn, system_id=system["id"], field="name", value=new_name)
|
||||
return True, "Name updated to {}.".format(new_name) if new_name else "Name cleared."
|
||||
|
||||
@command(cmd="pk;system", subcommand="description", usage="[clear]", description="Updates your system description. Add \"clear\" to clear.")
|
||||
async def system_description(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
# If "clear" in args, clear
|
||||
if len(args) > 0 and args[0] == "clear":
|
||||
new_description = None
|
||||
else:
|
||||
new_description = await text_input(message, "your system")
|
||||
|
||||
if not new_description:
|
||||
return True, "Description update cancelled."
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_system_field(conn, system_id=system["id"], field="description", value=new_description)
|
||||
return True, "Description set." if new_description else "Description cleared."
|
||||
|
||||
@command(cmd="pk;system", subcommand="tag", usage="[tag]", description="Updates your system tag. Leave blank to clear.")
|
||||
async def system_tag(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
if len(args) == 0:
|
||||
tag = None
|
||||
else:
|
||||
tag = " ".join(args)
|
||||
max_length = 32
|
||||
|
||||
# Make sure there are no members which would make the combined length exceed 32
|
||||
members_exceeding = await db.get_members_exceeding(conn, system_id=system["id"], length=max_length - len(tag))
|
||||
if len(members_exceeding) > 0:
|
||||
# If so, error out and warn
|
||||
member_names = ", ".join([member["name"] for member in members_exceeding])
|
||||
logger.debug("Members exceeding combined length with tag '{}': {}".format(tag, member_names))
|
||||
return False, "The maximum length of a name plus the system tag is 32 characters. The following members would exceed the limit: {}. Please reduce the length of the tag, or rename the members.".format(member_names)
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_system_field(conn, system_id=system["id"], field="tag", value=tag)
|
||||
|
||||
return True, "Tag updated to {}.".format(tag) if tag else "Tag cleared."
|
||||
|
||||
@command(cmd="pk;system", subcommand="remove", description="Removes your system ***permanently***.")
|
||||
async def system_remove(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
await client.send_message(message.channel, "Are you sure you want to remove your system? If so, reply to this message with the system's ID (`{}`).".format(system["hid"]))
|
||||
|
||||
msg = await client.wait_for_message(author=message.author, channel=message.channel)
|
||||
if msg.content == system["hid"]:
|
||||
await db.remove_system(conn, system_id=system["id"])
|
||||
return True, "System removed."
|
||||
else:
|
||||
return True, "System removal cancelled."
|
||||
|
||||
|
||||
@command(cmd="pk;system", subcommand="link", usage="<account>", description="Links another account to your system.")
|
||||
async def system_link(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
# Find account to link
|
||||
linkee = await parse_mention(args[0])
|
||||
if not linkee:
|
||||
return False, "Account not found."
|
||||
|
||||
# Make sure account doesn't already have a system
|
||||
account_system = await db.get_system_by_account(conn, linkee.id)
|
||||
if account_system:
|
||||
return False, "Account is already linked to a system (`{}`)".format(account_system["hid"])
|
||||
|
||||
# Send confirmation message
|
||||
msg = await client.send_message(message.channel, "{}, please confirm the link by clicking the ✅ reaction on this message.".format(linkee.mention))
|
||||
await client.add_reaction(msg, "✅")
|
||||
await client.add_reaction(msg, "❌")
|
||||
|
||||
reaction = await client.wait_for_reaction(emoji=["✅", "❌"], message=msg, user=linkee)
|
||||
# If account to be linked confirms...
|
||||
if reaction.reaction.emoji == "✅":
|
||||
async with conn.transaction():
|
||||
# Execute the link
|
||||
await db.link_account(conn, system_id=system["id"], account_id=linkee.id)
|
||||
return True, "Account linked to system."
|
||||
else:
|
||||
await client.clear_reactions(msg)
|
||||
return False, "Account link cancelled."
|
||||
|
||||
|
||||
@command(cmd="pk;system", subcommand="unlink", description="Unlinks your system from this account. There must be at least one other account linked.")
|
||||
async def system_unlink(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
# Make sure you can't unlink every account
|
||||
linked_accounts = await db.get_linked_accounts(conn, system_id=system["id"])
|
||||
if len(linked_accounts) == 1:
|
||||
return False, "This is the only account on your system, so you can't unlink it."
|
||||
|
||||
async with conn.transaction():
|
||||
await db.unlink_account(conn, system_id=system["id"], account_id=message.author.id)
|
||||
return True, "Account unlinked."
|
||||
|
||||
@command(cmd="pk;member", subcommand="new", usage="<name>", description="Adds a new member to your system.")
|
||||
async def new_member(conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
|
||||
if system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
name = " ".join(args)
|
||||
async with conn.transaction():
|
||||
# TODO: figure out what to do if this errors out on collision on generate_hid
|
||||
hid = generate_hid()
|
||||
|
||||
# Insert member row
|
||||
await db.create_member(conn, system_id=system["id"], member_name=name, member_hid=hid)
|
||||
return True, "Member \"{}\" (`{}`) registered!".format(name, hid)
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="info", description="Shows information about a system member.", system_only=False)
|
||||
async def member_info(conn, message, member, args):
|
||||
await client.send_message(message.channel, embed=await generate_member_info_card(conn, member))
|
||||
return True
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="color", usage="[color]", description="Updates a member's associated color. Leave blank to clear.")
|
||||
async def member_color(conn, message, member, args):
|
||||
if len(args) == 0:
|
||||
color = None
|
||||
else:
|
||||
match = re.fullmatch("#?([0-9a-f]{6})", args[0])
|
||||
if not match:
|
||||
return False, "Color must be a valid hex color (eg. #ff0000)"
|
||||
|
||||
color = match.group(1)
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_member_field(conn, member_id=member["id"], field="color", value=color)
|
||||
return True, "Color updated to #{}.".format(color) if color else "Color cleared."
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="pronouns", usage="[pronouns]", description="Updates a member's pronouns. Leave blank to clear.")
|
||||
async def member_pronouns(conn, message, member, args):
|
||||
if len(args) == 0:
|
||||
pronouns = None
|
||||
else:
|
||||
pronouns = " ".join(args)
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_member_field(conn, member_id=member["id"], field="pronouns", value=pronouns)
|
||||
return True, "Pronouns set to {}".format(pronouns) if pronouns else "Pronouns cleared."
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="birthdate", usage="[birthdate]", description="Updates a member's birthdate. Must be in ISO-8601 format (eg. 1999-07-25). Leave blank to clear.")
|
||||
async def member_birthday(conn, message, member, args):
|
||||
if len(args) == 0:
|
||||
new_date = None
|
||||
else:
|
||||
# Parse date
|
||||
try:
|
||||
new_date = datetime.strptime(args[0], "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
return False, "Invalid date. Date must be in ISO-8601 format (eg. 1999-07-25)."
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_member_field(conn, member_id=member["id"], field="birthday", value=new_date)
|
||||
return True, "Birthdate set to {}".format(new_date) if new_date else "Birthdate cleared."
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="description", description="Updates a member's description. Add \"clear\" to clear.")
|
||||
async def member_description(conn, message, member, args):
|
||||
if len(args) > 0 and args[0] == "clear":
|
||||
new_description = None
|
||||
else:
|
||||
new_description = await text_input(message, member["name"])
|
||||
|
||||
if not new_description:
|
||||
return True, "Description update cancelled."
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_member_field(conn, member_id=member["id"], field="description", value=new_description)
|
||||
return True, "Description set." if new_description else "Description cleared."
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="remove", description="Removes a member from your system.")
|
||||
async def member_remove(conn, message, member, args):
|
||||
await client.send_message(message.channel, "Are you sure you want to remove {}? If so, reply to this message with the member's name.".format(member["name"]))
|
||||
|
||||
msg = await client.wait_for_message(author=message.author, channel=message.channel)
|
||||
if msg.content == member["name"]:
|
||||
await db.delete_member(conn, member_id=member["id"])
|
||||
return True, "Member removed."
|
||||
else:
|
||||
return True, "Member removal cancelled."
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="avatar", usage="[user|url]", description="Updates a member's avatar. Can be an account mention (which will use that account's avatar), or a link to an image. Leave blank to clear.")
|
||||
async def member_avatar(conn, message, member, args):
|
||||
if len(args) == 0:
|
||||
avatar_url = None
|
||||
else:
|
||||
user = await parse_mention(args[0])
|
||||
if user:
|
||||
# Set the avatar to the mentioned user's avatar
|
||||
# Discord doesn't like webp, but also hosts png alternatives
|
||||
avatar_url = user.avatar_url.replace(".webp", ".png")
|
||||
else:
|
||||
# Validate URL
|
||||
u = urlparse(args[0])
|
||||
if u.scheme in ["http", "https"] and u.netloc and u.path:
|
||||
avatar_url = args[0]
|
||||
else:
|
||||
return False, "Invalid URL."
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_member_field(conn, member_id=member["id"], field="avatar_url", value=avatar_url)
|
||||
return True, "Avatar set." if avatar_url else "Avatar cleared."
|
||||
|
||||
|
||||
@member_command(cmd="pk;member", subcommand="proxy", usage="[example]", description="Updates a member's proxy settings. Needs an \"example\" proxied message containing the string \"text\" (eg. [text], |text|, etc).")
|
||||
async def member_proxy(conn, message, member, args):
|
||||
if len(args) == 0:
|
||||
prefix, suffix = None, None
|
||||
else:
|
||||
# Sanity checking
|
||||
example = " ".join(args)
|
||||
if "text" not in example:
|
||||
return False, "Example proxy message must contain the string 'text'."
|
||||
|
||||
if example.count("text") != 1:
|
||||
return False, "Example proxy message must contain the string 'text' exactly once."
|
||||
|
||||
# Extract prefix and suffix
|
||||
prefix = example[:example.index("text")].strip()
|
||||
suffix = example[example.index("text")+4:].strip()
|
||||
logger.debug("Matched prefix '{}' and suffix '{}'".format(prefix, suffix))
|
||||
|
||||
# DB stores empty strings as None, make that work
|
||||
if not prefix:
|
||||
prefix = None
|
||||
if not suffix:
|
||||
suffix = None
|
||||
|
||||
async with conn.transaction():
|
||||
await db.update_member_field(conn, member_id=member["id"], field="prefix", value=prefix)
|
||||
await db.update_member_field(conn, member_id=member["id"], field="suffix", value=suffix)
|
||||
return True, "Proxy settings updated." if prefix or suffix else "Proxy settings cleared."
|
||||
|
||||
@command(cmd="pk;message", subcommand=None, usage="<id>", description="Shows information about a proxied message. Requires the message ID.")
|
||||
async def message_info(conn, message, args):
|
||||
try:
|
||||
mid = int(args[0])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Find the message in the DB
|
||||
message_row = await db.get_message(conn, mid)
|
||||
if not message_row:
|
||||
return False, "Message not found."
|
||||
|
||||
# Find the actual message object
|
||||
channel = client.get_channel(str(message_row["channel"]))
|
||||
message = await client.get_message(channel, str(message_row["mid"]))
|
||||
|
||||
# Get the original sender of the message
|
||||
original_sender = await client.get_user_info(str(message_row["sender"]))
|
||||
|
||||
# Get sender member and system
|
||||
member = await db.get_member(conn, message_row["member"])
|
||||
system = await db.get_system(conn, member["system"])
|
||||
|
||||
embed = discord.Embed()
|
||||
embed.timestamp = message.timestamp
|
||||
embed.colour = discord.Colour.blue()
|
||||
|
||||
if system["name"]:
|
||||
system_value = "`{}`: {}".format(system["hid"], system["name"])
|
||||
else:
|
||||
system_value = "`{}`".format(system["hid"])
|
||||
embed.add_field(name="System", value=system_value)
|
||||
embed.add_field(name="Member", value="`{}`: {}".format(member["hid"], member["name"]))
|
||||
embed.add_field(name="Sent by", value="{}#{}".format(original_sender.name, original_sender.discriminator))
|
||||
embed.add_field(name="Content", value=message.clean_content, inline=False)
|
||||
|
||||
embed.set_author(name=member["name"], url=member["avatar_url"])
|
||||
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
return True
|
||||
|
||||
@command(cmd="pk;help", subcommand=None, usage="[system|member|message]", description="Shows this help message.")
|
||||
async def show_help(conn, message, args):
|
||||
embed = discord.Embed()
|
||||
embed.colour = discord.Colour.blue()
|
||||
embed.title = "PluralKit Help"
|
||||
embed.set_footer(text="<> denotes mandatory arguments, [] denotes optional arguments")
|
||||
|
||||
if len(args) > 0 and ("pk;" + args[0]) in command_map:
|
||||
cmds = ["", ("pk;" + args[0], command_map["pk;" + args[0]])]
|
||||
else:
|
||||
cmds = command_map.items()
|
||||
|
||||
for cmd, subcommands in cmds:
|
||||
for subcmd, (_, usage, description) in subcommands.items():
|
||||
embed.add_field(name="{} {} {}".format(cmd, subcmd or "", usage or ""), value=description, inline=False)
|
||||
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
return True
|
193
bot/pluralkit/db.py
Normal file
193
bot/pluralkit/db.py
Normal file
@ -0,0 +1,193 @@
|
||||
import time
|
||||
|
||||
import asyncpg
|
||||
import asyncpg.exceptions
|
||||
|
||||
from pluralkit.bot import logger
|
||||
|
||||
async def connect():
|
||||
while True:
|
||||
try:
|
||||
return await asyncpg.create_pool(user="postgres", password="postgres", database="postgres", host="db")
|
||||
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
|
||||
pass
|
||||
|
||||
def db_wrap(func):
|
||||
async def inner(*args, **kwargs):
|
||||
before = time.perf_counter()
|
||||
res = await func(*args, **kwargs)
|
||||
after = time.perf_counter()
|
||||
|
||||
logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000))
|
||||
return res
|
||||
return inner
|
||||
|
||||
webhook_cache = {}
|
||||
|
||||
@db_wrap
|
||||
async def create_system(conn, system_name: str, system_hid: str):
|
||||
logger.debug("Creating system (name={}, hid={})".format(system_name, system_hid))
|
||||
return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
|
||||
|
||||
@db_wrap
|
||||
async def remove_system(conn, system_id: int):
|
||||
logger.debug("Deleting system (id={})".format(system_id))
|
||||
await conn.execute("delete from systems where id = $1", system_id)
|
||||
|
||||
@db_wrap
|
||||
async def create_member(conn, system_id: int, member_name: str, member_hid: str):
|
||||
logger.debug("Creating member (system={}, name={}, hid={})".format(system_id, member_name, member_hid))
|
||||
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
|
||||
|
||||
@db_wrap
|
||||
async def delete_member(conn, member_id: int):
|
||||
logger.debug("Deleting member (id={})".format(member_id))
|
||||
await conn.execute("update switches set member = null, member_del = true where member = $1", member_id)
|
||||
await conn.execute("delete from members where id = $1", member_id)
|
||||
|
||||
@db_wrap
|
||||
async def link_account(conn, system_id: int, account_id: str):
|
||||
logger.debug("Linking account (account_id={}, system_id={})".format(account_id, system_id))
|
||||
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
|
||||
|
||||
@db_wrap
|
||||
async def unlink_account(conn, system_id: int, account_id: str):
|
||||
logger.debug("Unlinking account (account_id={}, system_id={})".format(account_id, system_id))
|
||||
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
|
||||
|
||||
@db_wrap
|
||||
async def get_linked_accounts(conn, system_id: int):
|
||||
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
|
||||
|
||||
@db_wrap
|
||||
async def get_system_by_account(conn, account_id: str):
|
||||
return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
|
||||
|
||||
@db_wrap
|
||||
async def get_system_by_hid(conn, system_hid: str):
|
||||
return await conn.fetchrow("select * from systems where hid = $1", system_hid)
|
||||
|
||||
@db_wrap
|
||||
async def get_system(conn, system_id: int):
|
||||
return await conn.fetchrow("select * from systems where id = $1", system_id)
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_name(conn, system_id: int, member_name: str):
|
||||
return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name)
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str):
|
||||
return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_hid(conn, member_hid: str):
|
||||
return await conn.fetchrow("select * from members where hid = $1", member_hid)
|
||||
|
||||
@db_wrap
|
||||
async def get_member(conn, member_id: int):
|
||||
return await conn.fetchrow("select * from members where id = $1", member_id)
|
||||
|
||||
@db_wrap
|
||||
async def get_message(conn, message_id: str):
|
||||
return await conn.fetchrow("select * from messages where mid = $1", message_id)
|
||||
|
||||
@db_wrap
|
||||
async def update_system_field(conn, system_id: int, field: str, value):
|
||||
logger.debug("Updating system field (id={}, {}={})".format(system_id, field, value))
|
||||
await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id)
|
||||
|
||||
@db_wrap
|
||||
async def update_member_field(conn, member_id: int, field: str, value):
|
||||
logger.debug("Updating member field (id={}, {}={})".format(member_id, field, value))
|
||||
await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id)
|
||||
|
||||
@db_wrap
|
||||
async def get_all_members(conn, system_id: int):
|
||||
return await conn.fetch("select * from members where system = $1", system_id)
|
||||
|
||||
@db_wrap
|
||||
async def get_members_exceeding(conn, system_id: int, length: int):
|
||||
return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length)
|
||||
|
||||
@db_wrap
|
||||
async def get_webhook(conn, channel_id: str):
|
||||
if channel_id in webhook_cache:
|
||||
return webhook_cache[channel_id]
|
||||
res = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
|
||||
webhook_cache[channel_id] = res
|
||||
return res
|
||||
|
||||
@db_wrap
|
||||
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
|
||||
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(channel_id, webhook_id, webhook_token))
|
||||
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
|
||||
|
||||
@db_wrap
|
||||
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str):
|
||||
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(message_id, channel_id, member_id, sender_id))
|
||||
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id))
|
||||
|
||||
@db_wrap
|
||||
async def get_members_by_account(conn, account_id: str):
|
||||
# Returns a "chimera" object
|
||||
return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id))
|
||||
|
||||
@db_wrap
|
||||
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str):
|
||||
await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id))
|
||||
|
||||
@db_wrap
|
||||
async def delete_message(conn, message_id: str):
|
||||
logger.debug("Deleting message (id={})".format(message_id))
|
||||
await conn.execute("delete from messages where mid = $1", int(message_id))
|
||||
|
||||
async def create_tables(conn):
|
||||
await conn.execute("""create table if not exists systems (
|
||||
id serial primary key,
|
||||
hid char(5) unique not null,
|
||||
name text,
|
||||
description text,
|
||||
tag text,
|
||||
created timestamp not null default current_timestamp
|
||||
)""")
|
||||
await conn.execute("""create table if not exists members (
|
||||
id serial primary key,
|
||||
hid char(5) unique not null,
|
||||
system serial not null references systems(id) on delete cascade,
|
||||
color char(6),
|
||||
avatar_url text,
|
||||
name text not null,
|
||||
birthday date,
|
||||
pronouns text,
|
||||
description text,
|
||||
prefix text,
|
||||
suffix text,
|
||||
created timestamp not null default current_timestamp
|
||||
)""")
|
||||
await conn.execute("""create table if not exists accounts (
|
||||
uid bigint primary key,
|
||||
system serial not null references systems(id) on delete cascade
|
||||
)""")
|
||||
await conn.execute("""create table if not exists messages (
|
||||
mid bigint primary key,
|
||||
channel bigint not null,
|
||||
member serial not null references members(id) on delete cascade,
|
||||
sender bigint not null references accounts(uid)
|
||||
)""")
|
||||
await conn.execute("""create table if not exists switches (
|
||||
id serial primary key,
|
||||
system serial not null references systems(id) on delete cascade,
|
||||
member serial references members(id) on delete restrict,
|
||||
timestamp timestamp not null default current_timestamp,
|
||||
member_del bool default false
|
||||
)""")
|
||||
await conn.execute("""create table if not exists webhooks (
|
||||
channel bigint primary key,
|
||||
webhook bigint not null,
|
||||
token text not null
|
||||
)""")
|
||||
await conn.execute("""create table if not exists servers (
|
||||
id bigint primary key,
|
||||
cmd_chans bigint[],
|
||||
proxy_chans bigint[]
|
||||
)""")
|
93
bot/pluralkit/proxy.py
Normal file
93
bot/pluralkit/proxy.py
Normal file
@ -0,0 +1,93 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import client, logger
|
||||
|
||||
async def get_webhook(conn, channel):
|
||||
async with conn.transaction():
|
||||
# Try to find an existing webhook
|
||||
hook_row = await db.get_webhook(conn, channel_id=channel.id)
|
||||
# There's none, we'll make one
|
||||
if not hook_row:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
req_data = {"name": "PluralKit Proxy Webhook"}
|
||||
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])}
|
||||
|
||||
async with session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), json=req_data, headers=req_headers) as resp:
|
||||
data = await resp.json()
|
||||
hook_id = data["id"]
|
||||
token = data["token"]
|
||||
|
||||
# Insert new hook into DB
|
||||
await db.add_webhook(conn, channel_id=channel.id, webhook_id=hook_id, webhook_token=token)
|
||||
return hook_id, token
|
||||
|
||||
return hook_row["webhook"], hook_row["token"]
|
||||
|
||||
async def proxy_message(conn, member, message, inner):
|
||||
logger.debug("Proxying message '{}' for member {}".format(inner, member["hid"]))
|
||||
# Delete the original message
|
||||
await client.delete_message(message)
|
||||
|
||||
# Get the webhook details
|
||||
hook_id, hook_token = await get_webhook(conn, message.channel)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
req_data = {
|
||||
"username": "{} {}".format(member["name"], member["tag"] or "").strip(),
|
||||
"avatar_url": member["avatar_url"],
|
||||
"content": inner
|
||||
}
|
||||
req_headers = {"Authorization": "Bot {}".format(os.environ["TOKEN"])}
|
||||
# And send the message
|
||||
async with session.post("https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), json=req_data, headers=req_headers) as resp:
|
||||
resp_data = await resp.json()
|
||||
logger.debug("Discord webhook response: {}".format(resp_data))
|
||||
|
||||
# Insert new message details into the DB
|
||||
await db.add_message(conn, message_id=resp_data["id"], channel_id=message.channel.id, member_id=member["id"], sender_id=message.author.id)
|
||||
|
||||
async def handle_proxying(conn, message):
|
||||
# Big fat query to find every member associated with this account
|
||||
# Returned member object has a few more keys (system tag, for example)
|
||||
members = await db.get_members_by_account(conn, account_id=message.author.id)
|
||||
|
||||
# Sort by specificity (members with both prefix and suffix go higher)
|
||||
members = sorted(members, key=lambda x: int(bool(x["prefix"])) + int(bool(x["suffix"])), reverse=True)
|
||||
|
||||
msg = message.content
|
||||
for member in members:
|
||||
# If no proxy details are configured, skip
|
||||
if not member["prefix"] and not member["suffix"]:
|
||||
continue
|
||||
|
||||
# Database stores empty strings as null, fix that here
|
||||
prefix = member["prefix"] or ""
|
||||
suffix = member["suffix"] or ""
|
||||
|
||||
# If we have a match, proxy the message
|
||||
if msg.startswith(prefix) and msg.endswith(suffix):
|
||||
# Extract the actual message contents sans tags
|
||||
if suffix:
|
||||
inner_message = message.content[len(prefix):-len(suffix)].strip()
|
||||
else:
|
||||
# Slicing to -0 breaks, don't do that
|
||||
inner_message = message.content[len(prefix):].strip()
|
||||
|
||||
await proxy_message(conn, member, message, inner_message)
|
||||
break
|
||||
|
||||
|
||||
|
||||
async def handle_reaction(conn, reaction, user):
|
||||
if reaction.emoji == "❌":
|
||||
async with conn.transaction():
|
||||
# Find the message in the DB, and make sure it's sent by the user who reacted
|
||||
message = await db.get_message_by_sender_and_id(conn, message_id=reaction.message.id, sender_id=user.id)
|
||||
|
||||
if message:
|
||||
# If so, delete the message and remove it from the DB
|
||||
await db.delete_message(conn, message["mid"])
|
||||
await client.delete_message(reaction.message)
|
211
bot/pluralkit/utils.py
Normal file
211
bot/pluralkit/utils.py
Normal file
@ -0,0 +1,211 @@
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
||||
import asyncio
|
||||
import asyncpg
|
||||
import discord
|
||||
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import client, logger
|
||||
|
||||
def generate_hid() -> str:
|
||||
return "".join(random.choices(string.ascii_lowercase, k=5))
|
||||
|
||||
async def parse_mention(mention: str) -> discord.User:
|
||||
# First try matching mention format
|
||||
match = re.fullmatch("<@!?(\\d+)>", mention)
|
||||
if match:
|
||||
try:
|
||||
return await client.get_user_info(match.group(1))
|
||||
except discord.NotFound:
|
||||
return None
|
||||
|
||||
# Then try with just ID
|
||||
try:
|
||||
return await client.get_user_info(str(int(mention)))
|
||||
except (ValueError, discord.NotFound):
|
||||
return None
|
||||
|
||||
async def get_system_fuzzy(conn, key) -> asyncpg.Record:
|
||||
if isinstance(key, discord.User):
|
||||
return await db.get_system_by_account(conn, account_id=key.id)
|
||||
|
||||
if isinstance(key, str) and len(key) == 5:
|
||||
return await db.get_system_by_hid(conn, system_hid=key)
|
||||
|
||||
system = parse_mention(key)
|
||||
|
||||
if system:
|
||||
return system
|
||||
return None
|
||||
|
||||
async def get_member_fuzzy(conn, system_id: int, key: str, system_only=True) -> asyncpg.Record:
|
||||
# First search by hid
|
||||
if system_only:
|
||||
member = await db.get_member_by_hid_in_system(conn, system_id=system_id, member_hid=key)
|
||||
else:
|
||||
member = await db.get_member_by_hid(conn, member_hid=key)
|
||||
if member is not None:
|
||||
return member
|
||||
|
||||
# Then search by name, if we have a system
|
||||
if system_id:
|
||||
member = await db.get_member_by_name(conn, system_id=system_id, member_name=key)
|
||||
if member is not None:
|
||||
return member
|
||||
|
||||
command_map = {}
|
||||
|
||||
# Command wrapper
|
||||
# Return True for success, return False for failure
|
||||
# Second parameter is the message it'll send. If just False, will print usage
|
||||
def command(cmd, subcommand, usage=None, description=None):
|
||||
def wrap(func):
|
||||
async def wrapper(conn, message, args):
|
||||
res = await func(conn, message, args)
|
||||
|
||||
if res is not None:
|
||||
if not isinstance(res, tuple):
|
||||
success, msg = res, None
|
||||
else:
|
||||
success, msg = res
|
||||
|
||||
if not success and not msg:
|
||||
# Failure, no message, print usage
|
||||
usage_embed = discord.Embed()
|
||||
usage_embed.colour = discord.Colour.blue()
|
||||
usage_embed.add_field(name="Usage", value=usage, inline=False)
|
||||
|
||||
await client.send_message(message.channel, embed=usage_embed)
|
||||
elif not success:
|
||||
# Failure, print message
|
||||
error_embed = discord.Embed()
|
||||
error_embed.colour = discord.Colour.dark_red()
|
||||
error_embed.description = msg
|
||||
await client.send_message(message.channel, embed=error_embed)
|
||||
elif msg:
|
||||
# Success, print message
|
||||
success_embed = discord.Embed()
|
||||
success_embed.colour = discord.Colour.blue()
|
||||
success_embed.description = msg
|
||||
await client.send_message(message.channel, embed=success_embed)
|
||||
# Success, don't print anything
|
||||
if cmd not in command_map:
|
||||
command_map[cmd] = {}
|
||||
if subcommand not in command_map[cmd]:
|
||||
command_map[cmd][subcommand] = {}
|
||||
|
||||
command_map[cmd][subcommand] = (wrapper, usage, description)
|
||||
return wrapper
|
||||
return wrap
|
||||
|
||||
# Member command wrapper
|
||||
# Tries to find member by first argument
|
||||
# If system_only=False, allows members from other systems by hid
|
||||
def member_command(cmd, subcommand, usage=None, description=None, system_only=True):
|
||||
def wrap(func):
|
||||
async def wrapper(conn, message, args):
|
||||
# Return if no member param
|
||||
if len(args) == 0:
|
||||
return False
|
||||
|
||||
# If system_only, we need a system to check
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
if system_only and system is None:
|
||||
return False, "No system is registered to this account."
|
||||
|
||||
# System is allowed to be none if not system_only
|
||||
system_id = system["id"] if system else None
|
||||
# And find member by key
|
||||
member = await get_member_fuzzy(conn, system_id=system_id, key=args[0], system_only=system_only)
|
||||
|
||||
if member is None:
|
||||
return False, "Can't find member \"{}\".".format(args[0])
|
||||
|
||||
return await func(conn, message, member, args[1:])
|
||||
return command(cmd=cmd, subcommand=subcommand, usage=usage, description=description)(wrapper)
|
||||
return wrap
|
||||
|
||||
async def generate_system_info_card(conn, system: asyncpg.Record) -> discord.Embed:
|
||||
card = discord.Embed()
|
||||
|
||||
if system["name"]:
|
||||
card.title = system["name"]
|
||||
|
||||
if system["description"]:
|
||||
card.add_field(name="Description", value=system["description"], inline=False)
|
||||
|
||||
if system["tag"]:
|
||||
card.add_field(name="Tag", value=system["tag"])
|
||||
|
||||
# Get names of all linked accounts
|
||||
async def get_name(account_id):
|
||||
account = await client.get_user_info(account_id)
|
||||
return "{}#{}".format(account.name, account.discriminator)
|
||||
|
||||
account_name_futures = []
|
||||
for account_id in await db.get_linked_accounts(conn, system_id=system["id"]):
|
||||
account_name_futures.append(get_name(account_id))
|
||||
# Run in parallel
|
||||
account_names = await asyncio.gather(*account_name_futures)
|
||||
|
||||
card.add_field(name="Linked accounts", value=", ".join(account_names))
|
||||
|
||||
# Get names of all members
|
||||
member_texts = []
|
||||
for member in await db.get_all_members(conn, system_id=system["id"]):
|
||||
member_texts.append("`{}`: {}".format(member["hid"], member["name"]))
|
||||
|
||||
if len(member_texts) > 0:
|
||||
card.add_field(name="Members", value="\n".join(member_texts), inline=False)
|
||||
|
||||
card.set_footer(text="System ID: {}".format(system["hid"]))
|
||||
return card
|
||||
|
||||
async def generate_member_info_card(conn, member: asyncpg.Record) -> discord.Embed:
|
||||
card = discord.Embed()
|
||||
card.set_author(name=member["name"], icon_url=member["avatar_url"])
|
||||
|
||||
if member["color"]:
|
||||
card.colour = int(member["color"], 16)
|
||||
|
||||
if member["birthday"]:
|
||||
card.add_field(name="Birthdate", value=member["birthday"].strftime("%b %d, %Y"))
|
||||
|
||||
if member["pronouns"]:
|
||||
card.add_field(name="Pronouns", value=member["pronouns"])
|
||||
|
||||
if member["prefix"] or member["suffix"]:
|
||||
prefix = member["prefix"] or ""
|
||||
suffix = member["suffix"] or ""
|
||||
card.add_field(name="Proxy Tags", value="{}text{}".format(prefix, suffix))
|
||||
|
||||
if member["description"]:
|
||||
card.add_field(name="Description", value=member["description"], inline=False)
|
||||
|
||||
# Get system name and hid
|
||||
system = await db.get_system(conn, system_id=member["system"])
|
||||
if system["name"]:
|
||||
system_value = "`{}`: {}".format(system["hid"], system["name"])
|
||||
else:
|
||||
system_value = "`{}`".format(system["hid"])
|
||||
card.add_field(name="System", value=system_value, inline=False)
|
||||
|
||||
card.set_footer(text="System ID: {} | Member ID: {}".format(system["hid"], member["hid"]))
|
||||
return card
|
||||
|
||||
async def text_input(message, subject):
|
||||
await client.send_message(message.channel, "Reply in this channel with the new description you want to set for {}.".format(subject))
|
||||
msg = await client.wait_for_message(author=message.author, channel=message.channel)
|
||||
|
||||
await client.send_message(message.channel, "Alright. When you're happy with the new description, click the ✅ reaction. To cancel, click the ❌ reaction.")
|
||||
await client.add_reaction(msg, "✅")
|
||||
await client.add_reaction(msg, "❌")
|
||||
|
||||
reaction = await client.wait_for_reaction(emoji=["✅", "❌"], message=msg, user=message.author)
|
||||
if reaction.reaction.emoji == "✅":
|
||||
return msg.content
|
||||
else:
|
||||
await client.clear_reactions(msg)
|
||||
return None
|
4
bot/requirements.txt
Normal file
4
bot/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
discord.py
|
||||
asyncpg
|
||||
aiohttp
|
||||
aioinflux
|
27
docker-compose.yml
Normal file
27
docker-compose.yml
Normal file
@ -0,0 +1,27 @@
|
||||
version: '3'
|
||||
services:
|
||||
bot:
|
||||
build: bot
|
||||
depends_on:
|
||||
- db
|
||||
environment:
|
||||
- TOKEN
|
||||
db:
|
||||
image: postgres:alpine
|
||||
volumes:
|
||||
- db_data:/var/lib/postgres
|
||||
restart: always
|
||||
influx:
|
||||
image: influxdb:alpine
|
||||
volumes:
|
||||
- /var/lib/influxdb
|
||||
environment:
|
||||
- INFLUXDB_GRAPHITE_ENABLED=true
|
||||
restart: always
|
||||
grafana:
|
||||
image: grafana/grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
restart: always
|
||||
volumes:
|
||||
db_data:
|
Loading…
Reference in New Issue
Block a user