diff --git a/PluralKit.API/Authentication/AuthExt.cs b/PluralKit.API/Authentication/AuthExt.cs new file mode 100644 index 00000000..1d259eec --- /dev/null +++ b/PluralKit.API/Authentication/AuthExt.cs @@ -0,0 +1,32 @@ +using System; +using System.Security.Claims; + +using PluralKit.Core; + +namespace PluralKit.API +{ + public static class AuthExt + { + public static SystemId CurrentSystem(this ClaimsPrincipal user) + { + var claim = user.FindFirst(PKClaims.SystemId); + if (claim == null) throw new ArgumentException("User is unauthorized"); + + if (int.TryParse(claim.Value, out var id)) + return new SystemId(id); + throw new ArgumentException("User has non-integer system ID claim"); + } + + public static LookupContext ContextFor(this ClaimsPrincipal user, PKSystem system) + { + if (!user.Identity.IsAuthenticated) return LookupContext.API; + return system.Id == user.CurrentSystem() ? LookupContext.ByOwner : LookupContext.API; + } + + public static LookupContext ContextFor(this ClaimsPrincipal user, PKMember member) + { + if (!user.Identity.IsAuthenticated) return LookupContext.API; + return member.System == user.CurrentSystem() ? LookupContext.ByOwner : LookupContext.API; + } + } +} \ No newline at end of file diff --git a/PluralKit.API/Authentication/PKClaims.cs b/PluralKit.API/Authentication/PKClaims.cs new file mode 100644 index 00000000..2ab31e1a --- /dev/null +++ b/PluralKit.API/Authentication/PKClaims.cs @@ -0,0 +1,7 @@ +namespace PluralKit.API +{ + public class PKClaims + { + public const string SystemId = "PluralKit:SystemId"; + } +} \ No newline at end of file diff --git a/PluralKit.API/Authentication/SystemTokenAuthenticationHandler.cs b/PluralKit.API/Authentication/SystemTokenAuthenticationHandler.cs new file mode 100644 index 00000000..8080f60a --- /dev/null +++ b/PluralKit.API/Authentication/SystemTokenAuthenticationHandler.cs @@ -0,0 +1,49 @@ +using System; +using System.Linq; +using System.Security.Claims; +using System.Text.Encodings.Web; +using System.Threading.Tasks; + +using Dapper; + +using Microsoft.AspNetCore.Authentication; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +using PluralKit.Core; + +namespace PluralKit.API +{ + public class SystemTokenAuthenticationHandler: AuthenticationHandler + { + private readonly IDatabase _db; + + public SystemTokenAuthenticationHandler(IOptionsMonitor options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock, IDatabase db): base(options, logger, encoder, clock) + { + _db = db; + } + + protected override async Task HandleAuthenticateAsync() + { + if (!Request.Headers.ContainsKey("Authorization")) + return AuthenticateResult.NoResult(); + + var token = Request.Headers["Authorization"].FirstOrDefault(); + var systemId = await _db.Execute(c => c.QuerySingleOrDefaultAsync("select id from systems where token = @token", new { token })); + if (systemId == null) return AuthenticateResult.Fail("Invalid system token"); + + var claims = new[] {new Claim(PKClaims.SystemId, systemId.Value.Value.ToString())}; + var identity = new ClaimsIdentity(claims, Scheme.Name); + var principal = new ClaimsPrincipal(identity); + var ticket = new AuthenticationTicket(principal, Scheme.Name); + ticket.Properties.IsPersistent = false; + ticket.Properties.AllowRefresh = false; + return AuthenticateResult.Success(ticket); + } + + public class Opts: AuthenticationSchemeOptions + { + + } + } +} \ No newline at end of file diff --git a/PluralKit.API/Authorization/MemberOwnerHandler.cs b/PluralKit.API/Authorization/MemberOwnerHandler.cs new file mode 100644 index 00000000..a212ad2c --- /dev/null +++ b/PluralKit.API/Authorization/MemberOwnerHandler.cs @@ -0,0 +1,19 @@ +using System.Threading.Tasks; + +using Microsoft.AspNetCore.Authorization; + +using PluralKit.Core; + +namespace PluralKit.API +{ + public class MemberOwnerHandler: AuthorizationHandler { + protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, + OwnSystemRequirement requirement, PKMember resource) + { + if (!context.User.Identity.IsAuthenticated) return Task.CompletedTask; + if (resource.System == context.User.CurrentSystem()) + context.Succeed(requirement); + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/PluralKit.API/Authorization/MemberPrivacyHandler.cs b/PluralKit.API/Authorization/MemberPrivacyHandler.cs new file mode 100644 index 00000000..41437ed6 --- /dev/null +++ b/PluralKit.API/Authorization/MemberPrivacyHandler.cs @@ -0,0 +1,21 @@ +using System.Threading.Tasks; + +using Microsoft.AspNetCore.Authorization; + +using PluralKit.Core; + +namespace PluralKit.API +{ + public class MemberPrivacyHandler: AuthorizationHandler, PKMember> + { + protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, + PrivacyRequirement requirement, PKMember resource) + { + var level = requirement.Mapper(resource); + var ctx = context.User.ContextFor(resource); + if (level.CanAccess(ctx)) + context.Succeed(requirement); + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/PluralKit.API/Authorization/OwnSystemRequirement.cs b/PluralKit.API/Authorization/OwnSystemRequirement.cs new file mode 100644 index 00000000..e292db75 --- /dev/null +++ b/PluralKit.API/Authorization/OwnSystemRequirement.cs @@ -0,0 +1,6 @@ +using Microsoft.AspNetCore.Authorization; + +namespace PluralKit.API +{ + public class OwnSystemRequirement: IAuthorizationRequirement { } +} \ No newline at end of file diff --git a/PluralKit.API/Authorization/PrivacyRequirement.cs b/PluralKit.API/Authorization/PrivacyRequirement.cs new file mode 100644 index 00000000..ef9312e1 --- /dev/null +++ b/PluralKit.API/Authorization/PrivacyRequirement.cs @@ -0,0 +1,18 @@ +using System; + +using Microsoft.AspNetCore.Authorization; + +using PluralKit.Core; + +namespace PluralKit.API +{ + public class PrivacyRequirement: IAuthorizationRequirement + { + public readonly Func Mapper; + + public PrivacyRequirement(Func mapper) + { + Mapper = mapper; + } + } +} \ No newline at end of file diff --git a/PluralKit.API/Authorization/SystemOwnerHandler.cs b/PluralKit.API/Authorization/SystemOwnerHandler.cs new file mode 100644 index 00000000..72cfede7 --- /dev/null +++ b/PluralKit.API/Authorization/SystemOwnerHandler.cs @@ -0,0 +1,20 @@ +using System.Threading.Tasks; + +using Microsoft.AspNetCore.Authorization; + +using PluralKit.Core; + +namespace PluralKit.API +{ + public class SystemOwnerHandler: AuthorizationHandler + { + protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, + OwnSystemRequirement requirement, PKSystem resource) + { + if (!context.User.Identity.IsAuthenticated) return Task.CompletedTask; + if (resource.Id == context.User.CurrentSystem()) + context.Succeed(requirement); + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/PluralKit.API/Authorization/SystemPrivacyHandler.cs b/PluralKit.API/Authorization/SystemPrivacyHandler.cs new file mode 100644 index 00000000..469324b1 --- /dev/null +++ b/PluralKit.API/Authorization/SystemPrivacyHandler.cs @@ -0,0 +1,21 @@ +using System.Threading.Tasks; + +using Microsoft.AspNetCore.Authorization; + +using PluralKit.Core; + +namespace PluralKit.API +{ + public class SystemPrivacyHandler: AuthorizationHandler, PKSystem> + { + protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, + PrivacyRequirement requirement, PKSystem resource) + { + var level = requirement.Mapper(resource); + var ctx = context.User.ContextFor(resource); + if (level.CanAccess(ctx)) + context.Succeed(requirement); + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/PluralKit.API/Controllers/AccountController.cs b/PluralKit.API/Controllers/v1/AccountController.cs similarity index 76% rename from PluralKit.API/Controllers/AccountController.cs rename to PluralKit.API/Controllers/v1/AccountController.cs index e1541b47..0ffec724 100644 --- a/PluralKit.API/Controllers/AccountController.cs +++ b/PluralKit.API/Controllers/v1/AccountController.cs @@ -15,12 +15,10 @@ namespace PluralKit.API public class AccountController: ControllerBase { private IDataStore _data; - private TokenAuthService _auth; - public AccountController(IDataStore data, TokenAuthService auth) + public AccountController(IDataStore data) { _data = data; - _auth = auth; } [HttpGet("{aid}")] @@ -29,7 +27,7 @@ namespace PluralKit.API var system = await _data.GetSystemByAccount(aid); if (system == null) return NotFound("Account not found."); - return Ok(system.ToJson(_auth.ContextFor(system))); + return Ok(system.ToJson(User.ContextFor(system))); } } } \ No newline at end of file diff --git a/PluralKit.API/Utils/JsonModelExt.cs b/PluralKit.API/Controllers/v1/JsonModelExt.cs similarity index 100% rename from PluralKit.API/Utils/JsonModelExt.cs rename to PluralKit.API/Controllers/v1/JsonModelExt.cs diff --git a/PluralKit.API/Controllers/MemberController.cs b/PluralKit.API/Controllers/v1/MemberController.cs similarity index 74% rename from PluralKit.API/Controllers/MemberController.cs rename to PluralKit.API/Controllers/v1/MemberController.cs index 72318a1e..40213aad 100644 --- a/PluralKit.API/Controllers/MemberController.cs +++ b/PluralKit.API/Controllers/v1/MemberController.cs @@ -1,5 +1,6 @@ using System.Threading.Tasks; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Newtonsoft.Json.Linq; @@ -15,9 +16,9 @@ namespace PluralKit.API public class MemberController: ControllerBase { private IDataStore _data; - private TokenAuthService _auth; + private IAuthorizationService _auth; - public MemberController(IDataStore data, TokenAuthService auth) + public MemberController(IDataStore data, IAuthorizationService auth) { _data = data; _auth = auth; @@ -29,15 +30,15 @@ namespace PluralKit.API var member = await _data.GetMemberByHid(hid); if (member == null) return NotFound("Member not found."); - return Ok(member.ToJson(_auth.ContextFor(member))); + return Ok(member.ToJson(User.ContextFor(member))); } [HttpPost] - [RequiresSystem] + [Authorize] public async Task> PostMember([FromBody] JObject properties) { - var system = _auth.CurrentSystem; - + var system = User.CurrentSystem(); + if (!properties.ContainsKey("name")) return BadRequest("Member name must be specified."); @@ -57,17 +58,18 @@ namespace PluralKit.API } await _data.SaveMember(member); - return Ok(member.ToJson(_auth.ContextFor(member))); + return Ok(member.ToJson(User.ContextFor(member))); } [HttpPatch("{hid}")] - [RequiresSystem] + [Authorize] public async Task> PatchMember(string hid, [FromBody] JObject changes) { var member = await _data.GetMemberByHid(hid); if (member == null) return NotFound("Member not found."); - - if (member.System != _auth.CurrentSystem.Id) return Unauthorized($"Member '{hid}' is not part of your system."); + + var res = await _auth.AuthorizeAsync(User, member, "EditMember"); + if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system."); try { @@ -79,17 +81,18 @@ namespace PluralKit.API } await _data.SaveMember(member); - return Ok(member.ToJson(_auth.ContextFor(member))); + return Ok(member.ToJson(User.ContextFor(member))); } [HttpDelete("{hid}")] - [RequiresSystem] + [Authorize] public async Task DeleteMember(string hid) { var member = await _data.GetMemberByHid(hid); if (member == null) return NotFound("Member not found."); - if (member.System != _auth.CurrentSystem.Id) return Unauthorized($"Member '{hid}' is not part of your system."); + var res = await _auth.AuthorizeAsync(User, member, "EditMember"); + if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system."); await _data.DeleteMember(member); return Ok(); diff --git a/PluralKit.API/Controllers/MessageController.cs b/PluralKit.API/Controllers/v1/MessageController.cs similarity index 84% rename from PluralKit.API/Controllers/MessageController.cs rename to PluralKit.API/Controllers/v1/MessageController.cs index f5a5b849..f3738742 100644 --- a/PluralKit.API/Controllers/MessageController.cs +++ b/PluralKit.API/Controllers/v1/MessageController.cs @@ -30,12 +30,10 @@ namespace PluralKit.API public class MessageController: ControllerBase { private IDataStore _data; - private TokenAuthService _auth; - public MessageController(IDataStore _data, TokenAuthService auth) + public MessageController(IDataStore _data) { this._data = _data; - _auth = auth; } [HttpGet("{mid}")] @@ -50,8 +48,8 @@ namespace PluralKit.API Id = msg.Message.Mid.ToString(), Channel = msg.Message.Channel.ToString(), Sender = msg.Message.Sender.ToString(), - Member = msg.Member.ToJson(_auth.ContextFor(msg.System)), - System = msg.System.ToJson(_auth.ContextFor(msg.System)), + Member = msg.Member.ToJson(User.ContextFor(msg.System)), + System = msg.System.ToJson(User.ContextFor(msg.System)), Original = msg.Message.OriginalMid?.ToString() }; } diff --git a/PluralKit.API/Controllers/SystemController.cs b/PluralKit.API/Controllers/v1/SystemController.cs similarity index 78% rename from PluralKit.API/Controllers/SystemController.cs rename to PluralKit.API/Controllers/v1/SystemController.cs index 4d3bd90a..8129b874 100644 --- a/PluralKit.API/Controllers/SystemController.cs +++ b/PluralKit.API/Controllers/v1/SystemController.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using Dapper; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; @@ -41,9 +42,9 @@ namespace PluralKit.API { private IDataStore _data; private IDatabase _conn; - private TokenAuthService _auth; + private IAuthorizationService _auth; - public SystemController(IDataStore data, IDatabase conn, TokenAuthService auth) + public SystemController(IDataStore data, IDatabase conn, IAuthorizationService auth) { _data = data; _conn = conn; @@ -51,10 +52,11 @@ namespace PluralKit.API } [HttpGet] - [RequiresSystem] - public Task> GetOwnSystem() + [Authorize] + public async Task> GetOwnSystem() { - return Task.FromResult>(Ok(_auth.CurrentSystem.ToJson(_auth.ContextFor(_auth.CurrentSystem)))); + var system = await _conn.Execute(c => c.QuerySystem(User.CurrentSystem())); + return system.ToJson(User.ContextFor(system)); } [HttpGet("{hid}")] @@ -62,7 +64,7 @@ namespace PluralKit.API { var system = await _data.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); - return Ok(system.ToJson(_auth.ContextFor(system))); + return Ok(system.ToJson(User.ContextFor(system))); } [HttpGet("{hid}/members")] @@ -71,13 +73,13 @@ namespace PluralKit.API var system = await _data.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); - if (!system.MemberListPrivacy.CanAccess(_auth.ContextFor(system))) + if (!system.MemberListPrivacy.CanAccess(User.ContextFor(system))) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list."); var members = _data.GetSystemMembers(system); return Ok(await members - .Where(m => m.MemberPrivacy.CanAccess(_auth.ContextFor(system))) - .Select(m => m.ToJson(_auth.ContextFor(system))) + .Where(m => m.MemberPrivacy.CanAccess(User.ContextFor(system))) + .Select(m => m.ToJson(User.ContextFor(system))) .ToListAsync()); } @@ -88,9 +90,9 @@ namespace PluralKit.API var system = await _data.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); - - if (!system.FrontHistoryPrivacy.CanAccess(_auth.ContextFor(system))) - return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history."); + + var auth = await _auth.AuthorizeAsync(User, system, "ViewFrontHistory"); + if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history."); using (var conn = await _conn.Obtain()) { @@ -112,26 +114,25 @@ namespace PluralKit.API var system = await _data.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); - if (!system.FrontPrivacy.CanAccess(_auth.ContextFor(system))) - return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter."); + var auth = await _auth.AuthorizeAsync(User, system, "ViewFront"); + if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter."); - var sw = await _data.GetLatestSwitch(system); + var sw = await _data.GetLatestSwitch(system.Id); if (sw == null) return NotFound("System has no registered switches."); var members = _data.GetSwitchMembers(sw); return Ok(new FrontersReturn { Timestamp = sw.Timestamp, - Members = await members.Select(m => m.ToJson(_auth.ContextFor(system))).ToListAsync() + Members = await members.Select(m => m.ToJson(User.ContextFor(system))).ToListAsync() }); } [HttpPatch] - [RequiresSystem] + [Authorize] public async Task> EditSystem([FromBody] JObject changes) { - var system = _auth.CurrentSystem; - + var system = await _conn.Execute(c => c.QuerySystem(User.CurrentSystem())); try { system.ApplyJson(changes); @@ -142,18 +143,18 @@ namespace PluralKit.API } await _data.SaveSystem(system); - return Ok(system.ToJson(_auth.ContextFor(system))); + return Ok(system.ToJson(User.ContextFor(system))); } [HttpPost("switches")] - [RequiresSystem] + [Authorize] public async Task PostSwitch([FromBody] PostSwitchParams param) { - if (param.Members.Distinct().Count() != param.Members.Count()) + if (param.Members.Distinct().Count() != param.Members.Count) return BadRequest("Duplicate members in member list."); // We get the current switch, if it exists - var latestSwitch = await _data.GetLatestSwitch(_auth.CurrentSystem); + var latestSwitch = await _data.GetLatestSwitch(User.CurrentSystem()); if (latestSwitch != null) { var latestSwitchMembers = _data.GetSwitchMembers(latestSwitch); @@ -169,7 +170,7 @@ namespace PluralKit.API membersList = (await conn.QueryAsync("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList(); foreach (var member in membersList) - if (member.System != _auth.CurrentSystem.Id) + if (member.System != User.CurrentSystem()) return BadRequest($"Cannot switch to member '{member.Hid}' not in system."); // membersList is in DB order, and we want it in actual input order @@ -185,7 +186,7 @@ namespace PluralKit.API } // Finally, log the switch (yay!) - await _data.AddSwitch(_auth.CurrentSystem, membersInOrder); + await _data.AddSwitch(User.CurrentSystem(), membersInOrder); return NoContent(); } } diff --git a/PluralKit.API/Modules.cs b/PluralKit.API/Modules.cs index b844dba4..827e32b9 100644 --- a/PluralKit.API/Modules.cs +++ b/PluralKit.API/Modules.cs @@ -6,8 +6,6 @@ namespace PluralKit.API { protected override void Load(ContainerBuilder builder) { - // Lifetime scope so the service, RequiresSystem, and handler itself all get the same value - builder.RegisterType().AsSelf().InstancePerLifetimeScope(); } } } \ No newline at end of file diff --git a/PluralKit.API/Services/TokenAuthService.cs b/PluralKit.API/Services/TokenAuthService.cs deleted file mode 100644 index 36c6defe..00000000 --- a/PluralKit.API/Services/TokenAuthService.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System.Linq; -using System.Threading.Tasks; - -using Dapper; - -using Microsoft.AspNetCore.Http; - -using PluralKit.Core; - -namespace PluralKit.API -{ - public class TokenAuthService: IMiddleware - { - public PKSystem CurrentSystem { get; set; } - - private readonly IDatabase _db; - - public TokenAuthService(IDatabase db) - { - _db = db; - } - - public async Task InvokeAsync(HttpContext context, RequestDelegate next) - { - var token = context.Request.Headers["Authorization"].FirstOrDefault(); - if (token != null) - { - CurrentSystem = await _db.Execute(c => c.QueryFirstOrDefaultAsync("select * from systems where token = @token", new { token })); - } - - await next.Invoke(context); - CurrentSystem = null; - } - - public LookupContext ContextFor(PKSystem system) => - system.Id == CurrentSystem?.Id ? LookupContext.ByOwner : LookupContext.API; - - public LookupContext ContextFor(PKMember member) => - member.System == CurrentSystem?.Id ? LookupContext.ByOwner : LookupContext.API; - } -} \ No newline at end of file diff --git a/PluralKit.API/Startup.cs b/PluralKit.API/Startup.cs index de975d8f..80f3fb3e 100644 --- a/PluralKit.API/Startup.cs +++ b/PluralKit.API/Startup.cs @@ -4,6 +4,8 @@ using System.Reflection; using Autofac; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Mvc; @@ -13,6 +15,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.OpenApi.Models; +using PluralKit.API; using PluralKit.Core; namespace PluralKit.API @@ -30,6 +33,23 @@ namespace PluralKit.API public void ConfigureServices(IServiceCollection services) { services.AddCors(); + services.AddAuthentication("SystemToken") + .AddScheme("SystemToken", null); + + services.AddAuthorization(options => + { + options.AddPolicy("EditSystem", p => p.RequireAuthenticatedUser().AddRequirements(new OwnSystemRequirement())); + options.AddPolicy("EditMember", p => p.RequireAuthenticatedUser().AddRequirements(new OwnSystemRequirement())); + + options.AddPolicy("ViewMembers", p => p.AddRequirements(new PrivacyRequirement(s => s.MemberListPrivacy))); + options.AddPolicy("ViewFront", p => p.AddRequirements(new PrivacyRequirement(s => s.FrontPrivacy))); + options.AddPolicy("ViewFrontHistory", p => p.AddRequirements(new PrivacyRequirement(s => s.FrontHistoryPrivacy))); + }); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddControllers() .SetCompatibilityVersion(CompatibilityVersion.Latest) .AddNewtonsoftJson(); // sorry MS, this just does *more* @@ -105,9 +125,10 @@ namespace PluralKit.API //app.UseHttpsRedirection(); app.UseCors(opts => opts.AllowAnyMethod().AllowAnyOrigin().WithHeaders("Content-Type", "Authorization")); - app.UseMiddleware(); app.UseRouting(); + app.UseAuthentication(); + app.UseAuthorization(); app.UseEndpoints(endpoints => endpoints.MapControllers()); } } diff --git a/PluralKit.API/Utils/RequiresSystemAttribute.cs b/PluralKit.API/Utils/RequiresSystemAttribute.cs deleted file mode 100644 index 381cf29e..00000000 --- a/PluralKit.API/Utils/RequiresSystemAttribute.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System.Threading.Tasks; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.Filters; -using Microsoft.Extensions.DependencyInjection; - -namespace PluralKit.API -{ - public class RequiresSystemAttribute: ActionFilterAttribute - { - - public override async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next) - { - var auth = context.HttpContext.RequestServices.GetRequiredService(); - if (auth.CurrentSystem == null) - { - context.Result = new UnauthorizedObjectResult("Invalid or missing token in Authorization header."); - return; - } - - await base.OnActionExecutionAsync(context, next); - } - } -} \ No newline at end of file diff --git a/PluralKit.Bot/Commands/Member.cs b/PluralKit.Bot/Commands/Member.cs index 932906b9..0d56ec96 100644 --- a/PluralKit.Bot/Commands/Member.cs +++ b/PluralKit.Bot/Commands/Member.cs @@ -33,12 +33,12 @@ namespace PluralKit.Bot } // Enforce per-system member limit - var memberCount = await _data.GetSystemMemberCount(ctx.System, true); + var memberCount = await _data.GetSystemMemberCount(ctx.System.Id, true); if (memberCount >= Limits.MaxMemberCount) throw Errors.MemberLimitReachedError; // Create the member - var member = await _data.CreateMember(ctx.System, memberName); + var member = await _data.CreateMember(ctx.System.Id, memberName); memberCount++; // Send confirmation and space hint diff --git a/PluralKit.Bot/Commands/Switch.cs b/PluralKit.Bot/Commands/Switch.cs index e650db19..e0086a3e 100644 --- a/PluralKit.Bot/Commands/Switch.cs +++ b/PluralKit.Bot/Commands/Switch.cs @@ -56,7 +56,7 @@ namespace PluralKit.Bot if (members.Select(m => m.Id).Distinct().Count() != members.Count) throw Errors.DuplicateSwitchMembers; // Find the last switch and its members if applicable - var lastSwitch = await _data.GetLatestSwitch(ctx.System); + var lastSwitch = await _data.GetLatestSwitch(ctx.System.Id); if (lastSwitch != null) { var lastSwitchMembers = _data.GetSwitchMembers(lastSwitch); @@ -65,7 +65,7 @@ namespace PluralKit.Bot throw Errors.SameSwitch(members); } - await _data.AddSwitch(ctx.System, members); + await _data.AddSwitch(ctx.System.Id, members); if (members.Count == 0) await ctx.Reply($"{Emojis.Success} Switch-out registered."); @@ -87,7 +87,7 @@ namespace PluralKit.Bot if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture; // Fetch the last two switches for the system to do bounds checking on - var lastTwoSwitches = await _data.GetSwitches(ctx.System).Take(2).ToListAsync(); + var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); // If we don't have a switch to move, don't bother if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; @@ -133,7 +133,7 @@ namespace PluralKit.Bot } // Fetch the last two switches for the system to do bounds checking on - var lastTwoSwitches = await _data.GetSwitches(ctx.System).Take(2).ToListAsync(); + var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]); diff --git a/PluralKit.Bot/Commands/SystemFront.cs b/PluralKit.Bot/Commands/SystemFront.cs index 92e99223..b48d2e69 100644 --- a/PluralKit.Bot/Commands/SystemFront.cs +++ b/PluralKit.Bot/Commands/SystemFront.cs @@ -36,7 +36,7 @@ namespace PluralKit.Bot if (system == null) throw Errors.NoSystemError; ctx.CheckSystemPrivacy(system, system.FrontPrivacy); - var sw = await _data.GetLatestSwitch(system); + var sw = await _data.GetLatestSwitch(system.Id); if (sw == null) throw Errors.NoRegisteredSwitches; await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone)); @@ -47,7 +47,7 @@ namespace PluralKit.Bot if (system == null) throw Errors.NoSystemError; ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy); - var sws = _data.GetSwitches(system) + var sws = _data.GetSwitches(system.Id) .Scan(new FrontHistoryEntry(null, null), (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch)); var totalSwitches = await _data.GetSwitchCount(system); if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; diff --git a/PluralKit.Bot/Services/EmbedService.cs b/PluralKit.Bot/Services/EmbedService.cs index 4b96d5f7..2a2e1ccd 100644 --- a/PluralKit.Bot/Services/EmbedService.cs +++ b/PluralKit.Bot/Services/EmbedService.cs @@ -32,14 +32,14 @@ namespace PluralKit.Bot { // Fetch/render info for all accounts simultaneously var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUserAsync(uid))?.NameAndMention() ?? $"(deleted account {uid})")); - var memberCount = await _data.GetSystemMemberCount(system, false); + var memberCount = await _data.GetSystemMemberCount(system.Id, false); var eb = new DiscordEmbedBuilder() .WithColor(DiscordUtils.Gray) .WithTitle(system.Name ?? null) .WithThumbnailUrl(system.AvatarUrl) .WithFooter($"System ID: {system.Hid} | Created on {DateTimeFormats.ZonedDateTimeFormat.Format(system.Created.InZone(system.Zone))}"); - var latestSwitch = await _data.GetLatestSwitch(system); + var latestSwitch = await _data.GetLatestSwitch(system.Id); if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx)) { var switchMembers = await _data.GetSwitchMembers(latestSwitch).ToListAsync(); diff --git a/PluralKit.Core/Services/DataFileService.cs b/PluralKit.Core/Services/DataFileService.cs index 7eb889c0..82a5f41d 100644 --- a/PluralKit.Core/Services/DataFileService.cs +++ b/PluralKit.Core/Services/DataFileService.cs @@ -132,7 +132,7 @@ namespace PluralKit.Core { // Tally up the members that didn't exist before, and check member count on import // If creating the unmatched members would put us over the member limit, abort before creating any members - var memberCountBefore = await _data.GetSystemMemberCount(system, true); + var memberCountBefore = await _data.GetSystemMemberCount(system.Id, true); var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name)); if (memberCountBefore + membersToAdd > Limits.MaxMemberCount) { diff --git a/PluralKit.Core/Services/IDataStore.cs b/PluralKit.Core/Services/IDataStore.cs index f6493465..522db4b5 100644 --- a/PluralKit.Core/Services/IDataStore.cs +++ b/PluralKit.Core/Services/IDataStore.cs @@ -79,7 +79,7 @@ namespace PluralKit.Core { /// Gets the member count of a system. /// /// Whether the returned count should include private members. - Task GetSystemMemberCount(PKSystem system, bool includePrivate); + Task GetSystemMemberCount(SystemId system, bool includePrivate); /// /// Gets a list of members with proxy tags that conflict with the given tags. @@ -162,7 +162,7 @@ namespace PluralKit.Core { /// The system in which to create the member. /// The name of the member to create. /// The created system model. - Task CreateMember(PKSystem system, string name); + Task CreateMember(SystemId system, string name); /// /// Saves the information within the given struct to the data store. @@ -213,7 +213,7 @@ namespace PluralKit.Core { /// Gets switches from a system. /// /// An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested. - IAsyncEnumerable GetSwitches(PKSystem system); + IAsyncEnumerable GetSwitches(SystemId system); /// /// Gets the total amount of switches in a given system. @@ -223,7 +223,7 @@ namespace PluralKit.Core { /// /// Gets the latest (temporally; closest to now) switch of a given system. /// - Task GetLatestSwitch(PKSystem system); + Task GetLatestSwitch(SystemId system); /// /// Gets the members a given switch consists of. @@ -261,7 +261,7 @@ namespace PluralKit.Core { /// Registers a switch with the given members in the given system. /// /// Throws an exception (TODO: which?) if any of the members are not in the given system. - Task AddSwitch(PKSystem system, IEnumerable switchMembers); + Task AddSwitch(SystemId system, IEnumerable switchMembers); /// /// Updates the timestamp of a given switch. diff --git a/PluralKit.Core/Services/PostgresDataStore.cs b/PluralKit.Core/Services/PostgresDataStore.cs index 4579f012..84e34df6 100644 --- a/PluralKit.Core/Services/PostgresDataStore.cs +++ b/PluralKit.Core/Services/PostgresDataStore.cs @@ -113,11 +113,11 @@ namespace PluralKit.Core { await conn.ExecuteAsync("delete from switches where system = @Id", system); } - public async Task CreateMember(PKSystem system, string name) { + public async Task CreateMember(SystemId system, string name) { PKMember member; using (var conn = await _conn.Obtain()) member = await conn.QuerySingleAsync("insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", new { - SystemID = system.Id, + SystemID = system, Name = name }); @@ -162,13 +162,13 @@ namespace PluralKit.Core { _logger.Information("Deleted member {@Member}", member); } - public async Task GetSystemMemberCount(PKSystem system, bool includePrivate) + public async Task GetSystemMemberCount(SystemId id, bool includePrivate) { - var query = "select count(*) from members where system = @Id"; + var query = "select count(*) from members where system = @id"; if (!includePrivate) query += " and member_privacy = 1"; // 1 = public using (var conn = await _conn.Obtain()) - return await conn.ExecuteScalarAsync(query, system); + return await conn.ExecuteScalarAsync(query, new { id }); } public async Task GetTotalMembers() @@ -220,7 +220,7 @@ namespace PluralKit.Core { } } - public async Task AddSwitch(PKSystem system, IEnumerable members) + public async Task AddSwitch(SystemId system, IEnumerable members) { // Use a transaction here since we're doing multiple executed commands in one await using var conn = await _conn.Obtain(); @@ -228,7 +228,7 @@ namespace PluralKit.Core { // First, we insert the switch itself var sw = await conn.QuerySingleAsync("insert into switches(system) values (@System) returning *", - new {System = system.Id}); + new {System = system}); // Then we insert each member in the switch in the switch_members table // TODO: can we parallelize this or send it in bulk somehow? @@ -242,16 +242,16 @@ namespace PluralKit.Core { // Finally we commit the tx, since the using block will otherwise rollback it await tx.CommitAsync(); - _logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id)); + _logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system, members.Select(m => m.Id)); } - public IAsyncEnumerable GetSwitches(PKSystem system) + public IAsyncEnumerable GetSwitches(SystemId system) { // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list // (maybe when we get caching in?) return _conn.QueryStreamAsync( "select * from switches where system = @System order by timestamp desc", - new {System = system.Id}); + new {System = system}); } public async Task GetSwitchCount(PKSystem system) @@ -304,7 +304,7 @@ namespace PluralKit.Core { new {Switch = sw.Id}); } - public async Task GetLatestSwitch(PKSystem system) => + public async Task GetLatestSwitch(SystemId system) => await GetSwitches(system).FirstOrDefaultAsync(); public async Task MoveSwitch(PKSwitch sw, Instant time)