Message and error logging, various bugfixes
This commit is contained in:
parent
ea62ede21b
commit
58d8927380
@ -7,12 +7,14 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
import traceback
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.bot import commands, proxy
|
from pluralkit.bot import commands, proxy, channel_logger, embeds
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||||
|
|
||||||
|
|
||||||
def connect_to_database() -> asyncpg.pool.Pool:
|
def connect_to_database() -> asyncpg.pool.Pool:
|
||||||
username = os.environ["DATABASE_USER"]
|
username = os.environ["DATABASE_USER"]
|
||||||
password = os.environ["DATABASE_PASS"]
|
password = os.environ["DATABASE_PASS"]
|
||||||
@ -21,7 +23,9 @@ def connect_to_database() -> asyncpg.pool.Pool:
|
|||||||
port = os.environ["DATABASE_PORT"]
|
port = os.environ["DATABASE_PORT"]
|
||||||
|
|
||||||
if username is None or password is None or name is None or host is None or port is None:
|
if username is None or password is None or name is None or host is None or port is None:
|
||||||
print("Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.", file=sys.stderr)
|
print(
|
||||||
|
"Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.",
|
||||||
|
file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -38,11 +42,20 @@ def connect_to_database() -> asyncpg.pool.Pool:
|
|||||||
port=port
|
port=port
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
pool = connect_to_database()
|
pool = connect_to_database()
|
||||||
|
|
||||||
|
async def create_tables():
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await db.create_tables(conn)
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(create_tables())
|
||||||
|
|
||||||
client = discord.Client()
|
client = discord.Client()
|
||||||
|
|
||||||
|
logger = channel_logger.ChannelLogger(client)
|
||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def on_ready():
|
async def on_ready():
|
||||||
print("PluralKit started.")
|
print("PluralKit started.")
|
||||||
@ -64,8 +77,51 @@ def run():
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Second pass: do proxy matching
|
# Second pass: do proxy matching
|
||||||
await proxy.try_proxy_message(message, conn)
|
await proxy.try_proxy_message(conn, message, logger)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_raw_message_delete(payload: discord.RawMessageDeleteEvent):
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await proxy.handle_deleted_message(conn, client, payload.message_id, None, logger)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_raw_bulk_message_delete(payload: discord.RawBulkMessageDeleteEvent):
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
for message_id in payload.message_ids:
|
||||||
|
await proxy.handle_deleted_message(conn, client, message_id, None, logger)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
|
||||||
|
if payload.emoji.name == "\u274c": # Red X
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await proxy.try_delete_by_reaction(conn, client, payload.message_id, payload.user_id, logger)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_error(event_name, *args, **kwargs):
|
||||||
|
log_channel_id = os.environ["LOG_CHANNEL"]
|
||||||
|
if not log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
log_channel = client.get_channel(int(log_channel_id))
|
||||||
|
|
||||||
|
# If this is a message event, we can attach additional information in an event
|
||||||
|
# ie. username, channel, content, etc
|
||||||
|
if args and isinstance(args[0], discord.Message):
|
||||||
|
message: discord.Message = args[0]
|
||||||
|
embed = embeds.exception_log(
|
||||||
|
message.content,
|
||||||
|
message.author.name,
|
||||||
|
message.author.discriminator,
|
||||||
|
message.author.id,
|
||||||
|
message.guild.id if message.guild else None,
|
||||||
|
message.channel.id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If not, just post the string itself
|
||||||
|
embed = None
|
||||||
|
|
||||||
|
traceback_str = "```python\n{}```".format(traceback.format_exc())
|
||||||
|
await log_channel.send(content=traceback_str, embed=embed)
|
||||||
|
|
||||||
bot_token = os.environ["TOKEN"]
|
bot_token = os.environ["TOKEN"]
|
||||||
if not bot_token:
|
if not bot_token:
|
||||||
|
@ -19,7 +19,7 @@ class ChannelLogger:
|
|||||||
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
|
self.logger = logging.getLogger("pluralkit.bot.channel_logger")
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
async def get_log_channel(self, conn, server_id: str):
|
async def get_log_channel(self, conn, server_id: int):
|
||||||
server_info = await db.get_server_info(conn, server_id)
|
server_info = await db.get_server_info(conn, server_id)
|
||||||
|
|
||||||
if not server_info:
|
if not server_info:
|
||||||
@ -30,21 +30,21 @@ class ChannelLogger:
|
|||||||
if not log_channel:
|
if not log_channel:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.client.get_channel(str(log_channel))
|
return self.client.get_channel(log_channel)
|
||||||
|
|
||||||
async def send_to_log_channel(self, log_channel: discord.Channel, embed: discord.Embed, text: str = None):
|
async def send_to_log_channel(self, log_channel: discord.TextChannel, embed: discord.Embed, text: str = None):
|
||||||
try:
|
try:
|
||||||
await self.client.send_message(log_channel, embed=embed, content=text)
|
await log_channel.send(content=text, embed=embed)
|
||||||
except discord.Forbidden:
|
except discord.Forbidden:
|
||||||
# TODO: spew big error
|
# TODO: spew big error
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Did not have permission to send message to logging channel (server={}, channel={})".format(
|
"Did not have permission to send message to logging channel (server={}, channel={})".format(
|
||||||
log_channel.server.id, log_channel.id))
|
log_channel.guild.id, log_channel.id))
|
||||||
|
|
||||||
async def log_message_proxied(self, conn,
|
async def log_message_proxied(self, conn,
|
||||||
server_id: str,
|
server_id: int,
|
||||||
channel_name: str,
|
channel_name: str,
|
||||||
channel_id: str,
|
channel_id: int,
|
||||||
sender_name: str,
|
sender_name: str,
|
||||||
sender_disc: int,
|
sender_disc: int,
|
||||||
sender_id: int,
|
sender_id: int,
|
||||||
@ -56,11 +56,13 @@ class ChannelLogger:
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_image: str,
|
message_image: str,
|
||||||
message_timestamp: datetime,
|
message_timestamp: datetime,
|
||||||
message_id: str):
|
message_id: int):
|
||||||
log_channel = await self.get_log_channel(conn, server_id)
|
log_channel = await self.get_log_channel(conn, server_id)
|
||||||
if not log_channel:
|
if not log_channel:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
|
||||||
|
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.colour = discord.Colour.blue()
|
embed.colour = discord.Colour.blue()
|
||||||
embed.description = message_text
|
embed.description = message_text
|
||||||
@ -75,11 +77,10 @@ class ChannelLogger:
|
|||||||
if message_image:
|
if message_image:
|
||||||
embed.set_thumbnail(url=message_image)
|
embed.set_thumbnail(url=message_image)
|
||||||
|
|
||||||
message_link = "https://discordapp.com/channels/{}/{}/{}".format(server_id, channel_id, message_id)
|
|
||||||
await self.send_to_log_channel(log_channel, embed, message_link)
|
await self.send_to_log_channel(log_channel, embed, message_link)
|
||||||
|
|
||||||
async def log_message_deleted(self, conn,
|
async def log_message_deleted(self, conn,
|
||||||
server_id: str,
|
server_id: int,
|
||||||
channel_name: str,
|
channel_name: str,
|
||||||
member_name: str,
|
member_name: str,
|
||||||
member_hid: str,
|
member_hid: str,
|
||||||
@ -87,22 +88,17 @@ class ChannelLogger:
|
|||||||
system_name: str,
|
system_name: str,
|
||||||
system_hid: str,
|
system_hid: str,
|
||||||
message_text: str,
|
message_text: str,
|
||||||
message_id: str,
|
message_id: int):
|
||||||
deleted_by_moderator: bool):
|
|
||||||
log_channel = await self.get_log_channel(conn, server_id)
|
log_channel = await self.get_log_channel(conn, server_id)
|
||||||
if not log_channel:
|
if not log_channel:
|
||||||
return
|
return
|
||||||
|
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.colour = discord.Colour.dark_red()
|
embed.colour = discord.Colour.dark_red()
|
||||||
embed.description = message_text
|
embed.description = message_text or "*(unknown, message deleted by moderator)*"
|
||||||
embed.timestamp = datetime.utcnow()
|
embed.timestamp = datetime.utcnow()
|
||||||
|
|
||||||
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
|
embed_set_author_name(embed, channel_name, member_name, system_name, member_avatar_url)
|
||||||
embed.set_footer(
|
embed.set_footer(text="System ID: {} | Member ID: {} | Message ID: {}".format(system_hid, member_hid, message_id))
|
||||||
text="System ID: {} | Member ID: {} | Message ID: {} | Deleted by moderator? {}".format(system_hid,
|
|
||||||
member_hid,
|
|
||||||
message_id,
|
|
||||||
"Yes" if deleted_by_moderator else "No"))
|
|
||||||
|
|
||||||
await self.send_to_log_channel(log_channel, embed)
|
await self.send_to_log_channel(log_channel, embed)
|
||||||
|
@ -8,7 +8,7 @@ async def get_message_contents(client: discord.Client, channel_id: int, message_
|
|||||||
channel = client.get_channel(str(channel_id))
|
channel = client.get_channel(str(channel_id))
|
||||||
if channel:
|
if channel:
|
||||||
try:
|
try:
|
||||||
original_message = await client.get_message(channel, str(message_id))
|
original_message = await client.get_channel(channel).get_message(message_id)
|
||||||
return original_message.content or None
|
return original_message.content or None
|
||||||
except (discord.errors.Forbidden, discord.errors.NotFound):
|
except (discord.errors.Forbidden, discord.errors.NotFound):
|
||||||
pass
|
pass
|
||||||
@ -24,13 +24,13 @@ async def message_info(ctx: CommandContext):
|
|||||||
return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
|
return CommandError("You must pass a valid number as a message ID.", help=help.message_lookup)
|
||||||
|
|
||||||
# Find the message in the DB
|
# Find the message in the DB
|
||||||
message = await db.get_message(ctx.conn, str(mid))
|
message = await db.get_message(ctx.conn, mid)
|
||||||
if not message:
|
if not message:
|
||||||
raise CommandError("Message with ID '{}' not found.".format(mid))
|
raise CommandError("Message with ID '{}' not found.".format(mid))
|
||||||
|
|
||||||
# Get the original sender of the messages
|
# Get the original sender of the messages
|
||||||
try:
|
try:
|
||||||
original_sender = await ctx.client.get_user_info(str(message.sender))
|
original_sender = await ctx.client.get_user_info(message.sender)
|
||||||
except discord.NotFound:
|
except discord.NotFound:
|
||||||
# Account was since deleted - rare but we're handling it anyway
|
# Account was since deleted - rare but we're handling it anyway
|
||||||
original_sender = None
|
original_sender = None
|
||||||
|
@ -4,10 +4,13 @@ logger = logging.getLogger("pluralkit.commands")
|
|||||||
|
|
||||||
|
|
||||||
async def set_log(ctx: CommandContext):
|
async def set_log(ctx: CommandContext):
|
||||||
if not ctx.message.author.server_permissions.administrator:
|
if not ctx.message.author.guild_permissions.administrator:
|
||||||
return CommandError("You must be a server administrator to use this command.")
|
return CommandError("You must be a server administrator to use this command.")
|
||||||
|
|
||||||
server = ctx.message.server
|
server = ctx.message.guild
|
||||||
|
if not server:
|
||||||
|
return CommandError("This command can not be run in a DM.")
|
||||||
|
|
||||||
if not ctx.has_next():
|
if not ctx.has_next():
|
||||||
channel_id = None
|
channel_id = None
|
||||||
else:
|
else:
|
||||||
|
@ -38,13 +38,13 @@ def status(text: str) -> discord.Embed:
|
|||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
def exception_log(message_content, author_name, author_discriminator, server_id, channel_id) -> discord.Embed:
|
def exception_log(message_content, author_name, author_discriminator, author_id, server_id, channel_id) -> discord.Embed:
|
||||||
embed = discord.Embed()
|
embed = discord.Embed()
|
||||||
embed.colour = discord.Colour.dark_red()
|
embed.colour = discord.Colour.dark_red()
|
||||||
embed.title = message_content
|
embed.title = message_content
|
||||||
|
|
||||||
embed.set_footer(text="Sender: {}#{} | Server: {} | Channel: {}".format(
|
embed.set_footer(text="Sender: {}#{} ({}) | Server: {} | Channel: {}".format(
|
||||||
author_name, author_discriminator,
|
author_name, author_discriminator, author_id,
|
||||||
server_id if server_id else "(DMs)",
|
server_id if server_id else "(DMs)",
|
||||||
channel_id
|
channel_id
|
||||||
))
|
))
|
||||||
@ -72,7 +72,7 @@ async def system_card(conn, client: discord.Client, system: System) -> discord.E
|
|||||||
|
|
||||||
account_names = []
|
account_names = []
|
||||||
for account_id in await system.get_linked_account_ids(conn):
|
for account_id in await system.get_linked_account_ids(conn):
|
||||||
account = await client.get_user_info(str(account_id))
|
account = await client.get_user_info(account_id)
|
||||||
account_names.append("{}#{}".format(account.name, account.discriminator))
|
account_names.append("{}#{}".format(account.name, account.discriminator))
|
||||||
|
|
||||||
card.add_field(name="Linked accounts", value="\n".join(account_names))
|
card.add_field(name="Linked accounts", value="\n".join(account_names))
|
||||||
|
@ -3,10 +3,11 @@ from io import BytesIO
|
|||||||
import discord
|
import discord
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from pluralkit import db
|
from pluralkit import db
|
||||||
from pluralkit.bot import utils
|
from pluralkit.bot import utils, channel_logger
|
||||||
|
from pluralkit.bot.channel_logger import ChannelLogger
|
||||||
|
|
||||||
logger = logging.getLogger("pluralkit.bot.proxy")
|
logger = logging.getLogger("pluralkit.bot.proxy")
|
||||||
|
|
||||||
@ -99,7 +100,7 @@ async def make_attachment_file(message: discord.Message):
|
|||||||
|
|
||||||
|
|
||||||
async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember,
|
async def do_proxy_message(conn, original_message: discord.Message, proxy_member: db.ProxyMember,
|
||||||
inner_text: str):
|
inner_text: str, logger: ChannelLogger):
|
||||||
# Send the message through the webhook
|
# Send the message through the webhook
|
||||||
webhook = await get_or_create_webhook_for_channel(conn, original_message.channel)
|
webhook = await get_or_create_webhook_for_channel(conn, original_message.channel)
|
||||||
sent_message = await webhook.send(
|
sent_message = await webhook.send(
|
||||||
@ -114,10 +115,30 @@ async def do_proxy_message(conn, original_message: discord.Message, proxy_member
|
|||||||
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id,
|
await db.add_message(conn, sent_message.id, original_message.channel.id, proxy_member.id,
|
||||||
original_message.author.id)
|
original_message.author.id)
|
||||||
|
|
||||||
# TODO: log message in log channel
|
await logger.log_message_proxied(
|
||||||
|
conn,
|
||||||
|
original_message.channel.guild.id,
|
||||||
|
original_message.channel.name,
|
||||||
|
original_message.channel.id,
|
||||||
|
original_message.author.name,
|
||||||
|
original_message.author.discriminator,
|
||||||
|
original_message.author.id,
|
||||||
|
proxy_member.name,
|
||||||
|
proxy_member.hid,
|
||||||
|
proxy_member.avatar_url,
|
||||||
|
proxy_member.system_name,
|
||||||
|
proxy_member.system_hid,
|
||||||
|
inner_text,
|
||||||
|
sent_message.attachments[0].url if sent_message.attachments else None,
|
||||||
|
sent_message.created_at,
|
||||||
|
sent_message.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# And finally, gotta delete the original.
|
||||||
|
await original_message.delete()
|
||||||
|
|
||||||
|
|
||||||
async def try_proxy_message(message: discord.Message, conn) -> bool:
|
async def try_proxy_message(conn, message: discord.Message, logger: ChannelLogger) -> bool:
|
||||||
# Don't bother proxying in DMs with the bot
|
# Don't bother proxying in DMs with the bot
|
||||||
if isinstance(message.channel, discord.abc.PrivateChannel):
|
if isinstance(message.channel, discord.abc.PrivateChannel):
|
||||||
return False
|
return False
|
||||||
@ -140,9 +161,57 @@ async def try_proxy_message(message: discord.Message, conn) -> bool:
|
|||||||
|
|
||||||
# So, we now have enough information to successfully proxy a message
|
# So, we now have enough information to successfully proxy a message
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
await do_proxy_message(conn, message, member, inner_text)
|
await do_proxy_message(conn, message, member, inner_text, logger)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_deleted_message(conn, client: discord.Client, message_id: int,
|
||||||
|
message_content: Optional[str], logger: channel_logger.ChannelLogger) -> bool:
|
||||||
|
msg = await db.get_message(conn, message_id)
|
||||||
|
if not msg:
|
||||||
|
return False
|
||||||
|
|
||||||
|
channel = client.get_channel(msg.channel)
|
||||||
|
if not channel:
|
||||||
|
# Weird edge case, but channel *could* be deleted at this point (can't think of any scenarios it would be tho)
|
||||||
|
return False
|
||||||
|
|
||||||
|
await db.delete_message(conn, message_id)
|
||||||
|
await logger.log_message_deleted(
|
||||||
|
conn,
|
||||||
|
channel.guild.id,
|
||||||
|
channel.name,
|
||||||
|
msg.name,
|
||||||
|
msg.hid,
|
||||||
|
msg.avatar_url,
|
||||||
|
msg.system_name,
|
||||||
|
msg.system_hid,
|
||||||
|
message_content,
|
||||||
|
message_id
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def try_delete_by_reaction(conn, client: discord.Client, message_id: int, reaction_user: int,
|
||||||
|
logger: channel_logger.ChannelLogger) -> bool:
|
||||||
|
# Find the message by the given message id or reaction user
|
||||||
|
msg = await db.get_message_by_sender_and_id(conn, message_id, reaction_user)
|
||||||
|
if not msg:
|
||||||
|
# Either the wrong user reacted or the message isn't a proxy message
|
||||||
|
# In either case - not our problem
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Find the original message
|
||||||
|
original_message = await client.get_channel(msg.channel).get_message(message_id)
|
||||||
|
if not original_message:
|
||||||
|
# Message got deleted, possibly race condition, eh
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Then delete the original message
|
||||||
|
await original_message.delete()
|
||||||
|
|
||||||
|
await handle_deleted_message(conn, client, message_id, original_message.content, logger)
|
||||||
|
|
||||||
# # TODO: possibly move this to bot __init__ so commands can access it too
|
# # TODO: possibly move this to bot __init__ so commands can access it too
|
||||||
# class WebhookPermissionError(Exception):
|
# class WebhookPermissionError(Exception):
|
||||||
# pass
|
# pass
|
||||||
|
@ -39,7 +39,7 @@ async def parse_mention(client: discord.Client, mention: str) -> Optional[discor
|
|||||||
def parse_channel_mention(mention: str, server: discord.Guild) -> Optional[discord.TextChannel]:
|
def parse_channel_mention(mention: str, server: discord.Guild) -> Optional[discord.TextChannel]:
|
||||||
match = re.fullmatch("<#(\\d+)>", mention)
|
match = re.fullmatch("<#(\\d+)>", mention)
|
||||||
if match:
|
if match:
|
||||||
return server.get_channel(match.group(1))
|
return server.get_channel(int(match.group(1)))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return server.get_channel(int(mention))
|
return server.get_channel(int(mention))
|
||||||
|
Loading…
Reference in New Issue
Block a user