diff --git a/PluralKit.Core/Database/Database.cs b/PluralKit.Core/Database/Database.cs index de98f468..679674b3 100644 --- a/PluralKit.Core/Database/Database.cs +++ b/PluralKit.Core/Database/Database.cs @@ -186,6 +186,16 @@ namespace PluralKit.Core return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings); } + public async Task ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "") + { + if (conn == null) + return await ExecuteQuery(q, extraSql, queryName); + + var query = _compiler.Compile(q); + using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) + return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings); + } + public async Task QueryFirst(Query q, string extraSql = "", [CallerMemberName] string queryName = "") { var query = _compiler.Compile(q); diff --git a/PluralKit.Core/Database/IDatabase.cs b/PluralKit.Core/Database/IDatabase.cs index 6c48d6be..ead9458c 100644 --- a/PluralKit.Core/Database/IDatabase.cs +++ b/PluralKit.Core/Database/IDatabase.cs @@ -15,6 +15,7 @@ namespace PluralKit.Core Task Execute(Func> func); IAsyncEnumerable Execute(Func> func); Task ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = ""); + Task ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = ""); Task QueryFirst(Query q, string extraSql = "", [CallerMemberName] string queryName = ""); Task QueryFirst(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = ""); Task> Query(Query q, [CallerMemberName] string queryName = ""); diff --git a/PluralKit.Core/Database/Repository/ModelRepository.System.cs b/PluralKit.Core/Database/Repository/ModelRepository.System.cs index a9a48e28..805892c9 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.System.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.System.cs @@ -88,7 +88,7 @@ namespace PluralKit.Core return _db.QueryFirst(conn, query, extraSql: "returning *"); } - public Task AddAccount(SystemId system, ulong accountId) + public Task AddAccount(SystemId system, ulong accountId, IPKConnection? conn = null) { // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent // This is used in import/export, although the pk;link command checks for this case beforehand @@ -100,7 +100,7 @@ namespace PluralKit.Core }); _logger.Information("Linked account {UserId} to {SystemId}", accountId, system); - return _db.ExecuteQuery(query, extraSql: "on conflict do nothing"); + return _db.ExecuteQuery(conn, query, extraSql: "on conflict do nothing"); } public async Task RemoveAccount(SystemId system, ulong accountId) diff --git a/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs b/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs index 5ad0297e..c7bae342 100644 --- a/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs +++ b/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs @@ -49,7 +49,7 @@ namespace PluralKit.Core if (system == null) { system = await repo.CreateSystem(null, importer._conn); - await repo.AddAccount(system.Id, userId); + await repo.AddAccount(system.Id, userId, importer._conn); importer._result.CreatedSystem = system.Hid; importer._system = system; }