diff --git a/PluralKit.API/Authentication/SystemTokenAuthenticationHandler.cs b/PluralKit.API/Authentication/SystemTokenAuthenticationHandler.cs index c843eb7e..759d4b87 100644 --- a/PluralKit.API/Authentication/SystemTokenAuthenticationHandler.cs +++ b/PluralKit.API/Authentication/SystemTokenAuthenticationHandler.cs @@ -29,6 +29,7 @@ namespace PluralKit.API return AuthenticateResult.NoResult(); var token = Request.Headers["Authorization"].FirstOrDefault(); + // todo: move this to ModelRepository var systemId = await _db.Execute(c => c.QuerySingleOrDefaultAsync("select id from systems where token = @token", new { token })); if (systemId == null) return AuthenticateResult.Fail("Invalid system token"); diff --git a/PluralKit.API/Controllers/v1/AccountController.cs b/PluralKit.API/Controllers/v1/AccountController.cs index 948e6dbf..4d9407da 100644 --- a/PluralKit.API/Controllers/v1/AccountController.cs +++ b/PluralKit.API/Controllers/v1/AccountController.cs @@ -24,7 +24,7 @@ namespace PluralKit.API [HttpGet("{aid}")] public async Task> GetSystemByAccount(ulong aid) { - var system = await _db.Execute(c => _repo.GetSystemByAccount(c, aid)); + var system = await _repo.GetSystemByAccount(aid); if (system == null) return NotFound("Account not found."); diff --git a/PluralKit.API/Controllers/v1/MemberController.cs b/PluralKit.API/Controllers/v1/MemberController.cs index 050c8674..8ca35626 100644 --- a/PluralKit.API/Controllers/v1/MemberController.cs +++ b/PluralKit.API/Controllers/v1/MemberController.cs @@ -31,7 +31,7 @@ namespace PluralKit.API [HttpGet("{hid}")] public async Task> GetMember(string hid) { - var member = await _db.Execute(conn => _repo.GetMemberByHid(conn, hid)); + var member = await _repo.GetMemberByHid(hid); if (member == null) return NotFound("Member not found."); return Ok(member.ToJson(User.ContextFor(member), needsLegacyProxyTags: true)); @@ -45,9 +45,9 @@ namespace PluralKit.API return BadRequest("Member name must be specified."); var systemId = User.CurrentSystem(); + var systemData = await _repo.GetSystem(systemId); await using var conn = await _db.Obtain(); - var systemData = await _repo.GetSystem(conn, systemId); // Enforce per-system member limit var memberCount = await conn.QuerySingleAsync("select count(*) from members where system = @System", new { System = systemId }); @@ -56,7 +56,7 @@ namespace PluralKit.API return BadRequest($"Member limit reached ({memberCount} / {memberLimit})."); await using var tx = await conn.BeginTransactionAsync(); - var member = await _repo.CreateMember(conn, systemId, properties.Value("name"), transaction: tx); + var member = await _repo.CreateMember(systemId, properties.Value("name"), conn); MemberPatch patch; try @@ -75,7 +75,7 @@ namespace PluralKit.API return BadRequest($"Request field '{e.Message}' is invalid."); } - member = await _repo.UpdateMember(conn, member.Id, patch, transaction: tx); + member = await _repo.UpdateMember(member.Id, patch, conn); await tx.CommitAsync(); return Ok(member.ToJson(User.ContextFor(member), needsLegacyProxyTags: true)); } @@ -84,9 +84,7 @@ namespace PluralKit.API [Authorize] public async Task> PatchMember(string hid, [FromBody] JObject changes) { - await using var conn = await _db.Obtain(); - - var member = await _repo.GetMemberByHid(conn, hid); + var member = await _repo.GetMemberByHid(hid); if (member == null) return NotFound("Member not found."); var res = await _auth.AuthorizeAsync(User, member, "EditMember"); @@ -107,7 +105,7 @@ namespace PluralKit.API return BadRequest($"Request field '{e.Message}' is invalid."); } - var newMember = await _repo.UpdateMember(conn, member.Id, patch); + var newMember = await _repo.UpdateMember(member.Id, patch); return Ok(newMember.ToJson(User.ContextFor(newMember), needsLegacyProxyTags: true)); } @@ -115,15 +113,13 @@ namespace PluralKit.API [Authorize] public async Task DeleteMember(string hid) { - await using var conn = await _db.Obtain(); - - var member = await _repo.GetMemberByHid(conn, hid); + var member = await _repo.GetMemberByHid(hid); if (member == null) return NotFound("Member not found."); var res = await _auth.AuthorizeAsync(User, member, "EditMember"); if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system."); - await _repo.DeleteMember(conn, member.Id); + await _repo.DeleteMember(member.Id); return Ok(); } } diff --git a/PluralKit.API/Controllers/v1/SystemController.cs b/PluralKit.API/Controllers/v1/SystemController.cs index 496c402d..30f648aa 100644 --- a/PluralKit.API/Controllers/v1/SystemController.cs +++ b/PluralKit.API/Controllers/v1/SystemController.cs @@ -55,14 +55,14 @@ namespace PluralKit.API [Authorize] public async Task> GetOwnSystem() { - var system = await _db.Execute(c => _repo.GetSystem(c, User.CurrentSystem())); + var system = await _repo.GetSystem(User.CurrentSystem()); return system.ToJson(User.ContextFor(system)); } [HttpGet("{hid}")] public async Task> GetSystem(string hid) { - var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid)); + var system = await _repo.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); return Ok(system.ToJson(User.ContextFor(system))); } @@ -70,14 +70,14 @@ namespace PluralKit.API [HttpGet("{hid}/members")] public async Task>> GetMembers(string hid) { - var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid)); + var system = await _repo.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); if (!system.MemberListPrivacy.CanAccess(User.ContextFor(system))) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list."); - var members = _db.Execute(c => _repo.GetSystemMembers(c, system.Id)); + var members = _repo.GetSystemMembers(system.Id); return Ok(await members .Where(m => m.MemberVisibility.CanAccess(User.ContextFor(system))) .Select(m => m.ToJson(User.ContextFor(system), needsLegacyProxyTags: true)) @@ -89,40 +89,36 @@ namespace PluralKit.API { if (before == null) before = SystemClock.Instance.GetCurrentInstant(); - await using var conn = await _db.Obtain(); - - var system = await _repo.GetSystemByHid(conn, hid); + var system = await _repo.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); var auth = await _auth.AuthorizeAsync(User, system, "ViewFrontHistory"); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history."); - var res = await conn.QueryAsync( + var res = await _db.Execute(conn => conn.QueryAsync( @"select *, array( select members.hid from switch_members, members where switch_members.switch = switches.id and members.id = switch_members.member ) as members from switches where switches.system = @System and switches.timestamp < @Before order by switches.timestamp desc - limit 100;", new { System = system.Id, Before = before }); + limit 100;", new { System = system.Id, Before = before })); return Ok(res); } [HttpGet("{hid}/fronters")] public async Task> GetFronters(string hid) { - await using var conn = await _db.Obtain(); - - var system = await _repo.GetSystemByHid(conn, hid); + var system = await _repo.GetSystemByHid(hid); if (system == null) return NotFound("System not found."); var auth = await _auth.AuthorizeAsync(User, system, "ViewFront"); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter."); - var sw = await _repo.GetLatestSwitch(conn, system.Id); + var sw = await _repo.GetLatestSwitch(system.Id); if (sw == null) return NotFound("System has no registered switches."); - var members = _repo.GetSwitchMembers(conn, sw.Id); + var members = _db.Execute(conn => _repo.GetSwitchMembers(conn, sw.Id)); return Ok(new FrontersReturn { Timestamp = sw.Timestamp, @@ -134,8 +130,7 @@ namespace PluralKit.API [Authorize] public async Task> EditSystem([FromBody] JObject changes) { - await using var conn = await _db.Obtain(); - var system = await _repo.GetSystem(conn, User.CurrentSystem()); + var system = await _repo.GetSystem(User.CurrentSystem()); SystemPatch patch; try @@ -152,7 +147,7 @@ namespace PluralKit.API return BadRequest($"Request field '{e.Message}' is invalid."); } - system = await _repo.UpdateSystem(conn, system!.Id, patch); + system = await _repo.UpdateSystem(system!.Id, patch); return Ok(system.ToJson(User.ContextFor(system))); } @@ -166,7 +161,7 @@ namespace PluralKit.API await using var conn = await _db.Obtain(); // We get the current switch, if it exists - var latestSwitch = await _repo.GetLatestSwitch(conn, User.CurrentSystem()); + var latestSwitch = await _repo.GetLatestSwitch(User.CurrentSystem()); if (latestSwitch != null) { var latestSwitchMembers = _repo.GetSwitchMembers(conn, latestSwitch.Id); diff --git a/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs b/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs index 61f7e881..7814a095 100644 --- a/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs +++ b/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs @@ -130,6 +130,7 @@ namespace PluralKit.Bot // if we can't, big error. Every group name must be valid. throw new PKError(ctx.CreateGroupNotFoundError(ctx.PopArgument())); + // todo: remove this, the database query enforces the restriction if (restrictToSystem != null && group.System != restrictToSystem) throw Errors.NotOwnGroupError; // TODO: name *which* group? diff --git a/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs b/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs index f7061150..7458a597 100644 --- a/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs +++ b/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs @@ -49,14 +49,12 @@ namespace PluralKit.Bot // - A @mention of an account connected to the system (<@uid>) // - A system hid - await using var conn = await ctx.Database.Obtain(); - // Direct IDs and mentions are both handled by the below method: if (input.TryParseMention(out var id)) - return await ctx.Repository.GetSystemByAccount(conn, id); + return await ctx.Repository.GetSystemByAccount(id); // Finally, try HID parsing - var system = await ctx.Repository.GetSystemByHid(conn, input); + var system = await ctx.Repository.GetSystemByHid(input); return system; } @@ -71,16 +69,15 @@ namespace PluralKit.Bot // - a textual display name of a member *in your own system* // First, if we have a system, try finding by member name in system - await using var conn = await ctx.Database.Obtain(); - if (ctx.System != null && await ctx.Repository.GetMemberByName(conn, ctx.System.Id, input) is PKMember memberByName) + if (ctx.System != null && await ctx.Repository.GetMemberByName(ctx.System.Id, input) is PKMember memberByName) return memberByName; // Then, try member HID parsing: - if (await ctx.Repository.GetMemberByHid(conn, input, restrictToSystem) is PKMember memberByHid) + if (await ctx.Repository.GetMemberByHid(input, restrictToSystem) is PKMember memberByHid) return memberByHid; // And if that again fails, we try finding a member with a display name matching the argument from the system - if (ctx.System != null && await ctx.Repository.GetMemberByDisplayName(conn, ctx.System.Id, input) is PKMember memberByDisplayName) + if (ctx.System != null && await ctx.Repository.GetMemberByDisplayName(ctx.System.Id, input) is PKMember memberByDisplayName) return memberByDisplayName; // We didn't find anything, so we return null. @@ -107,12 +104,11 @@ namespace PluralKit.Bot { var input = ctx.PeekArgument(); - await using var conn = await ctx.Database.Obtain(); - if (ctx.System != null && await ctx.Repository.GetGroupByName(conn, ctx.System.Id, input) is { } byName) + if (ctx.System != null && await ctx.Repository.GetGroupByName(ctx.System.Id, input) is { } byName) return byName; - if (await ctx.Repository.GetGroupByHid(conn, input, restrictToSystem) is { } byHid) + if (await ctx.Repository.GetGroupByHid(input, restrictToSystem) is { } byHid) return byHid; - if (await ctx.Repository.GetGroupByDisplayName(conn, ctx.System.Id, input) is { } byDisplayName) + if (await ctx.Repository.GetGroupByDisplayName(ctx.System.Id, input) is { } byDisplayName) return byDisplayName; return null; diff --git a/PluralKit.Bot/Commands/Admin.cs b/PluralKit.Bot/Commands/Admin.cs index d47b0fe2..db563da9 100644 --- a/PluralKit.Bot/Commands/Admin.cs +++ b/PluralKit.Bot/Commands/Admin.cs @@ -31,14 +31,14 @@ namespace PluralKit.Bot if (!Regex.IsMatch(newHid, "^[a-z]{5}$")) throw new PKError($"Invalid new system ID `{newHid}`."); - var existingSystem = await _db.Execute(c => _repo.GetSystemByHid(c, newHid)); + var existingSystem = _repo.GetSystemByHid(newHid); if (existingSystem != null) throw new PKError($"Another system already exists with ID `{newHid}`."); if (!await ctx.PromptYesNo($"Change system ID of `{target.Hid}` to `{newHid}`?", "Change")) throw new PKError("ID change cancelled."); - await _db.Execute(c => _repo.UpdateSystem(c, target.Id, new SystemPatch { Hid = newHid })); + await _repo.UpdateSystem(target.Id, new() { Hid = newHid }); await ctx.Reply($"{Emojis.Success} System ID updated (`{target.Hid}` -> `{newHid}`)."); } @@ -54,14 +54,14 @@ namespace PluralKit.Bot if (!Regex.IsMatch(newHid, "^[a-z]{5}$")) throw new PKError($"Invalid new member ID `{newHid}`."); - var existingMember = await _db.Execute(c => _repo.GetMemberByHid(c, newHid)); + var existingMember = await _repo.GetMemberByHid(newHid); if (existingMember != null) throw new PKError($"Another member already exists with ID `{newHid}`."); if (!await ctx.PromptYesNo($"Change member ID of **{target.NameFor(LookupContext.ByNonOwner)}** (`{target.Hid}`) to `{newHid}`?", "Change")) throw new PKError("ID change cancelled."); - await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch { Hid = newHid })); + await _repo.UpdateMember(target.Id, new() { Hid = newHid }); await ctx.Reply($"{Emojis.Success} Member ID updated (`{target.Hid}` -> `{newHid}`)."); } @@ -77,14 +77,14 @@ namespace PluralKit.Bot if (!Regex.IsMatch(newHid, "^[a-z]{5}$")) throw new PKError($"Invalid new group ID `{newHid}`."); - var existingGroup = await _db.Execute(c => _repo.GetGroupByHid(c, newHid)); + var existingGroup = _repo.GetGroupByHid(newHid); if (existingGroup != null) throw new PKError($"Another group already exists with ID `{newHid}`."); if (!await ctx.PromptYesNo($"Change group ID of **{target.Name}** (`{target.Hid}`) to `{newHid}`?", "Change")) throw new PKError("ID change cancelled."); - await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch { Hid = newHid })); + await _repo.UpdateGroup(target.Id, new() { Hid = newHid }); await ctx.Reply($"{Emojis.Success} Group ID updated (`{target.Hid}` -> `{newHid}`)."); } @@ -110,11 +110,7 @@ namespace PluralKit.Bot if (!await ctx.PromptYesNo($"Update member limit from **{currentLimit}** to **{newLimit}**?", "Update")) throw new PKError("Member limit change cancelled."); - await using var conn = await _db.Obtain(); - await _repo.UpdateSystem(conn, target.Id, new SystemPatch - { - MemberLimitOverride = newLimit - }); + await _repo.UpdateSystem(target.Id, new() { MemberLimitOverride = newLimit }); await ctx.Reply($"{Emojis.Success} Member limit updated."); } @@ -140,11 +136,7 @@ namespace PluralKit.Bot if (!await ctx.PromptYesNo($"Update group limit from **{currentLimit}** to **{newLimit}**?", "Update")) throw new PKError("Group limit change cancelled."); - await using var conn = await _db.Obtain(); - await _repo.UpdateSystem(conn, target.Id, new SystemPatch - { - GroupLimitOverride = newLimit - }); + await _repo.UpdateSystem(target.Id, new() { GroupLimitOverride = newLimit }); await ctx.Reply($"{Emojis.Success} Group limit updated."); } } diff --git a/PluralKit.Bot/Commands/Autoproxy.cs b/PluralKit.Bot/Commands/Autoproxy.cs index 0d094937..990060c9 100644 --- a/PluralKit.Bot/Commands/Autoproxy.cs +++ b/PluralKit.Bot/Commands/Autoproxy.cs @@ -94,8 +94,8 @@ namespace PluralKit.Bot var fronters = ctx.MessageContext.LastSwitchMembers; var relevantMember = ctx.MessageContext.AutoproxyMode switch { - AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => _repo.GetMember(c, fronters[0])) : null, - AutoproxyMode.Member => await _db.Execute(c => _repo.GetMember(c, ctx.MessageContext.AutoproxyMember.Value)), + AutoproxyMode.Front => fronters.Length > 0 ? await _repo.GetMember(fronters[0]) : null, + AutoproxyMode.Member => await _repo.GetMember(ctx.MessageContext.AutoproxyMember.Value), _ => null }; @@ -171,8 +171,7 @@ namespace PluralKit.Bot else newTimeout = timeoutPeriod; } - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, - new SystemPatch { LatchTimeout = (int?)newTimeout?.TotalSeconds })); + await _repo.UpdateSystem(ctx.System.Id, new() { LatchTimeout = (int?)newTimeout?.TotalSeconds }); if (newTimeout == null) await ctx.Reply($"{Emojis.Success} Latch timeout reset to default ({ProxyMatcher.DefaultLatchExpiryTime.ToTimeSpan().Humanize(4)})."); @@ -209,14 +208,16 @@ namespace PluralKit.Bot return; } var patch = new AccountPatch { AllowAutoproxy = allow }; - await _db.Execute(conn => _repo.UpdateAccount(conn, ctx.Author.Id, patch)); + await _repo.UpdateAccount(ctx.Author.Id, patch); await ctx.Reply($"{Emojis.Success} Autoproxy {statusString} for account <@{ctx.Author.Id}>."); } - private Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, MemberId? autoproxyMember) + private async Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, MemberId? autoproxyMember) { + await _repo.GetSystemGuild(ctx.Guild.Id, ctx.System.Id); + var patch = new SystemGuildPatch { AutoproxyMode = autoproxyMode, AutoproxyMember = autoproxyMember }; - return _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch)); + await _repo.UpdateSystemGuild(ctx.System.Id, ctx.Guild.Id, patch); } } } \ No newline at end of file diff --git a/PluralKit.Bot/Commands/Checks.cs b/PluralKit.Bot/Commands/Checks.cs index 084f0df4..c3af38ed 100644 --- a/PluralKit.Bot/Commands/Checks.cs +++ b/PluralKit.Bot/Commands/Checks.cs @@ -237,9 +237,7 @@ namespace PluralKit.Bot if (messageId == null || channelId == null) throw new PKError(failedToGetMessage); - await using var conn = await _db.Obtain(); - - var proxiedMsg = await _repo.GetMessage(conn, messageId.Value); + var proxiedMsg = await _db.Execute(conn => _repo.GetMessage(conn, messageId.Value)); if (proxiedMsg != null) { await ctx.Reply($"{Emojis.Success} This message was proxied successfully."); @@ -276,8 +274,8 @@ namespace PluralKit.Bot throw new PKError("Unable to get the channel associated with this message."); // using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId - var context = await _repo.GetMessageContext(conn, msg.Author.Id, channel.GuildId.Value, msg.ChannelId); - var members = (await _repo.GetProxyMembers(conn, msg.Author.Id, channel.GuildId.Value)).ToList(); + var context = await _repo.GetMessageContext(msg.Author.Id, channel.GuildId.Value, msg.ChannelId); + var members = (await _repo.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList(); // Run everything through the checks, catch the ProxyCheckFailedException, and reply with the error message. try diff --git a/PluralKit.Bot/Commands/Groups.cs b/PluralKit.Bot/Commands/Groups.cs index 7c35dd6f..e5539384 100644 --- a/PluralKit.Bot/Commands/Groups.cs +++ b/PluralKit.Bot/Commands/Groups.cs @@ -42,16 +42,14 @@ namespace PluralKit.Bot if (groupName.Length > Limits.MaxGroupNameLength) throw new PKError($"Group name too long ({groupName.Length}/{Limits.MaxGroupNameLength} characters)."); - await using var conn = await _db.Obtain(); - // Check group cap - var existingGroupCount = await _repo.GetSystemGroupCount(conn, ctx.System.Id); + var existingGroupCount = await _repo.GetSystemGroupCount(ctx.System.Id); var groupLimit = ctx.System.GroupLimitOverride ?? Limits.MaxGroupCount; if (existingGroupCount >= groupLimit) throw new PKError($"System has reached the maximum number of groups ({groupLimit}). Please delete unused groups first in order to create new ones."); // Warn if there's already a group by this name - var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, groupName); + var existingGroup = await _repo.GetGroupByName(ctx.System.Id, groupName); if (existingGroup != null) { var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to create another group with the same name?"; @@ -59,7 +57,7 @@ namespace PluralKit.Bot throw new PKError("Group creation cancelled."); } - var newGroup = await _repo.CreateGroup(conn, ctx.System.Id, groupName); + var newGroup = await _repo.CreateGroup(ctx.System.Id, groupName); var eb = new EmbedBuilder() .Description($"Your new group, **{groupName}**, has been created, with the group ID **`{newGroup.Hid}`**.\nBelow are a couple of useful commands:") @@ -82,10 +80,8 @@ namespace PluralKit.Bot if (newName.Length > Limits.MaxGroupNameLength) throw new PKError($"New group name too long ({newName.Length}/{Limits.MaxMemberNameLength} characters)."); - await using var conn = await _db.Obtain(); - // Warn if there's already a group by this name - var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, newName); + var existingGroup = await _repo.GetGroupByName(ctx.System.Id, newName); if (existingGroup != null && existingGroup.Id != target.Id) { var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to rename this group to that name too?"; @@ -93,7 +89,7 @@ namespace PluralKit.Bot throw new PKError("Group rename cancelled."); } - await _repo.UpdateGroup(conn, target.Id, new GroupPatch { Name = newName }); + await _repo.UpdateGroup(target.Id, new() { Name = newName }); await ctx.Reply($"{Emojis.Success} Group name changed from **{target.Name}** to **{newName}**."); } @@ -139,7 +135,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("this group's display name")) { var patch = new GroupPatch { DisplayName = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); + await _repo.UpdateGroup(target.Id, patch); await ctx.Reply($"{Emojis.Success} Group display name cleared."); } @@ -148,7 +144,7 @@ namespace PluralKit.Bot var newDisplayName = ctx.RemainderOrNull(skipFlags: false).NormalizeLineEndSpacing(); var patch = new GroupPatch { DisplayName = Partial.Present(newDisplayName) }; - await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); + await _repo.UpdateGroup(target.Id, patch); await ctx.Reply($"{Emojis.Success} Group display name changed."); } @@ -190,7 +186,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("this group's description")) { var patch = new GroupPatch { Description = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); + await _repo.UpdateGroup(target.Id, patch); await ctx.Reply($"{Emojis.Success} Group description cleared."); } else @@ -200,7 +196,7 @@ namespace PluralKit.Bot throw Errors.StringTooLongError("Description", description.Length, Limits.MaxDescriptionLength); var patch = new GroupPatch { Description = Partial.Present(description) }; - await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); + await _repo.UpdateGroup(target.Id, patch); await ctx.Reply($"{Emojis.Success} Group description changed."); } @@ -212,7 +208,7 @@ namespace PluralKit.Bot { ctx.CheckOwnGroup(target); - await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch { Icon = null })); + await _repo.UpdateGroup(target.Id, new() { Icon = null }); await ctx.Reply($"{Emojis.Success} Group icon cleared."); } @@ -222,7 +218,7 @@ namespace PluralKit.Bot await AvatarUtils.VerifyAvatarOrThrow(_client, img.Url); - await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch { Icon = img.Url })); + await _repo.UpdateGroup(target.Id, new() { Icon = img.Url }); var msg = img.Source switch { @@ -274,7 +270,7 @@ namespace PluralKit.Bot { ctx.CheckOwnGroup(target); - await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch { BannerImage = null })); + await _repo.UpdateGroup(target.Id, new() { BannerImage = null }); await ctx.Reply($"{Emojis.Success} Group banner image cleared."); } @@ -284,7 +280,7 @@ namespace PluralKit.Bot await AvatarUtils.VerifyAvatarOrThrow(_client, img.Url, isFullSizeImage: true); - await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch { BannerImage = img.Url })); + await _repo.UpdateGroup(target.Id, new() { BannerImage = img.Url }); var msg = img.Source switch { @@ -338,7 +334,7 @@ namespace PluralKit.Bot ctx.CheckOwnGroup(target); var patch = new GroupPatch { Color = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); + await _repo.UpdateGroup(target.Id, patch); await ctx.Reply($"{Emojis.Success} Group color cleared."); } @@ -368,7 +364,7 @@ namespace PluralKit.Bot if (!Regex.IsMatch(color, "^[0-9a-fA-F]{6}$")) throw Errors.InvalidColorError(color); var patch = new GroupPatch { Color = Partial.Present(color.ToLowerInvariant()) }; - await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); + await _repo.UpdateGroup(target.Id, patch); await ctx.Reply(embed: new EmbedBuilder() .Title($"{Emojis.Success} Group color changed.") @@ -389,7 +385,6 @@ namespace PluralKit.Bot ctx.CheckSystemPrivacy(system, system.GroupListPrivacy); // TODO: integrate with the normal "search" system - await using var conn = await _db.Obtain(); var pctx = LookupContext.ByNonOwner; if (ctx.MatchFlag("a", "all")) @@ -400,7 +395,7 @@ namespace PluralKit.Bot throw new PKError("You do not have permission to access this information."); } - var groups = (await conn.QueryGroupList(system.Id)) + var groups = (await _db.Execute(conn => conn.QueryGroupList(system.Id))) .Where(g => g.Visibility.CanAccess(pctx)) .OrderBy(g => g.Name, StringComparer.InvariantCultureIgnoreCase) .ToList(); @@ -434,8 +429,7 @@ namespace PluralKit.Bot public async Task ShowGroupCard(Context ctx, PKGroup target) { - await using var conn = await _db.Obtain(); - var system = await GetGroupSystem(ctx, target, conn); + var system = await GetGroupSystem(ctx, target); await ctx.Reply(embed: await _embeds.CreateGroupEmbed(ctx, system, target)); } @@ -448,10 +442,8 @@ namespace PluralKit.Bot .Distinct() .ToList(); - await using var conn = await _db.Obtain(); - - var existingMembersInGroup = (await conn.QueryMemberList(target.System, - new DatabaseViewsExt.MemberListQueryOptions { GroupFilter = target.Id })) + var existingMembersInGroup = (await _db.Execute(conn => conn.QueryMemberList(target.System, + new DatabaseViewsExt.MemberListQueryOptions { GroupFilter = target.Id }))) .Select(m => m.Id.Value) .Distinct() .ToHashSet(); @@ -463,14 +455,14 @@ namespace PluralKit.Bot toAction = members .Where(m => !existingMembersInGroup.Contains(m.Value)) .ToList(); - await _repo.AddMembersToGroup(conn, target.Id, toAction); + await _repo.AddMembersToGroup(target.Id, toAction); } else if (op == AddRemoveOperation.Remove) { toAction = members .Where(m => existingMembersInGroup.Contains(m.Value)) .ToList(); - await _repo.RemoveMembersFromGroup(conn, target.Id, toAction); + await _repo.RemoveMembersFromGroup(target.Id, toAction); } else return; // otherwise toAction "may be undefined" @@ -479,9 +471,7 @@ namespace PluralKit.Bot public async Task ListGroupMembers(Context ctx, PKGroup target) { - await using var conn = await _db.Obtain(); - - var targetSystem = await GetGroupSystem(ctx, target, conn); + var targetSystem = await GetGroupSystem(ctx, target); ctx.CheckSystemPrivacy(targetSystem, target.ListPrivacy); var opts = ctx.ParseMemberListOptions(ctx.LookupContextFor(target.System)); @@ -523,7 +513,7 @@ namespace PluralKit.Bot async Task SetAll(PrivacyLevel level) { - await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithAllPrivacy(level))); + await _repo.UpdateGroup(target.Id, new GroupPatch().WithAllPrivacy(level)); if (level == PrivacyLevel.Private) await ctx.Reply($"{Emojis.Success} All {target.Name}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the group card."); @@ -533,7 +523,7 @@ namespace PluralKit.Bot async Task SetLevel(GroupPrivacySubject subject, PrivacyLevel level) { - await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithPrivacy(subject, level))); + await _repo.UpdateGroup(target.Id, new GroupPatch().WithPrivacy(subject, level)); var subjectName = subject switch { @@ -576,19 +566,17 @@ namespace PluralKit.Bot if (!await ctx.ConfirmWithReply(target.Hid)) throw new PKError($"Group deletion cancelled. Note that you must reply with your group ID (`{target.Hid}`) *verbatim*."); - await _db.Execute(conn => _repo.DeleteGroup(conn, target.Id)); + await _repo.DeleteGroup(target.Id); await ctx.Reply($"{Emojis.Success} Group deleted."); } public async Task GroupFrontPercent(Context ctx, PKGroup target) { - await using var conn = await _db.Obtain(); - - var targetSystem = await GetGroupSystem(ctx, target, conn); + var targetSystem = await GetGroupSystem(ctx, target); ctx.CheckSystemPrivacy(targetSystem, targetSystem.FrontHistoryPrivacy); - var totalSwitches = await _db.Execute(conn => _repo.GetSwitchCount(conn, targetSystem.Id)); + var totalSwitches = await _repo.GetSwitchCount(targetSystem.Id); if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; string durationStr = ctx.RemainderOrNull() ?? "30d"; @@ -611,12 +599,12 @@ namespace PluralKit.Bot await ctx.Reply(embed: await _embeds.CreateFrontPercentEmbed(frontpercent, targetSystem, target, targetSystem.Zone, ctx.LookupContextFor(targetSystem), title.ToString(), ignoreNoFronters, showFlat)); } - private async Task GetGroupSystem(Context ctx, PKGroup target, IPKConnection conn) + private async Task GetGroupSystem(Context ctx, PKGroup target) { var system = ctx.System; if (system?.Id == target.System) return system; - return await _repo.GetSystem(conn, target.System)!; + return await _repo.GetSystem(target.System)!; } } } \ No newline at end of file diff --git a/PluralKit.Bot/Commands/Member.cs b/PluralKit.Bot/Commands/Member.cs index 2dba812d..f46e034b 100644 --- a/PluralKit.Bot/Commands/Member.cs +++ b/PluralKit.Bot/Commands/Member.cs @@ -40,7 +40,7 @@ namespace PluralKit.Bot throw Errors.StringTooLongError("Member name", memberName.Length, Limits.MaxMemberNameLength); // Warn if there's already a member by this name - var existingMember = await _db.Execute(c => _repo.GetMemberByName(c, ctx.System.Id, memberName)); + var existingMember = await _repo.GetMemberByName(ctx.System.Id, memberName); if (existingMember != null) { var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (with ID `{existingMember.Hid}`). Do you want to create another member with the same name?"; @@ -50,13 +50,13 @@ namespace PluralKit.Bot await using var conn = await _db.Obtain(); // Enforce per-system member limit - var memberCount = await _repo.GetSystemMemberCount(conn, ctx.System.Id); + var memberCount = await _repo.GetSystemMemberCount(ctx.System.Id); var memberLimit = ctx.System.MemberLimitOverride ?? Limits.MaxMemberCount; if (memberCount >= memberLimit) throw Errors.MemberLimitReachedError(memberLimit); // Create the member - var member = await _repo.CreateMember(conn, ctx.System.Id, memberName); + var member = await _repo.CreateMember(ctx.System.Id, memberName); memberCount++; // Try to match an image attached to the message @@ -67,7 +67,7 @@ namespace PluralKit.Bot try { await AvatarUtils.VerifyAvatarOrThrow(_client, avatarArg.Url); - await _db.Execute(conn => _repo.UpdateMember(conn, member.Id, new MemberPatch { AvatarUrl = avatarArg.Url })); + await _repo.UpdateMember(member.Id, new MemberPatch { AvatarUrl = avatarArg.Url }); } catch (Exception e) { @@ -77,6 +77,7 @@ namespace PluralKit.Bot // Send confirmation and space hint await ctx.Reply($"{Emojis.Success} Member \"{memberName}\" (`{member.Hid}`) registered! Check out the getting started page for how to get a member up and running: https://pluralkit.me/start#create-a-member"); + // todo: move this to ModelRepository if (await _db.Execute(conn => conn.QuerySingleAsync("select has_private_members(@System)", new { System = ctx.System.Id }))) //if has private members await ctx.Reply($"{Emojis.Warn} This member is currently **public**. To change this, use `pk;member {member.Hid} private`."); @@ -95,7 +96,7 @@ namespace PluralKit.Bot public async Task ViewMember(Context ctx, PKMember target) { - var system = await _db.Execute(c => _repo.GetSystem(c, target.System)); + var system = await _repo.GetSystem(target.System); await ctx.Reply(embed: await _embeds.CreateMemberEmbed(system, target, ctx.Guild, ctx.LookupContextFor(system))); } diff --git a/PluralKit.Bot/Commands/MemberAvatar.cs b/PluralKit.Bot/Commands/MemberAvatar.cs index bb4c0f2c..8aa60a3c 100644 --- a/PluralKit.Bot/Commands/MemberAvatar.cs +++ b/PluralKit.Bot/Commands/MemberAvatar.cs @@ -72,14 +72,14 @@ namespace PluralKit.Bot public async Task ServerAvatar(Context ctx, PKMember target) { ctx.CheckGuildContext(); - var guildData = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); + var guildData = await _repo.GetMemberGuild(ctx.Guild.Id, target.Id); await AvatarCommandTree(AvatarLocation.Server, ctx, target, guildData); } public async Task Avatar(Context ctx, PKMember target) { var guildData = ctx.Guild != null ? - await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)) + await _repo.GetMemberGuild(ctx.Guild.Id, target.Id) : null; await AvatarCommandTree(AvatarLocation.Member, ctx, target, guildData); @@ -147,11 +147,9 @@ namespace PluralKit.Bot switch (location) { case AvatarLocation.Server: - var serverPatch = new MemberGuildPatch { AvatarUrl = url }; - return _db.Execute(c => _repo.UpsertMemberGuild(c, target.Id, ctx.Guild.Id, serverPatch)); + return _repo.UpdateMemberGuild(target.Id, ctx.Guild.Id, new() { AvatarUrl = url }); case AvatarLocation.Member: - var memberPatch = new MemberPatch { AvatarUrl = url }; - return _db.Execute(c => _repo.UpdateMember(c, target.Id, memberPatch)); + return _repo.UpdateMember(target.Id, new() { AvatarUrl = url }); default: throw new ArgumentOutOfRangeException($"Unknown avatar location {location}"); } diff --git a/PluralKit.Bot/Commands/MemberEdit.cs b/PluralKit.Bot/Commands/MemberEdit.cs index 4d3f9941..6e987e69 100644 --- a/PluralKit.Bot/Commands/MemberEdit.cs +++ b/PluralKit.Bot/Commands/MemberEdit.cs @@ -35,7 +35,7 @@ namespace PluralKit.Bot throw Errors.StringTooLongError("Member name", newName.Length, Limits.MaxMemberNameLength); // Warn if there's already a member by this name - var existingMember = await _db.Execute(conn => _repo.GetMemberByName(conn, ctx.System.Id, newName)); + var existingMember = await _repo.GetMemberByName(ctx.System.Id, newName); if (existingMember != null && existingMember.Id != target.Id) { var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (`{existingMember.Hid}`). Do you want to rename this member to that name too?"; @@ -44,7 +44,7 @@ namespace PluralKit.Bot // Rename the member var patch = new MemberPatch { Name = Partial.Present(newName) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member renamed."); if (newName.Contains(" ")) await ctx.Reply($"{Emojis.Note} Note that this member's name now contains spaces. You will need to surround it with \"double quotes\" when using commands referring to it."); @@ -52,7 +52,7 @@ namespace PluralKit.Bot if (ctx.Guild != null) { - var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); + var memberGuildConfig = await _repo.GetMemberGuild(ctx.Guild.Id, target.Id); if (memberGuildConfig.DisplayName != null) await ctx.Reply($"{Emojis.Note} Note that this member has a server name set ({memberGuildConfig.DisplayName}) in this server ({ctx.Guild.Name}), and will be proxied using that name here."); } @@ -94,7 +94,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("this member's description")) { var patch = new MemberPatch { Description = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member description cleared."); } else @@ -104,7 +104,7 @@ namespace PluralKit.Bot throw Errors.StringTooLongError("Description", description.Length, Limits.MaxDescriptionLength); var patch = new MemberPatch { Description = Partial.Present(description) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member description changed."); } @@ -142,7 +142,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("this member's pronouns")) { var patch = new MemberPatch { Pronouns = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member pronouns cleared."); } else @@ -152,7 +152,7 @@ namespace PluralKit.Bot throw Errors.StringTooLongError("Pronouns", pronouns.Length, Limits.MaxPronounsLength); var patch = new MemberPatch { Pronouns = Partial.Present(pronouns) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member pronouns changed."); } @@ -164,7 +164,7 @@ namespace PluralKit.Bot async Task ClearBannerImage() { - await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch { BannerImage = null })); + await _repo.UpdateMember(target.Id, new() { BannerImage = null }); await ctx.Reply($"{Emojis.Success} Member banner image cleared."); } @@ -172,7 +172,7 @@ namespace PluralKit.Bot { await AvatarUtils.VerifyAvatarOrThrow(_client, img.Url, isFullSizeImage: true); - await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch { BannerImage = img.Url })); + await _repo.UpdateMember(target.Id, new() { BannerImage = img.Url }); var msg = img.Source switch { @@ -219,7 +219,7 @@ namespace PluralKit.Bot ctx.CheckOwnMember(target); var patch = new MemberPatch { Color = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member color cleared."); } @@ -251,7 +251,7 @@ namespace PluralKit.Bot if (!Regex.IsMatch(color, "^[0-9a-fA-F]{6}$")) throw Errors.InvalidColorError(color); var patch = new MemberPatch { Color = Partial.Present(color.ToLowerInvariant()) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply(embed: new EmbedBuilder() .Title($"{Emojis.Success} Member color changed.") @@ -267,7 +267,7 @@ namespace PluralKit.Bot ctx.CheckOwnMember(target); var patch = new MemberPatch { Birthday = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member birthdate cleared."); } @@ -292,7 +292,7 @@ namespace PluralKit.Bot if (birthday == null) throw Errors.BirthdayParseError(birthdayStr); var patch = new MemberPatch { Birthday = Partial.Present(birthday) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member birthdate changed."); } @@ -304,7 +304,7 @@ namespace PluralKit.Bot MemberGuildSettings memberGuildConfig = null; if (ctx.Guild != null) - memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); + memberGuildConfig = await _repo.GetMemberGuild(ctx.Guild.Id, target.Id); var eb = new EmbedBuilder() .Title($"Member names") @@ -341,7 +341,7 @@ namespace PluralKit.Bot var successStr = text; if (ctx.Guild != null) { - var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); + var memberGuildConfig = await _repo.GetMemberGuild(ctx.Guild.Id, target.Id); if (memberGuildConfig.DisplayName != null) successStr += $" However, this member has a server name set in this server ({ctx.Guild.Name}), and will be proxied using that name, \"{memberGuildConfig.DisplayName}\", here."; } @@ -379,7 +379,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("this member's display name")) { var patch = new MemberPatch { DisplayName = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await PrintSuccess($"{Emojis.Success} Member display name cleared. This member will now be proxied using their member name \"{target.NameFor(ctx)}\"."); } @@ -388,7 +388,7 @@ namespace PluralKit.Bot var newDisplayName = ctx.RemainderOrNull(skipFlags: false).NormalizeLineEndSpacing(); var patch = new MemberPatch { DisplayName = Partial.Present(newDisplayName) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await PrintSuccess($"{Emojis.Success} Member display name changed. This member will now be proxied using the name \"{newDisplayName}\"."); } @@ -403,10 +403,10 @@ namespace PluralKit.Bot noServerNameSetMessage += $" To set one, type `pk;member {target.Reference()} servername `."; // No perms check, display name isn't covered by member privacy + var memberGuildConfig = await _repo.GetMemberGuild(ctx.Guild.Id, target.Id); if (ctx.MatchRaw()) { - MemberGuildSettings memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); if (memberGuildConfig.DisplayName == null) await ctx.Reply(noServerNameSetMessage); @@ -427,8 +427,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("this member's server name")) { - var patch = new MemberGuildPatch { DisplayName = null }; - await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch)); + await _repo.UpdateMemberGuild(target.Id, ctx.Guild.Id, new() { DisplayName = null }); if (target.DisplayName != null) await ctx.Reply($"{Emojis.Success} Member server name cleared. This member will now be proxied using their global display name \"{target.DisplayName}\" in this server ({ctx.Guild.Name})."); @@ -439,8 +438,7 @@ namespace PluralKit.Bot { var newServerName = ctx.RemainderOrNull(skipFlags: false).NormalizeLineEndSpacing(); - var patch = new MemberGuildPatch { DisplayName = newServerName }; - await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch)); + await _repo.UpdateMemberGuild(target.Id, ctx.Guild.Id, new() { DisplayName = newServerName }); await ctx.Reply($"{Emojis.Success} Member server name changed. This member will now be proxied using the name \"{newServerName}\" in this server ({ctx.Guild.Name})."); } @@ -464,7 +462,7 @@ namespace PluralKit.Bot }; var patch = new MemberPatch { KeepProxy = Partial.Present(newValue) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); if (newValue) await ctx.Reply($"{Emojis.Success} Member proxy tags will now be included in the resulting message when proxying."); @@ -491,7 +489,7 @@ namespace PluralKit.Bot }; var patch = new MemberPatch { AllowAutoproxy = Partial.Present(newValue) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); if (newValue) await ctx.Reply($"{Emojis.Success} Latch / front autoproxy have been **enabled** for this member."); @@ -523,11 +521,11 @@ namespace PluralKit.Bot // Get guild settings (mostly for warnings and such) MemberGuildSettings guildSettings = null; if (ctx.Guild != null) - guildSettings = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); + guildSettings = await _repo.GetMemberGuild(ctx.Guild.Id, target.Id); async Task SetAll(PrivacyLevel level) { - await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithAllPrivacy(level))); + await _repo.UpdateMember(target.Id, new MemberPatch().WithAllPrivacy(level)); if (level == PrivacyLevel.Private) await ctx.Reply($"{Emojis.Success} All {target.NameFor(ctx)}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the member card."); @@ -537,7 +535,7 @@ namespace PluralKit.Bot async Task SetLevel(MemberPrivacySubject subject, PrivacyLevel level) { - await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithPrivacy(subject, level))); + await _repo.UpdateMember(target.Id, new MemberPatch().WithPrivacy(subject, level)); var subjectName = subject switch { @@ -596,7 +594,7 @@ namespace PluralKit.Bot await ctx.Reply($"{Emojis.Warn} Are you sure you want to delete \"{target.NameFor(ctx)}\"? If so, reply to this message with the member's ID (`{target.Hid}`). __***This cannot be undone!***__"); if (!await ctx.ConfirmWithReply(target.Hid)) throw Errors.MemberDeleteCancelled; - await _db.Execute(conn => _repo.DeleteMember(conn, target.Id)); + await _repo.DeleteMember(target.Id); await ctx.Reply($"{Emojis.Success} Member deleted."); } diff --git a/PluralKit.Bot/Commands/MemberGroup.cs b/PluralKit.Bot/Commands/MemberGroup.cs index 37cdc281..8be13ff1 100644 --- a/PluralKit.Bot/Commands/MemberGroup.cs +++ b/PluralKit.Bot/Commands/MemberGroup.cs @@ -29,8 +29,7 @@ namespace PluralKit.Bot .Distinct() .ToList(); - await using var conn = await _db.Obtain(); - var existingGroups = (await _repo.GetMemberGroups(conn, target.Id).ToListAsync()) + var existingGroups = (await _repo.GetMemberGroups(target.Id).ToListAsync()) .Select(g => g.Id) .Distinct() .ToList(); @@ -43,7 +42,7 @@ namespace PluralKit.Bot .Where(group => !existingGroups.Contains(group)) .ToList(); - await _repo.AddGroupsToMember(conn, target.Id, toAction); + await _repo.AddGroupsToMember(target.Id, toAction); } else if (op == Groups.AddRemoveOperation.Remove) { @@ -51,7 +50,7 @@ namespace PluralKit.Bot .Where(group => existingGroups.Contains(group)) .ToList(); - await _repo.RemoveGroupsFromMember(conn, target.Id, toAction); + await _repo.RemoveGroupsFromMember(target.Id, toAction); } else return; // otherwise toAction "may be unassigned" @@ -60,11 +59,9 @@ namespace PluralKit.Bot public async Task List(Context ctx, PKMember target) { - await using var conn = await _db.Obtain(); - var pctx = ctx.LookupContextFor(target.System); - var groups = await _repo.GetMemberGroups(conn, target.Id) + var groups = await _repo.GetMemberGroups(target.Id) .Where(g => g.Visibility.CanAccess(pctx)) .OrderBy(g => g.Name, StringComparer.InvariantCultureIgnoreCase) .ToListAsync(); diff --git a/PluralKit.Bot/Commands/MemberProxy.cs b/PluralKit.Bot/Commands/MemberProxy.cs index e46f2bb9..0651445f 100644 --- a/PluralKit.Bot/Commands/MemberProxy.cs +++ b/PluralKit.Bot/Commands/MemberProxy.cs @@ -57,7 +57,7 @@ namespace PluralKit.Bot } var patch = new MemberPatch { ProxyTags = Partial.Present(new ProxyTag[0]) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Proxy tags cleared."); } @@ -87,7 +87,7 @@ namespace PluralKit.Bot var newTags = target.ProxyTags.ToList(); newTags.Add(tagToAdd); var patch = new MemberPatch { ProxyTags = Partial.Present(newTags.ToArray()) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Added proxy tags {tagToAdd.ProxyString.AsCode()}."); } @@ -104,7 +104,7 @@ namespace PluralKit.Bot var newTags = target.ProxyTags.ToList(); newTags.Remove(tagToRemove); var patch = new MemberPatch { ProxyTags = Partial.Present(newTags.ToArray()) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Removed proxy tags {tagToRemove.ProxyString.AsCode()}."); } @@ -128,7 +128,7 @@ namespace PluralKit.Bot var newTags = new[] { requestedTag }; var patch = new MemberPatch { ProxyTags = Partial.Present(newTags) }; - await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); + await _repo.UpdateMember(target.Id, patch); await ctx.Reply($"{Emojis.Success} Member proxy tags set to {requestedTag.ProxyString.AsCode()}."); } diff --git a/PluralKit.Bot/Commands/Message.cs b/PluralKit.Bot/Commands/Message.cs index af177f01..6f7ba28f 100644 --- a/PluralKit.Bot/Commands/Message.cs +++ b/PluralKit.Bot/Commands/Message.cs @@ -1,4 +1,5 @@ #nullable enable +using System; using System.Threading.Tasks; using Myriad.Builders; @@ -105,8 +106,7 @@ namespace PluralKit.Bot private async Task FindRecentMessage(Context ctx) { - await using var conn = await _db.Obtain(); - var lastMessage = await _repo.GetLastMessage(conn, ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id); + var lastMessage = await _repo.GetLastMessage(ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id); if (lastMessage == null) return null; @@ -168,7 +168,7 @@ namespace PluralKit.Bot private async Task DeleteCommandMessage(Context ctx, ulong messageId) { - var message = await _db.Execute(conn => _repo.GetCommandMessage(conn, messageId)); + var message = await _repo.GetCommandMessage(messageId); if (message == null) throw Errors.MessageNotFound(messageId); diff --git a/PluralKit.Bot/Commands/Random.cs b/PluralKit.Bot/Commands/Random.cs index 79afe3ba..ff761f38 100644 --- a/PluralKit.Bot/Commands/Random.cs +++ b/PluralKit.Bot/Commands/Random.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -26,13 +27,10 @@ namespace PluralKit.Bot { ctx.CheckSystem(); - var members = await _db.Execute(c => - { - if (ctx.MatchFlag("all", "a")) - return _repo.GetSystemMembers(c, ctx.System.Id); - return _repo.GetSystemMembers(c, ctx.System.Id) - .Where(m => m.MemberVisibility == PrivacyLevel.Public); - }).ToListAsync(); + var members = await _repo.GetSystemMembers(ctx.System.Id).ToListAsync(); + + if (!ctx.MatchFlag("all", "a")) + members = members.Where(m => m.MemberVisibility == PrivacyLevel.Public).ToList(); if (members == null || !members.Any()) throw new PKError("Your system has no members! Please create at least one member before using this command."); diff --git a/PluralKit.Bot/Commands/ServerConfig.cs b/PluralKit.Bot/Commands/ServerConfig.cs index f8d70733..7849e190 100644 --- a/PluralKit.Bot/Commands/ServerConfig.cs +++ b/PluralKit.Bot/Commands/ServerConfig.cs @@ -29,10 +29,11 @@ namespace PluralKit.Bot public async Task SetLogChannel(Context ctx) { ctx.CheckGuildContext().CheckAuthorPermission(PermissionSet.ManageGuild, "Manage Server"); + await _repo.GetGuild(ctx.Guild.Id); if (await ctx.MatchClear("the server log channel")) { - await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, new GuildPatch { LogChannel = null })); + await _repo.UpdateGuild(ctx.Guild.Id, new() { LogChannel = null }); await ctx.Reply($"{Emojis.Success} Proxy logging channel cleared."); return; } @@ -45,8 +46,7 @@ namespace PluralKit.Bot channel = await ctx.MatchChannel(); if (channel == null || channel.GuildId != ctx.Guild.Id) throw Errors.ChannelNotFound(channelString); - var patch = new GuildPatch { LogChannel = channel.Id }; - await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch)); + await _repo.UpdateGuild(ctx.Guild.Id, new() { LogChannel = channel.Id }); await ctx.Reply($"{Emojis.Success} Proxy logging channel set to #{channel.Name}."); } @@ -67,19 +67,16 @@ namespace PluralKit.Bot } ulong? logChannel = null; - await using (var conn = await _db.Obtain()) - { - var config = await _repo.GetGuild(conn, ctx.Guild.Id); - logChannel = config.LogChannel; - var blacklist = config.LogBlacklist.ToHashSet(); - if (enable) - blacklist.ExceptWith(affectedChannels.Select(c => c.Id)); - else - blacklist.UnionWith(affectedChannels.Select(c => c.Id)); + var config = await _repo.GetGuild(ctx.Guild.Id); + logChannel = config.LogChannel; - var patch = new GuildPatch { LogBlacklist = blacklist.ToArray() }; - await _repo.UpsertGuild(conn, ctx.Guild.Id, patch); - } + var blacklist = config.LogBlacklist.ToHashSet(); + if (enable) + blacklist.ExceptWith(affectedChannels.Select(c => c.Id)); + else + blacklist.UnionWith(affectedChannels.Select(c => c.Id)); + + await _repo.UpdateGuild(ctx.Guild.Id, new() { LogBlacklist = blacklist.ToArray() }); await ctx.Reply( $"{Emojis.Success} Message logging for the given channels {(enable ? "enabled" : "disabled")}." + @@ -90,7 +87,7 @@ namespace PluralKit.Bot { ctx.CheckGuildContext().CheckAuthorPermission(PermissionSet.ManageGuild, "Manage Server"); - var blacklist = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id)); + var blacklist = await _repo.GetGuild(ctx.Guild.Id); // Resolve all channels from the cache and order by position var channels = blacklist.Blacklist @@ -151,18 +148,15 @@ namespace PluralKit.Bot affectedChannels.Add(channel); } - await using (var conn = await _db.Obtain()) - { - var guild = await _repo.GetGuild(conn, ctx.Guild.Id); - var blacklist = guild.Blacklist.ToHashSet(); - if (shouldAdd) - blacklist.UnionWith(affectedChannels.Select(c => c.Id)); - else - blacklist.ExceptWith(affectedChannels.Select(c => c.Id)); + var guild = await _repo.GetGuild(ctx.Guild.Id); - var patch = new GuildPatch { Blacklist = blacklist.ToArray() }; - await _repo.UpsertGuild(conn, ctx.Guild.Id, patch); - } + var blacklist = guild.Blacklist.ToHashSet(); + if (shouldAdd) + blacklist.UnionWith(affectedChannels.Select(c => c.Id)); + else + blacklist.ExceptWith(affectedChannels.Select(c => c.Id)); + + await _repo.UpdateGuild(ctx.Guild.Id, new() { Blacklist = blacklist.ToArray() }); await ctx.Reply($"{Emojis.Success} Channels {(shouldAdd ? "added to" : "removed from")} the proxy blacklist."); } @@ -184,7 +178,7 @@ namespace PluralKit.Bot .Title("Log cleanup settings") .Field(new("Supported bots", botList)); - var guildCfg = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id)); + var guildCfg = await _repo.GetGuild(ctx.Guild.Id); if (guildCfg.LogCleanupEnabled) eb.Description("Log cleanup is currently **on** for this server. To disable it, type `pk;logclean off`."); else @@ -193,8 +187,7 @@ namespace PluralKit.Bot return; } - var patch = new GuildPatch { LogCleanupEnabled = newValue }; - await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch)); + await _repo.UpdateGuild(ctx.Guild.Id, new() { LogCleanupEnabled = newValue }); if (newValue) await ctx.Reply($"{Emojis.Success} Log cleanup has been **enabled** for this server. Messages deleted by PluralKit will now be cleaned up from logging channels managed by the following bots:\n- **{botList}**\n\n{Emojis.Note} Make sure PluralKit has the **Manage Messages** permission in the channels in question.\n{Emojis.Note} Also, make sure to blacklist the logging channel itself from the bots in question to prevent conflicts."); diff --git a/PluralKit.Bot/Commands/Switch.cs b/PluralKit.Bot/Commands/Switch.cs index 3581aa9f..2aa425b7 100644 --- a/PluralKit.Bot/Commands/Switch.cs +++ b/PluralKit.Bot/Commands/Switch.cs @@ -45,7 +45,7 @@ namespace PluralKit.Bot // Find the last switch and its members if applicable await using var conn = await _db.Obtain(); - var lastSwitch = await _repo.GetLatestSwitch(conn, ctx.System.Id); + var lastSwitch = await _repo.GetLatestSwitch(ctx.System.Id); if (lastSwitch != null) { var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastSwitch.Id); @@ -72,13 +72,12 @@ namespace PluralKit.Bot var result = DateUtils.ParseDateTime(timeToMove, true, tz); if (result == null) throw Errors.InvalidDateTime(timeToMove); - await using var conn = await _db.Obtain(); var time = result.Value; if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture; // Fetch the last two switches for the system to do bounds checking on - var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync(); + var lastTwoSwitches = await _repo.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); // If we don't have a switch to move, don't bother if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; @@ -92,7 +91,7 @@ namespace PluralKit.Bot // Now we can actually do the move, yay! // But, we do a prompt to confirm. - var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id); + var lastSwitchMembers = _db.Execute(conn => _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id)); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var lastSwitchTime = lastTwoSwitches[0].Timestamp.ToUnixTimeSeconds(); // .FormatZoned(ctx.System) var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration(); @@ -104,7 +103,7 @@ namespace PluralKit.Bot if (!await ctx.PromptYesNo(msg, "Move Switch")) throw Errors.SwitchMoveCancelled; // aaaand *now* we do the move - await _repo.MoveSwitch(conn, lastTwoSwitches[0].Id, time.ToInstant()); + await _repo.MoveSwitch(lastTwoSwitches[0].Id, time.ToInstant()); await ctx.Reply($"{Emojis.Success} Switch moved to ({newSwitchDeltaStr} ago)."); } @@ -130,7 +129,7 @@ namespace PluralKit.Bot // Find the switch to edit await using var conn = await _db.Obtain(); - var lastSwitch = await _repo.GetLatestSwitch(conn, ctx.System.Id); + var lastSwitch = await _repo.GetLatestSwitch(ctx.System.Id); // Make sure there's at least one switch if (lastSwitch == null) throw Errors.NoRegisteredSwitches; var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastSwitch.Id); @@ -170,18 +169,16 @@ namespace PluralKit.Bot var purgeMsg = $"{Emojis.Warn} This will delete *all registered switches* in your system. Are you sure you want to proceed?"; if (!await ctx.PromptYesNo(purgeMsg, "Clear Switches")) throw Errors.GenericCancelled(); - await _db.Execute(c => _repo.DeleteAllSwitches(c, ctx.System.Id)); + await _repo.DeleteAllSwitches(ctx.System.Id); await ctx.Reply($"{Emojis.Success} Cleared system switches!"); return; } - await using var conn = await _db.Obtain(); - // Fetch the last two switches for the system to do bounds checking on - var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync(); + var lastTwoSwitches = await _repo.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; - var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id); + var lastSwitchMembers = _db.Execute(conn => _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id)); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration(); @@ -192,14 +189,14 @@ namespace PluralKit.Bot } else { - var secondSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[1].Id); + var secondSwitchMembers = _db.Execute(conn => _repo.GetSwitchMembers(conn, lastTwoSwitches[1].Id)); var secondSwitchMemberStr = string.Join(", ", await secondSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var secondSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[1].Timestamp).FormatDuration(); msg = $"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr}, {lastSwitchDeltaStr} ago). The next latest switch is {secondSwitchMemberStr} ({secondSwitchDeltaStr} ago). Is this okay?"; } if (!await ctx.PromptYesNo(msg, "Delete Switch")) throw Errors.SwitchDeleteCancelled; - await _repo.DeleteSwitch(conn, lastTwoSwitches[0].Id); + await _repo.DeleteSwitch(lastTwoSwitches[0].Id); await ctx.Reply($"{Emojis.Success} Switch deleted."); } diff --git a/PluralKit.Bot/Commands/System.cs b/PluralKit.Bot/Commands/System.cs index acd25cea..3013679c 100644 --- a/PluralKit.Bot/Commands/System.cs +++ b/PluralKit.Bot/Commands/System.cs @@ -32,12 +32,8 @@ namespace PluralKit.Bot if (systemName != null && systemName.Length > Limits.MaxSystemNameLength) throw Errors.StringTooLongError("System name", systemName.Length, Limits.MaxSystemNameLength); - var system = _db.Execute(async c => - { - var system = await _repo.CreateSystem(c, systemName); - await _repo.AddAccount(c, system.Id, ctx.Author.Id); - return system; - }); + var system = await _repo.CreateSystem(systemName); + await _repo.AddAccount(system.Id, ctx.Author.Id); // TODO: better message, perhaps embed like in groups? await ctx.Reply($"{Emojis.Success} Your system has been created. Type `pk;system` to view it, and type `pk;system help` for more information about commands you can use now. Now that you have that set up, check out the getting started guide on setting up members and proxies: "); diff --git a/PluralKit.Bot/Commands/SystemEdit.cs b/PluralKit.Bot/Commands/SystemEdit.cs index 0dc684ae..7b2ae9ed 100644 --- a/PluralKit.Bot/Commands/SystemEdit.cs +++ b/PluralKit.Bot/Commands/SystemEdit.cs @@ -52,8 +52,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("your system's name")) { - var clearPatch = new SystemPatch { Name = null }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Name = null }); await ctx.Reply($"{Emojis.Success} System name cleared."); } @@ -64,8 +63,7 @@ namespace PluralKit.Bot if (newSystemName.Length > Limits.MaxSystemNameLength) throw Errors.StringTooLongError("System name", newSystemName.Length, Limits.MaxSystemNameLength); - var patch = new SystemPatch { Name = newSystemName }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Name = newSystemName }); await ctx.Reply($"{Emojis.Success} System name changed."); } @@ -100,11 +98,9 @@ namespace PluralKit.Bot if (await ctx.MatchClear("your system's description")) { - var patch = new SystemPatch { Description = null }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Description = null }); await ctx.Reply($"{Emojis.Success} System description cleared."); - return; } else { @@ -112,8 +108,7 @@ namespace PluralKit.Bot if (newDescription.Length > Limits.MaxDescriptionLength) throw Errors.StringTooLongError("Description", newDescription.Length, Limits.MaxDescriptionLength); - var patch = new SystemPatch { Description = newDescription }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Description = newDescription }); await ctx.Reply($"{Emojis.Success} System description changed."); } @@ -125,8 +120,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear()) { - var patch = new SystemPatch { Color = Partial.Null() }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Color = Partial.Null() }); await ctx.Reply($"{Emojis.Success} System color cleared."); } @@ -150,8 +144,7 @@ namespace PluralKit.Bot if (color.StartsWith("#")) color = color.Substring(1); if (!Regex.IsMatch(color, "^[0-9a-fA-F]{6}$")) throw Errors.InvalidColorError(color); - var patch = new SystemPatch { Color = Partial.Present(color.ToLowerInvariant()) }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Color = Partial.Present(color.ToLowerInvariant()) }); await ctx.Reply(embed: new EmbedBuilder() .Title($"{Emojis.Success} System color changed.") @@ -186,8 +179,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear("your system's tag")) { - var patch = new SystemPatch { Tag = null }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Tag = null }); await ctx.Reply($"{Emojis.Success} System tag cleared."); } @@ -198,8 +190,7 @@ namespace PluralKit.Bot if (newTag.Length > Limits.MaxSystemTagLength) throw Errors.StringTooLongError("System tag", newTag.Length, Limits.MaxSystemTagLength); - var patch = new SystemPatch { Tag = newTag }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { Tag = newTag }); await ctx.Reply($"{Emojis.Success} System tag changed. Member names will now end with {newTag.AsCode()} when proxied."); } @@ -211,18 +202,20 @@ namespace PluralKit.Bot var setDisabledWarning = $"{Emojis.Warn} Your system tag is currently **disabled** in this server. No tag will be applied when proxying.\nTo re-enable the system tag in the current server, type `pk;s servertag -enable`."; + var settings = await _repo.GetSystemGuild(ctx.Guild.Id, ctx.System.Id); + async Task Show(bool raw = false) { - if (ctx.MessageContext.SystemGuildTag != null) + if (settings.Tag != null) { if (raw) { - await ctx.Reply($"```{ctx.MessageContext.SystemGuildTag}```"); + await ctx.Reply($"```{settings.Tag}```"); return; } - var msg = $"Your current system tag in '{ctx.Guild.Name}' is {ctx.MessageContext.SystemGuildTag.AsCode()}"; - if (!ctx.MessageContext.TagEnabled) + var msg = $"Your current system tag in '{ctx.Guild.Name}' is {settings.Tag.AsCode()}"; + if (!settings.TagEnabled) msg += ", but it is currently **disabled**. To re-enable it, type `pk;s servertag -enable`."; else msg += ". To change it, type `pk;s servertag `. To clear it, type `pk;s servertag -clear`."; @@ -231,7 +224,7 @@ namespace PluralKit.Bot return; } - else if (!ctx.MessageContext.TagEnabled) + else if (!settings.TagEnabled) await ctx.Reply($"Your global system tag is {ctx.System.Tag}, but it is **disabled** in this server. To re-enable it, type `pk;s servertag -enable`"); else await ctx.Reply($"You currently have no system tag specific to the server '{ctx.Guild.Name}'. To set one, type `pk;s servertag `. To disable the system tag in the current server, type `pk;s servertag -disable`."); @@ -243,8 +236,7 @@ namespace PluralKit.Bot if (newTag != null && newTag.Length > Limits.MaxSystemTagLength) throw Errors.StringTooLongError("System server tag", newTag.Length, Limits.MaxSystemTagLength); - var patch = new SystemGuildPatch { Tag = newTag }; - await _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch)); + await _repo.UpdateSystemGuild(ctx.System.Id, ctx.Guild.Id, new() { Tag = newTag }); await ctx.Reply($"{Emojis.Success} System server tag changed. Member names will now end with {newTag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'."); @@ -254,8 +246,7 @@ namespace PluralKit.Bot async Task Clear() { - var patch = new SystemGuildPatch { Tag = null }; - await _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch)); + await _repo.UpdateSystemGuild(ctx.System.Id, ctx.Guild.Id, new() { Tag = null }); await ctx.Reply($"{Emojis.Success} System server tag cleared. Member names will now end with the global system tag, if there is one set."); @@ -265,8 +256,7 @@ namespace PluralKit.Bot async Task EnableDisable(bool newValue) { - var patch = new SystemGuildPatch { TagEnabled = newValue }; - await _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch)); + await _repo.UpdateSystemGuild(ctx.System.Id, ctx.Guild.Id, new() { TagEnabled = newValue }); await ctx.Reply(PrintEnableDisableResult(newValue, newValue != ctx.MessageContext.TagEnabled)); } @@ -320,7 +310,7 @@ namespace PluralKit.Bot async Task ClearIcon() { - await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch { AvatarUrl = null })); + await _repo.UpdateSystem(ctx.System.Id, new() { AvatarUrl = null }); await ctx.Reply($"{Emojis.Success} System icon cleared."); } @@ -328,7 +318,7 @@ namespace PluralKit.Bot { await AvatarUtils.VerifyAvatarOrThrow(_client, img.Url); - await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch { AvatarUrl = img.Url })); + await _repo.UpdateSystem(ctx.System.Id, new() { AvatarUrl = img.Url }); var msg = img.Source switch { @@ -373,7 +363,7 @@ namespace PluralKit.Bot async Task ClearImage() { - await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch { BannerImage = null })); + await _repo.UpdateSystem(ctx.System.Id, new() { BannerImage = null }); await ctx.Reply($"{Emojis.Success} System banner image cleared."); } @@ -381,7 +371,7 @@ namespace PluralKit.Bot { await AvatarUtils.VerifyAvatarOrThrow(_client, img.Url, isFullSizeImage: true); - await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch { BannerImage = img.Url })); + await _repo.UpdateSystem(ctx.System.Id, new() { BannerImage = img.Url }); var msg = img.Source switch { @@ -428,7 +418,7 @@ namespace PluralKit.Bot if (!await ctx.ConfirmWithReply(ctx.System.Hid)) throw new PKError($"System deletion cancelled. Note that you must reply with your system ID (`{ctx.System.Hid}`) *verbatim*."); - await _db.Execute(conn => _repo.DeleteSystem(conn, ctx.System.Id)); + await _repo.DeleteSystem(ctx.System.Id); await ctx.Reply($"{Emojis.Success} System deleted."); } @@ -440,7 +430,7 @@ namespace PluralKit.Bot var guild = ctx.MatchGuild() ?? ctx.Guild ?? throw new PKError("You must run this command in a server or pass a server ID."); - var gs = await _db.Execute(c => _repo.GetSystemGuild(c, guild.Id, ctx.System.Id)); + var gs = await _repo.GetSystemGuild(guild.Id, ctx.System.Id); string serverText; if (guild.Id == ctx.Guild?.Id) @@ -461,8 +451,7 @@ namespace PluralKit.Bot return; } - var patch = new SystemGuildPatch { ProxyEnabled = newValue }; - await _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, guild.Id, patch)); + await _repo.UpdateSystemGuild(ctx.System.Id, guild.Id, new() { ProxyEnabled = newValue }); if (newValue) await ctx.Reply($"Message proxying in {serverText} is now **enabled** for your system."); @@ -476,8 +465,7 @@ namespace PluralKit.Bot if (await ctx.MatchClear()) { - var clearPatch = new SystemPatch { UiTz = "UTC" }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch)); + await _repo.UpdateSystem(ctx.System.Id, new() { UiTz = "UTC" }); await ctx.Reply($"{Emojis.Success} System time zone cleared (set to UTC)."); return; @@ -498,8 +486,7 @@ namespace PluralKit.Bot var msg = $"This will change the system time zone to **{zone.Id}**. The current time is **{currentTime.FormatZoned()}**. Is this correct?"; if (!await ctx.PromptYesNo(msg, "Change Timezone")) throw Errors.TimezoneChangeCancelled; - var patch = new SystemPatch { UiTz = zone.Id }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { UiTz = zone.Id }); await ctx.Reply($"System time zone changed to **{zone.Id}**."); } @@ -523,7 +510,7 @@ namespace PluralKit.Bot async Task SetLevel(SystemPrivacySubject subject, PrivacyLevel level) { - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithPrivacy(subject, level))); + await _repo.UpdateSystem(ctx.System.Id, new SystemPatch().WithPrivacy(subject, level)); var levelExplanation = level switch { @@ -548,7 +535,7 @@ namespace PluralKit.Bot async Task SetAll(PrivacyLevel level) { - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithAllPrivacy(level))); + await _repo.UpdateSystem(ctx.System.Id, new SystemPatch().WithAllPrivacy(level)); var msg = level switch { @@ -581,15 +568,13 @@ namespace PluralKit.Bot { if (ctx.Match("on", "enable")) { - var patch = new SystemPatch { PingsEnabled = true }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { PingsEnabled = true }); await ctx.Reply("Reaction pings have now been enabled."); } if (ctx.Match("off", "disable")) { - var patch = new SystemPatch { PingsEnabled = false }; - await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); + await _repo.UpdateSystem(ctx.System.Id, new() { PingsEnabled = false }); await ctx.Reply("Reaction pings have now been disabled."); } diff --git a/PluralKit.Bot/Commands/SystemFront.cs b/PluralKit.Bot/Commands/SystemFront.cs index beb5da6a..201a258a 100644 --- a/PluralKit.Bot/Commands/SystemFront.cs +++ b/PluralKit.Bot/Commands/SystemFront.cs @@ -39,9 +39,7 @@ namespace PluralKit.Bot if (system == null) throw Errors.NoSystemError; ctx.CheckSystemPrivacy(system, system.FrontPrivacy); - await using var conn = await _db.Obtain(); - - var sw = await _repo.GetLatestSwitch(conn, system.Id); + var sw = await _repo.GetLatestSwitch(system.Id); if (sw == null) throw Errors.NoRegisteredSwitches; await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone, ctx.LookupContextFor(system))); @@ -53,12 +51,13 @@ namespace PluralKit.Bot ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy); // Gotta be careful here: if we dispose of the connection while the IAE is alive, boom - await using var conn = await _db.Obtain(); + // todo: this comment was here, but we're not getting a connection here anymore + // hopefully nothing breaks? - var totalSwitches = await _repo.GetSwitchCount(conn, system.Id); + var totalSwitches = await _repo.GetSwitchCount(system.Id); if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; - var sws = _repo.GetSwitches(conn, system.Id) + var sws = _repo.GetSwitches(system.Id) .Scan(new FrontHistoryEntry(null, null), (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch)); @@ -80,7 +79,6 @@ namespace PluralKit.Bot var sw = entry.ThisSwitch; // Fetch member list and format - await using var conn = await _db.Obtain(); var members = await _db.Execute(c => _repo.GetSwitchMembers(c, sw.Id)).ToListAsync(); var membersStr = members.Any() ? string.Join(", ", members.Select(m => m.NameFor(ctx))) : "no fronter"; @@ -117,7 +115,7 @@ namespace PluralKit.Bot if (system == null) throw Errors.NoSystemError; ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy); - var totalSwitches = await _db.Execute(conn => _repo.GetSwitchCount(conn, system.Id)); + var totalSwitches = await _repo.GetSwitchCount(system.Id); if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; string durationStr = ctx.RemainderOrNull() ?? "30d"; diff --git a/PluralKit.Bot/Commands/SystemLink.cs b/PluralKit.Bot/Commands/SystemLink.cs index d81d8154..7f9cbfa4 100644 --- a/PluralKit.Bot/Commands/SystemLink.cs +++ b/PluralKit.Bot/Commands/SystemLink.cs @@ -23,20 +23,18 @@ namespace PluralKit.Bot { ctx.CheckSystem(); - await using var conn = await _db.Obtain(); - var account = await ctx.MatchUser() ?? throw new PKSyntaxError("You must pass an account to link with (either ID or @mention)."); - var accountIds = await _repo.GetSystemAccounts(conn, ctx.System.Id); + var accountIds = await _repo.GetSystemAccounts(ctx.System.Id); if (accountIds.Contains(account.Id)) throw Errors.AccountAlreadyLinked; - var existingAccount = await _repo.GetSystemByAccount(conn, account.Id); + var existingAccount = await _repo.GetSystemByAccount(account.Id); if (existingAccount != null) throw Errors.AccountInOtherSystem(existingAccount); var msg = $"{account.Mention()}, please confirm the link."; if (!await ctx.PromptYesNo(msg, "Confirm", user: account, matchFlag: false)) throw Errors.MemberLinkCancelled; - await _repo.AddAccount(conn, ctx.System.Id, account.Id); + await _repo.AddAccount(ctx.System.Id, account.Id); await ctx.Reply($"{Emojis.Success} Account linked to system."); } @@ -44,20 +42,18 @@ namespace PluralKit.Bot { ctx.CheckSystem(); - await using var conn = await _db.Obtain(); - ulong id; if (!ctx.MatchUserRaw(out id)) throw new PKSyntaxError("You must pass an account to link with (either ID or @mention)."); - var accountIds = (await _repo.GetSystemAccounts(conn, ctx.System.Id)).ToList(); + var accountIds = (await _repo.GetSystemAccounts(ctx.System.Id)).ToList(); if (!accountIds.Contains(id)) throw Errors.AccountNotLinked; if (accountIds.Count == 1) throw Errors.UnlinkingLastAccount; var msg = $"Are you sure you want to unlink <@{id}> from your system?"; if (!await ctx.PromptYesNo(msg, "Unlink")) throw Errors.MemberUnlinkCancelled; - await _repo.RemoveAccount(conn, ctx.System.Id, id); + await _repo.RemoveAccount(ctx.System.Id, id); await ctx.Reply($"{Emojis.Success} Account unlinked."); } } diff --git a/PluralKit.Bot/Commands/Token.cs b/PluralKit.Bot/Commands/Token.cs index f46baf6d..8b4c8dfd 100644 --- a/PluralKit.Bot/Commands/Token.cs +++ b/PluralKit.Bot/Commands/Token.cs @@ -50,8 +50,7 @@ namespace PluralKit.Bot private async Task MakeAndSetNewToken(PKSystem system) { - var patch = new SystemPatch { Token = StringUtils.GenerateToken() }; - system = await _db.Execute(conn => _repo.UpdateSystem(conn, system.Id, patch)); + system = await _repo.UpdateSystem(system.Id, new() { Token = StringUtils.GenerateToken() }); return system.Token; } diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index a10bcdc2..7f1115a9 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -71,9 +71,8 @@ namespace PluralKit.Bot // Get message context from DB (tracking w/ metrics) MessageContext ctx; - await using (var conn = await _db.Obtain()) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) - ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, rootChannel.Id); + ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel.Id); // Try each handler until we find one that succeeds if (await TryHandleLogClean(evt, ctx)) @@ -114,7 +113,7 @@ namespace PluralKit.Bot try { - var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null; + var system = ctx.SystemId != null ? await _repo.GetSystem(ctx.SystemId.Value) : null; await _tree.ExecuteCommand(new Context(_services, shard, guild, channel, evt, cmdStart, system, ctx)); } catch (PKError) diff --git a/PluralKit.Bot/Handlers/MessageDeleted.cs b/PluralKit.Bot/Handlers/MessageDeleted.cs index b575b482..cb78fefd 100644 --- a/PluralKit.Bot/Handlers/MessageDeleted.cs +++ b/PluralKit.Bot/Handlers/MessageDeleted.cs @@ -36,7 +36,7 @@ namespace PluralKit.Bot async Task Inner() { await Task.Delay(MessageDeleteDelay); - await _db.Execute(c => _repo.DeleteMessage(c, evt.Id)); + await _repo.DeleteMessage(evt.Id); } _lastMessage.HandleMessageDeletion(evt.ChannelId, evt.Id); @@ -56,7 +56,7 @@ namespace PluralKit.Bot _logger.Information("Bulk deleting {Count} messages in channel {Channel}", evt.Ids.Length, evt.ChannelId); - await _db.Execute(c => _repo.DeleteMessagesBulk(c, evt.Ids)); + await _repo.DeleteMessagesBulk(evt.Ids); } _lastMessage.HandleMessageDeletion(evt.ChannelId, evt.Ids.ToList()); diff --git a/PluralKit.Bot/Handlers/MessageEdited.cs b/PluralKit.Bot/Handlers/MessageEdited.cs index f22acfc6..32760e0b 100644 --- a/PluralKit.Bot/Handlers/MessageEdited.cs +++ b/PluralKit.Bot/Handlers/MessageEdited.cs @@ -63,9 +63,8 @@ namespace PluralKit.Bot // Just run the normal message handling code, with a flag to disable autoproxying MessageContext ctx; - await using (var conn = await _db.Obtain()) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) - ctx = await _repo.GetMessageContext(conn, evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId); + ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId); var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel); var botPermissions = _bot.PermissionsIn(channel.Id); diff --git a/PluralKit.Bot/Handlers/ReactionAdded.cs b/PluralKit.Bot/Handlers/ReactionAdded.cs index 47f1f3b1..c76677c5 100644 --- a/PluralKit.Bot/Handlers/ReactionAdded.cs +++ b/PluralKit.Bot/Handlers/ReactionAdded.cs @@ -73,7 +73,7 @@ namespace PluralKit.Bot return; } - var commandMsg = await _db.Execute(c => _commandMessageService.GetCommandMessage(c, evt.MessageId)); + var commandMsg = await _commandMessageService.GetCommandMessage(evt.MessageId); if (commandMsg != null) { await HandleCommandDeleteReaction(evt, commandMsg); @@ -124,7 +124,7 @@ namespace PluralKit.Bot if (!_bot.PermissionsIn(evt.ChannelId).HasFlag(PermissionSet.ManageMessages)) return; - var system = await _db.Execute(c => _repo.GetSystemByAccount(c, evt.UserId)); + var system = await _repo.GetSystemByAccount(evt.UserId); // Can only delete your own message if (msg.System.Id != system?.Id) return; @@ -138,7 +138,7 @@ namespace PluralKit.Bot // Message was deleted by something/someone else before we got to it } - await _db.Execute(c => _repo.DeleteMessage(c, evt.MessageId)); + await _repo.DeleteMessage(evt.MessageId); } private async ValueTask HandleCommandDeleteReaction(MessageReactionAddEvent evt, CommandMessage? msg) diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index 13f952d4..74d61293 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -61,10 +61,7 @@ namespace PluralKit.Bot List members; // Fetch members and try to match to a specific member using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime)) - { - await using var conn = await _db.Obtain(); - members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.GuildId!.Value)).ToList(); - } + members = (await _repo.GetProxyMembers(message.Author.Id, message.GuildId!.Value)).ToList(); if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Length > 0, allowAutoproxy)) return false; @@ -293,11 +290,8 @@ namespace PluralKit.Bot Sender = triggerMessage.Author.Id }; - async Task SaveMessageInDatabase() - { - await using var conn = await _db.Obtain(); - await _repo.AddMessage(conn, sentMessage); - } + Task SaveMessageInDatabase() + => _repo.AddMessage(sentMessage); Task LogMessageToChannel() => _logChannel.LogMessage(ctx, sentMessage, triggerMessage, proxyMessage).AsTask(); diff --git a/PluralKit.Bot/Services/CommandMessageService.cs b/PluralKit.Bot/Services/CommandMessageService.cs index fc205d50..b3007686 100644 --- a/PluralKit.Bot/Services/CommandMessageService.cs +++ b/PluralKit.Bot/Services/CommandMessageService.cs @@ -28,12 +28,12 @@ namespace PluralKit.Bot public async Task RegisterMessage(ulong messageId, ulong channelId, ulong authorId) { _logger.Debug("Registering command response {MessageId} from author {AuthorId} in {ChannelId}", messageId, authorId, channelId); - await _db.Execute(conn => _repo.SaveCommandMessage(conn, messageId, channelId, authorId)); + await _repo.SaveCommandMessage(messageId, channelId, authorId); } - public async Task GetCommandMessage(IPKConnection conn, ulong messageId) + public async Task GetCommandMessage(ulong messageId) { - return await _repo.GetCommandMessage(conn, messageId); + return await _repo.GetCommandMessage(messageId); } public async Task CleanupOldMessages() @@ -41,7 +41,7 @@ namespace PluralKit.Bot var deleteThresholdInstant = _clock.GetCurrentInstant() - CommandMessageRetention; var deleteThresholdSnowflake = DiscordUtils.InstantToSnowflake(deleteThresholdInstant); - var deletedRows = await _db.Execute(conn => _repo.DeleteCommandMessagesBefore(conn, deleteThresholdSnowflake)); + var deletedRows = await _repo.DeleteCommandMessagesBefore(deleteThresholdSnowflake); _logger.Information("Pruned {DeletedRows} command messages older than retention {Retention} (older than {DeleteThresholdInstant} / {DeleteThresholdSnowflake})", deletedRows, CommandMessageRetention, deleteThresholdInstant, deleteThresholdSnowflake); diff --git a/PluralKit.Bot/Services/EmbedService.cs b/PluralKit.Bot/Services/EmbedService.cs index 7b6a18d9..6dbb1716 100644 --- a/PluralKit.Bot/Services/EmbedService.cs +++ b/PluralKit.Bot/Services/EmbedService.cs @@ -46,13 +46,12 @@ namespace PluralKit.Bot public async Task CreateSystemEmbed(Context cctx, PKSystem system, LookupContext ctx) { - await using var conn = await _db.Obtain(); // Fetch/render info for all accounts simultaneously - var accounts = await _repo.GetSystemAccounts(conn, system.Id); + var accounts = await _repo.GetSystemAccounts(system.Id); var users = (await GetUsers(accounts)).Select(x => x.User?.NameAndMention() ?? $"(deleted account {x.Id})"); - var memberCount = cctx.MatchPrivateFlag(ctx) ? await _repo.GetSystemMemberCount(conn, system.Id, PrivacyLevel.Public) : await _repo.GetSystemMemberCount(conn, system.Id); + var memberCount = cctx.MatchPrivateFlag(ctx) ? await _repo.GetSystemMemberCount(system.Id, PrivacyLevel.Public) : await _repo.GetSystemMemberCount(system.Id); uint color; try @@ -74,10 +73,10 @@ namespace PluralKit.Bot if (system.DescriptionPrivacy.CanAccess(ctx)) eb.Image(new(system.BannerImage)); - var latestSwitch = await _repo.GetLatestSwitch(conn, system.Id); + var latestSwitch = await _repo.GetLatestSwitch(system.Id); if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx)) { - var switchMembers = await _repo.GetSwitchMembers(conn, latestSwitch.Id).ToListAsync(); + var switchMembers = await _db.Execute(conn => _repo.GetSwitchMembers(conn, latestSwitch.Id)).ToListAsync(); if (switchMembers.Count > 0) eb.Field(new("Fronter".ToQuantity(switchMembers.Count, ShowQuantityAs.None), string.Join(", ", switchMembers.Select(m => m.NameFor(ctx))))); } @@ -87,7 +86,7 @@ namespace PluralKit.Bot if (cctx.Guild != null) { - var guildSettings = await _repo.GetSystemGuild(conn, cctx.Guild.Id, system.Id); + var guildSettings = await _repo.GetSystemGuild(cctx.Guild.Id, system.Id); if (guildSettings.Tag != null && guildSettings.TagEnabled) eb.Field(new($"Tag (in server '{cctx.Guild.Name}')", guildSettings.Tag @@ -151,18 +150,16 @@ namespace PluralKit.Bot catch (ArgumentException) { // Bad API use can cause an invalid color string - // TODO: fix that in the API - // for now we just default to a blank color, yolo + // this is now fixed in the API, but might still have some remnants in the database + // so we just default to a blank color, yolo color = DiscordUtils.Gray; } - await using var conn = await _db.Obtain(); - - var guildSettings = guild != null ? await _repo.GetMemberGuild(conn, guild.Id, member.Id) : null; + var guildSettings = guild != null ? await _repo.GetMemberGuild(guild.Id, member.Id) : null; var guildDisplayName = guildSettings?.DisplayName; var avatar = guildSettings?.AvatarUrl ?? member.AvatarFor(ctx); - var groups = await _repo.GetMemberGroups(conn, member.Id) + var groups = await _repo.GetMemberGroups(member.Id) .Where(g => g.Visibility.CanAccess(ctx)) .OrderBy(g => g.Name, StringComparer.InvariantCultureIgnoreCase) .ToListAsync(); @@ -218,10 +215,8 @@ namespace PluralKit.Bot public async Task CreateGroupEmbed(Context ctx, PKSystem system, PKGroup target) { - await using var conn = await _db.Obtain(); - var pctx = ctx.LookupContextFor(system); - var memberCount = ctx.MatchPrivateFlag(pctx) ? await _repo.GetGroupMemberCount(conn, target.Id, PrivacyLevel.Public) : await _repo.GetGroupMemberCount(conn, target.Id); + var memberCount = ctx.MatchPrivateFlag(pctx) ? await _repo.GetGroupMemberCount(target.Id, PrivacyLevel.Public) : await _repo.GetGroupMemberCount(target.Id); var nameField = target.Name; if (system.Name != null) diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index 3037183e..52774a6f 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -43,13 +43,8 @@ namespace PluralKit.Bot var triggerChannel = _cache.GetChannel(proxiedMessage.Channel); - PKSystem system; - PKMember member; - await using (var conn = await _db.Obtain()) - { - system = await _repo.GetSystem(conn, ctx.SystemId.Value); - member = await _repo.GetMember(conn, proxiedMessage.Member); - } + var system = await _repo.GetSystem(ctx.SystemId.Value); + var member = await _repo.GetMember(proxiedMessage.Member); // Send embed! var embed = _embed.CreateLoggedMessageEmbed(trigger, hookMessage, system.Hid, member, triggerChannel.Name, oldContent); @@ -71,7 +66,7 @@ namespace PluralKit.Bot if (proxiedMessage.Guild != trigger.GuildId) { // we're editing a message from a different server, get log channel info from the database - var guild = await _db.Execute(c => _repo.GetGuild(c, proxiedMessage.Guild.Value)); + var guild = await _repo.GetGuild(proxiedMessage.Guild.Value); logChannelId = guild.LogChannel; isBlacklisted = guild.Blacklist.Any(x => x == logChannelId); } diff --git a/PluralKit.Core/Database/Database.cs b/PluralKit.Core/Database/Database.cs index 2ede1507..de98f468 100644 --- a/PluralKit.Core/Database/Database.cs +++ b/PluralKit.Core/Database/Database.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Data; using System.IO; +using System.Runtime.CompilerServices; using System.Threading.Tasks; using App.Metrics; @@ -14,6 +15,9 @@ using Npgsql; using Serilog; +using SqlKata; +using SqlKata.Compilers; + namespace PluralKit.Core { internal class Database: IDatabase @@ -46,6 +50,8 @@ namespace PluralKit.Core }.ConnectionString; } + private static readonly PostgresCompiler _compiler = new(); + public static void InitStatic() { DefaultTypeMap.MatchNamesWithUnderscores = true; @@ -151,28 +157,83 @@ namespace PluralKit.Core public override T[] Parse(object value) => Array.ConvertAll((TInner[])value, v => _factory(v)); } - } - public static class DatabaseExt - { - public static async Task Execute(this IDatabase db, Func func) + public async Task Execute(Func func) { - await using var conn = await db.Obtain(); + await using var conn = await Obtain(); await func(conn); } - public static async Task Execute(this IDatabase db, Func> func) + public async Task Execute(Func> func) { - await using var conn = await db.Obtain(); + await using var conn = await Obtain(); return await func(conn); } - public static async IAsyncEnumerable Execute(this IDatabase db, Func> func) + public async IAsyncEnumerable Execute(Func> func) { - await using var conn = await db.Obtain(); + await using var conn = await Obtain(); await foreach (var val in func(conn)) yield return val; } + + public async Task ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "") + { + var query = _compiler.Compile(q); + using var conn = await Obtain(); + using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) + return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings); + } + + public async Task QueryFirst(Query q, string extraSql = "", [CallerMemberName] string queryName = "") + { + var query = _compiler.Compile(q); + using var conn = await Obtain(); + using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) + return await conn.QueryFirstOrDefaultAsync(query.Sql + $" {extraSql}", query.NamedBindings); + } + + public async Task QueryFirst(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "") + { + if (conn == null) + return await QueryFirst(q, extraSql, queryName); + + var query = _compiler.Compile(q); + using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) + return await conn.QueryFirstOrDefaultAsync(query.Sql + $" {extraSql}", query.NamedBindings); + } + + public async Task> Query(Query q, [CallerMemberName] string queryName = "") + { + var query = _compiler.Compile(q); + using var conn = await Obtain(); + using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) + return await conn.QueryAsync(query.Sql, query.NamedBindings); + } + + public async IAsyncEnumerable QueryStream(Query q, [CallerMemberName] string queryName = "") + { + var query = _compiler.Compile(q); + using var conn = await Obtain(); + using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) + await foreach (var val in conn.QueryStreamAsync(query.Sql, query.NamedBindings)) + yield return val; + } + + // the procedures (message_context and proxy_members, as of writing) have their own metrics tracking elsewhere + // still, including them here for consistency + + public async Task QuerySingleProcedure(string queryName, object param) + { + using var conn = await Obtain(); + return await conn.QueryFirstAsync(queryName, param, commandType: CommandType.StoredProcedure); + } + + public async Task> QueryProcedure(string queryName, object param) + { + using var conn = await Obtain(); + return await conn.QueryAsync(queryName, param, commandType: CommandType.StoredProcedure); + } } } \ No newline at end of file diff --git a/PluralKit.Core/Database/IDatabase.cs b/PluralKit.Core/Database/IDatabase.cs index f6789e94..6c48d6be 100644 --- a/PluralKit.Core/Database/IDatabase.cs +++ b/PluralKit.Core/Database/IDatabase.cs @@ -1,10 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Threading.Tasks; +using SqlKata; + namespace PluralKit.Core { public interface IDatabase { Task ApplyMigrations(); Task Obtain(); + Task Execute(Func func); + Task Execute(Func> func); + IAsyncEnumerable Execute(Func> func); + Task ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = ""); + Task QueryFirst(Query q, string extraSql = "", [CallerMemberName] string queryName = ""); + Task QueryFirst(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = ""); + Task> Query(Query q, [CallerMemberName] string queryName = ""); + IAsyncEnumerable QueryStream(Query q, [CallerMemberName] string queryName = ""); + Task QuerySingleProcedure(string queryName, object param); + Task> QueryProcedure(string queryName, object param); } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Account.cs b/PluralKit.Core/Database/Repository/ModelRepository.Account.cs index 98682b0b..a0c606dc 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Account.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Account.cs @@ -1,21 +1,16 @@ -using System.Collections.Generic; -using System.Data; using System.Threading.Tasks; -using Dapper; +using SqlKata; namespace PluralKit.Core { public partial class ModelRepository { - public async Task UpdateAccount(IPKConnection conn, ulong id, AccountPatch patch) + public async Task UpdateAccount(ulong id, AccountPatch patch) { _logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("accounts", "uid = @uid")) - .WithConstant("uid", id) - .Build(); - await conn.ExecuteAsync(query, pms); + var query = patch.Apply(new Query("accounts").Where("uid", id)); + await _db.ExecuteQuery(query); } - } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs b/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs index 91344561..8e87027e 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs @@ -1,22 +1,33 @@ using System.Threading.Tasks; -using Dapper; +using SqlKata; namespace PluralKit.Core { public partial class ModelRepository { - public Task SaveCommandMessage(IPKConnection conn, ulong messageId, ulong channelId, ulong authorId) => - conn.QueryAsync("insert into command_messages (message_id, channel_id, author_id) values (@Message, @Channel, @Author)", - new { Message = messageId, Channel = channelId, Author = authorId }); + public Task SaveCommandMessage(ulong messageId, ulong channelId, ulong authorId) + { + var query = new Query("command_messages").AsInsert(new + { + message_id = messageId, + channel_id = channelId, + author_id = authorId, + }); + return _db.ExecuteQuery(query); + } - public Task GetCommandMessage(IPKConnection conn, ulong messageId) => - conn.QuerySingleOrDefaultAsync("select * from command_messages where message_id = @Message", - new { Message = messageId }); + public Task GetCommandMessage(ulong messageId) + { + var query = new Query("command_messages").Where("message_id", messageId); + return _db.QueryFirst(query); + } - public Task DeleteCommandMessagesBefore(IPKConnection conn, ulong messageIdThreshold) => - conn.ExecuteAsync("delete from command_messages where message_id < @Threshold", - new { Threshold = messageIdThreshold }); + public Task DeleteCommandMessagesBefore(ulong messageIdThreshold) + { + var query = new Query("command_messages").AsDelete().Where("message_id", "<", messageIdThreshold); + return _db.QueryFirst(query); + } } public class CommandMessage diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Context.cs b/PluralKit.Core/Database/Repository/ModelRepository.Context.cs index 589ce2ec..442ba422 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Context.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Context.cs @@ -1,25 +1,23 @@ using System.Collections.Generic; -using System.Data; using System.Threading.Tasks; -using Dapper; - namespace PluralKit.Core { public partial class ModelRepository { - public Task GetMessageContext(IPKConnection conn, ulong account, ulong guild, ulong channel) - { - return conn.QueryFirstAsync("message_context", - new { account_id = account, guild_id = guild, channel_id = channel }, - commandType: CommandType.StoredProcedure); - } + public Task GetMessageContext(ulong account, ulong guild, ulong channel) + => _db.QuerySingleProcedure("message_context", new + { + account_id = account, + guild_id = guild, + channel_id = channel + }); - public Task> GetProxyMembers(IPKConnection conn, ulong account, ulong guild) - { - return conn.QueryAsync("proxy_members", - new { account_id = account, guild_id = guild }, - commandType: CommandType.StoredProcedure); - } + public Task> GetProxyMembers(ulong account, ulong guild) + => _db.QueryProcedure("proxy_members", new + { + account_id = account, + guild_id = guild + }); } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Group.cs b/PluralKit.Core/Database/Repository/ModelRepository.Group.cs index 74e96358..ddf1f26d 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Group.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Group.cs @@ -1,89 +1,77 @@ #nullable enable using System; -using System.Collections.Generic; -using System.Data; -using System.Linq; -using System.Text; using System.Threading.Tasks; -using Dapper; +using SqlKata; namespace PluralKit.Core { public partial class ModelRepository { - public Task GetGroupByName(IPKConnection conn, SystemId system, string name) => - conn.QueryFirstOrDefaultAsync("select * from groups where system = @System and lower(Name) = lower(@Name)", new { System = system, Name = name }); - - public Task GetGroupByDisplayName(IPKConnection conn, SystemId system, string display_name) => - conn.QueryFirstOrDefaultAsync("select * from groups where system = @System and lower(display_name) = lower(@Name)", new { System = system, Name = display_name }); - - public Task GetGroupByHid(IPKConnection conn, string hid, SystemId? system = null) - => conn.QueryFirstOrDefaultAsync( - "select * from groups where hid = @hid" + (system != null ? " and system = @System" : ""), - new { hid = hid.ToLowerInvariant(), System = system } - ); - - public Task GetGroupByGuid(IPKConnection conn, Guid guid) => - conn.QueryFirstOrDefaultAsync("select * from groups where uuid = @Uuid", new { Uuid = guid }); - - public Task GetGroupMemberCount(IPKConnection conn, GroupId id, PrivacyLevel? privacyFilter = null) + public Task GetGroupByName(SystemId system, string name) { - var query = new StringBuilder("select count(*) from group_members"); - if (privacyFilter != null) - query.Append(" inner join members on group_members.member_id = members.id"); - query.Append(" where group_members.group_id = @Id"); - if (privacyFilter != null) - query.Append(" and members.member_visibility = @PrivacyFilter"); - return conn.QuerySingleOrDefaultAsync(query.ToString(), new { Id = id, PrivacyFilter = privacyFilter }); + var query = new Query("groups").Where("system", system).WhereRaw("lower(name) = lower(?)", name.ToLower()); + return _db.QueryFirst(query); } - public async Task CreateGroup(IPKConnection conn, SystemId system, string name, IDbTransaction? transaction = null) + public Task GetGroupByDisplayName(SystemId system, string display_name) { - var group = await conn.QueryFirstAsync( - "insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *", - new { System = system, Name = name }, transaction); + var query = new Query("groups").Where("system", system).WhereRaw("lower(display_name) = lower(?)", display_name.ToLower()); + return _db.QueryFirst(query); + } + + public Task GetGroupByHid(string hid, SystemId? system = null) + { + var query = new Query("groups").Where("hid", hid.ToLower()); + if (system != null) + query = query.Where("system", system); + return _db.QueryFirst(query); + } + + public Task GetGroupByGuid(Guid uuid) + { + var query = new Query("groups").Where("uuid", uuid); + return _db.QueryFirst(query); + } + + public Task GetGroupMemberCount(GroupId id, PrivacyLevel? privacyFilter = null) + { + var query = new Query("group_members") + .SelectRaw("count(*)") + .Where("group_members.group_id", id); + + if (privacyFilter != null) query = query + .Join("members", "group_members.member_id", "members.id") + .Where("members.member_visibility", privacyFilter); + + return _db.QueryFirst(query); + } + + public async Task CreateGroup(SystemId system, string name, IPKConnection? conn = null) + { + var query = new Query("groups").AsInsert(new + { + hid = new UnsafeLiteral("find_free_group_hid()"), + system = system, + name = name + }); + var group = await _db.QueryFirst(conn, query, extraSql: "returning *"); _logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name); return group; } - public Task UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch, IDbTransaction? transaction = null) + public Task UpdateGroup(GroupId id, GroupPatch patch, IPKConnection? conn = null) { _logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id")) - .WithConstant("id", id) - .Build("returning *"); - return conn.QueryFirstAsync(query, pms, transaction); + var query = patch.Apply(new Query("groups").Where("id", id)); + return _db.QueryFirst(conn, query, extraSql: "returning *"); } - public Task DeleteGroup(IPKConnection conn, GroupId group) + public Task DeleteGroup(GroupId group) { _logger.Information("Deleted {GroupId}", group); - return conn.ExecuteAsync("delete from groups where id = @Id", new { Id = @group }); - } - - public async Task AddMembersToGroup(IPKConnection conn, GroupId group, - IReadOnlyCollection members) - { - await using var w = - conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)"); - foreach (var member in members) - { - await w.StartRowAsync(); - await w.WriteAsync(group.Value); - await w.WriteAsync(member.Value); - } - - await w.CompleteAsync(); - _logger.Information("Added members to {GroupId}: {MemberIds}", group, members); - } - - public Task RemoveMembersFromGroup(IPKConnection conn, GroupId group, - IReadOnlyCollection members) - { - _logger.Information("Removed members from {GroupId}: {MemberIds}", group, members); - return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)", - new { Group = @group, Members = members.ToArray() }); + var query = new Query("groups").AsDelete().Where("id", group); + return _db.ExecuteQuery(query); } } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.GroupMember.cs b/PluralKit.Core/Database/Repository/ModelRepository.GroupMember.cs index a3b82149..992b0ed5 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.GroupMember.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.GroupMember.cs @@ -1,21 +1,25 @@ using System.Collections.Generic; -using System.Linq; using System.Threading.Tasks; -using Dapper; +using SqlKata; namespace PluralKit.Core { public partial class ModelRepository { - public IAsyncEnumerable GetMemberGroups(IPKConnection conn, MemberId id) => - conn.QueryStreamAsync( - "select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id", - new { Id = id }); - - - public async Task AddGroupsToMember(IPKConnection conn, MemberId member, IReadOnlyCollection groups) + public IAsyncEnumerable GetMemberGroups(MemberId id) { + var query = new Query("group_members") + .Select("groups.*") + .Join("groups", "group_members.group_id", "groups.id") + .Where("group_members.member_id", id); + return _db.QueryStream(query); + } + + // todo: add this to metrics tracking + public async Task AddGroupsToMember(MemberId member, IReadOnlyCollection groups) + { + await using var conn = await _db.Obtain(); await using var w = conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)"); foreach (var group in groups) @@ -29,12 +33,39 @@ namespace PluralKit.Core _logger.Information("Added member {MemberId} to groups {GroupIds}", member, groups); } - public Task RemoveGroupsFromMember(IPKConnection conn, MemberId member, IReadOnlyCollection groups) + public Task RemoveGroupsFromMember(MemberId member, IReadOnlyCollection groups) { _logger.Information("Removed groups from {MemberId}: {GroupIds}", member, groups); - return conn.ExecuteAsync("delete from group_members where member_id = @Member and group_id = any(@Groups)", - new { Member = @member, Groups = groups.ToArray() }); + var query = new Query("group_members").AsDelete() + .Where("member_id", member) + .WhereIn("group_id", groups); + return _db.ExecuteQuery(query); } + // todo: add this to metrics tracking + public async Task AddMembersToGroup(GroupId group, IReadOnlyCollection members) + { + await using var conn = await _db.Obtain(); + await using var w = + conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)"); + foreach (var member in members) + { + await w.StartRowAsync(); + await w.WriteAsync(group.Value); + await w.WriteAsync(member.Value); + } + + await w.CompleteAsync(); + _logger.Information("Added members to {GroupId}: {MemberIds}", group, members); + } + + public Task RemoveMembersFromGroup(GroupId group, IReadOnlyCollection members) + { + _logger.Information("Removed members from {GroupId}: {MemberIds}", group, members); + var query = new Query("group_members").AsDelete() + .Where("group_id", group) + .WhereIn("member_id", members); + return _db.ExecuteQuery(query); + } } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs b/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs index 43027049..538fce39 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs @@ -1,53 +1,63 @@ using System.Threading.Tasks; -using Dapper; +using SqlKata; namespace PluralKit.Core { public partial class ModelRepository { - public Task UpsertGuild(IPKConnection conn, ulong guild, GuildPatch patch) + public Task GetGuild(ulong guild) + { + var query = new Query("servers").AsInsert(new { id = guild }); + // sqlkata doesn't support postgres on conflict, so we just hack it on here + return _db.QueryFirst(query, "on conflict (id) do update set id = @$1 returning *"); + } + + public Task UpdateGuild(ulong guild, GuildPatch patch) { _logger.Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id")) - .WithConstant("id", guild) - .Build(); - return conn.ExecuteAsync(query, pms); + var query = patch.Apply(new Query("servers").Where("id", guild)); + return _db.ExecuteQuery(query); } - public Task UpsertSystemGuild(IPKConnection conn, SystemId system, ulong guild, - SystemGuildPatch patch) + + public Task GetSystemGuild(ulong guild, SystemId system) + { + var query = new Query("system_guild").AsInsert(new + { + guild = guild, + system = system + }); + return _db.QueryFirst(query, + extraSql: "on conflict (guild, system) do update set guild = $1, system = $2 returning *" + ); + } + + public Task UpdateSystemGuild(SystemId system, ulong guild, SystemGuildPatch patch) { _logger.Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild")) - .WithConstant("system", system) - .WithConstant("guild", guild) - .Build(); - return conn.ExecuteAsync(query, pms); + var query = patch.Apply(new Query("system_guild").Where("system", system).Where("guild", guild)); + return _db.ExecuteQuery(query); } - public Task UpsertMemberGuild(IPKConnection conn, MemberId member, ulong guild, - MemberGuildPatch patch) + + public Task GetMemberGuild(ulong guild, MemberId member) + { + var query = new Query("member_guild").AsInsert(new + { + guild = guild, + member = member + }); + return _db.QueryFirst(query, + extraSql: "on conflict (guild, member) do update set guild = $1, member = $2 returning *" + ); + } + + public Task UpdateMemberGuild(MemberId member, ulong guild, MemberGuildPatch patch) { _logger.Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild")) - .WithConstant("member", member) - .WithConstant("guild", guild) - .Build(); - return conn.ExecuteAsync(query, pms); + var query = patch.Apply(new Query("member_guild").Where("member", member).Where("guild", guild)); + return _db.ExecuteQuery(query); } - - public Task GetGuild(IPKConnection conn, ulong guild) => - conn.QueryFirstAsync("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new { guild }); - - public Task GetSystemGuild(IPKConnection conn, ulong guild, SystemId system) => - conn.QueryFirstAsync( - "insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *", - new { guild, system }); - - public Task GetMemberGuild(IPKConnection conn, ulong guild, MemberId member) => - conn.QueryFirstAsync( - "insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *", - new { guild, member }); } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Member.cs b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs index 74d680a0..646a5235 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Member.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs @@ -1,55 +1,77 @@ #nullable enable using System; -using System.Data; using System.Threading.Tasks; -using Dapper; +using SqlKata; namespace PluralKit.Core { public partial class ModelRepository { - public Task GetMember(IPKConnection conn, MemberId id) => - conn.QueryFirstOrDefaultAsync("select * from members where id = @id", new { id }); - - public Task GetMemberByHid(IPKConnection conn, string hid, SystemId? system = null) - => conn.QuerySingleOrDefaultAsync( - "select * from members where hid = @Hid" + (system != null ? " and system = @System" : ""), - new { Hid = hid.ToLower(), System = system } - ); - - public Task GetMemberByGuid(IPKConnection conn, Guid guid) => - conn.QuerySingleOrDefaultAsync("select * from members where uuid = @Uuid", new { Uuid = guid }); - - public Task GetMemberByName(IPKConnection conn, SystemId system, string name) => - conn.QueryFirstOrDefaultAsync("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system }); - - public Task GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) => - conn.QueryFirstOrDefaultAsync("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system }); - - public async Task CreateMember(IPKConnection conn, SystemId id, string memberName, IDbTransaction? transaction = null) + public Task GetMember(MemberId id) { - var member = await conn.QueryFirstAsync( - "insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", - new { SystemId = id, Name = memberName }, transaction); + var query = new Query("members").Where("id", id); + return _db.QueryFirst(query); + } + + public Task GetMemberByHid(string hid, SystemId? system = null) + { + var query = new Query("members").Where("hid", hid.ToLower()); + if (system != null) + query = query.Where("system", system); + return _db.QueryFirst(query); + } + + public Task GetMemberByGuid(Guid uuid) + { + var query = new Query("members").Where("uuid", uuid); + return _db.QueryFirst(query); + } + + public Task GetMemberByName(SystemId system, string name) + { + var query = new Query("members").WhereRaw( + "lower(name) = lower(?)", + name.ToLower() + ).Where("system", system); + return _db.QueryFirst(query); + } + + public Task GetMemberByDisplayName(SystemId system, string name) + { + var query = new Query("members").WhereRaw( + "lower(display_name) = lower(?)", + name.ToLower() + ).Where("system", system); + return _db.QueryFirst(query); + } + + public async Task CreateMember(SystemId systemId, string memberName, IPKConnection? conn = null) + { + var query = new Query("members").AsInsert(new + { + hid = new UnsafeLiteral("find_free_member_hid()"), + system = systemId, + name = memberName + }); + var member = await _db.QueryFirst(conn, query, "returning *"); _logger.Information("Created {MemberId} in {SystemId}: {MemberName}", - member.Id, id, memberName); + member.Id, systemId, memberName); return member; } - public Task UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch, IDbTransaction? transaction = null) + public Task UpdateMember(MemberId id, MemberPatch patch, IPKConnection? conn = null) { _logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id")) - .WithConstant("id", id) - .Build("returning *"); - return conn.QueryFirstAsync(query, pms, transaction); + var query = patch.Apply(new Query("members").Where("id", id)); + return _db.QueryFirst(conn, query); } - public Task DeleteMember(IPKConnection conn, MemberId id) + public Task DeleteMember(MemberId id) { _logger.Information("Deleted {MemberId}", id); - return conn.ExecuteAsync("delete from members where id = @Id", new { Id = id }); + var query = new Query("members").AsDelete().Where("id", id); + return _db.ExecuteQuery(query); } } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Message.cs b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs index 38780bdc..715f4264 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Message.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs @@ -4,17 +4,30 @@ using System.Threading.Tasks; using Dapper; +using SqlKata; + namespace PluralKit.Core { public partial class ModelRepository { - public async Task AddMessage(IPKConnection conn, PKMessage msg) + public Task AddMessage(PKMessage msg) { - // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before - await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@Mid, @Guild, @Channel, @Member, @Sender, @OriginalMid) on conflict do nothing", msg); + var query = new Query("messages").AsInsert(new + { + mid = msg.Mid, + guild = msg.Guild, + channel = msg.Channel, + member = msg.Member, + sender = msg.Sender, + original_mid = msg.OriginalMid, + }); _logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel); + + // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before + return _db.ExecuteQuery(query, extraSql: "on conflict do nothing"); } + // todo: add a Mapper to QuerySingle and move this to SqlKata public async Task GetMessage(IPKConnection conn, ulong id) { FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) => @@ -26,34 +39,36 @@ namespace PluralKit.Core return result.FirstOrDefault(); } - public async Task DeleteMessage(IPKConnection conn, ulong id) + public async Task DeleteMessage(ulong id) { - var rowCount = await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }); + var query = new Query("messages").AsDelete().Where("mid", id); + var rowCount = await _db.ExecuteQuery(query); if (rowCount > 0) _logger.Information("Deleted message {MessageId} from database", id); } - public async Task DeleteMessagesBulk(IPKConnection conn, IReadOnlyCollection ids) + public async Task DeleteMessagesBulk(IReadOnlyCollection ids) { // Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong // Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway) - var rowCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", - new { Ids = ids.Select(id => (long)id).ToArray() }); + var query = new Query("messages").AsDelete().WhereIn("mid", ids.Select(id => (long)id).ToArray()); + var rowCount = await _db.ExecuteQuery(query); if (rowCount > 0) _logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount, ids); } - public async Task GetLastMessage(IPKConnection conn, ulong guildId, ulong channelId, ulong accountId) + public Task GetLastMessage(ulong guildId, ulong channelId, ulong accountId) { // Want to index scan on the (guild, sender, mid) index so need the additional constraint - return await conn.QuerySingleOrDefaultAsync( - "select * from messages where guild = @Guild and channel = @Channel and sender = @Sender order by mid desc limit 1", new - { - Guild = guildId, - Channel = channelId, - Sender = accountId - }); + var query = new Query("messages") + .Where("guild", guildId) + .Where("channel", channelId) + .Where("sender", accountId) + .OrderByDesc("mid") + .Limit(1); + + return _db.QueryFirst(query); } } diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Switch.cs b/PluralKit.Core/Database/Repository/ModelRepository.Switch.cs index fb4a1441..0a8a5533 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Switch.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Switch.cs @@ -9,8 +9,11 @@ using NodaTime; using NpgsqlTypes; +using SqlKata; + namespace PluralKit.Core { + // todo: move the rest of the queries in here to SqlKata, if possible public partial class ModelRepository { public async Task AddSwitch(IPKConnection conn, SystemId system, IReadOnlyCollection members) @@ -69,40 +72,44 @@ namespace PluralKit.Core _logger.Information("Updated {SwitchId} members: {Members}", switchId, members); } - public async Task MoveSwitch(IPKConnection conn, SwitchId id, Instant time) + public Task MoveSwitch(SwitchId id, Instant time) { - await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id", - new { Time = time, Id = id }); - _logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", id, time); + var query = new Query("switches").AsUpdate(new { timestamp = time }).Where("id", id); + return _db.ExecuteQuery(query); } - public async Task DeleteSwitch(IPKConnection conn, SwitchId id) + public Task DeleteSwitch(SwitchId id) { - await conn.ExecuteAsync("delete from switches where id = @Id", new { Id = id }); _logger.Information("Deleted {Switch}", id); + var query = new Query("switches").AsDelete().Where("id", id); + return _db.ExecuteQuery(query); } - public async Task DeleteAllSwitches(IPKConnection conn, SystemId system) + public Task DeleteAllSwitches(SystemId system) { - await conn.ExecuteAsync("delete from switches where system = @Id", new { Id = system }); _logger.Information("Deleted all switches in {SystemId}", system); + var query = new Query("switches").AsDelete().Where("system", system); + return _db.ExecuteQuery(query); } - public IAsyncEnumerable GetSwitches(IPKConnection conn, SystemId system) + public IAsyncEnumerable GetSwitches(SystemId system) { // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list - return conn.QueryStreamAsync( - "select * from switches where system = @System order by timestamp desc", - new { System = system }); + var query = new Query("switches").Where("system", system).OrderByDesc("timestamp"); + return _db.QueryStream(query); } - public Task GetSwitchByUuid(IPKConnection conn, Guid uuid) => - conn.QuerySingleOrDefaultAsync("select * from switches where uuid = @Uuid", new { Uuid = uuid }); - - public async Task GetSwitchCount(IPKConnection conn, SystemId system) + public Task GetSwitchByUuid(Guid uuid) { - return await conn.QuerySingleAsync("select count(*) from switches where system = @Id", new { Id = system }); + var query = new Query("switches").Where("uuid", uuid); + return _db.QueryFirst(query); + } + + public Task GetSwitchCount(SystemId system) + { + var query = new Query("switches").SelectRaw("count(*)").Where("system", system); + return _db.QueryFirst(query); } public async IAsyncEnumerable GetSwitchMembersList(IPKConnection conn, @@ -149,9 +156,11 @@ namespace PluralKit.Core new { Switch = sw }); } - public async Task GetLatestSwitch(IPKConnection conn, SystemId system) => - // TODO: should query directly for perf - await GetSwitches(conn, system).FirstOrDefaultAsync(); + public Task GetLatestSwitch(SystemId system) + { + var query = new Query("switches").Where("system", system).OrderByDesc("timestamp").Limit(1); + return _db.QueryFirst(query); + } public async Task> GetPeriodFronters(IPKConnection conn, SystemId system, GroupId? group, Instant periodStart, diff --git a/PluralKit.Core/Database/Repository/ModelRepository.System.cs b/PluralKit.Core/Database/Repository/ModelRepository.System.cs index f03d03c4..a9a48e28 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.System.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.System.cs @@ -1,93 +1,120 @@ #nullable enable using System; using System.Collections.Generic; -using System.Text; using System.Threading.Tasks; -using Dapper; +using SqlKata; namespace PluralKit.Core { public partial class ModelRepository { - public Task GetSystem(IPKConnection conn, SystemId id) => - conn.QueryFirstOrDefaultAsync("select * from systems where id = @id", new { id }); - - public Task GetSystemByGuid(IPKConnection conn, Guid id) => - conn.QueryFirstOrDefaultAsync("select * from systems where uuid = @id", new { id }); - - public Task GetSystemByAccount(IPKConnection conn, ulong accountId) => - conn.QuerySingleOrDefaultAsync( - "select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", - new { Id = accountId }); - - public Task GetSystemByHid(IPKConnection conn, string hid) => - conn.QuerySingleOrDefaultAsync("select * from systems where systems.hid = @Hid", - new { Hid = hid.ToLower() }); - - public Task> GetSystemAccounts(IPKConnection conn, SystemId system) => - conn.QueryAsync("select uid from accounts where system = @Id", new { Id = system }); - - public IAsyncEnumerable GetSystemMembers(IPKConnection conn, SystemId system) => - conn.QueryStreamAsync("select * from members where system = @SystemID", new { SystemID = system }); - - public IAsyncEnumerable GetSystemGroups(IPKConnection conn, SystemId system) => - conn.QueryStreamAsync("select * from groups where system = @System", new { System = system }); - - public Task GetSystemMemberCount(IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null) + public Task GetSystem(SystemId id) { - var query = new StringBuilder("select count(*) from members where system = @Id"); - if (privacyFilter != null) - query.Append($" and member_visibility = {(int)privacyFilter.Value}"); - return conn.QuerySingleAsync(query.ToString(), new { Id = id }); + var query = new Query("systems").Where("id", id); + return _db.QueryFirst(query); } - public Task GetSystemGroupCount(IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null) + public Task GetSystemByGuid(Guid id) { - var query = new StringBuilder("select count(*) from groups where system = @Id"); - if (privacyFilter != null) - query.Append($" and visibility = {(int)privacyFilter.Value}"); - return conn.QuerySingleAsync(query.ToString(), new { Id = id }); + var query = new Query("systems").Where("uuid", id); + return _db.QueryFirst(query); } - public async Task CreateSystem(IPKConnection conn, string? systemName = null, IPKTransaction? tx = null) + + public Task GetSystemByAccount(ulong accountId) { - var system = await conn.QuerySingleAsync( - "insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *", - new { Name = systemName }, - transaction: tx); + var query = new Query("accounts").Select("systems.*").LeftJoin("systems", "systems.id", "accounts.system", "=").Where("uid", accountId); + return _db.QueryFirst(query); + } + + public Task GetSystemByHid(string hid) + { + var query = new Query("systems").Where("hid", hid.ToLower()); + return _db.QueryFirst(query); + } + + public Task> GetSystemAccounts(SystemId system) + { + var query = new Query("accounts").Select("uid").Where("system", system); + return _db.Query(query); + } + + public IAsyncEnumerable GetSystemMembers(SystemId system) + { + var query = new Query("members").Where("system", system); + return _db.QueryStream(query); + } + + public IAsyncEnumerable GetSystemGroups(SystemId system) + { + var query = new Query("groups").Where("system", system); + return _db.QueryStream(query); + } + + public Task GetSystemMemberCount(SystemId system, PrivacyLevel? privacyFilter = null) + { + var query = new Query("members").SelectRaw("count(*)").Where("system", system); + if (privacyFilter != null) + query.Where("member_visibility", (int)privacyFilter.Value); + + return _db.QueryFirst(query); + } + + public Task GetSystemGroupCount(SystemId system, PrivacyLevel? privacyFilter = null) + { + var query = new Query("groups").SelectRaw("count(*)").Where("system", system); + if (privacyFilter != null) + query.Where("visibility", (int)privacyFilter.Value); + + return _db.QueryFirst(query); + } + + public async Task CreateSystem(string? systemName = null, IPKConnection? conn = null) + { + var query = new Query("systems").AsInsert(new + { + hid = new UnsafeLiteral("find_free_system_hid()"), + name = systemName + }); + var system = await _db.QueryFirst(conn, query, extraSql: "returning *"); _logger.Information("Created {SystemId}", system.Id); return system; } - public Task UpdateSystem(IPKConnection conn, SystemId id, SystemPatch patch, IPKTransaction? tx = null) + public Task UpdateSystem(SystemId id, SystemPatch patch, IPKConnection? conn = null) { _logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id")) - .WithConstant("id", id) - .Build("returning *"); - return conn.QueryFirstAsync(query, pms, transaction: tx); + var query = patch.Apply(new Query("systems").Where("id", id)); + return _db.QueryFirst(conn, query, extraSql: "returning *"); } - public async Task AddAccount(IPKConnection conn, SystemId system, ulong accountId) + public Task AddAccount(SystemId system, ulong accountId) { // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent // This is used in import/export, although the pk;link command checks for this case beforehand - await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", - new { Id = accountId, SystemId = system }); + + var query = new Query("accounts").AsInsert(new + { + system = system, + uid = accountId, + }); + _logger.Information("Linked account {UserId} to {SystemId}", accountId, system); + return _db.ExecuteQuery(query, extraSql: "on conflict do nothing"); } - public async Task RemoveAccount(IPKConnection conn, SystemId system, ulong accountId) + public async Task RemoveAccount(SystemId system, ulong accountId) { - await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", - new { Id = accountId, SystemId = system }); + var query = new Query("accounts").AsDelete().Where("uid", accountId).Where("system", system); + await _db.ExecuteQuery(query); _logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system); } - public Task DeleteSystem(IPKConnection conn, SystemId id) + public Task DeleteSystem(SystemId id) { + var query = new Query("systems").AsDelete().Where("id", id); _logger.Information("Deleted {SystemId}", id); - return conn.ExecuteAsync("delete from systems where id = @Id", new { Id = id }); + return _db.ExecuteQuery(query); } } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.cs b/PluralKit.Core/Database/Repository/ModelRepository.cs index 6a2a324e..a3a35574 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.cs @@ -5,10 +5,11 @@ namespace PluralKit.Core public partial class ModelRepository { private readonly ILogger _logger; - - public ModelRepository(ILogger logger) + private readonly IDatabase _db; + public ModelRepository(ILogger logger, IDatabase db) { _logger = logger.ForContext(); + _db = db; } } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Utils/QueryPatchWrapper.cs b/PluralKit.Core/Database/Utils/QueryPatchWrapper.cs new file mode 100644 index 00000000..d76f60b1 --- /dev/null +++ b/PluralKit.Core/Database/Utils/QueryPatchWrapper.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; + +using SqlKata; + +namespace PluralKit.Core +{ + internal class QueryPatchWrapper + { + private Dictionary _dict = new(); + + public QueryPatchWrapper With(string columnName, Partial partialValue) + { + if (partialValue.IsPresent) + _dict.Add(columnName, partialValue); + + return this; + } + + public Query ToQuery(Query q) => q.AsUpdate(_dict); + } + + internal static class SqlKataExtensions + { + internal static Query ApplyPatch(this Query query, Func func) + => func(new QueryPatchWrapper()).ToQuery(query); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/AccountPatch.cs b/PluralKit.Core/Models/Patch/AccountPatch.cs index c4059d5e..9f2480d0 100644 --- a/PluralKit.Core/Models/Patch/AccountPatch.cs +++ b/PluralKit.Core/Models/Patch/AccountPatch.cs @@ -1,10 +1,13 @@ +using SqlKata; + namespace PluralKit.Core { public class AccountPatch: PatchObject { public Partial AllowAutoproxy { get; set; } - public override UpdateQueryBuilder Apply(UpdateQueryBuilder b) => b - .With("allow_autoproxy", AllowAutoproxy); + public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper + .With("allow_autoproxy", AllowAutoproxy) + ); } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/GroupPatch.cs b/PluralKit.Core/Models/Patch/GroupPatch.cs index ea5b0a35..75db6773 100644 --- a/PluralKit.Core/Models/Patch/GroupPatch.cs +++ b/PluralKit.Core/Models/Patch/GroupPatch.cs @@ -3,6 +3,8 @@ using System.Text.RegularExpressions; using Newtonsoft.Json.Linq; +using SqlKata; + namespace PluralKit.Core { public class GroupPatch: PatchObject @@ -20,7 +22,7 @@ namespace PluralKit.Core public Partial ListPrivacy { get; set; } public Partial Visibility { get; set; } - public override UpdateQueryBuilder Apply(UpdateQueryBuilder b) => b + public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper .With("name", Name) .With("hid", Hid) .With("display_name", DisplayName) @@ -31,7 +33,8 @@ namespace PluralKit.Core .With("description_privacy", DescriptionPrivacy) .With("icon_privacy", IconPrivacy) .With("list_privacy", ListPrivacy) - .With("visibility", Visibility); + .With("visibility", Visibility) + ); public new void AssertIsValid() { diff --git a/PluralKit.Core/Models/Patch/GuildPatch.cs b/PluralKit.Core/Models/Patch/GuildPatch.cs index 721f889f..b971a41d 100644 --- a/PluralKit.Core/Models/Patch/GuildPatch.cs +++ b/PluralKit.Core/Models/Patch/GuildPatch.cs @@ -1,3 +1,5 @@ +using SqlKata; + namespace PluralKit.Core { public class GuildPatch: PatchObject @@ -7,10 +9,11 @@ namespace PluralKit.Core public Partial Blacklist { get; set; } public Partial LogCleanupEnabled { get; set; } - public override UpdateQueryBuilder Apply(UpdateQueryBuilder b) => b + public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper .With("log_channel", LogChannel) .With("log_blacklist", LogBlacklist) .With("blacklist", Blacklist) - .With("log_cleanup_enabled", LogCleanupEnabled); + .With("log_cleanup_enabled", LogCleanupEnabled) + ); } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/MemberGuildPatch.cs b/PluralKit.Core/Models/Patch/MemberGuildPatch.cs index 5daf7e49..5207fb7e 100644 --- a/PluralKit.Core/Models/Patch/MemberGuildPatch.cs +++ b/PluralKit.Core/Models/Patch/MemberGuildPatch.cs @@ -1,4 +1,7 @@ #nullable enable + +using SqlKata; + namespace PluralKit.Core { public class MemberGuildPatch: PatchObject @@ -6,8 +9,9 @@ namespace PluralKit.Core public Partial DisplayName { get; set; } public Partial AvatarUrl { get; set; } - public override UpdateQueryBuilder Apply(UpdateQueryBuilder b) => b + public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper .With("display_name", DisplayName) - .With("avatar_url", AvatarUrl); + .With("avatar_url", AvatarUrl) + ); } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/MemberPatch.cs b/PluralKit.Core/Models/Patch/MemberPatch.cs index ac2207cd..b69577cd 100644 --- a/PluralKit.Core/Models/Patch/MemberPatch.cs +++ b/PluralKit.Core/Models/Patch/MemberPatch.cs @@ -6,6 +6,8 @@ using NodaTime; using Newtonsoft.Json.Linq; +using SqlKata; + namespace PluralKit.Core { public class MemberPatch: PatchObject @@ -31,7 +33,7 @@ namespace PluralKit.Core public Partial AvatarPrivacy { get; set; } public Partial MetadataPrivacy { get; set; } - public override UpdateQueryBuilder Apply(UpdateQueryBuilder b) => b + public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper .With("name", Name) .With("hid", Hid) .With("display_name", DisplayName) @@ -51,7 +53,8 @@ namespace PluralKit.Core .With("pronoun_privacy", PronounPrivacy) .With("birthday_privacy", BirthdayPrivacy) .With("avatar_privacy", AvatarPrivacy) - .With("metadata_privacy", MetadataPrivacy); + .With("metadata_privacy", MetadataPrivacy) + ); public new void AssertIsValid() { diff --git a/PluralKit.Core/Models/Patch/PatchObject.cs b/PluralKit.Core/Models/Patch/PatchObject.cs index d7e626fd..f1ddc84a 100644 --- a/PluralKit.Core/Models/Patch/PatchObject.cs +++ b/PluralKit.Core/Models/Patch/PatchObject.cs @@ -1,11 +1,13 @@ using System; using System.Text.RegularExpressions; +using SqlKata; + namespace PluralKit.Core { public abstract class PatchObject { - public abstract UpdateQueryBuilder Apply(UpdateQueryBuilder b); + public abstract Query Apply(Query q); public void AssertIsValid() { } diff --git a/PluralKit.Core/Models/Patch/SystemGuildPatch.cs b/PluralKit.Core/Models/Patch/SystemGuildPatch.cs index e07e15e1..a5642f85 100644 --- a/PluralKit.Core/Models/Patch/SystemGuildPatch.cs +++ b/PluralKit.Core/Models/Patch/SystemGuildPatch.cs @@ -1,4 +1,7 @@ #nullable enable + +using SqlKata; + namespace PluralKit.Core { public class SystemGuildPatch: PatchObject @@ -9,11 +12,12 @@ namespace PluralKit.Core public Partial Tag { get; set; } public Partial TagEnabled { get; set; } - public override UpdateQueryBuilder Apply(UpdateQueryBuilder b) => b + public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper .With("proxy_enabled", ProxyEnabled) .With("autoproxy_mode", AutoproxyMode) .With("autoproxy_member", AutoproxyMember) .With("tag", Tag) - .With("tag_enabled", TagEnabled); + .With("tag_enabled", TagEnabled) + ); } } \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/SystemPatch.cs b/PluralKit.Core/Models/Patch/SystemPatch.cs index bf4129d4..a2ba3b3e 100644 --- a/PluralKit.Core/Models/Patch/SystemPatch.cs +++ b/PluralKit.Core/Models/Patch/SystemPatch.cs @@ -6,6 +6,8 @@ using Newtonsoft.Json.Linq; using NodaTime; +using SqlKata; + namespace PluralKit.Core { public class SystemPatch: PatchObject @@ -29,7 +31,7 @@ namespace PluralKit.Core public Partial MemberLimitOverride { get; set; } public Partial GroupLimitOverride { get; set; } - public override UpdateQueryBuilder Apply(UpdateQueryBuilder b) => b + public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper .With("name", Name) .With("hid", Hid) .With("description", Description) @@ -47,7 +49,8 @@ namespace PluralKit.Core .With("pings_enabled", PingsEnabled) .With("latch_timeout", LatchTimeout) .With("member_limit_override", MemberLimitOverride) - .With("group_limit_override", GroupLimitOverride); + .With("group_limit_override", GroupLimitOverride) + ); public new void AssertIsValid() { diff --git a/PluralKit.Core/PluralKit.Core.csproj b/PluralKit.Core/PluralKit.Core.csproj index b16e47a1..c7633f61 100644 --- a/PluralKit.Core/PluralKit.Core.csproj +++ b/PluralKit.Core/PluralKit.Core.csproj @@ -42,6 +42,8 @@ + + @@ -50,10 +52,7 @@ - + diff --git a/PluralKit.Core/Services/DataFileService.cs b/PluralKit.Core/Services/DataFileService.cs index 9126341b..b0c2e721 100644 --- a/PluralKit.Core/Services/DataFileService.cs +++ b/PluralKit.Core/Services/DataFileService.cs @@ -40,10 +40,10 @@ namespace PluralKit.Core o.Add("avatar_url", system.AvatarUrl); o.Add("timezone", system.UiTz); o.Add("created", system.Created.FormatExport()); - o.Add("accounts", new JArray((await _repo.GetSystemAccounts(conn, system.Id)).ToList())); - o.Add("members", new JArray((await _repo.GetSystemMembers(conn, system.Id).ToListAsync()).Select(m => m.ToJson(LookupContext.ByOwner)))); + o.Add("accounts", new JArray((await _repo.GetSystemAccounts(system.Id)).ToList())); + o.Add("members", new JArray((await _repo.GetSystemMembers(system.Id).ToListAsync()).Select(m => m.ToJson(LookupContext.ByOwner)))); - var groups = (await _repo.GetSystemGroups(conn, system.Id).ToListAsync()); + var groups = (await _repo.GetSystemGroups(system.Id).ToListAsync()); var j_groups = groups.Select(x => x.ToJson(LookupContext.ByOwner, isExport: true)).ToList(); if (groups.Count > 0) diff --git a/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs b/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs index da76294e..5ad0297e 100644 --- a/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs +++ b/PluralKit.Core/Utils/BulkImporter/BulkImporter.cs @@ -48,8 +48,8 @@ namespace PluralKit.Core if (system == null) { - system = await repo.CreateSystem(conn, null, tx); - await repo.AddAccount(conn, system.Id, userId); + system = await repo.CreateSystem(null, importer._conn); + await repo.AddAccount(system.Id, userId); importer._result.CreatedSystem = system.Hid; importer._system = system; } @@ -113,7 +113,7 @@ namespace PluralKit.Core private async Task AssertMemberLimitNotReached(int newMembers) { var memberLimit = _system.MemberLimitOverride ?? Limits.MaxMemberCount; - var existingMembers = await _repo.GetSystemMemberCount(_conn, _system.Id); + var existingMembers = await _repo.GetSystemMemberCount(_system.Id); if (existingMembers + newMembers > memberLimit) throw new ImportException($"Import would exceed the maximum number of members ({memberLimit})."); } @@ -121,7 +121,7 @@ namespace PluralKit.Core private async Task AssertGroupLimitNotReached(int newGroups) { var limit = _system.GroupLimitOverride ?? Limits.MaxGroupCount; - var existing = await _repo.GetSystemGroupCount(_conn, _system.Id); + var existing = await _repo.GetSystemGroupCount(_system.Id); if (existing + newGroups > limit) throw new ImportException($"Import would exceed the maximum number of groups ({limit})."); } diff --git a/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs b/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs index 94389102..1412d612 100644 --- a/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs +++ b/PluralKit.Core/Utils/BulkImporter/PluralKitImport.cs @@ -29,7 +29,7 @@ namespace PluralKit.Core throw new ImportException($"Field {e.Message} in export file is invalid."); } - await _repo.UpdateSystem(_conn, _system.Id, patch, _tx); + await _repo.UpdateSystem(_system.Id, patch, _conn); var members = importFile.Value("members"); var groups = importFile.Value("groups"); @@ -104,13 +104,13 @@ namespace PluralKit.Core if (isNewMember) { - var newMember = await _repo.CreateMember(_conn, _system.Id, patch.Name.Value, _tx); + var newMember = await _repo.CreateMember(_system.Id, patch.Name.Value, _conn); memberId = newMember.Id; } _knownMemberIdentifiers[id] = memberId.Value; - await _repo.UpdateMember(_conn, memberId.Value, patch, _tx); + await _repo.UpdateMember(memberId.Value, patch, _conn); } private async Task ImportGroup(JObject group) @@ -145,13 +145,13 @@ namespace PluralKit.Core if (isNewGroup) { - var newGroup = await _repo.CreateGroup(_conn, _system.Id, patch.Name.Value, _tx); + var newGroup = await _repo.CreateGroup(_system.Id, patch.Name.Value, _conn); groupId = newGroup.Id; } _knownGroupIdentifiers[id] = groupId.Value; - await _repo.UpdateGroup(_conn, groupId.Value, patch, _tx); + await _repo.UpdateGroup(groupId.Value, patch, _conn); var groupMembers = group.Value("members"); var currentGroupMembers = (await _conn.QueryAsync( diff --git a/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs b/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs index df5137e7..43d8eb3f 100644 --- a/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs +++ b/PluralKit.Core/Utils/BulkImporter/TupperboxImport.cs @@ -90,7 +90,7 @@ namespace PluralKit.Core var isNewMember = false; if (!_existingMemberNames.TryGetValue(name, out var memberId)) { - var newMember = await _repo.CreateMember(_conn, _system.Id, name, _tx); + var newMember = await _repo.CreateMember(_system.Id, name, _conn); memberId = newMember.Id; isNewMember = true; _result.Added++; @@ -114,7 +114,7 @@ namespace PluralKit.Core throw new ImportException($"Field {e.Message} in tupper {name} is invalid."); } - await _repo.UpdateMember(_conn, memberId, patch, _tx); + await _repo.UpdateMember(memberId, patch, _conn); return (lastSetTag, multipleTags, hasGroup); } diff --git a/PluralKit.Core/packages.lock.json b/PluralKit.Core/packages.lock.json index d6f8267f..55c1aa6d 100644 --- a/PluralKit.Core/packages.lock.json +++ b/PluralKit.Core/packages.lock.json @@ -268,6 +268,26 @@ "System.Threading.Timer": "4.0.1" } }, + "SqlKata": { + "type": "Direct", + "requested": "[2.3.7, )", + "resolved": "2.3.7", + "contentHash": "erKffEMhrS2IFKXjYV83M4uc1IOCl91yeP/3uY5yIm6pRNFDNrqnTk3La1en6EGDlMRol9abTNO1erQCYf08tg==", + "dependencies": { + "System.Collections.Concurrent": "4.3.0" + } + }, + "SqlKata.Execution": { + "type": "Direct", + "requested": "[2.3.7, )", + "resolved": "2.3.7", + "contentHash": "LybTYj99riLRH7YQNt9Kuc8VpZOvaQ7H4sQBrj2zefktS8LASOaXsHRYC/k8NEcj25w6huQpOi+HrEZ5qHXl0w==", + "dependencies": { + "Humanizer.Core": "2.8.26", + "SqlKata": "2.3.7", + "dapper": "1.50.5" + } + }, "System.Interactive.Async": { "type": "Direct", "requested": "[5.0.0, )", @@ -339,6 +359,11 @@ "System.Diagnostics.DiagnosticSource": "4.5.1" } }, + "Humanizer.Core": { + "type": "Transitive", + "resolved": "2.8.26", + "contentHash": "OiKusGL20vby4uDEswj2IgkdchC1yQ6rwbIkZDVBPIR6al2b7n3pC91elBul9q33KaBgRKhbZH3+2Ur4fnWx2A==" + }, "Microsoft.Bcl.AsyncInterfaces": { "type": "Transitive", "resolved": "1.0.0",