diff --git a/PluralKit.Core/Database/Wrappers/IPKConnection.cs b/PluralKit.Core/Database/Wrappers/IPKConnection.cs index b83a5b77..f82971d2 100644 --- a/PluralKit.Core/Database/Wrappers/IPKConnection.cs +++ b/PluralKit.Core/Database/Wrappers/IPKConnection.cs @@ -17,8 +17,8 @@ namespace PluralKit.Core public Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default); - public ValueTask BeginTransactionAsync(CancellationToken ct = default) => BeginTransactionAsync(IsolationLevel.Unspecified, ct); - public ValueTask BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default); + public ValueTask BeginTransactionAsync(CancellationToken ct = default) => BeginTransactionAsync(IsolationLevel.Unspecified, ct); + public ValueTask BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default); public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand); public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand); diff --git a/PluralKit.Core/Database/Wrappers/IPKTransaction.cs b/PluralKit.Core/Database/Wrappers/IPKTransaction.cs new file mode 100644 index 00000000..324450f7 --- /dev/null +++ b/PluralKit.Core/Database/Wrappers/IPKTransaction.cs @@ -0,0 +1,13 @@ +using System; +using System.Data; +using System.Threading; +using System.Threading.Tasks; + +namespace PluralKit.Core +{ + public interface IPKTransaction: IDbTransaction, IAsyncDisposable + { + public Task CommitAsync(CancellationToken ct = default); + public Task RollbackAsync(CancellationToken ct = default); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/PKCommand.cs b/PluralKit.Core/Database/Wrappers/PKCommand.cs index ed048750..aa21241a 100644 --- a/PluralKit.Core/Database/Wrappers/PKCommand.cs +++ b/PluralKit.Core/Database/Wrappers/PKCommand.cs @@ -17,80 +17,87 @@ namespace PluralKit.Core { public class PKCommand: DbCommand, IPKCommand { - private readonly NpgsqlCommand _inner; + public NpgsqlCommand Inner { get; } + private readonly PKConnection _ourConnection; private readonly ILogger _logger; private readonly IMetrics _metrics; public PKCommand(NpgsqlCommand inner, PKConnection ourConnection, ILogger logger, IMetrics metrics) { - _inner = inner; + Inner = inner; _ourConnection = ourConnection; _logger = logger.ForContext(); _metrics = metrics; } - public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery)); - public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar)); - protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => throw SyncError(nameof(ExecuteDbDataReader)); + public override Task ExecuteNonQueryAsync(CancellationToken ct) => LogQuery(Inner.ExecuteNonQueryAsync(ct)); + public override Task ExecuteScalarAsync(CancellationToken ct) => LogQuery(Inner.ExecuteScalarAsync(ct)); + protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) => await LogQuery(Inner.ExecuteReaderAsync(behavior, ct)); - public override Task ExecuteNonQueryAsync(CancellationToken ct) => LogQuery(_inner.ExecuteNonQueryAsync(ct)); - public override Task ExecuteScalarAsync(CancellationToken ct) => LogQuery(_inner.ExecuteScalarAsync(ct)); - protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) => await LogQuery(_inner.ExecuteReaderAsync(behavior, ct)); - - public override void Prepare() => _inner.Prepare(); - public override void Cancel() => _inner.Cancel(); - protected override DbParameter CreateDbParameter() => _inner.CreateParameter(); + public override Task PrepareAsync(CancellationToken ct = default) => Inner.PrepareAsync(ct); + public override void Cancel() => Inner.Cancel(); + protected override DbParameter CreateDbParameter() => Inner.CreateParameter(); public override string CommandText { - get => _inner.CommandText; - set => _inner.CommandText = value; + get => Inner.CommandText; + set => Inner.CommandText = value; } public override int CommandTimeout { - get => _inner.CommandTimeout; - set => _inner.CommandTimeout = value; + get => Inner.CommandTimeout; + set => Inner.CommandTimeout = value; } public override CommandType CommandType { - get => _inner.CommandType; - set => _inner.CommandType = value; + get => Inner.CommandType; + set => Inner.CommandType = value; } public override UpdateRowSource UpdatedRowSource { - get => _inner.UpdatedRowSource; - set => _inner.UpdatedRowSource = value; + get => Inner.UpdatedRowSource; + set => Inner.UpdatedRowSource = value; } - protected override DbParameterCollection DbParameterCollection => _inner.Parameters; + protected override DbParameterCollection DbParameterCollection => Inner.Parameters; protected override DbTransaction? DbTransaction { - get => _inner.Transaction; - set => _inner.Transaction = (NpgsqlTransaction?) value; + get => Inner.Transaction; + set => Inner.Transaction = value switch + { + NpgsqlTransaction npg => npg, + PKTransaction pk => pk.Inner, + _ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlTransaction") + }; } public override bool DesignTimeVisible { - get => _inner.DesignTimeVisible; - set => _inner.DesignTimeVisible = value; + get => Inner.DesignTimeVisible; + set => Inner.DesignTimeVisible = value; } protected override DbConnection? DbConnection { - get => _inner.Connection; + get => Inner.Connection; set => - _inner.Connection = value switch + Inner.Connection = value switch { NpgsqlConnection npg => npg, PKConnection pk => pk.Inner, _ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlConnection") }; } - + + public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery)); + public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar)); + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => throw SyncError(nameof(ExecuteDbDataReader)); + public override void Prepare() => throw SyncError(nameof(Prepare)); + private async Task LogQuery(Task task) { var start = SystemClock.Instance.GetCurrentInstant(); @@ -112,6 +119,6 @@ namespace PluralKit.Core } } - private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKCommand function {caller}!"); + private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbCommand function {caller}!"); } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/PKConnection.cs b/PluralKit.Core/Database/Wrappers/PKConnection.cs index 46947fcc..153a7f02 100644 --- a/PluralKit.Core/Database/Wrappers/PKConnection.cs +++ b/PluralKit.Core/Database/Wrappers/PKConnection.cs @@ -52,14 +52,16 @@ namespace PluralKit.Core protected override DbCommand CreateDbCommand() => new PKCommand(Inner.CreateCommand(), this, _logger, _metrics); public void ReloadTypes() => Inner.ReloadTypes(); - + + public new async ValueTask BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct)); + public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) => Inner.BeginBinaryImport(copyFromCommand); public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) => Inner.BeginBinaryExport(copyToCommand); public override void ChangeDatabase(string databaseName) => Inner.ChangeDatabase(databaseName); public override Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default) => Inner.ChangeDatabaseAsync(databaseName, ct); protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw SyncError(nameof(BeginDbTransaction)); - protected override async ValueTask BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => await Inner.BeginTransactionAsync(level, ct); + protected override async ValueTask BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct)); public override void Open() => throw SyncError(nameof(Open)); public override void Close() => throw SyncError(nameof(Close)); @@ -102,6 +104,6 @@ namespace PluralKit.Core _logger.Verbose("Closed database connection {ConnectionId} (open for {ConnectionDuration}), new connection count {ConnectionCount}", ConnectionId, duration, _countHolder.ConnectionCount); } - private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKConnection function {caller}!"); + private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbCommand function {caller}!"); } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/PKTransaction.cs b/PluralKit.Core/Database/Wrappers/PKTransaction.cs new file mode 100644 index 00000000..84c17722 --- /dev/null +++ b/PluralKit.Core/Database/Wrappers/PKTransaction.cs @@ -0,0 +1,31 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; + +using Npgsql; + +namespace PluralKit.Core +{ + public class PKTransaction: DbTransaction, IPKTransaction + { + public NpgsqlTransaction Inner { get; } + + public PKTransaction(NpgsqlTransaction inner) + { + Inner = inner; + } + + public override void Commit() => throw SyncError(nameof(Commit)); + public override Task CommitAsync(CancellationToken ct = default) => Inner.CommitAsync(ct); + + public override void Rollback() => throw SyncError(nameof(Rollback)); + public override Task RollbackAsync(CancellationToken ct = default) => Inner.RollbackAsync(ct); + + protected override DbConnection DbConnection => Inner.Connection; + public override IsolationLevel IsolationLevel => Inner.IsolationLevel; + + private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbTransaction function {caller}!"); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Utils/BulkImporter.cs b/PluralKit.Core/Utils/BulkImporter.cs index 196f1267..41d329e8 100644 --- a/PluralKit.Core/Utils/BulkImporter.cs +++ b/PluralKit.Core/Utils/BulkImporter.cs @@ -21,12 +21,12 @@ namespace PluralKit.Core { private readonly int _systemId; private readonly IPKConnection _conn; - private readonly DbTransaction _tx; + private readonly IPKTransaction _tx; private readonly Dictionary _knownMembers = new Dictionary(); private readonly Dictionary _existingMembersByHid = new Dictionary(); private readonly Dictionary _existingMembersByName = new Dictionary(); - private BulkImporter(int systemId, IPKConnection conn, DbTransaction tx) + private BulkImporter(int systemId, IPKConnection conn, IPKTransaction tx) { _systemId = systemId; _conn = conn;