Basic proxy functionality fixed

This commit is contained in:
Ske 2018-10-27 22:00:41 +02:00
parent c8caeadec4
commit 4217d5d5d8
6 changed files with 499 additions and 369 deletions

View File

@ -5,7 +5,4 @@ import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from pluralkit import bot
pk = bot.PluralKitBot(os.environ["TOKEN"])
loop = asyncio.get_event_loop()
loop.run_until_complete(pk.run())
bot.run()

View File

@ -1,127 +1,187 @@
import asyncio
import json
import logging
import os
import time
import asyncpg
import sys
import traceback
from datetime import datetime
import asyncio
import os
import logging
import discord
from pluralkit import db
from pluralkit.bot import channel_logger, commands, proxy, embeds
from pluralkit.bot import commands, proxy
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
def connect_to_database() -> asyncpg.pool.Pool:
username = os.environ["DATABASE_USER"]
password = os.environ["DATABASE_PASS"]
name = os.environ["DATABASE_NAME"]
host = os.environ["DATABASE_HOST"]
port = os.environ["DATABASE_PORT"]
# logging.getLogger("pluralkit").setLevel(logging.DEBUG)
if username is None or password is None or name is None or host is None or port is None:
print("Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.", file=sys.stderr)
sys.exit(1)
class PluralKitBot:
def __init__(self, token):
self.token = token
self.logger = logging.getLogger("pluralkit.bot")
try:
port = int(port)
except ValueError:
print("Please pass a valid integer as the DATABASE_PORT environment variable.", file=sys.stderr)
sys.exit(1)
self.client = discord.Client()
self.client.event(self.on_error)
self.client.event(self.on_ready)
self.client.event(self.on_message)
self.client.event(self.on_socket_raw_receive)
return asyncio.get_event_loop().run_until_complete(db.connect(
username=username,
password=password,
database=name,
host=host,
port=port
))
self.channel_logger = channel_logger.ChannelLogger(self.client)
def run():
pool = connect_to_database()
self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
client = discord.Client()
async def on_error(self, evt, *args, **kwargs):
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
@client.event
async def on_ready():
print("PluralKit started.")
print("User: {}#{} (ID: {})".format(client.user.name, client.user.discriminator, client.user.id))
print("{} servers".format(len(client.guilds)))
print("{} shards".format(client.shard_count or 1))
async def on_ready(self):
self.logger.info("Connected to Discord.")
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
self.logger.info("- User ID: {}".format(self.client.user.id))
self.logger.info("- {} servers".format(len(self.client.servers)))
async def on_message(self, message):
# Ignore bot messages
@client.event
async def on_message(message: discord.Message):
# Ignore messages from bots
if message.author.bot:
return
try:
if await self.handle_command_dispatch(message):
# Grab a database connection from the pool
async with pool.acquire() as conn:
# First pass: do command handling
did_run_command = await commands.command_dispatch(client, message, conn)
if did_run_command:
return
if await self.handle_proxy_dispatch(message):
return
except Exception:
await self.log_error_in_channel(message)
# Second pass: do proxy matching
await proxy.try_proxy_message(message, conn)
async def on_socket_raw_receive(self, msg):
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# we parse socket data manually for the reaction add event
if isinstance(msg, str):
try:
msg_data = json.loads(msg)
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
evt_data = msg_data.get("d")
if evt_data:
user_id = evt_data["user_id"]
message_id = evt_data["message_id"]
emoji = evt_data["emoji"]["name"]
async with self.pool.acquire() as conn:
await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
elif msg_data.get("t") == "MESSAGE_DELETE":
evt_data = msg_data.get("d")
if evt_data:
message_id = evt_data["id"]
async with self.pool.acquire() as conn:
await self.proxy.handle_deletion(conn, message_id)
except ValueError:
pass
bot_token = os.environ["TOKEN"]
if not bot_token:
print("No token specified. Please pass a valid Discord bot token in the TOKEN environment variable.",
file=sys.stderr)
sys.exit(1)
async def handle_command_dispatch(self, message):
async with self.pool.acquire() as conn:
result = await commands.command_dispatch(self.client, message, conn)
return result
client.run(bot_token)
async def handle_proxy_dispatch(self, message):
# Try doing proxy parsing
async with self.pool.acquire() as conn:
return await self.proxy.try_proxy_message(conn, message)
# logging.getLogger("pluralkit").setLevel(logging.DEBUG)
async def log_error_in_channel(self, message):
channel_id = os.environ["LOG_CHANNEL"]
if not channel_id:
return
channel = self.client.get_channel(channel_id)
embed = embeds.exception_log(
message.content,
message.author.name,
message.author.discriminator,
message.server.id if message.server else None,
message.channel.id
)
await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed)
async def run(self):
try:
self.logger.info("Connecting to database...")
self.pool = await db.connect(
os.environ["DATABASE_USER"],
os.environ["DATABASE_PASS"],
os.environ["DATABASE_NAME"],
os.environ["DATABASE_HOST"],
int(os.environ["DATABASE_PORT"])
)
self.logger.info("Attempting to create tables...")
async with self.pool.acquire() as conn:
await db.create_tables(conn)
self.logger.info("Connecting to Discord...")
await self.client.start(self.token)
finally:
self.logger.info("Logging out from Discord...")
await self.client.logout()
# class PluralKitBot:
# def __init__(self, token):
# self.token = token
# self.logger = logging.getLogger("pluralkit.bot")
#
# self.client = discord.Client()
# self.client.event(self.on_error)
# self.client.event(self.on_ready)
# self.client.event(self.on_message)
# self.client.event(self.on_socket_raw_receive)
#
# self.channel_logger = channel_logger.ChannelLogger(self.client)
#
# self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
#
# async def on_error(self, evt, *args, **kwargs):
# self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
#
# async def on_ready(self):
# self.logger.info("Connected to Discord.")
# self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
# self.logger.info("- User ID: {}".format(self.client.user.id))
# self.logger.info("- {} servers".format(len(self.client.servers)))
#
# async def on_message(self, message):
# # Ignore bot messages
# if message.author.bot:
# return
#
# try:
# if await self.handle_command_dispatch(message):
# return
#
# if await self.handle_proxy_dispatch(message):
# return
# except Exception:
# await self.log_error_in_channel(message)
#
# async def on_socket_raw_receive(self, msg):
# # Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# # we parse socket data manually for the reaction add event
# if isinstance(msg, str):
# try:
# msg_data = json.loads(msg)
# if msg_data.get("t") == "MESSAGE_REACTION_ADD":
# evt_data = msg_data.get("d")
# if evt_data:
# user_id = evt_data["user_id"]
# message_id = evt_data["message_id"]
# emoji = evt_data["emoji"]["name"]
#
# async with self.pool.acquire() as conn:
# await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
# elif msg_data.get("t") == "MESSAGE_DELETE":
# evt_data = msg_data.get("d")
# if evt_data:
# message_id = evt_data["id"]
# async with self.pool.acquire() as conn:
# await self.proxy.handle_deletion(conn, message_id)
# except ValueError:
# pass
#
# async def handle_command_dispatch(self, message):
# async with self.pool.acquire() as conn:
# result = await commands.command_dispatch(self.client, message, conn)
# return result
#
# async def handle_proxy_dispatch(self, message):
# # Try doing proxy parsing
# async with self.pool.acquire() as conn:
# return await self.proxy.try_proxy_message(conn, message)
#
# async def log_error_in_channel(self, message):
# channel_id = os.environ["LOG_CHANNEL"]
# if not channel_id:
# return
#
# channel = self.client.get_channel(channel_id)
# embed = embeds.exception_log(
# message.content,
# message.author.name,
# message.author.discriminator,
# message.server.id if message.server else None,
# message.channel.id
# )
#
# await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed)
#
# async def run(self):
# try:
# self.logger.info("Connecting to database...")
# self.pool = await db.connect(
# os.environ["DATABASE_USER"],
# os.environ["DATABASE_PASS"],
# os.environ["DATABASE_NAME"],
# os.environ["DATABASE_HOST"],
# int(os.environ["DATABASE_PORT"])
# )
#
# self.logger.info("Attempting to create tables...")
# async with self.pool.acquire() as conn:
# await db.create_tables(conn)
#
# self.logger.info("Connecting to Discord...")
# await self.client.start(self.token)
# finally:
# self.logger.info("Logging out from Discord...")
# await self.client.logout()

View File

@ -118,7 +118,7 @@ class CommandContext:
raise CommandError("Timed out - try again.")
return reaction.reaction.emoji == ""
async def confirm_text(self, user: discord.Member, channel: discord.Channel, confirm_text: str, message: str):
async def confirm_text(self, user: discord.Member, channel: discord.TextChannel, confirm_text: str, message: str):
await self.reply(message)
message = await self.client.wait_for_message(channel=channel, author=user, timeout=60.0*5)

View File

@ -1,17 +1,16 @@
import ciso8601
from io import BytesIO
import discord
import logging
import re
import time
from typing import List, Optional
import aiohttp
import discord
from typing import List
from pluralkit import db
from pluralkit.bot import channel_logger, utils, embeds
from pluralkit.bot import utils
logger = logging.getLogger("pluralkit.bot.proxy")
def extract_leading_mentions(message_text):
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
@ -34,7 +33,9 @@ def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
# Ignore mentions at the very start of the message, and match proxy tags after those
message_text, leading_mentions = extract_leading_mentions(message_text)
logger.debug("Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions, prefix, suffix))
logger.debug(
"Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions,
prefix, suffix))
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
prefix_length = len(prefix)
@ -59,243 +60,314 @@ def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
for member in members:
match = match_member_proxy_tags(member, message_text)
if match is not None: # Using "is not None" because an empty string is OK here too
if match is not None: # Using "is not None" because an empty string is OK here too
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
return member, match
def get_message_attachment_url(message: discord.Message):
async def get_or_create_webhook_for_channel(conn, channel: discord.TextChannel):
# First, check if we have one saved in the DB
webhook_from_db = await db.get_webhook(conn, channel.id)
if webhook_from_db:
webhook_id, webhook_token = webhook_from_db
session = channel._state.http._session
hook = discord.Webhook.partial(webhook_id, webhook_token, adapter=discord.AsyncWebhookAdapter(session))
# Workaround for https://github.com/Rapptz/discord.py/issues/1242
hook._adapter.store_user = hook._adapter._store_user
return hook
# If not, we create one and save it
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
return created_webhook
async def make_attachment_file(message: discord.Message):
if not message.attachments:
return None
attachment = message.attachments[0]
if "proxy_url" in attachment:
return attachment["proxy_url"]
first_attachment = message.attachments[0]
if "url" in attachment:
return attachment["url"]
# Copy the file data to the buffer
# TODO: do this without buffering... somehow
bio = BytesIO()
await first_attachment.save(bio)
return discord.File(bio, first_attachment.filename)
# TODO: possibly move this to bot __init__ so commands can access it too
class WebhookPermissionError(Exception):
pass
async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember,
inner_text: str):
# Send the message through the webhook
webhook = await get_or_create_webhook_for_channel(conn, original_message.channel)
sent_message = await webhook.send(
content=inner_text,
username=proxy_member.name,
avatar_url=proxy_member.avatar_url,
file=await make_attachment_file(original_message),
wait=True
)
# Save the proxied message in the database
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id,
original_message.author.id)
# TODO: log message in log channel
class DeletionPermissionError(Exception):
pass
async def try_proxy_message(message: discord.Message, conn) -> bool:
# Don't bother proxying in DMs with the bot
if isinstance(message.channel, discord.abc.PrivateChannel):
return False
# Return every member associated with the account
members = await db.get_members_by_account(conn, message.author.id)
proxy_match = match_proxy_tags(members, message.content)
if not proxy_match:
# No proxy tags match here, we done
return False
class Proxy:
def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
self.logger = logging.getLogger("pluralkit.bot.proxy")
self.session = aiohttp.ClientSession()
self.client = client
self.token = token
self.channel_logger = logger
member, inner_text = proxy_match
async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
await db.add_webhook(conn, channel.id, id, token)
return id, token
# Sanitize inner text for @everyones and such
inner_text = utils.sanitize(inner_text)
async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
# This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
# No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
req_headers = {"Authorization": "Bot {}".format(self.token)}
# If we don't have an inner text OR an attachment, we cancel because the hook can't send that
if not inner_text and not message.attachments:
return False
# First, check if there's already a webhook belonging to the bot
async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
headers=req_headers) as resp:
if resp.status == 200:
webhooks = await resp.json()
for webhook in webhooks:
if webhook["user"]["id"] == self.client.user.id:
# This webhook belongs to us, we can use that, return it and save it
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# So, we now have enough information to successfully proxy a message
async with conn.transaction():
await do_proxy_message(conn, message, member, inner_text)
return True
# Then, try submitting a new one
req_data = {"name": "PluralKit Proxy Webhook"}
async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
json=req_data, headers=req_headers) as resp:
if resp.status == 200:
webhook = await resp.json()
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# Should not be reached without an exception being thrown
async def get_webhook_for_channel(self, conn, channel: discord.Channel):
async with conn.transaction():
hook_match = await db.get_webhook(conn, channel.id)
if not hook_match:
# We don't have a webhook, create/add one
return await self.create_and_add_channel_webhook(conn, channel)
else:
return hook_match
async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
attachment_url: str, has_already_retried=False):
hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
form_data = aiohttp.FormData()
form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
if text:
form_data.add_field("content", text)
if attachment_url:
attachment_resp = await self.session.get(attachment_url)
form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
filename=attachment_resp.url.name)
if member.avatar_url:
form_data.add_field("avatar_url", member.avatar_url)
async with self.session.post(
"https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
data=form_data) as resp:
if resp.status == 200:
message = await resp.json()
await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id)
try:
await self.client.delete_message(original_message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to delete original message (server={}, channel={})".format(
original_message.server.id, original_message.channel.id))
raise DeletionPermissionError()
except discord.NotFound:
self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
message_image = None
if message["attachments"]:
first_attachment = message["attachments"][0]
if "width" in first_attachment and "height" in first_attachment:
# Only log attachments that are actually images
message_image = first_attachment["url"]
await self.channel_logger.log_message_proxied(conn,
server_id=original_message.server.id,
channel_name=original_message.channel.name,
channel_id=original_message.channel.id,
sender_name=original_message.author.name,
sender_disc=original_message.author.discriminator,
sender_id=original_message.author.id,
member_name=member.name,
member_hid=member.hid,
member_avatar_url=member.avatar_url,
system_name=member.system_name,
system_hid=member.system_hid,
message_text=text,
message_image=message_image,
message_timestamp=ciso8601.parse_datetime(
message["timestamp"]),
message_id=message["id"])
elif resp.status == 404 and not has_already_retried:
# Webhook doesn't exist. Delete it from the DB, create, and add a new one
self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
await db.delete_webhook(conn, original_message.channel.id)
await self.create_and_add_channel_webhook(conn, original_message.channel)
# Then try again all over, making sure to not retry again and go in a loop should it continually fail
return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
else:
raise discord.HTTPException(resp, await resp.text())
async def try_proxy_message(self, conn, message: discord.Message):
# Can't proxy in DMs, webhook creation will explode
if message.channel.is_private:
return False
# Big fat query to find every member associated with this account
# Returned member object has a few more keys (system tag, for example)
members = await db.get_members_by_account(conn, account_id=message.author.id)
match = match_proxy_tags(members, message.content)
if not match:
return False
member, text = match
attachment_url = get_message_attachment_url(message)
# Can't proxy a message with no text AND no attachment
if not text and not attachment_url:
self.logger.debug("Skipping message because of no text and no attachment")
return False
# Remember to sanitize the text (remove @everyones and such)
text = utils.sanitize(text)
try:
async with conn.transaction():
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
except WebhookPermissionError:
embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
except DeletionPermissionError:
embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
return True
async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
async with conn.transaction():
# Find the message in the DB, and make sure it's sent by the user (if we need to check)
if check_user_id:
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
else:
db_message = await db.get_message(conn, message_id=message_id)
if db_message:
self.logger.debug("Deleting message {}".format(message_id))
channel = self.client.get_channel(str(db_message.channel))
# If we should also delete the actual message, do that
if delete_message:
message = await self.client.get_message(channel, message_id)
try:
await self.client.delete_message(message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
channel.server.id, channel.id))
return
# Remove it from the DB
await db.delete_message(conn, message_id)
# Then log deletion to logging channel
await self.channel_logger.log_message_deleted(conn,
server_id=channel.server.id,
channel_name=channel.name,
member_name=db_message.name,
member_hid=db_message.hid,
member_avatar_url=db_message.avatar_url,
system_name=db_message.system_name,
system_hid=db_message.system_hid,
message_text=db_message.content,
message_id=message_id,
deleted_by_moderator=deleted_by_moderator)
async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
if emoji == "":
await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
async def handle_deletion(self, conn, message_id: str):
# Don't delete the message, it's already gone at this point, just handle DB deletion and logging
await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)
# # TODO: possibly move this to bot __init__ so commands can access it too
# class WebhookPermissionError(Exception):
# pass
#
#
# class DeletionPermissionError(Exception):
# pass
#
#
# class Proxy:
# def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
# self.logger = logging.getLogger("pluralkit.bot.proxy")
# self.session = aiohttp.ClientSession()
# self.client = client
# self.token = token
# self.channel_logger = logger
#
# async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
# await db.add_webhook(conn, channel.id, id, token)
# return id, token
#
# async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
# # This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
# # No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
# req_headers = {"Authorization": "Bot {}".format(self.token)}
#
# # First, check if there's already a webhook belonging to the bot
# async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
# headers=req_headers) as resp:
# if resp.status == 200:
# webhooks = await resp.json()
# for webhook in webhooks:
# if webhook["user"]["id"] == self.client.user.id:
# # This webhook belongs to us, we can use that, return it and save it
# return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
# elif resp.status == 403:
# self.logger.warning(
# "Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
# channel.id))
# raise WebhookPermissionError()
# else:
# raise discord.HTTPException(resp, await resp.text())
#
# # Then, try submitting a new one
# req_data = {"name": "PluralKit Proxy Webhook"}
# async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
# json=req_data, headers=req_headers) as resp:
# if resp.status == 200:
# webhook = await resp.json()
# return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
# elif resp.status == 403:
# self.logger.warning(
# "Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
# channel.id))
# raise WebhookPermissionError()
# else:
# raise discord.HTTPException(resp, await resp.text())
#
# # Should not be reached without an exception being thrown
#
# async def get_webhook_for_channel(self, conn, channel: discord.Channel):
# async with conn.transaction():
# hook_match = await db.get_webhook(conn, channel.id)
# if not hook_match:
# # We don't have a webhook, create/add one
# return await self.create_and_add_channel_webhook(conn, channel)
# else:
# return hook_match
#
# async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
# attachment_url: str, has_already_retried=False):
# hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
#
# form_data = aiohttp.FormData()
# form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
#
# if text:
# form_data.add_field("content", text)
#
# if attachment_url:
# attachment_resp = await self.session.get(attachment_url)
# form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
# filename=attachment_resp.url.name)
#
# if member.avatar_url:
# form_data.add_field("avatar_url", member.avatar_url)
#
# async with self.session.post(
# "https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
# data=form_data) as resp:
# if resp.status == 200:
# message = await resp.json()
#
# await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id)
#
# try:
# await self.client.delete_message(original_message)
# except discord.Forbidden:
# self.logger.warning(
# "Did not have permission to delete original message (server={}, channel={})".format(
# original_message.server.id, original_message.channel.id))
# raise DeletionPermissionError()
# except discord.NotFound:
# self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
#
# message_image = None
# if message["attachments"]:
# first_attachment = message["attachments"][0]
# if "width" in first_attachment and "height" in first_attachment:
# # Only log attachments that are actually images
# message_image = first_attachment["url"]
#
# await self.channel_logger.log_message_proxied(conn,
# server_id=original_message.server.id,
# channel_name=original_message.channel.name,
# channel_id=original_message.channel.id,
# sender_name=original_message.author.name,
# sender_disc=original_message.author.discriminator,
# sender_id=original_message.author.id,
# member_name=member.name,
# member_hid=member.hid,
# member_avatar_url=member.avatar_url,
# system_name=member.system_name,
# system_hid=member.system_hid,
# message_text=text,
# message_image=message_image,
# message_timestamp=ciso8601.parse_datetime(
# message["timestamp"]),
# message_id=message["id"])
# elif resp.status == 404 and not has_already_retried:
# # Webhook doesn't exist. Delete it from the DB, create, and add a new one
# self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
# await db.delete_webhook(conn, original_message.channel.id)
# await self.create_and_add_channel_webhook(conn, original_message.channel)
#
# # Then try again all over, making sure to not retry again and go in a loop should it continually fail
# return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
# else:
# raise discord.HTTPException(resp, await resp.text())
#
# async def try_proxy_message(self, conn, message: discord.Message):
# # Can't proxy in DMs, webhook creation will explode
# if message.channel.is_private:
# return False
#
# # Big fat query to find every member associated with this account
# # Returned member object has a few more keys (system tag, for example)
# members = await db.get_members_by_account(conn, account_id=message.author.id)
#
# match = match_proxy_tags(members, message.content)
# if not match:
# return False
#
# member, text = match
# attachment_url = get_message_attachment_url(message)
#
# # Can't proxy a message with no text AND no attachment
# if not text and not attachment_url:
# self.logger.debug("Skipping message because of no text and no attachment")
# return False
#
# # Remember to sanitize the text (remove @everyones and such)
# text = utils.sanitize(text)
#
# try:
# async with conn.transaction():
# await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
# except WebhookPermissionError:
# embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
# await self.client.send_message(message.channel, embed=embed)
# except DeletionPermissionError:
# embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
# await self.client.send_message(message.channel, embed=embed)
#
# return True
#
# async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
# async with conn.transaction():
# # Find the message in the DB, and make sure it's sent by the user (if we need to check)
# if check_user_id:
# db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
# else:
# db_message = await db.get_message(conn, message_id=message_id)
#
# if db_message:
# self.logger.debug("Deleting message {}".format(message_id))
# channel = self.client.get_channel(str(db_message.channel))
#
# # Fetch the original message from the server
# try:
# original_message = await self.client.get_message(channel, message_id)
# except discord.NotFound:
# # Just in case it's already gone
# original_message = None
#
# # If we should also delete the actual message, do that
# if delete_message and original_message is not None:
# try:
# await self.client.delete_message(original_message)
# except discord.Forbidden:
# self.logger.warning(
# "Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
# channel.server.id, channel.id))
# return
#
# # Remove it from the DB
# await db.delete_message(conn, message_id)
#
# # Then log deletion to logging channel
# await self.channel_logger.log_message_deleted(conn,
# server_id=channel.server.id,
# channel_name=channel.name,
# member_name=db_message.name,
# member_hid=db_message.hid,
# member_avatar_url=db_message.avatar_url,
# system_name=db_message.system_name,
# system_hid=db_message.system_hid,
# message_text=original_message.content if original_message else "*(original not found)*",
# message_id=message_id,
# deleted_by_moderator=deleted_by_moderator)
#
# async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
# if emoji == "❌":
# await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
#
# async def handle_deletion(self, conn, message_id: str):
# # Don't delete the message, it's already gone at this point, just handle DB deletion and logging
# await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)

View File

@ -2,6 +2,7 @@ import logging
import re
import discord
from typing import Optional
from pluralkit import db
from pluralkit.system import System
@ -20,28 +21,28 @@ def bounds_check_member_name(new_name, system_tag):
if len("{} {}".format(new_name, system_tag)) > 32:
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
async def parse_mention(client: discord.Client, mention: str) -> discord.User:
async def parse_mention(client: discord.Client, mention: str) -> Optional[discord.User]:
# First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention)
if match:
try:
return await client.get_user_info(match.group(1))
return await client.get_user_info(int(match.group(1)))
except discord.NotFound:
return None
# Then try with just ID
try:
return await client.get_user_info(str(int(mention)))
return await client.get_user_info(int(mention))
except (ValueError, discord.NotFound):
return None
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
def parse_channel_mention(mention: str, server: discord.Guild) -> Optional[discord.TextChannel]:
match = re.fullmatch("<#(\\d+)>", mention)
if match:
return server.get_channel(match.group(1))
try:
return server.get_channel(str(int(mention)))
return server.get_channel(int(mention))
except ValueError:
return None

View File

@ -62,17 +62,17 @@ async def delete_member(conn, member_id: int):
@db_wrap
async def link_account(conn, system_id: int, account_id: str):
async def link_account(conn, system_id: int, account_id: int):
logger.debug("Linking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
await conn.execute("insert into accounts (uid, system) values ($1, $2)", account_id, system_id)
@db_wrap
async def unlink_account(conn, system_id: int, account_id: str):
async def unlink_account(conn, system_id: int, account_id: int):
logger.debug("Unlinking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
await conn.execute("delete from accounts where uid = $1 and system = $2", account_id, system_id)
@db_wrap
@ -81,8 +81,8 @@ async def get_linked_accounts(conn, system_id: int) -> List[int]:
@db_wrap
async def get_system_by_account(conn, account_id: str) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
async def get_system_by_account(conn, account_id: int) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
return System(**row) if row else None
@db_wrap
@ -151,26 +151,26 @@ async def get_members_exceeding(conn, system_id: int, length: int) -> List[Membe
@db_wrap
async def get_webhook(conn, channel_id: str) -> (str, str):
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
async def get_webhook(conn, channel_id: int) -> (str, str):
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", channel_id)
return (str(row["webhook"]), row["token"]) if row else None
@db_wrap
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
async def add_webhook(conn, channel_id: int, webhook_id: int, webhook_token: str):
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", channel_id, webhook_id, webhook_token)
@db_wrap
async def delete_webhook(conn, channel_id: str):
await conn.execute("delete from webhooks where channel = $1", int(channel_id))
async def delete_webhook(conn, channel_id: int):
await conn.execute("delete from webhooks where channel = $1", channel_id)
@db_wrap
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str):
async def add_message(conn, message_id: int, channel_id: int, member_id: int, sender_id: int):
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", message_id, channel_id, member_id, sender_id)
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
id: int
@ -185,7 +185,7 @@ class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "c
system_hid: str
@db_wrap
async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
async def get_members_by_account(conn, account_id: int) -> List[ProxyMember]:
# Returns a "chimera" object
rows = await conn.fetch("""select
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
@ -195,7 +195,7 @@ async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
where
accounts.uid = $1
and systems.id = accounts.system
and members.system = systems.id""", int(account_id))
and members.system = systems.id""", account_id)
return [ProxyMember(**row) for row in rows]
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
@ -220,7 +220,7 @@ class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender"
}
@db_wrap
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) -> MessageInfo:
async def get_message_by_sender_and_id(conn, message_id: int, sender_id: int) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url,
@ -231,12 +231,12 @@ async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) ->
messages.member = members.id
and members.system = systems.id
and mid = $1
and sender = $2""", int(message_id), int(sender_id))
and sender = $2""", message_id, sender_id)
return MessageInfo(**row) if row else None
@db_wrap
async def get_message(conn, message_id: str) -> MessageInfo:
async def get_message(conn, message_id: int) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url,
@ -246,14 +246,14 @@ async def get_message(conn, message_id: str) -> MessageInfo:
where
messages.member = members.id
and members.system = systems.id
and mid = $1""", int(message_id))
and mid = $1""", message_id)
return MessageInfo(**row) if row else None
@db_wrap
async def delete_message(conn, message_id: str):
async def delete_message(conn, message_id: int):
logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", int(message_id))
await conn.execute("delete from messages where mid = $1", message_id)
@db_wrap
async def get_member_message_count(conn, member_id: int) -> int:
@ -290,14 +290,14 @@ async def add_switch_member(conn, switch_id: int, member_id: int):
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
@db_wrap
async def get_server_info(conn, server_id: str):
return await conn.fetchrow("select * from servers where id = $1", int(server_id))
async def get_server_info(conn, server_id: int):
return await conn.fetchrow("select * from servers where id = $1", server_id)
@db_wrap
async def update_server(conn, server_id: str, logging_channel_id: str):
logging_channel_id = int(logging_channel_id) if logging_channel_id else None
async def update_server(conn, server_id: int, logging_channel_id: int):
logging_channel_id = logging_channel_id if logging_channel_id else None
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", server_id, logging_channel_id)
@db_wrap
async def member_count(conn) -> int: