Migrate API to ASP.NET Core Auth services + refactor

This commit is contained in:
Ske 2020-06-16 01:15:59 +02:00
parent 7fde54050a
commit 627f544ee8
25 changed files with 289 additions and 141 deletions

View File

@ -0,0 +1,32 @@
using System;
using System.Security.Claims;
using PluralKit.Core;
namespace PluralKit.API
{
public static class AuthExt
{
public static SystemId CurrentSystem(this ClaimsPrincipal user)
{
var claim = user.FindFirst(PKClaims.SystemId);
if (claim == null) throw new ArgumentException("User is unauthorized");
if (int.TryParse(claim.Value, out var id))
return new SystemId(id);
throw new ArgumentException("User has non-integer system ID claim");
}
public static LookupContext ContextFor(this ClaimsPrincipal user, PKSystem system)
{
if (!user.Identity.IsAuthenticated) return LookupContext.API;
return system.Id == user.CurrentSystem() ? LookupContext.ByOwner : LookupContext.API;
}
public static LookupContext ContextFor(this ClaimsPrincipal user, PKMember member)
{
if (!user.Identity.IsAuthenticated) return LookupContext.API;
return member.System == user.CurrentSystem() ? LookupContext.ByOwner : LookupContext.API;
}
}
}

View File

@ -0,0 +1,7 @@
namespace PluralKit.API
{
public class PKClaims
{
public const string SystemId = "PluralKit:SystemId";
}
}

View File

@ -0,0 +1,49 @@
using System;
using System.Linq;
using System.Security.Claims;
using System.Text.Encodings.Web;
using System.Threading.Tasks;
using Dapper;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using PluralKit.Core;
namespace PluralKit.API
{
public class SystemTokenAuthenticationHandler: AuthenticationHandler<SystemTokenAuthenticationHandler.Opts>
{
private readonly IDatabase _db;
public SystemTokenAuthenticationHandler(IOptionsMonitor<Opts> options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock, IDatabase db): base(options, logger, encoder, clock)
{
_db = db;
}
protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
{
if (!Request.Headers.ContainsKey("Authorization"))
return AuthenticateResult.NoResult();
var token = Request.Headers["Authorization"].FirstOrDefault();
var systemId = await _db.Execute(c => c.QuerySingleOrDefaultAsync<SystemId?>("select id from systems where token = @token", new { token }));
if (systemId == null) return AuthenticateResult.Fail("Invalid system token");
var claims = new[] {new Claim(PKClaims.SystemId, systemId.Value.Value.ToString())};
var identity = new ClaimsIdentity(claims, Scheme.Name);
var principal = new ClaimsPrincipal(identity);
var ticket = new AuthenticationTicket(principal, Scheme.Name);
ticket.Properties.IsPersistent = false;
ticket.Properties.AllowRefresh = false;
return AuthenticateResult.Success(ticket);
}
public class Opts: AuthenticationSchemeOptions
{
}
}
}

View File

@ -0,0 +1,19 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using PluralKit.Core;
namespace PluralKit.API
{
public class MemberOwnerHandler: AuthorizationHandler<OwnSystemRequirement, PKMember> {
protected override Task HandleRequirementAsync(AuthorizationHandlerContext context,
OwnSystemRequirement requirement, PKMember resource)
{
if (!context.User.Identity.IsAuthenticated) return Task.CompletedTask;
if (resource.System == context.User.CurrentSystem())
context.Succeed(requirement);
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,21 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using PluralKit.Core;
namespace PluralKit.API
{
public class MemberPrivacyHandler: AuthorizationHandler<PrivacyRequirement<PKMember>, PKMember>
{
protected override Task HandleRequirementAsync(AuthorizationHandlerContext context,
PrivacyRequirement<PKMember> requirement, PKMember resource)
{
var level = requirement.Mapper(resource);
var ctx = context.User.ContextFor(resource);
if (level.CanAccess(ctx))
context.Succeed(requirement);
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,6 @@
using Microsoft.AspNetCore.Authorization;
namespace PluralKit.API
{
public class OwnSystemRequirement: IAuthorizationRequirement { }
}

View File

@ -0,0 +1,18 @@
using System;
using Microsoft.AspNetCore.Authorization;
using PluralKit.Core;
namespace PluralKit.API
{
public class PrivacyRequirement<T>: IAuthorizationRequirement
{
public readonly Func<T, PrivacyLevel> Mapper;
public PrivacyRequirement(Func<T, PrivacyLevel> mapper)
{
Mapper = mapper;
}
}
}

View File

@ -0,0 +1,20 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using PluralKit.Core;
namespace PluralKit.API
{
public class SystemOwnerHandler: AuthorizationHandler<OwnSystemRequirement, PKSystem>
{
protected override Task HandleRequirementAsync(AuthorizationHandlerContext context,
OwnSystemRequirement requirement, PKSystem resource)
{
if (!context.User.Identity.IsAuthenticated) return Task.CompletedTask;
if (resource.Id == context.User.CurrentSystem())
context.Succeed(requirement);
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,21 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using PluralKit.Core;
namespace PluralKit.API
{
public class SystemPrivacyHandler: AuthorizationHandler<PrivacyRequirement<PKSystem>, PKSystem>
{
protected override Task HandleRequirementAsync(AuthorizationHandlerContext context,
PrivacyRequirement<PKSystem> requirement, PKSystem resource)
{
var level = requirement.Mapper(resource);
var ctx = context.User.ContextFor(resource);
if (level.CanAccess(ctx))
context.Succeed(requirement);
return Task.CompletedTask;
}
}
}

View File

@ -15,12 +15,10 @@ namespace PluralKit.API
public class AccountController: ControllerBase public class AccountController: ControllerBase
{ {
private IDataStore _data; private IDataStore _data;
private TokenAuthService _auth;
public AccountController(IDataStore data, TokenAuthService auth) public AccountController(IDataStore data)
{ {
_data = data; _data = data;
_auth = auth;
} }
[HttpGet("{aid}")] [HttpGet("{aid}")]
@ -29,7 +27,7 @@ namespace PluralKit.API
var system = await _data.GetSystemByAccount(aid); var system = await _data.GetSystemByAccount(aid);
if (system == null) return NotFound("Account not found."); if (system == null) return NotFound("Account not found.");
return Ok(system.ToJson(_auth.ContextFor(system))); return Ok(system.ToJson(User.ContextFor(system)));
} }
} }
} }

View File

@ -1,5 +1,6 @@
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
@ -15,9 +16,9 @@ namespace PluralKit.API
public class MemberController: ControllerBase public class MemberController: ControllerBase
{ {
private IDataStore _data; private IDataStore _data;
private TokenAuthService _auth; private IAuthorizationService _auth;
public MemberController(IDataStore data, TokenAuthService auth) public MemberController(IDataStore data, IAuthorizationService auth)
{ {
_data = data; _data = data;
_auth = auth; _auth = auth;
@ -29,14 +30,14 @@ namespace PluralKit.API
var member = await _data.GetMemberByHid(hid); var member = await _data.GetMemberByHid(hid);
if (member == null) return NotFound("Member not found."); if (member == null) return NotFound("Member not found.");
return Ok(member.ToJson(_auth.ContextFor(member))); return Ok(member.ToJson(User.ContextFor(member)));
} }
[HttpPost] [HttpPost]
[RequiresSystem] [Authorize]
public async Task<ActionResult<JObject>> PostMember([FromBody] JObject properties) public async Task<ActionResult<JObject>> PostMember([FromBody] JObject properties)
{ {
var system = _auth.CurrentSystem; var system = User.CurrentSystem();
if (!properties.ContainsKey("name")) if (!properties.ContainsKey("name"))
return BadRequest("Member name must be specified."); return BadRequest("Member name must be specified.");
@ -57,17 +58,18 @@ namespace PluralKit.API
} }
await _data.SaveMember(member); await _data.SaveMember(member);
return Ok(member.ToJson(_auth.ContextFor(member))); return Ok(member.ToJson(User.ContextFor(member)));
} }
[HttpPatch("{hid}")] [HttpPatch("{hid}")]
[RequiresSystem] [Authorize]
public async Task<ActionResult<JObject>> PatchMember(string hid, [FromBody] JObject changes) public async Task<ActionResult<JObject>> PatchMember(string hid, [FromBody] JObject changes)
{ {
var member = await _data.GetMemberByHid(hid); var member = await _data.GetMemberByHid(hid);
if (member == null) return NotFound("Member not found."); if (member == null) return NotFound("Member not found.");
if (member.System != _auth.CurrentSystem.Id) return Unauthorized($"Member '{hid}' is not part of your system."); var res = await _auth.AuthorizeAsync(User, member, "EditMember");
if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system.");
try try
{ {
@ -79,17 +81,18 @@ namespace PluralKit.API
} }
await _data.SaveMember(member); await _data.SaveMember(member);
return Ok(member.ToJson(_auth.ContextFor(member))); return Ok(member.ToJson(User.ContextFor(member)));
} }
[HttpDelete("{hid}")] [HttpDelete("{hid}")]
[RequiresSystem] [Authorize]
public async Task<ActionResult> DeleteMember(string hid) public async Task<ActionResult> DeleteMember(string hid)
{ {
var member = await _data.GetMemberByHid(hid); var member = await _data.GetMemberByHid(hid);
if (member == null) return NotFound("Member not found."); if (member == null) return NotFound("Member not found.");
if (member.System != _auth.CurrentSystem.Id) return Unauthorized($"Member '{hid}' is not part of your system."); var res = await _auth.AuthorizeAsync(User, member, "EditMember");
if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system.");
await _data.DeleteMember(member); await _data.DeleteMember(member);
return Ok(); return Ok();

View File

@ -30,12 +30,10 @@ namespace PluralKit.API
public class MessageController: ControllerBase public class MessageController: ControllerBase
{ {
private IDataStore _data; private IDataStore _data;
private TokenAuthService _auth;
public MessageController(IDataStore _data, TokenAuthService auth) public MessageController(IDataStore _data)
{ {
this._data = _data; this._data = _data;
_auth = auth;
} }
[HttpGet("{mid}")] [HttpGet("{mid}")]
@ -50,8 +48,8 @@ namespace PluralKit.API
Id = msg.Message.Mid.ToString(), Id = msg.Message.Mid.ToString(),
Channel = msg.Message.Channel.ToString(), Channel = msg.Message.Channel.ToString(),
Sender = msg.Message.Sender.ToString(), Sender = msg.Message.Sender.ToString(),
Member = msg.Member.ToJson(_auth.ContextFor(msg.System)), Member = msg.Member.ToJson(User.ContextFor(msg.System)),
System = msg.System.ToJson(_auth.ContextFor(msg.System)), System = msg.System.ToJson(User.ContextFor(msg.System)),
Original = msg.Message.OriginalMid?.ToString() Original = msg.Message.OriginalMid?.ToString()
}; };
} }

View File

@ -4,6 +4,7 @@ using System.Threading.Tasks;
using Dapper; using Dapper;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
@ -41,9 +42,9 @@ namespace PluralKit.API
{ {
private IDataStore _data; private IDataStore _data;
private IDatabase _conn; private IDatabase _conn;
private TokenAuthService _auth; private IAuthorizationService _auth;
public SystemController(IDataStore data, IDatabase conn, TokenAuthService auth) public SystemController(IDataStore data, IDatabase conn, IAuthorizationService auth)
{ {
_data = data; _data = data;
_conn = conn; _conn = conn;
@ -51,10 +52,11 @@ namespace PluralKit.API
} }
[HttpGet] [HttpGet]
[RequiresSystem] [Authorize]
public Task<ActionResult<JObject>> GetOwnSystem() public async Task<ActionResult<JObject>> GetOwnSystem()
{ {
return Task.FromResult<ActionResult<JObject>>(Ok(_auth.CurrentSystem.ToJson(_auth.ContextFor(_auth.CurrentSystem)))); var system = await _conn.Execute(c => c.QuerySystem(User.CurrentSystem()));
return system.ToJson(User.ContextFor(system));
} }
[HttpGet("{hid}")] [HttpGet("{hid}")]
@ -62,7 +64,7 @@ namespace PluralKit.API
{ {
var system = await _data.GetSystemByHid(hid); var system = await _data.GetSystemByHid(hid);
if (system == null) return NotFound("System not found."); if (system == null) return NotFound("System not found.");
return Ok(system.ToJson(_auth.ContextFor(system))); return Ok(system.ToJson(User.ContextFor(system)));
} }
[HttpGet("{hid}/members")] [HttpGet("{hid}/members")]
@ -71,13 +73,13 @@ namespace PluralKit.API
var system = await _data.GetSystemByHid(hid); var system = await _data.GetSystemByHid(hid);
if (system == null) return NotFound("System not found."); if (system == null) return NotFound("System not found.");
if (!system.MemberListPrivacy.CanAccess(_auth.ContextFor(system))) if (!system.MemberListPrivacy.CanAccess(User.ContextFor(system)))
return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list."); return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list.");
var members = _data.GetSystemMembers(system); var members = _data.GetSystemMembers(system);
return Ok(await members return Ok(await members
.Where(m => m.MemberPrivacy.CanAccess(_auth.ContextFor(system))) .Where(m => m.MemberPrivacy.CanAccess(User.ContextFor(system)))
.Select(m => m.ToJson(_auth.ContextFor(system))) .Select(m => m.ToJson(User.ContextFor(system)))
.ToListAsync()); .ToListAsync());
} }
@ -89,8 +91,8 @@ namespace PluralKit.API
var system = await _data.GetSystemByHid(hid); var system = await _data.GetSystemByHid(hid);
if (system == null) return NotFound("System not found."); if (system == null) return NotFound("System not found.");
if (!system.FrontHistoryPrivacy.CanAccess(_auth.ContextFor(system))) var auth = await _auth.AuthorizeAsync(User, system, "ViewFrontHistory");
return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history."); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history.");
using (var conn = await _conn.Obtain()) using (var conn = await _conn.Obtain())
{ {
@ -112,26 +114,25 @@ namespace PluralKit.API
var system = await _data.GetSystemByHid(hid); var system = await _data.GetSystemByHid(hid);
if (system == null) return NotFound("System not found."); if (system == null) return NotFound("System not found.");
if (!system.FrontPrivacy.CanAccess(_auth.ContextFor(system))) var auth = await _auth.AuthorizeAsync(User, system, "ViewFront");
return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter."); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter.");
var sw = await _data.GetLatestSwitch(system); var sw = await _data.GetLatestSwitch(system.Id);
if (sw == null) return NotFound("System has no registered switches."); if (sw == null) return NotFound("System has no registered switches.");
var members = _data.GetSwitchMembers(sw); var members = _data.GetSwitchMembers(sw);
return Ok(new FrontersReturn return Ok(new FrontersReturn
{ {
Timestamp = sw.Timestamp, Timestamp = sw.Timestamp,
Members = await members.Select(m => m.ToJson(_auth.ContextFor(system))).ToListAsync() Members = await members.Select(m => m.ToJson(User.ContextFor(system))).ToListAsync()
}); });
} }
[HttpPatch] [HttpPatch]
[RequiresSystem] [Authorize]
public async Task<ActionResult<JObject>> EditSystem([FromBody] JObject changes) public async Task<ActionResult<JObject>> EditSystem([FromBody] JObject changes)
{ {
var system = _auth.CurrentSystem; var system = await _conn.Execute(c => c.QuerySystem(User.CurrentSystem()));
try try
{ {
system.ApplyJson(changes); system.ApplyJson(changes);
@ -142,18 +143,18 @@ namespace PluralKit.API
} }
await _data.SaveSystem(system); await _data.SaveSystem(system);
return Ok(system.ToJson(_auth.ContextFor(system))); return Ok(system.ToJson(User.ContextFor(system)));
} }
[HttpPost("switches")] [HttpPost("switches")]
[RequiresSystem] [Authorize]
public async Task<IActionResult> PostSwitch([FromBody] PostSwitchParams param) public async Task<IActionResult> PostSwitch([FromBody] PostSwitchParams param)
{ {
if (param.Members.Distinct().Count() != param.Members.Count()) if (param.Members.Distinct().Count() != param.Members.Count)
return BadRequest("Duplicate members in member list."); return BadRequest("Duplicate members in member list.");
// We get the current switch, if it exists // We get the current switch, if it exists
var latestSwitch = await _data.GetLatestSwitch(_auth.CurrentSystem); var latestSwitch = await _data.GetLatestSwitch(User.CurrentSystem());
if (latestSwitch != null) if (latestSwitch != null)
{ {
var latestSwitchMembers = _data.GetSwitchMembers(latestSwitch); var latestSwitchMembers = _data.GetSwitchMembers(latestSwitch);
@ -169,7 +170,7 @@ namespace PluralKit.API
membersList = (await conn.QueryAsync<PKMember>("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList(); membersList = (await conn.QueryAsync<PKMember>("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList();
foreach (var member in membersList) foreach (var member in membersList)
if (member.System != _auth.CurrentSystem.Id) if (member.System != User.CurrentSystem())
return BadRequest($"Cannot switch to member '{member.Hid}' not in system."); return BadRequest($"Cannot switch to member '{member.Hid}' not in system.");
// membersList is in DB order, and we want it in actual input order // membersList is in DB order, and we want it in actual input order
@ -185,7 +186,7 @@ namespace PluralKit.API
} }
// Finally, log the switch (yay!) // Finally, log the switch (yay!)
await _data.AddSwitch(_auth.CurrentSystem, membersInOrder); await _data.AddSwitch(User.CurrentSystem(), membersInOrder);
return NoContent(); return NoContent();
} }
} }

View File

@ -6,8 +6,6 @@ namespace PluralKit.API
{ {
protected override void Load(ContainerBuilder builder) protected override void Load(ContainerBuilder builder)
{ {
// Lifetime scope so the service, RequiresSystem, and handler itself all get the same value
builder.RegisterType<TokenAuthService>().AsSelf().InstancePerLifetimeScope();
} }
} }
} }

View File

@ -1,41 +0,0 @@
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Microsoft.AspNetCore.Http;
using PluralKit.Core;
namespace PluralKit.API
{
public class TokenAuthService: IMiddleware
{
public PKSystem CurrentSystem { get; set; }
private readonly IDatabase _db;
public TokenAuthService(IDatabase db)
{
_db = db;
}
public async Task InvokeAsync(HttpContext context, RequestDelegate next)
{
var token = context.Request.Headers["Authorization"].FirstOrDefault();
if (token != null)
{
CurrentSystem = await _db.Execute(c => c.QueryFirstOrDefaultAsync("select * from systems where token = @token", new { token }));
}
await next.Invoke(context);
CurrentSystem = null;
}
public LookupContext ContextFor(PKSystem system) =>
system.Id == CurrentSystem?.Id ? LookupContext.ByOwner : LookupContext.API;
public LookupContext ContextFor(PKMember member) =>
member.System == CurrentSystem?.Id ? LookupContext.ByOwner : LookupContext.API;
}
}

View File

@ -4,6 +4,8 @@ using System.Reflection;
using Autofac; using Autofac;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
@ -13,6 +15,7 @@ using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Hosting;
using Microsoft.OpenApi.Models; using Microsoft.OpenApi.Models;
using PluralKit.API;
using PluralKit.Core; using PluralKit.Core;
namespace PluralKit.API namespace PluralKit.API
@ -30,6 +33,23 @@ namespace PluralKit.API
public void ConfigureServices(IServiceCollection services) public void ConfigureServices(IServiceCollection services)
{ {
services.AddCors(); services.AddCors();
services.AddAuthentication("SystemToken")
.AddScheme<SystemTokenAuthenticationHandler.Opts, SystemTokenAuthenticationHandler>("SystemToken", null);
services.AddAuthorization(options =>
{
options.AddPolicy("EditSystem", p => p.RequireAuthenticatedUser().AddRequirements(new OwnSystemRequirement()));
options.AddPolicy("EditMember", p => p.RequireAuthenticatedUser().AddRequirements(new OwnSystemRequirement()));
options.AddPolicy("ViewMembers", p => p.AddRequirements(new PrivacyRequirement<PKSystem>(s => s.MemberListPrivacy)));
options.AddPolicy("ViewFront", p => p.AddRequirements(new PrivacyRequirement<PKSystem>(s => s.FrontPrivacy)));
options.AddPolicy("ViewFrontHistory", p => p.AddRequirements(new PrivacyRequirement<PKSystem>(s => s.FrontHistoryPrivacy)));
});
services.AddSingleton<IAuthenticationHandler, SystemTokenAuthenticationHandler>();
services.AddSingleton<IAuthorizationHandler, MemberOwnerHandler>();
services.AddSingleton<IAuthorizationHandler, SystemOwnerHandler>();
services.AddSingleton<IAuthorizationHandler, SystemPrivacyHandler>();
services.AddControllers() services.AddControllers()
.SetCompatibilityVersion(CompatibilityVersion.Latest) .SetCompatibilityVersion(CompatibilityVersion.Latest)
.AddNewtonsoftJson(); // sorry MS, this just does *more* .AddNewtonsoftJson(); // sorry MS, this just does *more*
@ -105,9 +125,10 @@ namespace PluralKit.API
//app.UseHttpsRedirection(); //app.UseHttpsRedirection();
app.UseCors(opts => opts.AllowAnyMethod().AllowAnyOrigin().WithHeaders("Content-Type", "Authorization")); app.UseCors(opts => opts.AllowAnyMethod().AllowAnyOrigin().WithHeaders("Content-Type", "Authorization"));
app.UseMiddleware<TokenAuthService>();
app.UseRouting(); app.UseRouting();
app.UseAuthentication();
app.UseAuthorization();
app.UseEndpoints(endpoints => endpoints.MapControllers()); app.UseEndpoints(endpoints => endpoints.MapControllers());
} }
} }

View File

@ -1,23 +0,0 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.Extensions.DependencyInjection;
namespace PluralKit.API
{
public class RequiresSystemAttribute: ActionFilterAttribute
{
public override async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next)
{
var auth = context.HttpContext.RequestServices.GetRequiredService<TokenAuthService>();
if (auth.CurrentSystem == null)
{
context.Result = new UnauthorizedObjectResult("Invalid or missing token in Authorization header.");
return;
}
await base.OnActionExecutionAsync(context, next);
}
}
}

View File

@ -33,12 +33,12 @@ namespace PluralKit.Bot
} }
// Enforce per-system member limit // Enforce per-system member limit
var memberCount = await _data.GetSystemMemberCount(ctx.System, true); var memberCount = await _data.GetSystemMemberCount(ctx.System.Id, true);
if (memberCount >= Limits.MaxMemberCount) if (memberCount >= Limits.MaxMemberCount)
throw Errors.MemberLimitReachedError; throw Errors.MemberLimitReachedError;
// Create the member // Create the member
var member = await _data.CreateMember(ctx.System, memberName); var member = await _data.CreateMember(ctx.System.Id, memberName);
memberCount++; memberCount++;
// Send confirmation and space hint // Send confirmation and space hint

View File

@ -56,7 +56,7 @@ namespace PluralKit.Bot
if (members.Select(m => m.Id).Distinct().Count() != members.Count) throw Errors.DuplicateSwitchMembers; if (members.Select(m => m.Id).Distinct().Count() != members.Count) throw Errors.DuplicateSwitchMembers;
// Find the last switch and its members if applicable // Find the last switch and its members if applicable
var lastSwitch = await _data.GetLatestSwitch(ctx.System); var lastSwitch = await _data.GetLatestSwitch(ctx.System.Id);
if (lastSwitch != null) if (lastSwitch != null)
{ {
var lastSwitchMembers = _data.GetSwitchMembers(lastSwitch); var lastSwitchMembers = _data.GetSwitchMembers(lastSwitch);
@ -65,7 +65,7 @@ namespace PluralKit.Bot
throw Errors.SameSwitch(members); throw Errors.SameSwitch(members);
} }
await _data.AddSwitch(ctx.System, members); await _data.AddSwitch(ctx.System.Id, members);
if (members.Count == 0) if (members.Count == 0)
await ctx.Reply($"{Emojis.Success} Switch-out registered."); await ctx.Reply($"{Emojis.Success} Switch-out registered.");
@ -87,7 +87,7 @@ namespace PluralKit.Bot
if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture; if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture;
// Fetch the last two switches for the system to do bounds checking on // Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = await _data.GetSwitches(ctx.System).Take(2).ToListAsync(); var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync();
// If we don't have a switch to move, don't bother // If we don't have a switch to move, don't bother
if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
@ -133,7 +133,7 @@ namespace PluralKit.Bot
} }
// Fetch the last two switches for the system to do bounds checking on // Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = await _data.GetSwitches(ctx.System).Take(2).ToListAsync(); var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync();
if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]); var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]);

View File

@ -36,7 +36,7 @@ namespace PluralKit.Bot
if (system == null) throw Errors.NoSystemError; if (system == null) throw Errors.NoSystemError;
ctx.CheckSystemPrivacy(system, system.FrontPrivacy); ctx.CheckSystemPrivacy(system, system.FrontPrivacy);
var sw = await _data.GetLatestSwitch(system); var sw = await _data.GetLatestSwitch(system.Id);
if (sw == null) throw Errors.NoRegisteredSwitches; if (sw == null) throw Errors.NoRegisteredSwitches;
await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone)); await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone));
@ -47,7 +47,7 @@ namespace PluralKit.Bot
if (system == null) throw Errors.NoSystemError; if (system == null) throw Errors.NoSystemError;
ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy); ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy);
var sws = _data.GetSwitches(system) var sws = _data.GetSwitches(system.Id)
.Scan(new FrontHistoryEntry(null, null), (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch)); .Scan(new FrontHistoryEntry(null, null), (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch));
var totalSwitches = await _data.GetSwitchCount(system); var totalSwitches = await _data.GetSwitchCount(system);
if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; if (totalSwitches == 0) throw Errors.NoRegisteredSwitches;

View File

@ -32,14 +32,14 @@ namespace PluralKit.Bot {
// Fetch/render info for all accounts simultaneously // Fetch/render info for all accounts simultaneously
var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUserAsync(uid))?.NameAndMention() ?? $"(deleted account {uid})")); var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUserAsync(uid))?.NameAndMention() ?? $"(deleted account {uid})"));
var memberCount = await _data.GetSystemMemberCount(system, false); var memberCount = await _data.GetSystemMemberCount(system.Id, false);
var eb = new DiscordEmbedBuilder() var eb = new DiscordEmbedBuilder()
.WithColor(DiscordUtils.Gray) .WithColor(DiscordUtils.Gray)
.WithTitle(system.Name ?? null) .WithTitle(system.Name ?? null)
.WithThumbnailUrl(system.AvatarUrl) .WithThumbnailUrl(system.AvatarUrl)
.WithFooter($"System ID: {system.Hid} | Created on {DateTimeFormats.ZonedDateTimeFormat.Format(system.Created.InZone(system.Zone))}"); .WithFooter($"System ID: {system.Hid} | Created on {DateTimeFormats.ZonedDateTimeFormat.Format(system.Created.InZone(system.Zone))}");
var latestSwitch = await _data.GetLatestSwitch(system); var latestSwitch = await _data.GetLatestSwitch(system.Id);
if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx)) if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx))
{ {
var switchMembers = await _data.GetSwitchMembers(latestSwitch).ToListAsync(); var switchMembers = await _data.GetSwitchMembers(latestSwitch).ToListAsync();

View File

@ -132,7 +132,7 @@ namespace PluralKit.Core
{ {
// Tally up the members that didn't exist before, and check member count on import // Tally up the members that didn't exist before, and check member count on import
// If creating the unmatched members would put us over the member limit, abort before creating any members // If creating the unmatched members would put us over the member limit, abort before creating any members
var memberCountBefore = await _data.GetSystemMemberCount(system, true); var memberCountBefore = await _data.GetSystemMemberCount(system.Id, true);
var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name)); var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name));
if (memberCountBefore + membersToAdd > Limits.MaxMemberCount) if (memberCountBefore + membersToAdd > Limits.MaxMemberCount)
{ {

View File

@ -79,7 +79,7 @@ namespace PluralKit.Core {
/// Gets the member count of a system. /// Gets the member count of a system.
/// </summary> /// </summary>
/// <param name="includePrivate">Whether the returned count should include private members.</param> /// <param name="includePrivate">Whether the returned count should include private members.</param>
Task<int> GetSystemMemberCount(PKSystem system, bool includePrivate); Task<int> GetSystemMemberCount(SystemId system, bool includePrivate);
/// <summary> /// <summary>
/// Gets a list of members with proxy tags that conflict with the given tags. /// Gets a list of members with proxy tags that conflict with the given tags.
@ -162,7 +162,7 @@ namespace PluralKit.Core {
/// <param name="system">The system in which to create the member.</param> /// <param name="system">The system in which to create the member.</param>
/// <param name="name">The name of the member to create.</param> /// <param name="name">The name of the member to create.</param>
/// <returns>The created system model.</returns> /// <returns>The created system model.</returns>
Task<PKMember> CreateMember(PKSystem system, string name); Task<PKMember> CreateMember(SystemId system, string name);
/// <summary> /// <summary>
/// Saves the information within the given <see cref="PKMember"/> struct to the data store. /// Saves the information within the given <see cref="PKMember"/> struct to the data store.
@ -213,7 +213,7 @@ namespace PluralKit.Core {
/// Gets switches from a system. /// Gets switches from a system.
/// </summary> /// </summary>
/// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns> /// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns>
IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system); IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system);
/// <summary> /// <summary>
/// Gets the total amount of switches in a given system. /// Gets the total amount of switches in a given system.
@ -223,7 +223,7 @@ namespace PluralKit.Core {
/// <summary> /// <summary>
/// Gets the latest (temporally; closest to now) switch of a given system. /// Gets the latest (temporally; closest to now) switch of a given system.
/// </summary> /// </summary>
Task<PKSwitch> GetLatestSwitch(PKSystem system); Task<PKSwitch> GetLatestSwitch(SystemId system);
/// <summary> /// <summary>
/// Gets the members a given switch consists of. /// Gets the members a given switch consists of.
@ -261,7 +261,7 @@ namespace PluralKit.Core {
/// Registers a switch with the given members in the given system. /// Registers a switch with the given members in the given system.
/// </summary> /// </summary>
/// <exception>Throws an exception (TODO: which?) if any of the members are not in the given system.</exception> /// <exception>Throws an exception (TODO: which?) if any of the members are not in the given system.</exception>
Task AddSwitch(PKSystem system, IEnumerable<PKMember> switchMembers); Task AddSwitch(SystemId system, IEnumerable<PKMember> switchMembers);
/// <summary> /// <summary>
/// Updates the timestamp of a given switch. /// Updates the timestamp of a given switch.

View File

@ -113,11 +113,11 @@ namespace PluralKit.Core {
await conn.ExecuteAsync("delete from switches where system = @Id", system); await conn.ExecuteAsync("delete from switches where system = @Id", system);
} }
public async Task<PKMember> CreateMember(PKSystem system, string name) { public async Task<PKMember> CreateMember(SystemId system, string name) {
PKMember member; PKMember member;
using (var conn = await _conn.Obtain()) using (var conn = await _conn.Obtain())
member = await conn.QuerySingleAsync<PKMember>("insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", new { member = await conn.QuerySingleAsync<PKMember>("insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", new {
SystemID = system.Id, SystemID = system,
Name = name Name = name
}); });
@ -162,13 +162,13 @@ namespace PluralKit.Core {
_logger.Information("Deleted member {@Member}", member); _logger.Information("Deleted member {@Member}", member);
} }
public async Task<int> GetSystemMemberCount(PKSystem system, bool includePrivate) public async Task<int> GetSystemMemberCount(SystemId id, bool includePrivate)
{ {
var query = "select count(*) from members where system = @Id"; var query = "select count(*) from members where system = @id";
if (!includePrivate) query += " and member_privacy = 1"; // 1 = public if (!includePrivate) query += " and member_privacy = 1"; // 1 = public
using (var conn = await _conn.Obtain()) using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<int>(query, system); return await conn.ExecuteScalarAsync<int>(query, new { id });
} }
public async Task<ulong> GetTotalMembers() public async Task<ulong> GetTotalMembers()
@ -220,7 +220,7 @@ namespace PluralKit.Core {
} }
} }
public async Task AddSwitch(PKSystem system, IEnumerable<PKMember> members) public async Task AddSwitch(SystemId system, IEnumerable<PKMember> members)
{ {
// Use a transaction here since we're doing multiple executed commands in one // Use a transaction here since we're doing multiple executed commands in one
await using var conn = await _conn.Obtain(); await using var conn = await _conn.Obtain();
@ -228,7 +228,7 @@ namespace PluralKit.Core {
// First, we insert the switch itself // First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *", var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system.Id}); new {System = system});
// Then we insert each member in the switch in the switch_members table // Then we insert each member in the switch in the switch_members table
// TODO: can we parallelize this or send it in bulk somehow? // TODO: can we parallelize this or send it in bulk somehow?
@ -242,16 +242,16 @@ namespace PluralKit.Core {
// Finally we commit the tx, since the using block will otherwise rollback it // Finally we commit the tx, since the using block will otherwise rollback it
await tx.CommitAsync(); await tx.CommitAsync();
_logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id)); _logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system, members.Select(m => m.Id));
} }
public IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system) public IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system)
{ {
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
// (maybe when we get caching in?) // (maybe when we get caching in?)
return _conn.QueryStreamAsync<PKSwitch>( return _conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc", "select * from switches where system = @System order by timestamp desc",
new {System = system.Id}); new {System = system});
} }
public async Task<int> GetSwitchCount(PKSystem system) public async Task<int> GetSwitchCount(PKSystem system)
@ -304,7 +304,7 @@ namespace PluralKit.Core {
new {Switch = sw.Id}); new {Switch = sw.Id});
} }
public async Task<PKSwitch> GetLatestSwitch(PKSystem system) => public async Task<PKSwitch> GetLatestSwitch(SystemId system) =>
await GetSwitches(system).FirstOrDefaultAsync(); await GetSwitches(system).FirstOrDefaultAsync();
public async Task MoveSwitch(PKSwitch sw, Instant time) public async Task MoveSwitch(PKSwitch sw, Instant time)