Migrate DI container to Autofac

This commit is contained in:
Ske
2020-01-26 01:27:45 +01:00
parent 4311cb3ad1
commit 1ea0526ef8
10 changed files with 384 additions and 292 deletions

View File

@@ -7,6 +7,9 @@ using System.Threading;
using System.Threading.Tasks;
using App.Metrics;
using Autofac;
using Autofac.Core;
using Dapper;
using Discord;
using Discord.WebSocket;
@@ -15,11 +18,16 @@ using Microsoft.Extensions.DependencyInjection;
using PluralKit.Bot.Commands;
using PluralKit.Bot.CommandSystem;
using PluralKit.Core;
using Sentry;
using Sentry.Infrastructure;
using Serilog;
using Serilog.Events;
using SystemClock = NodaTime.SystemClock;
namespace PluralKit.Bot
{
class Initialize
@@ -45,105 +53,57 @@ namespace PluralKit.Bot
args.Cancel = true;
token.Cancel();
};
var builder = new ContainerBuilder();
builder.RegisterInstance(_config);
builder.RegisterModule(new ConfigModule<BotConfig>("Bot"));
builder.RegisterModule(new LoggingModule("bot"));
builder.RegisterModule(new MetricsModule());
builder.RegisterModule<DataStoreModule>();
builder.RegisterModule<BotModule>();
using var services = builder.Build();
using (var services = BuildServiceProvider())
var logger = services.Resolve<ILogger>().ForContext<Initialize>();
try
{
SchemaService.Initialize();
var coreConfig = services.Resolve<CoreConfig>();
var botConfig = services.Resolve<BotConfig>();
var schema = services.Resolve<SchemaService>();
using var _ = Sentry.SentrySdk.Init(coreConfig.SentryUrl);
var logger = services.GetRequiredService<ILogger>().ForContext<Initialize>();
var coreConfig = services.GetRequiredService<CoreConfig>();
var botConfig = services.GetRequiredService<BotConfig>();
var schema = services.GetRequiredService<SchemaService>();
logger.Information("Connecting to database");
await schema.ApplyMigrations();
using (Sentry.SentrySdk.Init(coreConfig.SentryUrl))
logger.Information("Connecting to Discord");
var client = services.Resolve<DiscordShardedClient>();
await client.LoginAsync(TokenType.Bot, botConfig.Token);
logger.Information("Initializing bot");
await client.StartAsync();
await services.Resolve<Bot>().Init();
try
{
logger.Information("Connecting to database");
await schema.ApplyMigrations();
logger.Information("Connecting to Discord");
var client = services.GetRequiredService<IDiscordClient>() as DiscordShardedClient;
await client.LoginAsync(TokenType.Bot, botConfig.Token);
logger.Information("Initializing bot");
await services.GetRequiredService<Bot>().Init();
await client.StartAsync();
try
{
await Task.Delay(-1, token.Token);
}
catch (TaskCanceledException) { } // We'll just exit normally
logger.Information("Shutting down");
await Task.Delay(-1, token.Token);
}
catch (TaskCanceledException) { } // We'll just exit normally
}
}
public ServiceProvider BuildServiceProvider() => new ServiceCollection()
.AddTransient(_ => _config.GetSection("PluralKit").Get<CoreConfig>() ?? new CoreConfig())
.AddTransient(_ => _config.GetSection("PluralKit").GetSection("Bot").Get<BotConfig>() ?? new BotConfig())
.AddSingleton<DbConnectionCountHolder>()
.AddTransient<DbConnectionFactory>()
.AddTransient<SchemaService>()
.AddSingleton<IDiscordClient, DiscordShardedClient>(_ => new DiscordShardedClient(new DiscordSocketConfig
catch (Exception e)
{
MessageCacheSize = 0,
ConnectionTimeout = 2*60*1000,
ExclusiveBulkDelete = true,
LargeThreshold = 50,
DefaultRetryMode = RetryMode.RetryTimeouts | RetryMode.RetryRatelimit
// Commented this out since Debug actually sends, uh, quite a lot that's not necessary in production
// but leaving it here in case I (or someone else) get[s] confused about why logging isn't working again :p
// LogLevel = LogSeverity.Debug // We filter log levels in Serilog, so just pass everything through (Debug is lower than Verbose)
}))
.AddSingleton<Bot>()
.AddSingleton(_ => new HttpClient { Timeout = TimeSpan.FromSeconds(5) })
.AddTransient<CommandTree>()
.AddTransient<SystemCommands>()
.AddTransient<MemberCommands>()
.AddTransient<SwitchCommands>()
.AddTransient<LinkCommands>()
.AddTransient<APICommands>()
.AddTransient<ImportExportCommands>()
.AddTransient<HelpCommands>()
.AddTransient<ModCommands>()
.AddTransient<MiscCommands>()
.AddTransient<AutoproxyCommands>()
.AddTransient<EmbedService>()
.AddTransient<ProxyService>()
.AddTransient<LogChannelService>()
.AddTransient<DataFileService>()
.AddTransient<WebhookExecutorService>()
logger.Fatal(e, "Unrecoverable error while initializing bot");
}
.AddTransient<ProxyCacheService>()
.AddSingleton<WebhookCacheService>()
.AddSingleton<AutoproxyCacheService>()
.AddSingleton<ShardInfoService>()
.AddSingleton<CpuStatService>()
.AddTransient<IDataStore, PostgresDataStore>()
.AddSingleton(svc => InitUtils.InitMetrics(svc.GetRequiredService<CoreConfig>()))
.AddSingleton<PeriodicStatCollector>()
.AddScoped(_ => new Sentry.Scope(null))
.AddTransient<PKEventHandler>()
.AddScoped<EventIdProvider>()
.AddSingleton(svc => new LoggerProvider(svc.GetRequiredService<CoreConfig>(), "bot"))
.AddScoped(svc => svc.GetRequiredService<LoggerProvider>().RootLogger.ForContext("EventId", svc.GetRequiredService<EventIdProvider>().EventId))
.AddMemoryCache()
.BuildServiceProvider();
logger.Information("Shutting down");
}
}
class Bot
{
private IServiceProvider _services;
private ILifetimeScope _services;
private DiscordShardedClient _client;
private Timer _updateTimer;
private IMetrics _metrics;
@@ -151,7 +111,7 @@ namespace PluralKit.Bot
private ILogger _logger;
private PKPerformanceEventListener _pl;
public Bot(IServiceProvider services, IDiscordClient client, IMetrics metrics, PeriodicStatCollector collector, ILogger logger)
public Bot(ILifetimeScope services, IDiscordClient client, IMetrics metrics, PeriodicStatCollector collector, ILogger logger)
{
_pl = new PKPerformanceEventListener();
_services = services;
@@ -167,12 +127,12 @@ namespace PluralKit.Bot
_client.ShardReady += ShardReady;
_client.Log += FrameworkLog;
_client.MessageReceived += (msg) => HandleEvent(s => s.AddMessageBreadcrumb(msg), eh => eh.HandleMessage(msg));
_client.ReactionAdded += (msg, channel, reaction) => HandleEvent(s => s.AddReactionAddedBreadcrumb(msg, channel, reaction), eh => eh.HandleReactionAdded(msg, channel, reaction));
_client.MessageDeleted += (msg, channel) => HandleEvent(s => s.AddMessageDeleteBreadcrumb(msg, channel), eh => eh.HandleMessageDeleted(msg, channel));
_client.MessagesBulkDeleted += (msgs, channel) => HandleEvent(s => s.AddMessageBulkDeleteBreadcrumb(msgs, channel), eh => eh.HandleMessagesBulkDelete(msgs, channel));
_client.MessageReceived += (msg) => HandleEvent(eh => eh.HandleMessage(msg));
_client.ReactionAdded += (msg, channel, reaction) => HandleEvent(eh => eh.HandleReactionAdded(msg, channel, reaction));
_client.MessageDeleted += (msg, channel) => HandleEvent(eh => eh.HandleMessageDeleted(msg, channel));
_client.MessagesBulkDeleted += (msgs, channel) => HandleEvent(eh => eh.HandleMessagesBulkDelete(msgs, channel));
_services.GetService<ShardInfoService>().Init(_client);
_services.Resolve<ShardInfoService>().Init(_client);
return Task.CompletedTask;
}
@@ -218,24 +178,24 @@ namespace PluralKit.Bot
private Task ShardReady(DiscordSocketClient shardClient)
{
_logger.Information("Shard {Shard} connected", shardClient.ShardId);
Console.WriteLine($"Shard #{shardClient.ShardId} connected to {shardClient.Guilds.Sum(g => g.Channels.Count)} channels in {shardClient.Guilds.Count} guilds.");
_logger.Information("Shard {Shard} connected to {ChannelCount} channels in {GuildCount} guilds", shardClient.ShardId, shardClient.Guilds.Sum(g => g.Channels.Count), shardClient.Guilds.Count);
if (shardClient.ShardId == 0)
{
_updateTimer = new Timer((_) => {
HandleEvent(s => s.AddPeriodicBreadcrumb(), __ => UpdatePeriodic());
HandleEvent(_ => UpdatePeriodic());
}, null, TimeSpan.Zero, TimeSpan.FromMinutes(1));
Console.WriteLine(
$"PluralKit started as {_client.CurrentUser.Username}#{_client.CurrentUser.Discriminator} ({_client.CurrentUser.Id}).");
}
_logger.Information("PluralKit started as {Username}#{Discriminator} ({Id})", _client.CurrentUser.Username, _client.CurrentUser.Discriminator, _client.CurrentUser.Id);
}
return Task.CompletedTask;
}
private Task HandleEvent(Action<Scope> breadcrumbFactory, Func<PKEventHandler, Task> handler)
private Task HandleEvent(Func<PKEventHandler, Task> handler)
{
_logger.Debug("Received event");
// Inner function so we can await the handler without stalling the entire pipeline
async Task Inner()
{
@@ -243,46 +203,36 @@ namespace PluralKit.Bot
// This prevents any synchronous nonsense from also stalling the pipeline before the first await point
await Task.Yield();
// Create a DI scope for this event
// and log the breadcrumb to the newly created (in-svc-scope) Sentry scope
using (var scope = _services.CreateScope())
{
var evtid = scope.ServiceProvider.GetService<EventIdProvider>().EventId;
try
{
await handler(scope.ServiceProvider.GetRequiredService<PKEventHandler>());
}
catch (Exception e)
{
var sentryScope = scope.ServiceProvider.GetRequiredService<Scope>();
sentryScope.SetTag("evtid", evtid.ToString());
breadcrumbFactory(sentryScope);
HandleRuntimeError(e, scope.ServiceProvider);
}
}
using var containerScope = _services.BeginLifetimeScope();
var sentryScope = containerScope.Resolve<Scope>();
var eventHandler = containerScope.Resolve<PKEventHandler>();
try
{
await handler(eventHandler);
}
catch (Exception e)
{
await HandleRuntimeError(eventHandler, e, sentryScope);
}
}
#pragma warning disable 4014
Inner();
#pragma warning restore 4014
var _ = Inner();
return Task.CompletedTask;
}
private void HandleRuntimeError(Exception e, IServiceProvider services)
private async Task HandleRuntimeError(PKEventHandler eventHandler, Exception exc, Scope scope)
{
var logger = services.GetRequiredService<ILogger>();
var scope = services.GetRequiredService<Scope>();
_logger.Error(exc, "Exception in bot event handler");
logger.Error(e, "Exception in bot event handler");
var evt = new SentryEvent(e);
var evt = new SentryEvent(exc);
// Don't blow out our Sentry budget on sporadic not-our-problem erorrs
if (e.IsOurProblem())
if (exc.IsOurProblem())
SentrySdk.CaptureEvent(evt, scope);
// Once we've sent it to Sentry, report it to the user
await eventHandler.ReportError(evt, exc);
}
}
@@ -292,28 +242,33 @@ namespace PluralKit.Bot
private IMetrics _metrics;
private DiscordShardedClient _client;
private DbConnectionFactory _connectionFactory;
private IServiceProvider _services;
private ILifetimeScope _services;
private CommandTree _tree;
private IDataStore _data;
private Scope _sentryScope;
public PKEventHandler(ProxyService proxy, ILogger logger, IMetrics metrics, IDiscordClient client, DbConnectionFactory connectionFactory, IServiceProvider services, CommandTree tree, IDataStore data)
// We're defining in the Autofac module that this class is instantiated with one instance per event
// This means that the HandleMessage function will either be called once, or not at all
// The ReportError function will be called on an error, and needs to refer back to the "trigger message"
// hence, we just store it in a local variable, ignoring it entirely if it's null.
private IUserMessage _msg = null;
public PKEventHandler(ProxyService proxy, ILogger logger, IMetrics metrics, DiscordShardedClient client, DbConnectionFactory connectionFactory, ILifetimeScope services, CommandTree tree, Scope sentryScope)
{
_proxy = proxy;
_logger = logger;
_metrics = metrics;
_client = (DiscordShardedClient) client;
_client = client;
_connectionFactory = connectionFactory;
_services = services;
_tree = tree;
_data = data;
_sentryScope = sentryScope;
}
public async Task HandleMessage(SocketMessage arg)
{
if (_client.GetShardFor((arg.Channel as IGuildChannel)?.Guild).ConnectionState != ConnectionState.Connected)
return; // Discard messages while the bot "catches up" to avoid unnecessary CPU pressure causing timeouts
RegisterMessageMetrics(arg);
// Ignore system messages (member joined, message pinned, etc)
@@ -323,6 +278,16 @@ namespace PluralKit.Bot
// Ignore bot messages
if (msg.Author.IsBot || msg.Author.IsWebhook) return;
// Add message info as Sentry breadcrumb
_msg = msg;
_sentryScope.AddBreadcrumb(msg.Content, "event.message", data: new Dictionary<string, string>
{
{"user", msg.Author.Id.ToString()},
{"channel", msg.Channel.Id.ToString()},
{"guild", ((msg.Channel as IGuildChannel)?.GuildId ?? 0).ToString()},
{"message", msg.Id.ToString()},
});
int argPos = -1;
// Check if message starts with the command prefix
if (msg.Content.StartsWith("pk;", StringComparison.InvariantCultureIgnoreCase)) argPos = 3;
@@ -349,18 +314,8 @@ namespace PluralKit.Bot
system = await conn.QueryFirstOrDefaultAsync<PKSystem>(
"select systems.* from systems, accounts where accounts.uid = @Id and systems.id = accounts.system",
new {Id = msg.Author.Id});
try
{
await _tree.ExecuteCommand(new Context(_services, msg, argPos, system));
}
catch (Exception e)
{
await HandleCommandError(msg, e);
// HandleCommandError only *reports* the error, we gotta pass it through to the parent
// error handler by rethrowing:
throw;
}
await _tree.ExecuteCommand(new Context(_services, msg, argPos, system));
}
else
{
@@ -376,16 +331,19 @@ namespace PluralKit.Bot
}
}
private async Task HandleCommandError(SocketUserMessage msg, Exception exception)
public async Task ReportError(SentryEvent evt, Exception exc)
{
// If we don't have a "trigger message", bail
if (_msg == null) return;
// This function *specifically* handles reporting a command execution error to the user.
// We'll fetch the event ID and send a user-facing error message.
// ONLY IF this error's actually our problem. As for what defines an error as "our problem",
// check the extension method :)
if (exception.IsOurProblem())
if (exc.IsOurProblem())
{
var eid = _services.GetService<EventIdProvider>().EventId;
await msg.Channel.SendMessageAsync(
var eid = evt.EventId;
await _msg.Channel.SendMessageAsync(
$"{Emojis.Error} Internal error occurred. Please join the support server (<https://discord.gg/PczBt78>), and send the developer this ID: `{eid}`\nBe sure to include a description of what you were doing to make the error occur.");
}
@@ -401,12 +359,43 @@ namespace PluralKit.Bot
}
public Task HandleReactionAdded(Cacheable<IUserMessage, ulong> message, ISocketMessageChannel channel,
SocketReaction reaction) => _proxy.HandleReactionAddedAsync(message, channel, reaction);
SocketReaction reaction)
{
_sentryScope.AddBreadcrumb("", "event.reaction", data: new Dictionary<string, string>()
{
{"user", reaction.UserId.ToString()},
{"channel", channel.Id.ToString()},
{"guild", ((channel as IGuildChannel)?.GuildId ?? 0).ToString()},
{"message", message.Id.ToString()},
{"reaction", reaction.Emote.Name}
});
return _proxy.HandleReactionAddedAsync(message, channel, reaction);
}
public Task HandleMessageDeleted(Cacheable<IMessage, ulong> message, ISocketMessageChannel channel) =>
_proxy.HandleMessageDeletedAsync(message, channel);
public Task HandleMessageDeleted(Cacheable<IMessage, ulong> message, ISocketMessageChannel channel)
{
_sentryScope.AddBreadcrumb("", "event.messageDelete", data: new Dictionary<string, string>()
{
{"channel", channel.Id.ToString()},
{"guild", ((channel as IGuildChannel)?.GuildId ?? 0).ToString()},
{"message", message.Id.ToString()},
});
return _proxy.HandleMessageDeletedAsync(message, channel);
}
public Task HandleMessagesBulkDelete(IReadOnlyCollection<Cacheable<IMessage, ulong>> messages,
IMessageChannel channel) => _proxy.HandleMessageBulkDeleteAsync(messages, channel);
IMessageChannel channel)
{
_sentryScope.AddBreadcrumb("", "event.messageDelete", data: new Dictionary<string, string>()
{
{"channel", channel.Id.ToString()},
{"guild", ((channel as IGuildChannel)?.GuildId ?? 0).ToString()},
{"messages", string.Join(",", messages.Select(m => m.Id))},
});
return _proxy.HandleMessageBulkDeleteAsync(messages, channel);
}
}
}