Refactored config file loading
This commit is contained in:
parent
abda846ca3
commit
560b79c2ae
@ -1,6 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -10,13 +8,5 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with open(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf") as f:
|
from pluralkit import bot
|
||||||
config = json.load(f)
|
bot.run(bot.Config.from_file_and_env(sys.argv[1] if len(sys.argv) > 1 else "pluralkit.conf"))
|
||||||
|
|
||||||
if "database_uri" not in config and "DATABASE_URI" not in os.environ:
|
|
||||||
print("Config file must contain key 'database_uri', or the environment variable DATABASE_URI must be present.")
|
|
||||||
elif "token" not in config and "TOKEN" not in os.environ:
|
|
||||||
print("Config file must contain key 'token', or the environment variable TOKEN must be present.")
|
|
||||||
else:
|
|
||||||
from pluralkit import bot
|
|
||||||
bot.run(os.environ.get("TOKEN", config.get("token")), os.environ.get("DATABASE_URI", config.get("database_uri")), int(config.get("log_channel", 0)))
|
|
@ -2,8 +2,10 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import asyncpg
|
import asyncpg
|
||||||
|
from collections import namedtuple
|
||||||
import discord
|
import discord
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@ -12,13 +14,45 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds
|
|||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||||
|
|
||||||
|
class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])):
|
||||||
|
required_fields = ["database_uri", "token"]
|
||||||
|
|
||||||
|
database_uri: str
|
||||||
|
token: str
|
||||||
|
log_channel: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_file_and_env(filename: str) -> "Config":
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
except IOError as e:
|
||||||
|
# If all the required fields are specified as environment variables, it's OK to
|
||||||
|
# not raise the IOError, we can just construct the dict from these
|
||||||
|
if all([rf.upper() in os.environ for rf in Config.required_fields]):
|
||||||
|
config = {}
|
||||||
|
else:
|
||||||
|
# If they aren't, though, then rethrow
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Override with environment variables
|
||||||
|
for f in Config._fields:
|
||||||
|
if f.upper() in os.environ:
|
||||||
|
config[f] = os.environ[f.upper()]
|
||||||
|
|
||||||
|
# If we currently don't have all the required fields, then raise
|
||||||
|
if not all([rf in config for rf in Config.required_fields]):
|
||||||
|
raise RuntimeError("Some required config fields were missing: " + ", ".join(filter(lambda rf: rf not in config, Config.required_fields)))
|
||||||
|
|
||||||
|
return Config(**config)
|
||||||
|
|
||||||
|
|
||||||
def connect_to_database(uri: str) -> asyncpg.pool.Pool:
|
def connect_to_database(uri: str) -> asyncpg.pool.Pool:
|
||||||
return asyncio.get_event_loop().run_until_complete(db.connect(uri))
|
return asyncio.get_event_loop().run_until_complete(db.connect(uri))
|
||||||
|
|
||||||
|
|
||||||
def run(token: str, db_uri: str, log_channel_id: int):
|
def run(config: Config):
|
||||||
pool = connect_to_database(db_uri)
|
pool = connect_to_database(config.database_uri)
|
||||||
|
|
||||||
async def create_tables():
|
async def create_tables():
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
@ -78,9 +112,9 @@ def run(token: str, db_uri: str, log_channel_id: int):
|
|||||||
|
|
||||||
# Then log it to the given log channel
|
# Then log it to the given log channel
|
||||||
# TODO: replace this with Sentry or something
|
# TODO: replace this with Sentry or something
|
||||||
if not log_channel_id:
|
if not config.log_channel:
|
||||||
return
|
return
|
||||||
log_channel = client.get_channel(log_channel_id)
|
log_channel = client.get_channel(int(config.log_channel))
|
||||||
|
|
||||||
# If this is a message event, we can attach additional information in an event
|
# If this is a message event, we can attach additional information in an event
|
||||||
# ie. username, channel, content, etc
|
# ie. username, channel, content, etc
|
||||||
@ -102,4 +136,4 @@ def run(token: str, db_uri: str, log_channel_id: int):
|
|||||||
if len(traceback.format_exc()) >= (2000 - len("```python\n```")):
|
if len(traceback.format_exc()) >= (2000 - len("```python\n```")):
|
||||||
traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):])
|
traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):])
|
||||||
await log_channel.send(content=traceback_str, embed=embed)
|
await log_channel.send(content=traceback_str, embed=embed)
|
||||||
client.run(token)
|
client.run(config.token)
|
||||||
|
Loading…
Reference in New Issue
Block a user