diff --git a/PluralKit.API/Controllers/SystemController.cs b/PluralKit.API/Controllers/SystemController.cs index 9ff9fd6b..05bc4223 100644 --- a/PluralKit.API/Controllers/SystemController.cs +++ b/PluralKit.API/Controllers/SystemController.cs @@ -40,10 +40,10 @@ namespace PluralKit.API public class SystemController : ControllerBase { private IDataStore _data; - private DbConnectionFactory _conn; + private Database _conn; private TokenAuthService _auth; - public SystemController(IDataStore data, DbConnectionFactory conn, TokenAuthService auth) + public SystemController(IDataStore data, Database conn, TokenAuthService auth) { _data = data; _conn = conn; diff --git a/PluralKit.Bot/Commands/Autoproxy.cs b/PluralKit.Bot/Commands/Autoproxy.cs index 0f73ad09..a89e9e9f 100644 --- a/PluralKit.Bot/Commands/Autoproxy.cs +++ b/PluralKit.Bot/Commands/Autoproxy.cs @@ -11,9 +11,9 @@ namespace PluralKit.Bot { public class Autoproxy { - private readonly DbConnectionFactory _db; + private readonly Database _db; - public Autoproxy(DbConnectionFactory db) + public Autoproxy(Database db) { _db = db; } diff --git a/PluralKit.Bot/Commands/MemberAvatar.cs b/PluralKit.Bot/Commands/MemberAvatar.cs index c425381c..1ddb8266 100644 --- a/PluralKit.Bot/Commands/MemberAvatar.cs +++ b/PluralKit.Bot/Commands/MemberAvatar.cs @@ -14,9 +14,9 @@ namespace PluralKit.Bot { public class MemberAvatar { - private readonly DbConnectionFactory _db; + private readonly Database _db; - public MemberAvatar(DbConnectionFactory db) + public MemberAvatar(Database db) { _db = db; } diff --git a/PluralKit.Bot/Commands/MemberEdit.cs b/PluralKit.Bot/Commands/MemberEdit.cs index 251d5494..2fc6019d 100644 --- a/PluralKit.Bot/Commands/MemberEdit.cs +++ b/PluralKit.Bot/Commands/MemberEdit.cs @@ -12,9 +12,9 @@ namespace PluralKit.Bot public class MemberEdit { private readonly IDataStore _data; - private readonly DbConnectionFactory _db; + private readonly Database _db; - public MemberEdit(IDataStore data, DbConnectionFactory db) + public MemberEdit(IDataStore data, Database db) { _data = data; _db = db; diff --git a/PluralKit.Bot/Commands/ServerConfig.cs b/PluralKit.Bot/Commands/ServerConfig.cs index b8cdb144..c0a130ec 100644 --- a/PluralKit.Bot/Commands/ServerConfig.cs +++ b/PluralKit.Bot/Commands/ServerConfig.cs @@ -13,9 +13,9 @@ namespace PluralKit.Bot { public class ServerConfig { - private DbConnectionFactory _db; + private Database _db; private LoggerCleanService _cleanService; - public ServerConfig(LoggerCleanService cleanService, DbConnectionFactory db) + public ServerConfig(LoggerCleanService cleanService, Database db) { _cleanService = cleanService; _db = db; diff --git a/PluralKit.Bot/Commands/SystemEdit.cs b/PluralKit.Bot/Commands/SystemEdit.cs index ac6b40b7..798bb5c5 100644 --- a/PluralKit.Bot/Commands/SystemEdit.cs +++ b/PluralKit.Bot/Commands/SystemEdit.cs @@ -18,10 +18,10 @@ namespace PluralKit.Bot public class SystemEdit { private IDataStore _data; - private DbConnectionFactory _db; + private Database _db; private EmbedService _embeds; - public SystemEdit(IDataStore data, EmbedService embeds, DbConnectionFactory db) + public SystemEdit(IDataStore data, EmbedService embeds, Database db) { _data = data; _embeds = embeds; diff --git a/PluralKit.Bot/Commands/SystemList.cs b/PluralKit.Bot/Commands/SystemList.cs index da2cccc4..64e1bbec 100644 --- a/PluralKit.Bot/Commands/SystemList.cs +++ b/PluralKit.Bot/Commands/SystemList.cs @@ -16,10 +16,10 @@ namespace PluralKit.Bot public class SystemList { private readonly IClock _clock; - private readonly DbConnectionFactory _db; + private readonly Database _db; private readonly ILogger _logger; - public SystemList(DbConnectionFactory db, ILogger logger, IClock clock) + public SystemList(Database db, ILogger logger, IClock clock) { _db = db; _logger = logger; diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index 59919cf5..87333b0f 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -22,12 +22,12 @@ namespace PluralKit.Bot private readonly IMetrics _metrics; private readonly ProxyService _proxy; private readonly ILifetimeScope _services; - private readonly DbConnectionFactory _db; + private readonly Database _db; private readonly IDataStore _data; public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, IMetrics metrics, ProxyService proxy, DiscordShardedClient client, - CommandTree tree, ILifetimeScope services, DbConnectionFactory db, IDataStore data) + CommandTree tree, ILifetimeScope services, Database db, IDataStore data) { _lastMessageCache = lastMessageCache; _loggerClean = loggerClean; diff --git a/PluralKit.Bot/Handlers/MessageEdited.cs b/PluralKit.Bot/Handlers/MessageEdited.cs index 8791e554..b1423f3d 100644 --- a/PluralKit.Bot/Handlers/MessageEdited.cs +++ b/PluralKit.Bot/Handlers/MessageEdited.cs @@ -11,9 +11,9 @@ namespace PluralKit.Bot { private readonly LastMessageCacheService _lastMessageCache; private readonly ProxyService _proxy; - private readonly DbConnectionFactory _db; + private readonly Database _db; - public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, DbConnectionFactory db) + public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, Database db) { _lastMessageCache = lastMessageCache; _proxy = proxy; diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index 1971480d..3c1844c1 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -21,14 +21,14 @@ namespace PluralKit.Bot public static readonly TimeSpan MessageDeletionDelay = TimeSpan.FromMilliseconds(1000); private readonly LogChannelService _logChannel; - private readonly DbConnectionFactory _db; + private readonly Database _db; private readonly IDataStore _data; private readonly ILogger _logger; private readonly WebhookExecutorService _webhookExecutor; private readonly ProxyMatcher _matcher; public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger, - WebhookExecutorService webhookExecutor, DbConnectionFactory db, ProxyMatcher matcher) + WebhookExecutorService webhookExecutor, Database db, ProxyMatcher matcher) { _logChannel = logChannel; _data = data; @@ -43,7 +43,8 @@ namespace PluralKit.Bot if (!ShouldProxy(message, ctx)) return false; // Fetch members and try to match to a specific member - var members = (await _db.Execute(c => c.QueryProxyMembers(message.Author.Id, message.Channel.GuildId))).ToList(); + await using var conn = await _db.Obtain(); + var members = (await conn.QueryProxyMembers(message.Author.Id, message.Channel.GuildId)).ToList(); if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0, allowAutoproxy)) return false; @@ -52,7 +53,7 @@ namespace PluralKit.Bot if (!CheckProxyNameBoundsOrError(match.Member.ProxyName(ctx))) return false; // Everything's in order, we can execute the proxy! - await ExecuteProxy(message, ctx, match); + await ExecuteProxy(conn, message, ctx, match); return true; } @@ -78,29 +79,39 @@ namespace PluralKit.Bot return true; } - private async Task ExecuteProxy(DiscordMessage trigger, MessageContext ctx, ProxyMatch match) + private async Task ExecuteProxy(IPKConnection conn, DiscordMessage trigger, MessageContext ctx, + ProxyMatch match) { // Send the webhook var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName(ctx), match.Member.ProxyAvatar(ctx), match.Content, trigger.Attachments); - // Handle post-proxy actions - await _data.AddMessage(trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id, - match.Member.Id); - await _logChannel.LogMessage(ctx, match, trigger, id); - - // Wait a second or so before deleting the original message - await Task.Delay(MessageDeletionDelay); - try + + Task SaveMessage() => _data.AddMessage(conn, trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id, match.Member.Id); + Task LogMessage() => _logChannel.LogMessage(ctx, match, trigger, id).AsTask(); + async Task DeleteMessage() { - await trigger.DeleteAsync(); - } - catch (NotFoundException) - { - // If it's already deleted, we just log and swallow the exception - _logger.Warning("Attempted to delete already deleted proxy trigger message {Message}", trigger.Id); + // Wait a second or so before deleting the original message + await Task.Delay(MessageDeletionDelay); + try + { + await trigger.DeleteAsync(); + } + catch (NotFoundException) + { + // If it's already deleted, we just log and swallow the exception + _logger.Warning("Attempted to delete already deleted proxy trigger message {Message}", trigger.Id); + } } + + // Run post-proxy actions (simultaneously; order doesn't matter) + // Note that only AddMessage is using our passed-in connection, careful not to pass it elsewhere and run into conflicts + await Task.WhenAll( + DeleteMessage(), + SaveMessage(), + LogMessage() + ); } private async Task CheckBotPermissionsOrError(DiscordChannel channel) diff --git a/PluralKit.Bot/Services/EmbedService.cs b/PluralKit.Bot/Services/EmbedService.cs index 96b32cb4..7e469b9f 100644 --- a/PluralKit.Bot/Services/EmbedService.cs +++ b/PluralKit.Bot/Services/EmbedService.cs @@ -16,10 +16,10 @@ namespace PluralKit.Bot { public class EmbedService { private IDataStore _data; - private DbConnectionFactory _db; + private Database _db; private DiscordShardedClient _client; - public EmbedService(DiscordShardedClient client, IDataStore data, DbConnectionFactory db) + public EmbedService(DiscordShardedClient client, IDataStore data, Database db) { _client = client; _data = data; diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index 12099d33..93afafeb 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -14,12 +14,12 @@ using Serilog; namespace PluralKit.Bot { public class LogChannelService { private readonly EmbedService _embed; - private readonly DbConnectionFactory _db; + private readonly Database _db; private readonly IDataStore _data; private readonly ILogger _logger; private readonly DiscordRestClient _rest; - public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, DbConnectionFactory db, IDataStore data) + public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, Database db, IDataStore data) { _embed = embed; _rest = rest; diff --git a/PluralKit.Bot/Services/LoggerCleanService.cs b/PluralKit.Bot/Services/LoggerCleanService.cs index cf196a91..a4d9fdf3 100644 --- a/PluralKit.Bot/Services/LoggerCleanService.cs +++ b/PluralKit.Bot/Services/LoggerCleanService.cs @@ -53,10 +53,10 @@ namespace PluralKit.Bot .Where(b => b.WebhookName != null) .ToDictionary(b => b.WebhookName); - private DbConnectionFactory _db; + private Database _db; private DiscordShardedClient _client; - public LoggerCleanService(DbConnectionFactory db, DiscordShardedClient client) + public LoggerCleanService(Database db, DiscordShardedClient client) { _db = db; _client = client; diff --git a/PluralKit.Core/CoreConfig.cs b/PluralKit.Core/CoreConfig.cs index c33d1500..d036731a 100644 --- a/PluralKit.Core/CoreConfig.cs +++ b/PluralKit.Core/CoreConfig.cs @@ -1,3 +1,5 @@ +using Serilog.Events; + namespace PluralKit.Core { public class CoreConfig @@ -7,5 +9,8 @@ namespace PluralKit.Core public string InfluxUrl { get; set; } public string InfluxDb { get; set; } public string LogDir { get; set; } + + public LogEventLevel ConsoleLogLevel { get; set; } = LogEventLevel.Verbose; + public LogEventLevel FileLogLevel { get; set; } = LogEventLevel.Information; } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Database.cs b/PluralKit.Core/Database/Database.cs new file mode 100644 index 00000000..217845bd --- /dev/null +++ b/PluralKit.Core/Database/Database.cs @@ -0,0 +1,52 @@ +using System; +using System.Threading.Tasks; + +using App.Metrics; + +using Npgsql; + +using Serilog; + +namespace PluralKit.Core +{ + public class Database + { + private readonly CoreConfig _config; + private readonly ILogger _logger; + private readonly IMetrics _metrics; + private readonly DbConnectionCountHolder _countHolder; + + public Database(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger, + IMetrics metrics) + { + _config = config; + _countHolder = countHolder; + _metrics = metrics; + _logger = logger; + } + + public async Task Obtain() + { + // Mark the request (for a handle, I guess) in the metrics + _metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests); + + // Create a connection and open it + // We wrap it in PKConnection for tracing purposes + var conn = new PKConnection(new NpgsqlConnection(_config.Database), _countHolder, _logger, _metrics); + await conn.OpenAsync(); + return conn; + } + + public async Task Execute(Func func) + { + await using var conn = await Obtain(); + await func(conn); + } + + public async Task Execute(Func> func) + { + await using var conn = await Obtain(); + return await func(conn); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs b/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs index fd7d7728..9800fd88 100644 --- a/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs +++ b/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs @@ -8,14 +8,14 @@ namespace PluralKit.Core { public static class DatabaseFunctionsExt { - public static Task QueryMessageContext(this IDbConnection conn, ulong account, ulong guild, ulong channel) + public static Task QueryMessageContext(this IPKConnection conn, ulong account, ulong guild, ulong channel) { return conn.QueryFirstAsync("message_context", new { account_id = account, guild_id = guild, channel_id = channel }, commandType: CommandType.StoredProcedure); } - public static Task> QueryProxyMembers(this IDbConnection conn, ulong account, ulong guild) + public static Task> QueryProxyMembers(this IPKConnection conn, ulong account, ulong guild) { return conn.QueryAsync("proxy_members", new { account_id = account, guild_id = guild }, diff --git a/PluralKit.Core/Database/Schemas.cs b/PluralKit.Core/Database/Schemas.cs index 280fb214..49ba79b1 100644 --- a/PluralKit.Core/Database/Schemas.cs +++ b/PluralKit.Core/Database/Schemas.cs @@ -16,10 +16,10 @@ namespace PluralKit.Core private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files private const int TargetSchemaVersion = 7; - private DbConnectionFactory _conn; + private Database _conn; private ILogger _logger; - public Schemas(DbConnectionFactory conn, ILogger logger) + public Schemas(Database conn, ILogger logger) { _conn = conn; _logger = logger.ForContext(); @@ -36,7 +36,7 @@ namespace PluralKit.Core { // Run everything in a transaction await using var conn = await _conn.Obtain(); - using var tx = conn.BeginTransaction(); + await using var tx = await conn.BeginTransactionAsync(); // Before applying migrations, clean out views/functions to prevent type errors await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx); @@ -49,10 +49,10 @@ namespace PluralKit.Core await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx); // Finally, commit tx - tx.Commit(); + await tx.CommitAsync(); } - private async Task ApplyMigrations(IAsyncDbConnection conn, IDbTransaction tx) + private async Task ApplyMigrations(IPKConnection conn, IDbTransaction tx) { var currentVersion = await GetCurrentDatabaseVersion(conn); _logger.Information("Current schema version: {CurrentVersion}", currentVersion); @@ -63,7 +63,7 @@ namespace PluralKit.Core } } - private async Task ExecuteSqlFile(string resourceName, IDbConnection conn, IDbTransaction tx = null) + private async Task ExecuteSqlFile(string resourceName, IPKConnection conn, IDbTransaction tx = null) { await using var stream = typeof(Schemas).Assembly.GetManifestResourceStream(resourceName); if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'"); @@ -76,10 +76,10 @@ namespace PluralKit.Core // If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches // This will propagate to every other connection as well, since it marks the global type mapper collection dirty. // TODO: find a way to get around the cast to our internal tracker wrapper... this could break if that ever changes - ((PerformanceTrackingConnection) conn)._impl.ReloadTypes(); + conn.ReloadTypes(); } - private async Task GetCurrentDatabaseVersion(IDbConnection conn) + private async Task GetCurrentDatabaseVersion(IPKConnection conn) { // First, check if the "info" table exists (it may not, if this is a *really* old database) var hasInfoTable = diff --git a/PluralKit.Core/Database/Wrappers/IPKCommand.cs b/PluralKit.Core/Database/Wrappers/IPKCommand.cs new file mode 100644 index 00000000..3f814d6c --- /dev/null +++ b/PluralKit.Core/Database/Wrappers/IPKCommand.cs @@ -0,0 +1,17 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; + +namespace PluralKit.Core +{ + public interface IPKCommand: IDbCommand, IAsyncDisposable + { + public Task PrepareAsync(CancellationToken ct = default); + public Task ExecuteNonQueryAsync(CancellationToken ct = default); + public Task ExecuteScalarAsync(CancellationToken ct = default); + public Task ExecuteReaderAsync(CancellationToken ct = default); + public Task ExecuteReaderAsync(CommandBehavior behavior, CancellationToken ct = default); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/IPKConnection.cs b/PluralKit.Core/Database/Wrappers/IPKConnection.cs new file mode 100644 index 00000000..b83a5b77 --- /dev/null +++ b/PluralKit.Core/Database/Wrappers/IPKConnection.cs @@ -0,0 +1,34 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; + +using Npgsql; + +namespace PluralKit.Core +{ + public interface IPKConnection: IDbConnection, IAsyncDisposable + { + public Guid ConnectionId { get; } + + public Task OpenAsync(CancellationToken cancellationToken = default); + public Task CloseAsync(); + + public Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default); + + public ValueTask BeginTransactionAsync(CancellationToken ct = default) => BeginTransactionAsync(IsolationLevel.Unspecified, ct); + public ValueTask BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default); + + public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand); + public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand); + + public void ReloadTypes(); + + [Obsolete] new void Open(); + [Obsolete] new void Close(); + + [Obsolete] new IDbTransaction BeginTransaction(); + [Obsolete] new IDbTransaction BeginTransaction(IsolationLevel il); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/PKCommand.cs b/PluralKit.Core/Database/Wrappers/PKCommand.cs new file mode 100644 index 00000000..ed048750 --- /dev/null +++ b/PluralKit.Core/Database/Wrappers/PKCommand.cs @@ -0,0 +1,117 @@ +#nullable enable +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; + +using App.Metrics; + +using NodaTime; + +using Npgsql; + +using Serilog; + +namespace PluralKit.Core +{ + public class PKCommand: DbCommand, IPKCommand + { + private readonly NpgsqlCommand _inner; + private readonly PKConnection _ourConnection; + private readonly ILogger _logger; + private readonly IMetrics _metrics; + + public PKCommand(NpgsqlCommand inner, PKConnection ourConnection, ILogger logger, IMetrics metrics) + { + _inner = inner; + _ourConnection = ourConnection; + _logger = logger.ForContext(); + _metrics = metrics; + } + + public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery)); + public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar)); + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => throw SyncError(nameof(ExecuteDbDataReader)); + + public override Task ExecuteNonQueryAsync(CancellationToken ct) => LogQuery(_inner.ExecuteNonQueryAsync(ct)); + public override Task ExecuteScalarAsync(CancellationToken ct) => LogQuery(_inner.ExecuteScalarAsync(ct)); + protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) => await LogQuery(_inner.ExecuteReaderAsync(behavior, ct)); + + public override void Prepare() => _inner.Prepare(); + public override void Cancel() => _inner.Cancel(); + protected override DbParameter CreateDbParameter() => _inner.CreateParameter(); + + public override string CommandText + { + get => _inner.CommandText; + set => _inner.CommandText = value; + } + + public override int CommandTimeout + { + get => _inner.CommandTimeout; + set => _inner.CommandTimeout = value; + } + + public override CommandType CommandType + { + get => _inner.CommandType; + set => _inner.CommandType = value; + } + + public override UpdateRowSource UpdatedRowSource + { + get => _inner.UpdatedRowSource; + set => _inner.UpdatedRowSource = value; + } + + protected override DbParameterCollection DbParameterCollection => _inner.Parameters; + protected override DbTransaction? DbTransaction + { + get => _inner.Transaction; + set => _inner.Transaction = (NpgsqlTransaction?) value; + } + + public override bool DesignTimeVisible + { + get => _inner.DesignTimeVisible; + set => _inner.DesignTimeVisible = value; + } + + protected override DbConnection? DbConnection + { + get => _inner.Connection; + set => + _inner.Connection = value switch + { + NpgsqlConnection npg => npg, + PKConnection pk => pk.Inner, + _ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlConnection") + }; + } + + private async Task LogQuery(Task task) + { + var start = SystemClock.Instance.GetCurrentInstant(); + try + { + return await task; + } + finally + { + var end = SystemClock.Instance.GetCurrentInstant(); + var elapsed = end - start; + + _logger.Verbose("Executed query {Query} in {ElapsedTime} on connection {ConnectionId}", CommandText, elapsed, _ourConnection.ConnectionId); + + // One "BCL compatible tick" is 100 nanoseconds + var micros = elapsed.BclCompatibleTicks / 10; + _metrics.Provider.Timer.Instance(CoreMetrics.DatabaseQuery, new MetricTags("query", CommandText)) + .Record(micros, TimeUnit.Microseconds, CommandText); + } + } + + private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKCommand function {caller}!"); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/PKConnection.cs b/PluralKit.Core/Database/Wrappers/PKConnection.cs new file mode 100644 index 00000000..46947fcc --- /dev/null +++ b/PluralKit.Core/Database/Wrappers/PKConnection.cs @@ -0,0 +1,107 @@ +#nullable enable +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; + +using App.Metrics; + +using NodaTime; + +using Npgsql; + +using Serilog; + +namespace PluralKit.Core +{ + public class PKConnection: DbConnection, IPKConnection + { + public NpgsqlConnection Inner { get; } + public Guid ConnectionId { get; } + + private readonly DbConnectionCountHolder _countHolder; + private readonly ILogger _logger; + private readonly IMetrics _metrics; + + private bool _hasOpened; + private bool _hasClosed; + private Instant _openTime; + + public PKConnection(NpgsqlConnection inner, DbConnectionCountHolder countHolder, ILogger logger, IMetrics metrics) + { + Inner = inner; + ConnectionId = Guid.NewGuid(); + _countHolder = countHolder; + _logger = logger.ForContext(); + _metrics = metrics; + } + + public override Task OpenAsync(CancellationToken ct) + { + if (_hasOpened) return Inner.OpenAsync(ct); + _countHolder.Increment(); + _hasOpened = true; + _openTime = SystemClock.Instance.GetCurrentInstant(); + _logger.Verbose("Opened database connection {ConnectionId}, new connection count {ConnectionCount}", ConnectionId, _countHolder.ConnectionCount); + return Inner.OpenAsync(ct); + } + + public override Task CloseAsync() => Inner.CloseAsync(); + + protected override DbCommand CreateDbCommand() => new PKCommand(Inner.CreateCommand(), this, _logger, _metrics); + + public void ReloadTypes() => Inner.ReloadTypes(); + + public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) => Inner.BeginBinaryImport(copyFromCommand); + public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) => Inner.BeginBinaryExport(copyToCommand); + + public override void ChangeDatabase(string databaseName) => Inner.ChangeDatabase(databaseName); + public override Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default) => Inner.ChangeDatabaseAsync(databaseName, ct); + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw SyncError(nameof(BeginDbTransaction)); + protected override async ValueTask BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => await Inner.BeginTransactionAsync(level, ct); + + public override void Open() => throw SyncError(nameof(Open)); + public override void Close() => throw SyncError(nameof(Close)); + + IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction)); + IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction)); + + public override string ConnectionString + { + get => Inner.ConnectionString; + set => Inner.ConnectionString = value; + } + + public override string? Database => Inner.Database; + public override ConnectionState State => Inner.State; + public override string DataSource => Inner.DataSource; + public override string ServerVersion => Inner.ServerVersion; + + protected override void Dispose(bool disposing) + { + Inner.Dispose(); + if (_hasClosed) return; + + LogClose(); + } + + public override ValueTask DisposeAsync() + { + if (_hasClosed) return Inner.DisposeAsync(); + LogClose(); + return Inner.DisposeAsync(); + } + + private void LogClose() + { + _countHolder.Decrement(); + _hasClosed = true; + + var duration = SystemClock.Instance.GetCurrentInstant() - _openTime; + _logger.Verbose("Closed database connection {ConnectionId} (open for {ConnectionDuration}), new connection count {ConnectionCount}", ConnectionId, duration, _countHolder.ConnectionCount); + } + + private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKConnection function {caller}!"); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Models/ModelQueryExt.cs b/PluralKit.Core/Models/ModelQueryExt.cs index 3cac2599..726bdde6 100644 --- a/PluralKit.Core/Models/ModelQueryExt.cs +++ b/PluralKit.Core/Models/ModelQueryExt.cs @@ -1,5 +1,4 @@ #nullable enable -using System; using System.Data; using System.Threading.Tasks; @@ -9,19 +8,19 @@ namespace PluralKit.Core { public static class ModelQueryExt { - public static Task QueryMember(this IDbConnection conn, int id) => + public static Task QueryMember(this IPKConnection conn, int id) => conn.QueryFirstOrDefaultAsync("select * from members where id = @id", new {id}); - public static Task QueryOrInsertGuildConfig(this IDbConnection conn, ulong guild) => + public static Task QueryOrInsertGuildConfig(this IPKConnection conn, ulong guild) => conn.QueryFirstAsync("insert into servers (id) values (@Guild) on conflict do nothing returning *", new {Guild = guild}); - public static Task QueryOrInsertSystemGuildConfig(this IDbConnection conn, ulong guild, int system) => + public static Task QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, int system) => conn.QueryFirstAsync( "insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *", new {guild, system}); public static Task QueryOrInsertMemberGuildConfig( - this IDbConnection conn, ulong guild, int member) => + this IPKConnection conn, ulong guild, int member) => conn.QueryFirstAsync( "insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *", new {guild, member}); diff --git a/PluralKit.Core/Modules.cs b/PluralKit.Core/Modules.cs index bd4b82db..6229286c 100644 --- a/PluralKit.Core/Modules.cs +++ b/PluralKit.Core/Modules.cs @@ -23,7 +23,7 @@ namespace PluralKit.Core protected override void Load(ContainerBuilder builder) { builder.RegisterType().SingleInstance(); - builder.RegisterType().AsSelf().SingleInstance(); + builder.RegisterType().AsSelf().SingleInstance(); builder.RegisterType().AsSelf().As(); builder.RegisterType().AsSelf(); @@ -99,7 +99,7 @@ namespace PluralKit.Core return new LoggerConfiguration() .ConfigureForNodaTime(DateTimeZoneProviders.Tzdb) - .MinimumLevel.Debug() + .MinimumLevel.Is(config.ConsoleLogLevel) .WriteTo.Async(a => { // Both the same output, except one is raw compact JSON and one is plain text. @@ -110,7 +110,7 @@ namespace PluralKit.Core outputTemplate: outputTemplate, rollingInterval: RollingInterval.Day, flushToDiskInterval: TimeSpan.FromMilliseconds(50), - restrictedToMinimumLevel: LogEventLevel.Information, + restrictedToMinimumLevel: config.FileLogLevel, formatProvider: new UTCTimestampFormatProvider(), buffered: true); @@ -119,7 +119,7 @@ namespace PluralKit.Core (config.LogDir ?? "logs") + $"/pluralkit.{_component}.json", rollingInterval: RollingInterval.Day, flushToDiskInterval: TimeSpan.FromMilliseconds(50), - restrictedToMinimumLevel: LogEventLevel.Information, + restrictedToMinimumLevel: config.FileLogLevel, buffered: true); }) // TODO: render as UTC in the console, too? or just in log files diff --git a/PluralKit.Core/Services/DataFileService.cs b/PluralKit.Core/Services/DataFileService.cs index cce2effa..6cd7cb87 100644 --- a/PluralKit.Core/Services/DataFileService.cs +++ b/PluralKit.Core/Services/DataFileService.cs @@ -15,10 +15,10 @@ namespace PluralKit.Core public class DataFileService { private IDataStore _data; - private DbConnectionFactory _db; + private Database _db; private ILogger _logger; - public DataFileService(ILogger logger, IDataStore data, DbConnectionFactory db) + public DataFileService(ILogger logger, IDataStore data, Database db) { _data = data; _db = db; @@ -127,8 +127,8 @@ namespace PluralKit.Core await _data.SaveSystem(system); // -- Member/switch import -- - await using var conn = (PerformanceTrackingConnection) await _db.Obtain(); - await using (var imp = await BulkImporter.Begin(system, conn._impl)) + await using var conn = await _db.Obtain(); + await using (var imp = await BulkImporter.Begin(system, conn)) { // Tally up the members that didn't exist before, and check member count on import // If creating the unmatched members would put us over the member limit, abort before creating any members diff --git a/PluralKit.Core/Services/IDataStore.cs b/PluralKit.Core/Services/IDataStore.cs index 6c90336b..94f2aec9 100644 --- a/PluralKit.Core/Services/IDataStore.cs +++ b/PluralKit.Core/Services/IDataStore.cs @@ -201,7 +201,7 @@ namespace PluralKit.Core { /// The ID of the original trigger message containing the proxy tags. /// The member (and by extension system) that was proxied. /// - Task AddMessage(ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId); + Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId); /// /// Deletes a message from the data store. diff --git a/PluralKit.Core/Services/PostgresDataStore.cs b/PluralKit.Core/Services/PostgresDataStore.cs index 80775cef..70562071 100644 --- a/PluralKit.Core/Services/PostgresDataStore.cs +++ b/PluralKit.Core/Services/PostgresDataStore.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -10,10 +10,10 @@ using Serilog; namespace PluralKit.Core { public class PostgresDataStore: IDataStore { - private DbConnectionFactory _conn; + private Database _conn; private ILogger _logger; - public PostgresDataStore(DbConnectionFactory conn, ILogger logger) + public PostgresDataStore(Database conn, ILogger logger) { _conn = conn; _logger = logger; @@ -182,17 +182,17 @@ namespace PluralKit.Core { using (var conn = await _conn.Obtain()) return await conn.ExecuteScalarAsync("select count(id) from members"); } - public async Task AddMessage(ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId) { - using (var conn = await _conn.Obtain()) - // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before - await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new { - MessageId = postedMessageId, - GuildId = guildId, - ChannelId = channelId, - MemberId = proxiedMemberId, - SenderId = senderId, - OriginalMid = triggerMessageId - }); + + public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId) { + // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before + await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new { + MessageId = postedMessageId, + GuildId = guildId, + ChannelId = channelId, + MemberId = proxiedMemberId, + SenderId = senderId, + OriginalMid = triggerMessageId + }); _logger.Debug("Stored message {Message} in channel {Channel}", postedMessageId, channelId); } @@ -235,27 +235,26 @@ namespace PluralKit.Core { public async Task AddSwitch(PKSystem system, IEnumerable members) { // Use a transaction here since we're doing multiple executed commands in one - using (var conn = await _conn.Obtain()) - using (var tx = conn.BeginTransaction()) + await using var conn = await _conn.Obtain(); + using var tx = await conn.BeginTransactionAsync(); + + // First, we insert the switch itself + var sw = await conn.QuerySingleAsync("insert into switches(system) values (@System) returning *", + new {System = system.Id}); + + // Then we insert each member in the switch in the switch_members table + // TODO: can we parallelize this or send it in bulk somehow? + foreach (var member in members) { - // First, we insert the switch itself - var sw = await conn.QuerySingleAsync("insert into switches(system) values (@System) returning *", - new {System = system.Id}); - - // Then we insert each member in the switch in the switch_members table - // TODO: can we parallelize this or send it in bulk somehow? - foreach (var member in members) - { - await conn.ExecuteAsync( - "insert into switch_members(switch, member) values(@Switch, @Member)", - new {Switch = sw.Id, Member = member.Id}); - } - - // Finally we commit the tx, since the using block will otherwise rollback it - tx.Commit(); - - _logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id)); + await conn.ExecuteAsync( + "insert into switch_members(switch, member) values(@Switch, @Member)", + new {Switch = sw.Id, Member = member.Id}); } + + // Finally we commit the tx, since the using block will otherwise rollback it + tx.Commit(); + + _logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id)); } public IAsyncEnumerable GetSwitches(PKSystem system) @@ -276,8 +275,8 @@ namespace PluralKit.Core { public async IAsyncEnumerable GetSwitchMembersList(PKSystem system, Instant start, Instant end) { // Wrap multiple commands in a single transaction for performance - using var conn = await _conn.Obtain(); - using var tx = conn.BeginTransaction(); + await using var conn = await _conn.Obtain(); + await using var tx = await conn.BeginTransactionAsync(); // Find the time of the last switch outside the range as it overlaps the range // If no prior switch exists, the lower bound of the range remains the start time diff --git a/PluralKit.Core/Utils/BulkImporter.cs b/PluralKit.Core/Utils/BulkImporter.cs index 99323097..196f1267 100644 --- a/PluralKit.Core/Utils/BulkImporter.cs +++ b/PluralKit.Core/Utils/BulkImporter.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Data; +using System.Data.Common; using System.Linq; using System.Threading.Tasks; @@ -19,20 +20,20 @@ namespace PluralKit.Core public class BulkImporter: IAsyncDisposable { private readonly int _systemId; - private readonly NpgsqlConnection _conn; - private readonly NpgsqlTransaction _tx; + private readonly IPKConnection _conn; + private readonly DbTransaction _tx; private readonly Dictionary _knownMembers = new Dictionary(); private readonly Dictionary _existingMembersByHid = new Dictionary(); private readonly Dictionary _existingMembersByName = new Dictionary(); - private BulkImporter(int systemId, NpgsqlConnection conn, NpgsqlTransaction tx) + private BulkImporter(int systemId, IPKConnection conn, DbTransaction tx) { _systemId = systemId; _conn = conn; _tx = tx; } - public static async Task Begin(PKSystem system, NpgsqlConnection conn) + public static async Task Begin(PKSystem system, IPKConnection conn) { var tx = await conn.BeginTransactionAsync(); var importer = new BulkImporter(system.Id, conn, tx); diff --git a/PluralKit.Core/Utils/ConnectionUtils.cs b/PluralKit.Core/Utils/ConnectionUtils.cs index 544a8070..d133b8d5 100644 --- a/PluralKit.Core/Utils/ConnectionUtils.cs +++ b/PluralKit.Core/Utils/ConnectionUtils.cs @@ -7,21 +7,21 @@ using Dapper; namespace PluralKit.Core { public static class ConnectionUtils { - public static async IAsyncEnumerable QueryStreamAsync(this DbConnectionFactory connFactory, string sql, object param) + public static async IAsyncEnumerable QueryStreamAsync(this Database connFactory, string sql, object param) { - using var conn = await connFactory.Obtain(); + await using var conn = await connFactory.Obtain(); await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param); var parser = reader.GetRowParser(); - while (reader.Read()) + while (await reader.ReadAsync()) yield return parser(reader); } - public static async IAsyncEnumerable QueryStreamAsync(this IDbConnection conn, string sql, object param) + public static async IAsyncEnumerable QueryStreamAsync(this IPKConnection conn, string sql, object param) { await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param); var parser = reader.GetRowParser(); - while (reader.Read()) + while (await reader.ReadAsync()) yield return parser(reader); } } diff --git a/PluralKit.Core/Utils/DatabaseUtils.cs b/PluralKit.Core/Utils/DatabaseUtils.cs index 4ef869eb..70849f31 100644 --- a/PluralKit.Core/Utils/DatabaseUtils.cs +++ b/PluralKit.Core/Utils/DatabaseUtils.cs @@ -1,234 +1,10 @@ -using System; -using System.Data; -using System.Data.Common; -using System.Diagnostics; +using System.Data; using System.Threading; -using System.Threading.Tasks; - -using App.Metrics; using Dapper; -using Npgsql; - -using Serilog; - namespace PluralKit.Core { - public class QueryLogger : IDisposable - { - private ILogger _logger; - private IMetrics _metrics; - private string _commandText; - private Stopwatch _stopwatch; - - public QueryLogger(ILogger logger, IMetrics metrics, string commandText) - { - _metrics = metrics; - _commandText = commandText; - _logger = logger; - - _stopwatch = new Stopwatch(); - _stopwatch.Start(); - } - - public void Dispose() - { - _stopwatch.Stop(); - _logger.Verbose("Executed query {Query} in {ElapsedTime}", _commandText, _stopwatch.Elapsed); - - // One tick is 100 nanoseconds - _metrics.Provider.Timer.Instance(CoreMetrics.DatabaseQuery, new MetricTags("query", _commandText)) - .Record(_stopwatch.ElapsedTicks / 10, TimeUnit.Microseconds, _commandText); - } - } - - public class PerformanceTrackingCommand: DbCommand - { - private NpgsqlCommand _impl; - private ILogger _logger; - private IMetrics _metrics; - - public PerformanceTrackingCommand(NpgsqlCommand impl, ILogger logger, IMetrics metrics) - { - _impl = impl; - _metrics = metrics; - _logger = logger; - } - - public override void Cancel() - { - _impl.Cancel(); - } - - public override int ExecuteNonQuery() - { - return _impl.ExecuteNonQuery(); - } - - public override object ExecuteScalar() - { - return _impl.ExecuteScalar(); - } - - public override void Prepare() - { - _impl.Prepare(); - } - - public override string CommandText - { - get => _impl.CommandText; - set => _impl.CommandText = value; - } - - public override int CommandTimeout - { - get => _impl.CommandTimeout; - set => _impl.CommandTimeout = value; - } - - public override CommandType CommandType - { - get => _impl.CommandType; - set => _impl.CommandType = value; - } - - public override UpdateRowSource UpdatedRowSource - { - get => _impl.UpdatedRowSource; - set => _impl.UpdatedRowSource = value; - } - - protected override DbConnection DbConnection - { - get => _impl.Connection; - set => _impl.Connection = (NpgsqlConnection) value; - } - - protected override DbParameterCollection DbParameterCollection => _impl.Parameters; - - protected override DbTransaction DbTransaction - { - get => _impl.Transaction; - set => _impl.Transaction = (NpgsqlTransaction) value; - } - - public override bool DesignTimeVisible - { - get => _impl.DesignTimeVisible; - set => _impl.DesignTimeVisible = value; - } - - protected override DbParameter CreateDbParameter() - { - return _impl.CreateParameter(); - } - - protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) - { - return _impl.ExecuteReader(behavior); - } - - private IDisposable LogQuery() - { - return new QueryLogger(_logger, _metrics, CommandText); - } - - protected override async Task ExecuteDbDataReaderAsync( - CommandBehavior behavior, CancellationToken cancellationToken) - { - using (LogQuery()) - return await _impl.ExecuteReaderAsync(behavior, cancellationToken); - } - - public override async Task ExecuteNonQueryAsync(CancellationToken cancellationToken) - { - using (LogQuery()) - return await _impl.ExecuteNonQueryAsync(cancellationToken); - } - - public override async Task ExecuteScalarAsync(CancellationToken cancellationToken) - { - using (LogQuery()) - return await _impl.ExecuteScalarAsync(cancellationToken); - } - } - - public class PerformanceTrackingConnection: IAsyncDbConnection - { - // Simple delegation of everything. - internal NpgsqlConnection _impl; - - private DbConnectionCountHolder _countHolder; - private ILogger _logger; - private IMetrics _metrics; - - public PerformanceTrackingConnection(NpgsqlConnection impl, DbConnectionCountHolder countHolder, - ILogger logger, IMetrics metrics) - { - _impl = impl; - _countHolder = countHolder; - _logger = logger; - _metrics = metrics; - } - - public void Dispose() - { - _impl.Dispose(); - - _countHolder.Decrement(); - } - - public IDbTransaction BeginTransaction() - { - return _impl.BeginTransaction(); - } - - public IDbTransaction BeginTransaction(IsolationLevel il) - { - return _impl.BeginTransaction(il); - } - - public void ChangeDatabase(string databaseName) - { - _impl.ChangeDatabase(databaseName); - } - - public void Close() - { - _impl.Close(); - } - - public IDbCommand CreateCommand() - { - return new PerformanceTrackingCommand(_impl.CreateCommand(), _logger, _metrics); - } - - public void Open() - { - _impl.Open(); - } - - public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) - { - return _impl.BeginBinaryImport(copyFromCommand); - } - - public string ConnectionString - { - get => _impl.ConnectionString; - set => _impl.ConnectionString = value; - } - - public int ConnectionTimeout => _impl.ConnectionTimeout; - - public string Database => _impl.Database; - - public ConnectionState State => _impl.State; - public ValueTask DisposeAsync() => _impl.DisposeAsync(); - } - public class DbConnectionCountHolder { private int _connectionCount; @@ -245,43 +21,6 @@ namespace PluralKit.Core } } - public interface IAsyncDbConnection: IDbConnection, IAsyncDisposable - { - - } - - public class DbConnectionFactory - { - private CoreConfig _config; - private ILogger _logger; - private IMetrics _metrics; - private DbConnectionCountHolder _countHolder; - - public DbConnectionFactory(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger, - IMetrics metrics) - { - _config = config; - _countHolder = countHolder; - _metrics = metrics; - _logger = logger; - } - - public async Task Obtain() - { - // Mark the request (for a handle, I guess) in the metrics - _metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests); - - // Actually create and try to open the connection - var conn = new NpgsqlConnection(_config.Database); - await conn.OpenAsync(); - - // Increment the count - _countHolder.Increment(); - // Return a wrapped connection which will decrement the counter on dispose - return new PerformanceTrackingConnection(conn, _countHolder, _logger, _metrics); - } - } - public class PassthroughTypeHandler: SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, T value) @@ -308,18 +47,4 @@ namespace PluralKit.Core parameter.Value = (long) value; } } - - public static class DatabaseExt - { - public static async Task Execute(this DbConnectionFactory db, Func func) - { - await using var conn = await db.Obtain(); - await func(conn); - } - public static async Task Execute(this DbConnectionFactory db, Func> func) - { - await using var conn = await db.Obtain(); - return await func(conn); - } - } } \ No newline at end of file