Refactored config file loading
This commit is contained in:
		| @@ -1,6 +1,4 @@ | ||||
| import asyncio | ||||
| import json | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| try: | ||||
| @@ -10,13 +8,5 @@ try: | ||||
| except ImportError: | ||||
|     pass | ||||
|  | ||||
| with open(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf") as f: | ||||
|     config = json.load(f) | ||||
|  | ||||
| if "database_uri" not in config and "DATABASE_URI" not in os.environ: | ||||
|     print("Config file must contain key 'database_uri', or the environment variable DATABASE_URI must be present.") | ||||
| elif "token" not in config and "TOKEN" not in os.environ: | ||||
|     print("Config file must contain key 'token', or the environment variable TOKEN must be present.") | ||||
| else: | ||||
|     from pluralkit import bot | ||||
|     bot.run(os.environ.get("TOKEN", config.get("token")), os.environ.get("DATABASE_URI", config.get("database_uri")), int(config.get("log_channel", 0))) | ||||
| from pluralkit import bot | ||||
| bot.run(bot.Config.from_file_and_env(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf")) | ||||
| @@ -2,8 +2,10 @@ import asyncio | ||||
| import sys | ||||
|  | ||||
| import asyncpg | ||||
| from collections import namedtuple | ||||
| import discord | ||||
| import logging | ||||
| import json | ||||
| import os | ||||
| import traceback | ||||
|  | ||||
| @@ -12,13 +14,45 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds | ||||
|  | ||||
| logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") | ||||
|  | ||||
| class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])): | ||||
|     required_fields = ["database_uri", "token"] | ||||
|  | ||||
|     database_uri: str | ||||
|     token: str | ||||
|     log_channel: str | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_file_and_env(filename: str) -> "Config": | ||||
|         try: | ||||
|             with open(filename, "r") as f: | ||||
|                 config = json.load(f) | ||||
|         except IOError as e: | ||||
|             # If all the required fields are specified as environment variables, it's OK to  | ||||
|             # not raise the IOError, we can just construct the dict from these | ||||
|             if all([rf.upper() in os.environ for rf in Config.required_fields]): | ||||
|                 config = {} | ||||
|             else: | ||||
|                 # If they aren't, though, then rethrow | ||||
|                 raise e | ||||
|  | ||||
|         # Override with environment variables | ||||
|         for f in Config._fields: | ||||
|             if f.upper() in os.environ: | ||||
|                 config[f] = os.environ[f.upper()] | ||||
|  | ||||
|         # If we currently don't have all the required fields, then raise | ||||
|         if not all([rf in config for rf in Config.required_fields]): | ||||
|             raise RuntimeError("Some required config fields were missing: " + ", ".join(filter(lambda rf: rf not in config, Config.required_fields))) | ||||
|  | ||||
|         return Config(**config) | ||||
|  | ||||
|  | ||||
| def connect_to_database(uri: str) -> asyncpg.pool.Pool: | ||||
|     return asyncio.get_event_loop().run_until_complete(db.connect(uri)) | ||||
|  | ||||
|  | ||||
| def run(token: str, db_uri: str, log_channel_id: int): | ||||
|     pool = connect_to_database(db_uri) | ||||
| def run(config: Config): | ||||
|     pool = connect_to_database(config.database_uri) | ||||
|  | ||||
|     async def create_tables(): | ||||
|         async with pool.acquire() as conn: | ||||
| @@ -78,9 +112,9 @@ def run(token: str, db_uri: str, log_channel_id: int): | ||||
|  | ||||
|         # Then log it to the given log channel | ||||
|         # TODO: replace this with Sentry or something | ||||
|         if not log_channel_id: | ||||
|         if not config.log_channel: | ||||
|             return | ||||
|         log_channel = client.get_channel(log_channel_id) | ||||
|         log_channel = client.get_channel(int(config.log_channel)) | ||||
|  | ||||
|         # If this is a message event, we can attach additional information in an event | ||||
|         # ie. username, channel, content, etc | ||||
| @@ -102,4 +136,4 @@ def run(token: str, db_uri: str, log_channel_id: int): | ||||
|         if len(traceback.format_exc()) >= (2000 - len("```python\n```")): | ||||
|             traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):]) | ||||
|         await log_channel.send(content=traceback_str, embed=embed) | ||||
|     client.run(token) | ||||
|     client.run(config.token) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user