PluralKit/src/pluralkit/bot/__init__.py

244 lines
8.8 KiB
Python

import asyncpg
import sys
import asyncio
import os
import logging
import discord
import traceback
from pluralkit import db
from pluralkit.bot import commands, proxy, channel_logger, embeds
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
def connect_to_database() -> asyncpg.pool.Pool:
username = os.environ["DATABASE_USER"]
password = os.environ["DATABASE_PASS"]
name = os.environ["DATABASE_NAME"]
host = os.environ["DATABASE_HOST"]
port = os.environ["DATABASE_PORT"]
if username is None or password is None or name is None or host is None or port is None:
print(
"Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.",
file=sys.stderr)
sys.exit(1)
try:
port = int(port)
except ValueError:
print("Please pass a valid integer as the DATABASE_PORT environment variable.", file=sys.stderr)
sys.exit(1)
return asyncio.get_event_loop().run_until_complete(db.connect(
username=username,
password=password,
database=name,
host=host,
port=port
))
def run():
pool = connect_to_database()
async def create_tables():
async with pool.acquire() as conn:
await db.create_tables(conn)
asyncio.get_event_loop().run_until_complete(create_tables())
client = discord.Client()
logger = channel_logger.ChannelLogger(client)
@client.event
async def on_ready():
print("PluralKit started.")
print("User: {}#{} (ID: {})".format(client.user.name, client.user.discriminator, client.user.id))
print("{} servers".format(len(client.guilds)))
print("{} shards".format(client.shard_count or 1))
@client.event
async def on_message(message: discord.Message):
# Ignore messages from bots
if message.author.bot:
return
# Grab a database connection from the pool
async with pool.acquire() as conn:
# First pass: do command handling
did_run_command = await commands.command_dispatch(client, message, conn)
if did_run_command:
return
# Second pass: do proxy matching
await proxy.try_proxy_message(conn, message, logger)
@client.event
async def on_raw_message_delete(payload: discord.RawMessageDeleteEvent):
async with pool.acquire() as conn:
await proxy.handle_deleted_message(conn, client, payload.message_id, None, logger)
@client.event
async def on_raw_bulk_message_delete(payload: discord.RawBulkMessageDeleteEvent):
async with pool.acquire() as conn:
for message_id in payload.message_ids:
await proxy.handle_deleted_message(conn, client, message_id, None, logger)
@client.event
async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
if payload.emoji.name == "\u274c": # Red X
async with pool.acquire() as conn:
await proxy.try_delete_by_reaction(conn, client, payload.message_id, payload.user_id, logger)
@client.event
async def on_error(event_name, *args, **kwargs):
log_channel_id = os.environ["LOG_CHANNEL"]
if not log_channel_id:
return
log_channel = client.get_channel(int(log_channel_id))
# If this is a message event, we can attach additional information in an event
# ie. username, channel, content, etc
if args and isinstance(args[0], discord.Message):
message: discord.Message = args[0]
embed = embeds.exception_log(
message.content,
message.author.name,
message.author.discriminator,
message.author.id,
message.guild.id if message.guild else None,
message.channel.id
)
else:
# If not, just post the string itself
embed = None
traceback_str = "```python\n{}```".format(traceback.format_exc())
await log_channel.send(content=traceback_str, embed=embed)
bot_token = os.environ["TOKEN"]
if not bot_token:
print("No token specified. Please pass a valid Discord bot token in the TOKEN environment variable.",
file=sys.stderr)
sys.exit(1)
client.run(bot_token)
# logging.getLogger("pluralkit").setLevel(logging.DEBUG)
# class PluralKitBot:
# def __init__(self, token):
# self.token = token
# self.logger = logging.getLogger("pluralkit.bot")
#
# self.client = discord.Client()
# self.client.event(self.on_error)
# self.client.event(self.on_ready)
# self.client.event(self.on_message)
# self.client.event(self.on_socket_raw_receive)
#
# self.channel_logger = channel_logger.ChannelLogger(self.client)
#
# self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
#
# async def on_error(self, evt, *args, **kwargs):
# self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
#
# async def on_ready(self):
# self.logger.info("Connected to Discord.")
# self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
# self.logger.info("- User ID: {}".format(self.client.user.id))
# self.logger.info("- {} servers".format(len(self.client.servers)))
#
# async def on_message(self, message):
# # Ignore bot messages
# if message.author.bot:
# return
#
# try:
# if await self.handle_command_dispatch(message):
# return
#
# if await self.handle_proxy_dispatch(message):
# return
# except Exception:
# await self.log_error_in_channel(message)
#
# async def on_socket_raw_receive(self, msg):
# # Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# # we parse socket data manually for the reaction add event
# if isinstance(msg, str):
# try:
# msg_data = json.loads(msg)
# if msg_data.get("t") == "MESSAGE_REACTION_ADD":
# evt_data = msg_data.get("d")
# if evt_data:
# user_id = evt_data["user_id"]
# message_id = evt_data["message_id"]
# emoji = evt_data["emoji"]["name"]
#
# async with self.pool.acquire() as conn:
# await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
# elif msg_data.get("t") == "MESSAGE_DELETE":
# evt_data = msg_data.get("d")
# if evt_data:
# message_id = evt_data["id"]
# async with self.pool.acquire() as conn:
# await self.proxy.handle_deletion(conn, message_id)
# except ValueError:
# pass
#
# async def handle_command_dispatch(self, message):
# async with self.pool.acquire() as conn:
# result = await commands.command_dispatch(self.client, message, conn)
# return result
#
# async def handle_proxy_dispatch(self, message):
# # Try doing proxy parsing
# async with self.pool.acquire() as conn:
# return await self.proxy.try_proxy_message(conn, message)
#
# async def log_error_in_channel(self, message):
# channel_id = os.environ["LOG_CHANNEL"]
# if not channel_id:
# return
#
# channel = self.client.get_channel(channel_id)
# embed = embeds.exception_log(
# message.content,
# message.author.name,
# message.author.discriminator,
# message.server.id if message.server else None,
# message.channel.id
# )
#
# await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed)
#
# async def run(self):
# try:
# self.logger.info("Connecting to database...")
# self.pool = await db.connect(
# os.environ["DATABASE_USER"],
# os.environ["DATABASE_PASS"],
# os.environ["DATABASE_NAME"],
# os.environ["DATABASE_HOST"],
# int(os.environ["DATABASE_PORT"])
# )
#
# self.logger.info("Attempting to create tables...")
# async with self.pool.acquire() as conn:
# await db.create_tables(conn)
#
# self.logger.info("Connecting to Discord...")
# await self.client.start(self.token)
# finally:
# self.logger.info("Logging out from Discord...")
# await self.client.logout()