Several more database-y refactors

- DbConnectionFactory renamed to "Database", will now be the primary entry point for DB stuff
- Created IPKConnection interface mostly containing async extensions to IDbConnection, use this going forward
- Reworked the Connection/Command wrappers (that have performance/logging extensions)
- Probably more stuff that I forgot???
This commit is contained in:
Ske 2020-06-13 18:31:20 +02:00
parent a915ddb41c
commit e176ccbab5
29 changed files with 454 additions and 387 deletions

View File

@ -40,10 +40,10 @@ namespace PluralKit.API
public class SystemController : ControllerBase public class SystemController : ControllerBase
{ {
private IDataStore _data; private IDataStore _data;
private DbConnectionFactory _conn; private Database _conn;
private TokenAuthService _auth; private TokenAuthService _auth;
public SystemController(IDataStore data, DbConnectionFactory conn, TokenAuthService auth) public SystemController(IDataStore data, Database conn, TokenAuthService auth)
{ {
_data = data; _data = data;
_conn = conn; _conn = conn;

View File

@ -11,9 +11,9 @@ namespace PluralKit.Bot
{ {
public class Autoproxy public class Autoproxy
{ {
private readonly DbConnectionFactory _db; private readonly Database _db;
public Autoproxy(DbConnectionFactory db) public Autoproxy(Database db)
{ {
_db = db; _db = db;
} }

View File

@ -14,9 +14,9 @@ namespace PluralKit.Bot
{ {
public class MemberAvatar public class MemberAvatar
{ {
private readonly DbConnectionFactory _db; private readonly Database _db;
public MemberAvatar(DbConnectionFactory db) public MemberAvatar(Database db)
{ {
_db = db; _db = db;
} }

View File

@ -12,9 +12,9 @@ namespace PluralKit.Bot
public class MemberEdit public class MemberEdit
{ {
private readonly IDataStore _data; private readonly IDataStore _data;
private readonly DbConnectionFactory _db; private readonly Database _db;
public MemberEdit(IDataStore data, DbConnectionFactory db) public MemberEdit(IDataStore data, Database db)
{ {
_data = data; _data = data;
_db = db; _db = db;

View File

@ -13,9 +13,9 @@ namespace PluralKit.Bot
{ {
public class ServerConfig public class ServerConfig
{ {
private DbConnectionFactory _db; private Database _db;
private LoggerCleanService _cleanService; private LoggerCleanService _cleanService;
public ServerConfig(LoggerCleanService cleanService, DbConnectionFactory db) public ServerConfig(LoggerCleanService cleanService, Database db)
{ {
_cleanService = cleanService; _cleanService = cleanService;
_db = db; _db = db;

View File

@ -18,10 +18,10 @@ namespace PluralKit.Bot
public class SystemEdit public class SystemEdit
{ {
private IDataStore _data; private IDataStore _data;
private DbConnectionFactory _db; private Database _db;
private EmbedService _embeds; private EmbedService _embeds;
public SystemEdit(IDataStore data, EmbedService embeds, DbConnectionFactory db) public SystemEdit(IDataStore data, EmbedService embeds, Database db)
{ {
_data = data; _data = data;
_embeds = embeds; _embeds = embeds;

View File

@ -16,10 +16,10 @@ namespace PluralKit.Bot
public class SystemList public class SystemList
{ {
private readonly IClock _clock; private readonly IClock _clock;
private readonly DbConnectionFactory _db; private readonly Database _db;
private readonly ILogger _logger; private readonly ILogger _logger;
public SystemList(DbConnectionFactory db, ILogger logger, IClock clock) public SystemList(Database db, ILogger logger, IClock clock)
{ {
_db = db; _db = db;
_logger = logger; _logger = logger;

View File

@ -22,12 +22,12 @@ namespace PluralKit.Bot
private readonly IMetrics _metrics; private readonly IMetrics _metrics;
private readonly ProxyService _proxy; private readonly ProxyService _proxy;
private readonly ILifetimeScope _services; private readonly ILifetimeScope _services;
private readonly DbConnectionFactory _db; private readonly Database _db;
private readonly IDataStore _data; private readonly IDataStore _data;
public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean,
IMetrics metrics, ProxyService proxy, DiscordShardedClient client, IMetrics metrics, ProxyService proxy, DiscordShardedClient client,
CommandTree tree, ILifetimeScope services, DbConnectionFactory db, IDataStore data) CommandTree tree, ILifetimeScope services, Database db, IDataStore data)
{ {
_lastMessageCache = lastMessageCache; _lastMessageCache = lastMessageCache;
_loggerClean = loggerClean; _loggerClean = loggerClean;

View File

@ -11,9 +11,9 @@ namespace PluralKit.Bot
{ {
private readonly LastMessageCacheService _lastMessageCache; private readonly LastMessageCacheService _lastMessageCache;
private readonly ProxyService _proxy; private readonly ProxyService _proxy;
private readonly DbConnectionFactory _db; private readonly Database _db;
public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, DbConnectionFactory db) public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, Database db)
{ {
_lastMessageCache = lastMessageCache; _lastMessageCache = lastMessageCache;
_proxy = proxy; _proxy = proxy;

View File

@ -21,14 +21,14 @@ namespace PluralKit.Bot
public static readonly TimeSpan MessageDeletionDelay = TimeSpan.FromMilliseconds(1000); public static readonly TimeSpan MessageDeletionDelay = TimeSpan.FromMilliseconds(1000);
private readonly LogChannelService _logChannel; private readonly LogChannelService _logChannel;
private readonly DbConnectionFactory _db; private readonly Database _db;
private readonly IDataStore _data; private readonly IDataStore _data;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly WebhookExecutorService _webhookExecutor; private readonly WebhookExecutorService _webhookExecutor;
private readonly ProxyMatcher _matcher; private readonly ProxyMatcher _matcher;
public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger, public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger,
WebhookExecutorService webhookExecutor, DbConnectionFactory db, ProxyMatcher matcher) WebhookExecutorService webhookExecutor, Database db, ProxyMatcher matcher)
{ {
_logChannel = logChannel; _logChannel = logChannel;
_data = data; _data = data;
@ -43,7 +43,8 @@ namespace PluralKit.Bot
if (!ShouldProxy(message, ctx)) return false; if (!ShouldProxy(message, ctx)) return false;
// Fetch members and try to match to a specific member // Fetch members and try to match to a specific member
var members = (await _db.Execute(c => c.QueryProxyMembers(message.Author.Id, message.Channel.GuildId))).ToList(); await using var conn = await _db.Obtain();
var members = (await conn.QueryProxyMembers(message.Author.Id, message.Channel.GuildId)).ToList();
if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0, if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0,
allowAutoproxy)) return false; allowAutoproxy)) return false;
@ -52,7 +53,7 @@ namespace PluralKit.Bot
if (!CheckProxyNameBoundsOrError(match.Member.ProxyName(ctx))) return false; if (!CheckProxyNameBoundsOrError(match.Member.ProxyName(ctx))) return false;
// Everything's in order, we can execute the proxy! // Everything's in order, we can execute the proxy!
await ExecuteProxy(message, ctx, match); await ExecuteProxy(conn, message, ctx, match);
return true; return true;
} }
@ -78,18 +79,19 @@ namespace PluralKit.Bot
return true; return true;
} }
private async Task ExecuteProxy(DiscordMessage trigger, MessageContext ctx, ProxyMatch match) private async Task ExecuteProxy(IPKConnection conn, DiscordMessage trigger, MessageContext ctx,
ProxyMatch match)
{ {
// Send the webhook // Send the webhook
var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName(ctx), var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName(ctx),
match.Member.ProxyAvatar(ctx), match.Member.ProxyAvatar(ctx),
match.Content, trigger.Attachments); match.Content, trigger.Attachments);
// Handle post-proxy actions
await _data.AddMessage(trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id,
match.Member.Id);
await _logChannel.LogMessage(ctx, match, trigger, id);
Task SaveMessage() => _data.AddMessage(conn, trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id, match.Member.Id);
Task LogMessage() => _logChannel.LogMessage(ctx, match, trigger, id).AsTask();
async Task DeleteMessage()
{
// Wait a second or so before deleting the original message // Wait a second or so before deleting the original message
await Task.Delay(MessageDeletionDelay); await Task.Delay(MessageDeletionDelay);
try try
@ -103,6 +105,15 @@ namespace PluralKit.Bot
} }
} }
// Run post-proxy actions (simultaneously; order doesn't matter)
// Note that only AddMessage is using our passed-in connection, careful not to pass it elsewhere and run into conflicts
await Task.WhenAll(
DeleteMessage(),
SaveMessage(),
LogMessage()
);
}
private async Task<bool> CheckBotPermissionsOrError(DiscordChannel channel) private async Task<bool> CheckBotPermissionsOrError(DiscordChannel channel)
{ {
var permissions = channel.BotPermissions(); var permissions = channel.BotPermissions();

View File

@ -16,10 +16,10 @@ namespace PluralKit.Bot {
public class EmbedService public class EmbedService
{ {
private IDataStore _data; private IDataStore _data;
private DbConnectionFactory _db; private Database _db;
private DiscordShardedClient _client; private DiscordShardedClient _client;
public EmbedService(DiscordShardedClient client, IDataStore data, DbConnectionFactory db) public EmbedService(DiscordShardedClient client, IDataStore data, Database db)
{ {
_client = client; _client = client;
_data = data; _data = data;

View File

@ -14,12 +14,12 @@ using Serilog;
namespace PluralKit.Bot { namespace PluralKit.Bot {
public class LogChannelService { public class LogChannelService {
private readonly EmbedService _embed; private readonly EmbedService _embed;
private readonly DbConnectionFactory _db; private readonly Database _db;
private readonly IDataStore _data; private readonly IDataStore _data;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly DiscordRestClient _rest; private readonly DiscordRestClient _rest;
public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, DbConnectionFactory db, IDataStore data) public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, Database db, IDataStore data)
{ {
_embed = embed; _embed = embed;
_rest = rest; _rest = rest;

View File

@ -53,10 +53,10 @@ namespace PluralKit.Bot
.Where(b => b.WebhookName != null) .Where(b => b.WebhookName != null)
.ToDictionary(b => b.WebhookName); .ToDictionary(b => b.WebhookName);
private DbConnectionFactory _db; private Database _db;
private DiscordShardedClient _client; private DiscordShardedClient _client;
public LoggerCleanService(DbConnectionFactory db, DiscordShardedClient client) public LoggerCleanService(Database db, DiscordShardedClient client)
{ {
_db = db; _db = db;
_client = client; _client = client;

View File

@ -1,3 +1,5 @@
using Serilog.Events;
namespace PluralKit.Core namespace PluralKit.Core
{ {
public class CoreConfig public class CoreConfig
@ -7,5 +9,8 @@ namespace PluralKit.Core
public string InfluxUrl { get; set; } public string InfluxUrl { get; set; }
public string InfluxDb { get; set; } public string InfluxDb { get; set; }
public string LogDir { get; set; } public string LogDir { get; set; }
public LogEventLevel ConsoleLogLevel { get; set; } = LogEventLevel.Verbose;
public LogEventLevel FileLogLevel { get; set; } = LogEventLevel.Information;
} }
} }

View File

@ -0,0 +1,52 @@
using System;
using System.Threading.Tasks;
using App.Metrics;
using Npgsql;
using Serilog;
namespace PluralKit.Core
{
public class Database
{
private readonly CoreConfig _config;
private readonly ILogger _logger;
private readonly IMetrics _metrics;
private readonly DbConnectionCountHolder _countHolder;
public Database(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger,
IMetrics metrics)
{
_config = config;
_countHolder = countHolder;
_metrics = metrics;
_logger = logger;
}
public async Task<IPKConnection> Obtain()
{
// Mark the request (for a handle, I guess) in the metrics
_metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests);
// Create a connection and open it
// We wrap it in PKConnection for tracing purposes
var conn = new PKConnection(new NpgsqlConnection(_config.Database), _countHolder, _logger, _metrics);
await conn.OpenAsync();
return conn;
}
public async Task Execute(Func<IPKConnection, Task> func)
{
await using var conn = await Obtain();
await func(conn);
}
public async Task<T> Execute<T>(Func<IPKConnection, Task<T>> func)
{
await using var conn = await Obtain();
return await func(conn);
}
}
}

View File

@ -8,14 +8,14 @@ namespace PluralKit.Core
{ {
public static class DatabaseFunctionsExt public static class DatabaseFunctionsExt
{ {
public static Task<MessageContext> QueryMessageContext(this IDbConnection conn, ulong account, ulong guild, ulong channel) public static Task<MessageContext> QueryMessageContext(this IPKConnection conn, ulong account, ulong guild, ulong channel)
{ {
return conn.QueryFirstAsync<MessageContext>("message_context", return conn.QueryFirstAsync<MessageContext>("message_context",
new { account_id = account, guild_id = guild, channel_id = channel }, new { account_id = account, guild_id = guild, channel_id = channel },
commandType: CommandType.StoredProcedure); commandType: CommandType.StoredProcedure);
} }
public static Task<IEnumerable<ProxyMember>> QueryProxyMembers(this IDbConnection conn, ulong account, ulong guild) public static Task<IEnumerable<ProxyMember>> QueryProxyMembers(this IPKConnection conn, ulong account, ulong guild)
{ {
return conn.QueryAsync<ProxyMember>("proxy_members", return conn.QueryAsync<ProxyMember>("proxy_members",
new { account_id = account, guild_id = guild }, new { account_id = account, guild_id = guild },

View File

@ -16,10 +16,10 @@ namespace PluralKit.Core
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
private const int TargetSchemaVersion = 7; private const int TargetSchemaVersion = 7;
private DbConnectionFactory _conn; private Database _conn;
private ILogger _logger; private ILogger _logger;
public Schemas(DbConnectionFactory conn, ILogger logger) public Schemas(Database conn, ILogger logger)
{ {
_conn = conn; _conn = conn;
_logger = logger.ForContext<Schemas>(); _logger = logger.ForContext<Schemas>();
@ -36,7 +36,7 @@ namespace PluralKit.Core
{ {
// Run everything in a transaction // Run everything in a transaction
await using var conn = await _conn.Obtain(); await using var conn = await _conn.Obtain();
using var tx = conn.BeginTransaction(); await using var tx = await conn.BeginTransactionAsync();
// Before applying migrations, clean out views/functions to prevent type errors // Before applying migrations, clean out views/functions to prevent type errors
await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx); await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx);
@ -49,10 +49,10 @@ namespace PluralKit.Core
await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx); await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx);
// Finally, commit tx // Finally, commit tx
tx.Commit(); await tx.CommitAsync();
} }
private async Task ApplyMigrations(IAsyncDbConnection conn, IDbTransaction tx) private async Task ApplyMigrations(IPKConnection conn, IDbTransaction tx)
{ {
var currentVersion = await GetCurrentDatabaseVersion(conn); var currentVersion = await GetCurrentDatabaseVersion(conn);
_logger.Information("Current schema version: {CurrentVersion}", currentVersion); _logger.Information("Current schema version: {CurrentVersion}", currentVersion);
@ -63,7 +63,7 @@ namespace PluralKit.Core
} }
} }
private async Task ExecuteSqlFile(string resourceName, IDbConnection conn, IDbTransaction tx = null) private async Task ExecuteSqlFile(string resourceName, IPKConnection conn, IDbTransaction tx = null)
{ {
await using var stream = typeof(Schemas).Assembly.GetManifestResourceStream(resourceName); await using var stream = typeof(Schemas).Assembly.GetManifestResourceStream(resourceName);
if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'"); if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'");
@ -76,10 +76,10 @@ namespace PluralKit.Core
// If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches // If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches
// This will propagate to every other connection as well, since it marks the global type mapper collection dirty. // This will propagate to every other connection as well, since it marks the global type mapper collection dirty.
// TODO: find a way to get around the cast to our internal tracker wrapper... this could break if that ever changes // TODO: find a way to get around the cast to our internal tracker wrapper... this could break if that ever changes
((PerformanceTrackingConnection) conn)._impl.ReloadTypes(); conn.ReloadTypes();
} }
private async Task<int> GetCurrentDatabaseVersion(IDbConnection conn) private async Task<int> GetCurrentDatabaseVersion(IPKConnection conn)
{ {
// First, check if the "info" table exists (it may not, if this is a *really* old database) // First, check if the "info" table exists (it may not, if this is a *really* old database)
var hasInfoTable = var hasInfoTable =

View File

@ -0,0 +1,17 @@
using System;
using System.Data;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
namespace PluralKit.Core
{
public interface IPKCommand: IDbCommand, IAsyncDisposable
{
public Task PrepareAsync(CancellationToken ct = default);
public Task<int> ExecuteNonQueryAsync(CancellationToken ct = default);
public Task<object> ExecuteScalarAsync(CancellationToken ct = default);
public Task<DbDataReader> ExecuteReaderAsync(CancellationToken ct = default);
public Task<DbDataReader> ExecuteReaderAsync(CommandBehavior behavior, CancellationToken ct = default);
}
}

View File

@ -0,0 +1,34 @@
using System;
using System.Data;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using Npgsql;
namespace PluralKit.Core
{
public interface IPKConnection: IDbConnection, IAsyncDisposable
{
public Guid ConnectionId { get; }
public Task OpenAsync(CancellationToken cancellationToken = default);
public Task CloseAsync();
public Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default);
public ValueTask<DbTransaction> BeginTransactionAsync(CancellationToken ct = default) => BeginTransactionAsync(IsolationLevel.Unspecified, ct);
public ValueTask<DbTransaction> BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default);
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand);
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand);
public void ReloadTypes();
[Obsolete] new void Open();
[Obsolete] new void Close();
[Obsolete] new IDbTransaction BeginTransaction();
[Obsolete] new IDbTransaction BeginTransaction(IsolationLevel il);
}
}

View File

@ -0,0 +1,117 @@
#nullable enable
using System;
using System.Data;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using App.Metrics;
using NodaTime;
using Npgsql;
using Serilog;
namespace PluralKit.Core
{
public class PKCommand: DbCommand, IPKCommand
{
private readonly NpgsqlCommand _inner;
private readonly PKConnection _ourConnection;
private readonly ILogger _logger;
private readonly IMetrics _metrics;
public PKCommand(NpgsqlCommand inner, PKConnection ourConnection, ILogger logger, IMetrics metrics)
{
_inner = inner;
_ourConnection = ourConnection;
_logger = logger.ForContext<PKCommand>();
_metrics = metrics;
}
public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery));
public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar));
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => throw SyncError(nameof(ExecuteDbDataReader));
public override Task<int> ExecuteNonQueryAsync(CancellationToken ct) => LogQuery(_inner.ExecuteNonQueryAsync(ct));
public override Task<object> ExecuteScalarAsync(CancellationToken ct) => LogQuery(_inner.ExecuteScalarAsync(ct));
protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) => await LogQuery(_inner.ExecuteReaderAsync(behavior, ct));
public override void Prepare() => _inner.Prepare();
public override void Cancel() => _inner.Cancel();
protected override DbParameter CreateDbParameter() => _inner.CreateParameter();
public override string CommandText
{
get => _inner.CommandText;
set => _inner.CommandText = value;
}
public override int CommandTimeout
{
get => _inner.CommandTimeout;
set => _inner.CommandTimeout = value;
}
public override CommandType CommandType
{
get => _inner.CommandType;
set => _inner.CommandType = value;
}
public override UpdateRowSource UpdatedRowSource
{
get => _inner.UpdatedRowSource;
set => _inner.UpdatedRowSource = value;
}
protected override DbParameterCollection DbParameterCollection => _inner.Parameters;
protected override DbTransaction? DbTransaction
{
get => _inner.Transaction;
set => _inner.Transaction = (NpgsqlTransaction?) value;
}
public override bool DesignTimeVisible
{
get => _inner.DesignTimeVisible;
set => _inner.DesignTimeVisible = value;
}
protected override DbConnection? DbConnection
{
get => _inner.Connection;
set =>
_inner.Connection = value switch
{
NpgsqlConnection npg => npg,
PKConnection pk => pk.Inner,
_ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlConnection")
};
}
private async Task<T> LogQuery<T>(Task<T> task)
{
var start = SystemClock.Instance.GetCurrentInstant();
try
{
return await task;
}
finally
{
var end = SystemClock.Instance.GetCurrentInstant();
var elapsed = end - start;
_logger.Verbose("Executed query {Query} in {ElapsedTime} on connection {ConnectionId}", CommandText, elapsed, _ourConnection.ConnectionId);
// One "BCL compatible tick" is 100 nanoseconds
var micros = elapsed.BclCompatibleTicks / 10;
_metrics.Provider.Timer.Instance(CoreMetrics.DatabaseQuery, new MetricTags("query", CommandText))
.Record(micros, TimeUnit.Microseconds, CommandText);
}
}
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKCommand function {caller}!");
}
}

View File

@ -0,0 +1,107 @@
#nullable enable
using System;
using System.Data;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using App.Metrics;
using NodaTime;
using Npgsql;
using Serilog;
namespace PluralKit.Core
{
public class PKConnection: DbConnection, IPKConnection
{
public NpgsqlConnection Inner { get; }
public Guid ConnectionId { get; }
private readonly DbConnectionCountHolder _countHolder;
private readonly ILogger _logger;
private readonly IMetrics _metrics;
private bool _hasOpened;
private bool _hasClosed;
private Instant _openTime;
public PKConnection(NpgsqlConnection inner, DbConnectionCountHolder countHolder, ILogger logger, IMetrics metrics)
{
Inner = inner;
ConnectionId = Guid.NewGuid();
_countHolder = countHolder;
_logger = logger.ForContext<PKConnection>();
_metrics = metrics;
}
public override Task OpenAsync(CancellationToken ct)
{
if (_hasOpened) return Inner.OpenAsync(ct);
_countHolder.Increment();
_hasOpened = true;
_openTime = SystemClock.Instance.GetCurrentInstant();
_logger.Verbose("Opened database connection {ConnectionId}, new connection count {ConnectionCount}", ConnectionId, _countHolder.ConnectionCount);
return Inner.OpenAsync(ct);
}
public override Task CloseAsync() => Inner.CloseAsync();
protected override DbCommand CreateDbCommand() => new PKCommand(Inner.CreateCommand(), this, _logger, _metrics);
public void ReloadTypes() => Inner.ReloadTypes();
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) => Inner.BeginBinaryImport(copyFromCommand);
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) => Inner.BeginBinaryExport(copyToCommand);
public override void ChangeDatabase(string databaseName) => Inner.ChangeDatabase(databaseName);
public override Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default) => Inner.ChangeDatabaseAsync(databaseName, ct);
protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw SyncError(nameof(BeginDbTransaction));
protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => await Inner.BeginTransactionAsync(level, ct);
public override void Open() => throw SyncError(nameof(Open));
public override void Close() => throw SyncError(nameof(Close));
IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction));
IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction));
public override string ConnectionString
{
get => Inner.ConnectionString;
set => Inner.ConnectionString = value;
}
public override string? Database => Inner.Database;
public override ConnectionState State => Inner.State;
public override string DataSource => Inner.DataSource;
public override string ServerVersion => Inner.ServerVersion;
protected override void Dispose(bool disposing)
{
Inner.Dispose();
if (_hasClosed) return;
LogClose();
}
public override ValueTask DisposeAsync()
{
if (_hasClosed) return Inner.DisposeAsync();
LogClose();
return Inner.DisposeAsync();
}
private void LogClose()
{
_countHolder.Decrement();
_hasClosed = true;
var duration = SystemClock.Instance.GetCurrentInstant() - _openTime;
_logger.Verbose("Closed database connection {ConnectionId} (open for {ConnectionDuration}), new connection count {ConnectionCount}", ConnectionId, duration, _countHolder.ConnectionCount);
}
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKConnection function {caller}!");
}
}

View File

@ -1,5 +1,4 @@
#nullable enable #nullable enable
using System;
using System.Data; using System.Data;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -9,19 +8,19 @@ namespace PluralKit.Core
{ {
public static class ModelQueryExt public static class ModelQueryExt
{ {
public static Task<PKMember?> QueryMember(this IDbConnection conn, int id) => public static Task<PKMember?> QueryMember(this IPKConnection conn, int id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id}); conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public static Task<GuildConfig> QueryOrInsertGuildConfig(this IDbConnection conn, ulong guild) => public static Task<GuildConfig> QueryOrInsertGuildConfig(this IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@Guild) on conflict do nothing returning *", new {Guild = guild}); conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@Guild) on conflict do nothing returning *", new {Guild = guild});
public static Task<SystemGuildSettings> QueryOrInsertSystemGuildConfig(this IDbConnection conn, ulong guild, int system) => public static Task<SystemGuildSettings> QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, int system) =>
conn.QueryFirstAsync<SystemGuildSettings>( conn.QueryFirstAsync<SystemGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *", "insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, system}); new {guild, system});
public static Task<MemberGuildSettings> QueryOrInsertMemberGuildConfig( public static Task<MemberGuildSettings> QueryOrInsertMemberGuildConfig(
this IDbConnection conn, ulong guild, int member) => this IPKConnection conn, ulong guild, int member) =>
conn.QueryFirstAsync<MemberGuildSettings>( conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *", "insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member}); new {guild, member});

View File

@ -23,7 +23,7 @@ namespace PluralKit.Core
protected override void Load(ContainerBuilder builder) protected override void Load(ContainerBuilder builder)
{ {
builder.RegisterType<DbConnectionCountHolder>().SingleInstance(); builder.RegisterType<DbConnectionCountHolder>().SingleInstance();
builder.RegisterType<DbConnectionFactory>().AsSelf().SingleInstance(); builder.RegisterType<Database>().AsSelf().SingleInstance();
builder.RegisterType<PostgresDataStore>().AsSelf().As<IDataStore>(); builder.RegisterType<PostgresDataStore>().AsSelf().As<IDataStore>();
builder.RegisterType<Schemas>().AsSelf(); builder.RegisterType<Schemas>().AsSelf();
@ -99,7 +99,7 @@ namespace PluralKit.Core
return new LoggerConfiguration() return new LoggerConfiguration()
.ConfigureForNodaTime(DateTimeZoneProviders.Tzdb) .ConfigureForNodaTime(DateTimeZoneProviders.Tzdb)
.MinimumLevel.Debug() .MinimumLevel.Is(config.ConsoleLogLevel)
.WriteTo.Async(a => .WriteTo.Async(a =>
{ {
// Both the same output, except one is raw compact JSON and one is plain text. // Both the same output, except one is raw compact JSON and one is plain text.
@ -110,7 +110,7 @@ namespace PluralKit.Core
outputTemplate: outputTemplate, outputTemplate: outputTemplate,
rollingInterval: RollingInterval.Day, rollingInterval: RollingInterval.Day,
flushToDiskInterval: TimeSpan.FromMilliseconds(50), flushToDiskInterval: TimeSpan.FromMilliseconds(50),
restrictedToMinimumLevel: LogEventLevel.Information, restrictedToMinimumLevel: config.FileLogLevel,
formatProvider: new UTCTimestampFormatProvider(), formatProvider: new UTCTimestampFormatProvider(),
buffered: true); buffered: true);
@ -119,7 +119,7 @@ namespace PluralKit.Core
(config.LogDir ?? "logs") + $"/pluralkit.{_component}.json", (config.LogDir ?? "logs") + $"/pluralkit.{_component}.json",
rollingInterval: RollingInterval.Day, rollingInterval: RollingInterval.Day,
flushToDiskInterval: TimeSpan.FromMilliseconds(50), flushToDiskInterval: TimeSpan.FromMilliseconds(50),
restrictedToMinimumLevel: LogEventLevel.Information, restrictedToMinimumLevel: config.FileLogLevel,
buffered: true); buffered: true);
}) })
// TODO: render as UTC in the console, too? or just in log files // TODO: render as UTC in the console, too? or just in log files

View File

@ -15,10 +15,10 @@ namespace PluralKit.Core
public class DataFileService public class DataFileService
{ {
private IDataStore _data; private IDataStore _data;
private DbConnectionFactory _db; private Database _db;
private ILogger _logger; private ILogger _logger;
public DataFileService(ILogger logger, IDataStore data, DbConnectionFactory db) public DataFileService(ILogger logger, IDataStore data, Database db)
{ {
_data = data; _data = data;
_db = db; _db = db;
@ -127,8 +127,8 @@ namespace PluralKit.Core
await _data.SaveSystem(system); await _data.SaveSystem(system);
// -- Member/switch import -- // -- Member/switch import --
await using var conn = (PerformanceTrackingConnection) await _db.Obtain(); await using var conn = await _db.Obtain();
await using (var imp = await BulkImporter.Begin(system, conn._impl)) await using (var imp = await BulkImporter.Begin(system, conn))
{ {
// Tally up the members that didn't exist before, and check member count on import // Tally up the members that didn't exist before, and check member count on import
// If creating the unmatched members would put us over the member limit, abort before creating any members // If creating the unmatched members would put us over the member limit, abort before creating any members

View File

@ -201,7 +201,7 @@ namespace PluralKit.Core {
/// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param> /// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param>
/// <param name="proxiedMemberId">The member (and by extension system) that was proxied.</param> /// <param name="proxiedMemberId">The member (and by extension system) that was proxied.</param>
/// <returns></returns> /// <returns></returns>
Task AddMessage(ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId); Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId);
/// <summary> /// <summary>
/// Deletes a message from the data store. /// Deletes a message from the data store.

View File

@ -1,4 +1,4 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -10,10 +10,10 @@ using Serilog;
namespace PluralKit.Core { namespace PluralKit.Core {
public class PostgresDataStore: IDataStore { public class PostgresDataStore: IDataStore {
private DbConnectionFactory _conn; private Database _conn;
private ILogger _logger; private ILogger _logger;
public PostgresDataStore(DbConnectionFactory conn, ILogger logger) public PostgresDataStore(Database conn, ILogger logger)
{ {
_conn = conn; _conn = conn;
_logger = logger; _logger = logger;
@ -182,8 +182,8 @@ namespace PluralKit.Core {
using (var conn = await _conn.Obtain()) using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(id) from members"); return await conn.ExecuteScalarAsync<ulong>("select count(id) from members");
} }
public async Task AddMessage(ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId) {
using (var conn = await _conn.Obtain()) public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new { await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new {
MessageId = postedMessageId, MessageId = postedMessageId,
@ -235,9 +235,9 @@ namespace PluralKit.Core {
public async Task AddSwitch(PKSystem system, IEnumerable<PKMember> members) public async Task AddSwitch(PKSystem system, IEnumerable<PKMember> members)
{ {
// Use a transaction here since we're doing multiple executed commands in one // Use a transaction here since we're doing multiple executed commands in one
using (var conn = await _conn.Obtain()) await using var conn = await _conn.Obtain();
using (var tx = conn.BeginTransaction()) using var tx = await conn.BeginTransactionAsync();
{
// First, we insert the switch itself // First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *", var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system.Id}); new {System = system.Id});
@ -256,7 +256,6 @@ namespace PluralKit.Core {
_logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id)); _logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id));
} }
}
public IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system) public IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system)
{ {
@ -276,8 +275,8 @@ namespace PluralKit.Core {
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(PKSystem system, Instant start, Instant end) public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(PKSystem system, Instant start, Instant end)
{ {
// Wrap multiple commands in a single transaction for performance // Wrap multiple commands in a single transaction for performance
using var conn = await _conn.Obtain(); await using var conn = await _conn.Obtain();
using var tx = conn.BeginTransaction(); await using var tx = await conn.BeginTransactionAsync();
// Find the time of the last switch outside the range as it overlaps the range // Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time // If no prior switch exists, the lower bound of the range remains the start time

View File

@ -3,6 +3,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Data; using System.Data;
using System.Data.Common;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -19,20 +20,20 @@ namespace PluralKit.Core
public class BulkImporter: IAsyncDisposable public class BulkImporter: IAsyncDisposable
{ {
private readonly int _systemId; private readonly int _systemId;
private readonly NpgsqlConnection _conn; private readonly IPKConnection _conn;
private readonly NpgsqlTransaction _tx; private readonly DbTransaction _tx;
private readonly Dictionary<string, int> _knownMembers = new Dictionary<string, int>(); private readonly Dictionary<string, int> _knownMembers = new Dictionary<string, int>();
private readonly Dictionary<string, PKMember> _existingMembersByHid = new Dictionary<string, PKMember>(); private readonly Dictionary<string, PKMember> _existingMembersByHid = new Dictionary<string, PKMember>();
private readonly Dictionary<string, PKMember> _existingMembersByName = new Dictionary<string, PKMember>(); private readonly Dictionary<string, PKMember> _existingMembersByName = new Dictionary<string, PKMember>();
private BulkImporter(int systemId, NpgsqlConnection conn, NpgsqlTransaction tx) private BulkImporter(int systemId, IPKConnection conn, DbTransaction tx)
{ {
_systemId = systemId; _systemId = systemId;
_conn = conn; _conn = conn;
_tx = tx; _tx = tx;
} }
public static async Task<BulkImporter> Begin(PKSystem system, NpgsqlConnection conn) public static async Task<BulkImporter> Begin(PKSystem system, IPKConnection conn)
{ {
var tx = await conn.BeginTransactionAsync(); var tx = await conn.BeginTransactionAsync();
var importer = new BulkImporter(system.Id, conn, tx); var importer = new BulkImporter(system.Id, conn, tx);

View File

@ -7,21 +7,21 @@ using Dapper;
namespace PluralKit.Core { namespace PluralKit.Core {
public static class ConnectionUtils public static class ConnectionUtils
{ {
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this DbConnectionFactory connFactory, string sql, object param) public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this Database connFactory, string sql, object param)
{ {
using var conn = await connFactory.Obtain(); await using var conn = await connFactory.Obtain();
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param); await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>(); var parser = reader.GetRowParser<T>();
while (reader.Read()) while (await reader.ReadAsync())
yield return parser(reader); yield return parser(reader);
} }
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IDbConnection conn, string sql, object param) public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IPKConnection conn, string sql, object param)
{ {
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param); await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>(); var parser = reader.GetRowParser<T>();
while (reader.Read()) while (await reader.ReadAsync())
yield return parser(reader); yield return parser(reader);
} }
} }

View File

@ -1,234 +1,10 @@
using System; using System.Data;
using System.Data;
using System.Data.Common;
using System.Diagnostics;
using System.Threading; using System.Threading;
using System.Threading.Tasks;
using App.Metrics;
using Dapper; using Dapper;
using Npgsql;
using Serilog;
namespace PluralKit.Core namespace PluralKit.Core
{ {
public class QueryLogger : IDisposable
{
private ILogger _logger;
private IMetrics _metrics;
private string _commandText;
private Stopwatch _stopwatch;
public QueryLogger(ILogger logger, IMetrics metrics, string commandText)
{
_metrics = metrics;
_commandText = commandText;
_logger = logger;
_stopwatch = new Stopwatch();
_stopwatch.Start();
}
public void Dispose()
{
_stopwatch.Stop();
_logger.Verbose("Executed query {Query} in {ElapsedTime}", _commandText, _stopwatch.Elapsed);
// One tick is 100 nanoseconds
_metrics.Provider.Timer.Instance(CoreMetrics.DatabaseQuery, new MetricTags("query", _commandText))
.Record(_stopwatch.ElapsedTicks / 10, TimeUnit.Microseconds, _commandText);
}
}
public class PerformanceTrackingCommand: DbCommand
{
private NpgsqlCommand _impl;
private ILogger _logger;
private IMetrics _metrics;
public PerformanceTrackingCommand(NpgsqlCommand impl, ILogger logger, IMetrics metrics)
{
_impl = impl;
_metrics = metrics;
_logger = logger;
}
public override void Cancel()
{
_impl.Cancel();
}
public override int ExecuteNonQuery()
{
return _impl.ExecuteNonQuery();
}
public override object ExecuteScalar()
{
return _impl.ExecuteScalar();
}
public override void Prepare()
{
_impl.Prepare();
}
public override string CommandText
{
get => _impl.CommandText;
set => _impl.CommandText = value;
}
public override int CommandTimeout
{
get => _impl.CommandTimeout;
set => _impl.CommandTimeout = value;
}
public override CommandType CommandType
{
get => _impl.CommandType;
set => _impl.CommandType = value;
}
public override UpdateRowSource UpdatedRowSource
{
get => _impl.UpdatedRowSource;
set => _impl.UpdatedRowSource = value;
}
protected override DbConnection DbConnection
{
get => _impl.Connection;
set => _impl.Connection = (NpgsqlConnection) value;
}
protected override DbParameterCollection DbParameterCollection => _impl.Parameters;
protected override DbTransaction DbTransaction
{
get => _impl.Transaction;
set => _impl.Transaction = (NpgsqlTransaction) value;
}
public override bool DesignTimeVisible
{
get => _impl.DesignTimeVisible;
set => _impl.DesignTimeVisible = value;
}
protected override DbParameter CreateDbParameter()
{
return _impl.CreateParameter();
}
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
{
return _impl.ExecuteReader(behavior);
}
private IDisposable LogQuery()
{
return new QueryLogger(_logger, _metrics, CommandText);
}
protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(
CommandBehavior behavior, CancellationToken cancellationToken)
{
using (LogQuery())
return await _impl.ExecuteReaderAsync(behavior, cancellationToken);
}
public override async Task<int> ExecuteNonQueryAsync(CancellationToken cancellationToken)
{
using (LogQuery())
return await _impl.ExecuteNonQueryAsync(cancellationToken);
}
public override async Task<object> ExecuteScalarAsync(CancellationToken cancellationToken)
{
using (LogQuery())
return await _impl.ExecuteScalarAsync(cancellationToken);
}
}
public class PerformanceTrackingConnection: IAsyncDbConnection
{
// Simple delegation of everything.
internal NpgsqlConnection _impl;
private DbConnectionCountHolder _countHolder;
private ILogger _logger;
private IMetrics _metrics;
public PerformanceTrackingConnection(NpgsqlConnection impl, DbConnectionCountHolder countHolder,
ILogger logger, IMetrics metrics)
{
_impl = impl;
_countHolder = countHolder;
_logger = logger;
_metrics = metrics;
}
public void Dispose()
{
_impl.Dispose();
_countHolder.Decrement();
}
public IDbTransaction BeginTransaction()
{
return _impl.BeginTransaction();
}
public IDbTransaction BeginTransaction(IsolationLevel il)
{
return _impl.BeginTransaction(il);
}
public void ChangeDatabase(string databaseName)
{
_impl.ChangeDatabase(databaseName);
}
public void Close()
{
_impl.Close();
}
public IDbCommand CreateCommand()
{
return new PerformanceTrackingCommand(_impl.CreateCommand(), _logger, _metrics);
}
public void Open()
{
_impl.Open();
}
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand)
{
return _impl.BeginBinaryImport(copyFromCommand);
}
public string ConnectionString
{
get => _impl.ConnectionString;
set => _impl.ConnectionString = value;
}
public int ConnectionTimeout => _impl.ConnectionTimeout;
public string Database => _impl.Database;
public ConnectionState State => _impl.State;
public ValueTask DisposeAsync() => _impl.DisposeAsync();
}
public class DbConnectionCountHolder public class DbConnectionCountHolder
{ {
private int _connectionCount; private int _connectionCount;
@ -245,43 +21,6 @@ namespace PluralKit.Core
} }
} }
public interface IAsyncDbConnection: IDbConnection, IAsyncDisposable
{
}
public class DbConnectionFactory
{
private CoreConfig _config;
private ILogger _logger;
private IMetrics _metrics;
private DbConnectionCountHolder _countHolder;
public DbConnectionFactory(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger,
IMetrics metrics)
{
_config = config;
_countHolder = countHolder;
_metrics = metrics;
_logger = logger;
}
public async Task<IAsyncDbConnection> Obtain()
{
// Mark the request (for a handle, I guess) in the metrics
_metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests);
// Actually create and try to open the connection
var conn = new NpgsqlConnection(_config.Database);
await conn.OpenAsync();
// Increment the count
_countHolder.Increment();
// Return a wrapped connection which will decrement the counter on dispose
return new PerformanceTrackingConnection(conn, _countHolder, _logger, _metrics);
}
}
public class PassthroughTypeHandler<T>: SqlMapper.TypeHandler<T> public class PassthroughTypeHandler<T>: SqlMapper.TypeHandler<T>
{ {
public override void SetValue(IDbDataParameter parameter, T value) public override void SetValue(IDbDataParameter parameter, T value)
@ -308,18 +47,4 @@ namespace PluralKit.Core
parameter.Value = (long) value; parameter.Value = (long) value;
} }
} }
public static class DatabaseExt
{
public static async Task Execute(this DbConnectionFactory db, Func<IDbConnection, Task> func)
{
await using var conn = await db.Obtain();
await func(conn);
}
public static async Task<T> Execute<T>(this DbConnectionFactory db, Func<IDbConnection, Task<T>> func)
{
await using var conn = await db.Obtain();
return await func(conn);
}
}
} }