Refactor import/export database code
This commit is contained in:
		| @@ -7,7 +7,6 @@ using System.Threading.Tasks; | ||||
| using Newtonsoft.Json; | ||||
|  | ||||
| using NodaTime; | ||||
| using NodaTime.Text; | ||||
|  | ||||
| using Serilog; | ||||
|  | ||||
| @@ -16,11 +15,13 @@ namespace PluralKit.Core | ||||
|     public class DataFileService | ||||
|     { | ||||
|         private IDataStore _data; | ||||
|         private DbConnectionFactory _db; | ||||
|         private ILogger _logger; | ||||
|  | ||||
|         public DataFileService(ILogger logger, IDataStore data) | ||||
|         public DataFileService(ILogger logger, IDataStore data, DbConnectionFactory db) | ||||
|         { | ||||
|             _data = data; | ||||
|             _db = db; | ||||
|             _logger = logger.ForContext<DataFileService>(); | ||||
|         } | ||||
|  | ||||
| @@ -58,6 +59,7 @@ namespace PluralKit.Core | ||||
|  | ||||
|             return new DataFileSystem | ||||
|             { | ||||
|                 Version = 1, | ||||
|                 Id = system.Hid, | ||||
|                 Name = system.Name, | ||||
|                 Description = system.Description, | ||||
| @@ -71,23 +73,52 @@ namespace PluralKit.Core | ||||
|             }; | ||||
|         } | ||||
|  | ||||
|         private PKMember ConvertMember(PKSystem system, DataFileMember fileMember) | ||||
|         { | ||||
|             var newMember = new PKMember | ||||
|             { | ||||
|                 Hid = fileMember.Id, | ||||
|                 System = system.Id, | ||||
|                 Name = fileMember.Name, | ||||
|                 DisplayName = fileMember.DisplayName, | ||||
|                 Description = fileMember.Description, | ||||
|                 Color = fileMember.Color, | ||||
|                 Pronouns = fileMember.Pronouns, | ||||
|                 AvatarUrl = fileMember.AvatarUrl, | ||||
|                 KeepProxy = fileMember.KeepProxy, | ||||
|             }; | ||||
|  | ||||
|             if (fileMember.Prefix != null || fileMember.Suffix != null) | ||||
|                 newMember.ProxyTags = new List<ProxyTag> {new ProxyTag(fileMember.Prefix, fileMember.Suffix)}; | ||||
|             else | ||||
|                 // Ignore proxy tags where both prefix and suffix are set to null (would be invalid anyway) | ||||
|                 newMember.ProxyTags = (fileMember.ProxyTags ?? new ProxyTag[] { }).Where(tag => !tag.IsEmpty).ToList(); | ||||
|                  | ||||
|             if (fileMember.Birthday != null) | ||||
|             { | ||||
|                 var birthdayParse = DateTimeFormats.DateExportFormat.Parse(fileMember.Birthday); | ||||
|                 newMember.Birthday = birthdayParse.Success ? (LocalDate?)birthdayParse.Value : null; | ||||
|             } | ||||
|  | ||||
|             return newMember; | ||||
|         } | ||||
|          | ||||
|         public async Task<ImportResult> ImportSystem(DataFileSystem data, PKSystem system, ulong accountId) | ||||
|         { | ||||
|             // TODO: make atomic, somehow - we'd need to obtain one IDbConnection and reuse it | ||||
|             // which probably means refactoring SystemStore.Save and friends etc | ||||
|             var result = new ImportResult { | ||||
|                 AddedNames = new List<string>(), | ||||
|                 ModifiedNames = new List<string>(), | ||||
|                 System = system, | ||||
|                 Success = true // Assume success unless indicated otherwise | ||||
|             }; | ||||
|             var dataFileToMemberMapping = new Dictionary<string, PKMember>(); | ||||
|             var unmappedMembers = new List<DataFileMember>(); | ||||
|  | ||||
|              | ||||
|             // If we don't already have a system to save to, create one | ||||
|             if (system == null) | ||||
|                 system = await _data.CreateSystem(data.Name); | ||||
|             result.System = system; | ||||
|  | ||||
|             { | ||||
|                 system = result.System = await _data.CreateSystem(data.Name); | ||||
|                 await _data.AddAccount(system, accountId); | ||||
|             } | ||||
|              | ||||
|             // Apply system info | ||||
|             system.Name = data.Name; | ||||
|             if (data.Description != null) system.Description = data.Description; | ||||
| @@ -95,111 +126,53 @@ namespace PluralKit.Core | ||||
|             if (data.AvatarUrl != null) system.AvatarUrl = data.AvatarUrl; | ||||
|             if (data.TimeZone != null) system.UiTz = data.TimeZone ?? "UTC"; | ||||
|             await _data.SaveSystem(system); | ||||
|  | ||||
|             // Make sure to link the sender account, too | ||||
|             await _data.AddAccount(system, accountId); | ||||
|  | ||||
|             // Determine which members already exist and which ones need to be created | ||||
|             var membersByHid = new Dictionary<string, PKMember>(); | ||||
|             var membersByName = new Dictionary<string, PKMember>(); | ||||
|             await foreach (var member in _data.GetSystemMembers(system)) | ||||
|              | ||||
|             // -- Member/switch import -- | ||||
|             await using var conn = (PerformanceTrackingConnection) await _db.Obtain(); | ||||
|             await using (var imp = await BulkImporter.Begin(system, conn._impl)) | ||||
|             { | ||||
|                 membersByHid[member.Hid] = member; | ||||
|                 membersByName[member.Name] = member; | ||||
|             }  | ||||
|  | ||||
|             foreach (var d in data.Members) | ||||
|             { | ||||
|                 PKMember match = null; | ||||
|                 if (membersByHid.TryGetValue(d.Id, out var matchByHid)) match = matchByHid; // Try to look up the member with the given ID | ||||
|                 else if (membersByName.TryGetValue(d.Name, out var matchByName)) match = matchByName; // Try with the name instead | ||||
|                 // Tally up the members that didn't exist before, and check member count on import | ||||
|                 // If creating the unmatched members would put us over the member limit, abort before creating any members | ||||
|                 var memberCountBefore = await _data.GetSystemMemberCount(system, true); | ||||
|                 var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name)); | ||||
|                 if (memberCountBefore + membersToAdd > Limits.MaxMemberCount) | ||||
|                 { | ||||
|                     result.Success = false; | ||||
|                     result.Message = $"Import would exceed the maximum number of members ({Limits.MaxMemberCount})."; | ||||
|                     return result; | ||||
|                 } | ||||
|                  | ||||
|                 if (match != null) | ||||
|                 async Task DoImportMember(BulkImporter imp, DataFileMember fileMember) | ||||
|                 { | ||||
|                     dataFileToMemberMapping.Add(d.Id, match); // Relate the data file ID to the PKMember for importing switches | ||||
|                     result.ModifiedNames.Add(d.Name); | ||||
|                 }          | ||||
|                 else | ||||
|                 { | ||||
|                     unmappedMembers.Add(d); // Track members that weren't found so we can create them all | ||||
|                     result.AddedNames.Add(d.Name); | ||||
|                     var isCreatingNewMember = imp.IsNewMember(fileMember.Id, fileMember.Name); | ||||
|  | ||||
|                     // Use the file member's id as the "unique identifier" for the importing (actual value is irrelevant but needs to be consistent) | ||||
|                     _logger.Debug( | ||||
|                         "Importing member with identifier {FileId} to system {System} (is creating new member? {IsCreatingNewMember})", | ||||
|                         fileMember.Id, system.Id, isCreatingNewMember); | ||||
|                     var newMember = await imp.AddMember(fileMember.Id, ConvertMember(system, fileMember)); | ||||
|  | ||||
|                     if (isCreatingNewMember) | ||||
|                         result.AddedNames.Add(newMember.Name); | ||||
|                     else | ||||
|                         result.ModifiedNames.Add(newMember.Name); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // If creating the unmatched members would put us over the member limit, abort before creating any members | ||||
|             // new total: # in the system + (# in the file - # in the file that already exist) | ||||
|             if (data.Members.Count - dataFileToMemberMapping.Count + membersByHid.Count > Limits.MaxMemberCount) | ||||
|             { | ||||
|                 result.Success = false; | ||||
|                 result.Message = $"Import would exceed the maximum number of members ({Limits.MaxMemberCount})."; | ||||
|                 result.AddedNames.Clear(); | ||||
|                 result.ModifiedNames.Clear(); | ||||
|                 return result; | ||||
|             } | ||||
|  | ||||
|             // Create all unmapped members in one transaction | ||||
|             // These consist of members from another PluralKit system or another framework (e.g. Tupperbox) | ||||
|             var membersToCreate = new Dictionary<string, string>(); | ||||
|             unmappedMembers.ForEach(x => membersToCreate.Add(x.Id, x.Name)); | ||||
|             var newMembers = await _data.CreateMembersBulk(system, membersToCreate); | ||||
|             foreach (var member in newMembers) | ||||
|                 dataFileToMemberMapping.Add(member.Key, member.Value); | ||||
|  | ||||
|             // Update members with data file properties | ||||
|             // TODO: parallelize? | ||||
|             foreach (var dataMember in data.Members) | ||||
|             { | ||||
|                 dataFileToMemberMapping.TryGetValue(dataMember.Id, out PKMember member); | ||||
|                 if (member == null) | ||||
|                     continue; | ||||
|  | ||||
|                 // Apply member info | ||||
|                 member.Name = dataMember.Name; | ||||
|                 if (dataMember.DisplayName != null) member.DisplayName = dataMember.DisplayName; | ||||
|                 if (dataMember.Description != null) member.Description = dataMember.Description; | ||||
|                 if (dataMember.Color != null) member.Color = dataMember.Color.ToLower(); | ||||
|                 if (dataMember.AvatarUrl != null) member.AvatarUrl = dataMember.AvatarUrl; | ||||
|                 if (dataMember.Prefix != null || dataMember.Suffix != null) | ||||
|                  | ||||
|                 // Can't parallelize this because we can't reuse the same connection/tx inside the importer | ||||
|                 foreach (var m in data.Members)  | ||||
|                     await DoImportMember(imp, m); | ||||
|                  | ||||
|                 // Lastly, import the switches | ||||
|                 await imp.AddSwitches(data.Switches.Select(sw => new BulkImporter.SwitchInfo | ||||
|                 { | ||||
|                     member.ProxyTags = new List<ProxyTag> { new ProxyTag(dataMember.Prefix, dataMember.Suffix) }; | ||||
|                 } | ||||
|                 else | ||||
|                 { | ||||
|                     // Ignore proxy tags where both prefix and suffix are set to null (would be invalid anyway) | ||||
|                     member.ProxyTags = (dataMember.ProxyTags ?? new ProxyTag[] { }).Where(tag => !tag.IsEmpty).ToList(); | ||||
|                 } | ||||
|  | ||||
|                 member.KeepProxy = dataMember.KeepProxy; | ||||
|  | ||||
|                 if (dataMember.Birthday != null) | ||||
|                 { | ||||
|                     var birthdayParse = DateTimeFormats.DateExportFormat.Parse(dataMember.Birthday); | ||||
|                     member.Birthday = birthdayParse.Success ? (LocalDate?)birthdayParse.Value : null; | ||||
|                 } | ||||
|  | ||||
|                 await _data.SaveMember(member); | ||||
|                     Timestamp = DateTimeFormats.TimestampExportFormat.Parse(sw.Timestamp).Value, | ||||
|                     // "Members" here is from whatever ID the data file uses, which the bulk importer can map to the real IDs! :) | ||||
|                     MemberIdentifiers = sw.Members.ToList() | ||||
|                 }).ToList()); | ||||
|             } | ||||
|  | ||||
|             // Re-map the switch members in the likely case IDs have changed | ||||
|             var mappedSwitches = new List<ImportedSwitch>(); | ||||
|             foreach (var sw in data.Switches) | ||||
|             { | ||||
|                 var timestamp = InstantPattern.ExtendedIso.Parse(sw.Timestamp).Value; | ||||
|                 var swMembers = new List<PKMember>(); | ||||
|                 swMembers.AddRange(sw.Members.Select(x => | ||||
|                     dataFileToMemberMapping.FirstOrDefault(y => y.Key.Equals(x)).Value)); | ||||
|                 mappedSwitches.Add(new ImportedSwitch | ||||
|                 { | ||||
|                     Timestamp = timestamp, | ||||
|                     Members = swMembers | ||||
|                 }); | ||||
|             } | ||||
|             // Import switches | ||||
|             if (mappedSwitches.Any()) | ||||
|                 await _data.AddSwitchesBulk(system, mappedSwitches); | ||||
|  | ||||
|             _logger.Information("Imported system {System}", system.Hid); | ||||
|             return result; | ||||
|             return result;  | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -214,6 +187,7 @@ namespace PluralKit.Core | ||||
|  | ||||
|     public struct DataFileSystem | ||||
|     { | ||||
|         [JsonProperty("version")] public int Version; | ||||
|         [JsonProperty("id")] public string Id; | ||||
|         [JsonProperty("name")] public string Name; | ||||
|         [JsonProperty("description")] public string Description; | ||||
|   | ||||
| @@ -228,14 +228,6 @@ namespace PluralKit.Core { | ||||
|         /// <returns>The created system model.</returns> | ||||
|         Task<PKMember> CreateMember(PKSystem system, string name); | ||||
|          | ||||
|         /// <summary> | ||||
|         /// Creates multiple members, auto-generating each corresponding ID. | ||||
|         /// </summary> | ||||
|         /// <param name="system">The system to create the member in.</param> | ||||
|         /// <param name="memberNames">A dictionary containing a mapping from an arbitrary key to the member's name.</param> | ||||
|         /// <returns>A dictionary containing the resulting member structs, each mapped to the key given in the argument dictionary.</returns> | ||||
|         Task<Dictionary<string, PKMember>> CreateMembersBulk(PKSystem system, Dictionary<string, string> memberNames); | ||||
|          | ||||
|         /// <summary> | ||||
|         /// Saves the information within the given <see cref="PKMember"/> struct to the data store. | ||||
|         /// </summary> | ||||
| @@ -357,14 +349,7 @@ namespace PluralKit.Core { | ||||
|         /// </summary> | ||||
|         /// <exception>Throws an exception (TODO: which?) if any of the members are not in the given system.</exception> | ||||
|         Task AddSwitch(PKSystem system, IEnumerable<PKMember> switchMembers); | ||||
|          | ||||
|         /// <summary> | ||||
|         /// Registers switches in bulk. | ||||
|         /// </summary> | ||||
|         /// <param name="switches">A list of switch structs, each containing a timestamp and a list of members.</param> | ||||
|         /// <exception>Throws an exception (TODO: which?) if any of the given members are not in the given system.</exception> | ||||
|         Task AddSwitchesBulk(PKSystem system, IEnumerable<ImportedSwitch> switches); | ||||
|          | ||||
|  | ||||
|         /// <summary> | ||||
|         /// Updates the timestamp of a given switch.  | ||||
|         /// </summary> | ||||
|   | ||||
| @@ -174,38 +174,6 @@ namespace PluralKit.Core { | ||||
|             return member; | ||||
|         } | ||||
|  | ||||
|         public async Task<Dictionary<string,PKMember>> CreateMembersBulk(PKSystem system, Dictionary<string,string> names) | ||||
|         { | ||||
|             using (var conn = await _conn.Obtain()) | ||||
|             using (var tx = conn.BeginTransaction()) | ||||
|             { | ||||
|                 var results = new Dictionary<string, PKMember>(); | ||||
|                 foreach (var name in names) | ||||
|                 { | ||||
|                     string hid; | ||||
|                     do | ||||
|                     { | ||||
|                         hid = await conn.QuerySingleOrDefaultAsync<string>("SELECT @Hid WHERE NOT EXISTS (SELECT id FROM members WHERE hid = @Hid LIMIT 1)", new | ||||
|                         { | ||||
|                             Hid = StringUtils.GenerateHid() | ||||
|                         }); | ||||
|                     } while (hid == null); | ||||
|                     var member = await conn.QuerySingleAsync<PKMember>("INSERT INTO members (hid, system, name) VALUES (@Hid, @SystemId, @Name) RETURNING *", new | ||||
|                     { | ||||
|                         Hid = hid, | ||||
|                         SystemID = system.Id, | ||||
|                         Name = name.Value | ||||
|                     }); | ||||
|                     results.Add(name.Key, member); | ||||
|                 } | ||||
|  | ||||
|                 tx.Commit(); | ||||
|                 _logger.Information("Created {MemberCount} members for system {SystemID}", names.Count(), system.Hid); | ||||
|                 await _cache.InvalidateSystem(system); | ||||
|                 return results; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public async Task<PKMember> GetMemberById(int id) { | ||||
|             using (var conn = await _conn.Obtain()) | ||||
|                 return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where id = @Id", new { Id = id }); | ||||
| @@ -439,79 +407,6 @@ namespace PluralKit.Core { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public async Task AddSwitchesBulk(PKSystem system, IEnumerable<ImportedSwitch> switches) | ||||
|         { | ||||
|             // Read existing switches to enforce unique timestamps | ||||
|             var priorSwitches = new List<PKSwitch>(); | ||||
|             await foreach (var sw in GetSwitches(system)) priorSwitches.Add(sw); | ||||
|              | ||||
|             var lastSwitchId = priorSwitches.Any() | ||||
|                 ? priorSwitches.Max(x => x.Id) | ||||
|                 : 0; | ||||
|              | ||||
|             using (var conn = (PerformanceTrackingConnection) await _conn.Obtain()) | ||||
|             { | ||||
|                 using (var tx = conn.BeginTransaction()) | ||||
|                 { | ||||
|                     // Import switches in bulk | ||||
|                     using (var importer = conn.BeginBinaryImport("COPY switches (system, timestamp) FROM STDIN (FORMAT BINARY)")) | ||||
|                     { | ||||
|                         foreach (var sw in switches) | ||||
|                         { | ||||
|                             // If there's already a switch at this time, move on | ||||
|                             if (priorSwitches.Any(x => x.Timestamp.Equals(sw.Timestamp))) | ||||
|                                 continue; | ||||
|  | ||||
|                             // Otherwise, add it to the importer | ||||
|                             importer.StartRow(); | ||||
|                             importer.Write(system.Id, NpgsqlTypes.NpgsqlDbType.Integer); | ||||
|                             importer.Write(sw.Timestamp, NpgsqlTypes.NpgsqlDbType.Timestamp); | ||||
|                         } | ||||
|                         importer.Complete(); // Commits the copy operation so dispose won't roll it back | ||||
|                     } | ||||
|  | ||||
|                     // Get all switches that were created above and don't have members for ID lookup | ||||
|                     var switchesWithoutMembers = | ||||
|                         await conn.QueryAsync<PKSwitch>(@" | ||||
|                         SELECT switches.* | ||||
|                         FROM switches | ||||
|                         LEFT JOIN switch_members | ||||
|                         ON switch_members.switch = switches.id | ||||
|                         WHERE switches.id > @LastSwitchId | ||||
|                         AND switches.system = @System | ||||
|                         AND switch_members.id IS NULL", new { LastSwitchId = lastSwitchId, System = system.Id }); | ||||
|  | ||||
|                     // Import switch_members in bulk | ||||
|                     using (var importer = conn.BeginBinaryImport("COPY switch_members (switch, member) FROM STDIN (FORMAT BINARY)")) | ||||
|                     { | ||||
|                         // Iterate over the switches we created above and set their members | ||||
|                         foreach (var pkSwitch in switchesWithoutMembers) | ||||
|                         { | ||||
|                             // If this isn't in our import set, move on | ||||
|                             var sw = switches.Select(x => (ImportedSwitch?) x).FirstOrDefault(x => x.Value.Timestamp.Equals(pkSwitch.Timestamp)); | ||||
|                             if (sw == null) | ||||
|                                 continue; | ||||
|  | ||||
|                             // Loop through associated members to add each to the switch | ||||
|                             foreach (var m in sw.Value.Members) | ||||
|                             { | ||||
|                                 // Skip switch-outs - these don't have switch_members | ||||
|                                 if (m == null) | ||||
|                                     continue; | ||||
|                                 importer.StartRow(); | ||||
|                                 importer.Write(pkSwitch.Id, NpgsqlTypes.NpgsqlDbType.Integer); | ||||
|                                 importer.Write(m.Id, NpgsqlTypes.NpgsqlDbType.Integer); | ||||
|                             } | ||||
|                         } | ||||
|                         importer.Complete(); // Commits the copy operation so dispose won't roll it back | ||||
|                     } | ||||
|                     tx.Commit(); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             _logger.Information("Completed bulk import of switches for system {0}", system.Hid); | ||||
|         } | ||||
|          | ||||
|         public IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system) | ||||
|         { | ||||
|             // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list | ||||
|   | ||||
							
								
								
									
										209
									
								
								PluralKit.Core/Utils/BulkImporter.cs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										209
									
								
								PluralKit.Core/Utils/BulkImporter.cs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,209 @@ | ||||
| #nullable enable | ||||
| using System; | ||||
| using System.Collections.Generic; | ||||
| using System.Collections.Immutable; | ||||
| using System.Linq; | ||||
| using System.Threading.Tasks; | ||||
|  | ||||
| using Dapper; | ||||
|  | ||||
| using NodaTime; | ||||
|  | ||||
| using Npgsql; | ||||
|  | ||||
| using NpgsqlTypes; | ||||
|  | ||||
| namespace PluralKit.Core | ||||
| { | ||||
|     public class BulkImporter: IAsyncDisposable | ||||
|     { | ||||
|         private readonly int _systemId; | ||||
|         private readonly NpgsqlConnection _conn; | ||||
|         private readonly NpgsqlTransaction _tx; | ||||
|         private readonly Dictionary<string, int> _knownMembers = new Dictionary<string, int>(); | ||||
|         private readonly Dictionary<string, PKMember> _existingMembersByHid = new Dictionary<string, PKMember>(); | ||||
|         private readonly Dictionary<string, PKMember> _existingMembersByName = new Dictionary<string, PKMember>(); | ||||
|  | ||||
|         private BulkImporter(int systemId, NpgsqlConnection conn, NpgsqlTransaction tx) | ||||
|         { | ||||
|             _systemId = systemId; | ||||
|             _conn = conn; | ||||
|             _tx = tx; | ||||
|         } | ||||
|  | ||||
|         public static async Task<BulkImporter> Begin(PKSystem system, NpgsqlConnection conn) | ||||
|         { | ||||
|             var tx = await conn.BeginTransactionAsync(); | ||||
|             var importer = new BulkImporter(system.Id, conn, tx); | ||||
|             await importer.Begin(); | ||||
|             return importer; | ||||
|         } | ||||
|  | ||||
|         public async Task Begin() | ||||
|         { | ||||
|             // Fetch all members in the system and log their names and hids | ||||
|             var members = await _conn.QueryAsync<PKMember>("select id, hid, name from members where system = @System", | ||||
|                 new {System = _systemId}); | ||||
|             foreach (var m in members) | ||||
|             { | ||||
|                 _existingMembersByHid[m.Hid] = m; | ||||
|                 _existingMembersByName[m.Name] = m; | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         /// <summary> | ||||
|         /// Checks whether trying to add a member with the given hid and name would result in creating a new member (as opposed to just updating one). | ||||
|         /// </summary> | ||||
|         public bool IsNewMember(string hid, string name) => FindExistingMemberInSystem(hid, name) == null; | ||||
|  | ||||
|         /// <summary> | ||||
|         /// Imports a member into the database | ||||
|         /// </summary> | ||||
|         /// <remarks>If an existing member exists in this system that matches this member in either HID or name, it'll overlay the member information on top of this instead.</remarks> | ||||
|         /// <param name="identifier">An opaque identifier string that refers to this member regardless of source. Is used when importing switches. Value is irrelevant, but should be consistent with the same member later.</param> | ||||
|         /// <param name="member">A member struct containing the data to apply to this member. Null fields will be ignored.</param> | ||||
|         /// <returns>The inserted member object, which may or may not share an ID or HID with the input member.</returns> | ||||
|         public async Task<PKMember> AddMember(string identifier, PKMember member) | ||||
|         { | ||||
|             // See if we can find a member that matches this one | ||||
|             // if not, roll a new hid and we'll insert one with that | ||||
|             // (we can't trust the hid given in the member, it might let us overwrite another system's members) | ||||
|             var existingMember = FindExistingMemberInSystem(member.Hid, member.Name);  | ||||
|             string newHid = existingMember?.Hid ?? await FindFreeHid(); | ||||
|  | ||||
|             // Upsert member data and return the ID | ||||
|             QueryBuilder qb = QueryBuilder.Upsert("members", "hid") | ||||
|                 .Constant("hid", "@Hid") | ||||
|                 .Constant("system", "@System") | ||||
|                 .Variable("name", "@Name") | ||||
|                 .Variable("keep_proxy", "@KeepProxy"); | ||||
|  | ||||
|             if (member.DisplayName != null) qb.Variable("display_name", "@DisplayName"); | ||||
|             if (member.Description != null) qb.Variable("description", "@Description"); | ||||
|             if (member.Color != null) qb.Variable("color", "@Color"); | ||||
|             if (member.AvatarUrl != null) qb.Variable("avatar_url", "@AvatarUrl"); | ||||
|             if (member.ProxyTags != null) qb.Variable("proxy_tags", "@ProxyTags"); | ||||
|             if (member.Birthday != null) qb.Variable("birthday", "@Birthday"); | ||||
|  | ||||
|             var newMember = await _conn.QueryFirstAsync<PKMember>(qb.Build("returning *"), | ||||
|                 new | ||||
|                 { | ||||
|                     Hid = newHid, | ||||
|                     System = _systemId, | ||||
|                     member.Name, | ||||
|                     member.DisplayName, | ||||
|                     member.Description, | ||||
|                     member.Color, | ||||
|                     member.AvatarUrl, | ||||
|                     member.KeepProxy, | ||||
|                     member.ProxyTags, | ||||
|                     member.Birthday | ||||
|                 }); | ||||
|  | ||||
|             // Log this member ID by the given identifier | ||||
|             _knownMembers[identifier] = newMember.Id; | ||||
|             return newMember; | ||||
|         } | ||||
|  | ||||
|         private PKMember? FindExistingMemberInSystem(string hid, string name) | ||||
|         { | ||||
|             if (_existingMembersByHid.TryGetValue(hid, out var byHid)) return byHid; | ||||
|             if (_existingMembersByName.TryGetValue(name, out var byName)) return byName; | ||||
|             return null; | ||||
|         } | ||||
|  | ||||
|         private async Task<string> FindFreeHid() | ||||
|         { | ||||
|             string hid; | ||||
|             do | ||||
|             { | ||||
|                 hid = await _conn.QuerySingleOrDefaultAsync<string>( | ||||
|                     "select @Hid where not exists (select id from members where hid = @Hid)", | ||||
|                     new {Hid = StringUtils.GenerateHid()}); | ||||
|             } while (hid == null); | ||||
|  | ||||
|             return hid; | ||||
|         } | ||||
|  | ||||
|         /// <summary> | ||||
|         /// Register switches in bulk. | ||||
|         /// </summary> | ||||
|         /// <remarks>This function assumes there are no duplicate switches (ie. switches with the same timestamp).</remarks> | ||||
|         public async Task AddSwitches(IReadOnlyCollection<SwitchInfo> switches) | ||||
|         { | ||||
|             // Ensure we're aware of all the members we're trying to import from | ||||
|             if (!switches.All(sw => sw.MemberIdentifiers.All(m => _knownMembers.ContainsKey(m)))) | ||||
|                 throw new ArgumentException("One or more switch members haven't been added using this importer"); | ||||
|              | ||||
|             // Fetch the existing switches in the database so we can avoid duplicates | ||||
|             var existingSwitches = (await _conn.QueryAsync<PKSwitch>("select * from switches where system = @System", new {System = _systemId})).ToList(); | ||||
|             var existingTimestamps = existingSwitches.Select(sw => sw.Timestamp).ToImmutableHashSet(); | ||||
|             var lastSwitchId = existingSwitches.Count != 0 ? existingSwitches.Select(sw => sw.Id).Max() : -1; | ||||
|  | ||||
|             // Import switch definitions | ||||
|             var importedSwitches = new Dictionary<Instant, SwitchInfo>(); | ||||
|             await using (var importer = _conn.BeginBinaryImport("copy switches (system, timestamp) from stdin (format binary)")) | ||||
|             { | ||||
|                 foreach (var sw in switches) | ||||
|                 { | ||||
|                     // Don't import duplicate switches | ||||
|                     if (existingTimestamps.Contains(sw.Timestamp)) continue; | ||||
|                      | ||||
|                     // Otherwise, write to importer | ||||
|                     await importer.StartRowAsync(); | ||||
|                     await importer.WriteAsync(_systemId, NpgsqlDbType.Integer); | ||||
|                     await importer.WriteAsync(sw.Timestamp, NpgsqlDbType.Timestamp); | ||||
|                      | ||||
|                     // Note that we've imported a switch with this timestamp | ||||
|                     importedSwitches[sw.Timestamp] = sw; | ||||
|                 } | ||||
|  | ||||
|                 // Commit the import | ||||
|                 await importer.CompleteAsync(); | ||||
|             } | ||||
|              | ||||
|             // Now, fetch all the switches we just added (so, now we get their IDs too) | ||||
|             // IDs are sequential, so any ID in this system, with a switch ID > the last max, will be one we just added | ||||
|             var justAddedSwitches = await _conn.QueryAsync<PKSwitch>( | ||||
|                 "select * from switches where system = @System and id > @LastSwitchId", | ||||
|                 new {System = _systemId, LastSwitchId = lastSwitchId}); | ||||
|              | ||||
|             // Lastly, import the switch members | ||||
|             await using (var importer = _conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)")) | ||||
|             { | ||||
|                 foreach (var justAddedSwitch in justAddedSwitches) | ||||
|                 { | ||||
|                     if (!importedSwitches.TryGetValue(justAddedSwitch.Timestamp, out var switchInfo)) | ||||
|                         throw new Exception($"Found 'just-added' switch (by ID) with timestamp {justAddedSwitch.Timestamp}, but this did not correspond to a timestamp we just added a switch entry of! :/"); | ||||
|                      | ||||
|                     // We still assume timestamps are unique and non-duplicate, so: | ||||
|                     var members = switchInfo.MemberIdentifiers; | ||||
|                     foreach (var memberIdentifier in members) | ||||
|                     { | ||||
|                         if (!_knownMembers.TryGetValue(memberIdentifier, out var memberId)) | ||||
|                             throw new Exception($"Attempted to import switch with member identifier {memberIdentifier} but could not find an entry in the id map for this! :/"); | ||||
|                          | ||||
|                         await importer.StartRowAsync(); | ||||
|                         await importer.WriteAsync(justAddedSwitch.Id, NpgsqlDbType.Integer); | ||||
|                         await importer.WriteAsync(memberId, NpgsqlDbType.Integer); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 await importer.CompleteAsync(); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public struct SwitchInfo | ||||
|         { | ||||
|             public Instant Timestamp; | ||||
|              | ||||
|             /// <summary> | ||||
|             /// An ordered list of "member identifiers" matching with the identifier parameter passed to <see cref="BulkImporter.AddMember"/>. | ||||
|             /// </summary> | ||||
|             public IReadOnlyList<string> MemberIdentifiers; | ||||
|         } | ||||
|  | ||||
|         public async ValueTask DisposeAsync() =>  | ||||
|             await _tx.CommitAsync(); | ||||
|     } | ||||
| } | ||||
| @@ -155,7 +155,7 @@ namespace PluralKit.Core | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     public class PerformanceTrackingConnection: IDbConnection | ||||
|     public class PerformanceTrackingConnection: IDbConnection, IAsyncDisposable | ||||
|     { | ||||
|         // Simple delegation of everything. | ||||
|         internal NpgsqlConnection _impl; | ||||
| @@ -226,6 +226,7 @@ namespace PluralKit.Core | ||||
|         public string Database => _impl.Database; | ||||
|  | ||||
|         public ConnectionState State => _impl.State; | ||||
|         public ValueTask DisposeAsync() => _impl.DisposeAsync(); | ||||
|     } | ||||
|  | ||||
|     public class DbConnectionCountHolder | ||||
|   | ||||
							
								
								
									
										88
									
								
								PluralKit.Core/Utils/QueryBuilder.cs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								PluralKit.Core/Utils/QueryBuilder.cs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| #nullable enable | ||||
| using System; | ||||
| using System.Text; | ||||
|  | ||||
| namespace PluralKit.Core | ||||
| { | ||||
|     public class QueryBuilder | ||||
|     { | ||||
|         private readonly string? _conflictField; | ||||
|         private readonly string? _condition; | ||||
|         private StringBuilder _insertFragment = new StringBuilder(); | ||||
|         private StringBuilder _valuesFragment = new StringBuilder(); | ||||
|         private StringBuilder _updateFragment = new StringBuilder(); | ||||
|         private bool _firstInsert = true; | ||||
|         private bool _firstUpdate = true; | ||||
|         public QueryType Type { get; } | ||||
|         public string Table { get; } | ||||
|  | ||||
|         private QueryBuilder(QueryType type, string table, string? conflictField, string? condition) | ||||
|         { | ||||
|             Type = type; | ||||
|             Table = table; | ||||
|             _conflictField = conflictField; | ||||
|             _condition = condition; | ||||
|         } | ||||
|          | ||||
|         public static QueryBuilder Insert(string table) => new QueryBuilder(QueryType.Insert, table,  null, null); | ||||
|         public static QueryBuilder Update(string table, string condition) => new QueryBuilder(QueryType.Update, table, null, condition); | ||||
|         public static QueryBuilder Upsert(string table, string conflictField) => new QueryBuilder(QueryType.Upsert, table, conflictField, null); | ||||
|  | ||||
|         public QueryBuilder Constant(string fieldName, string paramName) | ||||
|         { | ||||
|             if (_firstInsert) _firstInsert = false; | ||||
|             else  | ||||
|             { | ||||
|                 _insertFragment.Append(", "); | ||||
|                 _valuesFragment.Append(", "); | ||||
|             } | ||||
|              | ||||
|             _insertFragment.Append(fieldName); | ||||
|             _valuesFragment.Append(paramName); | ||||
|             return this; | ||||
|         } | ||||
|          | ||||
|         public QueryBuilder Variable(string fieldName, string paramName) | ||||
|         { | ||||
|             Constant(fieldName, paramName); | ||||
|              | ||||
|             if (_firstUpdate) _firstUpdate = false; | ||||
|             else _updateFragment.Append(", "); | ||||
|              | ||||
|             _updateFragment.Append(fieldName); | ||||
|             _updateFragment.Append("="); | ||||
|             _updateFragment.Append(paramName); | ||||
|             return this; | ||||
|         } | ||||
|  | ||||
|         public string Build(string? suffix = null) | ||||
|         { | ||||
|             if (_firstInsert) | ||||
|                 throw new ArgumentException("No fields have been added to the query."); | ||||
|              | ||||
|             StringBuilder query = new StringBuilder(Type switch | ||||
|             { | ||||
|                 QueryType.Insert => $"insert into {Table} ({_insertFragment}) values ({_valuesFragment})", | ||||
|                 QueryType.Upsert => $"insert into {Table} ({_insertFragment}) values ({_valuesFragment}) on conflict ({_conflictField}) do update set {_updateFragment}", | ||||
|                 QueryType.Update => $"update {Table} set {_updateFragment}", | ||||
|                 _ => throw new ArgumentOutOfRangeException($"Unknown query type {Type}") | ||||
|             }); | ||||
|  | ||||
|             if (Type == QueryType.Update && _condition != null) | ||||
|                 query.Append($" where {_condition}"); | ||||
|              | ||||
|             if (suffix != null) | ||||
|                 query.Append($" {suffix}"); | ||||
|             query.Append(";"); | ||||
|  | ||||
|             return query.ToString(); | ||||
|         } | ||||
|  | ||||
|         public enum QueryType | ||||
|         { | ||||
|             Insert, | ||||
|             Update, | ||||
|             Upsert | ||||
|         } | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user