using System.Data; using System.Runtime.CompilerServices; using App.Metrics; using Dapper; using NodaTime; using Npgsql; using Serilog; using SqlKata; using SqlKata.Compilers; namespace PluralKit.Core; internal partial class Database: IDatabase { private readonly CoreConfig _config; private readonly ILogger _logger; private readonly IMetrics _metrics; private readonly DbConnectionCountHolder _countHolder; private readonly DatabaseMigrator _migrator; private readonly string _connectionString; public Database(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger, IMetrics metrics, DatabaseMigrator migrator) { _config = config; _countHolder = countHolder; _metrics = metrics; _migrator = migrator; _logger = logger.ForContext(); var connectionString = new NpgsqlConnectionStringBuilder(_config.Database) { Pooling = true, Enlist = false, NoResetOnClose = true, // Lower timeout than default (15s -> 2s), should ideally fail-fast instead of hanging Timeout = 2 }; if (_config.DatabasePassword != null) connectionString.Password = _config.DatabasePassword; _connectionString = connectionString.ConnectionString; } private static readonly PostgresCompiler _compiler = new(); public static void InitStatic() { DefaultTypeMap.MatchNamesWithUnderscores = true; // Dapper by default tries to pass ulongs to Npgsql, which rejects them since PostgreSQL technically // doesn't support unsigned types on its own. // Instead we add a custom mapper to encode them as signed integers instead, converting them back and forth. SqlMapper.RemoveTypeMap(typeof(ulong)); SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler()); SqlMapper.AddTypeHandler(new UlongArrayHandler()); NpgsqlConnection.GlobalTypeMapper.UseNodaTime(); // With the thing we add above, Npgsql already handles NodaTime integration // This makes Dapper confused since it thinks it has to convert it anyway and doesn't understand the types // So we add a custom type handler that literally just passes the type through to Npgsql SqlMapper.AddTypeHandler(new PassthroughTypeHandler()); SqlMapper.AddTypeHandler(new PassthroughTypeHandler()); // Add ID types to Dapper SqlMapper.AddTypeHandler(new NumericIdHandler(i => new SystemId(i))); SqlMapper.AddTypeHandler(new NumericIdHandler(i => new MemberId(i))); SqlMapper.AddTypeHandler(new NumericIdHandler(i => new SwitchId(i))); SqlMapper.AddTypeHandler(new NumericIdHandler(i => new GroupId(i))); SqlMapper.AddTypeHandler(new NumericIdArrayHandler(i => new SystemId(i))); SqlMapper.AddTypeHandler(new NumericIdArrayHandler(i => new MemberId(i))); SqlMapper.AddTypeHandler(new NumericIdArrayHandler(i => new SwitchId(i))); SqlMapper.AddTypeHandler(new NumericIdArrayHandler(i => new GroupId(i))); // Register our custom types to Npgsql // Without these it'll still *work* but break at the first launch + probably cause other small issues NpgsqlConnection.GlobalTypeMapper.MapComposite("proxy_tag"); NpgsqlConnection.GlobalTypeMapper.MapEnum("privacy_level"); } // TODO: make sure every SQL query is behind a logged query method public async Task Obtain() { // Mark the request (for a handle, I guess) in the metrics _metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests); // Create a connection and open it // We wrap it in PKConnection for tracing purposes var conn = new PKConnection(new NpgsqlConnection(_connectionString), _countHolder, _logger, _metrics); await conn.OpenAsync(); return conn; } public async Task ApplyMigrations() { using var conn = await Obtain(); await _migrator.ApplyMigrations(conn); } private class PassthroughTypeHandler: SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, T value) => parameter.Value = value; public override T Parse(object value) => (T)value; } private class UlongEncodeAsLongHandler: SqlMapper.TypeHandler { public override ulong Parse(object value) => // Cast to long to unbox, then to ulong (???) (ulong)(long)value; public override void SetValue(IDbDataParameter parameter, ulong value) => parameter.Value = (long)value; } private class UlongArrayHandler: SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, ulong[] value) => parameter.Value = Array.ConvertAll(value, i => (long)i); public override ulong[] Parse(object value) => Array.ConvertAll((long[])value, i => (ulong)i); } private class NumericIdHandler: SqlMapper.TypeHandler where T : INumericId where TInner : IEquatable, IComparable { private readonly Func _factory; public NumericIdHandler(Func factory) { _factory = factory; } public override void SetValue(IDbDataParameter parameter, T value) => parameter.Value = value.Value; public override T Parse(object value) => _factory((TInner)value); } private class NumericIdArrayHandler: SqlMapper.TypeHandler where T : INumericId where TInner : IEquatable, IComparable { private readonly Func _factory; public NumericIdArrayHandler(Func factory) { _factory = factory; } public override void SetValue(IDbDataParameter parameter, T[] value) => parameter.Value = Array.ConvertAll(value, v => v.Value); public override T[] Parse(object value) => Array.ConvertAll((TInner[])value, v => _factory(v)); } }