Basic proxy functionality fixed

This commit is contained in:
Ske 2018-10-27 22:00:41 +02:00
parent c8caeadec4
commit 4217d5d5d8
6 changed files with 499 additions and 369 deletions

View File

@ -5,7 +5,4 @@ import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from pluralkit import bot from pluralkit import bot
bot.run()
pk = bot.PluralKitBot(os.environ["TOKEN"])
loop = asyncio.get_event_loop()
loop.run_until_complete(pk.run())

View File

@ -1,127 +1,187 @@
import asyncio import asyncpg
import json import sys
import logging
import os
import time
import traceback import asyncio
from datetime import datetime import os
import logging
import discord import discord
from pluralkit import db from pluralkit import db
from pluralkit.bot import channel_logger, commands, proxy, embeds from pluralkit.bot import commands, proxy
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
def connect_to_database() -> asyncpg.pool.Pool:
username = os.environ["DATABASE_USER"]
password = os.environ["DATABASE_PASS"]
name = os.environ["DATABASE_NAME"]
host = os.environ["DATABASE_HOST"]
port = os.environ["DATABASE_PORT"]
# logging.getLogger("pluralkit").setLevel(logging.DEBUG) if username is None or password is None or name is None or host is None or port is None:
print("Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.", file=sys.stderr)
sys.exit(1)
class PluralKitBot: try:
def __init__(self, token): port = int(port)
self.token = token except ValueError:
self.logger = logging.getLogger("pluralkit.bot") print("Please pass a valid integer as the DATABASE_PORT environment variable.", file=sys.stderr)
sys.exit(1)
self.client = discord.Client() return asyncio.get_event_loop().run_until_complete(db.connect(
self.client.event(self.on_error) username=username,
self.client.event(self.on_ready) password=password,
self.client.event(self.on_message) database=name,
self.client.event(self.on_socket_raw_receive) host=host,
port=port
))
self.channel_logger = channel_logger.ChannelLogger(self.client) def run():
pool = connect_to_database()
self.proxy = proxy.Proxy(self.client, token, self.channel_logger) client = discord.Client()
async def on_error(self, evt, *args, **kwargs): @client.event
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args)) async def on_ready():
print("PluralKit started.")
print("User: {}#{} (ID: {})".format(client.user.name, client.user.discriminator, client.user.id))
print("{} servers".format(len(client.guilds)))
print("{} shards".format(client.shard_count or 1))
async def on_ready(self): @client.event
self.logger.info("Connected to Discord.") async def on_message(message: discord.Message):
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator)) # Ignore messages from bots
self.logger.info("- User ID: {}".format(self.client.user.id))
self.logger.info("- {} servers".format(len(self.client.servers)))
async def on_message(self, message):
# Ignore bot messages
if message.author.bot: if message.author.bot:
return return
try: # Grab a database connection from the pool
if await self.handle_command_dispatch(message): async with pool.acquire() as conn:
# First pass: do command handling
did_run_command = await commands.command_dispatch(client, message, conn)
if did_run_command:
return return
if await self.handle_proxy_dispatch(message): # Second pass: do proxy matching
return await proxy.try_proxy_message(message, conn)
except Exception:
await self.log_error_in_channel(message)
async def on_socket_raw_receive(self, msg):
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# we parse socket data manually for the reaction add event
if isinstance(msg, str):
try:
msg_data = json.loads(msg)
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
evt_data = msg_data.get("d")
if evt_data:
user_id = evt_data["user_id"]
message_id = evt_data["message_id"]
emoji = evt_data["emoji"]["name"]
async with self.pool.acquire() as conn: bot_token = os.environ["TOKEN"]
await self.proxy.handle_reaction(conn, user_id, message_id, emoji) if not bot_token:
elif msg_data.get("t") == "MESSAGE_DELETE": print("No token specified. Please pass a valid Discord bot token in the TOKEN environment variable.",
evt_data = msg_data.get("d") file=sys.stderr)
if evt_data: sys.exit(1)
message_id = evt_data["id"]
async with self.pool.acquire() as conn:
await self.proxy.handle_deletion(conn, message_id)
except ValueError:
pass
async def handle_command_dispatch(self, message): client.run(bot_token)
async with self.pool.acquire() as conn:
result = await commands.command_dispatch(self.client, message, conn)
return result
async def handle_proxy_dispatch(self, message): # logging.getLogger("pluralkit").setLevel(logging.DEBUG)
# Try doing proxy parsing
async with self.pool.acquire() as conn:
return await self.proxy.try_proxy_message(conn, message)
async def log_error_in_channel(self, message): # class PluralKitBot:
channel_id = os.environ["LOG_CHANNEL"] # def __init__(self, token):
if not channel_id: # self.token = token
return # self.logger = logging.getLogger("pluralkit.bot")
#
channel = self.client.get_channel(channel_id) # self.client = discord.Client()
embed = embeds.exception_log( # self.client.event(self.on_error)
message.content, # self.client.event(self.on_ready)
message.author.name, # self.client.event(self.on_message)
message.author.discriminator, # self.client.event(self.on_socket_raw_receive)
message.server.id if message.server else None, #
message.channel.id # self.channel_logger = channel_logger.ChannelLogger(self.client)
) #
# self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed) #
# async def on_error(self, evt, *args, **kwargs):
async def run(self): # self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
try: #
self.logger.info("Connecting to database...") # async def on_ready(self):
self.pool = await db.connect( # self.logger.info("Connected to Discord.")
os.environ["DATABASE_USER"], # self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
os.environ["DATABASE_PASS"], # self.logger.info("- User ID: {}".format(self.client.user.id))
os.environ["DATABASE_NAME"], # self.logger.info("- {} servers".format(len(self.client.servers)))
os.environ["DATABASE_HOST"], #
int(os.environ["DATABASE_PORT"]) # async def on_message(self, message):
) # # Ignore bot messages
# if message.author.bot:
self.logger.info("Attempting to create tables...") # return
async with self.pool.acquire() as conn: #
await db.create_tables(conn) # try:
# if await self.handle_command_dispatch(message):
self.logger.info("Connecting to Discord...") # return
await self.client.start(self.token) #
finally: # if await self.handle_proxy_dispatch(message):
self.logger.info("Logging out from Discord...") # return
await self.client.logout() # except Exception:
# await self.log_error_in_channel(message)
#
# async def on_socket_raw_receive(self, msg):
# # Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# # we parse socket data manually for the reaction add event
# if isinstance(msg, str):
# try:
# msg_data = json.loads(msg)
# if msg_data.get("t") == "MESSAGE_REACTION_ADD":
# evt_data = msg_data.get("d")
# if evt_data:
# user_id = evt_data["user_id"]
# message_id = evt_data["message_id"]
# emoji = evt_data["emoji"]["name"]
#
# async with self.pool.acquire() as conn:
# await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
# elif msg_data.get("t") == "MESSAGE_DELETE":
# evt_data = msg_data.get("d")
# if evt_data:
# message_id = evt_data["id"]
# async with self.pool.acquire() as conn:
# await self.proxy.handle_deletion(conn, message_id)
# except ValueError:
# pass
#
# async def handle_command_dispatch(self, message):
# async with self.pool.acquire() as conn:
# result = await commands.command_dispatch(self.client, message, conn)
# return result
#
# async def handle_proxy_dispatch(self, message):
# # Try doing proxy parsing
# async with self.pool.acquire() as conn:
# return await self.proxy.try_proxy_message(conn, message)
#
# async def log_error_in_channel(self, message):
# channel_id = os.environ["LOG_CHANNEL"]
# if not channel_id:
# return
#
# channel = self.client.get_channel(channel_id)
# embed = embeds.exception_log(
# message.content,
# message.author.name,
# message.author.discriminator,
# message.server.id if message.server else None,
# message.channel.id
# )
#
# await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed)
#
# async def run(self):
# try:
# self.logger.info("Connecting to database...")
# self.pool = await db.connect(
# os.environ["DATABASE_USER"],
# os.environ["DATABASE_PASS"],
# os.environ["DATABASE_NAME"],
# os.environ["DATABASE_HOST"],
# int(os.environ["DATABASE_PORT"])
# )
#
# self.logger.info("Attempting to create tables...")
# async with self.pool.acquire() as conn:
# await db.create_tables(conn)
#
# self.logger.info("Connecting to Discord...")
# await self.client.start(self.token)
# finally:
# self.logger.info("Logging out from Discord...")
# await self.client.logout()

View File

@ -118,7 +118,7 @@ class CommandContext:
raise CommandError("Timed out - try again.") raise CommandError("Timed out - try again.")
return reaction.reaction.emoji == "" return reaction.reaction.emoji == ""
async def confirm_text(self, user: discord.Member, channel: discord.Channel, confirm_text: str, message: str): async def confirm_text(self, user: discord.Member, channel: discord.TextChannel, confirm_text: str, message: str):
await self.reply(message) await self.reply(message)
message = await self.client.wait_for_message(channel=channel, author=user, timeout=60.0*5) message = await self.client.wait_for_message(channel=channel, author=user, timeout=60.0*5)

View File

@ -1,17 +1,16 @@
import ciso8601 from io import BytesIO
import discord
import logging import logging
import re import re
import time from typing import List
from typing import List, Optional
import aiohttp
import discord
from pluralkit import db from pluralkit import db
from pluralkit.bot import channel_logger, utils, embeds from pluralkit.bot import utils
logger = logging.getLogger("pluralkit.bot.proxy") logger = logging.getLogger("pluralkit.bot.proxy")
def extract_leading_mentions(message_text): def extract_leading_mentions(message_text):
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces # This regex matches one or more mentions at the start of a message, separated by any amount of spaces
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text) match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
@ -34,7 +33,9 @@ def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
# Ignore mentions at the very start of the message, and match proxy tags after those # Ignore mentions at the very start of the message, and match proxy tags after those
message_text, leading_mentions = extract_leading_mentions(message_text) message_text, leading_mentions = extract_leading_mentions(message_text)
logger.debug("Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions, prefix, suffix)) logger.debug(
"Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions,
prefix, suffix))
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""): if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
prefix_length = len(prefix) prefix_length = len(prefix)
@ -59,243 +60,314 @@ def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
for member in members: for member in members:
match = match_member_proxy_tags(member, message_text) match = match_member_proxy_tags(member, message_text)
if match is not None: # Using "is not None" because an empty string is OK here too if match is not None: # Using "is not None" because an empty string is OK here too
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match)) logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
return member, match return member, match
def get_message_attachment_url(message: discord.Message): async def get_or_create_webhook_for_channel(conn, channel: discord.TextChannel):
# First, check if we have one saved in the DB
webhook_from_db = await db.get_webhook(conn, channel.id)
if webhook_from_db:
webhook_id, webhook_token = webhook_from_db
session = channel._state.http._session
hook = discord.Webhook.partial(webhook_id, webhook_token, adapter=discord.AsyncWebhookAdapter(session))
# Workaround for https://github.com/Rapptz/discord.py/issues/1242
hook._adapter.store_user = hook._adapter._store_user
return hook
# If not, we create one and save it
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
return created_webhook
async def make_attachment_file(message: discord.Message):
if not message.attachments: if not message.attachments:
return None return None
attachment = message.attachments[0] first_attachment = message.attachments[0]
if "proxy_url" in attachment:
return attachment["proxy_url"]
if "url" in attachment: # Copy the file data to the buffer
return attachment["url"] # TODO: do this without buffering... somehow
bio = BytesIO()
await first_attachment.save(bio)
return discord.File(bio, first_attachment.filename)
# TODO: possibly move this to bot __init__ so commands can access it too async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember,
class WebhookPermissionError(Exception): inner_text: str):
pass # Send the message through the webhook
webhook = await get_or_create_webhook_for_channel(conn, original_message.channel)
sent_message = await webhook.send(
content=inner_text,
username=proxy_member.name,
avatar_url=proxy_member.avatar_url,
file=await make_attachment_file(original_message),
wait=True
)
# Save the proxied message in the database
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id,
original_message.author.id)
# TODO: log message in log channel
class DeletionPermissionError(Exception): async def try_proxy_message(message: discord.Message, conn) -> bool:
pass # Don't bother proxying in DMs with the bot
if isinstance(message.channel, discord.abc.PrivateChannel):
return False
# Return every member associated with the account
members = await db.get_members_by_account(conn, message.author.id)
proxy_match = match_proxy_tags(members, message.content)
if not proxy_match:
# No proxy tags match here, we done
return False
class Proxy: member, inner_text = proxy_match
def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
self.logger = logging.getLogger("pluralkit.bot.proxy")
self.session = aiohttp.ClientSession()
self.client = client
self.token = token
self.channel_logger = logger
async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str): # Sanitize inner text for @everyones and such
await db.add_webhook(conn, channel.id, id, token) inner_text = utils.sanitize(inner_text)
return id, token
async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str): # If we don't have an inner text OR an attachment, we cancel because the hook can't send that
# This method is only called if there's no webhook found in the DB (and hopefully within a transaction) if not inner_text and not message.attachments:
# No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints) return False
req_headers = {"Authorization": "Bot {}".format(self.token)}
# First, check if there's already a webhook belonging to the bot # So, we now have enough information to successfully proxy a message
async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), async with conn.transaction():
headers=req_headers) as resp: await do_proxy_message(conn, message, member, inner_text)
if resp.status == 200: return True
webhooks = await resp.json()
for webhook in webhooks:
if webhook["user"]["id"] == self.client.user.id:
# This webhook belongs to us, we can use that, return it and save it
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# Then, try submitting a new one # # TODO: possibly move this to bot __init__ so commands can access it too
req_data = {"name": "PluralKit Proxy Webhook"} # class WebhookPermissionError(Exception):
async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id), # pass
json=req_data, headers=req_headers) as resp: #
if resp.status == 200: #
webhook = await resp.json() # class DeletionPermissionError(Exception):
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"]) # pass
elif resp.status == 403: #
self.logger.warning( #
"Did not have permission to create webhook (server={}, channel={})".format(channel.server.id, # class Proxy:
channel.id)) # def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
raise WebhookPermissionError() # self.logger = logging.getLogger("pluralkit.bot.proxy")
else: # self.session = aiohttp.ClientSession()
raise discord.HTTPException(resp, await resp.text()) # self.client = client
# self.token = token
# Should not be reached without an exception being thrown # self.channel_logger = logger
#
async def get_webhook_for_channel(self, conn, channel: discord.Channel): # async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
async with conn.transaction(): # await db.add_webhook(conn, channel.id, id, token)
hook_match = await db.get_webhook(conn, channel.id) # return id, token
if not hook_match: #
# We don't have a webhook, create/add one # async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
return await self.create_and_add_channel_webhook(conn, channel) # # This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
else: # # No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
return hook_match # req_headers = {"Authorization": "Bot {}".format(self.token)}
#
async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str, # # First, check if there's already a webhook belonging to the bot
attachment_url: str, has_already_retried=False): # async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel) # headers=req_headers) as resp:
# if resp.status == 200:
form_data = aiohttp.FormData() # webhooks = await resp.json()
form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip()) # for webhook in webhooks:
# if webhook["user"]["id"] == self.client.user.id:
if text: # # This webhook belongs to us, we can use that, return it and save it
form_data.add_field("content", text) # return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
# elif resp.status == 403:
if attachment_url: # self.logger.warning(
attachment_resp = await self.session.get(attachment_url) # "Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type, # channel.id))
filename=attachment_resp.url.name) # raise WebhookPermissionError()
# else:
if member.avatar_url: # raise discord.HTTPException(resp, await resp.text())
form_data.add_field("avatar_url", member.avatar_url) #
# # Then, try submitting a new one
async with self.session.post( # req_data = {"name": "PluralKit Proxy Webhook"}
"https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token), # async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
data=form_data) as resp: # json=req_data, headers=req_headers) as resp:
if resp.status == 200: # if resp.status == 200:
message = await resp.json() # webhook = await resp.json()
# return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id) # elif resp.status == 403:
# self.logger.warning(
try: # "Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
await self.client.delete_message(original_message) # channel.id))
except discord.Forbidden: # raise WebhookPermissionError()
self.logger.warning( # else:
"Did not have permission to delete original message (server={}, channel={})".format( # raise discord.HTTPException(resp, await resp.text())
original_message.server.id, original_message.channel.id)) #
raise DeletionPermissionError() # # Should not be reached without an exception being thrown
except discord.NotFound: #
self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id)) # async def get_webhook_for_channel(self, conn, channel: discord.Channel):
# async with conn.transaction():
message_image = None # hook_match = await db.get_webhook(conn, channel.id)
if message["attachments"]: # if not hook_match:
first_attachment = message["attachments"][0] # # We don't have a webhook, create/add one
if "width" in first_attachment and "height" in first_attachment: # return await self.create_and_add_channel_webhook(conn, channel)
# Only log attachments that are actually images # else:
message_image = first_attachment["url"] # return hook_match
#
await self.channel_logger.log_message_proxied(conn, # async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
server_id=original_message.server.id, # attachment_url: str, has_already_retried=False):
channel_name=original_message.channel.name, # hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
channel_id=original_message.channel.id, #
sender_name=original_message.author.name, # form_data = aiohttp.FormData()
sender_disc=original_message.author.discriminator, # form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
sender_id=original_message.author.id, #
member_name=member.name, # if text:
member_hid=member.hid, # form_data.add_field("content", text)
member_avatar_url=member.avatar_url, #
system_name=member.system_name, # if attachment_url:
system_hid=member.system_hid, # attachment_resp = await self.session.get(attachment_url)
message_text=text, # form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
message_image=message_image, # filename=attachment_resp.url.name)
message_timestamp=ciso8601.parse_datetime( #
message["timestamp"]), # if member.avatar_url:
message_id=message["id"]) # form_data.add_field("avatar_url", member.avatar_url)
elif resp.status == 404 and not has_already_retried: #
# Webhook doesn't exist. Delete it from the DB, create, and add a new one # async with self.session.post(
self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id)) # "https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
await db.delete_webhook(conn, original_message.channel.id) # data=form_data) as resp:
await self.create_and_add_channel_webhook(conn, original_message.channel) # if resp.status == 200:
# message = await resp.json()
# Then try again all over, making sure to not retry again and go in a loop should it continually fail #
return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True) # await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id)
else: #
raise discord.HTTPException(resp, await resp.text()) # try:
# await self.client.delete_message(original_message)
async def try_proxy_message(self, conn, message: discord.Message): # except discord.Forbidden:
# Can't proxy in DMs, webhook creation will explode # self.logger.warning(
if message.channel.is_private: # "Did not have permission to delete original message (server={}, channel={})".format(
return False # original_message.server.id, original_message.channel.id))
# raise DeletionPermissionError()
# Big fat query to find every member associated with this account # except discord.NotFound:
# Returned member object has a few more keys (system tag, for example) # self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
members = await db.get_members_by_account(conn, account_id=message.author.id) #
# message_image = None
match = match_proxy_tags(members, message.content) # if message["attachments"]:
if not match: # first_attachment = message["attachments"][0]
return False # if "width" in first_attachment and "height" in first_attachment:
# # Only log attachments that are actually images
member, text = match # message_image = first_attachment["url"]
attachment_url = get_message_attachment_url(message) #
# await self.channel_logger.log_message_proxied(conn,
# Can't proxy a message with no text AND no attachment # server_id=original_message.server.id,
if not text and not attachment_url: # channel_name=original_message.channel.name,
self.logger.debug("Skipping message because of no text and no attachment") # channel_id=original_message.channel.id,
return False # sender_name=original_message.author.name,
# sender_disc=original_message.author.discriminator,
# Remember to sanitize the text (remove @everyones and such) # sender_id=original_message.author.id,
text = utils.sanitize(text) # member_name=member.name,
# member_hid=member.hid,
try: # member_avatar_url=member.avatar_url,
async with conn.transaction(): # system_name=member.system_name,
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url) # system_hid=member.system_hid,
except WebhookPermissionError: # message_text=text,
embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.") # message_image=message_image,
await self.client.send_message(message.channel, embed=embed) # message_timestamp=ciso8601.parse_datetime(
except DeletionPermissionError: # message["timestamp"]),
embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.") # message_id=message["id"])
await self.client.send_message(message.channel, embed=embed) # elif resp.status == 404 and not has_already_retried:
# # Webhook doesn't exist. Delete it from the DB, create, and add a new one
return True # self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
# await db.delete_webhook(conn, original_message.channel.id)
async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool): # await self.create_and_add_channel_webhook(conn, original_message.channel)
async with conn.transaction(): #
# Find the message in the DB, and make sure it's sent by the user (if we need to check) # # Then try again all over, making sure to not retry again and go in a loop should it continually fail
if check_user_id: # return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id) # else:
else: # raise discord.HTTPException(resp, await resp.text())
db_message = await db.get_message(conn, message_id=message_id) #
# async def try_proxy_message(self, conn, message: discord.Message):
if db_message: # # Can't proxy in DMs, webhook creation will explode
self.logger.debug("Deleting message {}".format(message_id)) # if message.channel.is_private:
channel = self.client.get_channel(str(db_message.channel)) # return False
#
# If we should also delete the actual message, do that # # Big fat query to find every member associated with this account
if delete_message: # # Returned member object has a few more keys (system tag, for example)
message = await self.client.get_message(channel, message_id) # members = await db.get_members_by_account(conn, account_id=message.author.id)
#
try: # match = match_proxy_tags(members, message.content)
await self.client.delete_message(message) # if not match:
except discord.Forbidden: # return False
self.logger.warning( #
"Did not have permission to remove message, aborting deletion (server={}, channel={})".format( # member, text = match
channel.server.id, channel.id)) # attachment_url = get_message_attachment_url(message)
return #
# # Can't proxy a message with no text AND no attachment
# Remove it from the DB # if not text and not attachment_url:
await db.delete_message(conn, message_id) # self.logger.debug("Skipping message because of no text and no attachment")
# return False
# Then log deletion to logging channel #
await self.channel_logger.log_message_deleted(conn, # # Remember to sanitize the text (remove @everyones and such)
server_id=channel.server.id, # text = utils.sanitize(text)
channel_name=channel.name, #
member_name=db_message.name, # try:
member_hid=db_message.hid, # async with conn.transaction():
member_avatar_url=db_message.avatar_url, # await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
system_name=db_message.system_name, # except WebhookPermissionError:
system_hid=db_message.system_hid, # embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
message_text=db_message.content, # await self.client.send_message(message.channel, embed=embed)
message_id=message_id, # except DeletionPermissionError:
deleted_by_moderator=deleted_by_moderator) # embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
# await self.client.send_message(message.channel, embed=embed)
async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str): #
if emoji == "": # return True
await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False) #
# async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
async def handle_deletion(self, conn, message_id: str): # async with conn.transaction():
# Don't delete the message, it's already gone at this point, just handle DB deletion and logging # # Find the message in the DB, and make sure it's sent by the user (if we need to check)
await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True) # if check_user_id:
# db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
# else:
# db_message = await db.get_message(conn, message_id=message_id)
#
# if db_message:
# self.logger.debug("Deleting message {}".format(message_id))
# channel = self.client.get_channel(str(db_message.channel))
#
# # Fetch the original message from the server
# try:
# original_message = await self.client.get_message(channel, message_id)
# except discord.NotFound:
# # Just in case it's already gone
# original_message = None
#
# # If we should also delete the actual message, do that
# if delete_message and original_message is not None:
# try:
# await self.client.delete_message(original_message)
# except discord.Forbidden:
# self.logger.warning(
# "Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
# channel.server.id, channel.id))
# return
#
# # Remove it from the DB
# await db.delete_message(conn, message_id)
#
# # Then log deletion to logging channel
# await self.channel_logger.log_message_deleted(conn,
# server_id=channel.server.id,
# channel_name=channel.name,
# member_name=db_message.name,
# member_hid=db_message.hid,
# member_avatar_url=db_message.avatar_url,
# system_name=db_message.system_name,
# system_hid=db_message.system_hid,
# message_text=original_message.content if original_message else "*(original not found)*",
# message_id=message_id,
# deleted_by_moderator=deleted_by_moderator)
#
# async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
# if emoji == "❌":
# await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
#
# async def handle_deletion(self, conn, message_id: str):
# # Don't delete the message, it's already gone at this point, just handle DB deletion and logging
# await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)

View File

@ -2,6 +2,7 @@ import logging
import re import re
import discord import discord
from typing import Optional
from pluralkit import db from pluralkit import db
from pluralkit.system import System from pluralkit.system import System
@ -20,28 +21,28 @@ def bounds_check_member_name(new_name, system_tag):
if len("{} {}".format(new_name, system_tag)) > 32: if len("{} {}".format(new_name, system_tag)) > 32:
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag) return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
async def parse_mention(client: discord.Client, mention: str) -> discord.User: async def parse_mention(client: discord.Client, mention: str) -> Optional[discord.User]:
# First try matching mention format # First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention) match = re.fullmatch("<@!?(\\d+)>", mention)
if match: if match:
try: try:
return await client.get_user_info(match.group(1)) return await client.get_user_info(int(match.group(1)))
except discord.NotFound: except discord.NotFound:
return None return None
# Then try with just ID # Then try with just ID
try: try:
return await client.get_user_info(str(int(mention))) return await client.get_user_info(int(mention))
except (ValueError, discord.NotFound): except (ValueError, discord.NotFound):
return None return None
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel: def parse_channel_mention(mention: str, server: discord.Guild) -> Optional[discord.TextChannel]:
match = re.fullmatch("<#(\\d+)>", mention) match = re.fullmatch("<#(\\d+)>", mention)
if match: if match:
return server.get_channel(match.group(1)) return server.get_channel(match.group(1))
try: try:
return server.get_channel(str(int(mention))) return server.get_channel(int(mention))
except ValueError: except ValueError:
return None return None

View File

@ -62,17 +62,17 @@ async def delete_member(conn, member_id: int):
@db_wrap @db_wrap
async def link_account(conn, system_id: int, account_id: str): async def link_account(conn, system_id: int, account_id: int):
logger.debug("Linking account (account_id={}, system_id={})".format( logger.debug("Linking account (account_id={}, system_id={})".format(
account_id, system_id)) account_id, system_id))
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id) await conn.execute("insert into accounts (uid, system) values ($1, $2)", account_id, system_id)
@db_wrap @db_wrap
async def unlink_account(conn, system_id: int, account_id: str): async def unlink_account(conn, system_id: int, account_id: int):
logger.debug("Unlinking account (account_id={}, system_id={})".format( logger.debug("Unlinking account (account_id={}, system_id={})".format(
account_id, system_id)) account_id, system_id))
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id) await conn.execute("delete from accounts where uid = $1 and system = $2", account_id, system_id)
@db_wrap @db_wrap
@ -81,8 +81,8 @@ async def get_linked_accounts(conn, system_id: int) -> List[int]:
@db_wrap @db_wrap
async def get_system_by_account(conn, account_id: str) -> System: async def get_system_by_account(conn, account_id: int) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id)) row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
return System(**row) if row else None return System(**row) if row else None
@db_wrap @db_wrap
@ -151,26 +151,26 @@ async def get_members_exceeding(conn, system_id: int, length: int) -> List[Membe
@db_wrap @db_wrap
async def get_webhook(conn, channel_id: str) -> (str, str): async def get_webhook(conn, channel_id: int) -> (str, str):
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id)) row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", channel_id)
return (str(row["webhook"]), row["token"]) if row else None return (str(row["webhook"]), row["token"]) if row else None
@db_wrap @db_wrap
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str): async def add_webhook(conn, channel_id: int, webhook_id: int, webhook_token: str):
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format( logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
channel_id, webhook_id, webhook_token)) channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token) await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", channel_id, webhook_id, webhook_token)
@db_wrap @db_wrap
async def delete_webhook(conn, channel_id: str): async def delete_webhook(conn, channel_id: int):
await conn.execute("delete from webhooks where channel = $1", int(channel_id)) await conn.execute("delete from webhooks where channel = $1", channel_id)
@db_wrap @db_wrap
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str): async def add_message(conn, message_id: int, channel_id: int, member_id: int, sender_id: int):
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format( logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
message_id, channel_id, member_id, sender_id)) message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id)) await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", message_id, channel_id, member_id, sender_id)
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])): class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
id: int id: int
@ -185,7 +185,7 @@ class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "c
system_hid: str system_hid: str
@db_wrap @db_wrap
async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]: async def get_members_by_account(conn, account_id: int) -> List[ProxyMember]:
# Returns a "chimera" object # Returns a "chimera" object
rows = await conn.fetch("""select rows = await conn.fetch("""select
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url, members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
@ -195,7 +195,7 @@ async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
where where
accounts.uid = $1 accounts.uid = $1
and systems.id = accounts.system and systems.id = accounts.system
and members.system = systems.id""", int(account_id)) and members.system = systems.id""", account_id)
return [ProxyMember(**row) for row in rows] return [ProxyMember(**row) for row in rows]
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])): class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
@ -220,7 +220,7 @@ class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender"
} }
@db_wrap @db_wrap
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) -> MessageInfo: async def get_message_by_sender_and_id(conn, message_id: int, sender_id: int) -> MessageInfo:
row = await conn.fetchrow("""select row = await conn.fetchrow("""select
messages.*, messages.*,
members.name, members.hid, members.avatar_url, members.name, members.hid, members.avatar_url,
@ -231,12 +231,12 @@ async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) ->
messages.member = members.id messages.member = members.id
and members.system = systems.id and members.system = systems.id
and mid = $1 and mid = $1
and sender = $2""", int(message_id), int(sender_id)) and sender = $2""", message_id, sender_id)
return MessageInfo(**row) if row else None return MessageInfo(**row) if row else None
@db_wrap @db_wrap
async def get_message(conn, message_id: str) -> MessageInfo: async def get_message(conn, message_id: int) -> MessageInfo:
row = await conn.fetchrow("""select row = await conn.fetchrow("""select
messages.*, messages.*,
members.name, members.hid, members.avatar_url, members.name, members.hid, members.avatar_url,
@ -246,14 +246,14 @@ async def get_message(conn, message_id: str) -> MessageInfo:
where where
messages.member = members.id messages.member = members.id
and members.system = systems.id and members.system = systems.id
and mid = $1""", int(message_id)) and mid = $1""", message_id)
return MessageInfo(**row) if row else None return MessageInfo(**row) if row else None
@db_wrap @db_wrap
async def delete_message(conn, message_id: str): async def delete_message(conn, message_id: int):
logger.debug("Deleting message (id={})".format(message_id)) logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", int(message_id)) await conn.execute("delete from messages where mid = $1", message_id)
@db_wrap @db_wrap
async def get_member_message_count(conn, member_id: int) -> int: async def get_member_message_count(conn, member_id: int) -> int:
@ -290,14 +290,14 @@ async def add_switch_member(conn, switch_id: int, member_id: int):
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id) await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
@db_wrap @db_wrap
async def get_server_info(conn, server_id: str): async def get_server_info(conn, server_id: int):
return await conn.fetchrow("select * from servers where id = $1", int(server_id)) return await conn.fetchrow("select * from servers where id = $1", server_id)
@db_wrap @db_wrap
async def update_server(conn, server_id: str, logging_channel_id: str): async def update_server(conn, server_id: int, logging_channel_id: int):
logging_channel_id = int(logging_channel_id) if logging_channel_id else None logging_channel_id = logging_channel_id if logging_channel_id else None
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id)) logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id) await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", server_id, logging_channel_id)
@db_wrap @db_wrap
async def member_count(conn) -> int: async def member_count(conn) -> int: