feat: upgrade to .NET 6, refactor everything
This commit is contained in:
@@ -1,9 +1,5 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Data;
|
||||
using System.IO;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using App.Metrics;
|
||||
|
||||
@@ -18,232 +14,231 @@ using Serilog;
|
||||
using SqlKata;
|
||||
using SqlKata.Compilers;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
internal class Database: IDatabase
|
||||
{
|
||||
internal class Database: IDatabase
|
||||
|
||||
private readonly CoreConfig _config;
|
||||
private readonly ILogger _logger;
|
||||
private readonly IMetrics _metrics;
|
||||
private readonly DbConnectionCountHolder _countHolder;
|
||||
private readonly DatabaseMigrator _migrator;
|
||||
private readonly string _connectionString;
|
||||
|
||||
public Database(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger,
|
||||
IMetrics metrics, DatabaseMigrator migrator)
|
||||
{
|
||||
_config = config;
|
||||
_countHolder = countHolder;
|
||||
_metrics = metrics;
|
||||
_migrator = migrator;
|
||||
_logger = logger.ForContext<Database>();
|
||||
|
||||
private readonly CoreConfig _config;
|
||||
private readonly ILogger _logger;
|
||||
private readonly IMetrics _metrics;
|
||||
private readonly DbConnectionCountHolder _countHolder;
|
||||
private readonly DatabaseMigrator _migrator;
|
||||
private readonly string _connectionString;
|
||||
|
||||
public Database(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger,
|
||||
IMetrics metrics, DatabaseMigrator migrator)
|
||||
_connectionString = new NpgsqlConnectionStringBuilder(_config.Database)
|
||||
{
|
||||
_config = config;
|
||||
_countHolder = countHolder;
|
||||
_metrics = metrics;
|
||||
_migrator = migrator;
|
||||
_logger = logger.ForContext<Database>();
|
||||
Pooling = true,
|
||||
Enlist = false,
|
||||
NoResetOnClose = true,
|
||||
|
||||
_connectionString = new NpgsqlConnectionStringBuilder(_config.Database)
|
||||
{
|
||||
Pooling = true,
|
||||
Enlist = false,
|
||||
NoResetOnClose = true,
|
||||
// Lower timeout than default (15s -> 2s), should ideally fail-fast instead of hanging
|
||||
Timeout = 2
|
||||
}.ConnectionString;
|
||||
}
|
||||
|
||||
// Lower timeout than default (15s -> 2s), should ideally fail-fast instead of hanging
|
||||
Timeout = 2
|
||||
}.ConnectionString;
|
||||
private static readonly PostgresCompiler _compiler = new();
|
||||
|
||||
public static void InitStatic()
|
||||
{
|
||||
DefaultTypeMap.MatchNamesWithUnderscores = true;
|
||||
|
||||
// Dapper by default tries to pass ulongs to Npgsql, which rejects them since PostgreSQL technically
|
||||
// doesn't support unsigned types on its own.
|
||||
// Instead we add a custom mapper to encode them as signed integers instead, converting them back and forth.
|
||||
SqlMapper.RemoveTypeMap(typeof(ulong));
|
||||
SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler());
|
||||
SqlMapper.AddTypeHandler(new UlongArrayHandler());
|
||||
|
||||
NpgsqlConnection.GlobalTypeMapper.UseNodaTime();
|
||||
// With the thing we add above, Npgsql already handles NodaTime integration
|
||||
// This makes Dapper confused since it thinks it has to convert it anyway and doesn't understand the types
|
||||
// So we add a custom type handler that literally just passes the type through to Npgsql
|
||||
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<Instant>());
|
||||
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<LocalDate>());
|
||||
|
||||
// Add ID types to Dapper
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<SystemId, int>(i => new SystemId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<MemberId, int>(i => new MemberId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<SwitchId, int>(i => new SwitchId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<GroupId, int>(i => new GroupId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<SystemId, int>(i => new SystemId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<MemberId, int>(i => new MemberId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<SwitchId, int>(i => new SwitchId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<GroupId, int>(i => new GroupId(i)));
|
||||
|
||||
// Register our custom types to Npgsql
|
||||
// Without these it'll still *work* but break at the first launch + probably cause other small issues
|
||||
NpgsqlConnection.GlobalTypeMapper.MapComposite<ProxyTag>("proxy_tag");
|
||||
NpgsqlConnection.GlobalTypeMapper.MapEnum<PrivacyLevel>("privacy_level");
|
||||
}
|
||||
|
||||
public async Task<IPKConnection> Obtain()
|
||||
{
|
||||
// Mark the request (for a handle, I guess) in the metrics
|
||||
_metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests);
|
||||
|
||||
// Create a connection and open it
|
||||
// We wrap it in PKConnection for tracing purposes
|
||||
var conn = new PKConnection(new NpgsqlConnection(_connectionString), _countHolder, _logger, _metrics);
|
||||
await conn.OpenAsync();
|
||||
return conn;
|
||||
}
|
||||
|
||||
public async Task ApplyMigrations()
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
await _migrator.ApplyMigrations(conn);
|
||||
}
|
||||
|
||||
private class PassthroughTypeHandler<T>: SqlMapper.TypeHandler<T>
|
||||
{
|
||||
public override void SetValue(IDbDataParameter parameter, T value) => parameter.Value = value;
|
||||
public override T Parse(object value) => (T)value;
|
||||
}
|
||||
|
||||
private class UlongEncodeAsLongHandler: SqlMapper.TypeHandler<ulong>
|
||||
{
|
||||
public override ulong Parse(object value) =>
|
||||
// Cast to long to unbox, then to ulong (???)
|
||||
(ulong)(long)value;
|
||||
|
||||
public override void SetValue(IDbDataParameter parameter, ulong value) => parameter.Value = (long)value;
|
||||
}
|
||||
|
||||
private class UlongArrayHandler: SqlMapper.TypeHandler<ulong[]>
|
||||
{
|
||||
public override void SetValue(IDbDataParameter parameter, ulong[] value) => parameter.Value = Array.ConvertAll(value, i => (long)i);
|
||||
|
||||
public override ulong[] Parse(object value) => Array.ConvertAll((long[])value, i => (ulong)i);
|
||||
}
|
||||
|
||||
private class NumericIdHandler<T, TInner>: SqlMapper.TypeHandler<T>
|
||||
where T : INumericId<T, TInner>
|
||||
where TInner : IEquatable<TInner>, IComparable<TInner>
|
||||
{
|
||||
private readonly Func<TInner, T> _factory;
|
||||
|
||||
public NumericIdHandler(Func<TInner, T> factory)
|
||||
{
|
||||
_factory = factory;
|
||||
}
|
||||
|
||||
private static readonly PostgresCompiler _compiler = new();
|
||||
public override void SetValue(IDbDataParameter parameter, T value) => parameter.Value = value.Value;
|
||||
|
||||
public static void InitStatic()
|
||||
public override T Parse(object value) => _factory((TInner)value);
|
||||
}
|
||||
|
||||
private class NumericIdArrayHandler<T, TInner>: SqlMapper.TypeHandler<T[]>
|
||||
where T : INumericId<T, TInner>
|
||||
where TInner : IEquatable<TInner>, IComparable<TInner>
|
||||
{
|
||||
private readonly Func<TInner, T> _factory;
|
||||
|
||||
public NumericIdArrayHandler(Func<TInner, T> factory)
|
||||
{
|
||||
DefaultTypeMap.MatchNamesWithUnderscores = true;
|
||||
|
||||
// Dapper by default tries to pass ulongs to Npgsql, which rejects them since PostgreSQL technically
|
||||
// doesn't support unsigned types on its own.
|
||||
// Instead we add a custom mapper to encode them as signed integers instead, converting them back and forth.
|
||||
SqlMapper.RemoveTypeMap(typeof(ulong));
|
||||
SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler());
|
||||
SqlMapper.AddTypeHandler(new UlongArrayHandler());
|
||||
|
||||
NpgsqlConnection.GlobalTypeMapper.UseNodaTime();
|
||||
// With the thing we add above, Npgsql already handles NodaTime integration
|
||||
// This makes Dapper confused since it thinks it has to convert it anyway and doesn't understand the types
|
||||
// So we add a custom type handler that literally just passes the type through to Npgsql
|
||||
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<Instant>());
|
||||
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<LocalDate>());
|
||||
|
||||
// Add ID types to Dapper
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<SystemId, int>(i => new SystemId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<MemberId, int>(i => new MemberId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<SwitchId, int>(i => new SwitchId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdHandler<GroupId, int>(i => new GroupId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<SystemId, int>(i => new SystemId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<MemberId, int>(i => new MemberId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<SwitchId, int>(i => new SwitchId(i)));
|
||||
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<GroupId, int>(i => new GroupId(i)));
|
||||
|
||||
// Register our custom types to Npgsql
|
||||
// Without these it'll still *work* but break at the first launch + probably cause other small issues
|
||||
NpgsqlConnection.GlobalTypeMapper.MapComposite<ProxyTag>("proxy_tag");
|
||||
NpgsqlConnection.GlobalTypeMapper.MapEnum<PrivacyLevel>("privacy_level");
|
||||
_factory = factory;
|
||||
}
|
||||
|
||||
public async Task<IPKConnection> Obtain()
|
||||
{
|
||||
// Mark the request (for a handle, I guess) in the metrics
|
||||
_metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests);
|
||||
public override void SetValue(IDbDataParameter parameter, T[] value) => parameter.Value = Array.ConvertAll(value, v => v.Value);
|
||||
|
||||
// Create a connection and open it
|
||||
// We wrap it in PKConnection for tracing purposes
|
||||
var conn = new PKConnection(new NpgsqlConnection(_connectionString), _countHolder, _logger, _metrics);
|
||||
await conn.OpenAsync();
|
||||
return conn;
|
||||
}
|
||||
public override T[] Parse(object value) => Array.ConvertAll((TInner[])value, v => _factory(v));
|
||||
}
|
||||
|
||||
public async Task ApplyMigrations()
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
await _migrator.ApplyMigrations(conn);
|
||||
}
|
||||
public async Task Execute(Func<IPKConnection, Task> func)
|
||||
{
|
||||
await using var conn = await Obtain();
|
||||
await func(conn);
|
||||
}
|
||||
|
||||
private class PassthroughTypeHandler<T>: SqlMapper.TypeHandler<T>
|
||||
{
|
||||
public override void SetValue(IDbDataParameter parameter, T value) => parameter.Value = value;
|
||||
public override T Parse(object value) => (T)value;
|
||||
}
|
||||
public async Task<T> Execute<T>(Func<IPKConnection, Task<T>> func)
|
||||
{
|
||||
await using var conn = await Obtain();
|
||||
return await func(conn);
|
||||
}
|
||||
|
||||
private class UlongEncodeAsLongHandler: SqlMapper.TypeHandler<ulong>
|
||||
{
|
||||
public override ulong Parse(object value) =>
|
||||
// Cast to long to unbox, then to ulong (???)
|
||||
(ulong)(long)value;
|
||||
public async IAsyncEnumerable<T> Execute<T>(Func<IPKConnection, IAsyncEnumerable<T>> func)
|
||||
{
|
||||
await using var conn = await Obtain();
|
||||
|
||||
public override void SetValue(IDbDataParameter parameter, ulong value) => parameter.Value = (long)value;
|
||||
}
|
||||
await foreach (var val in func(conn))
|
||||
yield return val;
|
||||
}
|
||||
|
||||
private class UlongArrayHandler: SqlMapper.TypeHandler<ulong[]>
|
||||
{
|
||||
public override void SetValue(IDbDataParameter parameter, ulong[] value) => parameter.Value = Array.ConvertAll(value, i => (long)i);
|
||||
public async Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public override ulong[] Parse(object value) => Array.ConvertAll((long[])value, i => (ulong)i);
|
||||
}
|
||||
public async Task<int> ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
if (conn == null)
|
||||
return await ExecuteQuery(q, extraSql, queryName);
|
||||
|
||||
private class NumericIdHandler<T, TInner>: SqlMapper.TypeHandler<T>
|
||||
where T : INumericId<T, TInner>
|
||||
where TInner : IEquatable<TInner>, IComparable<TInner>
|
||||
{
|
||||
private readonly Func<TInner, T> _factory;
|
||||
var query = _compiler.Compile(q);
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public NumericIdHandler(Func<TInner, T> factory)
|
||||
{
|
||||
_factory = factory;
|
||||
}
|
||||
public async Task<T> QueryFirst<T>(Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryFirstOrDefaultAsync<T>(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public override void SetValue(IDbDataParameter parameter, T value) => parameter.Value = value.Value;
|
||||
public async Task<T> QueryFirst<T>(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
if (conn == null)
|
||||
return await QueryFirst<T>(q, extraSql, queryName);
|
||||
|
||||
public override T Parse(object value) => _factory((TInner)value);
|
||||
}
|
||||
var query = _compiler.Compile(q);
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryFirstOrDefaultAsync<T>(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
private class NumericIdArrayHandler<T, TInner>: SqlMapper.TypeHandler<T[]>
|
||||
where T : INumericId<T, TInner>
|
||||
where TInner : IEquatable<TInner>, IComparable<TInner>
|
||||
{
|
||||
private readonly Func<TInner, T> _factory;
|
||||
public async Task<IEnumerable<T>> Query<T>(Query q, [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryAsync<T>(query.Sql, query.NamedBindings);
|
||||
}
|
||||
|
||||
public NumericIdArrayHandler(Func<TInner, T> factory)
|
||||
{
|
||||
_factory = factory;
|
||||
}
|
||||
|
||||
public override void SetValue(IDbDataParameter parameter, T[] value) => parameter.Value = Array.ConvertAll(value, v => v.Value);
|
||||
|
||||
public override T[] Parse(object value) => Array.ConvertAll((TInner[])value, v => _factory(v));
|
||||
}
|
||||
|
||||
public async Task Execute(Func<IPKConnection, Task> func)
|
||||
{
|
||||
await using var conn = await Obtain();
|
||||
await func(conn);
|
||||
}
|
||||
|
||||
public async Task<T> Execute<T>(Func<IPKConnection, Task<T>> func)
|
||||
{
|
||||
await using var conn = await Obtain();
|
||||
return await func(conn);
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<T> Execute<T>(Func<IPKConnection, IAsyncEnumerable<T>> func)
|
||||
{
|
||||
await using var conn = await Obtain();
|
||||
|
||||
await foreach (var val in func(conn))
|
||||
public async IAsyncEnumerable<T> QueryStream<T>(Query q, [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
await foreach (var val in conn.QueryStreamAsync<T>(query.Sql, query.NamedBindings))
|
||||
yield return val;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
// the procedures (message_context and proxy_members, as of writing) have their own metrics tracking elsewhere
|
||||
// still, including them here for consistency
|
||||
|
||||
public async Task<int> ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
if (conn == null)
|
||||
return await ExecuteQuery(q, extraSql, queryName);
|
||||
public async Task<T> QuerySingleProcedure<T>(string queryName, object param)
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
return await conn.QueryFirstAsync<T>(queryName, param, commandType: CommandType.StoredProcedure);
|
||||
}
|
||||
|
||||
var query = _compiler.Compile(q);
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public async Task<T> QueryFirst<T>(Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryFirstOrDefaultAsync<T>(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public async Task<T> QueryFirst<T>(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
if (conn == null)
|
||||
return await QueryFirst<T>(q, extraSql, queryName);
|
||||
|
||||
var query = _compiler.Compile(q);
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryFirstOrDefaultAsync<T>(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<T>> Query<T>(Query q, [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryAsync<T>(query.Sql, query.NamedBindings);
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<T> QueryStream<T>(Query q, [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
await foreach (var val in conn.QueryStreamAsync<T>(query.Sql, query.NamedBindings))
|
||||
yield return val;
|
||||
}
|
||||
|
||||
// the procedures (message_context and proxy_members, as of writing) have their own metrics tracking elsewhere
|
||||
// still, including them here for consistency
|
||||
|
||||
public async Task<T> QuerySingleProcedure<T>(string queryName, object param)
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
return await conn.QueryFirstAsync<T>(queryName, param, commandType: CommandType.StoredProcedure);
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<T>> QueryProcedure<T>(string queryName, object param)
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
return await conn.QueryAsync<T>(queryName, param, commandType: CommandType.StoredProcedure);
|
||||
}
|
||||
public async Task<IEnumerable<T>> QueryProcedure<T>(string queryName, object param)
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
return await conn.QueryAsync<T>(queryName, param, commandType: CommandType.StoredProcedure);
|
||||
}
|
||||
}
|
@@ -2,31 +2,30 @@
|
||||
|
||||
using NodaTime;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
/// <summary>
|
||||
/// Model for the `message_context` PL/pgSQL function in `functions.sql`
|
||||
/// </summary>
|
||||
public class MessageContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Model for the `message_context` PL/pgSQL function in `functions.sql`
|
||||
/// </summary>
|
||||
public class MessageContext
|
||||
{
|
||||
public SystemId? SystemId { get; }
|
||||
public ulong? LogChannel { get; }
|
||||
public bool InBlacklist { get; }
|
||||
public bool InLogBlacklist { get; }
|
||||
public bool LogCleanupEnabled { get; }
|
||||
public bool ProxyEnabled { get; }
|
||||
public AutoproxyMode AutoproxyMode { get; }
|
||||
public MemberId? AutoproxyMember { get; }
|
||||
public ulong? LastMessage { get; }
|
||||
public MemberId? LastMessageMember { get; }
|
||||
public SwitchId? LastSwitch { get; }
|
||||
public MemberId[] LastSwitchMembers { get; } = new MemberId[0];
|
||||
public Instant? LastSwitchTimestamp { get; }
|
||||
public string? SystemTag { get; }
|
||||
public string? SystemGuildTag { get; }
|
||||
public bool TagEnabled { get; }
|
||||
public string? SystemAvatar { get; }
|
||||
public bool AllowAutoproxy { get; }
|
||||
public int? LatchTimeout { get; }
|
||||
}
|
||||
public SystemId? SystemId { get; }
|
||||
public ulong? LogChannel { get; }
|
||||
public bool InBlacklist { get; }
|
||||
public bool InLogBlacklist { get; }
|
||||
public bool LogCleanupEnabled { get; }
|
||||
public bool ProxyEnabled { get; }
|
||||
public AutoproxyMode AutoproxyMode { get; }
|
||||
public MemberId? AutoproxyMember { get; }
|
||||
public ulong? LastMessage { get; }
|
||||
public MemberId? LastMessageMember { get; }
|
||||
public SwitchId? LastSwitch { get; }
|
||||
public MemberId[] LastSwitchMembers { get; } = new MemberId[0];
|
||||
public Instant? LastSwitchTimestamp { get; }
|
||||
public string? SystemTag { get; }
|
||||
public string? SystemGuildTag { get; }
|
||||
public bool TagEnabled { get; }
|
||||
public string? SystemAvatar { get; }
|
||||
public bool AllowAutoproxy { get; }
|
||||
public int? LatchTimeout { get; }
|
||||
}
|
@@ -1,48 +1,46 @@
|
||||
#nullable enable
|
||||
using System.Collections.Generic;
|
||||
namespace PluralKit.Core;
|
||||
|
||||
namespace PluralKit.Core
|
||||
/// <summary>
|
||||
/// Model for the `proxy_members` PL/pgSQL function in `functions.sql`
|
||||
/// </summary>
|
||||
public class ProxyMember
|
||||
{
|
||||
/// <summary>
|
||||
/// Model for the `proxy_members` PL/pgSQL function in `functions.sql`
|
||||
/// </summary>
|
||||
public class ProxyMember
|
||||
public ProxyMember() { }
|
||||
|
||||
public ProxyMember(string name, params ProxyTag[] tags)
|
||||
{
|
||||
public MemberId Id { get; }
|
||||
public IReadOnlyCollection<ProxyTag> ProxyTags { get; } = new ProxyTag[0];
|
||||
public bool KeepProxy { get; }
|
||||
|
||||
public string? ServerName { get; }
|
||||
public string? DisplayName { get; }
|
||||
public string Name { get; } = "";
|
||||
|
||||
public string? ServerAvatar { get; }
|
||||
public string? Avatar { get; }
|
||||
|
||||
|
||||
public bool AllowAutoproxy { get; }
|
||||
public string? Color { get; }
|
||||
|
||||
public string ProxyName(MessageContext ctx)
|
||||
{
|
||||
var memberName = ServerName ?? DisplayName ?? Name;
|
||||
if (!ctx.TagEnabled)
|
||||
return memberName;
|
||||
|
||||
if (ctx.SystemGuildTag != null)
|
||||
return $"{memberName} {ctx.SystemGuildTag}";
|
||||
else if (ctx.SystemTag != null)
|
||||
return $"{memberName} {ctx.SystemTag}";
|
||||
else return memberName;
|
||||
}
|
||||
public string? ProxyAvatar(MessageContext ctx) => ServerAvatar ?? Avatar ?? ctx.SystemAvatar;
|
||||
|
||||
public ProxyMember() { }
|
||||
|
||||
public ProxyMember(string name, params ProxyTag[] tags)
|
||||
{
|
||||
Name = name;
|
||||
ProxyTags = tags;
|
||||
}
|
||||
Name = name;
|
||||
ProxyTags = tags;
|
||||
}
|
||||
|
||||
public MemberId Id { get; }
|
||||
public IReadOnlyCollection<ProxyTag> ProxyTags { get; } = new ProxyTag[0];
|
||||
public bool KeepProxy { get; }
|
||||
|
||||
public string? ServerName { get; }
|
||||
public string? DisplayName { get; }
|
||||
public string Name { get; } = "";
|
||||
|
||||
public string? ServerAvatar { get; }
|
||||
public string? Avatar { get; }
|
||||
|
||||
|
||||
public bool AllowAutoproxy { get; }
|
||||
public string? Color { get; }
|
||||
|
||||
public string ProxyName(MessageContext ctx)
|
||||
{
|
||||
var memberName = ServerName ?? DisplayName ?? Name;
|
||||
if (!ctx.TagEnabled)
|
||||
return memberName;
|
||||
|
||||
if (ctx.SystemGuildTag != null)
|
||||
return $"{memberName} {ctx.SystemGuildTag}";
|
||||
if (ctx.SystemTag != null)
|
||||
return $"{memberName} {ctx.SystemTag}";
|
||||
return memberName;
|
||||
}
|
||||
|
||||
public string? ProxyAvatar(MessageContext ctx) => ServerAvatar ?? Avatar ?? ctx.SystemAvatar;
|
||||
}
|
@@ -1,26 +1,28 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public interface IDatabase
|
||||
{
|
||||
public interface IDatabase
|
||||
{
|
||||
Task ApplyMigrations();
|
||||
Task<IPKConnection> Obtain();
|
||||
Task Execute(Func<IPKConnection, Task> func);
|
||||
Task<T> Execute<T>(Func<IPKConnection, Task<T>> func);
|
||||
IAsyncEnumerable<T> Execute<T>(Func<IPKConnection, IAsyncEnumerable<T>> func);
|
||||
Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "");
|
||||
Task<int> ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "");
|
||||
Task<T> QueryFirst<T>(Query q, string extraSql = "", [CallerMemberName] string queryName = "");
|
||||
Task<T> QueryFirst<T>(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "");
|
||||
Task<IEnumerable<T>> Query<T>(Query q, [CallerMemberName] string queryName = "");
|
||||
IAsyncEnumerable<T> QueryStream<T>(Query q, [CallerMemberName] string queryName = "");
|
||||
Task<T> QuerySingleProcedure<T>(string queryName, object param);
|
||||
Task<IEnumerable<T>> QueryProcedure<T>(string queryName, object param);
|
||||
}
|
||||
Task ApplyMigrations();
|
||||
Task<IPKConnection> Obtain();
|
||||
Task Execute(Func<IPKConnection, Task> func);
|
||||
Task<T> Execute<T>(Func<IPKConnection, Task<T>> func);
|
||||
IAsyncEnumerable<T> Execute<T>(Func<IPKConnection, IAsyncEnumerable<T>> func);
|
||||
Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "");
|
||||
|
||||
Task<int> ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "",
|
||||
[CallerMemberName] string queryName = "");
|
||||
|
||||
Task<T> QueryFirst<T>(Query q, string extraSql = "", [CallerMemberName] string queryName = "");
|
||||
|
||||
Task<T> QueryFirst<T>(IPKConnection? conn, Query q, string extraSql = "",
|
||||
[CallerMemberName] string queryName = "");
|
||||
|
||||
Task<IEnumerable<T>> Query<T>(Query q, [CallerMemberName] string queryName = "");
|
||||
IAsyncEnumerable<T> QueryStream<T>(Query q, [CallerMemberName] string queryName = "");
|
||||
Task<T> QuerySingleProcedure<T>(string queryName, object param);
|
||||
Task<IEnumerable<T>> QueryProcedure<T>(string queryName, object param);
|
||||
}
|
@@ -1,17 +1,14 @@
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public async Task UpdateAccount(ulong id, AccountPatch patch)
|
||||
{
|
||||
public async Task UpdateAccount(ulong id, AccountPatch patch)
|
||||
{
|
||||
_logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("accounts").Where("uid", id));
|
||||
_ = _dispatch.Dispatch(id, patch);
|
||||
await _db.ExecuteQuery(query, extraSql: "returning *");
|
||||
}
|
||||
_logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("accounts").Where("uid", id));
|
||||
_ = _dispatch.Dispatch(id, patch);
|
||||
await _db.ExecuteQuery(query, "returning *");
|
||||
}
|
||||
}
|
@@ -1,39 +1,36 @@
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public Task SaveCommandMessage(ulong messageId, ulong channelId, ulong authorId)
|
||||
{
|
||||
public Task SaveCommandMessage(ulong messageId, ulong channelId, ulong authorId)
|
||||
var query = new Query("command_messages").AsInsert(new
|
||||
{
|
||||
var query = new Query("command_messages").AsInsert(new
|
||||
{
|
||||
message_id = messageId,
|
||||
channel_id = channelId,
|
||||
author_id = authorId,
|
||||
});
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
public Task<CommandMessage?> GetCommandMessage(ulong messageId)
|
||||
{
|
||||
var query = new Query("command_messages").Where("message_id", messageId);
|
||||
return _db.QueryFirst<CommandMessage?>(query);
|
||||
}
|
||||
|
||||
public Task<int> DeleteCommandMessagesBefore(ulong messageIdThreshold)
|
||||
{
|
||||
var query = new Query("command_messages").AsDelete().Where("message_id", "<", messageIdThreshold);
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
message_id = messageId,
|
||||
channel_id = channelId,
|
||||
author_id = authorId,
|
||||
});
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
public class CommandMessage
|
||||
public Task<CommandMessage?> GetCommandMessage(ulong messageId)
|
||||
{
|
||||
public ulong AuthorId { get; set; }
|
||||
public ulong MessageId { get; set; }
|
||||
public ulong ChannelId { get; set; }
|
||||
var query = new Query("command_messages").Where("message_id", messageId);
|
||||
return _db.QueryFirst<CommandMessage?>(query);
|
||||
}
|
||||
|
||||
public Task<int> DeleteCommandMessagesBefore(ulong messageIdThreshold)
|
||||
{
|
||||
var query = new Query("command_messages").AsDelete().Where("message_id", "<", messageIdThreshold);
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
}
|
||||
|
||||
public class CommandMessage
|
||||
{
|
||||
public ulong AuthorId { get; set; }
|
||||
public ulong MessageId { get; set; }
|
||||
public ulong ChannelId { get; set; }
|
||||
}
|
@@ -1,23 +1,11 @@
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
namespace PluralKit.Core;
|
||||
|
||||
namespace PluralKit.Core
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel)
|
||||
=> _db.QuerySingleProcedure<MessageContext>("message_context", new
|
||||
{
|
||||
account_id = account,
|
||||
guild_id = guild,
|
||||
channel_id = channel
|
||||
});
|
||||
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel)
|
||||
=> _db.QuerySingleProcedure<MessageContext>("message_context",
|
||||
new { account_id = account, guild_id = guild, channel_id = channel });
|
||||
|
||||
public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
|
||||
=> _db.QueryProcedure<ProxyMember>("proxy_members", new
|
||||
{
|
||||
account_id = account,
|
||||
guild_id = guild
|
||||
});
|
||||
}
|
||||
public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
|
||||
=> _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild });
|
||||
}
|
@@ -1,98 +1,86 @@
|
||||
#nullable enable
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Newtonsoft.Json.Linq;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public Task<PKGroup?> GetGroup(GroupId id)
|
||||
{
|
||||
public Task<PKGroup?> GetGroup(GroupId id)
|
||||
{
|
||||
var query = new Query("groups").Where("id", id);
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
var query = new Query("groups").Where("id", id);
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
|
||||
public Task<PKGroup?> GetGroupByName(SystemId system, string name)
|
||||
{
|
||||
var query = new Query("groups").Where("system", system).WhereRaw("lower(name) = lower(?)", name.ToLower());
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
public Task<PKGroup?> GetGroupByName(SystemId system, string name)
|
||||
{
|
||||
var query = new Query("groups").Where("system", system).WhereRaw("lower(name) = lower(?)", name.ToLower());
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
|
||||
public Task<PKGroup?> GetGroupByDisplayName(SystemId system, string display_name)
|
||||
{
|
||||
var query = new Query("groups").Where("system", system).WhereRaw("lower(display_name) = lower(?)", display_name.ToLower());
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
public Task<PKGroup?> GetGroupByDisplayName(SystemId system, string display_name)
|
||||
{
|
||||
var query = new Query("groups").Where("system", system)
|
||||
.WhereRaw("lower(display_name) = lower(?)", display_name.ToLower());
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
|
||||
public Task<PKGroup?> GetGroupByHid(string hid, SystemId? system = null)
|
||||
{
|
||||
var query = new Query("groups").Where("hid", hid.ToLower());
|
||||
if (system != null)
|
||||
query = query.Where("system", system);
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
public Task<PKGroup?> GetGroupByHid(string hid, SystemId? system = null)
|
||||
{
|
||||
var query = new Query("groups").Where("hid", hid.ToLower());
|
||||
if (system != null)
|
||||
query = query.Where("system", system);
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
|
||||
public Task<PKGroup?> GetGroupByGuid(Guid uuid)
|
||||
{
|
||||
var query = new Query("groups").Where("uuid", uuid);
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
public Task<PKGroup?> GetGroupByGuid(Guid uuid)
|
||||
{
|
||||
var query = new Query("groups").Where("uuid", uuid);
|
||||
return _db.QueryFirst<PKGroup?>(query);
|
||||
}
|
||||
|
||||
public Task<int> GetGroupMemberCount(GroupId id, PrivacyLevel? privacyFilter = null)
|
||||
{
|
||||
var query = new Query("group_members")
|
||||
.SelectRaw("count(*)")
|
||||
.Where("group_members.group_id", id);
|
||||
public Task<int> GetGroupMemberCount(GroupId id, PrivacyLevel? privacyFilter = null)
|
||||
{
|
||||
var query = new Query("group_members")
|
||||
.SelectRaw("count(*)")
|
||||
.Where("group_members.group_id", id);
|
||||
|
||||
if (privacyFilter != null) query = query
|
||||
if (privacyFilter != null)
|
||||
query = query
|
||||
.Join("members", "group_members.member_id", "members.id")
|
||||
.Where("members.member_visibility", privacyFilter);
|
||||
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
|
||||
public async Task<PKGroup> CreateGroup(SystemId system, string name, IPKConnection? conn = null)
|
||||
{
|
||||
var query = new Query("groups").AsInsert(new
|
||||
{
|
||||
hid = new UnsafeLiteral("find_free_group_hid()"),
|
||||
system = system,
|
||||
name = name
|
||||
});
|
||||
var group = await _db.QueryFirst<PKGroup>(conn, query, extraSql: "returning *");
|
||||
_logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
|
||||
return group;
|
||||
}
|
||||
public async Task<PKGroup> CreateGroup(SystemId system, string name, IPKConnection? conn = null)
|
||||
{
|
||||
var query = new Query("groups").AsInsert(new { hid = new UnsafeLiteral("find_free_group_hid()"), system, name });
|
||||
var group = await _db.QueryFirst<PKGroup>(conn, query, "returning *");
|
||||
_logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
|
||||
return group;
|
||||
}
|
||||
|
||||
public async Task<PKGroup> UpdateGroup(GroupId id, GroupPatch patch, IPKConnection? conn = null)
|
||||
{
|
||||
_logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("groups").Where("id", id));
|
||||
var group = await _db.QueryFirst<PKGroup>(conn, query, extraSql: "returning *");
|
||||
public async Task<PKGroup> UpdateGroup(GroupId id, GroupPatch patch, IPKConnection? conn = null)
|
||||
{
|
||||
_logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("groups").Where("id", id));
|
||||
var group = await _db.QueryFirst<PKGroup>(conn, query, "returning *");
|
||||
|
||||
if (conn == null)
|
||||
_ = _dispatch.Dispatch(id, new()
|
||||
{
|
||||
Event = DispatchEvent.UPDATE_GROUP,
|
||||
EventData = patch.ToJson(),
|
||||
});
|
||||
return group;
|
||||
}
|
||||
if (conn == null)
|
||||
_ = _dispatch.Dispatch(id,
|
||||
new UpdateDispatchData { Event = DispatchEvent.UPDATE_GROUP, EventData = patch.ToJson() });
|
||||
return group;
|
||||
}
|
||||
|
||||
public async Task DeleteGroup(GroupId group)
|
||||
{
|
||||
var oldGroup = await GetGroup(group);
|
||||
public async Task DeleteGroup(GroupId group)
|
||||
{
|
||||
var oldGroup = await GetGroup(group);
|
||||
|
||||
_logger.Information("Deleted {GroupId}", group);
|
||||
var query = new Query("groups").AsDelete().Where("id", group);
|
||||
await _db.ExecuteQuery(query);
|
||||
_logger.Information("Deleted {GroupId}", group);
|
||||
var query = new Query("groups").AsDelete().Where("id", group);
|
||||
await _db.ExecuteQuery(query);
|
||||
|
||||
if (oldGroup != null)
|
||||
_ = _dispatch.Dispatch(oldGroup.System, oldGroup.Uuid, DispatchEvent.DELETE_GROUP);
|
||||
}
|
||||
if (oldGroup != null)
|
||||
_ = _dispatch.Dispatch(oldGroup.System, oldGroup.Uuid, DispatchEvent.DELETE_GROUP);
|
||||
}
|
||||
}
|
@@ -1,115 +1,110 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public IAsyncEnumerable<PKGroup> GetMemberGroups(MemberId id)
|
||||
{
|
||||
public IAsyncEnumerable<PKGroup> GetMemberGroups(MemberId id)
|
||||
{
|
||||
var query = new Query("group_members")
|
||||
.Select("groups.*")
|
||||
.Join("groups", "group_members.group_id", "groups.id")
|
||||
.Where("group_members.member_id", id);
|
||||
return _db.QueryStream<PKGroup>(query);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<PKMember> GetGroupMembers(GroupId id)
|
||||
{
|
||||
var query = new Query("group_members")
|
||||
.Select("members.*")
|
||||
.Join("members", "group_members.member_id", "members.id")
|
||||
.Where("group_members.group_id", id);
|
||||
return _db.QueryStream<PKMember>(query);
|
||||
}
|
||||
|
||||
public Task<IEnumerable<GroupMember>> GetGroupMemberInfo(IEnumerable<GroupId> ids)
|
||||
{
|
||||
return _db.Query<GroupMember>(new Query("group_members")
|
||||
.LeftJoin("groups", "groups.id", "group_members.group_id")
|
||||
.LeftJoin("members", "members.id", "group_members.member_id")
|
||||
.Select("groups.hid as group", "members.hid as member", "members.uuid as member_uuid", "members.member_visibility")
|
||||
.WhereIn("group_members.group_id", ids.Select(x => x.Value).ToArray()));
|
||||
}
|
||||
|
||||
// todo: add this to metrics tracking
|
||||
public async Task AddGroupsToMember(MemberId member, IReadOnlyCollection<GroupId> groups)
|
||||
{
|
||||
await using var conn = await _db.Obtain();
|
||||
await using var w =
|
||||
conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
|
||||
foreach (var group in groups)
|
||||
{
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(group.Value);
|
||||
await w.WriteAsync(member.Value);
|
||||
}
|
||||
|
||||
await w.CompleteAsync();
|
||||
_logger.Information("Added member {MemberId} to groups {GroupIds}", member, groups);
|
||||
}
|
||||
|
||||
public Task RemoveGroupsFromMember(MemberId member, IReadOnlyCollection<GroupId> groups)
|
||||
{
|
||||
_logger.Information("Removed groups from {MemberId}: {GroupIds}", member, groups);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("member_id", member)
|
||||
.WhereIn("group_id", groups);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
// todo: add this to metrics tracking
|
||||
public async Task AddMembersToGroup(GroupId group, IReadOnlyCollection<MemberId> members)
|
||||
{
|
||||
await using var conn = await _db.Obtain();
|
||||
await using var w =
|
||||
conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
|
||||
foreach (var member in members)
|
||||
{
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(group.Value);
|
||||
await w.WriteAsync(member.Value);
|
||||
}
|
||||
|
||||
await w.CompleteAsync();
|
||||
_logger.Information("Added members to {GroupId}: {MemberIds}", group, members);
|
||||
}
|
||||
|
||||
public Task RemoveMembersFromGroup(GroupId group, IReadOnlyCollection<MemberId> members)
|
||||
{
|
||||
_logger.Information("Removed members from {GroupId}: {MemberIds}", group, members);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("group_id", group)
|
||||
.WhereIn("member_id", members);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
public Task ClearGroupMembers(GroupId group)
|
||||
{
|
||||
_logger.Information("Cleared members of {GroupId}", group);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("group_id", group);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
public Task ClearMemberGroups(MemberId member)
|
||||
{
|
||||
_logger.Information("Cleared groups of {GroupId}", member);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("member_id", member);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
var query = new Query("group_members")
|
||||
.Select("groups.*")
|
||||
.Join("groups", "group_members.group_id", "groups.id")
|
||||
.Where("group_members.member_id", id);
|
||||
return _db.QueryStream<PKGroup>(query);
|
||||
}
|
||||
|
||||
public class GroupMember
|
||||
public IAsyncEnumerable<PKMember> GetGroupMembers(GroupId id)
|
||||
{
|
||||
public string Group { get; set; }
|
||||
public string Member { get; set; }
|
||||
public Guid MemberUuid { get; set; }
|
||||
public PrivacyLevel MemberVisibility { get; set; }
|
||||
var query = new Query("group_members")
|
||||
.Select("members.*")
|
||||
.Join("members", "group_members.member_id", "members.id")
|
||||
.Where("group_members.group_id", id);
|
||||
return _db.QueryStream<PKMember>(query);
|
||||
}
|
||||
|
||||
public Task<IEnumerable<GroupMember>> GetGroupMemberInfo(IEnumerable<GroupId> ids)
|
||||
{
|
||||
return _db.Query<GroupMember>(new Query("group_members")
|
||||
.LeftJoin("groups", "groups.id", "group_members.group_id")
|
||||
.LeftJoin("members", "members.id", "group_members.member_id")
|
||||
.Select("groups.hid as group", "members.hid as member", "members.uuid as member_uuid",
|
||||
"members.member_visibility")
|
||||
.WhereIn("group_members.group_id", ids.Select(x => x.Value).ToArray()));
|
||||
}
|
||||
|
||||
// todo: add this to metrics tracking
|
||||
public async Task AddGroupsToMember(MemberId member, IReadOnlyCollection<GroupId> groups)
|
||||
{
|
||||
await using var conn = await _db.Obtain();
|
||||
await using var w =
|
||||
conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
|
||||
foreach (var group in groups)
|
||||
{
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(group.Value);
|
||||
await w.WriteAsync(member.Value);
|
||||
}
|
||||
|
||||
await w.CompleteAsync();
|
||||
_logger.Information("Added member {MemberId} to groups {GroupIds}", member, groups);
|
||||
}
|
||||
|
||||
public Task RemoveGroupsFromMember(MemberId member, IReadOnlyCollection<GroupId> groups)
|
||||
{
|
||||
_logger.Information("Removed groups from {MemberId}: {GroupIds}", member, groups);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("member_id", member)
|
||||
.WhereIn("group_id", groups);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
// todo: add this to metrics tracking
|
||||
public async Task AddMembersToGroup(GroupId group, IReadOnlyCollection<MemberId> members)
|
||||
{
|
||||
await using var conn = await _db.Obtain();
|
||||
await using var w =
|
||||
conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
|
||||
foreach (var member in members)
|
||||
{
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(group.Value);
|
||||
await w.WriteAsync(member.Value);
|
||||
}
|
||||
|
||||
await w.CompleteAsync();
|
||||
_logger.Information("Added members to {GroupId}: {MemberIds}", group, members);
|
||||
}
|
||||
|
||||
public Task RemoveMembersFromGroup(GroupId group, IReadOnlyCollection<MemberId> members)
|
||||
{
|
||||
_logger.Information("Removed members from {GroupId}: {MemberIds}", group, members);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("group_id", group)
|
||||
.WhereIn("member_id", members);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
public Task ClearGroupMembers(GroupId group)
|
||||
{
|
||||
_logger.Information("Cleared members of {GroupId}", group);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("group_id", group);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
|
||||
public Task ClearMemberGroups(MemberId member)
|
||||
{
|
||||
_logger.Information("Cleared groups of {GroupId}", member);
|
||||
var query = new Query("group_members").AsDelete()
|
||||
.Where("member_id", member);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
}
|
||||
|
||||
public class GroupMember
|
||||
{
|
||||
public string Group { get; set; }
|
||||
public string Member { get; set; }
|
||||
public Guid MemberUuid { get; set; }
|
||||
public PrivacyLevel MemberVisibility { get; set; }
|
||||
}
|
@@ -1,77 +1,66 @@
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public Task<GuildConfig> GetGuild(ulong guild)
|
||||
{
|
||||
public Task<GuildConfig> GetGuild(ulong guild)
|
||||
{
|
||||
var query = new Query("servers").AsInsert(new { id = guild });
|
||||
// sqlkata doesn't support postgres on conflict, so we just hack it on here
|
||||
return _db.QueryFirst<GuildConfig>(query, "on conflict (id) do update set id = @$1 returning *");
|
||||
}
|
||||
var query = new Query("servers").AsInsert(new { id = guild });
|
||||
// sqlkata doesn't support postgres on conflict, so we just hack it on here
|
||||
return _db.QueryFirst<GuildConfig>(query, "on conflict (id) do update set id = @$1 returning *");
|
||||
}
|
||||
|
||||
public Task UpdateGuild(ulong guild, GuildPatch patch)
|
||||
{
|
||||
_logger.Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
|
||||
var query = patch.Apply(new Query("servers").Where("id", guild));
|
||||
return _db.ExecuteQuery(query, extraSql: "returning *");
|
||||
}
|
||||
public Task UpdateGuild(ulong guild, GuildPatch patch)
|
||||
{
|
||||
_logger.Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
|
||||
var query = patch.Apply(new Query("servers").Where("id", guild));
|
||||
return _db.ExecuteQuery(query, "returning *");
|
||||
}
|
||||
|
||||
|
||||
public Task<SystemGuildSettings> GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true)
|
||||
{
|
||||
if (!defaultInsert)
|
||||
return _db.QueryFirst<SystemGuildSettings>(new Query("system_guild")
|
||||
.Where("guild", guild)
|
||||
.Where("system", system)
|
||||
);
|
||||
|
||||
var query = new Query("system_guild").AsInsert(new
|
||||
{
|
||||
guild = guild,
|
||||
system = system
|
||||
});
|
||||
return _db.QueryFirst<SystemGuildSettings>(query,
|
||||
extraSql: "on conflict (guild, system) do update set guild = $1, system = $2 returning *"
|
||||
public Task<SystemGuildSettings> GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true)
|
||||
{
|
||||
if (!defaultInsert)
|
||||
return _db.QueryFirst<SystemGuildSettings>(new Query("system_guild")
|
||||
.Where("guild", guild)
|
||||
.Where("system", system)
|
||||
);
|
||||
}
|
||||
|
||||
public async Task<SystemGuildSettings> UpdateSystemGuild(SystemId system, ulong guild, SystemGuildPatch patch)
|
||||
{
|
||||
_logger.Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
|
||||
var query = patch.Apply(new Query("system_guild").Where("system", system).Where("guild", guild));
|
||||
var settings = await _db.QueryFirst<SystemGuildSettings>(query, extraSql: "returning *");
|
||||
_ = _dispatch.Dispatch(system, guild, patch);
|
||||
return settings;
|
||||
}
|
||||
var query = new Query("system_guild").AsInsert(new { guild, system });
|
||||
return _db.QueryFirst<SystemGuildSettings>(query,
|
||||
"on conflict (guild, system) do update set guild = $1, system = $2 returning *"
|
||||
);
|
||||
}
|
||||
|
||||
public Task<MemberGuildSettings> GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true)
|
||||
{
|
||||
if (!defaultInsert)
|
||||
return _db.QueryFirst<MemberGuildSettings>(new Query("member_guild")
|
||||
.Where("guild", guild)
|
||||
.Where("member", member)
|
||||
);
|
||||
public async Task<SystemGuildSettings> UpdateSystemGuild(SystemId system, ulong guild, SystemGuildPatch patch)
|
||||
{
|
||||
_logger.Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
|
||||
var query = patch.Apply(new Query("system_guild").Where("system", system).Where("guild", guild));
|
||||
var settings = await _db.QueryFirst<SystemGuildSettings>(query, "returning *");
|
||||
_ = _dispatch.Dispatch(system, guild, patch);
|
||||
return settings;
|
||||
}
|
||||
|
||||
var query = new Query("member_guild").AsInsert(new
|
||||
{
|
||||
guild = guild,
|
||||
member = member
|
||||
});
|
||||
return _db.QueryFirst<MemberGuildSettings>(query,
|
||||
extraSql: "on conflict (guild, member) do update set guild = $1, member = $2 returning *"
|
||||
public Task<MemberGuildSettings> GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true)
|
||||
{
|
||||
if (!defaultInsert)
|
||||
return _db.QueryFirst<MemberGuildSettings>(new Query("member_guild")
|
||||
.Where("guild", guild)
|
||||
.Where("member", member)
|
||||
);
|
||||
}
|
||||
|
||||
public Task<MemberGuildSettings> UpdateMemberGuild(MemberId member, ulong guild, MemberGuildPatch patch)
|
||||
{
|
||||
_logger.Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
|
||||
var query = patch.Apply(new Query("member_guild").Where("member", member).Where("guild", guild));
|
||||
_ = _dispatch.Dispatch(member, guild, patch);
|
||||
return _db.QueryFirst<MemberGuildSettings>(query, extraSql: "returning *");
|
||||
}
|
||||
var query = new Query("member_guild").AsInsert(new { guild, member });
|
||||
return _db.QueryFirst<MemberGuildSettings>(query,
|
||||
"on conflict (guild, member) do update set guild = $1, member = $2 returning *"
|
||||
);
|
||||
}
|
||||
|
||||
public Task<MemberGuildSettings> UpdateMemberGuild(MemberId member, ulong guild, MemberGuildPatch patch)
|
||||
{
|
||||
_logger.Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
|
||||
var query = patch.Apply(new Query("member_guild").Where("member", member).Where("guild", guild));
|
||||
_ = _dispatch.Dispatch(member, guild, patch);
|
||||
return _db.QueryFirst<MemberGuildSettings>(query, "returning *");
|
||||
}
|
||||
}
|
@@ -1,102 +1,92 @@
|
||||
#nullable enable
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Newtonsoft.Json.Linq;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public Task<PKMember?> GetMember(MemberId id)
|
||||
{
|
||||
public Task<PKMember?> GetMember(MemberId id)
|
||||
var query = new Query("members").Where("id", id);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
|
||||
public Task<PKMember?> GetMemberByHid(string hid, SystemId? system = null)
|
||||
{
|
||||
var query = new Query("members").Where("hid", hid.ToLower());
|
||||
if (system != null)
|
||||
query = query.Where("system", system);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
|
||||
public Task<PKMember?> GetMemberByGuid(Guid uuid)
|
||||
{
|
||||
var query = new Query("members").Where("uuid", uuid);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
|
||||
public Task<PKMember?> GetMemberByName(SystemId system, string name)
|
||||
{
|
||||
var query = new Query("members").WhereRaw(
|
||||
"lower(name) = lower(?)",
|
||||
name.ToLower()
|
||||
).Where("system", system);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
|
||||
public Task<PKMember?> GetMemberByDisplayName(SystemId system, string name)
|
||||
{
|
||||
var query = new Query("members").WhereRaw(
|
||||
"lower(display_name) = lower(?)",
|
||||
name.ToLower()
|
||||
).Where("system", system);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
|
||||
public Task<IEnumerable<Guid>> GetMemberGuids(IEnumerable<MemberId> ids)
|
||||
{
|
||||
var query = new Query("members")
|
||||
.Select("uuid")
|
||||
.WhereIn("id", ids);
|
||||
|
||||
return _db.Query<Guid>(query);
|
||||
}
|
||||
|
||||
public async Task<PKMember> CreateMember(SystemId systemId, string memberName, IPKConnection? conn = null)
|
||||
{
|
||||
var query = new Query("members").AsInsert(new
|
||||
{
|
||||
var query = new Query("members").Where("id", id);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
hid = new UnsafeLiteral("find_free_member_hid()"),
|
||||
system = systemId,
|
||||
name = memberName
|
||||
});
|
||||
var member = await _db.QueryFirst<PKMember>(conn, query, "returning *");
|
||||
_logger.Information("Created {MemberId} in {SystemId}: {MemberName}",
|
||||
member.Id, systemId, memberName);
|
||||
return member;
|
||||
}
|
||||
|
||||
public Task<PKMember?> GetMemberByHid(string hid, SystemId? system = null)
|
||||
{
|
||||
var query = new Query("members").Where("hid", hid.ToLower());
|
||||
if (system != null)
|
||||
query = query.Where("system", system);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
public Task<PKMember> UpdateMember(MemberId id, MemberPatch patch, IPKConnection? conn = null)
|
||||
{
|
||||
_logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("members").Where("id", id));
|
||||
|
||||
public Task<PKMember?> GetMemberByGuid(Guid uuid)
|
||||
{
|
||||
var query = new Query("members").Where("uuid", uuid);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
if (conn == null)
|
||||
_ = _dispatch.Dispatch(id,
|
||||
new UpdateDispatchData { Event = DispatchEvent.UPDATE_MEMBER, EventData = patch.ToJson() });
|
||||
return _db.QueryFirst<PKMember>(conn, query, "returning *");
|
||||
}
|
||||
|
||||
public Task<PKMember?> GetMemberByName(SystemId system, string name)
|
||||
{
|
||||
var query = new Query("members").WhereRaw(
|
||||
"lower(name) = lower(?)",
|
||||
name.ToLower()
|
||||
).Where("system", system);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
public async Task DeleteMember(MemberId id)
|
||||
{
|
||||
var oldMember = await GetMember(id);
|
||||
|
||||
public Task<PKMember?> GetMemberByDisplayName(SystemId system, string name)
|
||||
{
|
||||
var query = new Query("members").WhereRaw(
|
||||
"lower(display_name) = lower(?)",
|
||||
name.ToLower()
|
||||
).Where("system", system);
|
||||
return _db.QueryFirst<PKMember?>(query);
|
||||
}
|
||||
_logger.Information("Deleted {MemberId}", id);
|
||||
var query = new Query("members").AsDelete().Where("id", id);
|
||||
await _db.ExecuteQuery(query);
|
||||
|
||||
public Task<IEnumerable<Guid>> GetMemberGuids(IEnumerable<MemberId> ids)
|
||||
{
|
||||
var query = new Query("members")
|
||||
.Select("uuid")
|
||||
.WhereIn("id", ids);
|
||||
|
||||
return _db.Query<Guid>(query);
|
||||
}
|
||||
|
||||
public async Task<PKMember> CreateMember(SystemId systemId, string memberName, IPKConnection? conn = null)
|
||||
{
|
||||
var query = new Query("members").AsInsert(new
|
||||
{
|
||||
hid = new UnsafeLiteral("find_free_member_hid()"),
|
||||
system = systemId,
|
||||
name = memberName
|
||||
});
|
||||
var member = await _db.QueryFirst<PKMember>(conn, query, "returning *");
|
||||
_logger.Information("Created {MemberId} in {SystemId}: {MemberName}",
|
||||
member.Id, systemId, memberName);
|
||||
return member;
|
||||
}
|
||||
|
||||
public Task<PKMember> UpdateMember(MemberId id, MemberPatch patch, IPKConnection? conn = null)
|
||||
{
|
||||
_logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("members").Where("id", id));
|
||||
|
||||
if (conn == null)
|
||||
_ = _dispatch.Dispatch(id, new()
|
||||
{
|
||||
Event = DispatchEvent.UPDATE_MEMBER,
|
||||
EventData = patch.ToJson(),
|
||||
});
|
||||
return _db.QueryFirst<PKMember>(conn, query, extraSql: "returning *");
|
||||
}
|
||||
|
||||
public async Task DeleteMember(MemberId id)
|
||||
{
|
||||
var oldMember = await GetMember(id);
|
||||
|
||||
_logger.Information("Deleted {MemberId}", id);
|
||||
var query = new Query("members").AsDelete().Where("id", id);
|
||||
await _db.ExecuteQuery(query);
|
||||
|
||||
// shh, compiler
|
||||
if (oldMember != null)
|
||||
_ = _dispatch.Dispatch(oldMember.System, oldMember.Uuid, DispatchEvent.DELETE_MEMBER);
|
||||
}
|
||||
// shh, compiler
|
||||
if (oldMember != null)
|
||||
_ = _dispatch.Dispatch(oldMember.System, oldMember.Uuid, DispatchEvent.DELETE_MEMBER);
|
||||
}
|
||||
}
|
@@ -1,74 +1,69 @@
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public Task AddMessage(PKMessage msg)
|
||||
{
|
||||
public Task AddMessage(PKMessage msg)
|
||||
var query = new Query("messages").AsInsert(new
|
||||
{
|
||||
var query = new Query("messages").AsInsert(new
|
||||
{
|
||||
mid = msg.Mid,
|
||||
guild = msg.Guild,
|
||||
channel = msg.Channel,
|
||||
member = msg.Member,
|
||||
sender = msg.Sender,
|
||||
original_mid = msg.OriginalMid,
|
||||
});
|
||||
_logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
|
||||
mid = msg.Mid,
|
||||
guild = msg.Guild,
|
||||
channel = msg.Channel,
|
||||
member = msg.Member,
|
||||
sender = msg.Sender,
|
||||
original_mid = msg.OriginalMid
|
||||
});
|
||||
_logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
|
||||
|
||||
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
|
||||
return _db.ExecuteQuery(query, extraSql: "on conflict do nothing");
|
||||
}
|
||||
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
|
||||
return _db.ExecuteQuery(query, "on conflict do nothing");
|
||||
}
|
||||
|
||||
// todo: add a Mapper to QuerySingle and move this to SqlKata
|
||||
public async Task<FullMessage?> GetMessage(IPKConnection conn, ulong id)
|
||||
{
|
||||
FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) =>
|
||||
new FullMessage { Message = msg, System = system, Member = member };
|
||||
// todo: add a Mapper to QuerySingle and move this to SqlKata
|
||||
public async Task<FullMessage?> GetMessage(IPKConnection conn, ulong id)
|
||||
{
|
||||
FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) =>
|
||||
new() { Message = msg, System = system, Member = member };
|
||||
|
||||
var result = await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>(
|
||||
"select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system",
|
||||
Mapper, new { Id = id });
|
||||
return result.FirstOrDefault();
|
||||
}
|
||||
var result = await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>(
|
||||
"select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system",
|
||||
Mapper, new { Id = id });
|
||||
return result.FirstOrDefault();
|
||||
}
|
||||
|
||||
public async Task DeleteMessage(ulong id)
|
||||
{
|
||||
var query = new Query("messages").AsDelete().Where("mid", id);
|
||||
var rowCount = await _db.ExecuteQuery(query);
|
||||
if (rowCount > 0)
|
||||
_logger.Information("Deleted message {MessageId} from database", id);
|
||||
}
|
||||
public async Task DeleteMessage(ulong id)
|
||||
{
|
||||
var query = new Query("messages").AsDelete().Where("mid", id);
|
||||
var rowCount = await _db.ExecuteQuery(query);
|
||||
if (rowCount > 0)
|
||||
_logger.Information("Deleted message {MessageId} from database", id);
|
||||
}
|
||||
|
||||
public async Task DeleteMessagesBulk(IReadOnlyCollection<ulong> ids)
|
||||
{
|
||||
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
|
||||
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
|
||||
var query = new Query("messages").AsDelete().WhereIn("mid", ids.Select(id => (long)id).ToArray());
|
||||
var rowCount = await _db.ExecuteQuery(query);
|
||||
if (rowCount > 0)
|
||||
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
|
||||
ids);
|
||||
}
|
||||
public async Task DeleteMessagesBulk(IReadOnlyCollection<ulong> ids)
|
||||
{
|
||||
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
|
||||
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
|
||||
var query = new Query("messages").AsDelete().WhereIn("mid", ids.Select(id => (long)id).ToArray());
|
||||
var rowCount = await _db.ExecuteQuery(query);
|
||||
if (rowCount > 0)
|
||||
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
|
||||
ids);
|
||||
}
|
||||
|
||||
public Task<PKMessage?> GetLastMessage(ulong guildId, ulong channelId, ulong accountId)
|
||||
{
|
||||
// Want to index scan on the (guild, sender, mid) index so need the additional constraint
|
||||
var query = new Query("messages")
|
||||
.Where("guild", guildId)
|
||||
.Where("channel", channelId)
|
||||
.Where("sender", accountId)
|
||||
.OrderByDesc("mid")
|
||||
.Limit(1);
|
||||
public Task<PKMessage?> GetLastMessage(ulong guildId, ulong channelId, ulong accountId)
|
||||
{
|
||||
// Want to index scan on the (guild, sender, mid) index so need the additional constraint
|
||||
var query = new Query("messages")
|
||||
.Where("guild", guildId)
|
||||
.Where("channel", channelId)
|
||||
.Where("sender", accountId)
|
||||
.OrderByDesc("mid")
|
||||
.Limit(1);
|
||||
|
||||
return _db.QueryFirst<PKMessage?>(query);
|
||||
}
|
||||
return _db.QueryFirst<PKMessage?>(query);
|
||||
}
|
||||
}
|
@@ -1,30 +1,26 @@
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
using NodaTime;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public Task<IEnumerable<PKShardInfo>> GetShards() =>
|
||||
_db.Execute(conn => conn.QueryAsync<PKShardInfo>("select * from shards order by id"));
|
||||
public Task<IEnumerable<PKShardInfo>> GetShards() =>
|
||||
_db.Execute(conn => conn.QueryAsync<PKShardInfo>("select * from shards order by id"));
|
||||
|
||||
public Task SetShardStatus(IPKConnection conn, int shard, PKShardInfo.ShardStatus status) =>
|
||||
conn.ExecuteAsync(
|
||||
"insert into shards (id, status) values (@Id, @Status) on conflict (id) do update set status = @Status",
|
||||
new { Id = shard, Status = status });
|
||||
public Task SetShardStatus(IPKConnection conn, int shard, PKShardInfo.ShardStatus status) =>
|
||||
conn.ExecuteAsync(
|
||||
"insert into shards (id, status) values (@Id, @Status) on conflict (id) do update set status = @Status",
|
||||
new { Id = shard, Status = status });
|
||||
|
||||
public Task RegisterShardHeartbeat(IPKConnection conn, int shard, Duration ping) =>
|
||||
conn.ExecuteAsync(
|
||||
"insert into shards (id, last_heartbeat, ping) values (@Id, now(), @Ping) on conflict (id) do update set last_heartbeat = now(), ping = @Ping",
|
||||
new { Id = shard, Ping = ping.TotalSeconds });
|
||||
public Task RegisterShardHeartbeat(IPKConnection conn, int shard, Duration ping) =>
|
||||
conn.ExecuteAsync(
|
||||
"insert into shards (id, last_heartbeat, ping) values (@Id, now(), @Ping) on conflict (id) do update set last_heartbeat = now(), ping = @Ping",
|
||||
new { Id = shard, Ping = ping.TotalSeconds });
|
||||
|
||||
public Task RegisterShardConnection(IPKConnection conn, int shard) =>
|
||||
conn.ExecuteAsync(
|
||||
"insert into shards (id, last_connection) values (@Id, now()) on conflict (id) do update set last_connection = now()",
|
||||
new { Id = shard });
|
||||
}
|
||||
public Task RegisterShardConnection(IPKConnection conn, int shard) =>
|
||||
conn.ExecuteAsync(
|
||||
"insert into shards (id, last_connection) values (@Id, now()) on conflict (id) do update set last_connection = now()",
|
||||
new { Id = shard });
|
||||
}
|
@@ -1,33 +1,32 @@
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
using SqlKata;
|
||||
namespace PluralKit.Core;
|
||||
|
||||
namespace PluralKit.Core
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public async Task UpdateStats()
|
||||
{
|
||||
public async Task UpdateStats()
|
||||
{
|
||||
await _db.Execute(conn => conn.ExecuteAsync("update info set system_count = (select count(*) from systems)"));
|
||||
await _db.Execute(conn => conn.ExecuteAsync("update info set member_count = (select count(*) from members)"));
|
||||
await _db.Execute(conn => conn.ExecuteAsync("update info set group_count = (select count(*) from groups)"));
|
||||
await _db.Execute(conn => conn.ExecuteAsync("update info set switch_count = (select count(*) from switches)"));
|
||||
await _db.Execute(conn => conn.ExecuteAsync("update info set message_count = (select count(*) from messages)"));
|
||||
}
|
||||
await _db.Execute(conn =>
|
||||
conn.ExecuteAsync("update info set system_count = (select count(*) from systems)"));
|
||||
await _db.Execute(conn =>
|
||||
conn.ExecuteAsync("update info set member_count = (select count(*) from members)"));
|
||||
await _db.Execute(conn =>
|
||||
conn.ExecuteAsync("update info set group_count = (select count(*) from groups)"));
|
||||
await _db.Execute(conn =>
|
||||
conn.ExecuteAsync("update info set switch_count = (select count(*) from switches)"));
|
||||
await _db.Execute(conn =>
|
||||
conn.ExecuteAsync("update info set message_count = (select count(*) from messages)"));
|
||||
}
|
||||
|
||||
public Task<Counts> GetStats()
|
||||
=> _db.Execute(conn => conn.QuerySingleAsync<Counts>("select * from info"));
|
||||
public Task<Counts> GetStats()
|
||||
=> _db.Execute(conn => conn.QuerySingleAsync<Counts>("select * from info"));
|
||||
|
||||
public class Counts
|
||||
{
|
||||
public int SystemCount { get; }
|
||||
public int MemberCount { get; }
|
||||
public int GroupCount { get; }
|
||||
public int SwitchCount { get; }
|
||||
public int MessageCount { get; }
|
||||
}
|
||||
public class Counts
|
||||
{
|
||||
public int SystemCount { get; }
|
||||
public int MemberCount { get; }
|
||||
public int GroupCount { get; }
|
||||
public int SwitchCount { get; }
|
||||
public int MessageCount { get; }
|
||||
}
|
||||
}
|
@@ -1,8 +1,3 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
using Newtonsoft.Json.Linq;
|
||||
@@ -13,314 +8,316 @@ using NpgsqlTypes;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
// todo: move the rest of the queries in here to SqlKata, if possible
|
||||
public partial class ModelRepository
|
||||
{
|
||||
// todo: move the rest of the queries in here to SqlKata, if possible
|
||||
public partial class ModelRepository
|
||||
public async Task<PKSwitch> AddSwitch(IPKConnection conn, SystemId system,
|
||||
IReadOnlyCollection<MemberId> members)
|
||||
{
|
||||
public async Task<PKSwitch> AddSwitch(IPKConnection conn, SystemId system, IReadOnlyCollection<MemberId> members)
|
||||
// Use a transaction here since we're doing multiple executed commands in one
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
// First, we insert the switch itself
|
||||
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
|
||||
new { System = system });
|
||||
|
||||
// Then we insert each member in the switch in the switch_members table
|
||||
await using (var w =
|
||||
conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
|
||||
{
|
||||
// Use a transaction here since we're doing multiple executed commands in one
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
// First, we insert the switch itself
|
||||
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
|
||||
new { System = system });
|
||||
|
||||
// Then we insert each member in the switch in the switch_members table
|
||||
await using (var w = conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
|
||||
foreach (var member in members)
|
||||
{
|
||||
foreach (var member in members)
|
||||
{
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(sw.Id.Value, NpgsqlDbType.Integer);
|
||||
await w.WriteAsync(member.Value, NpgsqlDbType.Integer);
|
||||
}
|
||||
|
||||
await w.CompleteAsync();
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(sw.Id.Value, NpgsqlDbType.Integer);
|
||||
await w.WriteAsync(member.Value, NpgsqlDbType.Integer);
|
||||
}
|
||||
|
||||
// Finally we commit the tx, since the using block will otherwise rollback it
|
||||
await tx.CommitAsync();
|
||||
|
||||
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members);
|
||||
_ = _dispatch.Dispatch(sw.Id, new()
|
||||
{
|
||||
Event = DispatchEvent.CREATE_SWITCH,
|
||||
EventData = JObject.FromObject(new
|
||||
{
|
||||
id = sw.Uuid.ToString(),
|
||||
timestamp = sw.Timestamp.FormatExport(),
|
||||
members = await GetMemberGuids(members),
|
||||
}),
|
||||
});
|
||||
return sw;
|
||||
await w.CompleteAsync();
|
||||
}
|
||||
|
||||
public async Task EditSwitch(IPKConnection conn, SwitchId switchId, IReadOnlyCollection<MemberId> members)
|
||||
// Finally we commit the tx, since the using block will otherwise rollback it
|
||||
await tx.CommitAsync();
|
||||
|
||||
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members);
|
||||
_ = _dispatch.Dispatch(sw.Id, new UpdateDispatchData
|
||||
{
|
||||
// Use a transaction here since we're doing multiple executed commands in one
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
// Remove the old members from the switch
|
||||
await conn.ExecuteAsync("delete from switch_members where switch = @Switch",
|
||||
new { Switch = switchId });
|
||||
|
||||
// Add the new members
|
||||
await using (var w = conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
|
||||
Event = DispatchEvent.CREATE_SWITCH,
|
||||
EventData = JObject.FromObject(new
|
||||
{
|
||||
foreach (var member in members)
|
||||
{
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(switchId.Value, NpgsqlDbType.Integer);
|
||||
await w.WriteAsync(member.Value, NpgsqlDbType.Integer);
|
||||
}
|
||||
id = sw.Uuid.ToString(),
|
||||
timestamp = sw.Timestamp.FormatExport(),
|
||||
members = await GetMemberGuids(members),
|
||||
}),
|
||||
});
|
||||
return sw;
|
||||
}
|
||||
|
||||
await w.CompleteAsync();
|
||||
public async Task EditSwitch(IPKConnection conn, SwitchId switchId, IReadOnlyCollection<MemberId> members)
|
||||
{
|
||||
// Use a transaction here since we're doing multiple executed commands in one
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
// Remove the old members from the switch
|
||||
await conn.ExecuteAsync("delete from switch_members where switch = @Switch",
|
||||
new { Switch = switchId });
|
||||
|
||||
// Add the new members
|
||||
await using (var w =
|
||||
conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
|
||||
{
|
||||
foreach (var member in members)
|
||||
{
|
||||
await w.StartRowAsync();
|
||||
await w.WriteAsync(switchId.Value, NpgsqlDbType.Integer);
|
||||
await w.WriteAsync(member.Value, NpgsqlDbType.Integer);
|
||||
}
|
||||
|
||||
// Finally we commit the tx, since the using block will otherwise rollback it
|
||||
await tx.CommitAsync();
|
||||
await w.CompleteAsync();
|
||||
}
|
||||
|
||||
_ = _dispatch.Dispatch(switchId, new()
|
||||
// Finally we commit the tx, since the using block will otherwise rollback it
|
||||
await tx.CommitAsync();
|
||||
|
||||
_ = _dispatch.Dispatch(switchId, new UpdateDispatchData
|
||||
{
|
||||
Event = DispatchEvent.UPDATE_SWITCH_MEMBERS,
|
||||
EventData = JObject.FromObject(new
|
||||
{
|
||||
Event = DispatchEvent.UPDATE_SWITCH_MEMBERS,
|
||||
EventData = JObject.FromObject(new
|
||||
{
|
||||
members = await GetMemberGuids(members),
|
||||
}),
|
||||
});
|
||||
members = await GetMemberGuids(members),
|
||||
}),
|
||||
});
|
||||
|
||||
_logger.Information("Updated {SwitchId} members: {Members}", switchId, members);
|
||||
}
|
||||
_logger.Information("Updated {SwitchId} members: {Members}", switchId, members);
|
||||
}
|
||||
|
||||
public async Task MoveSwitch(SwitchId id, Instant time)
|
||||
public async Task MoveSwitch(SwitchId id, Instant time)
|
||||
{
|
||||
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", id, time);
|
||||
var query = new Query("switches").AsUpdate(new { timestamp = time }).Where("id", id);
|
||||
await _db.ExecuteQuery(query);
|
||||
_ = _dispatch.Dispatch(id, new UpdateDispatchData
|
||||
{
|
||||
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", id, time);
|
||||
var query = new Query("switches").AsUpdate(new { timestamp = time }).Where("id", id);
|
||||
await _db.ExecuteQuery(query);
|
||||
_ = _dispatch.Dispatch(id, new()
|
||||
Event = DispatchEvent.UPDATE_SWITCH,
|
||||
EventData = JObject.FromObject(new
|
||||
{
|
||||
Event = DispatchEvent.UPDATE_SWITCH,
|
||||
EventData = JObject.FromObject(new
|
||||
{
|
||||
timestamp = time.FormatExport(),
|
||||
}),
|
||||
});
|
||||
}
|
||||
timestamp = time.FormatExport(),
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
public async Task DeleteSwitch(SwitchId id)
|
||||
public async Task DeleteSwitch(SwitchId id)
|
||||
{
|
||||
var existingSwitch = await GetSwitch(id);
|
||||
|
||||
var query = new Query("switches").AsDelete().Where("id", id);
|
||||
await _db.ExecuteQuery(query);
|
||||
_logger.Information("Deleted {Switch}", id);
|
||||
_ = _dispatch.Dispatch(existingSwitch.System, existingSwitch.Uuid, DispatchEvent.DELETE_SWITCH);
|
||||
}
|
||||
|
||||
public async Task DeleteAllSwitches(SystemId system)
|
||||
{
|
||||
_logger.Information("Deleted all switches in {SystemId}", system);
|
||||
var query = new Query("switches").AsDelete().Where("system", system);
|
||||
await _db.ExecuteQuery(query);
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData
|
||||
{
|
||||
var existingSwitch = await GetSwitch(id);
|
||||
Event = DispatchEvent.DELETE_ALL_SWITCHES
|
||||
});
|
||||
}
|
||||
|
||||
var query = new Query("switches").AsDelete().Where("id", id);
|
||||
await _db.ExecuteQuery(query);
|
||||
_logger.Information("Deleted {Switch}", id);
|
||||
_ = _dispatch.Dispatch(existingSwitch.System, existingSwitch.Uuid, DispatchEvent.DELETE_SWITCH);
|
||||
}
|
||||
public IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system)
|
||||
{
|
||||
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
|
||||
var query = new Query("switches").Where("system", system).OrderByDesc("timestamp");
|
||||
return _db.QueryStream<PKSwitch>(query);
|
||||
}
|
||||
|
||||
public async Task DeleteAllSwitches(SystemId system)
|
||||
{
|
||||
_logger.Information("Deleted all switches in {SystemId}", system);
|
||||
var query = new Query("switches").AsDelete().Where("system", system);
|
||||
await _db.ExecuteQuery(query);
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData()
|
||||
{
|
||||
Event = DispatchEvent.DELETE_ALL_SWITCHES
|
||||
});
|
||||
}
|
||||
public Task<PKSwitch?> GetSwitch(SwitchId id)
|
||||
=> _db.QueryFirst<PKSwitch?>(new Query("switches").Where("id", id));
|
||||
|
||||
public IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system)
|
||||
{
|
||||
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
|
||||
var query = new Query("switches").Where("system", system).OrderByDesc("timestamp");
|
||||
return _db.QueryStream<PKSwitch>(query);
|
||||
}
|
||||
public Task<PKSwitch> GetSwitchByUuid(Guid uuid)
|
||||
{
|
||||
var query = new Query("switches").Where("uuid", uuid);
|
||||
return _db.QueryFirst<PKSwitch>(query);
|
||||
}
|
||||
|
||||
public Task<PKSwitch?> GetSwitch(SwitchId id)
|
||||
=> _db.QueryFirst<PKSwitch?>(new Query("switches").Where("id", id));
|
||||
public Task<int> GetSwitchCount(SystemId system)
|
||||
{
|
||||
var query = new Query("switches").SelectRaw("count(*)").Where("system", system);
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
|
||||
public Task<PKSwitch> GetSwitchByUuid(Guid uuid)
|
||||
{
|
||||
var query = new Query("switches").Where("uuid", uuid);
|
||||
return _db.QueryFirst<PKSwitch>(query);
|
||||
}
|
||||
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(IPKConnection conn,
|
||||
SystemId system, Instant start,
|
||||
Instant end)
|
||||
{
|
||||
// Wrap multiple commands in a single transaction for performance
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
public Task<int> GetSwitchCount(SystemId system)
|
||||
{
|
||||
var query = new Query("switches").SelectRaw("count(*)").Where("system", system);
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
// Find the time of the last switch outside the range as it overlaps the range
|
||||
// If no prior switch exists, the lower bound of the range remains the start time
|
||||
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
|
||||
@"SELECT COALESCE(MAX(timestamp), @Start)
|
||||
FROM switches
|
||||
WHERE switches.system = @System
|
||||
AND switches.timestamp < @Start",
|
||||
new { System = system, Start = start });
|
||||
|
||||
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(IPKConnection conn,
|
||||
SystemId system, Instant start, Instant end)
|
||||
{
|
||||
// Wrap multiple commands in a single transaction for performance
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
// Then collect the time and members of all switches that overlap the range
|
||||
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
|
||||
@"SELECT switch_members.member, switches.timestamp
|
||||
FROM switches
|
||||
LEFT JOIN switch_members
|
||||
ON switches.id = switch_members.switch
|
||||
WHERE switches.system = @System
|
||||
AND (
|
||||
switches.timestamp >= @Start
|
||||
OR switches.timestamp = @LastSwitch
|
||||
)
|
||||
AND switches.timestamp < @End
|
||||
ORDER BY switches.timestamp DESC",
|
||||
new { System = system, Start = start, End = end, LastSwitch = lastSwitch });
|
||||
|
||||
// Find the time of the last switch outside the range as it overlaps the range
|
||||
// If no prior switch exists, the lower bound of the range remains the start time
|
||||
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
|
||||
@"SELECT COALESCE(MAX(timestamp), @Start)
|
||||
FROM switches
|
||||
WHERE switches.system = @System
|
||||
AND switches.timestamp < @Start",
|
||||
new { System = system, Start = start });
|
||||
// Yield each value here
|
||||
await foreach (var entry in switchMembersEntries)
|
||||
yield return entry;
|
||||
|
||||
// Then collect the time and members of all switches that overlap the range
|
||||
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
|
||||
@"SELECT switch_members.member, switches.timestamp
|
||||
FROM switches
|
||||
LEFT JOIN switch_members
|
||||
ON switches.id = switch_members.switch
|
||||
WHERE switches.system = @System
|
||||
AND (
|
||||
switches.timestamp >= @Start
|
||||
OR switches.timestamp = @LastSwitch
|
||||
)
|
||||
AND switches.timestamp < @End
|
||||
ORDER BY switches.timestamp DESC",
|
||||
new { System = system, Start = start, End = end, LastSwitch = lastSwitch });
|
||||
// Don't really need to worry about the transaction here, we're not doing any *writes*
|
||||
}
|
||||
|
||||
// Yield each value here
|
||||
await foreach (var entry in switchMembersEntries)
|
||||
yield return entry;
|
||||
public IAsyncEnumerable<PKMember> GetSwitchMembers(IPKConnection conn, SwitchId sw) =>
|
||||
conn.QueryStreamAsync<PKMember>(
|
||||
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
|
||||
new { Switch = sw });
|
||||
|
||||
// Don't really need to worry about the transaction here, we're not doing any *writes*
|
||||
}
|
||||
public Task<PKSwitch> GetLatestSwitch(SystemId system)
|
||||
{
|
||||
var query = new Query("switches").Where("system", system).OrderByDesc("timestamp").Limit(1);
|
||||
return _db.QueryFirst<PKSwitch>(query);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<PKMember> GetSwitchMembers(IPKConnection conn, SwitchId sw)
|
||||
{
|
||||
return conn.QueryStreamAsync<PKMember>(
|
||||
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
|
||||
new { Switch = sw });
|
||||
}
|
||||
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(IPKConnection conn,
|
||||
SystemId system, GroupId? group, Instant periodStart, Instant periodEnd)
|
||||
{
|
||||
// TODO: IAsyncEnumerable-ify this one
|
||||
// TODO: this doesn't belong in the repo
|
||||
|
||||
public Task<PKSwitch> GetLatestSwitch(SystemId system)
|
||||
{
|
||||
var query = new Query("switches").Where("system", system).OrderByDesc("timestamp").Limit(1);
|
||||
return _db.QueryFirst<PKSwitch>(query);
|
||||
}
|
||||
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
|
||||
var switchMembers = await GetSwitchMembersList(conn, system, periodStart, periodEnd).ToListAsync();
|
||||
|
||||
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(IPKConnection conn,
|
||||
SystemId system, GroupId? group, Instant periodStart,
|
||||
Instant periodEnd)
|
||||
{
|
||||
// TODO: IAsyncEnumerable-ify this one
|
||||
// TODO: this doesn't belong in the repo
|
||||
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
|
||||
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
|
||||
// key used in GetPerMemberSwitchDuration below
|
||||
var membersList = await conn.QueryAsync<PKMember>(
|
||||
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
|
||||
new { Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList() });
|
||||
var memberObjects = membersList.ToDictionary(m => m.Id);
|
||||
|
||||
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
|
||||
var switchMembers = await GetSwitchMembersList(conn, system, periodStart, periodEnd).ToListAsync();
|
||||
|
||||
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
|
||||
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
|
||||
// key used in GetPerMemberSwitchDuration below
|
||||
var membersList = await conn.QueryAsync<PKMember>(
|
||||
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
|
||||
new { Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList() });
|
||||
var memberObjects = membersList.ToDictionary(m => m.Id);
|
||||
|
||||
// check if a group ID is provided. if so, query DB for all members of said group, otherwise use membersList
|
||||
var groupMembersList = group != null ? await conn.QueryAsync<PKMember>(
|
||||
// check if a group ID is provided. if so, query DB for all members of said group, otherwise use membersList
|
||||
var groupMembersList = group != null
|
||||
? await conn.QueryAsync<PKMember>(
|
||||
"select * from members inner join group_members on members.id = group_members.member_id where group_id = @id",
|
||||
new { id = group }) : membersList;
|
||||
var groupMemberObjects = groupMembersList.ToDictionary(m => m.Id);
|
||||
new { id = group })
|
||||
: membersList;
|
||||
var groupMemberObjects = groupMembersList.ToDictionary(m => m.Id);
|
||||
|
||||
// Initialize entries - still need to loop to determine the TimespanEnd below
|
||||
// use groupMemberObjects to make sure no members outside of the specified group (if present) are selected
|
||||
var entries =
|
||||
from item in switchMembers
|
||||
group item by item.Timestamp
|
||||
into g
|
||||
select new SwitchListEntry
|
||||
{
|
||||
TimespanStart = g.Key,
|
||||
Members = g.Where(x => x.Member != default(MemberId) && groupMemberObjects.Any(m => x.Member == m.Key)).Select(x => memberObjects[x.Member])
|
||||
.ToList()
|
||||
};
|
||||
|
||||
// Loop through every switch that overlaps the range and add it to the output list
|
||||
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
|
||||
var endTime = periodEnd;
|
||||
var outList = new List<SwitchListEntry>();
|
||||
foreach (var e in entries)
|
||||
// Initialize entries - still need to loop to determine the TimespanEnd below
|
||||
// use groupMemberObjects to make sure no members outside of the specified group (if present) are selected
|
||||
var entries =
|
||||
from item in switchMembers
|
||||
group item by item.Timestamp
|
||||
into g
|
||||
select new SwitchListEntry
|
||||
{
|
||||
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
|
||||
var switchStartClamped = e.TimespanStart < periodStart
|
||||
? periodStart
|
||||
: e.TimespanStart;
|
||||
|
||||
outList.Add(new SwitchListEntry
|
||||
{
|
||||
Members = e.Members,
|
||||
TimespanStart = switchStartClamped,
|
||||
TimespanEnd = endTime
|
||||
});
|
||||
|
||||
// next switch's end is this switch's start (we're working backward in time)
|
||||
endTime = e.TimespanStart;
|
||||
}
|
||||
|
||||
return outList;
|
||||
}
|
||||
|
||||
public async Task<FrontBreakdown> GetFrontBreakdown(IPKConnection conn, SystemId system, GroupId? group, Instant periodStart,
|
||||
Instant periodEnd)
|
||||
{
|
||||
// TODO: this doesn't belong in the repo
|
||||
var dict = new Dictionary<PKMember, Duration>();
|
||||
|
||||
var noFronterDuration = Duration.Zero;
|
||||
|
||||
// Sum up all switch durations for each member
|
||||
// switches with multiple members will result in the duration to add up to more than the actual period range
|
||||
|
||||
var actualStart = periodEnd; // will be "pulled" down
|
||||
var actualEnd = periodStart; // will be "pulled" up
|
||||
|
||||
foreach (var sw in await GetPeriodFronters(conn, system, group, periodStart, periodEnd))
|
||||
{
|
||||
var span = sw.TimespanEnd - sw.TimespanStart;
|
||||
foreach (var member in sw.Members)
|
||||
{
|
||||
if (!dict.ContainsKey(member)) dict.Add(member, span);
|
||||
else dict[member] += span;
|
||||
}
|
||||
|
||||
if (sw.Members.Count == 0) noFronterDuration += span;
|
||||
|
||||
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
|
||||
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
|
||||
}
|
||||
|
||||
return new FrontBreakdown
|
||||
{
|
||||
MemberSwitchDurations = dict,
|
||||
NoFronterDuration = noFronterDuration,
|
||||
RangeStart = actualStart,
|
||||
RangeEnd = actualEnd
|
||||
TimespanStart = g.Key,
|
||||
Members = g.Where(x => x.Member != default && groupMemberObjects.Any(m => x.Member == m.Key))
|
||||
.Select(x => memberObjects[x.Member])
|
||||
.ToList()
|
||||
};
|
||||
|
||||
// Loop through every switch that overlaps the range and add it to the output list
|
||||
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
|
||||
var endTime = periodEnd;
|
||||
var outList = new List<SwitchListEntry>();
|
||||
foreach (var e in entries)
|
||||
{
|
||||
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
|
||||
var switchStartClamped = e.TimespanStart < periodStart
|
||||
? periodStart
|
||||
: e.TimespanStart;
|
||||
|
||||
outList.Add(new SwitchListEntry
|
||||
{
|
||||
Members = e.Members,
|
||||
TimespanStart = switchStartClamped,
|
||||
TimespanEnd = endTime
|
||||
});
|
||||
|
||||
// next switch's end is this switch's start (we're working backward in time)
|
||||
endTime = e.TimespanStart;
|
||||
}
|
||||
|
||||
return outList;
|
||||
}
|
||||
|
||||
public struct SwitchListEntry
|
||||
public async Task<FrontBreakdown> GetFrontBreakdown(IPKConnection conn, SystemId system, GroupId? group,
|
||||
Instant periodStart,
|
||||
Instant periodEnd)
|
||||
{
|
||||
public ICollection<PKMember> Members;
|
||||
public Instant TimespanStart;
|
||||
public Instant TimespanEnd;
|
||||
}
|
||||
// TODO: this doesn't belong in the repo
|
||||
var dict = new Dictionary<PKMember, Duration>();
|
||||
|
||||
public struct FrontBreakdown
|
||||
{
|
||||
public Dictionary<PKMember, Duration> MemberSwitchDurations;
|
||||
public Duration NoFronterDuration;
|
||||
public Instant RangeStart;
|
||||
public Instant RangeEnd;
|
||||
}
|
||||
var noFronterDuration = Duration.Zero;
|
||||
|
||||
public struct SwitchMembersListEntry
|
||||
{
|
||||
public MemberId Member;
|
||||
public Instant Timestamp;
|
||||
// Sum up all switch durations for each member
|
||||
// switches with multiple members will result in the duration to add up to more than the actual period range
|
||||
|
||||
var actualStart = periodEnd; // will be "pulled" down
|
||||
var actualEnd = periodStart; // will be "pulled" up
|
||||
|
||||
foreach (var sw in await GetPeriodFronters(conn, system, group, periodStart, periodEnd))
|
||||
{
|
||||
var span = sw.TimespanEnd - sw.TimespanStart;
|
||||
foreach (var member in sw.Members)
|
||||
if (!dict.ContainsKey(member)) dict.Add(member, span);
|
||||
else dict[member] += span;
|
||||
|
||||
if (sw.Members.Count == 0) noFronterDuration += span;
|
||||
|
||||
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
|
||||
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
|
||||
}
|
||||
|
||||
return new FrontBreakdown
|
||||
{
|
||||
MemberSwitchDurations = dict,
|
||||
NoFronterDuration = noFronterDuration,
|
||||
RangeStart = actualStart,
|
||||
RangeEnd = actualEnd
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public struct SwitchListEntry
|
||||
{
|
||||
public ICollection<PKMember> Members;
|
||||
public Instant TimespanStart;
|
||||
public Instant TimespanEnd;
|
||||
}
|
||||
|
||||
public struct FrontBreakdown
|
||||
{
|
||||
public Dictionary<PKMember, Duration> MemberSwitchDurations;
|
||||
public Duration NoFronterDuration;
|
||||
public Instant RangeStart;
|
||||
public Instant RangeEnd;
|
||||
}
|
||||
|
||||
public struct SwitchMembersListEntry
|
||||
{
|
||||
public MemberId Member;
|
||||
public Instant Timestamp;
|
||||
}
|
@@ -1,141 +1,135 @@
|
||||
#nullable enable
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
public Task<PKSystem?> GetSystem(SystemId id)
|
||||
{
|
||||
public Task<PKSystem?> GetSystem(SystemId id)
|
||||
var query = new Query("systems").Where("id", id);
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
|
||||
public Task<PKSystem?> GetSystemByGuid(Guid id)
|
||||
{
|
||||
var query = new Query("systems").Where("uuid", id);
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
|
||||
public Task<PKSystem?> GetSystemByAccount(ulong accountId)
|
||||
{
|
||||
var query = new Query("accounts")
|
||||
.Select("systems.*")
|
||||
.LeftJoin("systems", "systems.id", "accounts.system")
|
||||
.Where("uid", accountId);
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
|
||||
public Task<PKSystem?> GetSystemByHid(string hid)
|
||||
{
|
||||
var query = new Query("systems").Where("hid", hid.ToLower());
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
|
||||
public Task<IEnumerable<ulong>> GetSystemAccounts(SystemId system)
|
||||
{
|
||||
var query = new Query("accounts").Select("uid").Where("system", system);
|
||||
return _db.Query<ulong>(query);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<PKMember> GetSystemMembers(SystemId system)
|
||||
{
|
||||
var query = new Query("members").Where("system", system);
|
||||
return _db.QueryStream<PKMember>(query);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<PKGroup> GetSystemGroups(SystemId system)
|
||||
{
|
||||
var query = new Query("groups").Where("system", system);
|
||||
return _db.QueryStream<PKGroup>(query);
|
||||
}
|
||||
|
||||
public Task<int> GetSystemMemberCount(SystemId system, PrivacyLevel? privacyFilter = null)
|
||||
{
|
||||
var query = new Query("members").SelectRaw("count(*)").Where("system", system);
|
||||
if (privacyFilter != null)
|
||||
query.Where("member_visibility", (int)privacyFilter.Value);
|
||||
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
|
||||
public Task<int> GetSystemGroupCount(SystemId system, PrivacyLevel? privacyFilter = null)
|
||||
{
|
||||
var query = new Query("groups").SelectRaw("count(*)").Where("system", system);
|
||||
if (privacyFilter != null)
|
||||
query.Where("visibility", (int)privacyFilter.Value);
|
||||
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
|
||||
public async Task<PKSystem> CreateSystem(string? systemName = null, IPKConnection? conn = null)
|
||||
{
|
||||
var query = new Query("systems").AsInsert(new
|
||||
{
|
||||
var query = new Query("systems").Where("id", id);
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
hid = new UnsafeLiteral("find_free_system_hid()"),
|
||||
name = systemName
|
||||
});
|
||||
var system = await _db.QueryFirst<PKSystem>(conn, query, "returning *");
|
||||
_logger.Information("Created {SystemId}", system.Id);
|
||||
|
||||
public Task<PKSystem?> GetSystemByGuid(Guid id)
|
||||
// no dispatch call here - system was just created, we don't have a webhook URL
|
||||
return system;
|
||||
}
|
||||
|
||||
public async Task<PKSystem> UpdateSystem(SystemId id, SystemPatch patch, IPKConnection? conn = null)
|
||||
{
|
||||
_logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("systems").Where("id", id));
|
||||
var res = await _db.QueryFirst<PKSystem>(conn, query, "returning *");
|
||||
|
||||
_ = _dispatch.Dispatch(id, new UpdateDispatchData
|
||||
{
|
||||
var query = new Query("systems").Where("uuid", id);
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
Event = DispatchEvent.UPDATE_SYSTEM,
|
||||
EventData = patch.ToJson(),
|
||||
});
|
||||
|
||||
public Task<PKSystem?> GetSystemByAccount(ulong accountId)
|
||||
return res;
|
||||
}
|
||||
|
||||
public async Task AddAccount(SystemId system, ulong accountId, IPKConnection? conn = null)
|
||||
{
|
||||
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
|
||||
// This is used in import/export, although the pk;link command checks for this case beforehand
|
||||
|
||||
var query = new Query("accounts").AsInsert(new { system, uid = accountId });
|
||||
|
||||
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system);
|
||||
await _db.ExecuteQuery(conn, query, "on conflict do nothing");
|
||||
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData
|
||||
{
|
||||
var query = new Query("accounts").Select("systems.*").LeftJoin("systems", "systems.id", "accounts.system", "=").Where("uid", accountId);
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
Event = DispatchEvent.LINK_ACCOUNT,
|
||||
EntityId = accountId.ToString(),
|
||||
});
|
||||
}
|
||||
|
||||
public Task<PKSystem?> GetSystemByHid(string hid)
|
||||
public async Task RemoveAccount(SystemId system, ulong accountId)
|
||||
{
|
||||
var query = new Query("accounts").AsDelete().Where("uid", accountId).Where("system", system);
|
||||
await _db.ExecuteQuery(query);
|
||||
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system);
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData
|
||||
{
|
||||
var query = new Query("systems").Where("hid", hid.ToLower());
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
Event = DispatchEvent.UNLINK_ACCOUNT,
|
||||
EntityId = accountId.ToString(),
|
||||
});
|
||||
}
|
||||
|
||||
public Task<IEnumerable<ulong>> GetSystemAccounts(SystemId system)
|
||||
{
|
||||
var query = new Query("accounts").Select("uid").Where("system", system);
|
||||
return _db.Query<ulong>(query);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<PKMember> GetSystemMembers(SystemId system)
|
||||
{
|
||||
var query = new Query("members").Where("system", system);
|
||||
return _db.QueryStream<PKMember>(query);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<PKGroup> GetSystemGroups(SystemId system)
|
||||
{
|
||||
var query = new Query("groups").Where("system", system);
|
||||
return _db.QueryStream<PKGroup>(query);
|
||||
}
|
||||
|
||||
public Task<int> GetSystemMemberCount(SystemId system, PrivacyLevel? privacyFilter = null)
|
||||
{
|
||||
var query = new Query("members").SelectRaw("count(*)").Where("system", system);
|
||||
if (privacyFilter != null)
|
||||
query.Where("member_visibility", (int)privacyFilter.Value);
|
||||
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
|
||||
public Task<int> GetSystemGroupCount(SystemId system, PrivacyLevel? privacyFilter = null)
|
||||
{
|
||||
var query = new Query("groups").SelectRaw("count(*)").Where("system", system);
|
||||
if (privacyFilter != null)
|
||||
query.Where("visibility", (int)privacyFilter.Value);
|
||||
|
||||
return _db.QueryFirst<int>(query);
|
||||
}
|
||||
|
||||
public async Task<PKSystem> CreateSystem(string? systemName = null, IPKConnection? conn = null)
|
||||
{
|
||||
var query = new Query("systems").AsInsert(new
|
||||
{
|
||||
hid = new UnsafeLiteral("find_free_system_hid()"),
|
||||
name = systemName
|
||||
});
|
||||
var system = await _db.QueryFirst<PKSystem>(conn, query, extraSql: "returning *");
|
||||
_logger.Information("Created {SystemId}", system.Id);
|
||||
|
||||
// no dispatch call here - system was just created, we don't have a webhook URL
|
||||
return system;
|
||||
}
|
||||
|
||||
public async Task<PKSystem> UpdateSystem(SystemId id, SystemPatch patch, IPKConnection? conn = null)
|
||||
{
|
||||
_logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch);
|
||||
var query = patch.Apply(new Query("systems").Where("id", id));
|
||||
var res = await _db.QueryFirst<PKSystem>(conn, query, extraSql: "returning *");
|
||||
|
||||
_ = _dispatch.Dispatch(id, new UpdateDispatchData()
|
||||
{
|
||||
Event = DispatchEvent.UPDATE_SYSTEM,
|
||||
EventData = patch.ToJson(),
|
||||
});
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
public async Task AddAccount(SystemId system, ulong accountId, IPKConnection? conn = null)
|
||||
{
|
||||
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
|
||||
// This is used in import/export, although the pk;link command checks for this case beforehand
|
||||
|
||||
var query = new Query("accounts").AsInsert(new
|
||||
{
|
||||
system = system,
|
||||
uid = accountId,
|
||||
});
|
||||
|
||||
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system);
|
||||
await _db.ExecuteQuery(conn, query, extraSql: "on conflict do nothing");
|
||||
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData()
|
||||
{
|
||||
Event = DispatchEvent.LINK_ACCOUNT,
|
||||
EntityId = accountId.ToString(),
|
||||
});
|
||||
}
|
||||
|
||||
public async Task RemoveAccount(SystemId system, ulong accountId)
|
||||
{
|
||||
var query = new Query("accounts").AsDelete().Where("uid", accountId).Where("system", system);
|
||||
await _db.ExecuteQuery(query);
|
||||
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system);
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData()
|
||||
{
|
||||
Event = DispatchEvent.UNLINK_ACCOUNT,
|
||||
EntityId = accountId.ToString(),
|
||||
});
|
||||
}
|
||||
|
||||
public Task DeleteSystem(SystemId id)
|
||||
{
|
||||
var query = new Query("systems").AsDelete().Where("id", id);
|
||||
_logger.Information("Deleted {SystemId}", id);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
public Task DeleteSystem(SystemId id)
|
||||
{
|
||||
var query = new Query("systems").AsDelete().Where("id", id);
|
||||
_logger.Information("Deleted {SystemId}", id);
|
||||
return _db.ExecuteQuery(query);
|
||||
}
|
||||
}
|
@@ -1,17 +1,17 @@
|
||||
using Serilog;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public partial class ModelRepository
|
||||
private readonly IDatabase _db;
|
||||
private readonly DispatchService _dispatch;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public ModelRepository(ILogger logger, IDatabase db, DispatchService dispatch)
|
||||
{
|
||||
private readonly ILogger _logger;
|
||||
private readonly IDatabase _db;
|
||||
private readonly DispatchService _dispatch;
|
||||
public ModelRepository(ILogger logger, IDatabase db, DispatchService dispatch)
|
||||
{
|
||||
_logger = logger.ForContext<ModelRepository>();
|
||||
_db = db;
|
||||
_dispatch = dispatch;
|
||||
}
|
||||
_logger = logger.ForContext<ModelRepository>();
|
||||
_db = db;
|
||||
_dispatch = dispatch;
|
||||
}
|
||||
}
|
@@ -1,19 +1,17 @@
|
||||
using System.Collections.Generic;
|
||||
using System.Data.Common;
|
||||
|
||||
using Dapper;
|
||||
|
||||
namespace PluralKit.Core
|
||||
{
|
||||
public static class ConnectionUtils
|
||||
{
|
||||
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IPKConnection conn, string sql, object param)
|
||||
{
|
||||
await using var reader = (DbDataReader)await conn.ExecuteReaderAsync(sql, param);
|
||||
var parser = reader.GetRowParser<T>();
|
||||
namespace PluralKit.Core;
|
||||
|
||||
while (await reader.ReadAsync())
|
||||
yield return parser(reader);
|
||||
}
|
||||
public static class ConnectionUtils
|
||||
{
|
||||
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IPKConnection conn, string sql, object param)
|
||||
{
|
||||
await using var reader = (DbDataReader)await conn.ExecuteReaderAsync(sql, param);
|
||||
var parser = reader.GetRowParser<T>();
|
||||
|
||||
while (await reader.ReadAsync())
|
||||
yield return parser(reader);
|
||||
}
|
||||
}
|
@@ -1,85 +1,81 @@
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
using Serilog;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
internal class DatabaseMigrator
|
||||
{
|
||||
internal class DatabaseMigrator
|
||||
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
|
||||
private const int TargetSchemaVersion = 20;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public DatabaseMigrator(ILogger logger)
|
||||
{
|
||||
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
|
||||
private const int TargetSchemaVersion = 20;
|
||||
private readonly ILogger _logger;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public DatabaseMigrator(ILogger logger)
|
||||
public async Task ApplyMigrations(IPKConnection conn)
|
||||
{
|
||||
// Run everything in a transaction
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
// Before applying migrations, clean out views/functions to prevent type errors
|
||||
await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx);
|
||||
|
||||
// Apply all migrations between the current database version and the target version
|
||||
await ApplyMigrations(conn, tx);
|
||||
|
||||
// Now, reapply views/functions (we deleted them above, no need to worry about conflicts)
|
||||
await ExecuteSqlFile($"{RootPath}.Views.views.sql", conn, tx);
|
||||
await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx);
|
||||
|
||||
// Finally, commit tx
|
||||
await tx.CommitAsync();
|
||||
}
|
||||
|
||||
private async Task ApplyMigrations(IPKConnection conn, IDbTransaction tx)
|
||||
{
|
||||
var currentVersion = await GetCurrentDatabaseVersion(conn);
|
||||
_logger.Information("Current schema version: {CurrentVersion}", currentVersion);
|
||||
for (var migration = currentVersion + 1; migration <= TargetSchemaVersion; migration++)
|
||||
{
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public async Task ApplyMigrations(IPKConnection conn)
|
||||
{
|
||||
// Run everything in a transaction
|
||||
await using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
// Before applying migrations, clean out views/functions to prevent type errors
|
||||
await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx);
|
||||
|
||||
// Apply all migrations between the current database version and the target version
|
||||
await ApplyMigrations(conn, tx);
|
||||
|
||||
// Now, reapply views/functions (we deleted them above, no need to worry about conflicts)
|
||||
await ExecuteSqlFile($"{RootPath}.Views.views.sql", conn, tx);
|
||||
await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx);
|
||||
|
||||
// Finally, commit tx
|
||||
await tx.CommitAsync();
|
||||
}
|
||||
|
||||
private async Task ApplyMigrations(IPKConnection conn, IDbTransaction tx)
|
||||
{
|
||||
var currentVersion = await GetCurrentDatabaseVersion(conn);
|
||||
_logger.Information("Current schema version: {CurrentVersion}", currentVersion);
|
||||
for (var migration = currentVersion + 1; migration <= TargetSchemaVersion; migration++)
|
||||
{
|
||||
_logger.Information("Applying schema migration {MigrationId}", migration);
|
||||
await ExecuteSqlFile($"{RootPath}.Migrations.{migration}.sql", conn, tx);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ExecuteSqlFile(string resourceName, IPKConnection conn, IDbTransaction tx = null)
|
||||
{
|
||||
await using var stream = typeof(Database).Assembly.GetManifestResourceStream(resourceName);
|
||||
if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'");
|
||||
|
||||
using var reader = new StreamReader(stream);
|
||||
var query = await reader.ReadToEndAsync();
|
||||
|
||||
await conn.ExecuteAsync(query, transaction: tx);
|
||||
|
||||
// If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches
|
||||
// This will propagate to every other connection as well, since it marks the global type mapper collection dirty.
|
||||
((PKConnection)conn).ReloadTypes();
|
||||
}
|
||||
|
||||
private async Task<int> GetCurrentDatabaseVersion(IPKConnection conn)
|
||||
{
|
||||
// First, check if the "info" table exists (it may not, if this is a *really* old database)
|
||||
var hasInfoTable =
|
||||
await conn.QuerySingleOrDefaultAsync<int>(
|
||||
"select count(*) from information_schema.tables where table_name = 'info'") == 1;
|
||||
|
||||
// If we have the table, read the schema version
|
||||
if (hasInfoTable)
|
||||
return await conn.QuerySingleOrDefaultAsync<int>("select schema_version from info");
|
||||
|
||||
// If not, we return version "-1"
|
||||
// This means migration 0 will get executed, getting us into a consistent state
|
||||
// Then, migration 1 gets executed, which creates the info table and sets version to 1
|
||||
return -1;
|
||||
_logger.Information("Applying schema migration {MigrationId}", migration);
|
||||
await ExecuteSqlFile($"{RootPath}.Migrations.{migration}.sql", conn, tx);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ExecuteSqlFile(string resourceName, IPKConnection conn, IDbTransaction tx = null)
|
||||
{
|
||||
await using var stream = typeof(Database).Assembly.GetManifestResourceStream(resourceName);
|
||||
if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'");
|
||||
|
||||
using var reader = new StreamReader(stream);
|
||||
var query = await reader.ReadToEndAsync();
|
||||
|
||||
await conn.ExecuteAsync(query, transaction: tx);
|
||||
|
||||
// If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches
|
||||
// This will propagate to every other connection as well, since it marks the global type mapper collection dirty.
|
||||
((PKConnection)conn).ReloadTypes();
|
||||
}
|
||||
|
||||
private async Task<int> GetCurrentDatabaseVersion(IPKConnection conn)
|
||||
{
|
||||
// First, check if the "info" table exists (it may not, if this is a *really* old database)
|
||||
var hasInfoTable =
|
||||
await conn.QuerySingleOrDefaultAsync<int>(
|
||||
"select count(*) from information_schema.tables where table_name = 'info'") == 1;
|
||||
|
||||
// If we have the table, read the schema version
|
||||
if (hasInfoTable)
|
||||
return await conn.QuerySingleOrDefaultAsync<int>("select schema_version from info");
|
||||
|
||||
// If not, we return version "-1"
|
||||
// This means migration 0 will get executed, getting us into a consistent state
|
||||
// Then, migration 1 gets executed, which creates the info table and sets version to 1
|
||||
return -1;
|
||||
}
|
||||
}
|
@@ -1,20 +1,17 @@
|
||||
using System.Threading;
|
||||
namespace PluralKit.Core;
|
||||
|
||||
namespace PluralKit.Core
|
||||
public class DbConnectionCountHolder
|
||||
{
|
||||
public class DbConnectionCountHolder
|
||||
private int _connectionCount;
|
||||
public int ConnectionCount => _connectionCount;
|
||||
|
||||
public void Increment()
|
||||
{
|
||||
private int _connectionCount;
|
||||
public int ConnectionCount => _connectionCount;
|
||||
Interlocked.Increment(ref _connectionCount);
|
||||
}
|
||||
|
||||
public void Increment()
|
||||
{
|
||||
Interlocked.Increment(ref _connectionCount);
|
||||
}
|
||||
|
||||
public void Decrement()
|
||||
{
|
||||
Interlocked.Decrement(ref _connectionCount);
|
||||
}
|
||||
public void Decrement()
|
||||
{
|
||||
Interlocked.Decrement(ref _connectionCount);
|
||||
}
|
||||
}
|
@@ -1,28 +1,24 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
internal class QueryPatchWrapper
|
||||
{
|
||||
internal class QueryPatchWrapper
|
||||
private readonly Dictionary<string, object> _dict = new();
|
||||
|
||||
public QueryPatchWrapper With<T>(string columnName, Partial<T> partialValue)
|
||||
{
|
||||
private Dictionary<string, object> _dict = new();
|
||||
if (partialValue.IsPresent)
|
||||
_dict.Add(columnName, partialValue);
|
||||
|
||||
public QueryPatchWrapper With<T>(string columnName, Partial<T> partialValue)
|
||||
{
|
||||
if (partialValue.IsPresent)
|
||||
_dict.Add(columnName, partialValue);
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
public Query ToQuery(Query q) => q.AsUpdate(_dict);
|
||||
return this;
|
||||
}
|
||||
|
||||
internal static class SqlKataExtensions
|
||||
{
|
||||
internal static Query ApplyPatch(this Query query, Func<QueryPatchWrapper, QueryPatchWrapper> func)
|
||||
=> func(new QueryPatchWrapper()).ToQuery(query);
|
||||
}
|
||||
public Query ToQuery(Query q) => q.AsUpdate(_dict);
|
||||
}
|
||||
|
||||
internal static class SqlKataExtensions
|
||||
{
|
||||
internal static Query ApplyPatch(this Query query, Func<QueryPatchWrapper, QueryPatchWrapper> func)
|
||||
=> func(new QueryPatchWrapper()).ToQuery(query);
|
||||
}
|
@@ -1,45 +1,41 @@
|
||||
using System.Text;
|
||||
|
||||
using Dapper;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public class UpdateQueryBuilder
|
||||
{
|
||||
public class UpdateQueryBuilder
|
||||
private readonly DynamicParameters _params = new();
|
||||
private readonly QueryBuilder _qb;
|
||||
|
||||
private UpdateQueryBuilder(QueryBuilder qb)
|
||||
{
|
||||
private readonly QueryBuilder _qb;
|
||||
private readonly DynamicParameters _params = new DynamicParameters();
|
||||
|
||||
private UpdateQueryBuilder(QueryBuilder qb)
|
||||
{
|
||||
_qb = qb;
|
||||
}
|
||||
|
||||
public static UpdateQueryBuilder Insert(string table) => new UpdateQueryBuilder(QueryBuilder.Insert(table));
|
||||
public static UpdateQueryBuilder Update(string table, string condition) => new UpdateQueryBuilder(QueryBuilder.Update(table, condition));
|
||||
public static UpdateQueryBuilder Upsert(string table, string conflictField) => new UpdateQueryBuilder(QueryBuilder.Upsert(table, conflictField));
|
||||
|
||||
public UpdateQueryBuilder WithConstant<T>(string name, T value)
|
||||
{
|
||||
_params.Add(name, value);
|
||||
_qb.Constant(name, $"@{name}");
|
||||
return this;
|
||||
}
|
||||
|
||||
public UpdateQueryBuilder With<T>(string columnName, T value)
|
||||
{
|
||||
_params.Add(columnName, value);
|
||||
_qb.Variable(columnName, $"@{columnName}");
|
||||
return this;
|
||||
}
|
||||
|
||||
public UpdateQueryBuilder With<T>(string columnName, Partial<T> partialValue)
|
||||
{
|
||||
return partialValue.IsPresent ? With(columnName, partialValue.Value) : this;
|
||||
}
|
||||
|
||||
public (string Query, DynamicParameters Parameters) Build(string suffix = "")
|
||||
{
|
||||
return (_qb.Build(suffix), _params);
|
||||
}
|
||||
_qb = qb;
|
||||
}
|
||||
|
||||
public static UpdateQueryBuilder Insert(string table) => new(QueryBuilder.Insert(table));
|
||||
|
||||
public static UpdateQueryBuilder Update(string table, string condition) =>
|
||||
new(QueryBuilder.Update(table, condition));
|
||||
|
||||
public static UpdateQueryBuilder Upsert(string table, string conflictField) =>
|
||||
new(QueryBuilder.Upsert(table, conflictField));
|
||||
|
||||
public UpdateQueryBuilder WithConstant<T>(string name, T value)
|
||||
{
|
||||
_params.Add(name, value);
|
||||
_qb.Constant(name, $"@{name}");
|
||||
return this;
|
||||
}
|
||||
|
||||
public UpdateQueryBuilder With<T>(string columnName, T value)
|
||||
{
|
||||
_params.Add(columnName, value);
|
||||
_qb.Variable(columnName, $"@{columnName}");
|
||||
return this;
|
||||
}
|
||||
|
||||
public UpdateQueryBuilder With<T>(string columnName, Partial<T> partialValue) =>
|
||||
partialValue.IsPresent ? With(columnName, partialValue.Value) : this;
|
||||
|
||||
public (string Query, DynamicParameters Parameters) Build(string suffix = "") => (_qb.Build(suffix), _params);
|
||||
}
|
@@ -1,57 +1,60 @@
|
||||
#nullable enable
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public static class DatabaseViewsExt
|
||||
{
|
||||
public static class DatabaseViewsExt
|
||||
public static Task<IEnumerable<SystemFronter>> QueryCurrentFronters(this IPKConnection conn, SystemId system) =>
|
||||
conn.QueryAsync<SystemFronter>("select * from system_fronters where system = @system", new { system });
|
||||
|
||||
public static Task<IEnumerable<ListedGroup>> QueryGroupList(this IPKConnection conn, SystemId system) =>
|
||||
conn.QueryAsync<ListedGroup>("select * from group_list where system = @System", new { System = system });
|
||||
|
||||
public static Task<IEnumerable<ListedMember>> QueryMemberList(this IPKConnection conn, SystemId system,
|
||||
MemberListQueryOptions opts)
|
||||
{
|
||||
public static Task<IEnumerable<SystemFronter>> QueryCurrentFronters(this IPKConnection conn, SystemId system) =>
|
||||
conn.QueryAsync<SystemFronter>("select * from system_fronters where system = @system", new { system });
|
||||
StringBuilder query;
|
||||
if (opts.GroupFilter == null)
|
||||
query = new StringBuilder("select * from member_list where system = @system");
|
||||
else
|
||||
query = new StringBuilder(
|
||||
"select member_list.* from group_members inner join member_list on member_list.id = group_members.member_id where group_id = @groupFilter");
|
||||
|
||||
public static Task<IEnumerable<ListedGroup>> QueryGroupList(this IPKConnection conn, SystemId system) =>
|
||||
conn.QueryAsync<ListedGroup>("select * from group_list where system = @System", new { System = system });
|
||||
if (opts.PrivacyFilter != null)
|
||||
query.Append($" and member_visibility = {(int)opts.PrivacyFilter}");
|
||||
|
||||
public static Task<IEnumerable<ListedMember>> QueryMemberList(this IPKConnection conn, SystemId system, MemberListQueryOptions opts)
|
||||
if (opts.Search != null)
|
||||
{
|
||||
StringBuilder query;
|
||||
if (opts.GroupFilter == null)
|
||||
query = new StringBuilder("select * from member_list where system = @system");
|
||||
else
|
||||
query = new StringBuilder("select member_list.* from group_members inner join member_list on member_list.id = group_members.member_id where group_id = @groupFilter");
|
||||
static string Filter(string column) =>
|
||||
$"(position(lower(@filter) in lower(coalesce({column}, ''))) > 0)";
|
||||
|
||||
if (opts.PrivacyFilter != null)
|
||||
query.Append($" and member_visibility = {(int)opts.PrivacyFilter}");
|
||||
|
||||
if (opts.Search != null)
|
||||
query.Append($" and ({Filter("name")} or {Filter("display_name")}");
|
||||
if (opts.SearchDescription)
|
||||
{
|
||||
static string Filter(string column) => $"(position(lower(@filter) in lower(coalesce({column}, ''))) > 0)";
|
||||
|
||||
query.Append($" and ({Filter("name")} or {Filter("display_name")}");
|
||||
if (opts.SearchDescription)
|
||||
{
|
||||
// We need to account for the possibility of description privacy when searching
|
||||
// If we're looking up from the outside, only search "public_description" (defined in the view; null if desc is private)
|
||||
// If we're the owner, just search the full description
|
||||
var descriptionColumn = opts.Context == LookupContext.ByOwner ? "description" : "public_description";
|
||||
query.Append($"or {Filter(descriptionColumn)}");
|
||||
}
|
||||
query.Append(")");
|
||||
// We need to account for the possibility of description privacy when searching
|
||||
// If we're looking up from the outside, only search "public_description" (defined in the view; null if desc is private)
|
||||
// If we're the owner, just search the full description
|
||||
var descriptionColumn =
|
||||
opts.Context == LookupContext.ByOwner ? "description" : "public_description";
|
||||
query.Append($"or {Filter(descriptionColumn)}");
|
||||
}
|
||||
|
||||
return conn.QueryAsync<ListedMember>(query.ToString(), new { system, filter = opts.Search, groupFilter = opts.GroupFilter });
|
||||
query.Append(")");
|
||||
}
|
||||
|
||||
public struct MemberListQueryOptions
|
||||
{
|
||||
public PrivacyLevel? PrivacyFilter;
|
||||
public string? Search;
|
||||
public bool SearchDescription;
|
||||
public LookupContext Context;
|
||||
public GroupId? GroupFilter;
|
||||
}
|
||||
return conn.QueryAsync<ListedMember>(query.ToString(),
|
||||
new { system, filter = opts.Search, groupFilter = opts.GroupFilter });
|
||||
}
|
||||
|
||||
public struct MemberListQueryOptions
|
||||
{
|
||||
public PrivacyLevel? PrivacyFilter;
|
||||
public string? Search;
|
||||
public bool SearchDescription;
|
||||
public LookupContext Context;
|
||||
public GroupId? GroupFilter;
|
||||
}
|
||||
}
|
@@ -1,7 +1,6 @@
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public class ListedGroup: PKGroup
|
||||
{
|
||||
public class ListedGroup: PKGroup
|
||||
{
|
||||
public int MemberCount { get; }
|
||||
}
|
||||
public int MemberCount { get; }
|
||||
}
|
@@ -1,17 +1,16 @@
|
||||
#nullable enable
|
||||
using NodaTime;
|
||||
|
||||
namespace PluralKit.Core
|
||||
{
|
||||
// TODO: is inheritance here correct?
|
||||
public class ListedMember: PKMember
|
||||
{
|
||||
// public ulong? LastMessage { get; }
|
||||
public Instant? LastSwitchTime { get; }
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public AnnualDate? AnnualBirthday =>
|
||||
Birthday != null
|
||||
? new AnnualDate(Birthday.Value.Month, Birthday.Value.Day)
|
||||
: (AnnualDate?)null;
|
||||
}
|
||||
// TODO: is inheritance here correct?
|
||||
public class ListedMember: PKMember
|
||||
{
|
||||
// public ulong? LastMessage { get; }
|
||||
public Instant? LastSwitchTime { get; }
|
||||
|
||||
public AnnualDate? AnnualBirthday =>
|
||||
Birthday != null
|
||||
? new AnnualDate(Birthday.Value.Month, Birthday.Value.Day)
|
||||
: null;
|
||||
}
|
@@ -1,14 +1,13 @@
|
||||
using NodaTime;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public class SystemFronter
|
||||
{
|
||||
public class SystemFronter
|
||||
{
|
||||
public SystemId SystemId { get; }
|
||||
public SwitchId SwitchId { get; }
|
||||
public Instant SwitchTimestamp { get; }
|
||||
public MemberId MemberId { get; }
|
||||
public string MemberHid { get; }
|
||||
public string MemberName { get; }
|
||||
}
|
||||
public SystemId SystemId { get; }
|
||||
public SwitchId SwitchId { get; }
|
||||
public Instant SwitchTimestamp { get; }
|
||||
public MemberId MemberId { get; }
|
||||
public string MemberHid { get; }
|
||||
public string MemberName { get; }
|
||||
}
|
@@ -1,17 +1,13 @@
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public interface IPKCommand: IDbCommand, IAsyncDisposable
|
||||
{
|
||||
public interface IPKCommand: IDbCommand, IAsyncDisposable
|
||||
{
|
||||
public Task PrepareAsync(CancellationToken ct = default);
|
||||
public Task<int> ExecuteNonQueryAsync(CancellationToken ct = default);
|
||||
public Task<object?> ExecuteScalarAsync(CancellationToken ct = default);
|
||||
public Task<DbDataReader> ExecuteReaderAsync(CancellationToken ct = default);
|
||||
public Task<DbDataReader> ExecuteReaderAsync(CommandBehavior behavior, CancellationToken ct = default);
|
||||
}
|
||||
public Task PrepareAsync(CancellationToken ct = default);
|
||||
public Task<int> ExecuteNonQueryAsync(CancellationToken ct = default);
|
||||
public Task<object?> ExecuteScalarAsync(CancellationToken ct = default);
|
||||
public Task<DbDataReader> ExecuteReaderAsync(CancellationToken ct = default);
|
||||
public Task<DbDataReader> ExecuteReaderAsync(CommandBehavior behavior, CancellationToken ct = default);
|
||||
}
|
@@ -1,31 +1,29 @@
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Npgsql;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public interface IPKConnection: IDbConnection, IAsyncDisposable
|
||||
{
|
||||
public interface IPKConnection: IDbConnection, IAsyncDisposable
|
||||
{
|
||||
public Guid ConnectionId { get; }
|
||||
public Guid ConnectionId { get; }
|
||||
|
||||
public Task OpenAsync(CancellationToken cancellationToken = default);
|
||||
public Task CloseAsync();
|
||||
public Task OpenAsync(CancellationToken cancellationToken = default);
|
||||
public Task CloseAsync();
|
||||
|
||||
public Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default);
|
||||
public Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default);
|
||||
|
||||
public ValueTask<IPKTransaction> BeginTransactionAsync(CancellationToken ct = default) => BeginTransactionAsync(IsolationLevel.Unspecified, ct);
|
||||
public ValueTask<IPKTransaction> BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default);
|
||||
public ValueTask<IPKTransaction> BeginTransactionAsync(CancellationToken ct = default) =>
|
||||
BeginTransactionAsync(IsolationLevel.Unspecified, ct);
|
||||
|
||||
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand);
|
||||
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand);
|
||||
public ValueTask<IPKTransaction> BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default);
|
||||
|
||||
[Obsolete] new void Open();
|
||||
[Obsolete] new void Close();
|
||||
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand);
|
||||
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand);
|
||||
|
||||
[Obsolete] new IDbTransaction BeginTransaction();
|
||||
[Obsolete] new IDbTransaction BeginTransaction(IsolationLevel il);
|
||||
}
|
||||
[Obsolete] new void Open();
|
||||
[Obsolete] new void Close();
|
||||
|
||||
[Obsolete] new IDbTransaction BeginTransaction();
|
||||
[Obsolete] new IDbTransaction BeginTransaction(IsolationLevel il);
|
||||
}
|
@@ -1,13 +1,9 @@
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public interface IPKTransaction: IDbTransaction, IAsyncDisposable
|
||||
{
|
||||
public interface IPKTransaction: IDbTransaction, IAsyncDisposable
|
||||
{
|
||||
public Task CommitAsync(CancellationToken ct = default);
|
||||
public Task RollbackAsync(CancellationToken ct = default);
|
||||
}
|
||||
public Task CommitAsync(CancellationToken ct = default);
|
||||
public Task RollbackAsync(CancellationToken ct = default);
|
||||
}
|
@@ -1,10 +1,7 @@
|
||||
#nullable enable
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using App.Metrics;
|
||||
|
||||
@@ -14,113 +11,125 @@ using Npgsql;
|
||||
|
||||
using Serilog;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
internal class PKCommand: DbCommand, IPKCommand
|
||||
{
|
||||
internal class PKCommand: DbCommand, IPKCommand
|
||||
private readonly ILogger _logger;
|
||||
private readonly IMetrics _metrics;
|
||||
|
||||
private readonly PKConnection _ourConnection;
|
||||
|
||||
public PKCommand(NpgsqlCommand inner, PKConnection ourConnection, ILogger logger, IMetrics metrics)
|
||||
{
|
||||
private NpgsqlCommand Inner { get; }
|
||||
|
||||
private readonly PKConnection _ourConnection;
|
||||
private readonly ILogger _logger;
|
||||
private readonly IMetrics _metrics;
|
||||
|
||||
public PKCommand(NpgsqlCommand inner, PKConnection ourConnection, ILogger logger, IMetrics metrics)
|
||||
{
|
||||
Inner = inner;
|
||||
_ourConnection = ourConnection;
|
||||
_logger = logger.ForContext<PKCommand>();
|
||||
_metrics = metrics;
|
||||
}
|
||||
|
||||
public override Task<int> ExecuteNonQueryAsync(CancellationToken ct) => LogQuery(Inner.ExecuteNonQueryAsync(ct));
|
||||
public override Task<object?> ExecuteScalarAsync(CancellationToken ct) => LogQuery(Inner.ExecuteScalarAsync(ct));
|
||||
protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) => await LogQuery(Inner.ExecuteReaderAsync(behavior, ct));
|
||||
|
||||
public override Task PrepareAsync(CancellationToken ct = default) => Inner.PrepareAsync(ct);
|
||||
public override void Cancel() => Inner.Cancel();
|
||||
protected override DbParameter CreateDbParameter() => Inner.CreateParameter();
|
||||
|
||||
[AllowNull]
|
||||
public override string CommandText
|
||||
{
|
||||
get => Inner.CommandText;
|
||||
set => Inner.CommandText = value;
|
||||
}
|
||||
|
||||
public override int CommandTimeout
|
||||
{
|
||||
get => Inner.CommandTimeout;
|
||||
set => Inner.CommandTimeout = value;
|
||||
}
|
||||
|
||||
public override CommandType CommandType
|
||||
{
|
||||
get => Inner.CommandType;
|
||||
set => Inner.CommandType = value;
|
||||
}
|
||||
|
||||
public override UpdateRowSource UpdatedRowSource
|
||||
{
|
||||
get => Inner.UpdatedRowSource;
|
||||
set => Inner.UpdatedRowSource = value;
|
||||
}
|
||||
|
||||
protected override DbParameterCollection DbParameterCollection => Inner.Parameters;
|
||||
protected override DbTransaction? DbTransaction
|
||||
{
|
||||
get => Inner.Transaction;
|
||||
set => Inner.Transaction = value switch
|
||||
{
|
||||
NpgsqlTransaction npg => npg,
|
||||
PKTransaction pk => pk.Inner,
|
||||
_ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlTransaction")
|
||||
};
|
||||
}
|
||||
|
||||
public override bool DesignTimeVisible
|
||||
{
|
||||
get => Inner.DesignTimeVisible;
|
||||
set => Inner.DesignTimeVisible = value;
|
||||
}
|
||||
|
||||
protected override DbConnection? DbConnection
|
||||
{
|
||||
get => Inner.Connection;
|
||||
set =>
|
||||
Inner.Connection = value switch
|
||||
{
|
||||
NpgsqlConnection npg => npg,
|
||||
PKConnection pk => pk.Inner,
|
||||
_ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlConnection")
|
||||
};
|
||||
}
|
||||
|
||||
public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery));
|
||||
public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar));
|
||||
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => throw SyncError(nameof(ExecuteDbDataReader));
|
||||
public override void Prepare() => throw SyncError(nameof(Prepare));
|
||||
|
||||
private async Task<T> LogQuery<T>(Task<T> task)
|
||||
{
|
||||
var start = SystemClock.Instance.GetCurrentInstant();
|
||||
try
|
||||
{
|
||||
return await task;
|
||||
}
|
||||
finally
|
||||
{
|
||||
var end = SystemClock.Instance.GetCurrentInstant();
|
||||
var elapsed = end - start;
|
||||
|
||||
_logger.Verbose("Executed query {Query} in {ElapsedTime} on connection {ConnectionId}", CommandText, elapsed, _ourConnection.ConnectionId);
|
||||
|
||||
// One "BCL compatible tick" is 100 nanoseconds
|
||||
var micros = elapsed.BclCompatibleTicks / 10;
|
||||
_metrics.Provider.Timer.Instance(CoreMetrics.DatabaseQuery, new MetricTags("query", CommandText))
|
||||
.Record(micros, TimeUnit.Microseconds, CommandText);
|
||||
}
|
||||
}
|
||||
|
||||
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbCommand function {caller}!");
|
||||
Inner = inner;
|
||||
_ourConnection = ourConnection;
|
||||
_logger = logger.ForContext<PKCommand>();
|
||||
_metrics = metrics;
|
||||
}
|
||||
|
||||
private NpgsqlCommand Inner { get; }
|
||||
|
||||
protected override DbParameterCollection DbParameterCollection => Inner.Parameters;
|
||||
|
||||
protected override DbTransaction? DbTransaction
|
||||
{
|
||||
get => Inner.Transaction;
|
||||
set => Inner.Transaction = value switch
|
||||
{
|
||||
NpgsqlTransaction npg => npg,
|
||||
PKTransaction pk => pk.Inner,
|
||||
_ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlTransaction")
|
||||
};
|
||||
}
|
||||
|
||||
public override bool DesignTimeVisible
|
||||
{
|
||||
get => Inner.DesignTimeVisible;
|
||||
set => Inner.DesignTimeVisible = value;
|
||||
}
|
||||
|
||||
protected override DbConnection? DbConnection
|
||||
{
|
||||
get => Inner.Connection;
|
||||
set =>
|
||||
Inner.Connection = value switch
|
||||
{
|
||||
NpgsqlConnection npg => npg,
|
||||
PKConnection pk => pk.Inner,
|
||||
_ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlConnection")
|
||||
};
|
||||
}
|
||||
|
||||
public override Task<int> ExecuteNonQueryAsync(CancellationToken ct) =>
|
||||
LogQuery(Inner.ExecuteNonQueryAsync(ct));
|
||||
|
||||
public override Task<object?> ExecuteScalarAsync(CancellationToken ct) =>
|
||||
LogQuery(Inner.ExecuteScalarAsync(ct));
|
||||
|
||||
public override Task PrepareAsync(CancellationToken ct = default) => Inner.PrepareAsync(ct);
|
||||
public override void Cancel() => Inner.Cancel();
|
||||
|
||||
[AllowNull]
|
||||
public override string CommandText
|
||||
{
|
||||
get => Inner.CommandText;
|
||||
set => Inner.CommandText = value;
|
||||
}
|
||||
|
||||
public override int CommandTimeout
|
||||
{
|
||||
get => Inner.CommandTimeout;
|
||||
set => Inner.CommandTimeout = value;
|
||||
}
|
||||
|
||||
public override CommandType CommandType
|
||||
{
|
||||
get => Inner.CommandType;
|
||||
set => Inner.CommandType = value;
|
||||
}
|
||||
|
||||
public override UpdateRowSource UpdatedRowSource
|
||||
{
|
||||
get => Inner.UpdatedRowSource;
|
||||
set => Inner.UpdatedRowSource = value;
|
||||
}
|
||||
|
||||
public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery));
|
||||
public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar));
|
||||
public override void Prepare() => throw SyncError(nameof(Prepare));
|
||||
|
||||
protected override async Task<DbDataReader>
|
||||
ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) =>
|
||||
await LogQuery(Inner.ExecuteReaderAsync(behavior, ct));
|
||||
|
||||
protected override DbParameter CreateDbParameter() => Inner.CreateParameter();
|
||||
|
||||
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) =>
|
||||
throw SyncError(nameof(ExecuteDbDataReader));
|
||||
|
||||
private async Task<T> LogQuery<T>(Task<T> task)
|
||||
{
|
||||
var start = SystemClock.Instance.GetCurrentInstant();
|
||||
try
|
||||
{
|
||||
return await task;
|
||||
}
|
||||
finally
|
||||
{
|
||||
var end = SystemClock.Instance.GetCurrentInstant();
|
||||
var elapsed = end - start;
|
||||
|
||||
_logger.Verbose("Executed query {Query} in {ElapsedTime} on connection {ConnectionId}", CommandText,
|
||||
elapsed, _ourConnection.ConnectionId);
|
||||
|
||||
// One "BCL compatible tick" is 100 nanoseconds
|
||||
var micros = elapsed.BclCompatibleTicks / 10;
|
||||
_metrics.Provider.Timer.Instance(CoreMetrics.DatabaseQuery, new MetricTags("query", CommandText))
|
||||
.Record(micros, TimeUnit.Microseconds, CommandText);
|
||||
}
|
||||
}
|
||||
|
||||
private static Exception SyncError(string caller) =>
|
||||
throw new Exception($"Executed synchronous IDbCommand function {caller}!");
|
||||
}
|
@@ -1,10 +1,7 @@
|
||||
#nullable enable
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using App.Metrics;
|
||||
|
||||
@@ -14,102 +11,120 @@ using Npgsql;
|
||||
|
||||
using Serilog;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
internal class PKConnection: DbConnection, IPKConnection
|
||||
{
|
||||
internal class PKConnection: DbConnection, IPKConnection
|
||||
private readonly DbConnectionCountHolder _countHolder;
|
||||
private readonly ILogger _logger;
|
||||
private readonly IMetrics _metrics;
|
||||
private bool _hasClosed;
|
||||
|
||||
private bool _hasOpened;
|
||||
private Instant _openTime;
|
||||
|
||||
public PKConnection(NpgsqlConnection inner, DbConnectionCountHolder countHolder, ILogger logger,
|
||||
IMetrics metrics)
|
||||
{
|
||||
public NpgsqlConnection Inner { get; }
|
||||
public Guid ConnectionId { get; }
|
||||
|
||||
private readonly DbConnectionCountHolder _countHolder;
|
||||
private readonly ILogger _logger;
|
||||
private readonly IMetrics _metrics;
|
||||
|
||||
private bool _hasOpened;
|
||||
private bool _hasClosed;
|
||||
private Instant _openTime;
|
||||
|
||||
public PKConnection(NpgsqlConnection inner, DbConnectionCountHolder countHolder, ILogger logger, IMetrics metrics)
|
||||
{
|
||||
Inner = inner;
|
||||
ConnectionId = Guid.NewGuid();
|
||||
_countHolder = countHolder;
|
||||
_logger = logger.ForContext<PKConnection>();
|
||||
_metrics = metrics;
|
||||
}
|
||||
|
||||
public override Task OpenAsync(CancellationToken ct)
|
||||
{
|
||||
if (_hasOpened) return Inner.OpenAsync(ct);
|
||||
_countHolder.Increment();
|
||||
_hasOpened = true;
|
||||
_openTime = SystemClock.Instance.GetCurrentInstant();
|
||||
_logger.Verbose("Opened database connection {ConnectionId}, new connection count {ConnectionCount}", ConnectionId, _countHolder.ConnectionCount);
|
||||
return Inner.OpenAsync(ct);
|
||||
}
|
||||
|
||||
public override Task CloseAsync() => Inner.CloseAsync();
|
||||
|
||||
protected override DbCommand CreateDbCommand() => new PKCommand(Inner.CreateCommand(), this, _logger, _metrics);
|
||||
|
||||
public void ReloadTypes() => Inner.ReloadTypes();
|
||||
|
||||
public new async ValueTask<IPKTransaction> BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
|
||||
|
||||
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) => Inner.BeginBinaryImport(copyFromCommand);
|
||||
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) => Inner.BeginBinaryExport(copyToCommand);
|
||||
|
||||
public override void ChangeDatabase(string databaseName) => Inner.ChangeDatabase(databaseName);
|
||||
public override Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default) => Inner.ChangeDatabaseAsync(databaseName, ct);
|
||||
protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw SyncError(nameof(BeginDbTransaction));
|
||||
protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
|
||||
|
||||
public override void Open() => throw SyncError(nameof(Open));
|
||||
public override void Close()
|
||||
{
|
||||
// Don't throw SyncError here, Dapper calls sync Close() internally so that sucks
|
||||
Inner.Close();
|
||||
}
|
||||
|
||||
IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction));
|
||||
IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction));
|
||||
|
||||
[AllowNull]
|
||||
public override string ConnectionString
|
||||
{
|
||||
get => Inner.ConnectionString;
|
||||
set => Inner.ConnectionString = value;
|
||||
}
|
||||
|
||||
public override string Database => Inner.Database!;
|
||||
public override ConnectionState State => Inner.State;
|
||||
public override string DataSource => Inner.DataSource;
|
||||
public override string ServerVersion => Inner.ServerVersion;
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
Inner.Dispose();
|
||||
if (_hasClosed) return;
|
||||
|
||||
LogClose();
|
||||
}
|
||||
|
||||
public override ValueTask DisposeAsync()
|
||||
{
|
||||
if (_hasClosed) return Inner.DisposeAsync();
|
||||
LogClose();
|
||||
return Inner.DisposeAsync();
|
||||
}
|
||||
|
||||
private void LogClose()
|
||||
{
|
||||
_countHolder.Decrement();
|
||||
_hasClosed = true;
|
||||
|
||||
var duration = SystemClock.Instance.GetCurrentInstant() - _openTime;
|
||||
_logger.Verbose("Closed database connection {ConnectionId} (open for {ConnectionDuration}), new connection count {ConnectionCount}", ConnectionId, duration, _countHolder.ConnectionCount);
|
||||
}
|
||||
|
||||
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbCommand function {caller}!");
|
||||
Inner = inner;
|
||||
ConnectionId = Guid.NewGuid();
|
||||
_countHolder = countHolder;
|
||||
_logger = logger.ForContext<PKConnection>();
|
||||
_metrics = metrics;
|
||||
}
|
||||
|
||||
public NpgsqlConnection Inner { get; }
|
||||
public override string DataSource => Inner.DataSource;
|
||||
public override string ServerVersion => Inner.ServerVersion;
|
||||
public Guid ConnectionId { get; }
|
||||
|
||||
public override Task OpenAsync(CancellationToken ct)
|
||||
{
|
||||
if (_hasOpened) return Inner.OpenAsync(ct);
|
||||
_countHolder.Increment();
|
||||
_hasOpened = true;
|
||||
_openTime = SystemClock.Instance.GetCurrentInstant();
|
||||
_logger.Verbose("Opened database connection {ConnectionId}, new connection count {ConnectionCount}",
|
||||
ConnectionId, _countHolder.ConnectionCount);
|
||||
return Inner.OpenAsync(ct);
|
||||
}
|
||||
|
||||
public override Task CloseAsync() => Inner.CloseAsync();
|
||||
|
||||
public new async ValueTask<IPKTransaction>
|
||||
BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default) =>
|
||||
new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
|
||||
|
||||
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) =>
|
||||
Inner.BeginBinaryImport(copyFromCommand);
|
||||
|
||||
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) => Inner.BeginBinaryExport(copyToCommand);
|
||||
|
||||
public override void ChangeDatabase(string databaseName) => Inner.ChangeDatabase(databaseName);
|
||||
|
||||
public override Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default) =>
|
||||
Inner.ChangeDatabaseAsync(databaseName, ct);
|
||||
|
||||
public override void Open() => throw SyncError(nameof(Open));
|
||||
|
||||
public override void Close()
|
||||
{
|
||||
// Don't throw SyncError here, Dapper calls sync Close() internally so that sucks
|
||||
Inner.Close();
|
||||
}
|
||||
|
||||
IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction));
|
||||
|
||||
IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) =>
|
||||
throw SyncError(nameof(BeginTransaction));
|
||||
|
||||
[AllowNull]
|
||||
public override string ConnectionString
|
||||
{
|
||||
get => Inner.ConnectionString;
|
||||
set => Inner.ConnectionString = value;
|
||||
}
|
||||
|
||||
public override string Database => Inner.Database!;
|
||||
public override ConnectionState State => Inner.State;
|
||||
|
||||
public override ValueTask DisposeAsync()
|
||||
{
|
||||
if (_hasClosed) return Inner.DisposeAsync();
|
||||
LogClose();
|
||||
return Inner.DisposeAsync();
|
||||
}
|
||||
|
||||
protected override DbCommand CreateDbCommand() => new PKCommand(Inner.CreateCommand(), this, _logger, _metrics);
|
||||
|
||||
public void ReloadTypes() => Inner.ReloadTypes();
|
||||
|
||||
protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) =>
|
||||
throw SyncError(nameof(BeginDbTransaction));
|
||||
|
||||
protected override async ValueTask<DbTransaction>
|
||||
BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) =>
|
||||
new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
|
||||
|
||||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
Inner.Dispose();
|
||||
if (_hasClosed) return;
|
||||
|
||||
LogClose();
|
||||
}
|
||||
|
||||
private void LogClose()
|
||||
{
|
||||
_countHolder.Decrement();
|
||||
_hasClosed = true;
|
||||
|
||||
var duration = SystemClock.Instance.GetCurrentInstant() - _openTime;
|
||||
_logger.Verbose(
|
||||
"Closed database connection {ConnectionId} (open for {ConnectionDuration}), new connection count {ConnectionCount}",
|
||||
ConnectionId, duration, _countHolder.ConnectionCount);
|
||||
}
|
||||
|
||||
private static Exception SyncError(string caller) =>
|
||||
throw new Exception($"Executed synchronous IDbCommand function {caller}!");
|
||||
}
|
@@ -1,31 +1,28 @@
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Npgsql;
|
||||
|
||||
namespace PluralKit.Core
|
||||
namespace PluralKit.Core;
|
||||
|
||||
internal class PKTransaction: DbTransaction, IPKTransaction
|
||||
{
|
||||
internal class PKTransaction: DbTransaction, IPKTransaction
|
||||
public PKTransaction(NpgsqlTransaction inner)
|
||||
{
|
||||
public NpgsqlTransaction Inner { get; }
|
||||
|
||||
public PKTransaction(NpgsqlTransaction inner)
|
||||
{
|
||||
Inner = inner;
|
||||
}
|
||||
|
||||
public override void Commit() => throw SyncError(nameof(Commit));
|
||||
public override Task CommitAsync(CancellationToken ct = default) => Inner.CommitAsync(ct);
|
||||
|
||||
public override void Rollback() => throw SyncError(nameof(Rollback));
|
||||
public override Task RollbackAsync(CancellationToken ct = default) => Inner.RollbackAsync(ct);
|
||||
|
||||
protected override DbConnection DbConnection => Inner.Connection;
|
||||
public override IsolationLevel IsolationLevel => Inner.IsolationLevel;
|
||||
|
||||
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbTransaction function {caller}!");
|
||||
Inner = inner;
|
||||
}
|
||||
|
||||
public NpgsqlTransaction Inner { get; }
|
||||
|
||||
protected override DbConnection DbConnection => Inner.Connection;
|
||||
|
||||
public override void Commit() => throw SyncError(nameof(Commit));
|
||||
public override Task CommitAsync(CancellationToken ct = default) => Inner.CommitAsync(ct);
|
||||
|
||||
public override void Rollback() => throw SyncError(nameof(Rollback));
|
||||
public override Task RollbackAsync(CancellationToken ct = default) => Inner.RollbackAsync(ct);
|
||||
public override IsolationLevel IsolationLevel => Inner.IsolationLevel;
|
||||
|
||||
private static Exception SyncError(string caller) =>
|
||||
throw new Exception($"Executed synchronous IDbTransaction function {caller}!");
|
||||
}
|
Reference in New Issue
Block a user