diff --git a/PluralKit.API/Controllers/v1/MemberController.cs b/PluralKit.API/Controllers/v1/MemberController.cs index 539893c2..56a5ae43 100644 --- a/PluralKit.API/Controllers/v1/MemberController.cs +++ b/PluralKit.API/Controllers/v1/MemberController.cs @@ -1,3 +1,4 @@ +using System; using System.Threading.Tasks; using Dapper; @@ -54,18 +55,28 @@ namespace PluralKit.API if (memberCount >= memberLimit) return BadRequest($"Member limit reached ({memberCount} / {memberLimit})."); - var member = await _repo.CreateMember(conn, systemId, properties.Value("name")); + await using var tx = await conn.BeginTransactionAsync(); + var member = await _repo.CreateMember(conn, systemId, properties.Value("name"), transaction: tx); + MemberPatch patch; try { patch = JsonModelExt.ToMemberPatch(properties); + patch.CheckIsValid(); } catch (JsonModelParseError e) { + await tx.RollbackAsync(); return BadRequest(e.Message); } + catch (InvalidPatchException e) + { + await tx.RollbackAsync(); + return BadRequest($"Request field '{e.Message}' is invalid."); + } - member = await _repo.UpdateMember(conn, member.Id, patch); + member = await _repo.UpdateMember(conn, member.Id, patch, transaction: tx); + await tx.CommitAsync(); return Ok(member.ToJson(User.ContextFor(member))); } @@ -85,11 +96,16 @@ namespace PluralKit.API try { patch = JsonModelExt.ToMemberPatch(changes); + patch.CheckIsValid(); } catch (JsonModelParseError e) { return BadRequest(e.Message); } + catch (InvalidPatchException e) + { + return BadRequest($"Request field '{e.Message}' is invalid."); + } var newMember = await _repo.UpdateMember(conn, member.Id, patch); return Ok(newMember.ToJson(User.ContextFor(newMember))); diff --git a/PluralKit.API/Controllers/v1/SystemController.cs b/PluralKit.API/Controllers/v1/SystemController.cs index 0dce14e1..b0bf7003 100644 --- a/PluralKit.API/Controllers/v1/SystemController.cs +++ b/PluralKit.API/Controllers/v1/SystemController.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -140,13 +141,18 @@ namespace PluralKit.API try { patch = JsonModelExt.ToSystemPatch(changes); + patch.CheckIsValid(); } catch (JsonModelParseError e) { return BadRequest(e.Message); } + catch (InvalidPatchException e) + { + return BadRequest($"Request field '{e.Message}' is invalid."); + } - await _repo.UpdateSystem(conn, system!.Id, patch); + system = await _repo.UpdateSystem(conn, system!.Id, patch); return Ok(system.ToJson(User.ContextFor(system))); } diff --git a/PluralKit.Bot/Utils/AvatarUtils.cs b/PluralKit.Bot/Utils/AvatarUtils.cs index df4881ab..790c6bd3 100644 --- a/PluralKit.Bot/Utils/AvatarUtils.cs +++ b/PluralKit.Bot/Utils/AvatarUtils.cs @@ -25,17 +25,8 @@ namespace PluralKit.Bot { using (var client = new HttpClient()) { - Uri uri; - try - { - uri = new Uri(url); - if (!uri.IsAbsoluteUri || (uri.Scheme != "http" && uri.Scheme != "https")) - throw Errors.InvalidUrl(url); - } - catch (UriFormatException) - { + if (!PluralKit.Core.MiscUtils.TryMatchUri(url, out var uri)) throw Errors.InvalidUrl(url); - } var response = await client.GetAsync(uri); if (!response.IsSuccessStatusCode) // Check status code diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Group.cs b/PluralKit.Core/Database/Repository/ModelRepository.Group.cs index 7a88dba7..2afb43eb 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Group.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Group.cs @@ -1,5 +1,6 @@ #nullable enable using System.Collections.Generic; +using System.Data; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -30,22 +31,22 @@ namespace PluralKit.Core return conn.QuerySingleOrDefaultAsync(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter}); } - public async Task CreateGroup(IPKConnection conn, SystemId system, string name) + public async Task CreateGroup(IPKConnection conn, SystemId system, string name, IDbTransaction? transaction = null) { var group = await conn.QueryFirstAsync( "insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *", - new {System = system, Name = name}); + new {System = system, Name = name}, transaction); _logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name); return group; } - public Task UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch) + public Task UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch, IDbTransaction? transaction = null) { _logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch); var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id")) .WithConstant("id", id) .Build("returning *"); - return conn.QueryFirstAsync(query, pms); + return conn.QueryFirstAsync(query, pms, transaction); } public Task DeleteGroup(IPKConnection conn, GroupId group) diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Member.cs b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs index e2e25888..c7dfa34c 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Member.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs @@ -1,4 +1,5 @@ #nullable enable +using System.Data; using System.Threading.Tasks; using Dapper; @@ -19,23 +20,23 @@ namespace PluralKit.Core public Task GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) => conn.QueryFirstOrDefaultAsync("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system }); - public async Task CreateMember(IPKConnection conn, SystemId id, string memberName) + public async Task CreateMember(IPKConnection conn, SystemId id, string memberName, IDbTransaction? transaction = null) { var member = await conn.QueryFirstAsync( "insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", - new {SystemId = id, Name = memberName}); + new {SystemId = id, Name = memberName}, transaction); _logger.Information("Created {MemberId} in {SystemId}: {MemberName}", member.Id, id, memberName); return member; } - public Task UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch) + public Task UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch, IDbTransaction? transaction = null) { _logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch); var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id")) .WithConstant("id", id) .Build("returning *"); - return conn.QueryFirstAsync(query, pms); + return conn.QueryFirstAsync(query, pms, transaction); } public Task DeleteMember(IPKConnection conn, MemberId id) diff --git a/PluralKit.Core/Models/Patch/GroupPatch.cs b/PluralKit.Core/Models/Patch/GroupPatch.cs index ee624df8..933a3376 100644 --- a/PluralKit.Core/Models/Patch/GroupPatch.cs +++ b/PluralKit.Core/Models/Patch/GroupPatch.cs @@ -1,4 +1,6 @@ #nullable enable +using System.Text.RegularExpressions; + namespace PluralKit.Core { public class GroupPatch: PatchObject @@ -24,5 +26,14 @@ namespace PluralKit.Core .With("icon_privacy", IconPrivacy) .With("list_privacy", ListPrivacy) .With("visibility", Visibility); + + public new void CheckIsValid() + { + if (Icon.Value != null && !MiscUtils.TryMatchUri(Icon.Value, out var avatarUri)) + throw new InvalidPatchException("avatar_url"); + if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$"))) + throw new InvalidPatchException("color"); + } + } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/MemberPatch.cs b/PluralKit.Core/Models/Patch/MemberPatch.cs index 645e2b1a..f07b3074 100644 --- a/PluralKit.Core/Models/Patch/MemberPatch.cs +++ b/PluralKit.Core/Models/Patch/MemberPatch.cs @@ -1,4 +1,5 @@ #nullable enable +using System.Text.RegularExpressions; using NodaTime; @@ -44,5 +45,14 @@ namespace PluralKit.Core .With("birthday_privacy", BirthdayPrivacy) .With("avatar_privacy", AvatarPrivacy) .With("metadata_privacy", MetadataPrivacy); + + public new void CheckIsValid() + { + if (AvatarUrl.Value != null && !MiscUtils.TryMatchUri(AvatarUrl.Value, out var avatarUri)) + throw new InvalidPatchException("avatar_url"); + if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$"))) + throw new InvalidPatchException("color"); + } + } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/PatchObject.cs b/PluralKit.Core/Models/Patch/PatchObject.cs index 476007ef..03fcaf81 100644 --- a/PluralKit.Core/Models/Patch/PatchObject.cs +++ b/PluralKit.Core/Models/Patch/PatchObject.cs @@ -1,7 +1,17 @@ -namespace PluralKit.Core +using System; + +namespace PluralKit.Core { + + public class InvalidPatchException : Exception + { + public InvalidPatchException(string message) : base(message) {} + } + public abstract class PatchObject { public abstract UpdateQueryBuilder Apply(UpdateQueryBuilder b); + + public void CheckIsValid() {} } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/SystemPatch.cs b/PluralKit.Core/Models/Patch/SystemPatch.cs index 0f787749..dbbfa1f5 100644 --- a/PluralKit.Core/Models/Patch/SystemPatch.cs +++ b/PluralKit.Core/Models/Patch/SystemPatch.cs @@ -1,4 +1,6 @@ #nullable enable +using System.Text.RegularExpressions; + namespace PluralKit.Core { public class SystemPatch: PatchObject @@ -33,5 +35,14 @@ namespace PluralKit.Core .With("front_history_privacy", FrontHistoryPrivacy) .With("pings_enabled", PingsEnabled) .With("latch_timeout", LatchTimeout); + + public new void CheckIsValid() + { + if (AvatarUrl.Value != null && !MiscUtils.TryMatchUri(AvatarUrl.Value, out var avatarUri)) + throw new InvalidPatchException("avatar_url"); + if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$"))) + throw new InvalidPatchException("color"); + } + } } \ No newline at end of file diff --git a/PluralKit.Core/Utils/MiscUtils.cs b/PluralKit.Core/Utils/MiscUtils.cs new file mode 100644 index 00000000..bd21281a --- /dev/null +++ b/PluralKit.Core/Utils/MiscUtils.cs @@ -0,0 +1,24 @@ +using System; + +namespace PluralKit.Core +{ + public static class MiscUtils + { + public static bool TryMatchUri(string input, out Uri uri) + { + try + { + uri = new Uri(input); + if (!uri.IsAbsoluteUri || (uri.Scheme != "http" && uri.Scheme != "https")) + return false; + } + catch (UriFormatException) + { + uri = null; + return false; + } + + return true; + } + } +} \ No newline at end of file