Merge branch 'rewrite-port'

This commit is contained in:
Ske 2018-11-08 16:43:09 +01:00
commit a72a7c3de9
13 changed files with 369 additions and 450 deletions

View File

@ -5,7 +5,4 @@ import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from pluralkit import bot
pk = bot.PluralKitBot(os.environ["TOKEN"])
loop = asyncio.get_event_loop()
loop.run_until_complete(pk.run())
bot.run()

View File

@ -1,131 +1,134 @@
import asyncio
import json
import logging
import os
import time
import asyncpg
import sys
import traceback
from datetime import datetime
import asyncio
import os
import logging
import discord
import traceback
from pluralkit import db
from pluralkit.bot import channel_logger, commands, proxy, embeds
from pluralkit.bot import commands, proxy, channel_logger, embeds
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
# logging.getLogger("pluralkit").setLevel(logging.DEBUG)
def connect_to_database() -> asyncpg.pool.Pool:
username = os.environ["DATABASE_USER"]
password = os.environ["DATABASE_PASS"]
name = os.environ["DATABASE_NAME"]
host = os.environ["DATABASE_HOST"]
port = os.environ["DATABASE_PORT"]
class PluralKitBot:
def __init__(self, token):
self.token = token
self.logger = logging.getLogger("pluralkit.bot")
if username is None or password is None or name is None or host is None or port is None:
print(
"Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.",
file=sys.stderr)
sys.exit(1)
self.client = discord.Client()
self.client.event(self.on_error)
self.client.event(self.on_ready)
self.client.event(self.on_message)
self.client.event(self.on_socket_raw_receive)
try:
port = int(port)
except ValueError:
print("Please pass a valid integer as the DATABASE_PORT environment variable.", file=sys.stderr)
sys.exit(1)
self.channel_logger = channel_logger.ChannelLogger(self.client)
return asyncio.get_event_loop().run_until_complete(db.connect(
username=username,
password=password,
database=name,
host=host,
port=port
))
self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
async def on_error(self, evt, *args, **kwargs):
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
def run():
pool = connect_to_database()
async def on_ready(self):
self.logger.info("Connected to Discord.")
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
self.logger.info("- User ID: {}".format(self.client.user.id))
self.logger.info("- {} servers".format(len(self.client.servers)))
async def create_tables():
async with pool.acquire() as conn:
await db.create_tables(conn)
# Set playing message
# TODO: change this when merging rewrite-port branch, kwarg game -> activity
await self.client.change_presence(game=discord.Game(name="pk;help"))
asyncio.get_event_loop().run_until_complete(create_tables())
async def on_message(self, message):
# Ignore bot messages
client = discord.Client()
logger = channel_logger.ChannelLogger(client)
@client.event
async def on_ready():
print("PluralKit started.")
print("User: {}#{} (ID: {})".format(client.user.name, client.user.discriminator, client.user.id))
print("{} servers".format(len(client.guilds)))
print("{} shards".format(client.shard_count or 1))
await client.change_presence(activity=discord.Game(name="pk;help"))
@client.event
async def on_message(message: discord.Message):
# Ignore messages from bots
if message.author.bot:
return
try:
if await self.handle_command_dispatch(message):
# Grab a database connection from the pool
async with pool.acquire() as conn:
# First pass: do command handling
did_run_command = await commands.command_dispatch(client, message, conn)
if did_run_command:
return
if await self.handle_proxy_dispatch(message):
return
except Exception:
await self.log_error_in_channel(message)
# Second pass: do proxy matching
await proxy.try_proxy_message(conn, message, logger)
async def on_socket_raw_receive(self, msg):
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# we parse socket data manually for the reaction add event
if isinstance(msg, str):
try:
msg_data = json.loads(msg)
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
evt_data = msg_data.get("d")
if evt_data:
user_id = evt_data["user_id"]
message_id = evt_data["message_id"]
emoji = evt_data["emoji"]["name"]
@client.event
async def on_raw_message_delete(payload: discord.RawMessageDeleteEvent):
async with pool.acquire() as conn:
await proxy.handle_deleted_message(conn, client, payload.message_id, None, logger)
async with self.pool.acquire() as conn:
await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
elif msg_data.get("t") == "MESSAGE_DELETE":
evt_data = msg_data.get("d")
if evt_data:
message_id = evt_data["id"]
async with self.pool.acquire() as conn:
await self.proxy.handle_deletion(conn, message_id)
except ValueError:
pass
@client.event
async def on_raw_bulk_message_delete(payload: discord.RawBulkMessageDeleteEvent):
async with pool.acquire() as conn:
for message_id in payload.message_ids:
await proxy.handle_deleted_message(conn, client, message_id, None, logger)
async def handle_command_dispatch(self, message):
async with self.pool.acquire() as conn:
result = await commands.command_dispatch(self.client, message, conn)
return result
@client.event
async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
if payload.emoji.name == "\u274c": # Red X
async with pool.acquire() as conn:
await proxy.try_delete_by_reaction(conn, client, payload.message_id, payload.user_id, logger)
async def handle_proxy_dispatch(self, message):
# Try doing proxy parsing
async with self.pool.acquire() as conn:
return await self.proxy.try_proxy_message(conn, message)
async def log_error_in_channel(self, message):
channel_id = os.environ["LOG_CHANNEL"]
if not channel_id:
@client.event
async def on_error(event_name, *args, **kwargs):
log_channel_id = os.environ["LOG_CHANNEL"]
if not log_channel_id:
return
channel = self.client.get_channel(channel_id)
embed = embeds.exception_log(
message.content,
message.author.name,
message.author.discriminator,
message.server.id if message.server else None,
message.channel.id
)
log_channel = client.get_channel(int(log_channel_id))
await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed)
async def run(self):
try:
self.logger.info("Connecting to database...")
self.pool = await db.connect(
os.environ["DATABASE_USER"],
os.environ["DATABASE_PASS"],
os.environ["DATABASE_NAME"],
os.environ["DATABASE_HOST"],
int(os.environ["DATABASE_PORT"])
# If this is a message event, we can attach additional information in an event
# ie. username, channel, content, etc
if args and isinstance(args[0], discord.Message):
message: discord.Message = args[0]
embed = embeds.exception_log(
message.content,
message.author.name,
message.author.discriminator,
message.author.id,
message.guild.id if message.guild else None,
message.channel.id
)
else:
# If not, just post the string itself
embed = None
self.logger.info("Attempting to create tables...")
async with self.pool.acquire() as conn:
await db.create_tables(conn)
traceback_str = "```python\n{}```".format(traceback.format_exc())
await log_channel.send(content=traceback_str, embed=embed)
self.logger.info("Connecting to Discord...")
await self.client.start(self.token)
finally:
self.logger.info("Logging out from Discord...")
await self.client.logout()
bot_token = os.environ["TOKEN"]
if not bot_token:
print("No token specified. Please pass a valid Discord bot token in the TOKEN environment variable.",
file=sys.stderr)
sys.exit(1)
client.run(bot_token)

View File

@ -19,7 +19,7 @@ class ChannelLogger:
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
self.client = client
async def get_log_channel(self, conn, server_id: str):
async def get_log_channel(self, conn, server_id: int):
server_info = await db.get_server_info(conn, server_id)
if not server_info:
@ -30,21 +30,21 @@ class ChannelLogger:
if not log_channel:
return None
return self.client.get_channel(str(log_channel))
return self.client.get_channel(log_channel)
async def send_to_log_channel(self, log_channel: discord.Channel, embed: discord.Embed, text: str = None):
async def send_to_log_channel(self, log_channel: discord.TextChannel, embed: discord.Embed, text: str = None):
try:
await self.client.send_message(log_channel, embed=embed, content=text)
await log_channel.send(content=text, embed=embed)
except discord.Forbidden:
# TODO: spew big error
self.logger.warning(
"Did not have permission to send message to logging channel (server={}, channel={})".format(
log_channel.server.id, log_channel.id))
log_channel.guild.id, log_channel.id))
async def log_message_proxied(self, conn,
server_id: str,
server_id: int,
channel_name: str,
channel_id: str,
channel_id: int,
sender_name: str,
sender_disc: int,
sender_id: int,
@ -56,11 +56,13 @@ class ChannelLogger:
message_text: str,
message_image: str,
message_timestamp: datetime,
message_id: str):
message_id: int):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
embed = discord.Embed()
embed.colour = discord.Colour.blue()
embed.description = message_text
@ -75,11 +77,10 @@ class ChannelLogger:
if message_image:
embed.set_thumbnail(url=message_image)
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
await self.send_to_log_channel(log_channel, embed, message_link)
async def log_message_deleted(self, conn,
server_id: str,
server_id: int,
channel_name: str,
member_name: str,
member_hid: str,
@ -87,22 +88,17 @@ class ChannelLogger:
system_name: str,
system_hid: str,
message_text: str,
message_id: str,
deleted_by_moderator: bool):
message_id: int):
log_channel = await self.get_log_channel(conn, server_id)
if not log_channel:
return
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.description = message_text
embed.description = message_text or "*(unknown, message deleted by moderator)*"
embed.timestamp = datetime.utcnow()
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
embed.set_footer(
text="System ID: {} | Member ID: {} | Message ID: {} | Deleted by moderator? {}".format(system_hid,
member_hid,
message_id,
"Yes" if deleted_by_moderator else "No"))
embed.set_footer(text="System ID: {} | Member ID: {} | Message ID: {}".format(system_hid, member_hid, message_id))
await self.send_to_log_channel(log_channel, embed)

View File

@ -1,3 +1,4 @@
import asyncio
import discord
import logging
import re
@ -69,7 +70,7 @@ class CommandContext:
def has_next(self) -> bool:
return bool(self.args)
def pop_str(self, error: CommandError = None) -> str:
def pop_str(self, error: CommandError = None) -> Optional[str]:
if not self.args:
if error:
raise error
@ -105,26 +106,27 @@ class CommandContext:
return self.args
async def reply(self, content=None, embed=None):
return await self.client.send_message(self.message.channel, content=content, embed=embed)
return await self.message.channel.send(content=content, embed=embed)
async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
message = await self.reply(message)
await message.add_reaction("\u2705") # Checkmark
await message.add_reaction("\u274c") # Red X
await self.client.add_reaction(message, "")
await self.client.add_reaction(message, "")
reaction = await self.client.wait_for_reaction(emoji=["", ""], user=user, timeout=60.0*5)
if not reaction:
try:
reaction, _ = await self.client.wait_for("reaction_add", check=lambda r, u: u.id == user.id and r.emoji in ["\u2705", "\u274c"], timeout=60.0*5)
return reaction.emoji == "\u2705"
except asyncio.TimeoutError:
raise CommandError("Timed out - try again.")
return reaction.reaction.emoji == ""
async def confirm_text(self, user: discord.Member, channel: discord.Channel, confirm_text: str, message: str):
async def confirm_text(self, user: discord.Member, channel: discord.TextChannel, confirm_text: str, message: str):
await self.reply(message)
message = await self.client.wait_for_message(channel=channel, author=user, timeout=60.0*5)
if not message:
try:
message = await self.client.wait_for("message", check=lambda m: m.channel.id == channel.id and m.author.id == user.id, timeout=60.0*5)
return message.content.lower() == confirm_text.lower()
except asyncio.TimeoutError:
raise CommandError("Timed out - try again.")
return message.content == confirm_text
import pluralkit.bot.commands.import_commands

View File

@ -8,22 +8,16 @@ logger = logging.getLogger("pluralkit.commands")
async def import_tupperware(ctx: CommandContext):
tupperware_ids = ["431544605209788416", "433916057053560832"] # Main bot instance and Multi-Pals-specific fork
tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if
ctx.message.server.get_member(bot_id)]
tupperware_member = ctx.message.guild.get_member(431544605209788416)
# Check if there's any Tupperware bot on the server
if not tupperware_members:
# Check if there's a Tupperware bot on the server
if not tupperware_member:
return CommandError("This command only works in a server where the Tupperware bot is also present.")
# Make sure at least one of the bts have send/read permissions here
for bot_member in tupperware_members:
channel_permissions = ctx.message.channel.permissions_for(bot_member)
if channel_permissions.read_messages and channel_permissions.send_messages:
# If so, break out of the loop
break
else:
# If no bots have permission (ie. loop doesn't break), throw error
# Make sure at the bot has send/read permissions here
channel_permissions = ctx.message.channel.permissions_for(tupperware_member)
if not (channel_permissions.read_messages and channel_permissions.send_messages):
# If it doesn't, throw error
return CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
await ctx.reply(
@ -31,29 +25,31 @@ async def import_tupperware(ctx: CommandContext):
# Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user
def ensure_account(tw_msg):
if tw_msg.author not in tupperware_members:
if tw_msg.channel.id != ctx.message.channel.id:
return False
if tw_msg.author.id != tupperware_member.id:
return False
if not tw_msg.embeds:
return False
if not tw_msg.embeds[0]["title"]:
if not tw_msg.embeds[0].title:
return False
return tw_msg.embeds[0]["title"].startswith(
return tw_msg.embeds[0].title.startswith(
"{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
tupperware_page_embeds = []
tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0 * 5,
check=ensure_account)
tw_msg: discord.Message = await ctx.client.wait_for("message", check=ensure_account, timeout=60.0 * 5)
if not tw_msg:
return CommandError("Tupperware import timed out.")
tupperware_page_embeds.append(tw_msg.embeds[0])
tupperware_page_embeds.append(tw_msg.embeds[0].to_dict())
# Handle Tupperware pagination
def match_pagination():
pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0]["title"])
pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0].title)
if not pagination_match:
return None
return int(pagination_match.group(1)), int(pagination_match.group(2))
@ -72,13 +68,12 @@ async def import_tupperware(ctx: CommandContext):
new_page, total_pages = match_pagination()
# Put the found page in the pages dict
pages_found[new_page] = dict(tw_msg.embeds[0])
pages_found[new_page] = tw_msg.embeds[0].to_dict()
# If this isn't the same page as last check, edit the status message
if new_page != current_page:
last_found_time = datetime.utcnow()
await ctx.client.edit_message(status_msg,
"Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(
await status_msg.edit(content="Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(
len(pages_found), total_pages))
current_page = new_page
@ -94,7 +89,7 @@ async def import_tupperware(ctx: CommandContext):
tupperware_page_embeds = list([embed for page, embed in sorted(pages_found.items(), key=lambda x: x[0])])
# Also edit the status message to indicate we're now importing, and it may take a while because there's probably a lot of members
await ctx.client.edit_message(status_msg, "All pages read. Now importing...")
await status_msg.edit(content="All pages read. Now importing...")
logger.debug("Importing from Tupperware...")

View File

@ -8,7 +8,7 @@ async def get_message_contents(client: discord.Client, channel_id: int, message_
channel = client.get_channel(str(channel_id))
if channel:
try:
original_message = await client.get_message(channel, str(message_id))
original_message = await client.get_channel(channel).get_message(message_id)
return original_message.content or None
except (discord.errors.Forbidden, discord.errors.NotFound):
pass
@ -24,19 +24,19 @@ async def message_info(ctx: CommandContext):
return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
# Find the message in the DB
message = await db.get_message(ctx.conn, str(mid))
message = await db.get_message(ctx.conn, mid)
if not message:
raise CommandError("Message with ID '{}' not found.".format(mid))
# Get the original sender of the messages
try:
original_sender = await ctx.client.get_user_info(str(message.sender))
original_sender = await ctx.client.get_user_info(message.sender)
except discord.NotFound:
# Account was since deleted - rare but we're handling it anyway
original_sender = None
embed = discord.Embed()
embed.timestamp = discord.utils.snowflake_time(str(mid))
embed.timestamp = discord.utils.snowflake_time(mid)
embed.colour = discord.Colour.blue()
if message.system_name:
@ -55,8 +55,7 @@ async def message_info(ctx: CommandContext):
embed.add_field(name="Sent by", value=sender_name)
message_content = await get_message_contents(ctx.client, message.channel, message.mid)
if message_content:
embed.description = message_content
embed.description = message_content or "(unknown, message deleted)"
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)

View File

@ -88,4 +88,4 @@ async def export(ctx: CommandContext):
}
f = io.BytesIO(json.dumps(data).encode("utf-8"))
await ctx.client.send_file(ctx.message.channel, f, filename="system.json")
await ctx.message.channel.send(content="Here you go!", file=discord.File(fp=f, filename="system.json"))

View File

@ -4,10 +4,13 @@ logger = logging.getLogger("pluralkit.commands")
async def set_log(ctx: CommandContext):
if not ctx.message.author.server_permissions.administrator:
if not ctx.message.author.guild_permissions.administrator:
return CommandError("You must be a server administrator to use this command.")
server = ctx.message.server
server = ctx.message.guild
if not server:
return CommandError("This command can not be run in a DM.")
if not ctx.has_next():
channel_id = None
else:

View File

@ -38,13 +38,13 @@ def status(text: str) -> discord.Embed:
return embed
def exception_log(message_content, author_name, author_discriminator, server_id, channel_id) -> discord.Embed:
def exception_log(message_content, author_name, author_discriminator, author_id, server_id, channel_id) -> discord.Embed:
embed = discord.Embed()
embed.colour = discord.Colour.dark_red()
embed.title = message_content
embed.set_footer(text="Sender: {}#{} | Server: {} | Channel: {}".format(
author_name, author_discriminator,
embed.set_footer(text="Sender: {}#{} ({}) | Server: {} | Channel: {}".format(
author_name, author_discriminator, author_id,
server_id if server_id else "(DMs)",
channel_id
))
@ -72,7 +72,7 @@ async def system_card(conn, client: discord.Client, system: System) -> discord.E
account_names = []
for account_id in await system.get_linked_account_ids(conn):
account = await client.get_user_info(str(account_id))
account = await client.get_user_info(account_id)
account_names.append("{}#{}".format(account.name, account.discriminator))
card.add_field(name="Linked accounts", value="\n".join(account_names))
@ -82,29 +82,31 @@ async def system_card(conn, client: discord.Client, system: System) -> discord.E
value=system.description, inline=False)
# Get names of all members
member_texts = []
for member in await system.get_members(conn):
member_texts.append("{} (`{}`)".format(escape(member.name), member.hid))
all_members = await system.get_members(conn)
if all_members:
member_texts = []
for member in all_members:
member_texts.append("{} (`{}`)".format(escape(member.name), member.hid))
# Interim solution for pagination of large systems
# Previously a lot of systems would hit the 1024 character limit and thus break the message
# This splits large system lists into multiple embed fields
# The 6000 character total limit will still apply here but this sort of pushes the problem until I find a better fix
pages = [""]
for member in member_texts:
last_page = pages[-1]
new_page = last_page + "\n" + member
# Interim solution for pagination of large systems
# Previously a lot of systems would hit the 1024 character limit and thus break the message
# This splits large system lists into multiple embed fields
# The 6000 character total limit will still apply here but this sort of pushes the problem until I find a better fix
pages = [""]
for member in member_texts:
last_page = pages[-1]
new_page = last_page + "\n" + member
if len(new_page) >= 1024:
pages.append(member)
else:
pages[-1] = new_page
if len(new_page) >= 1024:
pages.append(member)
else:
pages[-1] = new_page
for index, page in enumerate(pages):
field_name = "Members"
if index >= 1:
field_name = "Members (part {})".format(index + 1)
card.add_field(name=field_name, value=page, inline=False)
for index, page in enumerate(pages):
field_name = "Members"
if index >= 1:
field_name = "Members (part {})".format(index + 1)
card.add_field(name=field_name, value=page, inline=False)
card.set_footer(text="System ID: {}".format(system.hid))
return card

View File

@ -1,17 +1,17 @@
import ciso8601
from io import BytesIO
import discord
import logging
import re
import time
from typing import List, Optional
import aiohttp
import discord
from pluralkit import db
from pluralkit.bot import channel_logger, utils, embeds
from pluralkit.bot import utils, channel_logger
from pluralkit.bot.channel_logger import ChannelLogger
logger = logging.getLogger("pluralkit.bot.proxy")
def extract_leading_mentions(message_text):
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
@ -34,7 +34,9 @@ def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
# Ignore mentions at the very start of the message, and match proxy tags after those
message_text, leading_mentions = extract_leading_mentions(message_text)
logger.debug("Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions, prefix, suffix))
logger.debug(
"Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions,
prefix, suffix))
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
prefix_length = len(prefix)
@ -59,243 +61,162 @@ def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
for member in members:
match = match_member_proxy_tags(member, message_text)
if match is not None: # Using "is not None" because an empty string is OK here too
if match is not None: # Using "is not None" because an empty string is OK here too
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
return member, match
def get_message_attachment_url(message: discord.Message):
async def get_or_create_webhook_for_channel(conn, channel: discord.TextChannel):
# First, check if we have one saved in the DB
webhook_from_db = await db.get_webhook(conn, channel.id)
if webhook_from_db:
webhook_id, webhook_token = webhook_from_db
session = channel._state.http._session
hook = discord.Webhook.partial(webhook_id, webhook_token, adapter=discord.AsyncWebhookAdapter(session))
# Workaround for https://github.com/Rapptz/discord.py/issues/1242
hook._adapter.store_user = hook._adapter._store_user
return hook
# If not, we create one and save it
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
return created_webhook
async def make_attachment_file(message: discord.Message):
if not message.attachments:
return None
attachment = message.attachments[0]
if "proxy_url" in attachment:
return attachment["proxy_url"]
first_attachment = message.attachments[0]
if "url" in attachment:
return attachment["url"]
# Copy the file data to the buffer
# TODO: do this without buffering... somehow
bio = BytesIO()
await first_attachment.save(bio)
return discord.File(bio, first_attachment.filename)
# TODO: possibly move this to bot __init__ so commands can access it too
class WebhookPermissionError(Exception):
pass
async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember,
inner_text: str, logger: ChannelLogger):
# Send the message through the webhook
webhook = await get_or_create_webhook_for_channel(conn, original_message.channel)
try:
sent_message = await webhook.send(
content=inner_text,
username="{} {}".format(proxy_member.name, proxy_member.tag),
avatar_url=proxy_member.avatar_url,
file=await make_attachment_file(original_message),
wait=True
)
except discord.NotFound:
# The webhook we got from the DB doesn't actually exist
# If we delete it from the DB then call the function again, it'll re-create one for us
await db.delete_webhook(conn, original_message.channel.id)
await do_proxy_message(conn, original_message, proxy_member, inner_text, logger)
return
# Save the proxied message in the database
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id,
original_message.author.id)
await logger.log_message_proxied(
conn,
original_message.channel.guild.id,
original_message.channel.name,
original_message.channel.id,
original_message.author.name,
original_message.author.discriminator,
original_message.author.id,
proxy_member.name,
proxy_member.hid,
proxy_member.avatar_url,
proxy_member.system_name,
proxy_member.system_hid,
inner_text,
sent_message.attachments[0].url if sent_message.attachments else None,
sent_message.created_at,
sent_message.id
)
# And finally, gotta delete the original.
await original_message.delete()
class DeletionPermissionError(Exception):
pass
async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger) -> bool:
# Don't bother proxying in DMs with the bot
if isinstance(message.channel, discord.abc.PrivateChannel):
return False
# Return every member associated with the account
members = await db.get_members_by_account(conn, message.author.id)
proxy_match = match_proxy_tags(members, message.content)
if not proxy_match:
# No proxy tags match here, we done
return False
member, inner_text = proxy_match
# Sanitize inner text for @everyones and such
inner_text = utils.sanitize(inner_text)
# If we don't have an inner text OR an attachment, we cancel because the hook can't send that
# Strip so it counts a string of solely spaces as blank too
if not inner_text.strip() and not message.attachments:
return False
# So, we now have enough information to successfully proxy a message
async with conn.transaction():
await do_proxy_message(conn, message, member, inner_text, logger)
return True
class Proxy:
def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
self.logger = logging.getLogger("pluralkit.bot.proxy")
self.session = aiohttp.ClientSession()
self.client = client
self.token = token
self.channel_logger = logger
async def handle_deleted_message(conn, client: discord.Client, message_id: int,
message_content: Optional[str], logger: channel_logger.ChannelLogger) -> bool:
msg = await db.get_message(conn, message_id)
if not msg:
return False
async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
await db.add_webhook(conn, channel.id, id, token)
return id, token
channel = client.get_channel(msg.channel)
if not channel:
# Weird edge case, but channel *could* be deleted at this point (can't think of any scenarios it would be tho)
return False
async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
# This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
# No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
req_headers = {"Authorization": "Bot {}".format(self.token)}
await db.delete_message(conn, message_id)
await logger.log_message_deleted(
conn,
channel.guild.id,
channel.name,
msg.name,
msg.hid,
msg.avatar_url,
msg.system_name,
msg.system_hid,
message_content,
message_id
)
return True
# First, check if there's already a webhook belonging to the bot
async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
headers=req_headers) as resp:
if resp.status == 200:
webhooks = await resp.json()
for webhook in webhooks:
if webhook["user"]["id"] == self.client.user.id:
# This webhook belongs to us, we can use that, return it and save it
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
# Then, try submitting a new one
req_data = {"name": "PluralKit Proxy Webhook"}
async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
json=req_data, headers=req_headers) as resp:
if resp.status == 200:
webhook = await resp.json()
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
elif resp.status == 403:
self.logger.warning(
"Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
channel.id))
raise WebhookPermissionError()
else:
raise discord.HTTPException(resp, await resp.text())
async def try_delete_by_reaction(conn, client: discord.Client, message_id: int, reaction_user: int,
logger: channel_logger.ChannelLogger) -> bool:
# Find the message by the given message id or reaction user
msg = await db.get_message_by_sender_and_id(conn, message_id, reaction_user)
if not msg:
# Either the wrong user reacted or the message isn't a proxy message
# In either case - not our problem
return False
# Should not be reached without an exception being thrown
# Find the original message
original_message = await client.get_channel(msg.channel).get_message(message_id)
if not original_message:
# Message got deleted, possibly race condition, eh
return False
async def get_webhook_for_channel(self, conn, channel: discord.Channel):
async with conn.transaction():
hook_match = await db.get_webhook(conn, channel.id)
if not hook_match:
# We don't have a webhook, create/add one
return await self.create_and_add_channel_webhook(conn, channel)
else:
return hook_match
# Then delete the original message
await original_message.delete()
async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
attachment_url: str, has_already_retried=False):
hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
form_data = aiohttp.FormData()
form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
if text:
form_data.add_field("content", text)
if attachment_url:
attachment_resp = await self.session.get(attachment_url)
form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
filename=attachment_resp.url.name)
if member.avatar_url:
form_data.add_field("avatar_url", member.avatar_url)
async with self.session.post(
"https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
data=form_data) as resp:
if resp.status == 200:
message = await resp.json()
await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id)
try:
await self.client.delete_message(original_message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to delete original message (server={}, channel={})".format(
original_message.server.id, original_message.channel.id))
raise DeletionPermissionError()
except discord.NotFound:
self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
message_image = None
if message["attachments"]:
first_attachment = message["attachments"][0]
if "width" in first_attachment and "height" in first_attachment:
# Only log attachments that are actually images
message_image = first_attachment["url"]
await self.channel_logger.log_message_proxied(conn,
server_id=original_message.server.id,
channel_name=original_message.channel.name,
channel_id=original_message.channel.id,
sender_name=original_message.author.name,
sender_disc=original_message.author.discriminator,
sender_id=original_message.author.id,
member_name=member.name,
member_hid=member.hid,
member_avatar_url=member.avatar_url,
system_name=member.system_name,
system_hid=member.system_hid,
message_text=text,
message_image=message_image,
message_timestamp=ciso8601.parse_datetime(
message["timestamp"]),
message_id=message["id"])
elif resp.status == 404 and not has_already_retried:
# Webhook doesn't exist. Delete it from the DB, create, and add a new one
self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
await db.delete_webhook(conn, original_message.channel.id)
await self.create_and_add_channel_webhook(conn, original_message.channel)
# Then try again all over, making sure to not retry again and go in a loop should it continually fail
return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
else:
raise discord.HTTPException(resp, await resp.text())
async def try_proxy_message(self, conn, message: discord.Message):
# Can't proxy in DMs, webhook creation will explode
if message.channel.is_private:
return False
# Big fat query to find every member associated with this account
# Returned member object has a few more keys (system tag, for example)
members = await db.get_members_by_account(conn, account_id=message.author.id)
match = match_proxy_tags(members, message.content)
if not match:
return False
member, text = match
attachment_url = get_message_attachment_url(message)
# Can't proxy a message with no text AND no attachment
if not text and not attachment_url:
self.logger.debug("Skipping message because of no text and no attachment")
return False
# Remember to sanitize the text (remove @everyones and such)
text = utils.sanitize(text)
try:
async with conn.transaction():
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
except WebhookPermissionError:
embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
except DeletionPermissionError:
embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
await self.client.send_message(message.channel, embed=embed)
return True
async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
async with conn.transaction():
# Find the message in the DB, and make sure it's sent by the user (if we need to check)
if check_user_id:
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
else:
db_message = await db.get_message(conn, message_id=message_id)
if db_message:
self.logger.debug("Deleting message {}".format(message_id))
channel = self.client.get_channel(str(db_message.channel))
# If we should also delete the actual message, do that
if delete_message:
message = await self.client.get_message(channel, message_id)
try:
await self.client.delete_message(message)
except discord.Forbidden:
self.logger.warning(
"Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
channel.server.id, channel.id))
return
# Remove it from the DB
await db.delete_message(conn, message_id)
# Then log deletion to logging channel
await self.channel_logger.log_message_deleted(conn,
server_id=channel.server.id,
channel_name=channel.name,
member_name=db_message.name,
member_hid=db_message.hid,
member_avatar_url=db_message.avatar_url,
system_name=db_message.system_name,
system_hid=db_message.system_hid,
message_text=db_message.content,
message_id=message_id,
deleted_by_moderator=deleted_by_moderator)
async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
if emoji == "":
await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
async def handle_deletion(self, conn, message_id: str):
# Don't delete the message, it's already gone at this point, just handle DB deletion and logging
await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)
await handle_deleted_message(conn, client, message_id, original_message.content, logger)

View File

@ -2,6 +2,7 @@ import logging
import re
import discord
from typing import Optional
from pluralkit import db
from pluralkit.system import System
@ -20,28 +21,28 @@ def bounds_check_member_name(new_name, system_tag):
if len("{} {}".format(new_name, system_tag)) > 32:
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
async def parse_mention(client: discord.Client, mention: str) -> discord.User:
async def parse_mention(client: discord.Client, mention: str) -> Optional[discord.User]:
# First try matching mention format
match = re.fullmatch("<@!?(\\d+)>", mention)
if match:
try:
return await client.get_user_info(match.group(1))
return await client.get_user_info(int(match.group(1)))
except discord.NotFound:
return None
# Then try with just ID
try:
return await client.get_user_info(str(int(mention)))
return await client.get_user_info(int(mention))
except (ValueError, discord.NotFound):
return None
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
def parse_channel_mention(mention: str, server: discord.Guild) -> Optional[discord.TextChannel]:
match = re.fullmatch("<#(\\d+)>", mention)
if match:
return server.get_channel(match.group(1))
return server.get_channel(int(match.group(1)))
try:
return server.get_channel(str(int(mention)))
return server.get_channel(int(mention))
except ValueError:
return None

View File

@ -62,17 +62,17 @@ async def delete_member(conn, member_id: int):
@db_wrap
async def link_account(conn, system_id: int, account_id: str):
async def link_account(conn, system_id: int, account_id: int):
logger.debug("Linking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
await conn.execute("insert into accounts (uid, system) values ($1, $2)", account_id, system_id)
@db_wrap
async def unlink_account(conn, system_id: int, account_id: str):
async def unlink_account(conn, system_id: int, account_id: int):
logger.debug("Unlinking account (account_id={}, system_id={})".format(
account_id, system_id))
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
await conn.execute("delete from accounts where uid = $1 and system = $2", account_id, system_id)
@db_wrap
@ -81,8 +81,8 @@ async def get_linked_accounts(conn, system_id: int) -> List[int]:
@db_wrap
async def get_system_by_account(conn, account_id: str) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
async def get_system_by_account(conn, account_id: int) -> System:
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
return System(**row) if row else None
@db_wrap
@ -151,26 +151,26 @@ async def get_members_exceeding(conn, system_id: int, length: int) -> List[Membe
@db_wrap
async def get_webhook(conn, channel_id: str) -> (str, str):
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
async def get_webhook(conn, channel_id: int) -> (str, str):
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", channel_id)
return (str(row["webhook"]), row["token"]) if row else None
@db_wrap
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
async def add_webhook(conn, channel_id: int, webhook_id: int, webhook_token: str):
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
channel_id, webhook_id, webhook_token))
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", channel_id, webhook_id, webhook_token)
@db_wrap
async def delete_webhook(conn, channel_id: str):
await conn.execute("delete from webhooks where channel = $1", int(channel_id))
async def delete_webhook(conn, channel_id: int):
await conn.execute("delete from webhooks where channel = $1", channel_id)
@db_wrap
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str):
async def add_message(conn, message_id: int, channel_id: int, member_id: int, sender_id: int):
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
message_id, channel_id, member_id, sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id))
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", message_id, channel_id, member_id, sender_id)
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
id: int
@ -185,7 +185,7 @@ class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "c
system_hid: str
@db_wrap
async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
async def get_members_by_account(conn, account_id: int) -> List[ProxyMember]:
# Returns a "chimera" object
rows = await conn.fetch("""select
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
@ -195,7 +195,7 @@ async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
where
accounts.uid = $1
and systems.id = accounts.system
and members.system = systems.id""", int(account_id))
and members.system = systems.id""", account_id)
return [ProxyMember(**row) for row in rows]
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
@ -220,7 +220,7 @@ class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender"
}
@db_wrap
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) -> MessageInfo:
async def get_message_by_sender_and_id(conn, message_id: int, sender_id: int) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url,
@ -231,12 +231,12 @@ async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) ->
messages.member = members.id
and members.system = systems.id
and mid = $1
and sender = $2""", int(message_id), int(sender_id))
and sender = $2""", message_id, sender_id)
return MessageInfo(**row) if row else None
@db_wrap
async def get_message(conn, message_id: str) -> MessageInfo:
async def get_message(conn, message_id: int) -> MessageInfo:
row = await conn.fetchrow("""select
messages.*,
members.name, members.hid, members.avatar_url,
@ -246,14 +246,14 @@ async def get_message(conn, message_id: str) -> MessageInfo:
where
messages.member = members.id
and members.system = systems.id
and mid = $1""", int(message_id))
and mid = $1""", message_id)
return MessageInfo(**row) if row else None
@db_wrap
async def delete_message(conn, message_id: str):
async def delete_message(conn, message_id: int):
logger.debug("Deleting message (id={})".format(message_id))
await conn.execute("delete from messages where mid = $1", int(message_id))
await conn.execute("delete from messages where mid = $1", message_id)
@db_wrap
async def get_member_message_count(conn, member_id: int) -> int:
@ -290,14 +290,14 @@ async def add_switch_member(conn, switch_id: int, member_id: int):
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
@db_wrap
async def get_server_info(conn, server_id: str):
return await conn.fetchrow("select * from servers where id = $1", int(server_id))
async def get_server_info(conn, server_id: int):
return await conn.fetchrow("select * from servers where id = $1", server_id)
@db_wrap
async def update_server(conn, server_id: str, logging_channel_id: str):
logging_channel_id = int(logging_channel_id) if logging_channel_id else None
async def update_server(conn, server_id: int, logging_channel_id: int):
logging_channel_id = logging_channel_id if logging_channel_id else None
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", server_id, logging_channel_id)
@db_wrap
async def member_count(conn) -> int:

View File

@ -2,7 +2,7 @@ aiodns
aiohttp
asyncpg
dateparser
discord.py
https://github.com/Rapptz/discord.py/archive/860d6a9ace8248dfeec18b8b159e7b757d9f56bb.zip#egg=discord.py
humanize
uvloop
ciso8601