Wrap DbTransaction too

This commit is contained in:
Ske 2020-06-13 18:49:05 +02:00
parent e176ccbab5
commit 37b99f9521
6 changed files with 89 additions and 36 deletions

View File

@ -17,8 +17,8 @@ namespace PluralKit.Core
public Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default); public Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default);
public ValueTask<DbTransaction> BeginTransactionAsync(CancellationToken ct = default) => BeginTransactionAsync(IsolationLevel.Unspecified, ct); public ValueTask<IPKTransaction> BeginTransactionAsync(CancellationToken ct = default) => BeginTransactionAsync(IsolationLevel.Unspecified, ct);
public ValueTask<DbTransaction> BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default); public ValueTask<IPKTransaction> BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default);
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand); public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand);
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand); public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand);

View File

@ -0,0 +1,13 @@
using System;
using System.Data;
using System.Threading;
using System.Threading.Tasks;
namespace PluralKit.Core
{
public interface IPKTransaction: IDbTransaction, IAsyncDisposable
{
public Task CommitAsync(CancellationToken ct = default);
public Task RollbackAsync(CancellationToken ct = default);
}
}

View File

@ -17,80 +17,87 @@ namespace PluralKit.Core
{ {
public class PKCommand: DbCommand, IPKCommand public class PKCommand: DbCommand, IPKCommand
{ {
private readonly NpgsqlCommand _inner; public NpgsqlCommand Inner { get; }
private readonly PKConnection _ourConnection; private readonly PKConnection _ourConnection;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly IMetrics _metrics; private readonly IMetrics _metrics;
public PKCommand(NpgsqlCommand inner, PKConnection ourConnection, ILogger logger, IMetrics metrics) public PKCommand(NpgsqlCommand inner, PKConnection ourConnection, ILogger logger, IMetrics metrics)
{ {
_inner = inner; Inner = inner;
_ourConnection = ourConnection; _ourConnection = ourConnection;
_logger = logger.ForContext<PKCommand>(); _logger = logger.ForContext<PKCommand>();
_metrics = metrics; _metrics = metrics;
} }
public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery)); public override Task<int> ExecuteNonQueryAsync(CancellationToken ct) => LogQuery(Inner.ExecuteNonQueryAsync(ct));
public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar)); public override Task<object> ExecuteScalarAsync(CancellationToken ct) => LogQuery(Inner.ExecuteScalarAsync(ct));
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => throw SyncError(nameof(ExecuteDbDataReader)); protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) => await LogQuery(Inner.ExecuteReaderAsync(behavior, ct));
public override Task<int> ExecuteNonQueryAsync(CancellationToken ct) => LogQuery(_inner.ExecuteNonQueryAsync(ct)); public override Task PrepareAsync(CancellationToken ct = default) => Inner.PrepareAsync(ct);
public override Task<object> ExecuteScalarAsync(CancellationToken ct) => LogQuery(_inner.ExecuteScalarAsync(ct)); public override void Cancel() => Inner.Cancel();
protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken ct) => await LogQuery(_inner.ExecuteReaderAsync(behavior, ct)); protected override DbParameter CreateDbParameter() => Inner.CreateParameter();
public override void Prepare() => _inner.Prepare();
public override void Cancel() => _inner.Cancel();
protected override DbParameter CreateDbParameter() => _inner.CreateParameter();
public override string CommandText public override string CommandText
{ {
get => _inner.CommandText; get => Inner.CommandText;
set => _inner.CommandText = value; set => Inner.CommandText = value;
} }
public override int CommandTimeout public override int CommandTimeout
{ {
get => _inner.CommandTimeout; get => Inner.CommandTimeout;
set => _inner.CommandTimeout = value; set => Inner.CommandTimeout = value;
} }
public override CommandType CommandType public override CommandType CommandType
{ {
get => _inner.CommandType; get => Inner.CommandType;
set => _inner.CommandType = value; set => Inner.CommandType = value;
} }
public override UpdateRowSource UpdatedRowSource public override UpdateRowSource UpdatedRowSource
{ {
get => _inner.UpdatedRowSource; get => Inner.UpdatedRowSource;
set => _inner.UpdatedRowSource = value; set => Inner.UpdatedRowSource = value;
} }
protected override DbParameterCollection DbParameterCollection => _inner.Parameters; protected override DbParameterCollection DbParameterCollection => Inner.Parameters;
protected override DbTransaction? DbTransaction protected override DbTransaction? DbTransaction
{ {
get => _inner.Transaction; get => Inner.Transaction;
set => _inner.Transaction = (NpgsqlTransaction?) value; set => Inner.Transaction = value switch
{
NpgsqlTransaction npg => npg,
PKTransaction pk => pk.Inner,
_ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlTransaction")
};
} }
public override bool DesignTimeVisible public override bool DesignTimeVisible
{ {
get => _inner.DesignTimeVisible; get => Inner.DesignTimeVisible;
set => _inner.DesignTimeVisible = value; set => Inner.DesignTimeVisible = value;
} }
protected override DbConnection? DbConnection protected override DbConnection? DbConnection
{ {
get => _inner.Connection; get => Inner.Connection;
set => set =>
_inner.Connection = value switch Inner.Connection = value switch
{ {
NpgsqlConnection npg => npg, NpgsqlConnection npg => npg,
PKConnection pk => pk.Inner, PKConnection pk => pk.Inner,
_ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlConnection") _ => throw new ArgumentException($"Can't convert input type {value?.GetType()} to NpgsqlConnection")
}; };
} }
public override int ExecuteNonQuery() => throw SyncError(nameof(ExecuteNonQuery));
public override object ExecuteScalar() => throw SyncError(nameof(ExecuteScalar));
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => throw SyncError(nameof(ExecuteDbDataReader));
public override void Prepare() => throw SyncError(nameof(Prepare));
private async Task<T> LogQuery<T>(Task<T> task) private async Task<T> LogQuery<T>(Task<T> task)
{ {
var start = SystemClock.Instance.GetCurrentInstant(); var start = SystemClock.Instance.GetCurrentInstant();
@ -112,6 +119,6 @@ namespace PluralKit.Core
} }
} }
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKCommand function {caller}!"); private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbCommand function {caller}!");
} }
} }

View File

@ -52,14 +52,16 @@ namespace PluralKit.Core
protected override DbCommand CreateDbCommand() => new PKCommand(Inner.CreateCommand(), this, _logger, _metrics); protected override DbCommand CreateDbCommand() => new PKCommand(Inner.CreateCommand(), this, _logger, _metrics);
public void ReloadTypes() => Inner.ReloadTypes(); public void ReloadTypes() => Inner.ReloadTypes();
public new async ValueTask<IPKTransaction> BeginTransactionAsync(IsolationLevel level, CancellationToken ct = default) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) => Inner.BeginBinaryImport(copyFromCommand); public NpgsqlBinaryImporter BeginBinaryImport(string copyFromCommand) => Inner.BeginBinaryImport(copyFromCommand);
public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) => Inner.BeginBinaryExport(copyToCommand); public NpgsqlBinaryExporter BeginBinaryExport(string copyToCommand) => Inner.BeginBinaryExport(copyToCommand);
public override void ChangeDatabase(string databaseName) => Inner.ChangeDatabase(databaseName); public override void ChangeDatabase(string databaseName) => Inner.ChangeDatabase(databaseName);
public override Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default) => Inner.ChangeDatabaseAsync(databaseName, ct); public override Task ChangeDatabaseAsync(string databaseName, CancellationToken ct = default) => Inner.ChangeDatabaseAsync(databaseName, ct);
protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw SyncError(nameof(BeginDbTransaction)); protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw SyncError(nameof(BeginDbTransaction));
protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => await Inner.BeginTransactionAsync(level, ct); protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
public override void Open() => throw SyncError(nameof(Open)); public override void Open() => throw SyncError(nameof(Open));
public override void Close() => throw SyncError(nameof(Close)); public override void Close() => throw SyncError(nameof(Close));
@ -102,6 +104,6 @@ namespace PluralKit.Core
_logger.Verbose("Closed database connection {ConnectionId} (open for {ConnectionDuration}), new connection count {ConnectionCount}", ConnectionId, duration, _countHolder.ConnectionCount); _logger.Verbose("Closed database connection {ConnectionId} (open for {ConnectionDuration}), new connection count {ConnectionCount}", ConnectionId, duration, _countHolder.ConnectionCount);
} }
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IPKConnection function {caller}!"); private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbCommand function {caller}!");
} }
} }

View File

@ -0,0 +1,31 @@
using System;
using System.Data;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using Npgsql;
namespace PluralKit.Core
{
public class PKTransaction: DbTransaction, IPKTransaction
{
public NpgsqlTransaction Inner { get; }
public PKTransaction(NpgsqlTransaction inner)
{
Inner = inner;
}
public override void Commit() => throw SyncError(nameof(Commit));
public override Task CommitAsync(CancellationToken ct = default) => Inner.CommitAsync(ct);
public override void Rollback() => throw SyncError(nameof(Rollback));
public override Task RollbackAsync(CancellationToken ct = default) => Inner.RollbackAsync(ct);
protected override DbConnection DbConnection => Inner.Connection;
public override IsolationLevel IsolationLevel => Inner.IsolationLevel;
private static Exception SyncError(string caller) => throw new Exception($"Executed synchronous IDbTransaction function {caller}!");
}
}

View File

@ -21,12 +21,12 @@ namespace PluralKit.Core
{ {
private readonly int _systemId; private readonly int _systemId;
private readonly IPKConnection _conn; private readonly IPKConnection _conn;
private readonly DbTransaction _tx; private readonly IPKTransaction _tx;
private readonly Dictionary<string, int> _knownMembers = new Dictionary<string, int>(); private readonly Dictionary<string, int> _knownMembers = new Dictionary<string, int>();
private readonly Dictionary<string, PKMember> _existingMembersByHid = new Dictionary<string, PKMember>(); private readonly Dictionary<string, PKMember> _existingMembersByHid = new Dictionary<string, PKMember>();
private readonly Dictionary<string, PKMember> _existingMembersByName = new Dictionary<string, PKMember>(); private readonly Dictionary<string, PKMember> _existingMembersByName = new Dictionary<string, PKMember>();
private BulkImporter(int systemId, IPKConnection conn, DbTransaction tx) private BulkImporter(int systemId, IPKConnection conn, IPKTransaction tx)
{ {
_systemId = systemId; _systemId = systemId;
_conn = conn; _conn = conn;