API patch improvements

- add PatchObject.CheckIsValid
- use transaction when creating member, as to not create a member if the
patch is invalid
- return edited system in `PATCH /s` endpoint
This commit is contained in:
spiral 2021-04-21 22:57:19 +01:00
parent a2d2036851
commit b34ed5c4c0
No known key found for this signature in database
GPG Key ID: A6059F0CA0E1BD31
10 changed files with 103 additions and 22 deletions

View File

@ -1,3 +1,4 @@
using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using Dapper; using Dapper;
@ -54,18 +55,28 @@ namespace PluralKit.API
if (memberCount >= memberLimit) if (memberCount >= memberLimit)
return BadRequest($"Member limit reached ({memberCount} / {memberLimit})."); return BadRequest($"Member limit reached ({memberCount} / {memberLimit}).");
var member = await _repo.CreateMember(conn, systemId, properties.Value<string>("name")); await using var tx = await conn.BeginTransactionAsync();
var member = await _repo.CreateMember(conn, systemId, properties.Value<string>("name"), transaction: tx);
MemberPatch patch; MemberPatch patch;
try try
{ {
patch = JsonModelExt.ToMemberPatch(properties); patch = JsonModelExt.ToMemberPatch(properties);
patch.CheckIsValid();
} }
catch (JsonModelParseError e) catch (JsonModelParseError e)
{ {
await tx.RollbackAsync();
return BadRequest(e.Message); return BadRequest(e.Message);
} }
catch (InvalidPatchException e)
{
await tx.RollbackAsync();
return BadRequest($"Request field '{e.Message}' is invalid.");
}
member = await _repo.UpdateMember(conn, member.Id, patch); member = await _repo.UpdateMember(conn, member.Id, patch, transaction: tx);
await tx.CommitAsync();
return Ok(member.ToJson(User.ContextFor(member))); return Ok(member.ToJson(User.ContextFor(member)));
} }
@ -85,11 +96,16 @@ namespace PluralKit.API
try try
{ {
patch = JsonModelExt.ToMemberPatch(changes); patch = JsonModelExt.ToMemberPatch(changes);
patch.CheckIsValid();
} }
catch (JsonModelParseError e) catch (JsonModelParseError e)
{ {
return BadRequest(e.Message); return BadRequest(e.Message);
} }
catch (InvalidPatchException e)
{
return BadRequest($"Request field is invalid: {e.Message}");
}
var newMember = await _repo.UpdateMember(conn, member.Id, patch); var newMember = await _repo.UpdateMember(conn, member.Id, patch);
return Ok(newMember.ToJson(User.ContextFor(newMember))); return Ok(newMember.ToJson(User.ContextFor(newMember)));

View File

@ -1,3 +1,4 @@
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -140,13 +141,18 @@ namespace PluralKit.API
try try
{ {
patch = JsonModelExt.ToSystemPatch(changes); patch = JsonModelExt.ToSystemPatch(changes);
patch.CheckIsValid();
} }
catch (JsonModelParseError e) catch (JsonModelParseError e)
{ {
return BadRequest(e.Message); return BadRequest(e.Message);
} }
catch (InvalidPatchException e)
{
return BadRequest($"Request field '{e.Message}' is invalid.");
}
await _repo.UpdateSystem(conn, system!.Id, patch); system = await _repo.UpdateSystem(conn, system!.Id, patch);
return Ok(system.ToJson(User.ContextFor(system))); return Ok(system.ToJson(User.ContextFor(system)));
} }

View File

@ -25,17 +25,8 @@ namespace PluralKit.Bot {
using (var client = new HttpClient()) using (var client = new HttpClient())
{ {
Uri uri; if (!PluralKit.Core.MiscUtils.TryMatchUri(url, out var uri))
try
{
uri = new Uri(url);
if (!uri.IsAbsoluteUri || (uri.Scheme != "http" && uri.Scheme != "https"))
throw Errors.InvalidUrl(url);
}
catch (UriFormatException)
{
throw Errors.InvalidUrl(url); throw Errors.InvalidUrl(url);
}
var response = await client.GetAsync(uri); var response = await client.GetAsync(uri);
if (!response.IsSuccessStatusCode) // Check status code if (!response.IsSuccessStatusCode) // Check status code

View File

@ -1,5 +1,6 @@
#nullable enable #nullable enable
using System.Collections.Generic; using System.Collections.Generic;
using System.Data;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -30,22 +31,22 @@ namespace PluralKit.Core
return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter}); return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter});
} }
public async Task<PKGroup> CreateGroup(IPKConnection conn, SystemId system, string name) public async Task<PKGroup> CreateGroup(IPKConnection conn, SystemId system, string name, IDbTransaction? transaction = null)
{ {
var group = await conn.QueryFirstAsync<PKGroup>( var group = await conn.QueryFirstAsync<PKGroup>(
"insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *", "insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *",
new {System = system, Name = name}); new {System = system, Name = name}, transaction);
_logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name); _logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
return group; return group;
} }
public Task<PKGroup> UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch) public Task<PKGroup> UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch, IDbTransaction? transaction = null)
{ {
_logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch); _logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id")) var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id"))
.WithConstant("id", id) .WithConstant("id", id)
.Build("returning *"); .Build("returning *");
return conn.QueryFirstAsync<PKGroup>(query, pms); return conn.QueryFirstAsync<PKGroup>(query, pms, transaction);
} }
public Task DeleteGroup(IPKConnection conn, GroupId group) public Task DeleteGroup(IPKConnection conn, GroupId group)

View File

@ -1,4 +1,5 @@
#nullable enable #nullable enable
using System.Data;
using System.Threading.Tasks; using System.Threading.Tasks;
using Dapper; using Dapper;
@ -19,23 +20,23 @@ namespace PluralKit.Core
public Task<PKMember?> GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) => public Task<PKMember?> GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system }); conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system });
public async Task<PKMember> CreateMember(IPKConnection conn, SystemId id, string memberName) public async Task<PKMember> CreateMember(IPKConnection conn, SystemId id, string memberName, IDbTransaction? transaction = null)
{ {
var member = await conn.QueryFirstAsync<PKMember>( var member = await conn.QueryFirstAsync<PKMember>(
"insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", "insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *",
new {SystemId = id, Name = memberName}); new {SystemId = id, Name = memberName}, transaction);
_logger.Information("Created {MemberId} in {SystemId}: {MemberName}", _logger.Information("Created {MemberId} in {SystemId}: {MemberName}",
member.Id, id, memberName); member.Id, id, memberName);
return member; return member;
} }
public Task<PKMember> UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch) public Task<PKMember> UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch, IDbTransaction? transaction = null)
{ {
_logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch); _logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id")) var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id"))
.WithConstant("id", id) .WithConstant("id", id)
.Build("returning *"); .Build("returning *");
return conn.QueryFirstAsync<PKMember>(query, pms); return conn.QueryFirstAsync<PKMember>(query, pms, transaction);
} }
public Task DeleteMember(IPKConnection conn, MemberId id) public Task DeleteMember(IPKConnection conn, MemberId id)

View File

@ -1,4 +1,6 @@
#nullable enable #nullable enable
using System.Text.RegularExpressions;
namespace PluralKit.Core namespace PluralKit.Core
{ {
public class GroupPatch: PatchObject public class GroupPatch: PatchObject
@ -24,5 +26,14 @@ namespace PluralKit.Core
.With("icon_privacy", IconPrivacy) .With("icon_privacy", IconPrivacy)
.With("list_privacy", ListPrivacy) .With("list_privacy", ListPrivacy)
.With("visibility", Visibility); .With("visibility", Visibility);
public new void CheckIsValid()
{
if (Icon.Value != null && !MiscUtils.TryMatchUri(Icon.Value, out var avatarUri))
throw new InvalidPatchException("avatar_url");
if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$")))
throw new InvalidPatchException("color");
}
} }
} }

View File

@ -1,4 +1,5 @@
#nullable enable #nullable enable
using System.Text.RegularExpressions;
using NodaTime; using NodaTime;
@ -44,5 +45,14 @@ namespace PluralKit.Core
.With("birthday_privacy", BirthdayPrivacy) .With("birthday_privacy", BirthdayPrivacy)
.With("avatar_privacy", AvatarPrivacy) .With("avatar_privacy", AvatarPrivacy)
.With("metadata_privacy", MetadataPrivacy); .With("metadata_privacy", MetadataPrivacy);
public new void CheckIsValid()
{
if (AvatarUrl.Value != null && !MiscUtils.TryMatchUri(AvatarUrl.Value, out var avatarUri))
throw new InvalidPatchException("avatar_url");
if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$")))
throw new InvalidPatchException("color");
}
} }
} }

View File

@ -1,7 +1,17 @@
namespace PluralKit.Core using System;
namespace PluralKit.Core
{ {
public class InvalidPatchException : Exception
{
public InvalidPatchException(string message) : base(message) {}
}
public abstract class PatchObject public abstract class PatchObject
{ {
public abstract UpdateQueryBuilder Apply(UpdateQueryBuilder b); public abstract UpdateQueryBuilder Apply(UpdateQueryBuilder b);
public void CheckIsValid() {}
} }
} }

View File

@ -1,4 +1,6 @@
#nullable enable #nullable enable
using System.Text.RegularExpressions;
namespace PluralKit.Core namespace PluralKit.Core
{ {
public class SystemPatch: PatchObject public class SystemPatch: PatchObject
@ -33,5 +35,14 @@ namespace PluralKit.Core
.With("front_history_privacy", FrontHistoryPrivacy) .With("front_history_privacy", FrontHistoryPrivacy)
.With("pings_enabled", PingsEnabled) .With("pings_enabled", PingsEnabled)
.With("latch_timeout", LatchTimeout); .With("latch_timeout", LatchTimeout);
public new void CheckIsValid()
{
if (AvatarUrl.Value != null && !MiscUtils.TryMatchUri(AvatarUrl.Value, out var avatarUri))
throw new InvalidPatchException("avatar_url");
if (Color.Value != null && (!Regex.IsMatch(Color.Value, "^[0-9a-fA-F]{6}$")))
throw new InvalidPatchException("color");
}
} }
} }

View File

@ -0,0 +1,24 @@
using System;
namespace PluralKit.Core
{
public static class MiscUtils
{
public static bool TryMatchUri(string input, out Uri uri)
{
try
{
uri = new Uri(input);
if (!uri.IsAbsoluteUri || (uri.Scheme != "http" && uri.Scheme != "https"))
return false;
}
catch (UriFormatException)
{
uri = null;
return false;
}
return true;
}
}
}