feat: refactor external input handling code

- refactor import/export code
- make import/export use the same JSON parsing as API
- make Patch.AssertIsValid actually useful
This commit is contained in:
spiral 2021-08-25 21:43:31 -04:00
parent f912805ecc
commit 4b944e2b20
No known key found for this signature in database
GPG Key ID: A6059F0CA0E1BD31
18 changed files with 619 additions and 694 deletions

View File

@ -34,7 +34,7 @@ namespace PluralKit.API
var member = await _db.Execute(conn => _repo.GetMemberByHid(conn, hid));
if (member == null) return NotFound("Member not found.");
return Ok(member.ToJson(User.ContextFor(member)));
return Ok(member.ToJson(User.ContextFor(member), needsLegacyProxyTags: true));
}
[HttpPost]
@ -62,14 +62,14 @@ namespace PluralKit.API
try
{
patch = MemberPatch.FromJSON(properties);
patch.CheckIsValid();
patch.AssertIsValid();
}
catch (JsonModelParseError e)
catch (FieldTooLongError e)
{
await tx.RollbackAsync();
return BadRequest(e.Message);
}
catch (InvalidPatchException e)
catch (ValidationError e)
{
await tx.RollbackAsync();
return BadRequest($"Request field '{e.Message}' is invalid.");
@ -77,7 +77,7 @@ namespace PluralKit.API
member = await _repo.UpdateMember(conn, member.Id, patch, transaction: tx);
await tx.CommitAsync();
return Ok(member.ToJson(User.ContextFor(member)));
return Ok(member.ToJson(User.ContextFor(member), needsLegacyProxyTags: true));
}
[HttpPatch("{hid}")]
@ -96,19 +96,19 @@ namespace PluralKit.API
try
{
patch = MemberPatch.FromJSON(changes);
patch.CheckIsValid();
patch.AssertIsValid();
}
catch (JsonModelParseError e)
catch (FieldTooLongError e)
{
return BadRequest(e.Message);
}
catch (InvalidPatchException e)
catch (ValidationError e)
{
return BadRequest($"Request field '{e.Message}' is invalid.");
}
var newMember = await _repo.UpdateMember(conn, member.Id, patch);
return Ok(newMember.ToJson(User.ContextFor(newMember)));
return Ok(newMember.ToJson(User.ContextFor(newMember), needsLegacyProxyTags: true));
}
[HttpDelete("{hid}")]

View File

@ -49,7 +49,7 @@ namespace PluralKit.API
Id = msg.Message.Mid.ToString(),
Channel = msg.Message.Channel.ToString(),
Sender = msg.Message.Sender.ToString(),
Member = msg.Member.ToJson(User.ContextFor(msg.System)),
Member = msg.Member.ToJson(User.ContextFor(msg.System), needsLegacyProxyTags: true),
System = msg.System.ToJson(User.ContextFor(msg.System)),
Original = msg.Message.OriginalMid?.ToString()
};

View File

@ -80,7 +80,7 @@ namespace PluralKit.API
var members = _db.Execute(c => _repo.GetSystemMembers(c, system.Id));
return Ok(await members
.Where(m => m.MemberVisibility.CanAccess(User.ContextFor(system)))
.Select(m => m.ToJson(User.ContextFor(system)))
.Select(m => m.ToJson(User.ContextFor(system), needsLegacyProxyTags: true))
.ToListAsync());
}
@ -126,7 +126,7 @@ namespace PluralKit.API
return Ok(new FrontersReturn
{
Timestamp = sw.Timestamp,
Members = await members.Select(m => m.ToJson(User.ContextFor(system))).ToListAsync()
Members = await members.Select(m => m.ToJson(User.ContextFor(system), needsLegacyProxyTags: true)).ToListAsync()
});
}
@ -141,13 +141,13 @@ namespace PluralKit.API
try
{
patch = SystemPatch.FromJSON(changes);
patch.CheckIsValid();
patch.AssertIsValid();
}
catch (JsonModelParseError e)
catch (FieldTooLongError e)
{
return BadRequest(e.Message);
}
catch (InvalidPatchException e)
catch (ValidationError e)
{
return BadRequest($"Request field '{e.Message}' is invalid.");
}

View File

@ -41,91 +41,57 @@ namespace PluralKit.Bot
if (url == null) throw Errors.NoImportFilePassed;
await ctx.BusyIndicator(async () =>
{
HttpResponseMessage response;
{
JObject data;
try
{
response = await _client.GetAsync(url);
var response = await _client.GetAsync(url);
if (!response.IsSuccessStatusCode)
throw Errors.InvalidImportFile;
data = JsonConvert.DeserializeObject<JObject>(await response.Content.ReadAsStringAsync(), _settings);
if (data == null)
throw Errors.InvalidImportFile;
}
catch (InvalidOperationException)
{
// Invalid URL throws this, we just error back out
throw Errors.InvalidImportFile;
}
if (!response.IsSuccessStatusCode)
throw Errors.InvalidImportFile;
DataFileSystem data;
try
{
var json = JsonConvert.DeserializeObject<JObject>(await response.Content.ReadAsStringAsync(), _settings);
data = await LoadSystem(ctx, json);
}
catch (JsonException)
{
throw Errors.InvalidImportFile;
}
if (!data.Valid)
throw Errors.InvalidImportFile;
if (data.LinkedAccounts != null && !data.LinkedAccounts.Contains(ctx.Author.Id))
async Task ConfirmImport(string message)
{
var msg = $"{message}\n\nDo you want to proceed with the import?";
if (!await ctx.PromptYesNo(msg, "Proceed"))
throw Errors.ImportCancelled;
}
if (data.ContainsKey("accounts")
&& data.Value<JArray>("accounts").Type != JTokenType.Null
&& data.Value<JArray>("accounts").Contains((JToken) ctx.Author.Id.ToString()))
{
var msg = $"{Emojis.Warn} You seem to importing a system profile belonging to another account. Are you sure you want to proceed?";
if (!await ctx.PromptYesNo(msg, "Import")) throw Errors.ImportCancelled;
}
// If passed system is null, it'll create a new one
// (and that's okay!)
var result = await _dataFiles.ImportSystem(data, ctx.System, ctx.Author.Id);
var result = await _dataFiles.ImportSystem(ctx.Author.Id, ctx.System, data, ConfirmImport);
if (!result.Success)
await ctx.Reply($"{Emojis.Error} The provided system profile could not be imported. {result.Message}");
if (result.Message == null)
throw Errors.InvalidImportFile;
else
await ctx.Reply($"{Emojis.Error} The provided system profile could not be imported: {result.Message}");
else if (ctx.System == null)
{
// We didn't have a system prior to importing, so give them the new system's ID
await ctx.Reply($"{Emojis.Success} PluralKit has created a system for you based on the given file. Your system ID is `{result.System.Hid}`. Type `pk;system` for more information.");
}
await ctx.Reply($"{Emojis.Success} PluralKit has created a system for you based on the given file. Your system ID is `{result.CreatedSystem}`. Type `pk;system` for more information.");
else
{
// We already had a system, so show them what changed
await ctx.Reply($"{Emojis.Success} Updated {result.ModifiedNames.Count} members, created {result.AddedNames.Count} members. Type `pk;system list` to check!");
}
await ctx.Reply($"{Emojis.Success} Updated {result.Modified} members, created {result.Added} members. Type `pk;system list` to check!");
});
}
private async Task<DataFileSystem> LoadSystem(Context ctx, JObject json)
{
if (json.ContainsKey("tuppers"))
return await ImportFromTupperbox(ctx, json);
return json.ToObject<DataFileSystem>();
}
private async Task<DataFileSystem> ImportFromTupperbox(Context ctx, JObject json)
{
var tupperbox = json.ToObject<TupperboxProfile>();
if (!tupperbox.Valid)
throw Errors.InvalidImportFile;
var res = tupperbox.ToPluralKit();
if (res.HadGroups || res.HadIndividualTags)
{
var issueStr =
$"{Emojis.Warn} The following potential issues were detected converting your Tupperbox input file:";
if (res.HadGroups)
issueStr += "\n- PluralKit does not support member groups. Members will be imported without groups.";
if (res.HadIndividualTags)
issueStr += "\n- PluralKit does not support per-member system tags. Since you had multiple members with distinct tags, those tags will be applied to the members' *display names*/nicknames instead.";
var msg = $"{issueStr}\n\nDo you want to proceed with the import?";
if (!await ctx.PromptYesNo(msg, "Proceed"))
throw Errors.ImportCancelled;
}
return res.System;
}
public async Task Export(Context ctx)
{
ctx.CheckSystem();

View File

@ -35,22 +35,23 @@ namespace PluralKit.Core
return conn.QuerySingleAsync<int>(query.ToString(), new {Id = id});
}
public async Task<PKSystem> CreateSystem(IPKConnection conn, string? systemName = null)
public async Task<PKSystem> CreateSystem(IPKConnection conn, string? systemName = null, IPKTransaction? tx = null)
{
var system = await conn.QuerySingleAsync<PKSystem>(
"insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *",
new {Name = systemName});
new {Name = systemName},
transaction: tx);
_logger.Information("Created {SystemId}", system.Id);
return system;
}
public Task<PKSystem> UpdateSystem(IPKConnection conn, SystemId id, SystemPatch patch)
public Task<PKSystem> UpdateSystem(IPKConnection conn, SystemId id, SystemPatch patch, IPKTransaction? tx = null)
{
_logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKSystem>(query, pms);
return conn.QueryFirstAsync<PKSystem>(query, pms, transaction: tx);
}
public async Task AddAccount(IPKConnection conn, SystemId system, ulong accountId)

View File

@ -102,7 +102,7 @@ namespace PluralKit.Core {
public static int MessageCountFor(this PKMember member, LookupContext ctx) =>
member.MetadataPrivacy.Get(ctx, member.MessageCount);
public static JObject ToJson(this PKMember member, LookupContext ctx)
public static JObject ToJson(this PKMember member, LookupContext ctx, bool needsLegacyProxyTags = false)
{
var includePrivacy = ctx == LookupContext.ByOwner;
@ -138,7 +138,7 @@ namespace PluralKit.Core {
o.Add("created", member.CreatedFor(ctx)?.FormatExport());
if (member.ProxyTags.Count > 0)
if (member.ProxyTags.Count > 0 && needsLegacyProxyTags)
{
// Legacy compatibility only, TODO: remove at some point
o.Add("prefix", member.ProxyTags?.FirstOrDefault().Prefix);

View File

@ -31,14 +31,14 @@ namespace PluralKit.Core
.With("list_privacy", ListPrivacy)
.With("visibility", Visibility);
public new void CheckIsValid()
public new void AssertIsValid()
{
if (Icon.Value != null && !MiscUtils.TryMatchUri(Icon.Value, out var avatarUri))
throw new InvalidPatchException("icon");
throw new ValidationError("icon");
if (BannerImage.Value != null && !MiscUtils.TryMatchUri(BannerImage.Value, out var bannerImage))
throw new InvalidPatchException("banner");
throw new ValidationError("banner");
if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$")))
throw new InvalidPatchException("color");
throw new ValidationError("color");
}
}

View File

@ -53,14 +53,28 @@ namespace PluralKit.Core
.With("avatar_privacy", AvatarPrivacy)
.With("metadata_privacy", MetadataPrivacy);
public new void CheckIsValid()
public new void AssertIsValid()
{
if (AvatarUrl.Value != null && !MiscUtils.TryMatchUri(AvatarUrl.Value, out var avatarUri))
throw new InvalidPatchException("avatar_url");
if (BannerImage.Value != null && !MiscUtils.TryMatchUri(BannerImage.Value, out var bannerImage))
throw new InvalidPatchException("banner");
if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$")))
throw new InvalidPatchException("color");
if (Name.IsPresent)
AssertValid(Name.Value, "display_name", Limits.MaxMemberNameLength);
if (DisplayName.Value != null)
AssertValid(DisplayName.Value, "display_name", Limits.MaxMemberNameLength);
if (AvatarUrl.Value != null)
AssertValid(AvatarUrl.Value, "avatar_url", Limits.MaxUriLength,
s => MiscUtils.TryMatchUri(s, out var avatarUri));
if (BannerImage.Value != null)
AssertValid(BannerImage.Value, "banner", Limits.MaxUriLength,
s => MiscUtils.TryMatchUri(s, out var bannerUri));
if (Color.Value != null)
AssertValid(Color.Value, "color", "^[0-9a-fA-F]{6}$");
if (Pronouns.Value != null)
AssertValid(Pronouns.Value, "pronouns", Limits.MaxPronounsLength);
if (Description.Value != null)
AssertValid(Description.Value, "description", Limits.MaxDescriptionLength);
if (ProxyTags.IsPresent && (ProxyTags.Value.Length > 100 ||
ProxyTags.Value.Any(tag => tag.ProxyString.IsLongerThan(100))))
// todo: have a better error for this
throw new ValidationError("proxy_tags");
}
#nullable disable
@ -70,13 +84,13 @@ namespace PluralKit.Core
var patch = new MemberPatch();
if (o.ContainsKey("name") && o["name"].Type == JTokenType.Null)
throw new JsonModelParseError("Member name can not be set to null.");
throw new ValidationError("Member name can not be set to null.");
if (o.ContainsKey("name")) patch.Name = o.Value<string>("name").BoundsCheckField(Limits.MaxMemberNameLength, "Member name");
if (o.ContainsKey("name")) patch.Name = o.Value<string>("name");
if (o.ContainsKey("color")) patch.Color = o.Value<string>("color").NullIfEmpty()?.ToLower();
if (o.ContainsKey("display_name")) patch.DisplayName = o.Value<string>("display_name").NullIfEmpty().BoundsCheckField(Limits.MaxMemberNameLength, "Member display name");
if (o.ContainsKey("avatar_url")) patch.AvatarUrl = o.Value<string>("avatar_url").NullIfEmpty().BoundsCheckField(Limits.MaxUriLength, "Member avatar URL");
if (o.ContainsKey("banner")) patch.BannerImage = o.Value<string>("banner").NullIfEmpty().BoundsCheckField(Limits.MaxUriLength, "Member banner URL");
if (o.ContainsKey("display_name")) patch.DisplayName = o.Value<string>("display_name").NullIfEmpty();
if (o.ContainsKey("avatar_url")) patch.AvatarUrl = o.Value<string>("avatar_url").NullIfEmpty();
if (o.ContainsKey("banner")) patch.BannerImage = o.Value<string>("banner").NullIfEmpty();
if (o.ContainsKey("birthday"))
{
@ -84,26 +98,25 @@ namespace PluralKit.Core
var res = DateTimeFormats.DateExportFormat.Parse(str);
if (res.Success) patch.Birthday = res.Value;
else if (str == null) patch.Birthday = null;
else throw new JsonModelParseError("Could not parse member birthday.");
else throw new ValidationError("birthday");
}
if (o.ContainsKey("pronouns")) patch.Pronouns = o.Value<string>("pronouns").NullIfEmpty().BoundsCheckField(Limits.MaxPronounsLength, "Member pronouns");
if (o.ContainsKey("description")) patch.Description = o.Value<string>("description").NullIfEmpty().BoundsCheckField(Limits.MaxDescriptionLength, "Member descriptoin");
if (o.ContainsKey("pronouns")) patch.Pronouns = o.Value<string>("pronouns").NullIfEmpty();
if (o.ContainsKey("description")) patch.Description = o.Value<string>("description").NullIfEmpty();
if (o.ContainsKey("keep_proxy")) patch.KeepProxy = o.Value<bool>("keep_proxy");
// legacy: used in old export files and APIv1
// todo: should we parse `proxy_tags` first?
if (o.ContainsKey("prefix") || o.ContainsKey("suffix") && !o.ContainsKey("proxy_tags"))
patch.ProxyTags = new[] {new ProxyTag(o.Value<string>("prefix"), o.Value<string>("suffix"))};
else if (o.ContainsKey("proxy_tags"))
{
patch.ProxyTags = o.Value<JArray>("proxy_tags")
.OfType<JObject>().Select(o => new ProxyTag(o.Value<string>("prefix"), o.Value<string>("suffix")))
.Where(p => p.Valid)
.ToArray();
}
if(o.ContainsKey("privacy")) //TODO: Deprecate this completely in api v2
{
var plevel = o.Value<string>("privacy").ParsePrivacy("member");
var plevel = o.ParsePrivacy("privacy");
patch.Visibility = plevel;
patch.NamePrivacy = plevel;
@ -116,14 +129,14 @@ namespace PluralKit.Core
}
else
{
if (o.ContainsKey("visibility")) patch.Visibility = o.Value<string>("visibility").ParsePrivacy("member");
if (o.ContainsKey("name_privacy")) patch.NamePrivacy = o.Value<string>("name_privacy").ParsePrivacy("member");
if (o.ContainsKey("description_privacy")) patch.DescriptionPrivacy = o.Value<string>("description_privacy").ParsePrivacy("member");
if (o.ContainsKey("avatar_privacy")) patch.AvatarPrivacy = o.Value<string>("avatar_privacy").ParsePrivacy("member");
if (o.ContainsKey("birthday_privacy")) patch.BirthdayPrivacy = o.Value<string>("birthday_privacy").ParsePrivacy("member");
if (o.ContainsKey("pronoun_privacy")) patch.PronounPrivacy = o.Value<string>("pronoun_privacy").ParsePrivacy("member");
// if (o.ContainsKey("color_privacy")) member.ColorPrivacy = o.Value<string>("color_privacy").ParsePrivacy("member");
if (o.ContainsKey("metadata_privacy")) patch.MetadataPrivacy = o.Value<string>("metadata_privacy").ParsePrivacy("member");
if (o.ContainsKey("visibility")) patch.Visibility = o.ParsePrivacy("visibility");
if (o.ContainsKey("name_privacy")) patch.NamePrivacy = o.ParsePrivacy("name_privacy");
if (o.ContainsKey("description_privacy")) patch.DescriptionPrivacy = o.ParsePrivacy("description_privacy");
if (o.ContainsKey("avatar_privacy")) patch.AvatarPrivacy = o.ParsePrivacy("avatar_privacy");
if (o.ContainsKey("birthday_privacy")) patch.BirthdayPrivacy = o.ParsePrivacy("birthday_privacy");
if (o.ContainsKey("pronoun_privacy")) patch.PronounPrivacy = o.ParsePrivacy("pronoun_privacy");
// if (o.ContainsKey("color_privacy")) member.ColorPrivacy = o.ParsePrivacy("member");
if (o.ContainsKey("metadata_privacy")) patch.MetadataPrivacy = o.ParsePrivacy("metadata_privacy");
}
return patch;

View File

@ -1,17 +1,48 @@
using System;
using System.Text.RegularExpressions;
namespace PluralKit.Core
{
public class InvalidPatchException : Exception
{
public InvalidPatchException(string message) : base(message) {}
}
public abstract class PatchObject
{
public abstract UpdateQueryBuilder Apply(UpdateQueryBuilder b);
public void CheckIsValid() {}
public void AssertIsValid() {}
protected bool AssertValid(string input, string name, int maxLength, Func<string, bool>? validate = null)
{
if (input.Length > maxLength)
throw new FieldTooLongError(name, maxLength, input.Length);
if (validate != null && !validate(input))
throw new ValidationError(name);
return true;
}
protected bool AssertValid(string input, string name, string pattern)
{
if (!Regex.IsMatch(input, pattern))
throw new ValidationError(name);
return true;
}
}
public class ValidationError: Exception
{
public ValidationError(string message): base(message) { }
}
public class FieldTooLongError: ValidationError
{
public string Name;
public int MaxLength;
public int ActualLength;
public FieldTooLongError(string name, int maxLength, int actualLength):
base($"{name} too long ({actualLength} > {maxLength})")
{
Name = name;
MaxLength = maxLength;
ActualLength = actualLength;
}
}
}

View File

@ -1,8 +1,11 @@
#nullable enable
using System;
using System.Text.RegularExpressions;
using Newtonsoft.Json.Linq;
using NodaTime;
namespace PluralKit.Core
{
public class SystemPatch: PatchObject
@ -46,34 +49,44 @@ namespace PluralKit.Core
.With("member_limit_override", MemberLimitOverride)
.With("group_limit_override", GroupLimitOverride);
public new void CheckIsValid()
public new void AssertIsValid()
{
if (AvatarUrl.Value != null && !MiscUtils.TryMatchUri(AvatarUrl.Value, out var avatarUri))
throw new InvalidPatchException("avatar_url");
if (BannerImage.Value != null && !MiscUtils.TryMatchUri(BannerImage.Value, out var bannerImage))
throw new InvalidPatchException("banner");
if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$")))
throw new InvalidPatchException("color");
if (Name.Value != null)
AssertValid(Name.Value, "name", Limits.MaxSystemNameLength);
if (Description.Value != null)
AssertValid(Description.Value, "description", Limits.MaxDescriptionLength);
if (Tag.Value != null)
AssertValid(Tag.Value, "tag", Limits.MaxSystemTagLength);
if (AvatarUrl.Value != null)
AssertValid(AvatarUrl.Value, "avatar_url", Limits.MaxUriLength,
s => MiscUtils.TryMatchUri(s, out var avatarUri));
if (BannerImage.Value != null)
AssertValid(BannerImage.Value, "banner", Limits.MaxUriLength,
s => MiscUtils.TryMatchUri(s, out var bannerUri));
if (Color.Value != null)
AssertValid(Color.Value, "color", "^[0-9a-fA-F]{6}$");
if (UiTz.IsPresent && DateTimeZoneProviders.Tzdb.GetZoneOrNull(UiTz.Value) == null)
throw new ValidationError("avatar_url");
}
public static SystemPatch FromJSON(JObject o)
{
var patch = new SystemPatch();
if (o.ContainsKey("name")) patch.Name = o.Value<string>("name").NullIfEmpty().BoundsCheckField(Limits.MaxSystemNameLength, "System name");
if (o.ContainsKey("description")) patch.Description = o.Value<string>("description").NullIfEmpty().BoundsCheckField(Limits.MaxDescriptionLength, "System description");
if (o.ContainsKey("tag")) patch.Tag = o.Value<string>("tag").NullIfEmpty().BoundsCheckField(Limits.MaxSystemTagLength, "System tag");
if (o.ContainsKey("avatar_url")) patch.AvatarUrl = o.Value<string>("avatar_url").NullIfEmpty().BoundsCheckField(Limits.MaxUriLength, "System avatar URL");
if (o.ContainsKey("banner")) patch.BannerImage = o.Value<string>("banner").NullIfEmpty().BoundsCheckField(Limits.MaxUriLength, "System banner URL");
if (o.ContainsKey("name")) patch.Name = o.Value<string>("name").NullIfEmpty();
if (o.ContainsKey("description")) patch.Description = o.Value<string>("description").NullIfEmpty();
if (o.ContainsKey("tag")) patch.Tag = o.Value<string>("tag").NullIfEmpty();
if (o.ContainsKey("avatar_url")) patch.AvatarUrl = o.Value<string>("avatar_url").NullIfEmpty();
if (o.ContainsKey("banner")) patch.BannerImage = o.Value<string>("banner").NullIfEmpty();
if (o.ContainsKey("timezone")) patch.UiTz = o.Value<string>("tz") ?? "UTC";
// legacy: APIv1 uses "tz" instead of "timezone"
// todo: remove in APIv2
if (o.ContainsKey("tz")) patch.UiTz = o.Value<string>("tz") ?? "UTC";
if (o.ContainsKey("description_privacy")) patch.DescriptionPrivacy = o.Value<string>("description_privacy").ParsePrivacy("description");
if (o.ContainsKey("member_list_privacy")) patch.MemberListPrivacy = o.Value<string>("member_list_privacy").ParsePrivacy("member list");
if (o.ContainsKey("front_privacy")) patch.FrontPrivacy = o.Value<string>("front_privacy").ParsePrivacy("front");
if (o.ContainsKey("front_history_privacy")) patch.FrontHistoryPrivacy = o.Value<string>("front_history_privacy").ParsePrivacy("front history");
if (o.ContainsKey("description_privacy")) patch.DescriptionPrivacy = o.ParsePrivacy("description_privacy");
if (o.ContainsKey("member_list_privacy")) patch.MemberListPrivacy = o.ParsePrivacy("member_list_privacy");
if (o.ContainsKey("front_privacy")) patch.FrontPrivacy = o.ParsePrivacy("front_privacy");
if (o.ContainsKey("front_history_privacy")) patch.FrontHistoryPrivacy = o.ParsePrivacy("front_history_privacy");
return patch;
}
}

View File

@ -1,5 +1,7 @@
using System;
using Newtonsoft.Json.Linq;
namespace PluralKit.Core
{
public enum PrivacyLevel
@ -41,13 +43,16 @@ namespace PluralKit.Core
public static string ToJsonString(this PrivacyLevel level) => level.LevelName();
public static PrivacyLevel ParsePrivacy(this string input, string errorName)
public static PrivacyLevel ParsePrivacy(this JObject o, string propertyName)
{
var input = o.Value<string>(propertyName);
if (input == null) return PrivacyLevel.Public;
if (input == "") return PrivacyLevel.Private;
if (input == "private") return PrivacyLevel.Private;
if (input == "public") return PrivacyLevel.Public;
throw new JsonModelParseError($"Could not parse {errorName} privacy.");
throw new ValidationError(propertyName);
}
}

View File

@ -14,6 +14,10 @@ namespace PluralKit.Core
[JsonProperty("prefix")] public string Prefix { get; set; }
[JsonProperty("suffix")] public string Suffix { get; set; }
[JsonIgnore] public bool Valid =>
Prefix != null || Suffix != null
&& ProxyString.Length <= Limits.MaxProxyTagLength;
[JsonIgnore] public string ProxyString => $"{Prefix ?? ""}text{Suffix ?? ""}";
[JsonIgnore] public bool IsEmpty => Prefix == null && Suffix == null;

View File

@ -1,10 +1,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Dapper;
using Newtonsoft.Json.Linq;
using NodaTime;
@ -17,352 +17,52 @@ namespace PluralKit.Core
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly ILogger _logger;
public DataFileService(ILogger logger, IDatabase db, ModelRepository repo)
public DataFileService(IDatabase db, ModelRepository repo, ILogger logger)
{
_db = db;
_repo = repo;
_logger = logger.ForContext<DataFileService>();
_logger = logger;
}
public async Task<DataFileSystem> ExportSystem(PKSystem system)
public async Task<JObject> ExportSystem(PKSystem system)
{
await using var conn = await _db.Obtain();
// Export members
var members = new List<DataFileMember>();
var pkMembers = _repo.GetSystemMembers(conn, system.Id); // Read all members in the system
await foreach (var member in pkMembers.Select(m => new DataFileMember
{
Id = m.Hid,
Name = m.Name,
DisplayName = m.DisplayName,
Description = m.Description,
Birthday = m.Birthday?.FormatExport(),
Pronouns = m.Pronouns,
Color = m.Color,
AvatarUrl = m.AvatarUrl.TryGetCleanCdnUrl(),
ProxyTags = m.ProxyTags,
KeepProxy = m.KeepProxy,
Created = m.Created.FormatExport(),
MessageCount = m.MessageCount
})) members.Add(member);
// Export switches
var switches = new List<DataFileSwitch>();
var switchList = await _repo.GetPeriodFronters(conn, system.Id, null, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
switches.AddRange(switchList.Select(x => new DataFileSwitch
{
Timestamp = x.TimespanStart.FormatExport(),
Members = x.Members.Select(m => m.Hid).ToList() // Look up member's HID using the member export from above
}));
var o = new JObject();
return new DataFileSystem
o.Add("version", 1);
o.Add("id", system.Hid);
o.Add("name", system.Name);
o.Add("description", system.Description);
o.Add("tag", system.Tag);
o.Add("avatar_url", system.AvatarUrl);
o.Add("timezone", system.UiTz);
o.Add("created", system.Created.FormatExport());
o.Add("accounts", new JArray((await _repo.GetSystemAccounts(conn, system.Id)).ToList()));
o.Add("members", new JArray((await _repo.GetSystemMembers(conn, system.Id).ToListAsync()).Select(m => m.ToJson(LookupContext.ByOwner))));
var switches = new JArray();
var switchList = await _repo.GetPeriodFronters(conn, system.Id, null,
Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
foreach (var sw in switchList)
{
Version = 1,
Id = system.Hid,
Name = system.Name,
Description = system.Description,
Tag = system.Tag,
AvatarUrl = system.AvatarUrl,
TimeZone = system.UiTz,
Members = members,
Switches = switches,
Created = system.Created.FormatExport(),
LinkedAccounts = (await _repo.GetSystemAccounts(conn, system.Id)).ToList()
};
}
private MemberPatch ToMemberPatch(DataFileMember fileMember)
{
var newMember = new MemberPatch
{
Name = fileMember.Name,
DisplayName = fileMember.DisplayName,
Description = fileMember.Description,
Color = fileMember.Color,
Pronouns = fileMember.Pronouns,
AvatarUrl = fileMember.AvatarUrl,
KeepProxy = fileMember.KeepProxy,
MessageCount = fileMember.MessageCount,
};
if (fileMember.Prefix != null || fileMember.Suffix != null)
newMember.ProxyTags = new[] {new ProxyTag(fileMember.Prefix, fileMember.Suffix)};
else
// Ignore proxy tags where both prefix and suffix are set to null (would be invalid anyway)
newMember.ProxyTags = (fileMember.ProxyTags ?? new ProxyTag[] { }).Where(tag => !tag.IsEmpty).ToArray();
if (fileMember.Birthday != null)
{
var birthdayParse = DateTimeFormats.DateExportFormat.Parse(fileMember.Birthday);
newMember.Birthday = birthdayParse.Success ? (LocalDate?)birthdayParse.Value : null;
var s = new JObject();
s.Add("timestamp", sw.TimespanStart.FormatExport());
s.Add("members", new JArray(sw.Members.Select(m => m.Hid)));
switches.Add(s);
}
return newMember;
o.Add("switches", switches);
return o;
}
public async Task<ImportResult> ImportSystem(DataFileSystem data, PKSystem system, ulong accountId)
public async Task<ImportResultNew> ImportSystem(ulong userId, PKSystem? system, JObject importFile, Func<string, Task> confirmFunc)
{
await using var conn = await _db.Obtain();
await using var tx = await conn.BeginTransactionAsync();
var result = new ImportResult {
AddedNames = new List<string>(),
ModifiedNames = new List<string>(),
System = system,
Success = true // Assume success unless indicated otherwise
};
// If we don't already have a system to save to, create one
if (system == null)
{
system = result.System = await _repo.CreateSystem(conn, data.Name);
await _repo.AddAccount(conn, system.Id, accountId);
}
var memberLimit = system.MemberLimitOverride ?? Limits.MaxMemberCount;
// Apply system info
var patch = new SystemPatch {Name = data.Name};
if (data.Description != null) patch.Description = data.Description;
if (data.Tag != null) patch.Tag = data.Tag;
if (data.AvatarUrl != null) patch.AvatarUrl = data.AvatarUrl.TryGetCleanCdnUrl();
if (data.TimeZone != null) patch.UiTz = data.TimeZone ?? "UTC";
await _repo.UpdateSystem(conn, system.Id, patch);
// -- Member/switch import --
await using (var imp = await BulkImporter.Begin(system, conn))
{
// Tally up the members that didn't exist before, and check member count on import
// If creating the unmatched members would put us over the member limit, abort before creating any members
var memberCountBefore = await _repo.GetSystemMemberCount(conn, system.Id);
var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name));
if (memberCountBefore + membersToAdd > memberLimit)
{
result.Success = false;
result.Message = $"Import would exceed the maximum number of members ({memberLimit}).";
return result;
}
async Task DoImportMember(BulkImporter imp, DataFileMember fileMember)
{
var isCreatingNewMember = imp.IsNewMember(fileMember.Id, fileMember.Name);
// Use the file member's id as the "unique identifier" for the importing (actual value is irrelevant but needs to be consistent)
_logger.Debug(
"Importing member with identifier {FileId} to system {System} (is creating new member? {IsCreatingNewMember})",
fileMember.Id, system.Id, isCreatingNewMember);
var newMember = await imp.AddMember(fileMember.Id, fileMember.Id, fileMember.Name, ToMemberPatch(fileMember));
if (isCreatingNewMember)
result.AddedNames.Add(newMember.Name);
else
result.ModifiedNames.Add(newMember.Name);
}
// Can't parallelize this because we can't reuse the same connection/tx inside the importer
foreach (var m in data.Members)
await DoImportMember(imp, m);
// Lastly, import the switches
await imp.AddSwitches(data.Switches.Select(sw => new BulkImporter.SwitchInfo
{
Timestamp = DateTimeFormats.TimestampExportFormat.Parse(sw.Timestamp).Value,
// "Members" here is from whatever ID the data file uses, which the bulk importer can map to the real IDs! :)
MemberIdentifiers = sw.Members.ToList()
}).ToList());
}
_logger.Information("Imported system {System}", system.Hid);
return result;
return await BulkImporter.PerformImport(conn, tx, _repo, _logger, userId, system, importFile, confirmFunc);
}
}
public struct ImportResult
{
public ICollection<string> AddedNames;
public ICollection<string> ModifiedNames;
public PKSystem System;
public bool Success;
public string Message;
}
public struct DataFileSystem
{
[JsonProperty("version")] public int Version;
[JsonProperty("id")] public string Id;
[JsonProperty("name")] public string Name;
[JsonProperty("description")] public string Description;
[JsonProperty("tag")] public string Tag;
[JsonProperty("avatar_url")] public string AvatarUrl;
[JsonProperty("timezone")] public string TimeZone;
[JsonProperty("members")] public ICollection<DataFileMember> Members;
[JsonProperty("switches")] public ICollection<DataFileSwitch> Switches;
[JsonProperty("accounts")] public ICollection<ulong> LinkedAccounts;
[JsonProperty("created")] public string Created;
private bool TimeZoneValid => TimeZone == null || DateTimeZoneProviders.Tzdb.GetZoneOrNull(TimeZone) != null;
[JsonIgnore] public bool Valid =>
TimeZoneValid &&
Members != null &&
// no need to check this here, it is checked later as part of the import
// Members.Count <= Limits.MaxMemberCount &&
Members.All(m => m.Valid) &&
Switches != null &&
Switches.Count < 10000 &&
Switches.All(s => s.Valid) &&
!Name.IsLongerThan(Limits.MaxSystemNameLength) &&
!Description.IsLongerThan(Limits.MaxDescriptionLength) &&
!Tag.IsLongerThan(Limits.MaxSystemTagLength) &&
!AvatarUrl.IsLongerThan(1000);
}
public struct DataFileMember
{
[JsonProperty("id")] public string Id;
[JsonProperty("name")] public string Name;
[JsonProperty("display_name")] public string DisplayName;
[JsonProperty("description")] public string Description;
[JsonProperty("birthday")] public string Birthday;
[JsonProperty("pronouns")] public string Pronouns;
[JsonProperty("color")] public string Color;
[JsonProperty("avatar_url")] public string AvatarUrl;
// For legacy single-tag imports
[JsonProperty("prefix")] [JsonIgnore] public string Prefix;
[JsonProperty("suffix")] [JsonIgnore] public string Suffix;
// ^ is superseded by v
[JsonProperty("proxy_tags")] public ICollection<ProxyTag> ProxyTags;
[JsonProperty("keep_proxy")] public bool KeepProxy;
[JsonProperty("message_count")] public int MessageCount;
[JsonProperty("created")] public string Created;
[JsonIgnore] public bool Valid =>
Name != null &&
!Name.IsLongerThan(Limits.MaxMemberNameLength) &&
!DisplayName.IsLongerThan(Limits.MaxMemberNameLength) &&
!Description.IsLongerThan(Limits.MaxDescriptionLength) &&
!Pronouns.IsLongerThan(Limits.MaxPronounsLength) &&
(Color == null || Regex.IsMatch(Color, "[0-9a-fA-F]{6}")) &&
(Birthday == null || DateTimeFormats.DateExportFormat.Parse(Birthday).Success) &&
// Sanity checks
!AvatarUrl.IsLongerThan(1000) &&
// Older versions have Prefix and Suffix as fields, meaning ProxyTags is null
(ProxyTags == null || ProxyTags.Count < 100 &&
ProxyTags.All(t => !t.ProxyString.IsLongerThan(100))) &&
!Prefix.IsLongerThan(100) && !Suffix.IsLongerThan(100);
}
public struct DataFileSwitch
{
[JsonProperty("timestamp")] public string Timestamp;
[JsonProperty("members")] public ICollection<string> Members;
[JsonIgnore] public bool Valid =>
Members != null &&
Members.Count < 100 &&
DateTimeFormats.TimestampExportFormat.Parse(Timestamp).Success;
}
public struct TupperboxConversionResult
{
public bool HadGroups;
public bool HadIndividualTags;
public DataFileSystem System;
}
public struct TupperboxProfile
{
[JsonProperty("tuppers")] public ICollection<TupperboxTupper> Tuppers;
[JsonProperty("groups")] public ICollection<TupperboxGroup> Groups;
[JsonIgnore] public bool Valid => Tuppers != null && Groups != null && Tuppers.All(t => t.Valid) && Groups.All(g => g.Valid);
public TupperboxConversionResult ToPluralKit()
{
// Set by member conversion function
string lastSetTag = null;
TupperboxConversionResult output = default(TupperboxConversionResult);
var members = Tuppers.Select(t => t.ToPluralKit(ref lastSetTag, ref output.HadIndividualTags,
ref output.HadGroups)).ToList();
// Nowadays we set each member's display name to their name + tag, so we don't set a global system tag
output.System = new DataFileSystem
{
Members = members,
Switches = new List<DataFileSwitch>()
};
return output;
}
}
public struct TupperboxTupper
{
[JsonProperty("name")] public string Name;
[JsonProperty("avatar_url")] public string AvatarUrl;
[JsonProperty("brackets")] public IList<string> Brackets;
[JsonProperty("posts")] public int Posts; // Not supported by PK
[JsonProperty("show_brackets")] public bool ShowBrackets;
[JsonProperty("birthday")] public string Birthday;
[JsonProperty("description")] public string Description;
[JsonProperty("tag")] public string Tag;
[JsonProperty("group_id")] public string GroupId; // Not supported by PK
[JsonProperty("group_pos")] public int? GroupPos; // Not supported by PK
[JsonIgnore] public bool Valid =>
Name != null && Brackets != null && Brackets.Count % 2 == 0 &&
(Birthday == null || DateTimeFormats.TimestampExportFormat.Parse(Birthday).Success);
public DataFileMember ToPluralKit(ref string lastSetTag, ref bool multipleTags, ref bool hasGroup)
{
// If we've set a tag before and it's not the same as this one,
// then we have multiple unique tags and we pass that flag back to the caller
if (Tag != null && lastSetTag != null && lastSetTag != Tag) multipleTags = true;
lastSetTag = Tag;
// If this member is in a group, we have a (used) group and we flag that
if (GroupId != null) hasGroup = true;
// Brackets in Tupperbox format are arranged as a single array
// [prefix1, suffix1, prefix2, suffix2, prefix3... etc]
var tags = new List<ProxyTag>();
for (var i = 0; i < Brackets.Count / 2; i++)
tags.Add(new ProxyTag(Brackets[i * 2], Brackets[i * 2 + 1]));
// Convert birthday from ISO timestamp format to ISO date
var convertedBirthdate = Birthday != null
? LocalDate.FromDateTime(DateTimeFormats.TimestampExportFormat.Parse(Birthday).Value.ToDateTimeUtc())
: (LocalDate?) null;
return new DataFileMember
{
Id = Guid.NewGuid().ToString(), // Note: this is only ever used for lookup purposes
Name = Name,
AvatarUrl = AvatarUrl,
Birthday = convertedBirthdate?.FormatExport(),
Description = Description,
ProxyTags = tags,
KeepProxy = ShowBrackets,
DisplayName = Tag != null ? $"{Name} {Tag}" : null
};
}
}
public struct TupperboxGroup
{
[JsonProperty("id")] public int Id;
[JsonProperty("name")] public string Name;
[JsonProperty("description")] public string Description;
[JsonProperty("tag")] public string Tag;
[JsonIgnore] public bool Valid => true;
}
}
}

View File

@ -1,204 +0,0 @@
#nullable enable
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Data;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using NpgsqlTypes;
namespace PluralKit.Core
{
public class BulkImporter: IAsyncDisposable
{
private readonly SystemId _systemId;
private readonly IPKConnection _conn;
private readonly IPKTransaction _tx;
private readonly Dictionary<string, MemberId> _knownMembers = new Dictionary<string, MemberId>();
private readonly Dictionary<string, PKMember> _existingMembersByHid = new Dictionary<string, PKMember>();
private readonly Dictionary<string, PKMember> _existingMembersByName = new Dictionary<string, PKMember>();
private BulkImporter(SystemId systemId, IPKConnection conn, IPKTransaction tx)
{
_systemId = systemId;
_conn = conn;
_tx = tx;
}
public static async Task<BulkImporter> Begin(PKSystem system, IPKConnection conn)
{
var tx = await conn.BeginTransactionAsync();
var importer = new BulkImporter(system.Id, conn, tx);
await importer.Begin();
return importer;
}
public async Task Begin()
{
// Fetch all members in the system and log their names and hids
var members = await _conn.QueryAsync<PKMember>("select id, hid, name from members where system = @System",
new {System = _systemId});
foreach (var m in members)
{
_existingMembersByHid[m.Hid] = m;
_existingMembersByName[m.Name] = m;
}
}
/// <summary>
/// Checks whether trying to add a member with the given hid and name would result in creating a new member (as opposed to just updating one).
/// </summary>
public bool IsNewMember(string hid, string name) => FindExistingMemberInSystem(hid, name) == null;
/// <summary>
/// Imports a member into the database
/// </summary>
/// <remarks>If an existing member exists in this system that matches this member in either HID or name, it'll overlay the member information on top of this instead.</remarks>
/// <param name="identifier">An opaque identifier string that refers to this member regardless of source. Is used when importing switches. Value is irrelevant, but should be consistent with the same member later.</param>
/// <param name="potentialHid">When trying to match the member to an existing member, will use a member with this HID if present in system.</param>
/// <param name="potentialName">When trying to match the member to an existing member, will use a member with this name if present in system.</param>
/// <param name="patch">A member patch struct containing the data to apply to this member </param>
/// <returns>The inserted member object, which may or may not share an ID or HID with the input member.</returns>
public async Task<PKMember> AddMember(string identifier, string potentialHid, string potentialName, MemberPatch patch)
{
// See if we can find a member that matches this one
// if not, roll a new hid and we'll insert one with that
// (we can't trust the hid given in the member, it might let us overwrite another system's members)
var existingMember = FindExistingMemberInSystem(potentialHid, potentialName);
string newHid = existingMember?.Hid ?? await _conn.QuerySingleAsync<string>("find_free_member_hid", commandType: CommandType.StoredProcedure);
// Upsert member data and return the ID
QueryBuilder qb = QueryBuilder.Upsert("members", "hid")
.Constant("hid", "@Hid")
.Constant("system", "@System");
if (patch.Name.IsPresent) qb.Variable("name", "@Name");
if (patch.DisplayName.IsPresent) qb.Variable("display_name", "@DisplayName");
if (patch.Description.IsPresent) qb.Variable("description", "@Description");
if (patch.Pronouns.IsPresent) qb.Variable("pronouns", "@Pronouns");
if (patch.Color.IsPresent) qb.Variable("color", "@Color");
if (patch.AvatarUrl.IsPresent) qb.Variable("avatar_url", "@AvatarUrl");
if (patch.ProxyTags.IsPresent) qb.Variable("proxy_tags", "@ProxyTags");
if (patch.Birthday.IsPresent) qb.Variable("birthday", "@Birthday");
if (patch.KeepProxy.IsPresent) qb.Variable("keep_proxy", "@KeepProxy");
// don't overwrite message count on existing members
if (existingMember == null)
if (patch.MessageCount.IsPresent) qb.Variable("message_count", "@MessageCount");
var newMember = await _conn.QueryFirstAsync<PKMember>(qb.Build("returning *"),
new
{
Hid = newHid,
System = _systemId,
Name = patch.Name.Value,
DisplayName = patch.DisplayName.Value,
Description = patch.Description.Value,
Pronouns = patch.Pronouns.Value,
Color = patch.Color.Value,
AvatarUrl = patch.AvatarUrl.Value?.TryGetCleanCdnUrl(),
KeepProxy = patch.KeepProxy.Value,
ProxyTags = patch.ProxyTags.Value,
Birthday = patch.Birthday.Value,
MessageCount = patch.MessageCount.Value,
});
// Log this member ID by the given identifier
_knownMembers[identifier] = newMember.Id;
return newMember;
}
private PKMember? FindExistingMemberInSystem(string hid, string name)
{
if (_existingMembersByHid.TryGetValue(hid, out var byHid)) return byHid;
if (_existingMembersByName.TryGetValue(name, out var byName)) return byName;
return null;
}
/// <summary>
/// Register switches in bulk.
/// </summary>
/// <remarks>This function assumes there are no duplicate switches (ie. switches with the same timestamp).</remarks>
public async Task AddSwitches(IReadOnlyCollection<SwitchInfo> switches)
{
// Ensure we're aware of all the members we're trying to import from
if (!switches.All(sw => sw.MemberIdentifiers.All(m => _knownMembers.ContainsKey(m))))
throw new ArgumentException("One or more switch members haven't been added using this importer");
// Fetch the existing switches in the database so we can avoid duplicates
var existingSwitches = (await _conn.QueryAsync<PKSwitch>("select * from switches where system = @System", new {System = _systemId})).ToList();
var existingTimestamps = existingSwitches.Select(sw => sw.Timestamp).ToImmutableHashSet();
var lastSwitchId = existingSwitches.Count != 0 ? existingSwitches.Select(sw => sw.Id).Max() : (SwitchId?) null;
// Import switch definitions
var importedSwitches = new Dictionary<Instant, SwitchInfo>();
await using (var importer = _conn.BeginBinaryImport("copy switches (system, timestamp) from stdin (format binary)"))
{
foreach (var sw in switches)
{
// Don't import duplicate switches
if (existingTimestamps.Contains(sw.Timestamp)) continue;
// Otherwise, write to importer
await importer.StartRowAsync();
await importer.WriteAsync(_systemId.Value, NpgsqlDbType.Integer);
await importer.WriteAsync(sw.Timestamp, NpgsqlDbType.Timestamp);
// Note that we've imported a switch with this timestamp
importedSwitches[sw.Timestamp] = sw;
}
// Commit the import
await importer.CompleteAsync();
}
// Now, fetch all the switches we just added (so, now we get their IDs too)
// IDs are sequential, so any ID in this system, with a switch ID > the last max, will be one we just added
var justAddedSwitches = await _conn.QueryAsync<PKSwitch>(
"select * from switches where system = @System and id > @LastSwitchId",
new {System = _systemId, LastSwitchId = lastSwitchId?.Value ?? -1});
// Lastly, import the switch members
await using (var importer = _conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
{
foreach (var justAddedSwitch in justAddedSwitches)
{
if (!importedSwitches.TryGetValue(justAddedSwitch.Timestamp, out var switchInfo))
throw new Exception($"Found 'just-added' switch (by ID) with timestamp {justAddedSwitch.Timestamp}, but this did not correspond to a timestamp we just added a switch entry of! :/");
// We still assume timestamps are unique and non-duplicate, so:
var members = switchInfo.MemberIdentifiers;
foreach (var memberIdentifier in members)
{
if (!_knownMembers.TryGetValue(memberIdentifier, out var memberId))
throw new Exception($"Attempted to import switch with member identifier {memberIdentifier} but could not find an entry in the id map for this! :/");
await importer.StartRowAsync();
await importer.WriteAsync(justAddedSwitch.Id.Value, NpgsqlDbType.Integer);
await importer.WriteAsync(memberId.Value, NpgsqlDbType.Integer);
}
}
await importer.CompleteAsync();
}
}
public struct SwitchInfo
{
public Instant Timestamp;
/// <summary>
/// An ordered list of "member identifiers" matching with the identifier parameter passed to <see cref="BulkImporter.AddMember"/>.
/// </summary>
public IReadOnlyList<string> MemberIdentifiers;
}
public async ValueTask DisposeAsync() =>
await _tx.CommitAsync();
}
}

View File

@ -0,0 +1,124 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Newtonsoft.Json.Linq;
using Autofac;
using Dapper;
using Serilog;
namespace PluralKit.Core
{
public partial class BulkImporter : IAsyncDisposable
{
private ILogger _logger { get; init; }
private ModelRepository _repo { get; init; }
private PKSystem _system { get; set; }
private IPKConnection _conn { get; init; }
private IPKTransaction _tx { get; init; }
private Func<string, Task> _confirmFunc { get; init; }
private readonly Dictionary<string, MemberId> _existingMemberHids = new();
private readonly Dictionary<string, MemberId> _existingMemberNames = new();
private readonly Dictionary<string, MemberId> _knownIdentifiers = new();
private ImportResultNew _result = new();
internal static async Task<ImportResultNew> PerformImport(IPKConnection conn, IPKTransaction tx, ModelRepository repo, ILogger logger,
ulong userId, PKSystem? system, JObject importFile, Func<string, Task> confirmFunc)
{
await using var importer = new BulkImporter()
{
_logger = logger,
_repo = repo,
_system = system,
_conn = conn,
_tx = tx,
_confirmFunc = confirmFunc,
};
if (system == null) {
system = await repo.CreateSystem(conn, null, tx);
await repo.AddAccount(conn, system.Id, userId);
importer._result.CreatedSystem = system.Hid;
importer._system = system;
}
// Fetch all members in the system and log their names and hids
var members = await conn.QueryAsync<PKMember>("select id, hid, name from members where system = @System",
new {System = system.Id});
foreach (var m in members)
{
importer._existingMemberHids[m.Hid] = m.Id;
importer._existingMemberNames[m.Name] = m.Id;
}
try
{
if (importFile.ContainsKey("tuppers"))
await importer.ImportTupperbox(importFile);
else if (importFile.ContainsKey("switches"))
await importer.ImportPluralKit(importFile);
else
throw new ImportException("File type is unknown.");
importer._result.Success = true;
await tx.CommitAsync();
}
catch (ImportException e)
{
importer._result.Success = false;
importer._result.Message = e.Message;
}
catch (ArgumentNullException)
{
importer._result.Success = false;
}
return importer._result;
}
private (MemberId?, bool) TryGetExistingMember(string hid, string name)
{
if (_existingMemberHids.TryGetValue(hid, out var byHid)) return (byHid, true);
if (_existingMemberNames.TryGetValue(name, out var byName)) return (byName, false);
return (null, false);
}
private async Task AssertLimitNotReached(int newMembers)
{
var memberLimit = _system.MemberLimitOverride ?? Limits.MaxMemberCount;
var existingMembers = await _repo.GetSystemMemberCount(_conn, _system.Id);
if (existingMembers + newMembers > memberLimit)
throw new ImportException($"Import would exceed the maximum number of members ({memberLimit}).");
}
public async ValueTask DisposeAsync()
{
// try rolling back the transaction
// this will throw if the transaction was committed, but that's fine
// so we just catch InvalidOperationException
try
{
await _tx.RollbackAsync();
}
catch (InvalidOperationException) {}
}
private class ImportException : Exception {
public ImportException(string Message) : base(Message) {}
}
}
public record ImportResultNew
{
public int Added = 0;
public int Modified = 0;
public bool Success;
public string? CreatedSystem;
public string? Message;
}
}

View File

@ -0,0 +1,169 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Newtonsoft.Json.Linq;
using NodaTime;
using NpgsqlTypes;
namespace PluralKit.Core
{
public partial class BulkImporter
{
private async Task<ImportResultNew> ImportPluralKit(JObject importFile)
{
var patch = SystemPatch.FromJSON(importFile);
try
{
patch.AssertIsValid();
}
catch (ValidationError e)
{
throw new ImportException($"Field {e.Message} in export file is invalid.");
}
await _repo.UpdateSystem(_conn, _system.Id, patch, _tx);
var members = importFile.Value<JArray>("members");
var switches = importFile.Value<JArray>("switches");
var newMembers = members.Count(m => {
var (found, _) = TryGetExistingMember(m.Value<string>("id"), m.Value<string>("name"));
return found == null;
});
await AssertLimitNotReached(newMembers);
foreach (JObject member in members)
await ImportMember(member);
if (switches.Any(sw => sw.Value<JArray>("members").Any(m => !_knownIdentifiers.ContainsKey((string) m))))
throw new ImportException("One or more switches include members that haven't been imported.");
await ImportSwitches(switches);
return _result;
}
private async Task ImportMember(JObject member)
{
var id = member.Value<string>("id");
var name = member.Value<string>("name");
var (found, isHidExisting) = TryGetExistingMember(id, name);
var isNewMember = found == null;
var referenceName = isHidExisting ? id : name;
if (isNewMember)
_result.Added++;
else
_result.Modified++;
_logger.Debug(
"Importing member with identifier {FileId} to system {System} (is creating new member? {IsCreatingNewMember})",
referenceName, _system.Id, isNewMember
);
var patch = MemberPatch.FromJSON(member);
try
{
patch.AssertIsValid();
}
catch (FieldTooLongError e)
{
throw new ImportException($"Field {e.Name} in member {referenceName} is too long ({e.ActualLength} > {e.MaxLength}).");
}
catch (ValidationError e)
{
throw new ImportException($"Field {e.Message} in member {referenceName} is invalid.");
}
MemberId? memberId = found;
if (isNewMember)
{
var newMember = await _repo.CreateMember(_conn, _system.Id, patch.Name.Value, _tx);
memberId = newMember.Id;
}
_knownIdentifiers[id] = memberId.Value;
await _repo.UpdateMember(_conn, memberId.Value, patch, _tx);
}
private async Task ImportSwitches(JArray switches)
{
var existingSwitches = (await _conn.QueryAsync<PKSwitch>("select * from switches where system = @System", new {System = _system.Id})).ToList();
var existingTimestamps = existingSwitches.Select(sw => sw.Timestamp).ToImmutableHashSet();
var lastSwitchId = existingSwitches.Count != 0 ? existingSwitches.Select(sw => sw.Id).Max() : (SwitchId?) null;
if (switches.Count > 10000)
throw new ImportException($"Too many switches present in import file.");
// Import switch definitions
var importedSwitches = new Dictionary<Instant, JArray>();
await using (var importer = _conn.BeginBinaryImport("copy switches (system, timestamp) from stdin (format binary)"))
{
foreach (var sw in switches)
{
var timestampString = sw.Value<string>("timestamp");
var timestamp = DateTimeFormats.TimestampExportFormat.Parse(timestampString);
if (!timestamp.Success) throw new ImportException($"Switch timestamp {timestampString} is not an valid timestamp.");
// Don't import duplicate switches
if (existingTimestamps.Contains(timestamp.Value)) continue;
// Otherwise, write to importer
await importer.StartRowAsync();
await importer.WriteAsync(_system.Id.Value, NpgsqlDbType.Integer);
await importer.WriteAsync(timestamp.Value, NpgsqlDbType.Timestamp);
var members = sw.Value<JArray>("members");
if (members.Count > Limits.MaxSwitchMemberCount)
throw new ImportException($"Switch with timestamp {timestampString} contains too many members ({members.Count} > 100).");
// Note that we've imported a switch with this timestamp
importedSwitches[timestamp.Value] = sw.Value<JArray>("members");
}
// Commit the import
await importer.CompleteAsync();
}
// Now, fetch all the switches we just added (so, now we get their IDs too)
// IDs are sequential, so any ID in this system, with a switch ID > the last max, will be one we just added
var justAddedSwitches = await _conn.QueryAsync<PKSwitch>(
"select * from switches where system = @System and id > @LastSwitchId",
new {System = _system.Id, LastSwitchId = lastSwitchId?.Value ?? -1});
// Lastly, import the switch members
await using (var importer = _conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
{
foreach (var justAddedSwitch in justAddedSwitches)
{
if (!importedSwitches.TryGetValue(justAddedSwitch.Timestamp, out var switchMembers))
throw new Exception($"Found 'just-added' switch (by ID) with timestamp {justAddedSwitch.Timestamp}, but this did not correspond to a timestamp we just added a switch entry of! :/");
// We still assume timestamps are unique and non-duplicate, so:
foreach (var memberIdentifier in switchMembers)
{
if (!_knownIdentifiers.TryGetValue((string) memberIdentifier, out var memberId))
throw new Exception($"Attempted to import switch with member identifier {memberIdentifier} but could not find an entry in the id map for this! :/");
await importer.StartRowAsync();
await importer.WriteAsync(justAddedSwitch.Id.Value, NpgsqlDbType.Integer);
await importer.WriteAsync(memberId.Value, NpgsqlDbType.Integer);
}
}
await importer.CompleteAsync();
}
}
}
}

View File

@ -0,0 +1,122 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Newtonsoft.Json.Linq;
using NodaTime;
namespace PluralKit.Core
{
public partial class BulkImporter
{
private async Task<ImportResultNew> ImportTupperbox(JObject importFile)
{
var tuppers = importFile.Value<JArray>("tuppers");
var newMembers = tuppers.Count(t => !_existingMemberNames.TryGetValue("name", out var memberId));
await AssertLimitNotReached(newMembers);
string lastSetTag = null;
bool multipleTags = false;
bool hasGroup = false;
foreach (JObject tupper in tuppers)
(lastSetTag, multipleTags, hasGroup) = await ImportTupper(tupper, lastSetTag);
if (multipleTags || hasGroup)
{
var issueStr =
$"{Emojis.Warn} The following potential issues were detected converting your Tupperbox input file:";
if (hasGroup)
issueStr +=
"\n- PluralKit does not support member groups. Members will be imported without groups.";
if (multipleTags)
issueStr +=
"\n- PluralKit does not support per-member system tags. Since you had multiple members with distinct tags, those tags will be applied to the members' *display names*/nicknames instead.";
await _confirmFunc(issueStr);
_result.Success = true;
}
return _result;
}
private async Task<(string lastSetTag, bool multipleTags, bool hasGroup)> ImportTupper(JObject tupper, string lastSetTag)
{
if (!tupper.ContainsKey("name") || tupper["name"].Type == JTokenType.Null)
throw new ImportException("Field 'name' cannot be null.");
var hasGroup = tupper.ContainsKey("group_id") && tupper["group_id"].Type != JTokenType.Null;
var multipleTags = false;
var name = tupper.Value<string>("name");
var patch = new MemberPatch();
patch.Name = name;
if (tupper.ContainsKey("avatar_url") && tupper["avatar_url"].Type != JTokenType.Null) patch.AvatarUrl = tupper.Value<string>("avatar_url").NullIfEmpty();
if (tupper.ContainsKey("brackets"))
{
var brackets = tupper.Value<JArray>("brackets");
if (brackets.Count % 2 != 0)
throw new ImportException($"Field 'brackets' in tupper {name} is invalid.");
var tags = new List<ProxyTag>();
for (var i = 0; i < brackets.Count / 2; i++)
tags.Add(new ProxyTag((string) brackets[i * 2], (string) brackets[i * 2 + 1]));
patch.ProxyTags = tags.ToArray();
}
// todo: && if is new member
if (tupper.ContainsKey("posts")) patch.MessageCount = tupper.Value<int>("posts");
if (tupper.ContainsKey("show_brackets")) patch.KeepProxy = tupper.Value<bool>("show_brackets");
if (tupper.ContainsKey("birthday") && tupper["birthday"].Type != JTokenType.Null)
{
var parsed = DateTimeFormats.TimestampExportFormat.Parse(tupper.Value<string>("birthday"));
if (!parsed.Success)
throw new ImportException($"Field 'birthday' in tupper {name} is invalid.");
patch.Birthday = LocalDate.FromDateTime(parsed.Value.ToDateTimeUtc());
}
if (tupper.ContainsKey("description")) patch.Description = tupper.Value<string>("description");
if (tupper.ContainsKey("tag") && tupper["tag"].Type != JTokenType.Null)
{
var tag = tupper.Value<string>("tag");
if (tag != lastSetTag)
{
lastSetTag = tag;
multipleTags = true;
}
patch.DisplayName = $"{name} {tag}";
}
var isNewMember = false;
if (!_existingMemberNames.TryGetValue(name, out var memberId))
{
var newMember = await _repo.CreateMember(_conn, _system.Id, name, _tx);
memberId = newMember.Id;
isNewMember = true;
_result.Added++;
}
else
_result.Modified++;
_logger.Debug("Importing member with identifier {FileId} to system {System} (is creating new member? {IsCreatingNewMember})",
name, _system.Id, isNewMember);
try
{
patch.AssertIsValid();
}
catch (FieldTooLongError e)
{
throw new ImportException($"Field {e.Name} in tupper {name} is too long ({e.ActualLength} > {e.MaxLength}).");
}
catch (ValidationError e)
{
throw new ImportException($"Field {e.Message} in tupper {name} is invalid.");
}
await _repo.UpdateMember(_conn, memberId, patch, _tx);
return (lastSetTag, multipleTags, hasGroup);
}
}
}

View File

@ -1,19 +0,0 @@
using System;
namespace PluralKit.Core
{
internal static class JsonUtils
{
public static string BoundsCheckField(this string input, int maxLength, string nameInError)
{
if (input != null && input.Length > maxLength)
throw new JsonModelParseError($"{nameInError} too long ({input.Length} > {maxLength}).");
return input;
}
}
public class JsonModelParseError: Exception
{
public JsonModelParseError(string message): base(message) { }
}
}