diff --git a/PluralKit.Bot/Commands/Groups.cs b/PluralKit.Bot/Commands/Groups.cs index 81790c7a..bdf5b0fc 100644 --- a/PluralKit.Bot/Commands/Groups.cs +++ b/PluralKit.Bot/Commands/Groups.cs @@ -45,7 +45,7 @@ namespace PluralKit.Bot await using var conn = await _db.Obtain(); // Check group cap - var existingGroupCount = await conn.QuerySingleAsync("select count(*) from groups where system = @System", new { System = ctx.System.Id }); + var existingGroupCount = await _repo.GetSystemGroupCount(conn, ctx.System.Id); var groupLimit = ctx.System.GroupLimitOverride ?? Limits.MaxGroupCount; if (existingGroupCount >= groupLimit) throw new PKError($"System has reached the maximum number of groups ({groupLimit}). Please delete unused groups first in order to create new ones."); diff --git a/PluralKit.Core/Database/Repository/ModelRepository.System.cs b/PluralKit.Core/Database/Repository/ModelRepository.System.cs index 102722ec..3028793f 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.System.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.System.cs @@ -27,6 +27,9 @@ namespace PluralKit.Core public IAsyncEnumerable GetSystemMembers(IPKConnection conn, SystemId system) => conn.QueryStreamAsync("select * from members where system = @SystemID", new { SystemID = system }); + public IAsyncEnumerable GetSystemGroups(IPKConnection conn, SystemId system) => + conn.QueryStreamAsync("select * from groups where system = @System", new { System = system }); + public Task GetSystemMemberCount(IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null) { var query = new StringBuilder("select count(*) from members where system = @Id"); @@ -35,6 +38,9 @@ namespace PluralKit.Core return conn.QuerySingleAsync(query.ToString(), new { Id = id }); } + public Task GetSystemGroupCount(IPKConnection conn, SystemId id) => + conn.QuerySingleAsync("select count(*) from groups where system = @System", new { System = id }); + public async Task CreateSystem(IPKConnection conn, string? systemName = null, IPKTransaction? tx = null) { var system = await conn.QuerySingleAsync( diff --git a/PluralKit.Core/Models/PKGroup.cs b/PluralKit.Core/Models/PKGroup.cs index 50baa216..b288c69a 100644 --- a/PluralKit.Core/Models/PKGroup.cs +++ b/PluralKit.Core/Models/PKGroup.cs @@ -1,6 +1,6 @@ using NodaTime; - +using Newtonsoft.Json.Linq; namespace PluralKit.Core { @@ -57,5 +57,39 @@ namespace PluralKit.Core public static string? IconFor(this PKGroup group, LookupContext ctx) => group.IconPrivacy.Get(ctx, group.Icon?.TryGetCleanCdnUrl()); + + public static JObject ToJson(this PKGroup group, LookupContext ctx, bool isExport = false) + { + var o = new JObject(); + + o.Add("id", group.Hid); + o.Add("name", group.Name); + o.Add("display_name", group.DisplayName); + o.Add("description", group.DescriptionPrivacy.Get(ctx, group.Description)); + o.Add("icon", group.Icon); + o.Add("banner", group.DescriptionPrivacy.Get(ctx, group.BannerImage)); + o.Add("color", group.Color); + + o.Add("created", group.Created.FormatExport()); + + if (isExport) + o.Add("members", new JArray()); + + if (ctx == LookupContext.ByOwner) + { + var p = new JObject(); + + p.Add("description_privacy", group.DescriptionPrivacy.ToJsonString()); + p.Add("icon_privacy", group.IconPrivacy.ToJsonString()); + p.Add("list_privacy", group.ListPrivacy.ToJsonString()); + p.Add("visibility", group.Visibility.ToJsonString()); + + o.Add("privacy", p); + } + else + o.Add("privacy", null); + + return o; + } } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/GroupPatch.cs b/PluralKit.Core/Models/Patch/GroupPatch.cs index 95f0116e..4ff0a401 100644 --- a/PluralKit.Core/Models/Patch/GroupPatch.cs +++ b/PluralKit.Core/Models/Patch/GroupPatch.cs @@ -1,6 +1,8 @@ #nullable enable using System.Text.RegularExpressions; +using Newtonsoft.Json.Linq; + namespace PluralKit.Core { public class GroupPatch: PatchObject @@ -40,6 +42,40 @@ namespace PluralKit.Core if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$"))) throw new ValidationError("color"); } +#nullable disable + public static GroupPatch FromJson(JObject o) + { + var patch = new GroupPatch(); + + if (o.ContainsKey("name") && o["name"].Type == JTokenType.Null) + throw new ValidationError("Group name can not be set to null."); + + if (o.ContainsKey("name")) patch.Name = o.Value("name"); + if (o.ContainsKey("display_name")) patch.DisplayName = o.Value("display_name").NullIfEmpty(); + if (o.ContainsKey("description")) patch.Description = o.Value("description").NullIfEmpty(); + if (o.ContainsKey("icon")) patch.Icon = o.Value("icon").NullIfEmpty(); + if (o.ContainsKey("banner")) patch.BannerImage = o.Value("banner").NullIfEmpty(); + if (o.ContainsKey("color")) patch.Color = o.Value("color").NullIfEmpty()?.ToLower(); + + if (o.ContainsKey("privacy") && o["privacy"].Type != JTokenType.Null) + { + var privacy = o.Value("privacy"); + + if (privacy.ContainsKey("description_privacy")) + patch.DescriptionPrivacy = privacy.ParsePrivacy("description_privacy"); + + if (privacy.ContainsKey("icon_privacy")) + patch.IconPrivacy = privacy.ParsePrivacy("icon_privacy"); + + if (privacy.ContainsKey("list_privacy")) + patch.ListPrivacy = privacy.ParsePrivacy("list_privacy"); + + if (privacy.ContainsKey("visibility")) + patch.Visibility = privacy.ParsePrivacy("visibility"); + } + + return patch; + } } } \ No newline at end of file diff --git a/PluralKit.Core/Services/DataFileService.cs b/PluralKit.Core/Services/DataFileService.cs index 0d5f3ff0..a7986787 100644 --- a/PluralKit.Core/Services/DataFileService.cs +++ b/PluralKit.Core/Services/DataFileService.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -42,6 +43,23 @@ namespace PluralKit.Core o.Add("accounts", new JArray((await _repo.GetSystemAccounts(conn, system.Id)).ToList())); o.Add("members", new JArray((await _repo.GetSystemMembers(conn, system.Id).ToListAsync()).Select(m => m.ToJson(LookupContext.ByOwner)))); + var groups = (await _repo.GetSystemGroups(conn, system.Id).ToListAsync()); + var j_groups = groups.Select(x => x.ToJson(LookupContext.ByOwner, isExport: true)).ToList(); + + if (groups.Count > 0) + { + var q = await conn.QueryAsync(@$"select groups.hid as group, members.hid as member from group_members + left join groups on groups.id = group_members.group_id + left join members on members.id = group_members.member_id + where group_members.group_id in ({string.Join(", ", groups.Select(x => x.Id.Value.ToString()))}) + "); + + foreach (var row in q) + ((JArray)j_groups.Find(x => x.Value("id") == row.Group)["members"]).Add(row.Member); + } + + o.Add("groups", new JArray(j_groups)); + var switches = new JArray(); var switchList = await _repo.GetPeriodFronters(conn, system.Id, null, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant()); @@ -64,5 +82,12 @@ namespace PluralKit.Core return await BulkImporter.PerformImport(conn, tx, _repo, _logger, userId, system, importFile, confirmFunc); } + + } + + public class GroupMember + { + public string Group { get; set; } + public string Member { get; set; } } } \ No newline at end of file diff --git a/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs b/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs index 3055d669..da76294e 100644 --- a/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs +++ b/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs @@ -25,7 +25,12 @@ namespace PluralKit.Core private readonly Dictionary _existingMemberHids = new(); private readonly Dictionary _existingMemberNames = new(); - private readonly Dictionary _knownIdentifiers = new(); + private readonly Dictionary _knownMemberIdentifiers = new(); + + private readonly Dictionary _existingGroupHids = new(); + private readonly Dictionary _existingGroupNames = new(); + private readonly Dictionary _knownGroupIdentifiers = new(); + private ImportResultNew _result = new(); internal static async Task PerformImport(IPKConnection conn, IPKTransaction tx, ModelRepository repo, ILogger logger, @@ -58,6 +63,15 @@ namespace PluralKit.Core importer._existingMemberNames[m.Name] = m.Id; } + // same as above for groups + var groups = await conn.QueryAsync("select id, hid, name from groups where system = @System", + new { System = system.Id }); + foreach (var g in groups) + { + importer._existingGroupHids[g.Hid] = g.Id; + importer._existingGroupNames[g.Name] = g.Id; + } + try { if (importFile.ContainsKey("tuppers")) @@ -89,7 +103,14 @@ namespace PluralKit.Core return (null, false); } - private async Task AssertLimitNotReached(int newMembers) + private (GroupId?, bool) TryGetExistingGroup(string hid, string name) + { + if (_existingGroupHids.TryGetValue(hid, out var byHid)) return (byHid, true); + if (_existingGroupNames.TryGetValue(name, out var byName)) return (byName, false); + return (null, false); + } + + private async Task AssertMemberLimitNotReached(int newMembers) { var memberLimit = _system.MemberLimitOverride ?? Limits.MaxMemberCount; var existingMembers = await _repo.GetSystemMemberCount(_conn, _system.Id); @@ -97,6 +118,14 @@ namespace PluralKit.Core throw new ImportException($"Import would exceed the maximum number of members ({memberLimit})."); } + private async Task AssertGroupLimitNotReached(int newGroups) + { + var limit = _system.GroupLimitOverride ?? Limits.MaxGroupCount; + var existing = await _repo.GetSystemGroupCount(_conn, _system.Id); + if (existing + newGroups > limit) + throw new ImportException($"Import would exceed the maximum number of groups ({limit})."); + } + public async ValueTask DisposeAsync() { // try rolling back the transaction diff --git a/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs b/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs index 09c90b0d..94389102 100644 --- a/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs +++ b/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs @@ -32,6 +32,7 @@ namespace PluralKit.Core await _repo.UpdateSystem(_conn, _system.Id, patch, _tx); var members = importFile.Value("members"); + var groups = importFile.Value("groups"); var switches = importFile.Value("switches"); var newMembers = members.Count(m => @@ -39,12 +40,26 @@ namespace PluralKit.Core var (found, _) = TryGetExistingMember(m.Value("id"), m.Value("name")); return found == null; }); - await AssertLimitNotReached(newMembers); + await AssertMemberLimitNotReached(newMembers); + + if (groups != null) + { + var newGroups = groups.Count(g => + { + var (found, _) = TryGetExistingGroup(g.Value("id"), g.Value("name")); + return found == null; + }); + await AssertGroupLimitNotReached(newGroups); + } foreach (JObject member in members) await ImportMember(member); - if (switches.Any(sw => sw.Value("members").Any(m => !_knownIdentifiers.ContainsKey((string)m)))) + if (groups != null) + foreach (JObject group in groups) + await ImportGroup(group); + + if (switches.Any(sw => sw.Value("members").Any(m => !_knownMemberIdentifiers.ContainsKey((string)m)))) throw new ImportException("One or more switches include members that haven't been imported."); await ImportSwitches(switches); @@ -93,11 +108,76 @@ namespace PluralKit.Core memberId = newMember.Id; } - _knownIdentifiers[id] = memberId.Value; + _knownMemberIdentifiers[id] = memberId.Value; await _repo.UpdateMember(_conn, memberId.Value, patch, _tx); } + private async Task ImportGroup(JObject group) + { + var id = group.Value("id"); + var name = group.Value("name"); + + var (found, isHidExisting) = TryGetExistingGroup(id, name); + var isNewGroup = found == null; + var referenceName = isHidExisting ? id : name; + + _logger.Debug( + "Importing group with identifier {FileId} to system {System} (is creating new group? {IsCreatingNewGroup})", + referenceName, _system.Id, isNewGroup + ); + + var patch = GroupPatch.FromJson(group); + try + { + patch.AssertIsValid(); + } + catch (FieldTooLongError e) + { + throw new ImportException($"Field {e.Name} in group {referenceName} is too long ({e.ActualLength} > {e.MaxLength})."); + } + catch (ValidationError e) + { + throw new ImportException($"Field {e.Message} in group {referenceName} is invalid."); + } + + GroupId? groupId = found; + + if (isNewGroup) + { + var newGroup = await _repo.CreateGroup(_conn, _system.Id, patch.Name.Value, _tx); + groupId = newGroup.Id; + } + + _knownGroupIdentifiers[id] = groupId.Value; + + await _repo.UpdateGroup(_conn, groupId.Value, patch, _tx); + + var groupMembers = group.Value("members"); + var currentGroupMembers = (await _conn.QueryAsync( + "select member_id from group_members where group_id = @groupId", + new { groupId = groupId.Value } + )).ToList(); + + await using (var importer = _conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)")) + { + foreach (var memberIdentifier in groupMembers) + { + if (!_knownMemberIdentifiers.TryGetValue(memberIdentifier.ToString(), out var memberId)) + throw new Exception($"Attempted to import group member with member identifier {memberIdentifier} but could not find a recently imported member with this id!"); + + if (currentGroupMembers.Contains(memberId)) + continue; + + await importer.StartRowAsync(); + await importer.WriteAsync(groupId.Value.Value, NpgsqlDbType.Integer); + await importer.WriteAsync(memberId.Value, NpgsqlDbType.Integer); + } + + await importer.CompleteAsync(); + } + + } private async Task ImportSwitches(JArray switches) { var existingSwitches = (await _conn.QueryAsync("select * from switches where system = @System", new { System = _system.Id })).ToList(); @@ -154,7 +234,7 @@ namespace PluralKit.Core // We still assume timestamps are unique and non-duplicate, so: foreach (var memberIdentifier in switchMembers) { - if (!_knownIdentifiers.TryGetValue((string)memberIdentifier, out var memberId)) + if (!_knownMemberIdentifiers.TryGetValue((string)memberIdentifier, out var memberId)) throw new Exception($"Attempted to import switch with member identifier {memberIdentifier} but could not find an entry in the id map for this! :/"); await importer.StartRowAsync(); diff --git a/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs b/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs index 2023c7d4..df5137e7 100644 --- a/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs +++ b/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs @@ -15,7 +15,7 @@ namespace PluralKit.Core { var tuppers = importFile.Value("tuppers"); var newMembers = tuppers.Count(t => !_existingMemberNames.TryGetValue("name", out var memberId)); - await AssertLimitNotReached(newMembers); + await AssertMemberLimitNotReached(newMembers); string lastSetTag = null; bool multipleTags = false;