Merge branch 'rewrite-port'
This commit is contained in:
commit
a72a7c3de9
@ -5,7 +5,4 @@ import uvloop
|
|||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
from pluralkit import bot
|
from pluralkit import bot
|
||||||
|
bot.run()
|
||||||
pk = bot.PluralKitBot(os.environ["TOKEN"])
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.run_until_complete(pk.run())
|
|
@ -1,131 +1,134 @@
|
|||||||
import asyncio
|
import asyncpg
|
||||||
import json
|
import sys
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import traceback
|
import asyncio
|
||||||
from datetime import datetime
|
import os
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
import traceback
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.bot import channel_logger, commands, proxy, embeds
|
from pluralkit.bot import commands, proxy, channel_logger, embeds
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||||
|
|
||||||
|
|
||||||
# logging.getLogger("pluralkit").setLevel(logging.DEBUG)
|
def connect_to_database() -> asyncpg.pool.Pool:
|
||||||
|
username = os.environ["DATABASE_USER"]
|
||||||
|
password = os.environ["DATABASE_PASS"]
|
||||||
|
name = os.environ["DATABASE_NAME"]
|
||||||
|
host = os.environ["DATABASE_HOST"]
|
||||||
|
port = os.environ["DATABASE_PORT"]
|
||||||
|
|
||||||
class PluralKitBot:
|
if username is None or password is None or name is None or host is None or port is None:
|
||||||
def __init__(self, token):
|
print(
|
||||||
self.token = token
|
"Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.",
|
||||||
self.logger = logging.getLogger("pluralkit.bot")
|
file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
self.client = discord.Client()
|
try:
|
||||||
self.client.event(self.on_error)
|
port = int(port)
|
||||||
self.client.event(self.on_ready)
|
except ValueError:
|
||||||
self.client.event(self.on_message)
|
print("Please pass a valid integer as the DATABASE_PORT environment variable.", file=sys.stderr)
|
||||||
self.client.event(self.on_socket_raw_receive)
|
sys.exit(1)
|
||||||
|
|
||||||
self.channel_logger = channel_logger.ChannelLogger(self.client)
|
return asyncio.get_event_loop().run_until_complete(db.connect(
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
database=name,
|
||||||
|
host=host,
|
||||||
|
port=port
|
||||||
|
))
|
||||||
|
|
||||||
self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
|
|
||||||
|
|
||||||
async def on_error(self, evt, *args, **kwargs):
|
def run():
|
||||||
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
|
pool = connect_to_database()
|
||||||
|
|
||||||
async def on_ready(self):
|
async def create_tables():
|
||||||
self.logger.info("Connected to Discord.")
|
async with pool.acquire() as conn:
|
||||||
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
|
await db.create_tables(conn)
|
||||||
self.logger.info("- User ID: {}".format(self.client.user.id))
|
|
||||||
self.logger.info("- {} servers".format(len(self.client.servers)))
|
|
||||||
|
|
||||||
# Set playing message
|
asyncio.get_event_loop().run_until_complete(create_tables())
|
||||||
# TODO: change this when merging rewrite-port branch, kwarg game -> activity
|
|
||||||
await self.client.change_presence(game=discord.Game(name="pk;help"))
|
|
||||||
|
|
||||||
async def on_message(self, message):
|
client = discord.Client()
|
||||||
# Ignore bot messages
|
|
||||||
|
logger = channel_logger.ChannelLogger(client)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
print("PluralKit started.")
|
||||||
|
print("User: {}#{} (ID: {})".format(client.user.name, client.user.discriminator, client.user.id))
|
||||||
|
print("{} servers".format(len(client.guilds)))
|
||||||
|
print("{} shards".format(client.shard_count or 1))
|
||||||
|
|
||||||
|
await client.change_presence(activity=discord.Game(name="pk;help"))
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_message(message: discord.Message):
|
||||||
|
# Ignore messages from bots
|
||||||
if message.author.bot:
|
if message.author.bot:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
# Grab a database connection from the pool
|
||||||
if await self.handle_command_dispatch(message):
|
async with pool.acquire() as conn:
|
||||||
|
# First pass: do command handling
|
||||||
|
did_run_command = await commands.command_dispatch(client, message, conn)
|
||||||
|
if did_run_command:
|
||||||
return
|
return
|
||||||
|
|
||||||
if await self.handle_proxy_dispatch(message):
|
# Second pass: do proxy matching
|
||||||
return
|
await proxy.try_proxy_message(conn, message, logger)
|
||||||
except Exception:
|
|
||||||
await self.log_error_in_channel(message)
|
|
||||||
|
|
||||||
async def on_socket_raw_receive(self, msg):
|
@client.event
|
||||||
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
|
async def on_raw_message_delete(payload: discord.RawMessageDeleteEvent):
|
||||||
# we parse socket data manually for the reaction add event
|
async with pool.acquire() as conn:
|
||||||
if isinstance(msg, str):
|
await proxy.handle_deleted_message(conn, client, payload.message_id, None, logger)
|
||||||
try:
|
|
||||||
msg_data = json.loads(msg)
|
|
||||||
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
|
|
||||||
evt_data = msg_data.get("d")
|
|
||||||
if evt_data:
|
|
||||||
user_id = evt_data["user_id"]
|
|
||||||
message_id = evt_data["message_id"]
|
|
||||||
emoji = evt_data["emoji"]["name"]
|
|
||||||
|
|
||||||
async with self.pool.acquire() as conn:
|
@client.event
|
||||||
await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
|
async def on_raw_bulk_message_delete(payload: discord.RawBulkMessageDeleteEvent):
|
||||||
elif msg_data.get("t") == "MESSAGE_DELETE":
|
async with pool.acquire() as conn:
|
||||||
evt_data = msg_data.get("d")
|
for message_id in payload.message_ids:
|
||||||
if evt_data:
|
await proxy.handle_deleted_message(conn, client, message_id, None, logger)
|
||||||
message_id = evt_data["id"]
|
|
||||||
async with self.pool.acquire() as conn:
|
|
||||||
await self.proxy.handle_deletion(conn, message_id)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def handle_command_dispatch(self, message):
|
@client.event
|
||||||
async with self.pool.acquire() as conn:
|
async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
|
||||||
result = await commands.command_dispatch(self.client, message, conn)
|
if payload.emoji.name == "\u274c": # Red X
|
||||||
return result
|
async with pool.acquire() as conn:
|
||||||
|
await proxy.try_delete_by_reaction(conn, client, payload.message_id, payload.user_id, logger)
|
||||||
|
|
||||||
async def handle_proxy_dispatch(self, message):
|
@client.event
|
||||||
# Try doing proxy parsing
|
async def on_error(event_name, *args, **kwargs):
|
||||||
async with self.pool.acquire() as conn:
|
log_channel_id = os.environ["LOG_CHANNEL"]
|
||||||
return await self.proxy.try_proxy_message(conn, message)
|
if not log_channel_id:
|
||||||
|
|
||||||
async def log_error_in_channel(self, message):
|
|
||||||
channel_id = os.environ["LOG_CHANNEL"]
|
|
||||||
if not channel_id:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
channel = self.client.get_channel(channel_id)
|
log_channel = client.get_channel(int(log_channel_id))
|
||||||
embed = embeds.exception_log(
|
|
||||||
message.content,
|
|
||||||
message.author.name,
|
|
||||||
message.author.discriminator,
|
|
||||||
message.server.id if message.server else None,
|
|
||||||
message.channel.id
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed)
|
# If this is a message event, we can attach additional information in an event
|
||||||
|
# ie. username, channel, content, etc
|
||||||
async def run(self):
|
if args and isinstance(args[0], discord.Message):
|
||||||
try:
|
message: discord.Message = args[0]
|
||||||
self.logger.info("Connecting to database...")
|
embed = embeds.exception_log(
|
||||||
self.pool = await db.connect(
|
message.content,
|
||||||
os.environ["DATABASE_USER"],
|
message.author.name,
|
||||||
os.environ["DATABASE_PASS"],
|
message.author.discriminator,
|
||||||
os.environ["DATABASE_NAME"],
|
message.author.id,
|
||||||
os.environ["DATABASE_HOST"],
|
message.guild.id if message.guild else None,
|
||||||
int(os.environ["DATABASE_PORT"])
|
message.channel.id
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# If not, just post the string itself
|
||||||
|
embed = None
|
||||||
|
|
||||||
self.logger.info("Attempting to create tables...")
|
traceback_str = "```python\n{}```".format(traceback.format_exc())
|
||||||
async with self.pool.acquire() as conn:
|
await log_channel.send(content=traceback_str, embed=embed)
|
||||||
await db.create_tables(conn)
|
|
||||||
|
|
||||||
self.logger.info("Connecting to Discord...")
|
bot_token = os.environ["TOKEN"]
|
||||||
await self.client.start(self.token)
|
if not bot_token:
|
||||||
finally:
|
print("No token specified. Please pass a valid Discord bot token in the TOKEN environment variable.",
|
||||||
self.logger.info("Logging out from Discord...")
|
file=sys.stderr)
|
||||||
await self.client.logout()
|
sys.exit(1)
|
||||||
|
|
||||||
|
client.run(bot_token)
|
@ -19,7 +19,7 @@ class ChannelLogger:
|
|||||||
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
|
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
async def get_log_channel(self, conn, server_id: str):
|
async def get_log_channel(self, conn, server_id: int):
|
||||||
server_info = await db.get_server_info(conn, server_id)
|
server_info = await db.get_server_info(conn, server_id)
|
||||||
|
|
||||||
if not server_info:
|
if not server_info:
|
||||||
@ -30,21 +30,21 @@ class ChannelLogger:
|
|||||||
if not log_channel:
|
if not log_channel:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.client.get_channel(str(log_channel))
|
return self.client.get_channel(log_channel)
|
||||||
|
|
||||||
async def send_to_log_channel(self, log_channel: discord.Channel, embed: discord.Embed, text: str = None):
|
async def send_to_log_channel(self, log_channel: discord.TextChannel, embed: discord.Embed, text: str = None):
|
||||||
try:
|
try:
|
||||||
await self.client.send_message(log_channel, embed=embed, content=text)
|
await log_channel.send(content=text, embed=embed)
|
||||||
except discord.Forbidden:
|
except discord.Forbidden:
|
||||||
# TODO: spew big error
|
# TODO: spew big error
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Did not have permission to send message to logging channel (server={}, channel={})".format(
|
"Did not have permission to send message to logging channel (server={}, channel={})".format(
|
||||||
log_channel.server.id, log_channel.id))
|
log_channel.guild.id, log_channel.id))
|
||||||
|
|
||||||
async def log_message_proxied(self, conn,
|
async def log_message_proxied(self, conn,
|
||||||
server_id: str,
|
server_id: int,
|
||||||
channel_name: str,
|
channel_name: str,
|
||||||
channel_id: str,
|
channel_id: int,
|
||||||
sender_name: str,
|
sender_name: str,
|
||||||
sender_disc: int,
|
sender_disc: int,
|
||||||
sender_id: int,
|
sender_id: int,
|
||||||
@ -56,11 +56,13 @@ class ChannelLogger:
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_image: str,
|
message_image: str,
|
||||||
message_timestamp: datetime,
|
message_timestamp: datetime,
|
||||||
message_id: str):
|
message_id: int):
|
||||||
log_channel = await self.get_log_channel(conn, server_id)
|
log_channel = await self.get_log_channel(conn, server_id)
|
||||||
if not log_channel:
|
if not log_channel:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
|
||||||
|
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.colour = discord.Colour.blue()
|
embed.colour = discord.Colour.blue()
|
||||||
embed.description = message_text
|
embed.description = message_text
|
||||||
@ -75,11 +77,10 @@ class ChannelLogger:
|
|||||||
if message_image:
|
if message_image:
|
||||||
embed.set_thumbnail(url=message_image)
|
embed.set_thumbnail(url=message_image)
|
||||||
|
|
||||||
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
|
|
||||||
await self.send_to_log_channel(log_channel, embed, message_link)
|
await self.send_to_log_channel(log_channel, embed, message_link)
|
||||||
|
|
||||||
async def log_message_deleted(self, conn,
|
async def log_message_deleted(self, conn,
|
||||||
server_id: str,
|
server_id: int,
|
||||||
channel_name: str,
|
channel_name: str,
|
||||||
member_name: str,
|
member_name: str,
|
||||||
member_hid: str,
|
member_hid: str,
|
||||||
@ -87,22 +88,17 @@ class ChannelLogger:
|
|||||||
system_name: str,
|
system_name: str,
|
||||||
system_hid: str,
|
system_hid: str,
|
||||||
message_text: str,
|
message_text: str,
|
||||||
message_id: str,
|
message_id: int):
|
||||||
deleted_by_moderator: bool):
|
|
||||||
log_channel = await self.get_log_channel(conn, server_id)
|
log_channel = await self.get_log_channel(conn, server_id)
|
||||||
if not log_channel:
|
if not log_channel:
|
||||||
return
|
return
|
||||||
|
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.colour = discord.Colour.dark_red()
|
embed.colour = discord.Colour.dark_red()
|
||||||
embed.description = message_text
|
embed.description = message_text or "*(unknown, message deleted by moderator)*"
|
||||||
embed.timestamp = datetime.utcnow()
|
embed.timestamp = datetime.utcnow()
|
||||||
|
|
||||||
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
|
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
|
||||||
embed.set_footer(
|
embed.set_footer(text="System ID: {} | Member ID: {} | Message ID: {}".format(system_hid, member_hid, message_id))
|
||||||
text="System ID: {} | Member ID: {} | Message ID: {} | Deleted by moderator? {}".format(system_hid,
|
|
||||||
member_hid,
|
|
||||||
message_id,
|
|
||||||
"Yes" if deleted_by_moderator else "No"))
|
|
||||||
|
|
||||||
await self.send_to_log_channel(log_channel, embed)
|
await self.send_to_log_channel(log_channel, embed)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import discord
|
import discord
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -69,7 +70,7 @@ class CommandContext:
|
|||||||
def has_next(self) -> bool:
|
def has_next(self) -> bool:
|
||||||
return bool(self.args)
|
return bool(self.args)
|
||||||
|
|
||||||
def pop_str(self, error: CommandError = None) -> str:
|
def pop_str(self, error: CommandError = None) -> Optional[str]:
|
||||||
if not self.args:
|
if not self.args:
|
||||||
if error:
|
if error:
|
||||||
raise error
|
raise error
|
||||||
@ -105,26 +106,27 @@ class CommandContext:
|
|||||||
return self.args
|
return self.args
|
||||||
|
|
||||||
async def reply(self, content=None, embed=None):
|
async def reply(self, content=None, embed=None):
|
||||||
return await self.client.send_message(self.message.channel, content=content, embed=embed)
|
return await self.message.channel.send(content=content, embed=embed)
|
||||||
|
|
||||||
async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
|
async def confirm_react(self, user: Union[discord.Member, discord.User], message: str):
|
||||||
message = await self.reply(message)
|
message = await self.reply(message)
|
||||||
|
await message.add_reaction("\u2705") # Checkmark
|
||||||
|
await message.add_reaction("\u274c") # Red X
|
||||||
|
|
||||||
await self.client.add_reaction(message, "✅")
|
try:
|
||||||
await self.client.add_reaction(message, "❌")
|
reaction, _ = await self.client.wait_for("reaction_add", check=lambda r, u: u.id == user.id and r.emoji in ["\u2705", "\u274c"], timeout=60.0*5)
|
||||||
|
return reaction.emoji == "\u2705"
|
||||||
reaction = await self.client.wait_for_reaction(emoji=["✅", "❌"], user=user, timeout=60.0*5)
|
except asyncio.TimeoutError:
|
||||||
if not reaction:
|
|
||||||
raise CommandError("Timed out - try again.")
|
raise CommandError("Timed out - try again.")
|
||||||
return reaction.reaction.emoji == "✅"
|
|
||||||
|
|
||||||
async def confirm_text(self, user: discord.Member, channel: discord.Channel, confirm_text: str, message: str):
|
async def confirm_text(self, user: discord.Member, channel: discord.TextChannel, confirm_text: str, message: str):
|
||||||
await self.reply(message)
|
await self.reply(message)
|
||||||
|
|
||||||
message = await self.client.wait_for_message(channel=channel, author=user, timeout=60.0*5)
|
try:
|
||||||
if not message:
|
message = await self.client.wait_for("message", check=lambda m: m.channel.id == channel.id and m.author.id == user.id, timeout=60.0*5)
|
||||||
|
return message.content.lower() == confirm_text.lower()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
raise CommandError("Timed out - try again.")
|
raise CommandError("Timed out - try again.")
|
||||||
return message.content == confirm_text
|
|
||||||
|
|
||||||
|
|
||||||
import pluralkit.bot.commands.import_commands
|
import pluralkit.bot.commands.import_commands
|
||||||
|
@ -8,22 +8,16 @@ logger = logging.getLogger("pluralkit.commands")
|
|||||||
|
|
||||||
|
|
||||||
async def import_tupperware(ctx: CommandContext):
|
async def import_tupperware(ctx: CommandContext):
|
||||||
tupperware_ids = ["431544605209788416", "433916057053560832"] # Main bot instance and Multi-Pals-specific fork
|
tupperware_member = ctx.message.guild.get_member(431544605209788416)
|
||||||
tupperware_members = [ctx.message.server.get_member(bot_id) for bot_id in tupperware_ids if
|
|
||||||
ctx.message.server.get_member(bot_id)]
|
|
||||||
|
|
||||||
# Check if there's any Tupperware bot on the server
|
# Check if there's a Tupperware bot on the server
|
||||||
if not tupperware_members:
|
if not tupperware_member:
|
||||||
return CommandError("This command only works in a server where the Tupperware bot is also present.")
|
return CommandError("This command only works in a server where the Tupperware bot is also present.")
|
||||||
|
|
||||||
# Make sure at least one of the bts have send/read permissions here
|
# Make sure at the bot has send/read permissions here
|
||||||
for bot_member in tupperware_members:
|
channel_permissions = ctx.message.channel.permissions_for(tupperware_member)
|
||||||
channel_permissions = ctx.message.channel.permissions_for(bot_member)
|
if not (channel_permissions.read_messages and channel_permissions.send_messages):
|
||||||
if channel_permissions.read_messages and channel_permissions.send_messages:
|
# If it doesn't, throw error
|
||||||
# If so, break out of the loop
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# If no bots have permission (ie. loop doesn't break), throw error
|
|
||||||
return CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
|
return CommandError("This command only works in a channel where the Tupperware bot has read/send access.")
|
||||||
|
|
||||||
await ctx.reply(
|
await ctx.reply(
|
||||||
@ -31,29 +25,31 @@ async def import_tupperware(ctx: CommandContext):
|
|||||||
|
|
||||||
# Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user
|
# Check to make sure the message is sent by Tupperware, and that the Tupperware response actually belongs to the correct user
|
||||||
def ensure_account(tw_msg):
|
def ensure_account(tw_msg):
|
||||||
if tw_msg.author not in tupperware_members:
|
if tw_msg.channel.id != ctx.message.channel.id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if tw_msg.author.id != tupperware_member.id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not tw_msg.embeds:
|
if not tw_msg.embeds:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not tw_msg.embeds[0]["title"]:
|
if not tw_msg.embeds[0].title:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return tw_msg.embeds[0]["title"].startswith(
|
return tw_msg.embeds[0].title.startswith(
|
||||||
"{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
|
"{}#{}".format(ctx.message.author.name, ctx.message.author.discriminator))
|
||||||
|
|
||||||
tupperware_page_embeds = []
|
tupperware_page_embeds = []
|
||||||
|
|
||||||
tw_msg: discord.Message = await ctx.client.wait_for_message(channel=ctx.message.channel, timeout=60.0 * 5,
|
tw_msg: discord.Message = await ctx.client.wait_for("message", check=ensure_account, timeout=60.0 * 5)
|
||||||
check=ensure_account)
|
|
||||||
if not tw_msg:
|
if not tw_msg:
|
||||||
return CommandError("Tupperware import timed out.")
|
return CommandError("Tupperware import timed out.")
|
||||||
tupperware_page_embeds.append(tw_msg.embeds[0])
|
tupperware_page_embeds.append(tw_msg.embeds[0].to_dict())
|
||||||
|
|
||||||
# Handle Tupperware pagination
|
# Handle Tupperware pagination
|
||||||
def match_pagination():
|
def match_pagination():
|
||||||
pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0]["title"])
|
pagination_match = re.search(r"\(page (\d+)/(\d+), \d+ total\)", tw_msg.embeds[0].title)
|
||||||
if not pagination_match:
|
if not pagination_match:
|
||||||
return None
|
return None
|
||||||
return int(pagination_match.group(1)), int(pagination_match.group(2))
|
return int(pagination_match.group(1)), int(pagination_match.group(2))
|
||||||
@ -72,13 +68,12 @@ async def import_tupperware(ctx: CommandContext):
|
|||||||
new_page, total_pages = match_pagination()
|
new_page, total_pages = match_pagination()
|
||||||
|
|
||||||
# Put the found page in the pages dict
|
# Put the found page in the pages dict
|
||||||
pages_found[new_page] = dict(tw_msg.embeds[0])
|
pages_found[new_page] = tw_msg.embeds[0].to_dict()
|
||||||
|
|
||||||
# If this isn't the same page as last check, edit the status message
|
# If this isn't the same page as last check, edit the status message
|
||||||
if new_page != current_page:
|
if new_page != current_page:
|
||||||
last_found_time = datetime.utcnow()
|
last_found_time = datetime.utcnow()
|
||||||
await ctx.client.edit_message(status_msg,
|
await status_msg.edit(content="Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(
|
||||||
"Multi-page member list found. Please manually scroll through all the pages. Read {}/{} pages.".format(
|
|
||||||
len(pages_found), total_pages))
|
len(pages_found), total_pages))
|
||||||
current_page = new_page
|
current_page = new_page
|
||||||
|
|
||||||
@ -94,7 +89,7 @@ async def import_tupperware(ctx: CommandContext):
|
|||||||
tupperware_page_embeds = list([embed for page, embed in sorted(pages_found.items(), key=lambda x: x[0])])
|
tupperware_page_embeds = list([embed for page, embed in sorted(pages_found.items(), key=lambda x: x[0])])
|
||||||
|
|
||||||
# Also edit the status message to indicate we're now importing, and it may take a while because there's probably a lot of members
|
# Also edit the status message to indicate we're now importing, and it may take a while because there's probably a lot of members
|
||||||
await ctx.client.edit_message(status_msg, "All pages read. Now importing...")
|
await status_msg.edit(content="All pages read. Now importing...")
|
||||||
|
|
||||||
logger.debug("Importing from Tupperware...")
|
logger.debug("Importing from Tupperware...")
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ async def get_message_contents(client: discord.Client, channel_id: int, message_
|
|||||||
channel = client.get_channel(str(channel_id))
|
channel = client.get_channel(str(channel_id))
|
||||||
if channel:
|
if channel:
|
||||||
try:
|
try:
|
||||||
original_message = await client.get_message(channel, str(message_id))
|
original_message = await client.get_channel(channel).get_message(message_id)
|
||||||
return original_message.content or None
|
return original_message.content or None
|
||||||
except (discord.errors.Forbidden, discord.errors.NotFound):
|
except (discord.errors.Forbidden, discord.errors.NotFound):
|
||||||
pass
|
pass
|
||||||
@ -24,19 +24,19 @@ async def message_info(ctx: CommandContext):
|
|||||||
return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
|
return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
|
||||||
|
|
||||||
# Find the message in the DB
|
# Find the message in the DB
|
||||||
message = await db.get_message(ctx.conn, str(mid))
|
message = await db.get_message(ctx.conn, mid)
|
||||||
if not message:
|
if not message:
|
||||||
raise CommandError("Message with ID '{}' not found.".format(mid))
|
raise CommandError("Message with ID '{}' not found.".format(mid))
|
||||||
|
|
||||||
# Get the original sender of the messages
|
# Get the original sender of the messages
|
||||||
try:
|
try:
|
||||||
original_sender = await ctx.client.get_user_info(str(message.sender))
|
original_sender = await ctx.client.get_user_info(message.sender)
|
||||||
except discord.NotFound:
|
except discord.NotFound:
|
||||||
# Account was since deleted - rare but we're handling it anyway
|
# Account was since deleted - rare but we're handling it anyway
|
||||||
original_sender = None
|
original_sender = None
|
||||||
|
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.timestamp = discord.utils.snowflake_time(str(mid))
|
embed.timestamp = discord.utils.snowflake_time(mid)
|
||||||
embed.colour = discord.Colour.blue()
|
embed.colour = discord.Colour.blue()
|
||||||
|
|
||||||
if message.system_name:
|
if message.system_name:
|
||||||
@ -55,8 +55,7 @@ async def message_info(ctx: CommandContext):
|
|||||||
embed.add_field(name="Sent by", value=sender_name)
|
embed.add_field(name="Sent by", value=sender_name)
|
||||||
|
|
||||||
message_content = await get_message_contents(ctx.client, message.channel, message.mid)
|
message_content = await get_message_contents(ctx.client, message.channel, message.mid)
|
||||||
if message_content:
|
embed.description = message_content or "(unknown, message deleted)"
|
||||||
embed.description = message_content
|
|
||||||
|
|
||||||
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
|
embed.set_author(name=message.name, icon_url=message.avatar_url or discord.Embed.Empty)
|
||||||
|
|
||||||
|
@ -88,4 +88,4 @@ async def export(ctx: CommandContext):
|
|||||||
}
|
}
|
||||||
|
|
||||||
f = io.BytesIO(json.dumps(data).encode("utf-8"))
|
f = io.BytesIO(json.dumps(data).encode("utf-8"))
|
||||||
await ctx.client.send_file(ctx.message.channel, f, filename="system.json")
|
await ctx.message.channel.send(content="Here you go!", file=discord.File(fp=f, filename="system.json"))
|
||||||
|
@ -4,10 +4,13 @@ logger = logging.getLogger("pluralkit.commands")
|
|||||||
|
|
||||||
|
|
||||||
async def set_log(ctx: CommandContext):
|
async def set_log(ctx: CommandContext):
|
||||||
if not ctx.message.author.server_permissions.administrator:
|
if not ctx.message.author.guild_permissions.administrator:
|
||||||
return CommandError("You must be a server administrator to use this command.")
|
return CommandError("You must be a server administrator to use this command.")
|
||||||
|
|
||||||
server = ctx.message.server
|
server = ctx.message.guild
|
||||||
|
if not server:
|
||||||
|
return CommandError("This command can not be run in a DM.")
|
||||||
|
|
||||||
if not ctx.has_next():
|
if not ctx.has_next():
|
||||||
channel_id = None
|
channel_id = None
|
||||||
else:
|
else:
|
||||||
|
@ -38,13 +38,13 @@ def status(text: str) -> discord.Embed:
|
|||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
def exception_log(message_content, author_name, author_discriminator, server_id, channel_id) -> discord.Embed:
|
def exception_log(message_content, author_name, author_discriminator, author_id, server_id, channel_id) -> discord.Embed:
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.colour = discord.Colour.dark_red()
|
embed.colour = discord.Colour.dark_red()
|
||||||
embed.title = message_content
|
embed.title = message_content
|
||||||
|
|
||||||
embed.set_footer(text="Sender: {}#{} | Server: {} | Channel: {}".format(
|
embed.set_footer(text="Sender: {}#{} ({}) | Server: {} | Channel: {}".format(
|
||||||
author_name, author_discriminator,
|
author_name, author_discriminator, author_id,
|
||||||
server_id if server_id else "(DMs)",
|
server_id if server_id else "(DMs)",
|
||||||
channel_id
|
channel_id
|
||||||
))
|
))
|
||||||
@ -72,7 +72,7 @@ async def system_card(conn, client: discord.Client, system: System) -> discord.E
|
|||||||
|
|
||||||
account_names = []
|
account_names = []
|
||||||
for account_id in await system.get_linked_account_ids(conn):
|
for account_id in await system.get_linked_account_ids(conn):
|
||||||
account = await client.get_user_info(str(account_id))
|
account = await client.get_user_info(account_id)
|
||||||
account_names.append("{}#{}".format(account.name, account.discriminator))
|
account_names.append("{}#{}".format(account.name, account.discriminator))
|
||||||
|
|
||||||
card.add_field(name="Linked accounts", value="\n".join(account_names))
|
card.add_field(name="Linked accounts", value="\n".join(account_names))
|
||||||
@ -82,29 +82,31 @@ async def system_card(conn, client: discord.Client, system: System) -> discord.E
|
|||||||
value=system.description, inline=False)
|
value=system.description, inline=False)
|
||||||
|
|
||||||
# Get names of all members
|
# Get names of all members
|
||||||
member_texts = []
|
all_members = await system.get_members(conn)
|
||||||
for member in await system.get_members(conn):
|
if all_members:
|
||||||
member_texts.append("{} (`{}`)".format(escape(member.name), member.hid))
|
member_texts = []
|
||||||
|
for member in all_members:
|
||||||
|
member_texts.append("{} (`{}`)".format(escape(member.name), member.hid))
|
||||||
|
|
||||||
# Interim solution for pagination of large systems
|
# Interim solution for pagination of large systems
|
||||||
# Previously a lot of systems would hit the 1024 character limit and thus break the message
|
# Previously a lot of systems would hit the 1024 character limit and thus break the message
|
||||||
# This splits large system lists into multiple embed fields
|
# This splits large system lists into multiple embed fields
|
||||||
# The 6000 character total limit will still apply here but this sort of pushes the problem until I find a better fix
|
# The 6000 character total limit will still apply here but this sort of pushes the problem until I find a better fix
|
||||||
pages = [""]
|
pages = [""]
|
||||||
for member in member_texts:
|
for member in member_texts:
|
||||||
last_page = pages[-1]
|
last_page = pages[-1]
|
||||||
new_page = last_page + "\n" + member
|
new_page = last_page + "\n" + member
|
||||||
|
|
||||||
if len(new_page) >= 1024:
|
if len(new_page) >= 1024:
|
||||||
pages.append(member)
|
pages.append(member)
|
||||||
else:
|
else:
|
||||||
pages[-1] = new_page
|
pages[-1] = new_page
|
||||||
|
|
||||||
for index, page in enumerate(pages):
|
for index, page in enumerate(pages):
|
||||||
field_name = "Members"
|
field_name = "Members"
|
||||||
if index >= 1:
|
if index >= 1:
|
||||||
field_name = "Members (part {})".format(index + 1)
|
field_name = "Members (part {})".format(index + 1)
|
||||||
card.add_field(name=field_name, value=page, inline=False)
|
card.add_field(name=field_name, value=page, inline=False)
|
||||||
|
|
||||||
card.set_footer(text="System ID: {}".format(system.hid))
|
card.set_footer(text="System ID: {}".format(system.hid))
|
||||||
return card
|
return card
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
import ciso8601
|
from io import BytesIO
|
||||||
|
|
||||||
|
import discord
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import discord
|
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.bot import channel_logger, utils, embeds
|
from pluralkit.bot import utils, channel_logger
|
||||||
|
from pluralkit.bot.channel_logger import ChannelLogger
|
||||||
|
|
||||||
logger = logging.getLogger("pluralkit.bot.proxy")
|
logger = logging.getLogger("pluralkit.bot.proxy")
|
||||||
|
|
||||||
|
|
||||||
def extract_leading_mentions(message_text):
|
def extract_leading_mentions(message_text):
|
||||||
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
|
# This regex matches one or more mentions at the start of a message, separated by any amount of spaces
|
||||||
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
|
match = re.match(r"^(<(@|@!|#|@&|a?:\w+:)\d+>\s*)+", message_text)
|
||||||
@ -34,7 +34,9 @@ def match_member_proxy_tags(member: db.ProxyMember, message_text: str):
|
|||||||
# Ignore mentions at the very start of the message, and match proxy tags after those
|
# Ignore mentions at the very start of the message, and match proxy tags after those
|
||||||
message_text, leading_mentions = extract_leading_mentions(message_text)
|
message_text, leading_mentions = extract_leading_mentions(message_text)
|
||||||
|
|
||||||
logger.debug("Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions, prefix, suffix))
|
logger.debug(
|
||||||
|
"Matching text '{}' and leading mentions '{}' to proxy tags {}text{}".format(message_text, leading_mentions,
|
||||||
|
prefix, suffix))
|
||||||
|
|
||||||
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
|
if message_text.startswith(member.prefix or "") and message_text.endswith(member.suffix or ""):
|
||||||
prefix_length = len(prefix)
|
prefix_length = len(prefix)
|
||||||
@ -59,243 +61,162 @@ def match_proxy_tags(members: List[db.ProxyMember], message_text: str):
|
|||||||
|
|
||||||
for member in members:
|
for member in members:
|
||||||
match = match_member_proxy_tags(member, message_text)
|
match = match_member_proxy_tags(member, message_text)
|
||||||
if match is not None: # Using "is not None" because an empty string is OK here too
|
if match is not None: # Using "is not None" because an empty string is OK here too
|
||||||
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
|
logger.debug("Matched member {} with inner text '{}'".format(member.hid, match))
|
||||||
return member, match
|
return member, match
|
||||||
|
|
||||||
|
|
||||||
def get_message_attachment_url(message: discord.Message):
|
async def get_or_create_webhook_for_channel(conn, channel: discord.TextChannel):
|
||||||
|
# First, check if we have one saved in the DB
|
||||||
|
webhook_from_db = await db.get_webhook(conn, channel.id)
|
||||||
|
if webhook_from_db:
|
||||||
|
webhook_id, webhook_token = webhook_from_db
|
||||||
|
|
||||||
|
session = channel._state.http._session
|
||||||
|
hook = discord.Webhook.partial(webhook_id, webhook_token, adapter=discord.AsyncWebhookAdapter(session))
|
||||||
|
|
||||||
|
# Workaround for https://github.com/Rapptz/discord.py/issues/1242
|
||||||
|
hook._adapter.store_user = hook._adapter._store_user
|
||||||
|
return hook
|
||||||
|
|
||||||
|
# If not, we create one and save it
|
||||||
|
created_webhook = await channel.create_webhook(name="PluralKit Proxy Webhook")
|
||||||
|
await db.add_webhook(conn, channel.id, created_webhook.id, created_webhook.token)
|
||||||
|
return created_webhook
|
||||||
|
|
||||||
|
|
||||||
|
async def make_attachment_file(message: discord.Message):
|
||||||
if not message.attachments:
|
if not message.attachments:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
attachment = message.attachments[0]
|
first_attachment = message.attachments[0]
|
||||||
if "proxy_url" in attachment:
|
|
||||||
return attachment["proxy_url"]
|
|
||||||
|
|
||||||
if "url" in attachment:
|
# Copy the file data to the buffer
|
||||||
return attachment["url"]
|
# TODO: do this without buffering... somehow
|
||||||
|
bio = BytesIO()
|
||||||
|
await first_attachment.save(bio)
|
||||||
|
|
||||||
|
return discord.File(bio, first_attachment.filename)
|
||||||
|
|
||||||
|
|
||||||
# TODO: possibly move this to bot __init__ so commands can access it too
|
async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember,
|
||||||
class WebhookPermissionError(Exception):
|
inner_text: str, logger: ChannelLogger):
|
||||||
pass
|
# Send the message through the webhook
|
||||||
|
webhook = await get_or_create_webhook_for_channel(conn, original_message.channel)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sent_message = await webhook.send(
|
||||||
|
content=inner_text,
|
||||||
|
username="{} {}".format(proxy_member.name, proxy_member.tag),
|
||||||
|
avatar_url=proxy_member.avatar_url,
|
||||||
|
file=await make_attachment_file(original_message),
|
||||||
|
wait=True
|
||||||
|
)
|
||||||
|
except discord.NotFound:
|
||||||
|
# The webhook we got from the DB doesn't actually exist
|
||||||
|
# If we delete it from the DB then call the function again, it'll re-create one for us
|
||||||
|
await db.delete_webhook(conn, original_message.channel.id)
|
||||||
|
await do_proxy_message(conn, original_message, proxy_member, inner_text, logger)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save the proxied message in the database
|
||||||
|
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id,
|
||||||
|
original_message.author.id)
|
||||||
|
|
||||||
|
await logger.log_message_proxied(
|
||||||
|
conn,
|
||||||
|
original_message.channel.guild.id,
|
||||||
|
original_message.channel.name,
|
||||||
|
original_message.channel.id,
|
||||||
|
original_message.author.name,
|
||||||
|
original_message.author.discriminator,
|
||||||
|
original_message.author.id,
|
||||||
|
proxy_member.name,
|
||||||
|
proxy_member.hid,
|
||||||
|
proxy_member.avatar_url,
|
||||||
|
proxy_member.system_name,
|
||||||
|
proxy_member.system_hid,
|
||||||
|
inner_text,
|
||||||
|
sent_message.attachments[0].url if sent_message.attachments else None,
|
||||||
|
sent_message.created_at,
|
||||||
|
sent_message.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# And finally, gotta delete the original.
|
||||||
|
await original_message.delete()
|
||||||
|
|
||||||
|
|
||||||
class DeletionPermissionError(Exception):
|
async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger) -> bool:
|
||||||
pass
|
# Don't bother proxying in DMs with the bot
|
||||||
|
if isinstance(message.channel, discord.abc.PrivateChannel):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Return every member associated with the account
|
||||||
|
members = await db.get_members_by_account(conn, message.author.id)
|
||||||
|
proxy_match = match_proxy_tags(members, message.content)
|
||||||
|
if not proxy_match:
|
||||||
|
# No proxy tags match here, we done
|
||||||
|
return False
|
||||||
|
|
||||||
|
member, inner_text = proxy_match
|
||||||
|
|
||||||
|
# Sanitize inner text for @everyones and such
|
||||||
|
inner_text = utils.sanitize(inner_text)
|
||||||
|
|
||||||
|
# If we don't have an inner text OR an attachment, we cancel because the hook can't send that
|
||||||
|
# Strip so it counts a string of solely spaces as blank too
|
||||||
|
if not inner_text.strip() and not message.attachments:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# So, we now have enough information to successfully proxy a message
|
||||||
|
async with conn.transaction():
|
||||||
|
await do_proxy_message(conn, message, member, inner_text, logger)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Proxy:
|
async def handle_deleted_message(conn, client: discord.Client, message_id: int,
|
||||||
def __init__(self, client: discord.Client, token: str, logger: channel_logger.ChannelLogger):
|
message_content: Optional[str], logger: channel_logger.ChannelLogger) -> bool:
|
||||||
self.logger = logging.getLogger("pluralkit.bot.proxy")
|
msg = await db.get_message(conn, message_id)
|
||||||
self.session = aiohttp.ClientSession()
|
if not msg:
|
||||||
self.client = client
|
return False
|
||||||
self.token = token
|
|
||||||
self.channel_logger = logger
|
|
||||||
|
|
||||||
async def save_channel_webhook(self, conn, channel: discord.Channel, id: str, token: str) -> (str, str):
|
channel = client.get_channel(msg.channel)
|
||||||
await db.add_webhook(conn, channel.id, id, token)
|
if not channel:
|
||||||
return id, token
|
# Weird edge case, but channel *could* be deleted at this point (can't think of any scenarios it would be tho)
|
||||||
|
return False
|
||||||
|
|
||||||
async def create_and_add_channel_webhook(self, conn, channel: discord.Channel) -> (str, str):
|
await db.delete_message(conn, message_id)
|
||||||
# This method is only called if there's no webhook found in the DB (and hopefully within a transaction)
|
await logger.log_message_deleted(
|
||||||
# No need to worry about error handling if there's a DB conflict (which will throw an exception because DB constraints)
|
conn,
|
||||||
req_headers = {"Authorization": "Bot {}".format(self.token)}
|
channel.guild.id,
|
||||||
|
channel.name,
|
||||||
|
msg.name,
|
||||||
|
msg.hid,
|
||||||
|
msg.avatar_url,
|
||||||
|
msg.system_name,
|
||||||
|
msg.system_hid,
|
||||||
|
message_content,
|
||||||
|
message_id
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
# First, check if there's already a webhook belonging to the bot
|
|
||||||
async with self.session.get("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
|
|
||||||
headers=req_headers) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
webhooks = await resp.json()
|
|
||||||
for webhook in webhooks:
|
|
||||||
if webhook["user"]["id"] == self.client.user.id:
|
|
||||||
# This webhook belongs to us, we can use that, return it and save it
|
|
||||||
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
|
|
||||||
elif resp.status == 403:
|
|
||||||
self.logger.warning(
|
|
||||||
"Did not have permission to fetch webhook list (server={}, channel={})".format(channel.server.id,
|
|
||||||
channel.id))
|
|
||||||
raise WebhookPermissionError()
|
|
||||||
else:
|
|
||||||
raise discord.HTTPException(resp, await resp.text())
|
|
||||||
|
|
||||||
# Then, try submitting a new one
|
async def try_delete_by_reaction(conn, client: discord.Client, message_id: int, reaction_user: int,
|
||||||
req_data = {"name": "PluralKit Proxy Webhook"}
|
logger: channel_logger.ChannelLogger) -> bool:
|
||||||
async with self.session.post("https://discordapp.com/api/v6/channels/{}/webhooks".format(channel.id),
|
# Find the message by the given message id or reaction user
|
||||||
json=req_data, headers=req_headers) as resp:
|
msg = await db.get_message_by_sender_and_id(conn, message_id, reaction_user)
|
||||||
if resp.status == 200:
|
if not msg:
|
||||||
webhook = await resp.json()
|
# Either the wrong user reacted or the message isn't a proxy message
|
||||||
return await self.save_channel_webhook(conn, channel, webhook["id"], webhook["token"])
|
# In either case - not our problem
|
||||||
elif resp.status == 403:
|
return False
|
||||||
self.logger.warning(
|
|
||||||
"Did not have permission to create webhook (server={}, channel={})".format(channel.server.id,
|
|
||||||
channel.id))
|
|
||||||
raise WebhookPermissionError()
|
|
||||||
else:
|
|
||||||
raise discord.HTTPException(resp, await resp.text())
|
|
||||||
|
|
||||||
# Should not be reached without an exception being thrown
|
# Find the original message
|
||||||
|
original_message = await client.get_channel(msg.channel).get_message(message_id)
|
||||||
|
if not original_message:
|
||||||
|
# Message got deleted, possibly race condition, eh
|
||||||
|
return False
|
||||||
|
|
||||||
async def get_webhook_for_channel(self, conn, channel: discord.Channel):
|
# Then delete the original message
|
||||||
async with conn.transaction():
|
await original_message.delete()
|
||||||
hook_match = await db.get_webhook(conn, channel.id)
|
|
||||||
if not hook_match:
|
|
||||||
# We don't have a webhook, create/add one
|
|
||||||
return await self.create_and_add_channel_webhook(conn, channel)
|
|
||||||
else:
|
|
||||||
return hook_match
|
|
||||||
|
|
||||||
async def do_proxy_message(self, conn, member: db.ProxyMember, original_message: discord.Message, text: str,
|
await handle_deleted_message(conn, client, message_id, original_message.content, logger)
|
||||||
attachment_url: str, has_already_retried=False):
|
|
||||||
hook_id, hook_token = await self.get_webhook_for_channel(conn, original_message.channel)
|
|
||||||
|
|
||||||
form_data = aiohttp.FormData()
|
|
||||||
form_data.add_field("username", "{} {}".format(member.name, member.tag or "").strip())
|
|
||||||
|
|
||||||
if text:
|
|
||||||
form_data.add_field("content", text)
|
|
||||||
|
|
||||||
if attachment_url:
|
|
||||||
attachment_resp = await self.session.get(attachment_url)
|
|
||||||
form_data.add_field("file", attachment_resp.content, content_type=attachment_resp.content_type,
|
|
||||||
filename=attachment_resp.url.name)
|
|
||||||
|
|
||||||
if member.avatar_url:
|
|
||||||
form_data.add_field("avatar_url", member.avatar_url)
|
|
||||||
|
|
||||||
async with self.session.post(
|
|
||||||
"https://discordapp.com/api/v6/webhooks/{}/{}?wait=true".format(hook_id, hook_token),
|
|
||||||
data=form_data) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
message = await resp.json()
|
|
||||||
|
|
||||||
await db.add_message(conn, message["id"], message["channel_id"], member.id, original_message.author.id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.client.delete_message(original_message)
|
|
||||||
except discord.Forbidden:
|
|
||||||
self.logger.warning(
|
|
||||||
"Did not have permission to delete original message (server={}, channel={})".format(
|
|
||||||
original_message.server.id, original_message.channel.id))
|
|
||||||
raise DeletionPermissionError()
|
|
||||||
except discord.NotFound:
|
|
||||||
self.logger.warning("Tried to delete message when proxying, but message was already gone (server={}, channel={})".format(original_message.server.id, original_message.channel.id))
|
|
||||||
|
|
||||||
message_image = None
|
|
||||||
if message["attachments"]:
|
|
||||||
first_attachment = message["attachments"][0]
|
|
||||||
if "width" in first_attachment and "height" in first_attachment:
|
|
||||||
# Only log attachments that are actually images
|
|
||||||
message_image = first_attachment["url"]
|
|
||||||
|
|
||||||
await self.channel_logger.log_message_proxied(conn,
|
|
||||||
server_id=original_message.server.id,
|
|
||||||
channel_name=original_message.channel.name,
|
|
||||||
channel_id=original_message.channel.id,
|
|
||||||
sender_name=original_message.author.name,
|
|
||||||
sender_disc=original_message.author.discriminator,
|
|
||||||
sender_id=original_message.author.id,
|
|
||||||
member_name=member.name,
|
|
||||||
member_hid=member.hid,
|
|
||||||
member_avatar_url=member.avatar_url,
|
|
||||||
system_name=member.system_name,
|
|
||||||
system_hid=member.system_hid,
|
|
||||||
message_text=text,
|
|
||||||
message_image=message_image,
|
|
||||||
message_timestamp=ciso8601.parse_datetime(
|
|
||||||
message["timestamp"]),
|
|
||||||
message_id=message["id"])
|
|
||||||
elif resp.status == 404 and not has_already_retried:
|
|
||||||
# Webhook doesn't exist. Delete it from the DB, create, and add a new one
|
|
||||||
self.logger.warning("Webhook registered in DB doesn't exist, deleting hook from DB, re-adding, and trying again (channel={}, hook={})".format(original_message.channel.id, hook_id))
|
|
||||||
await db.delete_webhook(conn, original_message.channel.id)
|
|
||||||
await self.create_and_add_channel_webhook(conn, original_message.channel)
|
|
||||||
|
|
||||||
# Then try again all over, making sure to not retry again and go in a loop should it continually fail
|
|
||||||
return await self.do_proxy_message(conn, member, original_message, text, attachment_url, has_already_retried=True)
|
|
||||||
else:
|
|
||||||
raise discord.HTTPException(resp, await resp.text())
|
|
||||||
|
|
||||||
async def try_proxy_message(self, conn, message: discord.Message):
|
|
||||||
# Can't proxy in DMs, webhook creation will explode
|
|
||||||
if message.channel.is_private:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Big fat query to find every member associated with this account
|
|
||||||
# Returned member object has a few more keys (system tag, for example)
|
|
||||||
members = await db.get_members_by_account(conn, account_id=message.author.id)
|
|
||||||
|
|
||||||
match = match_proxy_tags(members, message.content)
|
|
||||||
if not match:
|
|
||||||
return False
|
|
||||||
|
|
||||||
member, text = match
|
|
||||||
attachment_url = get_message_attachment_url(message)
|
|
||||||
|
|
||||||
# Can't proxy a message with no text AND no attachment
|
|
||||||
if not text and not attachment_url:
|
|
||||||
self.logger.debug("Skipping message because of no text and no attachment")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Remember to sanitize the text (remove @everyones and such)
|
|
||||||
text = utils.sanitize(text)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with conn.transaction():
|
|
||||||
await self.do_proxy_message(conn, member, message, text=text, attachment_url=attachment_url)
|
|
||||||
except WebhookPermissionError:
|
|
||||||
embed = embeds.error("PluralKit does not have permission to manage webhooks for this channel. Contact your local server administrator to fix this.")
|
|
||||||
await self.client.send_message(message.channel, embed=embed)
|
|
||||||
except DeletionPermissionError:
|
|
||||||
embed = embeds.error("PluralKit does not have permission to delete messages in this channel. Contact your local server administrator to fix this.")
|
|
||||||
await self.client.send_message(message.channel, embed=embed)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def try_delete_message(self, conn, message_id: str, check_user_id: Optional[str], delete_message: bool, deleted_by_moderator: bool):
|
|
||||||
async with conn.transaction():
|
|
||||||
# Find the message in the DB, and make sure it's sent by the user (if we need to check)
|
|
||||||
if check_user_id:
|
|
||||||
db_message = await db.get_message_by_sender_and_id(conn, message_id=message_id, sender_id=check_user_id)
|
|
||||||
else:
|
|
||||||
db_message = await db.get_message(conn, message_id=message_id)
|
|
||||||
|
|
||||||
if db_message:
|
|
||||||
self.logger.debug("Deleting message {}".format(message_id))
|
|
||||||
channel = self.client.get_channel(str(db_message.channel))
|
|
||||||
|
|
||||||
# If we should also delete the actual message, do that
|
|
||||||
if delete_message:
|
|
||||||
message = await self.client.get_message(channel, message_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.client.delete_message(message)
|
|
||||||
except discord.Forbidden:
|
|
||||||
self.logger.warning(
|
|
||||||
"Did not have permission to remove message, aborting deletion (server={}, channel={})".format(
|
|
||||||
channel.server.id, channel.id))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove it from the DB
|
|
||||||
await db.delete_message(conn, message_id)
|
|
||||||
|
|
||||||
# Then log deletion to logging channel
|
|
||||||
await self.channel_logger.log_message_deleted(conn,
|
|
||||||
server_id=channel.server.id,
|
|
||||||
channel_name=channel.name,
|
|
||||||
member_name=db_message.name,
|
|
||||||
member_hid=db_message.hid,
|
|
||||||
member_avatar_url=db_message.avatar_url,
|
|
||||||
system_name=db_message.system_name,
|
|
||||||
system_hid=db_message.system_hid,
|
|
||||||
message_text=db_message.content,
|
|
||||||
message_id=message_id,
|
|
||||||
deleted_by_moderator=deleted_by_moderator)
|
|
||||||
|
|
||||||
async def handle_reaction(self, conn, user_id: str, message_id: str, emoji: str):
|
|
||||||
if emoji == "❌":
|
|
||||||
await self.try_delete_message(conn, message_id, check_user_id=user_id, delete_message=True, deleted_by_moderator=False)
|
|
||||||
|
|
||||||
async def handle_deletion(self, conn, message_id: str):
|
|
||||||
# Don't delete the message, it's already gone at this point, just handle DB deletion and logging
|
|
||||||
await self.try_delete_message(conn, message_id, check_user_id=None, delete_message=False, deleted_by_moderator=True)
|
|
@ -2,6 +2,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.system import System
|
from pluralkit.system import System
|
||||||
@ -20,28 +21,28 @@ def bounds_check_member_name(new_name, system_tag):
|
|||||||
if len("{} {}".format(new_name, system_tag)) > 32:
|
if len("{} {}".format(new_name, system_tag)) > 32:
|
||||||
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
|
return "This name, combined with the system tag ({}), would exceed the maximum length of 32 characters. Please reduce the length of the tag, or use a shorter name.".format(system_tag)
|
||||||
|
|
||||||
async def parse_mention(client: discord.Client, mention: str) -> discord.User:
|
async def parse_mention(client: discord.Client, mention: str) -> Optional[discord.User]:
|
||||||
# First try matching mention format
|
# First try matching mention format
|
||||||
match = re.fullmatch("<@!?(\\d+)>", mention)
|
match = re.fullmatch("<@!?(\\d+)>", mention)
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
return await client.get_user_info(match.group(1))
|
return await client.get_user_info(int(match.group(1)))
|
||||||
except discord.NotFound:
|
except discord.NotFound:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Then try with just ID
|
# Then try with just ID
|
||||||
try:
|
try:
|
||||||
return await client.get_user_info(str(int(mention)))
|
return await client.get_user_info(int(mention))
|
||||||
except (ValueError, discord.NotFound):
|
except (ValueError, discord.NotFound):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_channel_mention(mention: str, server: discord.Server) -> discord.Channel:
|
def parse_channel_mention(mention: str, server: discord.Guild) -> Optional[discord.TextChannel]:
|
||||||
match = re.fullmatch("<#(\\d+)>", mention)
|
match = re.fullmatch("<#(\\d+)>", mention)
|
||||||
if match:
|
if match:
|
||||||
return server.get_channel(match.group(1))
|
return server.get_channel(int(match.group(1)))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return server.get_channel(str(int(mention)))
|
return server.get_channel(int(mention))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -62,17 +62,17 @@ async def delete_member(conn, member_id: int):
|
|||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def link_account(conn, system_id: int, account_id: str):
|
async def link_account(conn, system_id: int, account_id: int):
|
||||||
logger.debug("Linking account (account_id={}, system_id={})".format(
|
logger.debug("Linking account (account_id={}, system_id={})".format(
|
||||||
account_id, system_id))
|
account_id, system_id))
|
||||||
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
|
await conn.execute("insert into accounts (uid, system) values ($1, $2)", account_id, system_id)
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def unlink_account(conn, system_id: int, account_id: str):
|
async def unlink_account(conn, system_id: int, account_id: int):
|
||||||
logger.debug("Unlinking account (account_id={}, system_id={})".format(
|
logger.debug("Unlinking account (account_id={}, system_id={})".format(
|
||||||
account_id, system_id))
|
account_id, system_id))
|
||||||
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
|
await conn.execute("delete from accounts where uid = $1 and system = $2", account_id, system_id)
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
@ -81,8 +81,8 @@ async def get_linked_accounts(conn, system_id: int) -> List[int]:
|
|||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_system_by_account(conn, account_id: str) -> System:
|
async def get_system_by_account(conn, account_id: int) -> System:
|
||||||
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
|
row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id)
|
||||||
return System(**row) if row else None
|
return System(**row) if row else None
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
@ -151,26 +151,26 @@ async def get_members_exceeding(conn, system_id: int, length: int) -> List[Membe
|
|||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_webhook(conn, channel_id: str) -> (str, str):
|
async def get_webhook(conn, channel_id: int) -> (str, str):
|
||||||
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", int(channel_id))
|
row = await conn.fetchrow("select webhook, token from webhooks where channel = $1", channel_id)
|
||||||
return (str(row["webhook"]), row["token"]) if row else None
|
return (str(row["webhook"]), row["token"]) if row else None
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
|
async def add_webhook(conn, channel_id: int, webhook_id: int, webhook_token: str):
|
||||||
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
|
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
|
||||||
channel_id, webhook_id, webhook_token))
|
channel_id, webhook_id, webhook_token))
|
||||||
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
|
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", channel_id, webhook_id, webhook_token)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def delete_webhook(conn, channel_id: str):
|
async def delete_webhook(conn, channel_id: int):
|
||||||
await conn.execute("delete from webhooks where channel = $1", int(channel_id))
|
await conn.execute("delete from webhooks where channel = $1", channel_id)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str):
|
async def add_message(conn, message_id: int, channel_id: int, member_id: int, sender_id: int):
|
||||||
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
|
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
|
||||||
message_id, channel_id, member_id, sender_id))
|
message_id, channel_id, member_id, sender_id))
|
||||||
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id))
|
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", message_id, channel_id, member_id, sender_id)
|
||||||
|
|
||||||
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
|
class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "color", "name", "avatar_url", "tag", "system_name", "system_hid"])):
|
||||||
id: int
|
id: int
|
||||||
@ -185,7 +185,7 @@ class ProxyMember(namedtuple("ProxyMember", ["id", "hid", "prefix", "suffix", "c
|
|||||||
system_hid: str
|
system_hid: str
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
|
async def get_members_by_account(conn, account_id: int) -> List[ProxyMember]:
|
||||||
# Returns a "chimera" object
|
# Returns a "chimera" object
|
||||||
rows = await conn.fetch("""select
|
rows = await conn.fetch("""select
|
||||||
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
|
members.id, members.hid, members.prefix, members.suffix, members.color, members.name, members.avatar_url,
|
||||||
@ -195,7 +195,7 @@ async def get_members_by_account(conn, account_id: str) -> List[ProxyMember]:
|
|||||||
where
|
where
|
||||||
accounts.uid = $1
|
accounts.uid = $1
|
||||||
and systems.id = accounts.system
|
and systems.id = accounts.system
|
||||||
and members.system = systems.id""", int(account_id))
|
and members.system = systems.id""", account_id)
|
||||||
return [ProxyMember(**row) for row in rows]
|
return [ProxyMember(**row) for row in rows]
|
||||||
|
|
||||||
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
|
class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender", "name", "hid", "avatar_url", "system_name", "system_hid"])):
|
||||||
@ -220,7 +220,7 @@ class MessageInfo(namedtuple("MemberInfo", ["mid", "channel", "member", "sender"
|
|||||||
}
|
}
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) -> MessageInfo:
|
async def get_message_by_sender_and_id(conn, message_id: int, sender_id: int) -> MessageInfo:
|
||||||
row = await conn.fetchrow("""select
|
row = await conn.fetchrow("""select
|
||||||
messages.*,
|
messages.*,
|
||||||
members.name, members.hid, members.avatar_url,
|
members.name, members.hid, members.avatar_url,
|
||||||
@ -231,12 +231,12 @@ async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str) ->
|
|||||||
messages.member = members.id
|
messages.member = members.id
|
||||||
and members.system = systems.id
|
and members.system = systems.id
|
||||||
and mid = $1
|
and mid = $1
|
||||||
and sender = $2""", int(message_id), int(sender_id))
|
and sender = $2""", message_id, sender_id)
|
||||||
return MessageInfo(**row) if row else None
|
return MessageInfo(**row) if row else None
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_message(conn, message_id: str) -> MessageInfo:
|
async def get_message(conn, message_id: int) -> MessageInfo:
|
||||||
row = await conn.fetchrow("""select
|
row = await conn.fetchrow("""select
|
||||||
messages.*,
|
messages.*,
|
||||||
members.name, members.hid, members.avatar_url,
|
members.name, members.hid, members.avatar_url,
|
||||||
@ -246,14 +246,14 @@ async def get_message(conn, message_id: str) -> MessageInfo:
|
|||||||
where
|
where
|
||||||
messages.member = members.id
|
messages.member = members.id
|
||||||
and members.system = systems.id
|
and members.system = systems.id
|
||||||
and mid = $1""", int(message_id))
|
and mid = $1""", message_id)
|
||||||
return MessageInfo(**row) if row else None
|
return MessageInfo(**row) if row else None
|
||||||
|
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def delete_message(conn, message_id: str):
|
async def delete_message(conn, message_id: int):
|
||||||
logger.debug("Deleting message (id={})".format(message_id))
|
logger.debug("Deleting message (id={})".format(message_id))
|
||||||
await conn.execute("delete from messages where mid = $1", int(message_id))
|
await conn.execute("delete from messages where mid = $1", message_id)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_member_message_count(conn, member_id: int) -> int:
|
async def get_member_message_count(conn, member_id: int) -> int:
|
||||||
@ -290,14 +290,14 @@ async def add_switch_member(conn, switch_id: int, member_id: int):
|
|||||||
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
|
await conn.execute("insert into switch_members (switch, member) values ($1, $2)", switch_id, member_id)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def get_server_info(conn, server_id: str):
|
async def get_server_info(conn, server_id: int):
|
||||||
return await conn.fetchrow("select * from servers where id = $1", int(server_id))
|
return await conn.fetchrow("select * from servers where id = $1", server_id)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def update_server(conn, server_id: str, logging_channel_id: str):
|
async def update_server(conn, server_id: int, logging_channel_id: int):
|
||||||
logging_channel_id = int(logging_channel_id) if logging_channel_id else None
|
logging_channel_id = logging_channel_id if logging_channel_id else None
|
||||||
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
|
logger.debug("Updating server settings (id={}, log_channel={})".format(server_id, logging_channel_id))
|
||||||
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", int(server_id), logging_channel_id)
|
await conn.execute("insert into servers (id, log_channel) values ($1, $2) on conflict (id) do update set log_channel = $2", server_id, logging_channel_id)
|
||||||
|
|
||||||
@db_wrap
|
@db_wrap
|
||||||
async def member_count(conn) -> int:
|
async def member_count(conn) -> int:
|
||||||
|
@ -2,7 +2,7 @@ aiodns
|
|||||||
aiohttp
|
aiohttp
|
||||||
asyncpg
|
asyncpg
|
||||||
dateparser
|
dateparser
|
||||||
discord.py
|
https://github.com/Rapptz/discord.py/archive/860d6a9ace8248dfeec18b8b159e7b757d9f56bb.zip#egg=discord.py
|
||||||
humanize
|
humanize
|
||||||
uvloop
|
uvloop
|
||||||
ciso8601
|
ciso8601
|
||||||
|
Loading…
x
Reference in New Issue
Block a user