diff --git a/src/bot_main.py b/src/bot_main.py index 74ff7a02..211c162b 100644 --- a/src/bot_main.py +++ b/src/bot_main.py @@ -1,6 +1,4 @@ import asyncio -import json -import os import sys try: @@ -10,13 +8,5 @@ try: except ImportError: pass -with open(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf") as f: - config = json.load(f) - -if "database_uri" not in config and "DATABASE_URI" not in os.environ: - print("Config file must contain key 'database_uri', or the environment variable DATABASE_URI must be present.") -elif "token" not in config and "TOKEN" not in os.environ: - print("Config file must contain key 'token', or the environment variable TOKEN must be present.") -else: - from pluralkit import bot - bot.run(os.environ.get("TOKEN", config.get("token")), os.environ.get("DATABASE_URI", config.get("database_uri")), int(config.get("log_channel", 0))) \ No newline at end of file +from pluralkit import bot +bot.run(bot.Config.from_file_and_env(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf")) \ No newline at end of file diff --git a/src/pluralkit/bot/__init__.py b/src/pluralkit/bot/__init__.py index d160d2f4..357d1a9d 100644 --- a/src/pluralkit/bot/__init__.py +++ b/src/pluralkit/bot/__init__.py @@ -2,8 +2,10 @@ import asyncio import sys import asyncpg +from collections import namedtuple import discord import logging +import json import os import traceback @@ -12,13 +14,45 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") +class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])): + required_fields = ["database_uri", "token"] + + database_uri: str + token: str + log_channel: str + + @staticmethod + def from_file_and_env(filename: str) -> "Config": + try: + with open(filename, "r") as f: + config = json.load(f) + except IOError as e: + # If all the required fields are specified as environment variables, it's OK to + # not raise the IOError, we can just construct the dict from these + if all([rf.upper() in os.environ for rf in Config.required_fields]): + config = {} + else: + # If they aren't, though, then rethrow + raise e + + # Override with environment variables + for f in Config._fields: + if f.upper() in os.environ: + config[f] = os.environ[f.upper()] + + # If we currently don't have all the required fields, then raise + if not all([rf in config for rf in Config.required_fields]): + raise RuntimeError("Some required config fields were missing: " + ", ".join(filter(lambda rf: rf not in config, Config.required_fields))) + + return Config(**config) + def connect_to_database(uri: str) -> asyncpg.pool.Pool: return asyncio.get_event_loop().run_until_complete(db.connect(uri)) -def run(token: str, db_uri: str, log_channel_id: int): - pool = connect_to_database(db_uri) +def run(config: Config): + pool = connect_to_database(config.database_uri) async def create_tables(): async with pool.acquire() as conn: @@ -78,9 +112,9 @@ def run(token: str, db_uri: str, log_channel_id: int): # Then log it to the given log channel # TODO: replace this with Sentry or something - if not log_channel_id: + if not config.log_channel: return - log_channel = client.get_channel(log_channel_id) + log_channel = client.get_channel(int(config.log_channel)) # If this is a message event, we can attach additional information in an event # ie. username, channel, content, etc @@ -102,4 +136,4 @@ def run(token: str, db_uri: str, log_channel_id: int): if len(traceback.format_exc()) >= (2000 - len("```python\n```")): traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):]) await log_channel.send(content=traceback_str, embed=embed) - client.run(token) + client.run(config.token)