Refactored config file loading

This commit is contained in:
Ske 2019-03-08 17:22:05 +01:00
parent abda846ca3
commit 560b79c2ae
2 changed files with 41 additions and 17 deletions

View File

@ -1,6 +1,4 @@
import asyncio import asyncio
import json
import os
import sys import sys
try: try:
@ -10,13 +8,5 @@ try:
except ImportError: except ImportError:
pass pass
with open(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf") as f: from pluralkit import bot
config = json.load(f) bot.run(bot.Config.from_file_and_env(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf"))
if "database_uri" not in config and "DATABASE_URI" not in os.environ:
print("Config file must contain key 'database_uri', or the environment variable DATABASE_URI must be present.")
elif "token" not in config and "TOKEN" not in os.environ:
print("Config file must contain key 'token', or the environment variable TOKEN must be present.")
else:
from pluralkit import bot
bot.run(os.environ.get("TOKEN", config.get("token")), os.environ.get("DATABASE_URI", config.get("database_uri")), int(config.get("log_channel", 0)))

View File

@ -2,8 +2,10 @@ import asyncio
import sys import sys
import asyncpg import asyncpg
from collections import namedtuple
import discord import discord
import logging import logging
import json
import os import os
import traceback import traceback
@ -12,13 +14,45 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s") logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])):
required_fields = ["database_uri", "token"]
database_uri: str
token: str
log_channel: str
@staticmethod
def from_file_and_env(filename: str) -> "Config":
try:
with open(filename, "r") as f:
config = json.load(f)
except IOError as e:
# If all the required fields are specified as environment variables, it's OK to
# not raise the IOError, we can just construct the dict from these
if all([rf.upper() in os.environ for rf in Config.required_fields]):
config = {}
else:
# If they aren't, though, then rethrow
raise e
# Override with environment variables
for f in Config._fields:
if f.upper() in os.environ:
config[f] = os.environ[f.upper()]
# If we currently don't have all the required fields, then raise
if not all([rf in config for rf in Config.required_fields]):
raise RuntimeError("Some required config fields were missing: " + ", ".join(filter(lambda rf: rf not in config, Config.required_fields)))
return Config(**config)
def connect_to_database(uri: str) -> asyncpg.pool.Pool: def connect_to_database(uri: str) -> asyncpg.pool.Pool:
return asyncio.get_event_loop().run_until_complete(db.connect(uri)) return asyncio.get_event_loop().run_until_complete(db.connect(uri))
def run(token: str, db_uri: str, log_channel_id: int): def run(config: Config):
pool = connect_to_database(db_uri) pool = connect_to_database(config.database_uri)
async def create_tables(): async def create_tables():
async with pool.acquire() as conn: async with pool.acquire() as conn:
@ -78,9 +112,9 @@ def run(token: str, db_uri: str, log_channel_id: int):
# Then log it to the given log channel # Then log it to the given log channel
# TODO: replace this with Sentry or something # TODO: replace this with Sentry or something
if not log_channel_id: if not config.log_channel:
return return
log_channel = client.get_channel(log_channel_id) log_channel = client.get_channel(int(config.log_channel))
# If this is a message event, we can attach additional information in an event # If this is a message event, we can attach additional information in an event
# ie. username, channel, content, etc # ie. username, channel, content, etc
@ -102,4 +136,4 @@ def run(token: str, db_uri: str, log_channel_id: int):
if len(traceback.format_exc()) >= (2000 - len("```python\n```")): if len(traceback.format_exc()) >= (2000 - len("```python\n```")):
traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):]) traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):])
await log_channel.send(content=traceback_str, embed=embed) await log_channel.send(content=traceback_str, embed=embed)
client.run(token) client.run(config.token)