From c7612df37eeb26e4a5846563f9260bf405826c7d Mon Sep 17 00:00:00 2001 From: Ske Date: Sat, 29 Aug 2020 13:46:27 +0200 Subject: [PATCH] Major database refactor (again) --- .../Controllers/v1/AccountController.cs | 14 +- .../Controllers/v1/MemberController.cs | 22 +- .../Controllers/v1/MessageController.cs | 10 +- .../Controllers/v1/SystemController.cs | 62 ++-- PluralKit.Bot/CommandSystem/Context.cs | 9 +- .../ContextEntityArgumentsExt.cs | 17 +- PluralKit.Bot/CommandSystem/Parameters.cs | 8 +- PluralKit.Bot/Commands/Autoproxy.cs | 10 +- PluralKit.Bot/Commands/Groups.cs | 40 ++- PluralKit.Bot/Commands/ImportExport.cs | 2 +- PluralKit.Bot/Commands/Member.cs | 31 +- PluralKit.Bot/Commands/MemberAvatar.cs | 12 +- PluralKit.Bot/Commands/MemberEdit.cs | 50 +-- PluralKit.Bot/Commands/MemberProxy.cs | 12 +- PluralKit.Bot/Commands/Misc.cs | 20 +- PluralKit.Bot/Commands/ServerConfig.cs | 24 +- PluralKit.Bot/Commands/Switch.cs | 35 +- PluralKit.Bot/Commands/System.cs | 20 +- PluralKit.Bot/Commands/SystemEdit.cs | 44 ++- PluralKit.Bot/Commands/SystemFront.cs | 41 ++- PluralKit.Bot/Commands/SystemLink.cs | 28 +- PluralKit.Bot/Commands/Token.cs | 6 +- PluralKit.Bot/Handlers/MessageCreated.cs | 8 +- PluralKit.Bot/Handlers/MessageDeleted.cs | 15 +- PluralKit.Bot/Handlers/MessageEdited.cs | 6 +- PluralKit.Bot/Handlers/ReactionAdded.cs | 23 +- PluralKit.Bot/Proxy/ProxyService.cs | 21 +- PluralKit.Bot/Services/CpuStatService.cs | 2 +- PluralKit.Bot/Services/EmbedService.cs | 32 +- .../Services/LastMessageCacheService.cs | 2 +- PluralKit.Bot/Services/LogChannelService.cs | 10 +- PluralKit.Bot/Services/LoggerCleanService.cs | 28 +- .../Services/PeriodicStatCollector.cs | 14 +- PluralKit.Bot/Services/ShardInfoService.cs | 9 +- PluralKit.Bot/Services/WebhookCacheService.cs | 8 +- .../Services/WebhookExecutorService.cs | 8 +- PluralKit.Core/Database/Database.cs | 11 +- .../ModelRepository.Context.cs} | 6 +- .../Repository/ModelRepository.Group.cs | 83 +++++ .../Repository/ModelRepository.Guild.cs | 53 +++ .../Repository/ModelRepository.Member.cs | 47 +++ .../Repository/ModelRepository.Message.cs | 63 ++++ .../Repository/ModelRepository.Switch.cs | 236 +++++++++++++ .../Repository/ModelRepository.System.cs | 78 ++++ .../Database/Repository/ModelRepository.cs | 15 + .../Database/Wrappers/PKConnection.cs | 6 +- PluralKit.Core/Models/ModelQueryExt.cs | 69 ---- PluralKit.Core/Models/Patch/ModelPatchExt.cs | 128 ------- PluralKit.Core/Models/SystemGuildSettings.cs | 8 + PluralKit.Core/Modules.cs | 4 +- PluralKit.Core/Services/DataFileService.cs | 30 +- PluralKit.Core/Services/IDataStore.cs | 223 ------------ PluralKit.Core/Services/PostgresDataStore.cs | 334 ------------------ PluralKit.Core/Utils/ConnectionUtils.cs | 11 +- PluralKit.Core/Utils/QueryBuilder.cs | 6 +- 55 files changed, 1014 insertions(+), 1100 deletions(-) rename PluralKit.Core/Database/{Functions/DatabaseFunctionsExt.cs => Repository/ModelRepository.Context.cs} (66%) create mode 100644 PluralKit.Core/Database/Repository/ModelRepository.Group.cs create mode 100644 PluralKit.Core/Database/Repository/ModelRepository.Guild.cs create mode 100644 PluralKit.Core/Database/Repository/ModelRepository.Member.cs create mode 100644 PluralKit.Core/Database/Repository/ModelRepository.Message.cs create mode 100644 PluralKit.Core/Database/Repository/ModelRepository.Switch.cs create mode 100644 PluralKit.Core/Database/Repository/ModelRepository.System.cs create mode 100644 PluralKit.Core/Database/Repository/ModelRepository.cs delete mode 100644 PluralKit.Core/Models/ModelQueryExt.cs delete mode 100644 PluralKit.Core/Models/Patch/ModelPatchExt.cs delete mode 100644 PluralKit.Core/Services/IDataStore.cs delete mode 100644 PluralKit.Core/Services/PostgresDataStore.cs diff --git a/PluralKit.API/Controllers/v1/AccountController.cs b/PluralKit.API/Controllers/v1/AccountController.cs index b563bfca..edb3114a 100644 --- a/PluralKit.API/Controllers/v1/AccountController.cs +++ b/PluralKit.API/Controllers/v1/AccountController.cs @@ -13,18 +13,20 @@ namespace PluralKit.API [Route( "v{version:apiVersion}/a" )] public class AccountController: ControllerBase { - private IDataStore _data; - - public AccountController(IDataStore data) + private readonly IDatabase _db; + private readonly ModelRepository _repo; + public AccountController(IDatabase db, ModelRepository repo) { - _data = data; + _db = db; + _repo = repo; } [HttpGet("{aid}")] public async Task> GetSystemByAccount(ulong aid) { - var system = await _data.GetSystemByAccount(aid); - if (system == null) return NotFound("Account not found."); + var system = await _db.Execute(c => _repo.GetSystemByAccount(c, aid)); + if (system == null) + return NotFound("Account not found."); return Ok(system.ToJson(User.ContextFor(system))); } diff --git a/PluralKit.API/Controllers/v1/MemberController.cs b/PluralKit.API/Controllers/v1/MemberController.cs index 30bc8bef..bea62798 100644 --- a/PluralKit.API/Controllers/v1/MemberController.cs +++ b/PluralKit.API/Controllers/v1/MemberController.cs @@ -16,19 +16,21 @@ namespace PluralKit.API [Route( "v{version:apiVersion}/m" )] public class MemberController: ControllerBase { - private IDatabase _db; - private IAuthorizationService _auth; + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly IAuthorizationService _auth; - public MemberController(IAuthorizationService auth, IDatabase db) + public MemberController(IAuthorizationService auth, IDatabase db, ModelRepository repo) { _auth = auth; _db = db; + _repo = repo; } [HttpGet("{hid}")] public async Task> GetMember(string hid) { - var member = await _db.Execute(conn => conn.QueryMemberByHid(hid)); + var member = await _db.Execute(conn => _repo.GetMemberByHid(conn, hid)); if (member == null) return NotFound("Member not found."); return Ok(member.ToJson(User.ContextFor(member))); @@ -49,7 +51,7 @@ namespace PluralKit.API if (memberCount >= Limits.MaxMemberCount) return BadRequest($"Member limit reached ({memberCount} / {Limits.MaxMemberCount})."); - var member = await conn.CreateMember(system, properties.Value("name")); + var member = await _repo.CreateMember(conn, system, properties.Value("name")); MemberPatch patch; try { @@ -60,7 +62,7 @@ namespace PluralKit.API return BadRequest(e.Message); } - member = await conn.UpdateMember(member.Id, patch); + member = await _repo.UpdateMember(conn, member.Id, patch); return Ok(member.ToJson(User.ContextFor(member))); } @@ -70,7 +72,7 @@ namespace PluralKit.API { await using var conn = await _db.Obtain(); - var member = await conn.QueryMemberByHid(hid); + var member = await _repo.GetMemberByHid(conn, hid); if (member == null) return NotFound("Member not found."); var res = await _auth.AuthorizeAsync(User, member, "EditMember"); @@ -86,7 +88,7 @@ namespace PluralKit.API return BadRequest(e.Message); } - var newMember = await conn.UpdateMember(member.Id, patch); + var newMember = await _repo.UpdateMember(conn, member.Id, patch); return Ok(newMember.ToJson(User.ContextFor(newMember))); } @@ -96,13 +98,13 @@ namespace PluralKit.API { await using var conn = await _db.Obtain(); - var member = await conn.QueryMemberByHid(hid); + var member = await _repo.GetMemberByHid(conn, hid); if (member == null) return NotFound("Member not found."); var res = await _auth.AuthorizeAsync(User, member, "EditMember"); if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system."); - await conn.DeleteMember(member.Id); + await _repo.DeleteMember(conn, member.Id); return Ok(); } } diff --git a/PluralKit.API/Controllers/v1/MessageController.cs b/PluralKit.API/Controllers/v1/MessageController.cs index f4f12f67..a036a4c0 100644 --- a/PluralKit.API/Controllers/v1/MessageController.cs +++ b/PluralKit.API/Controllers/v1/MessageController.cs @@ -28,17 +28,19 @@ namespace PluralKit.API [Route( "v{version:apiVersion}/msg" )] public class MessageController: ControllerBase { - private IDataStore _data; + private readonly IDatabase _db; + private readonly ModelRepository _repo; - public MessageController(IDataStore _data) + public MessageController(ModelRepository repo, IDatabase db) { - this._data = _data; + _repo = repo; + _db = db; } [HttpGet("{mid}")] public async Task> GetMessage(ulong mid) { - var msg = await _data.GetMessage(mid); + var msg = await _db.Execute(c => _repo.GetMessage(c, mid)); if (msg == null) return NotFound("Message not found."); return new MessageReturn diff --git a/PluralKit.API/Controllers/v1/SystemController.cs b/PluralKit.API/Controllers/v1/SystemController.cs index ef44acf1..0dce14e1 100644 --- a/PluralKit.API/Controllers/v1/SystemController.cs +++ b/PluralKit.API/Controllers/v1/SystemController.cs @@ -39,29 +39,29 @@ namespace PluralKit.API [Route( "v{version:apiVersion}/s" )] public class SystemController : ControllerBase { - private IDataStore _data; - private IDatabase _db; - private IAuthorizationService _auth; + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly IAuthorizationService _auth; - public SystemController(IDataStore data, IDatabase db, IAuthorizationService auth) + public SystemController(IDatabase db, IAuthorizationService auth, ModelRepository repo) { - _data = data; _db = db; _auth = auth; + _repo = repo; } [HttpGet] [Authorize] public async Task> GetOwnSystem() { - var system = await _db.Execute(c => c.QuerySystem(User.CurrentSystem())); + var system = await _db.Execute(c => _repo.GetSystem(c, User.CurrentSystem())); return system.ToJson(User.ContextFor(system)); } [HttpGet("{hid}")] public async Task> GetSystem(string hid) { - var system = await _data.GetSystemByHid(hid); + var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid)); if (system == null) return NotFound("System not found."); return Ok(system.ToJson(User.ContextFor(system))); } @@ -69,13 +69,14 @@ namespace PluralKit.API [HttpGet("{hid}/members")] public async Task>> GetMembers(string hid) { - var system = await _data.GetSystemByHid(hid); - if (system == null) return NotFound("System not found."); + var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid)); + if (system == null) + return NotFound("System not found."); if (!system.MemberListPrivacy.CanAccess(User.ContextFor(system))) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list."); - var members = _data.GetSystemMembers(system); + var members = _db.Execute(c => _repo.GetSystemMembers(c, system.Id)); return Ok(await members .Where(m => m.MemberVisibility.CanAccess(User.ContextFor(system))) .Select(m => m.ToJson(User.ContextFor(system))) @@ -87,39 +88,40 @@ namespace PluralKit.API { if (before == null) before = SystemClock.Instance.GetCurrentInstant(); - var system = await _data.GetSystemByHid(hid); + await using var conn = await _db.Obtain(); + + var system = await _repo.GetSystemByHid(conn, hid); if (system == null) return NotFound("System not found."); var auth = await _auth.AuthorizeAsync(User, system, "ViewFrontHistory"); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history."); - using (var conn = await _db.Obtain()) - { - var res = await conn.QueryAsync( - @"select *, array( + var res = await conn.QueryAsync( + @"select *, array( select members.hid from switch_members, members where switch_members.switch = switches.id and members.id = switch_members.member ) as members from switches where switches.system = @System and switches.timestamp < @Before order by switches.timestamp desc limit 100;", new {System = system.Id, Before = before}); - return Ok(res); - } + return Ok(res); } [HttpGet("{hid}/fronters")] public async Task> GetFronters(string hid) { - var system = await _data.GetSystemByHid(hid); + await using var conn = await _db.Obtain(); + + var system = await _repo.GetSystemByHid(conn, hid); if (system == null) return NotFound("System not found."); var auth = await _auth.AuthorizeAsync(User, system, "ViewFront"); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter."); - var sw = await _data.GetLatestSwitch(system.Id); + var sw = await _repo.GetLatestSwitch(conn, system.Id); if (sw == null) return NotFound("System has no registered switches."); - var members = _data.GetSwitchMembers(sw); + var members = _repo.GetSwitchMembers(conn, sw.Id); return Ok(new FrontersReturn { Timestamp = sw.Timestamp, @@ -131,7 +133,8 @@ namespace PluralKit.API [Authorize] public async Task> EditSystem([FromBody] JObject changes) { - var system = await _db.Execute(c => c.QuerySystem(User.CurrentSystem())); + await using var conn = await _db.Obtain(); + var system = await _repo.GetSystem(conn, User.CurrentSystem()); SystemPatch patch; try @@ -143,7 +146,7 @@ namespace PluralKit.API return BadRequest(e.Message); } - await _db.Execute(conn => conn.UpdateSystem(system.Id, patch)); + await _repo.UpdateSystem(conn, system!.Id, patch); return Ok(system.ToJson(User.ContextFor(system))); } @@ -154,11 +157,13 @@ namespace PluralKit.API if (param.Members.Distinct().Count() != param.Members.Count) return BadRequest("Duplicate members in member list."); + await using var conn = await _db.Obtain(); + // We get the current switch, if it exists - var latestSwitch = await _data.GetLatestSwitch(User.CurrentSystem()); + var latestSwitch = await _repo.GetLatestSwitch(conn, User.CurrentSystem()); if (latestSwitch != null) { - var latestSwitchMembers = _data.GetSwitchMembers(latestSwitch); + var latestSwitchMembers = _repo.GetSwitchMembers(conn, latestSwitch.Id); // Bail if this switch is identical to the latest one if (await latestSwitchMembers.Select(m => m.Hid).SequenceEqualAsync(param.Members.ToAsyncEnumerable())) @@ -166,9 +171,7 @@ namespace PluralKit.API } // Resolve member objects for all given IDs - IEnumerable membersList; - using (var conn = await _db.Obtain()) - membersList = (await conn.QueryAsync("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList(); + var membersList = (await conn.QueryAsync("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList(); foreach (var member in membersList) if (member.System != User.CurrentSystem()) @@ -182,12 +185,13 @@ namespace PluralKit.API // We do this without .Select() since we want to have the early return bail if it doesn't find the member foreach (var givenMemberId in param.Members) { - if (!membersDict.TryGetValue(givenMemberId, out var member)) return BadRequest($"Member '{givenMemberId}' not found."); + if (!membersDict.TryGetValue(givenMemberId, out var member)) + return BadRequest($"Member '{givenMemberId}' not found."); membersInOrder.Add(member); } // Finally, log the switch (yay!) - await _data.AddSwitch(User.CurrentSystem(), membersInOrder); + await _repo.AddSwitch(conn, User.CurrentSystem(), membersInOrder.Select(m => m.Id).ToList()); return NoContent(); } } diff --git a/PluralKit.Bot/CommandSystem/Context.cs b/PluralKit.Bot/CommandSystem/Context.cs index ccdd03be..1cff801b 100644 --- a/PluralKit.Bot/CommandSystem/Context.cs +++ b/PluralKit.Bot/CommandSystem/Context.cs @@ -15,7 +15,7 @@ namespace PluralKit.Bot { public class Context { - private ILifetimeScope _provider; + private readonly ILifetimeScope _provider; private readonly DiscordRestClient _rest; private readonly DiscordShardedClient _client; @@ -24,8 +24,8 @@ namespace PluralKit.Bot private readonly Parameters _parameters; private readonly MessageContext _messageContext; - private readonly IDataStore _data; private readonly IDatabase _db; + private readonly ModelRepository _repo; private readonly PKSystem _senderSystem; private readonly IMetrics _metrics; @@ -38,10 +38,10 @@ namespace PluralKit.Bot _client = provider.Resolve(); _message = message; _shard = shard; - _data = provider.Resolve(); _senderSystem = senderSystem; _messageContext = messageContext; _db = provider.Resolve(); + _repo = provider.Resolve(); _metrics = provider.Resolve(); _provider = provider; _parameters = new Parameters(message.Content.Substring(commandParseOffset)); @@ -61,9 +61,8 @@ namespace PluralKit.Bot public Parameters Parameters => _parameters; - // TODO: this is just here so the extension methods can access it; should it be public/private/? - internal IDataStore DataStore => _data; internal IDatabase Database => _db; + internal ModelRepository Repository => _repo; public Task Reply(string text = null, DiscordEmbed embed = null, IEnumerable mentions = null) { diff --git a/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs b/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs index 4926efe7..7c58fdd7 100644 --- a/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs +++ b/PluralKit.Bot/CommandSystem/ContextEntityArgumentsExt.cs @@ -47,12 +47,14 @@ namespace PluralKit.Bot // - A @mention of an account connected to the system (<@uid>) // - A system hid + await using var conn = await ctx.Database.Obtain(); + // Direct IDs and mentions are both handled by the below method: if (input.TryParseMention(out var id)) - return await ctx.DataStore.GetSystemByAccount(id); + return await ctx.Repository.GetSystemByAccount(conn, id); // Finally, try HID parsing - var system = await ctx.DataStore.GetSystemByHid(input); + var system = await ctx.Repository.GetSystemByHid(conn, input); return system; } @@ -67,15 +69,16 @@ namespace PluralKit.Bot // - a textual display name of a member *in your own system* // First, if we have a system, try finding by member name in system - if (ctx.System != null && await ctx.DataStore.GetMemberByName(ctx.System, input) is PKMember memberByName) + await using var conn = await ctx.Database.Obtain(); + if (ctx.System != null && await ctx.Repository.GetMemberByName(conn, ctx.System.Id, input) is PKMember memberByName) return memberByName; // Then, try member HID parsing: - if (await ctx.DataStore.GetMemberByHid(input) is PKMember memberByHid) + if (await ctx.Repository.GetMemberByHid(conn, input) is PKMember memberByHid) return memberByHid; // And if that again fails, we try finding a member with a display name matching the argument from the system - if (ctx.System != null && await ctx.DataStore.GetMemberByDisplayName(ctx.System, input) is PKMember memberByDisplayName) + if (ctx.System != null && await ctx.Repository.GetMemberByDisplayName(conn, ctx.System.Id, input) is PKMember memberByDisplayName) return memberByDisplayName; // We didn't find anything, so we return null. @@ -103,9 +106,9 @@ namespace PluralKit.Bot var input = ctx.PeekArgument(); await using var conn = await ctx.Database.Obtain(); - if (ctx.System != null && await conn.QueryGroupByName(ctx.System.Id, input) is {} byName) + if (ctx.System != null && await ctx.Repository.GetGroupByName(conn, ctx.System.Id, input) is {} byName) return byName; - if (await conn.QueryGroupByHid(input) is {} byHid) + if (await ctx.Repository.GetGroupByHid(conn, input) is {} byHid) return byHid; return null; diff --git a/PluralKit.Bot/CommandSystem/Parameters.cs b/PluralKit.Bot/CommandSystem/Parameters.cs index ebb69194..8dc229af 100644 --- a/PluralKit.Bot/CommandSystem/Parameters.cs +++ b/PluralKit.Bot/CommandSystem/Parameters.cs @@ -36,15 +36,15 @@ namespace PluralKit.Bot private struct WordPosition { // Start of the word - internal int startPos; + internal readonly int startPos; // End of the word - internal int endPos; + internal readonly int endPos; // How much to advance word pointer afterwards to point at the start of the *next* word - internal int advanceAfterWord; + internal readonly int advanceAfterWord; - internal bool wasQuoted; + internal readonly bool wasQuoted; public WordPosition(int startPos, int endPos, int advanceAfterWord, bool wasQuoted) { diff --git a/PluralKit.Bot/Commands/Autoproxy.cs b/PluralKit.Bot/Commands/Autoproxy.cs index e78f226f..7181c1cf 100644 --- a/PluralKit.Bot/Commands/Autoproxy.cs +++ b/PluralKit.Bot/Commands/Autoproxy.cs @@ -12,10 +12,12 @@ namespace PluralKit.Bot public class Autoproxy { private readonly IDatabase _db; + private readonly ModelRepository _repo; - public Autoproxy(IDatabase db) + public Autoproxy(IDatabase db, ModelRepository repo) { _db = db; + _repo = repo; } public async Task AutoproxyRoot(Context ctx) @@ -87,8 +89,8 @@ namespace PluralKit.Bot var fronters = ctx.MessageContext.LastSwitchMembers; var relevantMember = ctx.MessageContext.AutoproxyMode switch { - AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => c.QueryMember(fronters[0])) : null, - AutoproxyMode.Member => await _db.Execute(c => c.QueryMember(ctx.MessageContext.AutoproxyMember.Value)), + AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => _repo.GetMember(c, fronters[0])) : null, + AutoproxyMode.Member => await _db.Execute(c => _repo.GetMember(c, ctx.MessageContext.AutoproxyMember.Value)), _ => null }; @@ -126,7 +128,7 @@ namespace PluralKit.Bot private Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, MemberId? autoproxyMember) { var patch = new SystemGuildPatch {AutoproxyMode = autoproxyMode, AutoproxyMember = autoproxyMember}; - return _db.Execute(conn => conn.UpsertSystemGuild(ctx.System.Id, ctx.Guild.Id, patch)); + return _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch)); } } } \ No newline at end of file diff --git a/PluralKit.Bot/Commands/Groups.cs b/PluralKit.Bot/Commands/Groups.cs index 926218de..24f6b928 100644 --- a/PluralKit.Bot/Commands/Groups.cs +++ b/PluralKit.Bot/Commands/Groups.cs @@ -17,10 +17,12 @@ namespace PluralKit.Bot public class Groups { private readonly IDatabase _db; + private readonly ModelRepository _repo; - public Groups(IDatabase db) + public Groups(IDatabase db, ModelRepository repo) { _db = db; + _repo = repo; } public async Task CreateGroup(Context ctx) @@ -40,14 +42,14 @@ namespace PluralKit.Bot throw new PKError($"System has reached the maximum number of groups ({Limits.MaxGroupCount}). Please delete unused groups first in order to create new ones."); // Warn if there's already a group by this name - var existingGroup = await conn.QueryGroupByName(ctx.System.Id, groupName); + var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, groupName); if (existingGroup != null) { var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to create another group with the same name?"; if (!await ctx.PromptYesNo(msg)) throw new PKError("Group creation cancelled."); } - var newGroup = await conn.CreateGroup(ctx.System.Id, groupName); + var newGroup = await _repo.CreateGroup(conn, ctx.System.Id, groupName); var eb = new DiscordEmbedBuilder() .WithDescription($"Your new group, **{groupName}**, has been created, with the group ID **`{newGroup.Hid}`**.\nBelow are a couple of useful commands:") @@ -70,14 +72,14 @@ namespace PluralKit.Bot await using var conn = await _db.Obtain(); // Warn if there's already a group by this name - var existingGroup = await conn.QueryGroupByName(ctx.System.Id, newName); + var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, newName); if (existingGroup != null && existingGroup.Id != target.Id) { var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to rename this member to that name too?"; if (!await ctx.PromptYesNo(msg)) throw new PKError("Group creation cancelled."); } - await conn.UpdateGroup(target.Id, new GroupPatch {Name = newName}); + await _repo.UpdateGroup(conn, target.Id, new GroupPatch {Name = newName}); await ctx.Reply($"{Emojis.Success} Group name changed from **{target.Name}** to **{newName}**."); } @@ -89,7 +91,7 @@ namespace PluralKit.Bot ctx.CheckOwnGroup(target); var patch = new GroupPatch {DisplayName = Partial.Null()}; - await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Group display name cleared."); } @@ -112,7 +114,7 @@ namespace PluralKit.Bot var newDisplayName = ctx.RemainderOrNull(); var patch = new GroupPatch {DisplayName = Partial.Present(newDisplayName)}; - await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Group display name changed."); } @@ -125,7 +127,7 @@ namespace PluralKit.Bot ctx.CheckOwnGroup(target); var patch = new GroupPatch {Description = Partial.Null()}; - await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Group description cleared."); } else if (!ctx.HasNext()) @@ -154,7 +156,7 @@ namespace PluralKit.Bot throw Errors.DescriptionTooLongError(description.Length); var patch = new GroupPatch {Description = Partial.Present(description)}; - await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Group description changed."); } @@ -166,7 +168,7 @@ namespace PluralKit.Bot { ctx.CheckOwnGroup(target); - await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch {Icon = null})); + await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch {Icon = null})); await ctx.Reply($"{Emojis.Success} Group icon cleared."); } @@ -178,7 +180,7 @@ namespace PluralKit.Bot throw Errors.InvalidUrl(img.Url); await AvatarUtils.VerifyAvatarOrThrow(img.Url); - await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch {Icon = img.Url})); + await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch {Icon = img.Url})); var msg = img.Source switch { @@ -282,7 +284,7 @@ namespace PluralKit.Bot var system = await GetGroupSystem(ctx, target, conn); var pctx = ctx.LookupContextFor(system); - var memberCount = await conn.QueryGroupMemberCount(target.Id, PrivacyLevel.Public); + var memberCount = await _repo.GetGroupMemberCount(conn, target.Id, PrivacyLevel.Public); var nameField = target.Name; if (system.Name != null) @@ -333,7 +335,7 @@ namespace PluralKit.Bot .Select(m => m.Id) .Distinct() .ToList(); - await conn.AddMembersToGroup(target.Id, membersNotInGroup); + await _repo.AddMembersToGroup(conn, target.Id, membersNotInGroup); if (membersNotInGroup.Count == members.Count) await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersNotInGroup.Count)} added to group."); @@ -347,7 +349,7 @@ namespace PluralKit.Bot .Select(m => m.Id) .Distinct() .ToList(); - await conn.RemoveMembersFromGroup(target.Id, membersInGroup); + await _repo.RemoveMembersFromGroup(conn, target.Id, membersInGroup); if (membersInGroup.Count == members.Count) await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersInGroup.Count)} removed from group."); @@ -422,7 +424,7 @@ namespace PluralKit.Bot async Task SetAll(PrivacyLevel level) { - await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch().WithAllPrivacy(level))); + await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithAllPrivacy(level))); if (level == PrivacyLevel.Private) await ctx.Reply($"{Emojis.Success} All {target.Name}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the group card."); @@ -432,7 +434,7 @@ namespace PluralKit.Bot async Task SetLevel(GroupPrivacySubject subject, PrivacyLevel level) { - await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch().WithPrivacy(subject, level))); + await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithPrivacy(subject, level))); var subjectName = subject switch { @@ -475,17 +477,17 @@ namespace PluralKit.Bot if (!await ctx.ConfirmWithReply(target.Hid)) throw new PKError($"Group deletion cancelled. Note that you must reply with your group ID (`{target.Hid}`) *verbatim*."); - await _db.Execute(conn => conn.DeleteGroup(target.Id)); + await _db.Execute(conn => _repo.DeleteGroup(conn, target.Id)); await ctx.Reply($"{Emojis.Success} Group deleted."); } - private static async Task GetGroupSystem(Context ctx, PKGroup target, IPKConnection conn) + private async Task GetGroupSystem(Context ctx, PKGroup target, IPKConnection conn) { var system = ctx.System; if (system?.Id == target.System) return system; - return await conn.QuerySystem(target.System)!; + return await _repo.GetSystem(conn, target.System)!; } } } \ No newline at end of file diff --git a/PluralKit.Bot/Commands/ImportExport.cs b/PluralKit.Bot/Commands/ImportExport.cs index eb53a58d..377580b0 100644 --- a/PluralKit.Bot/Commands/ImportExport.cs +++ b/PluralKit.Bot/Commands/ImportExport.cs @@ -15,7 +15,7 @@ namespace PluralKit.Bot { public class ImportExport { - private DataFileService _dataFiles; + private readonly DataFileService _dataFiles; public ImportExport(DataFileService dataFiles) { _dataFiles = dataFiles; diff --git a/PluralKit.Bot/Commands/Member.cs b/PluralKit.Bot/Commands/Member.cs index 92fb5dcb..737b84f5 100644 --- a/PluralKit.Bot/Commands/Member.cs +++ b/PluralKit.Bot/Commands/Member.cs @@ -8,15 +8,15 @@ namespace PluralKit.Bot { public class Member { - private IDataStore _data; - private IDatabase _db; - private EmbedService _embeds; + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly EmbedService _embeds; - public Member(IDataStore data, EmbedService embeds, IDatabase db) + public Member(EmbedService embeds, IDatabase db, ModelRepository repo) { - _data = data; _embeds = embeds; _db = db; + _repo = repo; } public async Task NewMember(Context ctx) { @@ -27,7 +27,7 @@ namespace PluralKit.Bot if (memberName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(memberName.Length); // Warn if there's already a member by this name - var existingMember = await _data.GetMemberByName(ctx.System, memberName); + var existingMember = await _db.Execute(c => _repo.GetMemberByName(c, ctx.System.Id, memberName)); if (existingMember != null) { var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (with ID `{existingMember.Hid}`). Do you want to create another member with the same name?"; if (!await ctx.PromptYesNo(msg)) throw new PKError("Member creation cancelled."); @@ -36,12 +36,12 @@ namespace PluralKit.Bot await using var conn = await _db.Obtain(); // Enforce per-system member limit - var memberCount = await conn.GetSystemMemberCount(ctx.System.Id); + var memberCount = await _repo.GetSystemMemberCount(conn, ctx.System.Id); if (memberCount >= Limits.MaxMemberCount) throw Errors.MemberLimitReachedError; // Create the member - var member = await conn.CreateMember(ctx.System.Id, memberName); + var member = await _repo.CreateMember(conn, ctx.System.Id, memberName); memberCount++; // Send confirmation and space hint @@ -62,10 +62,14 @@ namespace PluralKit.Bot //Maybe move this somewhere else in the file structure since it doesn't need to get created at every command // TODO: don't buffer these, find something else to do ig - - List members; - if (ctx.MatchFlag("all", "a")) members = await _data.GetSystemMembers(ctx.System).ToListAsync(); - else members = await _data.GetSystemMembers(ctx.System).Where(m => m.MemberVisibility == PrivacyLevel.Public).ToListAsync(); + + var members = await _db.Execute(c => + { + if (ctx.MatchFlag("all", "a")) + return _repo.GetSystemMembers(c, ctx.System.Id); + return _repo.GetSystemMembers(c, ctx.System.Id) + .Where(m => m.MemberVisibility == PrivacyLevel.Public); + }).ToListAsync(); if (members == null || !members.Any()) throw Errors.NoMembersError; @@ -75,8 +79,7 @@ namespace PluralKit.Bot public async Task ViewMember(Context ctx, PKMember target) { - - var system = await _db.Execute(c => c.QuerySystem(target.System)); + var system = await _db.Execute(c => _repo.GetSystem(c, target.System)); await ctx.Reply(embed: await _embeds.CreateMemberEmbed(system, target, ctx.Guild, ctx.LookupContextFor(system))); } } diff --git a/PluralKit.Bot/Commands/MemberAvatar.cs b/PluralKit.Bot/Commands/MemberAvatar.cs index bdfb47ec..bf1726dc 100644 --- a/PluralKit.Bot/Commands/MemberAvatar.cs +++ b/PluralKit.Bot/Commands/MemberAvatar.cs @@ -11,10 +11,12 @@ namespace PluralKit.Bot public class MemberAvatar { private readonly IDatabase _db; + private readonly ModelRepository _repo; - public MemberAvatar(IDatabase db) + public MemberAvatar(IDatabase db, ModelRepository repo) { _db = db; + _repo = repo; } private async Task AvatarClear(AvatarLocation location, Context ctx, PKMember target, MemberGuildSettings? mgs) @@ -67,14 +69,14 @@ namespace PluralKit.Bot public async Task ServerAvatar(Context ctx, PKMember target) { ctx.CheckGuildContext(); - var guildData = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); + var guildData = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); await AvatarCommandTree(AvatarLocation.Server, ctx, target, guildData); } public async Task Avatar(Context ctx, PKMember target) { var guildData = ctx.Guild != null ? - await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)) + await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)) : null; await AvatarCommandTree(AvatarLocation.Member, ctx, target, guildData); @@ -150,10 +152,10 @@ namespace PluralKit.Bot { case AvatarLocation.Server: var serverPatch = new MemberGuildPatch { AvatarUrl = url }; - return _db.Execute(c => c.UpsertMemberGuild(target.Id, ctx.Guild.Id, serverPatch)); + return _db.Execute(c => _repo.UpsertMemberGuild(c, target.Id, ctx.Guild.Id, serverPatch)); case AvatarLocation.Member: var memberPatch = new MemberPatch { AvatarUrl = url }; - return _db.Execute(c => c.UpdateMember(target.Id, memberPatch)); + return _db.Execute(c => _repo.UpdateMember(c, target.Id, memberPatch)); default: throw new ArgumentOutOfRangeException($"Unknown avatar location {location}"); } diff --git a/PluralKit.Bot/Commands/MemberEdit.cs b/PluralKit.Bot/Commands/MemberEdit.cs index 7fc16590..727af9de 100644 --- a/PluralKit.Bot/Commands/MemberEdit.cs +++ b/PluralKit.Bot/Commands/MemberEdit.cs @@ -15,13 +15,13 @@ namespace PluralKit.Bot { public class MemberEdit { - private readonly IDataStore _data; private readonly IDatabase _db; + private readonly ModelRepository _repo; - public MemberEdit(IDataStore data, IDatabase db) + public MemberEdit(IDatabase db, ModelRepository repo) { - _data = data; _db = db; + _repo = repo; } public async Task Name(Context ctx, PKMember target) { @@ -35,7 +35,7 @@ namespace PluralKit.Bot if (newName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(newName.Length); // Warn if there's already a member by this name - var existingMember = await _data.GetMemberByName(ctx.System, newName); + var existingMember = await _db.Execute(conn => _repo.GetMemberByName(conn, ctx.System.Id, newName)); if (existingMember != null && existingMember.Id != target.Id) { var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (`{existingMember.Hid}`). Do you want to rename this member to that name too?"; @@ -44,7 +44,7 @@ namespace PluralKit.Bot // Rename the member var patch = new MemberPatch {Name = Partial.Present(newName)}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member renamed."); if (newName.Contains(" ")) await ctx.Reply($"{Emojis.Note} Note that this member's name now contains spaces. You will need to surround it with \"double quotes\" when using commands referring to it."); @@ -52,7 +52,7 @@ namespace PluralKit.Bot if (ctx.Guild != null) { - var memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); + var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); if (memberGuildConfig.DisplayName != null) await ctx.Reply($"{Emojis.Note} Note that this member has a server name set ({memberGuildConfig.DisplayName}) in this server ({ctx.Guild.Name}), and will be proxied using that name here."); } @@ -69,7 +69,7 @@ namespace PluralKit.Bot CheckEditMemberPermission(ctx, target); var patch = new MemberPatch {Description = Partial.Null()}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member description cleared."); } else if (!ctx.HasNext()) @@ -100,7 +100,7 @@ namespace PluralKit.Bot throw Errors.DescriptionTooLongError(description.Length); var patch = new MemberPatch {Description = Partial.Present(description)}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member description changed."); } @@ -111,7 +111,7 @@ namespace PluralKit.Bot { CheckEditMemberPermission(ctx, target); var patch = new MemberPatch {Pronouns = Partial.Null()}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member pronouns cleared."); } else if (!ctx.HasNext()) @@ -136,7 +136,7 @@ namespace PluralKit.Bot throw Errors.MemberPronounsTooLongError(pronouns.Length); var patch = new MemberPatch {Pronouns = Partial.Present(pronouns)}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member pronouns changed."); } @@ -150,7 +150,7 @@ namespace PluralKit.Bot CheckEditMemberPermission(ctx, target); var patch = new MemberPatch {Color = Partial.Null()}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member color cleared."); } @@ -182,7 +182,7 @@ namespace PluralKit.Bot if (!Regex.IsMatch(color, "^[0-9a-fA-F]{6}$")) throw Errors.InvalidColorError(color); var patch = new MemberPatch {Color = Partial.Present(color.ToLowerInvariant())}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply(embed: new DiscordEmbedBuilder() .WithTitle($"{Emojis.Success} Member color changed.") @@ -198,7 +198,7 @@ namespace PluralKit.Bot CheckEditMemberPermission(ctx, target); var patch = new MemberPatch {Birthday = Partial.Null()}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member birthdate cleared."); } @@ -223,7 +223,7 @@ namespace PluralKit.Bot if (birthday == null) throw Errors.BirthdayParseError(birthdayStr); var patch = new MemberPatch {Birthday = Partial.Present(birthday)}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member birthdate changed."); } @@ -235,7 +235,7 @@ namespace PluralKit.Bot MemberGuildSettings memberGuildConfig = null; if (ctx.Guild != null) - memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); + memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); var eb = new DiscordEmbedBuilder().WithTitle($"Member names") .WithFooter($"Member ID: {target.Hid} | Active name in bold. Server name overrides display name, which overrides base name."); @@ -271,7 +271,7 @@ namespace PluralKit.Bot var successStr = text; if (ctx.Guild != null) { - var memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); + var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); if (memberGuildConfig.DisplayName != null) successStr += $" However, this member has a server name set in this server ({ctx.Guild.Name}), and will be proxied using that name, \"{memberGuildConfig.DisplayName}\", here."; } @@ -284,7 +284,7 @@ namespace PluralKit.Bot CheckEditMemberPermission(ctx, target); var patch = new MemberPatch {DisplayName = Partial.Null()}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await PrintSuccess($"{Emojis.Success} Member display name cleared. This member will now be proxied using their member name \"{target.NameFor(ctx)}\"."); } @@ -303,7 +303,7 @@ namespace PluralKit.Bot var newDisplayName = ctx.RemainderOrNull(); var patch = new MemberPatch {DisplayName = Partial.Present(newDisplayName)}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await PrintSuccess($"{Emojis.Success} Member display name changed. This member will now be proxied using the name \"{newDisplayName}\"."); } @@ -318,7 +318,7 @@ namespace PluralKit.Bot CheckEditMemberPermission(ctx, target); var patch = new MemberGuildPatch {DisplayName = null}; - await _db.Execute(conn => conn.UpsertMemberGuild(target.Id, ctx.Guild.Id, patch)); + await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch)); if (target.DisplayName != null) await ctx.Reply($"{Emojis.Success} Member server name cleared. This member will now be proxied using their global display name \"{target.DisplayName}\" in this server ({ctx.Guild.Name})."); @@ -340,7 +340,7 @@ namespace PluralKit.Bot var newServerName = ctx.RemainderOrNull(); var patch = new MemberGuildPatch {DisplayName = newServerName}; - await _db.Execute(conn => conn.UpsertMemberGuild(target.Id, ctx.Guild.Id, patch)); + await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch)); await ctx.Reply($"{Emojis.Success} Member server name changed. This member will now be proxied using the name \"{newServerName}\" in this server ({ctx.Guild.Name})."); } @@ -365,7 +365,7 @@ namespace PluralKit.Bot }; var patch = new MemberPatch {KeepProxy = Partial.Present(newValue)}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); if (newValue) await ctx.Reply($"{Emojis.Success} Member proxy tags will now be included in the resulting message when proxying."); @@ -398,11 +398,11 @@ namespace PluralKit.Bot // Get guild settings (mostly for warnings and such) MemberGuildSettings guildSettings = null; if (ctx.Guild != null) - guildSettings = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); + guildSettings = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id)); async Task SetAll(PrivacyLevel level) { - await _db.Execute(c => c.UpdateMember(target.Id, new MemberPatch().WithAllPrivacy(level))); + await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithAllPrivacy(level))); if (level == PrivacyLevel.Private) await ctx.Reply($"{Emojis.Success} All {target.NameFor(ctx)}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the member card."); @@ -412,7 +412,7 @@ namespace PluralKit.Bot async Task SetLevel(MemberPrivacySubject subject, PrivacyLevel level) { - await _db.Execute(c => c.UpdateMember(target.Id, new MemberPatch().WithPrivacy(subject, level))); + await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithPrivacy(subject, level))); var subjectName = subject switch { @@ -472,7 +472,7 @@ namespace PluralKit.Bot await ctx.Reply($"{Emojis.Warn} Are you sure you want to delete \"{target.NameFor(ctx)}\"? If so, reply to this message with the member's ID (`{target.Hid}`). __***This cannot be undone!***__"); if (!await ctx.ConfirmWithReply(target.Hid)) throw Errors.MemberDeleteCancelled; - await _db.Execute(conn => conn.DeleteMember(target.Id)); + await _db.Execute(conn => _repo.DeleteMember(conn, target.Id)); await ctx.Reply($"{Emojis.Success} Member deleted."); } diff --git a/PluralKit.Bot/Commands/MemberProxy.cs b/PluralKit.Bot/Commands/MemberProxy.cs index c80f3c3b..3df730b1 100644 --- a/PluralKit.Bot/Commands/MemberProxy.cs +++ b/PluralKit.Bot/Commands/MemberProxy.cs @@ -10,10 +10,12 @@ namespace PluralKit.Bot public class MemberProxy { private readonly IDatabase _db; + private readonly ModelRepository _repo; - public MemberProxy(IDatabase db) + public MemberProxy(IDatabase db, ModelRepository repo) { _db = db; + _repo = repo; } public async Task Proxy(Context ctx, PKMember target) @@ -55,7 +57,7 @@ namespace PluralKit.Bot } var patch = new MemberPatch {ProxyTags = Partial.Present(new ProxyTag[0])}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Proxy tags cleared."); } @@ -83,7 +85,7 @@ namespace PluralKit.Bot var newTags = target.ProxyTags.ToList(); newTags.Add(tagToAdd); var patch = new MemberPatch {ProxyTags = Partial.Present(newTags.ToArray())}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Added proxy tags {tagToAdd.ProxyString.AsCode()}."); } @@ -100,7 +102,7 @@ namespace PluralKit.Bot var newTags = target.ProxyTags.ToList(); newTags.Remove(tagToRemove); var patch = new MemberPatch {ProxyTags = Partial.Present(newTags.ToArray())}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Removed proxy tags {tagToRemove.ProxyString.AsCode()}."); } @@ -124,7 +126,7 @@ namespace PluralKit.Bot var newTags = new[] {requestedTag}; var patch = new MemberPatch {ProxyTags = Partial.Present(newTags)}; - await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); + await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch)); await ctx.Reply($"{Emojis.Success} Member proxy tags set to {requestedTag.ProxyString.AsCode()}."); } diff --git a/PluralKit.Bot/Commands/Misc.cs b/PluralKit.Bot/Commands/Misc.cs index b63cae5d..9b0337f0 100644 --- a/PluralKit.Bot/Commands/Misc.cs +++ b/PluralKit.Bot/Commands/Misc.cs @@ -18,21 +18,23 @@ using DSharpPlus.Entities; namespace PluralKit.Bot { public class Misc { - private BotConfig _botConfig; - private IMetrics _metrics; - private CpuStatService _cpu; - private ShardInfoService _shards; - private IDataStore _data; - private EmbedService _embeds; + private readonly BotConfig _botConfig; + private readonly IMetrics _metrics; + private readonly CpuStatService _cpu; + private readonly ShardInfoService _shards; + private readonly EmbedService _embeds; + private readonly IDatabase _db; + private readonly ModelRepository _repo; - public Misc(BotConfig botConfig, IMetrics metrics, CpuStatService cpu, ShardInfoService shards, IDataStore data, EmbedService embeds) + public Misc(BotConfig botConfig, IMetrics metrics, CpuStatService cpu, ShardInfoService shards, EmbedService embeds, ModelRepository repo, IDatabase db) { _botConfig = botConfig; _metrics = metrics; _cpu = cpu; _shards = shards; - _data = data; _embeds = embeds; + _repo = repo; + _db = db; } public async Task Invite(Context ctx) @@ -198,7 +200,7 @@ namespace PluralKit.Bot { messageId = ulong.Parse(match.Groups[1].Value); else throw new PKSyntaxError($"Could not parse {word.AsCode()} as a message ID or link."); - var message = await _data.GetMessage(messageId); + var message = await _db.Execute(c => _repo.GetMessage(c, messageId)); if (message == null) throw Errors.MessageNotFound(messageId); await ctx.Reply(embed: await _embeds.CreateMessageInfoEmbed(ctx.Shard, message)); diff --git a/PluralKit.Bot/Commands/ServerConfig.cs b/PluralKit.Bot/Commands/ServerConfig.cs index 1102448f..167cc089 100644 --- a/PluralKit.Bot/Commands/ServerConfig.cs +++ b/PluralKit.Bot/Commands/ServerConfig.cs @@ -12,12 +12,14 @@ namespace PluralKit.Bot { public class ServerConfig { - private IDatabase _db; - private LoggerCleanService _cleanService; - public ServerConfig(LoggerCleanService cleanService, IDatabase db) + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly LoggerCleanService _cleanService; + public ServerConfig(LoggerCleanService cleanService, IDatabase db, ModelRepository repo) { _cleanService = cleanService; _db = db; + _repo = repo; } public async Task SetLogChannel(Context ctx) @@ -32,7 +34,7 @@ namespace PluralKit.Bot if (channel == null || channel.GuildId != ctx.Guild.Id) throw Errors.ChannelNotFound(channelString); var patch = new GuildPatch {LogChannel = channel?.Id}; - await _db.Execute(conn => conn.UpsertGuild(ctx.Guild.Id, patch)); + await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch)); if (channel != null) await ctx.Reply($"{Emojis.Success} Proxy logging channel set to #{channel.Name}."); @@ -59,7 +61,7 @@ namespace PluralKit.Bot ulong? logChannel = null; await using (var conn = await _db.Obtain()) { - var config = await conn.QueryOrInsertGuildConfig(ctx.Guild.Id); + var config = await _repo.GetGuild(conn, ctx.Guild.Id); logChannel = config.LogChannel; var blacklist = config.LogBlacklist.ToHashSet(); if (enable) @@ -68,7 +70,7 @@ namespace PluralKit.Bot blacklist.UnionWith(affectedChannels.Select(c => c.Id)); var patch = new GuildPatch {LogBlacklist = blacklist.ToArray()}; - await conn.UpsertGuild(ctx.Guild.Id, patch); + await _repo.UpsertGuild(conn, ctx.Guild.Id, patch); } await ctx.Reply( @@ -80,7 +82,7 @@ namespace PluralKit.Bot { ctx.CheckGuildContext().CheckAuthorPermission(Permissions.ManageGuild, "Manage Server"); - var blacklist = await _db.Execute(c => c.QueryOrInsertGuildConfig(ctx.Guild.Id)); + var blacklist = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id)); // Resolve all channels from the cache and order by position var channels = blacklist.Blacklist @@ -139,7 +141,7 @@ namespace PluralKit.Bot await using (var conn = await _db.Obtain()) { - var guild = await conn.QueryOrInsertGuildConfig(ctx.Guild.Id); + var guild = await _repo.GetGuild(conn, ctx.Guild.Id); var blacklist = guild.Blacklist.ToHashSet(); if (shouldAdd) blacklist.UnionWith(affectedChannels.Select(c => c.Id)); @@ -147,7 +149,7 @@ namespace PluralKit.Bot blacklist.ExceptWith(affectedChannels.Select(c => c.Id)); var patch = new GuildPatch {Blacklist = blacklist.ToArray()}; - await conn.UpsertGuild(ctx.Guild.Id, patch); + await _repo.UpsertGuild(conn, ctx.Guild.Id, patch); } await ctx.Reply($"{Emojis.Success} Channels {(shouldAdd ? "added to" : "removed from")} the proxy blacklist."); @@ -170,7 +172,7 @@ namespace PluralKit.Bot .WithTitle("Log cleanup settings") .AddField("Supported bots", botList); - var guildCfg = await _db.Execute(c => c.QueryOrInsertGuildConfig(ctx.Guild.Id)); + var guildCfg = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id)); if (guildCfg.LogCleanupEnabled) eb.WithDescription("Log cleanup is currently **on** for this server. To disable it, type `pk;logclean off`."); else @@ -180,7 +182,7 @@ namespace PluralKit.Bot } var patch = new GuildPatch {LogCleanupEnabled = newValue}; - await _db.Execute(conn => conn.UpsertGuild(ctx.Guild.Id, patch)); + await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch)); if (newValue) await ctx.Reply($"{Emojis.Success} Log cleanup has been **enabled** for this server. Messages deleted by PluralKit will now be cleaned up from logging channels managed by the following bots:\n- **{botList}**\n\n{Emojis.Note} Make sure PluralKit has the **Manage Messages** permission in the channels in question.\n{Emojis.Note} Also, make sure to blacklist the logging channel itself from the bots in question to prevent conflicts."); diff --git a/PluralKit.Bot/Commands/Switch.cs b/PluralKit.Bot/Commands/Switch.cs index 7a2251bf..49d48a22 100644 --- a/PluralKit.Bot/Commands/Switch.cs +++ b/PluralKit.Bot/Commands/Switch.cs @@ -13,11 +13,13 @@ namespace PluralKit.Bot { public class Switch { - private IDataStore _data; + private readonly IDatabase _db; + private readonly ModelRepository _repo; - public Switch(IDataStore data) + public Switch(IDatabase db, ModelRepository repo) { - _data = data; + _db = db; + _repo = repo; } public async Task SwitchDo(Context ctx) @@ -42,16 +44,17 @@ namespace PluralKit.Bot if (members.Select(m => m.Id).Distinct().Count() != members.Count) throw Errors.DuplicateSwitchMembers; // Find the last switch and its members if applicable - var lastSwitch = await _data.GetLatestSwitch(ctx.System.Id); + await using var conn = await _db.Obtain(); + var lastSwitch = await _repo.GetLatestSwitch(conn, ctx.System.Id); if (lastSwitch != null) { - var lastSwitchMembers = _data.GetSwitchMembers(lastSwitch); + var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastSwitch.Id); // Make sure the requested switch isn't identical to the last one if (await lastSwitchMembers.Select(m => m.Id).SequenceEqualAsync(members.Select(m => m.Id).ToAsyncEnumerable())) throw Errors.SameSwitch(members, ctx.LookupContextFor(ctx.System)); } - await _data.AddSwitch(ctx.System.Id, members); + await _repo.AddSwitch(conn, ctx.System.Id, members.Select(m => m.Id).ToList()); if (members.Count == 0) await ctx.Reply($"{Emojis.Success} Switch-out registered."); @@ -68,12 +71,14 @@ namespace PluralKit.Bot var result = DateUtils.ParseDateTime(timeToMove, true, tz); if (result == null) throw Errors.InvalidDateTime(timeToMove); + + await using var conn = await _db.Obtain(); var time = result.Value; if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture; // Fetch the last two switches for the system to do bounds checking on - var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); + var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync(); // If we don't have a switch to move, don't bother if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; @@ -87,7 +92,7 @@ namespace PluralKit.Bot // Now we can actually do the move, yay! // But, we do a prompt to confirm. - var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]); + var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var lastSwitchTimeStr = lastTwoSwitches[0].Timestamp.FormatZoned(ctx.System); var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration(); @@ -99,7 +104,7 @@ namespace PluralKit.Bot if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchMoveCancelled; // aaaand *now* we do the move - await _data.MoveSwitch(lastTwoSwitches[0], time.ToInstant()); + await _repo.MoveSwitch(conn, lastTwoSwitches[0].Id, time.ToInstant()); await ctx.Reply($"{Emojis.Success} Switch moved."); } @@ -113,16 +118,18 @@ namespace PluralKit.Bot var purgeMsg = $"{Emojis.Warn} This will delete *all registered switches* in your system. Are you sure you want to proceed?"; if (!await ctx.PromptYesNo(purgeMsg)) throw Errors.GenericCancelled(); - await _data.DeleteAllSwitches(ctx.System); + await _db.Execute(c => _repo.DeleteAllSwitches(c, ctx.System.Id)); await ctx.Reply($"{Emojis.Success} Cleared system switches!"); return; } + await using var conn = await _db.Obtain(); + // Fetch the last two switches for the system to do bounds checking on - var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); + var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync(); if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; - var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]); + var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration(); @@ -133,14 +140,14 @@ namespace PluralKit.Bot } else { - var secondSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[1]); + var secondSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[1].Id); var secondSwitchMemberStr = string.Join(", ", await secondSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var secondSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[1].Timestamp).FormatDuration(); msg = $"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr}, {lastSwitchDeltaStr} ago). The next latest switch is {secondSwitchMemberStr} ({secondSwitchDeltaStr} ago). Is this okay?"; } if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchDeleteCancelled; - await _data.DeleteSwitch(lastTwoSwitches[0]); + await _repo.DeleteSwitch(conn, lastTwoSwitches[0].Id); await ctx.Reply($"{Emojis.Success} Switch deleted."); } diff --git a/PluralKit.Bot/Commands/System.cs b/PluralKit.Bot/Commands/System.cs index 4923adac..509f2b59 100644 --- a/PluralKit.Bot/Commands/System.cs +++ b/PluralKit.Bot/Commands/System.cs @@ -6,13 +6,15 @@ namespace PluralKit.Bot { public class System { - private IDataStore _data; - private EmbedService _embeds; + private readonly EmbedService _embeds; + private readonly IDatabase _db; + private readonly ModelRepository _repo; - public System(EmbedService embeds, IDataStore data) + public System(EmbedService embeds, IDatabase db, ModelRepository repo) { _embeds = embeds; - _data = data; + _db = db; + _repo = repo; } public async Task Query(Context ctx, PKSystem system) { @@ -28,9 +30,15 @@ namespace PluralKit.Bot var systemName = ctx.RemainderOrNull(); if (systemName != null && systemName.Length > Limits.MaxSystemNameLength) throw Errors.SystemNameTooLongError(systemName.Length); + + var system = _db.Execute(async c => + { + var system = await _repo.CreateSystem(c, systemName); + await _repo.AddAccount(c, system.Id, ctx.Author.Id); + return system; + }); - var system = await _data.CreateSystem(systemName); - await _data.AddAccount(system, ctx.Author.Id); + // TODO: better message, perhaps embed like in groups? await ctx.Reply($"{Emojis.Success} Your system has been created. Type `pk;system` to view it, and type `pk;system help` for more information about commands you can use now. Now that you have that set up, check out the getting started guide on setting up members and proxies: "); } } diff --git a/PluralKit.Bot/Commands/SystemEdit.cs b/PluralKit.Bot/Commands/SystemEdit.cs index cea4e113..0ab656eb 100644 --- a/PluralKit.Bot/Commands/SystemEdit.cs +++ b/PluralKit.Bot/Commands/SystemEdit.cs @@ -19,15 +19,13 @@ namespace PluralKit.Bot { public class SystemEdit { - private IDataStore _data; - private IDatabase _db; - private EmbedService _embeds; + private readonly IDatabase _db; + private readonly ModelRepository _repo; - public SystemEdit(IDataStore data, EmbedService embeds, IDatabase db) + public SystemEdit(IDatabase db, ModelRepository repo) { - _data = data; - _embeds = embeds; _db = db; + _repo = repo; } public async Task Name(Context ctx) @@ -37,7 +35,7 @@ namespace PluralKit.Bot if (ctx.MatchClear()) { var clearPatch = new SystemPatch {Name = null}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, clearPatch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch)); await ctx.Reply($"{Emojis.Success} System name cleared."); return; @@ -57,7 +55,7 @@ namespace PluralKit.Bot throw Errors.SystemNameTooLongError(newSystemName.Length); var patch = new SystemPatch {Name = newSystemName}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply($"{Emojis.Success} System name changed."); } @@ -68,7 +66,7 @@ namespace PluralKit.Bot if (ctx.MatchClear()) { var patch = new SystemPatch {Description = null}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply($"{Emojis.Success} System description cleared."); return; @@ -93,7 +91,7 @@ namespace PluralKit.Bot if (newDescription.Length > Limits.MaxDescriptionLength) throw Errors.DescriptionTooLongError(newDescription.Length); var patch = new SystemPatch {Description = newDescription}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply($"{Emojis.Success} System description changed."); } @@ -106,7 +104,7 @@ namespace PluralKit.Bot if (ctx.MatchClear()) { var patch = new SystemPatch {Tag = null}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply($"{Emojis.Success} System tag cleared."); } else if (!ctx.HasNext(skipFlags: false)) @@ -124,7 +122,7 @@ namespace PluralKit.Bot throw Errors.SystemNameTooLongError(newTag.Length); var patch = new SystemPatch {Tag = newTag}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply($"{Emojis.Success} System tag changed. Member names will now end with {newTag.AsCode()} when proxied."); } @@ -136,7 +134,7 @@ namespace PluralKit.Bot async Task ClearIcon() { - await _db.Execute(c => c.UpdateSystem(ctx.System.Id, new SystemPatch {AvatarUrl = null})); + await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch {AvatarUrl = null})); await ctx.Reply($"{Emojis.Success} System icon cleared."); } @@ -146,7 +144,7 @@ namespace PluralKit.Bot throw Errors.InvalidUrl(img.Url); await AvatarUtils.VerifyAvatarOrThrow(img.Url); - await _db.Execute(c => c.UpdateSystem(ctx.System.Id, new SystemPatch {AvatarUrl = img.Url})); + await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch {AvatarUrl = img.Url})); var msg = img.Source switch { @@ -192,7 +190,7 @@ namespace PluralKit.Bot if (!await ctx.ConfirmWithReply(ctx.System.Hid)) throw new PKError($"System deletion cancelled. Note that you must reply with your system ID (`{ctx.System.Hid}`) *verbatim*."); - await _db.Execute(conn => conn.DeleteSystem(ctx.System.Id)); + await _db.Execute(conn => _repo.DeleteSystem(conn, ctx.System.Id)); await ctx.Reply($"{Emojis.Success} System deleted."); } @@ -200,7 +198,7 @@ namespace PluralKit.Bot public async Task SystemProxy(Context ctx) { ctx.CheckSystem().CheckGuildContext(); - var gs = await _db.Execute(c => c.QueryOrInsertSystemGuildConfig(ctx.Guild.Id, ctx.System.Id)); + var gs = await _db.Execute(c => _repo.GetSystemGuild(c, ctx.Guild.Id, ctx.System.Id)); bool newValue; if (ctx.Match("on", "enabled", "true", "yes")) newValue = true; @@ -216,7 +214,7 @@ namespace PluralKit.Bot } var patch = new SystemGuildPatch {ProxyEnabled = newValue}; - await _db.Execute(conn => conn.UpsertSystemGuild(ctx.System.Id, ctx.Guild.Id, patch)); + await _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch)); if (newValue) await ctx.Reply($"Message proxying in this server ({ctx.Guild.Name.EscapeMarkdown()}) is now **enabled** for your system."); @@ -231,7 +229,7 @@ namespace PluralKit.Bot if (ctx.MatchClear()) { var clearPatch = new SystemPatch {UiTz = "UTC"}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, clearPatch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch)); await ctx.Reply($"{Emojis.Success} System time zone cleared (set to UTC)."); return; @@ -253,7 +251,7 @@ namespace PluralKit.Bot if (!await ctx.PromptYesNo(msg)) throw Errors.TimezoneChangeCancelled; var patch = new SystemPatch {UiTz = zone.Id}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply($"System time zone changed to **{zone.Id}**."); } @@ -277,7 +275,7 @@ namespace PluralKit.Bot async Task SetLevel(SystemPrivacySubject subject, PrivacyLevel level) { - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, new SystemPatch().WithPrivacy(subject, level))); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithPrivacy(subject, level))); var levelExplanation = level switch { @@ -302,7 +300,7 @@ namespace PluralKit.Bot async Task SetAll(PrivacyLevel level) { - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, new SystemPatch().WithAllPrivacy(level))); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithAllPrivacy(level))); var msg = level switch { @@ -334,13 +332,13 @@ namespace PluralKit.Bot else { if (ctx.Match("on", "enable")) { var patch = new SystemPatch {PingsEnabled = true}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply("Reaction pings have now been enabled."); } if (ctx.Match("off", "disable")) { var patch = new SystemPatch {PingsEnabled = false}; - await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); + await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch)); await ctx.Reply("Reaction pings have now been disabled."); } diff --git a/PluralKit.Bot/Commands/SystemFront.cs b/PluralKit.Bot/Commands/SystemFront.cs index 18887e7d..9dfc15da 100644 --- a/PluralKit.Bot/Commands/SystemFront.cs +++ b/PluralKit.Bot/Commands/SystemFront.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -10,19 +11,21 @@ namespace PluralKit.Bot { public class SystemFront { - private IDataStore _data; - private EmbedService _embeds; + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly EmbedService _embeds; - public SystemFront(IDataStore data, EmbedService embeds) + public SystemFront(EmbedService embeds, IDatabase db, ModelRepository repo) { - _data = data; _embeds = embeds; + _db = db; + _repo = repo; } struct FrontHistoryEntry { - public Instant? LastTime; - public PKSwitch ThisSwitch; + public readonly Instant? LastTime; + public readonly PKSwitch ThisSwitch; public FrontHistoryEntry(Instant? lastTime, PKSwitch thisSwitch) { @@ -35,8 +38,10 @@ namespace PluralKit.Bot { if (system == null) throw Errors.NoSystemError; ctx.CheckSystemPrivacy(system, system.FrontPrivacy); + + await using var conn = await _db.Obtain(); - var sw = await _data.GetLatestSwitch(system.Id); + var sw = await _repo.GetLatestSwitch(conn, system.Id); if (sw == null) throw Errors.NoRegisteredSwitches; await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone, ctx.LookupContextFor(system))); @@ -47,11 +52,16 @@ namespace PluralKit.Bot if (system == null) throw Errors.NoSystemError; ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy); - var sws = _data.GetSwitches(system.Id) - .Scan(new FrontHistoryEntry(null, null), (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch)); - var totalSwitches = await _data.GetSwitchCount(system); - if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; + // Gotta be careful here: if we dispose of the connection while the IAE is alive, boom + await using var conn = await _db.Obtain(); + var totalSwitches = await _repo.GetSwitchCount(conn, system.Id); + if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; + + var sws = _repo.GetSwitches(conn, system.Id) + .Scan(new FrontHistoryEntry(null, null), + (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch)); + var embedTitle = system.Name != null ? $"Front history of {system.Name} (`{system.Hid}`)" : $"Front history of `{system.Hid}`"; await ctx.Paginate( @@ -66,8 +76,11 @@ namespace PluralKit.Bot var lastSw = entry.LastTime; var sw = entry.ThisSwitch; + // Fetch member list and format - var members = await _data.GetSwitchMembers(sw).ToListAsync(); + await using var conn = await _db.Obtain(); + + var members = await _db.Execute(c => _repo.GetSwitchMembers(c, sw.Id)).ToListAsync(); var membersStr = members.Any() ? string.Join(", ", members.Select(m => m.NameFor(ctx))) : "no fronter"; var switchSince = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp; @@ -111,8 +124,8 @@ namespace PluralKit.Bot var rangeStart = DateUtils.ParseDateTime(durationStr, true, system.Zone); if (rangeStart == null) throw Errors.InvalidDateTime(durationStr); if (rangeStart.Value.ToInstant() > now) throw Errors.FrontPercentTimeInFuture; - - var frontpercent = await _data.GetFrontBreakdown(system, rangeStart.Value.ToInstant(), now); + + var frontpercent = await _db.Execute(c => _repo.GetFrontBreakdown(c, system.Id, rangeStart.Value.ToInstant(), now)); await ctx.Reply(embed: await _embeds.CreateFrontPercentEmbed(frontpercent, system.Zone, ctx.LookupContextFor(system))); } } diff --git a/PluralKit.Bot/Commands/SystemLink.cs b/PluralKit.Bot/Commands/SystemLink.cs index 3a3430b7..af3cd053 100644 --- a/PluralKit.Bot/Commands/SystemLink.cs +++ b/PluralKit.Bot/Commands/SystemLink.cs @@ -9,28 +9,34 @@ namespace PluralKit.Bot { public class SystemLink { - private IDataStore _data; + private readonly IDatabase _db; + private readonly ModelRepository _repo; - public SystemLink(IDataStore data) + public SystemLink(IDatabase db, ModelRepository repo) { - _data = data; + _db = db; + _repo = repo; } public async Task LinkSystem(Context ctx) { ctx.CheckSystem(); + + await using var conn = await _db.Obtain(); var account = await ctx.MatchUser() ?? throw new PKSyntaxError("You must pass an account to link with (either ID or @mention)."); - var accountIds = await _data.GetSystemAccounts(ctx.System); - if (accountIds.Contains(account.Id)) throw Errors.AccountAlreadyLinked; + var accountIds = await _repo.GetSystemAccounts(conn, ctx.System.Id); + if (accountIds.Contains(account.Id)) + throw Errors.AccountAlreadyLinked; - var existingAccount = await _data.GetSystemByAccount(account.Id); - if (existingAccount != null) throw Errors.AccountInOtherSystem(existingAccount); + var existingAccount = await _repo.GetSystemByAccount(conn, account.Id); + if (existingAccount != null) + throw Errors.AccountInOtherSystem(existingAccount); var msg = $"{account.Mention}, please confirm the link by clicking the {Emojis.Success} reaction on this message."; var mentions = new IMention[] { new UserMention(account) }; if (!await ctx.PromptYesNo(msg, user: account, mentions: mentions)) throw Errors.MemberLinkCancelled; - await _data.AddAccount(ctx.System, account.Id); + await _repo.AddAccount(conn, ctx.System.Id, account.Id); await ctx.Reply($"{Emojis.Success} Account linked to system."); } @@ -38,20 +44,22 @@ namespace PluralKit.Bot { ctx.CheckSystem(); + await using var conn = await _db.Obtain(); + ulong id; if (!ctx.HasNext()) id = ctx.Author.Id; else if (!ctx.MatchUserRaw(out id)) throw new PKSyntaxError("You must pass an account to link with (either ID or @mention)."); - var accountIds = (await _data.GetSystemAccounts(ctx.System)).ToList(); + var accountIds = (await _repo.GetSystemAccounts(conn, ctx.System.Id)).ToList(); if (!accountIds.Contains(id)) throw Errors.AccountNotLinked; if (accountIds.Count == 1) throw Errors.UnlinkingLastAccount; var msg = $"Are you sure you want to unlink <@{id}> from your system?"; if (!await ctx.PromptYesNo(msg)) throw Errors.MemberUnlinkCancelled; - await _data.RemoveAccount(ctx.System, id); + await _repo.RemoveAccount(conn, ctx.System.Id, id); await ctx.Reply($"{Emojis.Success} Account unlinked."); } } diff --git a/PluralKit.Bot/Commands/Token.cs b/PluralKit.Bot/Commands/Token.cs index 2cbb8142..e6a26cb9 100644 --- a/PluralKit.Bot/Commands/Token.cs +++ b/PluralKit.Bot/Commands/Token.cs @@ -10,9 +10,11 @@ namespace PluralKit.Bot public class Token { private readonly IDatabase _db; - public Token(IDatabase db) + private readonly ModelRepository _repo; + public Token(IDatabase db, ModelRepository repo) { _db = db; + _repo = repo; } public async Task GetToken(Context ctx) @@ -45,7 +47,7 @@ namespace PluralKit.Bot private async Task MakeAndSetNewToken(PKSystem system) { var patch = new SystemPatch {Token = StringUtils.GenerateToken()}; - system = await _db.Execute(conn => conn.UpdateSystem(system.Id, patch)); + system = await _db.Execute(conn => _repo.UpdateSystem(conn, system.Id, patch)); return system.Token; } diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index 64ece7fe..9ed335da 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -23,11 +23,12 @@ namespace PluralKit.Bot private readonly ProxyService _proxy; private readonly ILifetimeScope _services; private readonly IDatabase _db; + private readonly ModelRepository _repo; private readonly BotConfig _config; public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, IMetrics metrics, ProxyService proxy, DiscordShardedClient client, - CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config) + CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, ModelRepository repo) { _lastMessageCache = lastMessageCache; _loggerClean = loggerClean; @@ -38,6 +39,7 @@ namespace PluralKit.Bot _services = services; _db = db; _config = config; + _repo = repo; } public DiscordChannel ErrorChannelFor(MessageCreateEventArgs evt) => evt.Channel; @@ -59,7 +61,7 @@ namespace PluralKit.Bot MessageContext ctx; await using (var conn = await _db.Obtain()) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) - ctx = await conn.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id); + ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id); // Try each handler until we find one that succeeds if (await TryHandleLogClean(evt, ctx)) @@ -98,7 +100,7 @@ namespace PluralKit.Bot try { - var system = ctx.SystemId != null ? await _db.Execute(c => c.QuerySystem(ctx.SystemId.Value)) : null; + var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null; await _tree.ExecuteCommand(new Context(_services, evt.Client, evt.Message, cmdStart, system, ctx)); } catch (PKError) diff --git a/PluralKit.Bot/Handlers/MessageDeleted.cs b/PluralKit.Bot/Handlers/MessageDeleted.cs index 6a8bbc38..6869fa0b 100644 --- a/PluralKit.Bot/Handlers/MessageDeleted.cs +++ b/PluralKit.Bot/Handlers/MessageDeleted.cs @@ -12,28 +12,29 @@ namespace PluralKit.Bot // Double duty :) public class MessageDeleted: IEventHandler, IEventHandler { - private readonly IDataStore _data; + private readonly IDatabase _db; + private readonly ModelRepository _repo; private readonly ILogger _logger; - public MessageDeleted(IDataStore data, ILogger logger) + public MessageDeleted(ILogger logger, IDatabase db, ModelRepository repo) { - _data = data; + _db = db; + _repo = repo; _logger = logger.ForContext(); } public async Task Handle(MessageDeleteEventArgs evt) { // Delete deleted webhook messages from the data store - // (if we don't know whether it's a webhook, delete it just to be safe) - if (!evt.Message.WebhookMessage) return; - await _data.DeleteMessage(evt.Message.Id); + // Most of the data in the given message is wrong/missing, so always delete just to be sure. + await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id)); } public async Task Handle(MessageBulkDeleteEventArgs evt) { // Same as above, but bulk _logger.Information("Bulk deleting {Count} messages in channel {Channel}", evt.Messages.Count, evt.Channel.Id); - await _data.DeleteMessagesBulk(evt.Messages.Select(m => m.Id).ToList()); + await _db.Execute(c => _repo.DeleteMessagesBulk(c, evt.Messages.Select(m => m.Id).ToList())); } } } \ No newline at end of file diff --git a/PluralKit.Bot/Handlers/MessageEdited.cs b/PluralKit.Bot/Handlers/MessageEdited.cs index 6ed1615b..935190fc 100644 --- a/PluralKit.Bot/Handlers/MessageEdited.cs +++ b/PluralKit.Bot/Handlers/MessageEdited.cs @@ -14,14 +14,16 @@ namespace PluralKit.Bot private readonly LastMessageCacheService _lastMessageCache; private readonly ProxyService _proxy; private readonly IDatabase _db; + private readonly ModelRepository _repo; private readonly IMetrics _metrics; - public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, IDatabase db, IMetrics metrics) + public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, IDatabase db, IMetrics metrics, ModelRepository repo) { _lastMessageCache = lastMessageCache; _proxy = proxy; _db = db; _metrics = metrics; + _repo = repo; } public async Task Handle(MessageUpdateEventArgs evt) @@ -36,7 +38,7 @@ namespace PluralKit.Bot MessageContext ctx; await using (var conn = await _db.Obtain()) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) - ctx = await conn.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id); + ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id); await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: false); } } diff --git a/PluralKit.Bot/Handlers/ReactionAdded.cs b/PluralKit.Bot/Handlers/ReactionAdded.cs index ce98d0e8..5fac7357 100644 --- a/PluralKit.Bot/Handlers/ReactionAdded.cs +++ b/PluralKit.Bot/Handlers/ReactionAdded.cs @@ -13,14 +13,16 @@ namespace PluralKit.Bot { public class ReactionAdded: IEventHandler { - private IDataStore _data; - private EmbedService _embeds; - private ILogger _logger; + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly EmbedService _embeds; + private readonly ILogger _logger; - public ReactionAdded(IDataStore data, EmbedService embeds, ILogger logger) + public ReactionAdded(EmbedService embeds, ILogger logger, IDatabase db, ModelRepository repo) { - _data = data; _embeds = embeds; + _db = db; + _repo = repo; _logger = logger.ForContext(); } @@ -42,18 +44,21 @@ namespace PluralKit.Bot // Ignore reactions from bots (we can't DM them anyway) if (evt.User.IsBot) return; + Task GetMessage() => + _db.Execute(c => _repo.GetMessage(c, evt.Message.Id)); + FullMessage msg; switch (evt.Emoji.Name) { // Message deletion case "\u274C": // Red X - if ((msg = await _data.GetMessage(evt.Message.Id)) != null) + if ((msg = await GetMessage()) != null) await HandleDeleteReaction(evt, msg); break; case "\u2753": // Red question mark case "\u2754": // White question mark - if ((msg = await _data.GetMessage(evt.Message.Id)) != null) + if ((msg = await GetMessage()) != null) await HandleQueryReaction(evt, msg); break; @@ -62,7 +67,7 @@ namespace PluralKit.Bot case "\U0001F3D3": // Ping pong paddle (lol) case "\u23F0": // Alarm clock case "\u2757": // Exclamation mark - if ((msg = await _data.GetMessage(evt.Message.Id)) != null) + if ((msg = await GetMessage()) != null) await HandlePingReaction(evt, msg); break; } @@ -84,7 +89,7 @@ namespace PluralKit.Bot // Message was deleted by something/someone else before we got to it } - await _data.DeleteMessage(evt.Message.Id); + await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id)); } private async ValueTask HandleQueryReaction(MessageReactionAddEventArgs evt, FullMessage msg) diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index b616914c..6a9694f8 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -21,21 +21,21 @@ namespace PluralKit.Bot private readonly LogChannelService _logChannel; private readonly IDatabase _db; - private readonly IDataStore _data; + private readonly ModelRepository _repo; private readonly ILogger _logger; private readonly WebhookExecutorService _webhookExecutor; private readonly ProxyMatcher _matcher; private readonly IMetrics _metrics; - public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger, - WebhookExecutorService webhookExecutor, IDatabase db, ProxyMatcher matcher, IMetrics metrics) + public ProxyService(LogChannelService logChannel, ILogger logger, + WebhookExecutorService webhookExecutor, IDatabase db, ProxyMatcher matcher, IMetrics metrics, ModelRepository repo) { _logChannel = logChannel; - _data = data; _webhookExecutor = webhookExecutor; _db = db; _matcher = matcher; _metrics = metrics; + _repo = repo; _logger = logger.ForContext(); } @@ -48,7 +48,7 @@ namespace PluralKit.Bot List members; using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime)) - members = (await conn.QueryProxyMembers(message.Author.Id, message.Channel.GuildId)).ToList(); + members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.Channel.GuildId)).ToList(); if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0, allowAutoproxy)) return false; @@ -99,8 +99,17 @@ namespace PluralKit.Bot var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName(ctx), match.Member.ProxyAvatar(ctx), content, trigger.Attachments, allowEveryone); + + Task SaveMessage() => _repo.AddMessage(conn, new PKMessage + { + Channel = trigger.ChannelId, + Guild = trigger.Channel.GuildId, + Member = match.Member.Id, + Mid = id, + OriginalMid = trigger.Id, + Sender = trigger.Author.Id + }); - Task SaveMessage() => _data.AddMessage(conn, trigger.Author.Id, trigger.Channel.GuildId, trigger.ChannelId, id, trigger.Id, match.Member.Id); Task LogMessage() => _logChannel.LogMessage(ctx, match, trigger, id).AsTask(); async Task DeleteMessage() { diff --git a/PluralKit.Bot/Services/CpuStatService.cs b/PluralKit.Bot/Services/CpuStatService.cs index 00aebdb9..de161be7 100644 --- a/PluralKit.Bot/Services/CpuStatService.cs +++ b/PluralKit.Bot/Services/CpuStatService.cs @@ -7,7 +7,7 @@ namespace PluralKit.Bot { public class CpuStatService { - private ILogger _logger; + private readonly ILogger _logger; public double LastCpuMeasure { get; private set; } diff --git a/PluralKit.Bot/Services/EmbedService.cs b/PluralKit.Bot/Services/EmbedService.cs index e27e892c..da37da5b 100644 --- a/PluralKit.Bot/Services/EmbedService.cs +++ b/PluralKit.Bot/Services/EmbedService.cs @@ -15,40 +15,38 @@ using PluralKit.Core; namespace PluralKit.Bot { public class EmbedService { - private IDataStore _data; - private IDatabase _db; - private DiscordShardedClient _client; + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly DiscordShardedClient _client; - public EmbedService(DiscordShardedClient client, IDataStore data, IDatabase db) + public EmbedService(DiscordShardedClient client, IDatabase db, ModelRepository repo) { _client = client; - _data = data; _db = db; + _repo = repo; } - - public async Task CreateSystemEmbed(DiscordClient client, PKSystem system, LookupContext ctx) { await using var conn = await _db.Obtain(); // Fetch/render info for all accounts simultaneously - var accounts = await conn.GetLinkedAccounts(system.Id); + var accounts = await _repo.GetSystemAccounts(conn, system.Id); var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUser(uid))?.NameAndMention() ?? $"(deleted account {uid})")); - var memberCount = await conn.GetSystemMemberCount(system.Id, PrivacyLevel.Public); + var memberCount = await _repo.GetSystemMemberCount(conn, system.Id, PrivacyLevel.Public); var eb = new DiscordEmbedBuilder() .WithColor(DiscordUtils.Gray) .WithTitle(system.Name ?? null) .WithThumbnail(system.AvatarUrl) .WithFooter($"System ID: {system.Hid} | Created on {system.Created.FormatZoned(system)}"); - var latestSwitch = await _data.GetLatestSwitch(system.Id); + var latestSwitch = await _repo.GetLatestSwitch(conn, system.Id); if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx)) { - var switchMembers = await _data.GetSwitchMembers(latestSwitch).ToListAsync(); - if (switchMembers.Count > 0) - eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None), + var switchMembers = await _repo.GetSwitchMembers(conn, latestSwitch.Id).ToListAsync(); + if (switchMembers.Count > 0) + eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None), string.Join(", ", switchMembers.Select(m => m.NameFor(ctx)))); } @@ -105,11 +103,13 @@ namespace PluralKit.Bot { await using var conn = await _db.Obtain(); - var guildSettings = guild != null ? await conn.QueryOrInsertMemberGuildConfig(guild.Id, member.Id) : null; + var guildSettings = guild != null ? await _repo.GetMemberGuild(conn, guild.Id, member.Id) : null; var guildDisplayName = guildSettings?.DisplayName; var avatar = guildSettings?.AvatarUrl ?? member.AvatarFor(ctx); - var groups = (await conn.QueryMemberGroups(member.Id)).Where(g => g.Visibility.CanAccess(ctx)).ToList(); + var groups = await _repo.GetMemberGroups(conn, member.Id) + .Where(g => g.Visibility.CanAccess(ctx)) + .ToListAsync(); var eb = new DiscordEmbedBuilder() // TODO: add URL of website when that's up @@ -157,7 +157,7 @@ namespace PluralKit.Bot { public async Task CreateFronterEmbed(PKSwitch sw, DateTimeZone zone, LookupContext ctx) { - var members = await _data.GetSwitchMembers(sw).ToListAsync(); + var members = await _db.Execute(c => _repo.GetSwitchMembers(c, sw.Id).ToListAsync().AsTask()); var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp; return new DiscordEmbedBuilder() .WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? DiscordUtils.Gray) diff --git a/PluralKit.Bot/Services/LastMessageCacheService.cs b/PluralKit.Bot/Services/LastMessageCacheService.cs index 2ae108e0..6b88e18a 100644 --- a/PluralKit.Bot/Services/LastMessageCacheService.cs +++ b/PluralKit.Bot/Services/LastMessageCacheService.cs @@ -10,7 +10,7 @@ namespace PluralKit.Bot // TODO: is this still needed after the D#+ migration? public class LastMessageCacheService { - private IDictionary _cache = new ConcurrentDictionary(); + private readonly IDictionary _cache = new ConcurrentDictionary(); public void AddMessage(ulong channel, ulong message) { diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index e840e89b..5e563c41 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -15,16 +15,16 @@ namespace PluralKit.Bot { public class LogChannelService { private readonly EmbedService _embed; private readonly IDatabase _db; - private readonly IDataStore _data; + private readonly ModelRepository _repo; private readonly ILogger _logger; private readonly DiscordRestClient _rest; - public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, IDatabase db, IDataStore data) + public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, IDatabase db, ModelRepository repo) { _embed = embed; _rest = rest; _db = db; - _data = data; + _repo = repo; _logger = logger.ForContext(); } @@ -47,8 +47,8 @@ namespace PluralKit.Bot { // Send embed! await using var conn = await _db.Obtain(); - var embed = _embed.CreateLoggedMessageEmbed(await conn.QuerySystem(ctx.SystemId.Value), - await conn.QueryMember(proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content, + var embed = _embed.CreateLoggedMessageEmbed(await _repo.GetSystem(conn, ctx.SystemId.Value), + await _repo.GetMember(conn, proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content, trigger.Channel); var url = $"https://discord.com/channels/{trigger.Channel.GuildId}/{trigger.ChannelId}/{hookMessage}"; await logChannel.SendMessageFixedAsync(content: url, embed: embed); diff --git a/PluralKit.Bot/Services/LoggerCleanService.cs b/PluralKit.Bot/Services/LoggerCleanService.cs index 7ec2a564..5a45ebbd 100644 --- a/PluralKit.Bot/Services/LoggerCleanService.cs +++ b/PluralKit.Bot/Services/LoggerCleanService.cs @@ -16,19 +16,19 @@ namespace PluralKit.Bot { public class LoggerCleanService { - private static Regex _basicRegex = new Regex("(\\d{17,19})"); - private static Regex _dynoRegex = new Regex("Message ID: (\\d{17,19})"); - private static Regex _carlRegex = new Regex("ID: (\\d{17,19})"); - private static Regex _circleRegex = new Regex("\\(`(\\d{17,19})`\\)"); - private static Regex _loggerARegex = new Regex("Message = (\\d{17,19})"); - private static Regex _loggerBRegex = new Regex("MessageID:(\\d{17,19})"); - private static Regex _auttajaRegex = new Regex("Message (\\d{17,19}) deleted"); - private static Regex _mantaroRegex = new Regex("Message \\(?ID:? (\\d{17,19})\\)? created by .* in channel .* was deleted\\."); - private static Regex _pancakeRegex = new Regex("Message from <@(\\d{17,19})> deleted in"); - private static Regex _unbelievaboatRegex = new Regex("Message ID: (\\d{17,19})"); - private static Regex _vanessaRegex = new Regex("Message sent by <@!?(\\d{17,19})> deleted in"); - private static Regex _salRegex = new Regex("\\(ID: (\\d{17,19})\\)"); - private static Regex _GearBotRegex = new Regex("\\(``(\\d{17,19})``\\) in <#\\d{17,19}> has been removed."); + private static readonly Regex _basicRegex = new Regex("(\\d{17,19})"); + private static readonly Regex _dynoRegex = new Regex("Message ID: (\\d{17,19})"); + private static readonly Regex _carlRegex = new Regex("ID: (\\d{17,19})"); + private static readonly Regex _circleRegex = new Regex("\\(`(\\d{17,19})`\\)"); + private static readonly Regex _loggerARegex = new Regex("Message = (\\d{17,19})"); + private static readonly Regex _loggerBRegex = new Regex("MessageID:(\\d{17,19})"); + private static readonly Regex _auttajaRegex = new Regex("Message (\\d{17,19}) deleted"); + private static readonly Regex _mantaroRegex = new Regex("Message \\(?ID:? (\\d{17,19})\\)? created by .* in channel .* was deleted\\."); + private static readonly Regex _pancakeRegex = new Regex("Message from <@(\\d{17,19})> deleted in"); + private static readonly Regex _unbelievaboatRegex = new Regex("Message ID: (\\d{17,19})"); + private static readonly Regex _vanessaRegex = new Regex("Message sent by <@!?(\\d{17,19})> deleted in"); + private static readonly Regex _salRegex = new Regex("\\(ID: (\\d{17,19})\\)"); + private static readonly Regex _GearBotRegex = new Regex("\\(``(\\d{17,19})``\\) in <#\\d{17,19}> has been removed."); private static readonly Dictionary _bots = new[] { @@ -55,7 +55,7 @@ namespace PluralKit.Bot .Where(b => b.WebhookName != null) .ToDictionary(b => b.WebhookName); - private IDatabase _db; + private readonly IDatabase _db; private DiscordShardedClient _client; public LoggerCleanService(IDatabase db, DiscordShardedClient client) diff --git a/PluralKit.Bot/Services/PeriodicStatCollector.cs b/PluralKit.Bot/Services/PeriodicStatCollector.cs index 9794213f..d1b77742 100644 --- a/PluralKit.Bot/Services/PeriodicStatCollector.cs +++ b/PluralKit.Bot/Services/PeriodicStatCollector.cs @@ -17,17 +17,17 @@ namespace PluralKit.Bot { public class PeriodicStatCollector { - private DiscordShardedClient _client; - private IMetrics _metrics; - private CpuStatService _cpu; + private readonly DiscordShardedClient _client; + private readonly IMetrics _metrics; + private readonly CpuStatService _cpu; - private IDatabase _db; + private readonly IDatabase _db; - private WebhookCacheService _webhookCache; + private readonly WebhookCacheService _webhookCache; - private DbConnectionCountHolder _countHolder; + private readonly DbConnectionCountHolder _countHolder; - private ILogger _logger; + private readonly ILogger _logger; public PeriodicStatCollector(DiscordShardedClient client, IMetrics metrics, ILogger logger, WebhookCacheService webhookCache, DbConnectionCountHolder countHolder, CpuStatService cpu, IDatabase db) { diff --git a/PluralKit.Bot/Services/ShardInfoService.cs b/PluralKit.Bot/Services/ShardInfoService.cs index cbfd3c39..6214f66b 100644 --- a/PluralKit.Bot/Services/ShardInfoService.cs +++ b/PluralKit.Bot/Services/ShardInfoService.cs @@ -16,7 +16,6 @@ namespace PluralKit.Bot { public class ShardInfoService { - public class ShardInfo { public bool HasAttachedListeners; @@ -27,10 +26,10 @@ namespace PluralKit.Bot public bool Connected; } - private IMetrics _metrics; - private ILogger _logger; - private DiscordShardedClient _client; - private Dictionary _shardInfo = new Dictionary(); + private readonly IMetrics _metrics; + private readonly ILogger _logger; + private readonly DiscordShardedClient _client; + private readonly Dictionary _shardInfo = new Dictionary(); public ShardInfoService(ILogger logger, DiscordShardedClient client, IMetrics metrics) { diff --git a/PluralKit.Bot/Services/WebhookCacheService.cs b/PluralKit.Bot/Services/WebhookCacheService.cs index 5e403694..d2f10e09 100644 --- a/PluralKit.Bot/Services/WebhookCacheService.cs +++ b/PluralKit.Bot/Services/WebhookCacheService.cs @@ -18,11 +18,11 @@ namespace PluralKit.Bot { public static readonly string WebhookName = "PluralKit Proxy Webhook"; - private DiscordShardedClient _client; - private ConcurrentDictionary>> _webhooks; + private readonly DiscordShardedClient _client; + private readonly ConcurrentDictionary>> _webhooks; - private IMetrics _metrics; - private ILogger _logger; + private readonly IMetrics _metrics; + private readonly ILogger _logger; public WebhookCacheService(DiscordShardedClient client, ILogger logger, IMetrics metrics) { diff --git a/PluralKit.Bot/Services/WebhookExecutorService.cs b/PluralKit.Bot/Services/WebhookExecutorService.cs index 3f48ba1d..ca6fac34 100644 --- a/PluralKit.Bot/Services/WebhookExecutorService.cs +++ b/PluralKit.Bot/Services/WebhookExecutorService.cs @@ -29,10 +29,10 @@ namespace PluralKit.Bot public class WebhookExecutorService { - private WebhookCacheService _webhookCache; - private ILogger _logger; - private IMetrics _metrics; - private HttpClient _client; + private readonly WebhookCacheService _webhookCache; + private readonly ILogger _logger; + private readonly IMetrics _metrics; + private readonly HttpClient _client; public WebhookExecutorService(IMetrics metrics, WebhookCacheService webhookCache, ILogger logger, HttpClient client) { diff --git a/PluralKit.Core/Database/Database.cs b/PluralKit.Core/Database/Database.cs index cea0a72e..e43c9966 100644 --- a/PluralKit.Core/Database/Database.cs +++ b/PluralKit.Core/Database/Database.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Data; using System.IO; -using System.Linq; using System.Threading.Tasks; using App.Metrics; @@ -207,11 +206,19 @@ namespace PluralKit.Core await using var conn = await db.Obtain(); await func(conn); } - + public static async Task Execute(this IDatabase db, Func> func) { await using var conn = await db.Obtain(); return await func(conn); } + + public static async IAsyncEnumerable Execute(this IDatabase db, Func> func) + { + await using var conn = await db.Obtain(); + + await foreach (var val in func(conn)) + yield return val; + } } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs b/PluralKit.Core/Database/Repository/ModelRepository.Context.cs similarity index 66% rename from PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs rename to PluralKit.Core/Database/Repository/ModelRepository.Context.cs index 9800fd88..0a4d330a 100644 --- a/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Context.cs @@ -6,16 +6,16 @@ using Dapper; namespace PluralKit.Core { - public static class DatabaseFunctionsExt + public partial class ModelRepository { - public static Task QueryMessageContext(this IPKConnection conn, ulong account, ulong guild, ulong channel) + public Task GetMessageContext(IPKConnection conn, ulong account, ulong guild, ulong channel) { return conn.QueryFirstAsync("message_context", new { account_id = account, guild_id = guild, channel_id = channel }, commandType: CommandType.StoredProcedure); } - public static Task> QueryProxyMembers(this IPKConnection conn, ulong account, ulong guild) + public Task> GetProxyMembers(IPKConnection conn, ulong account, ulong guild) { return conn.QueryAsync("proxy_members", new { account_id = account, guild_id = guild }, diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Group.cs b/PluralKit.Core/Database/Repository/ModelRepository.Group.cs new file mode 100644 index 00000000..9cdb7385 --- /dev/null +++ b/PluralKit.Core/Database/Repository/ModelRepository.Group.cs @@ -0,0 +1,83 @@ +#nullable enable +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +using Dapper; + +namespace PluralKit.Core +{ + public partial class ModelRepository + { + public Task GetGroupByName(IPKConnection conn, SystemId system, string name) => + conn.QueryFirstOrDefaultAsync("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name}); + + public Task GetGroupByHid(IPKConnection conn, string hid) => + conn.QueryFirstOrDefaultAsync("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()}); + + public Task GetGroupMemberCount(IPKConnection conn, GroupId id, PrivacyLevel? privacyFilter = null) + { + var query = new StringBuilder("select count(*) from group_members"); + if (privacyFilter != null) + query.Append(" inner join members on group_members.member_id = members.id"); + query.Append(" where group_members.group_id = @Id"); + if (privacyFilter != null) + query.Append(" and members.member_visibility = @PrivacyFilter"); + return conn.QuerySingleOrDefaultAsync(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter}); + } + + public IAsyncEnumerable GetMemberGroups(IPKConnection conn, MemberId id) => + conn.QueryStreamAsync( + "select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id", + new {Id = id}); + + public async Task CreateGroup(IPKConnection conn, SystemId system, string name) + { + var group = await conn.QueryFirstAsync( + "insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *", + new {System = system, Name = name}); + _logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name); + return group; + } + + public Task UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch) + { + _logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch); + var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id")) + .WithConstant("id", id) + .Build("returning *"); + return conn.QueryFirstAsync(query, pms); + } + + public Task DeleteGroup(IPKConnection conn, GroupId group) + { + _logger.Information("Deleted {GroupId}", group); + return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group}); + } + + public async Task AddMembersToGroup(IPKConnection conn, GroupId group, + IReadOnlyCollection members) + { + await using var w = + conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)"); + foreach (var member in members) + { + await w.StartRowAsync(); + await w.WriteAsync(group.Value); + await w.WriteAsync(member.Value); + } + + await w.CompleteAsync(); + _logger.Information("Added members to {GroupId}: {MemberIds}", group, members); + } + + public Task RemoveMembersFromGroup(IPKConnection conn, GroupId group, + IReadOnlyCollection members) + { + _logger.Information("Removed members from {GroupId}: {MemberIds}", group, members); + return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)", + new {Group = @group, Members = members.ToArray()}); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs b/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs new file mode 100644 index 00000000..8be75949 --- /dev/null +++ b/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs @@ -0,0 +1,53 @@ +using System.Threading.Tasks; + +using Dapper; + +namespace PluralKit.Core +{ + public partial class ModelRepository + { + public Task UpsertGuild(IPKConnection conn, ulong guild, GuildPatch patch) + { + _logger.Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch); + var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id")) + .WithConstant("id", guild) + .Build(); + return conn.ExecuteAsync(query, pms); + } + + public Task UpsertSystemGuild(IPKConnection conn, SystemId system, ulong guild, + SystemGuildPatch patch) + { + _logger.Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch); + var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild")) + .WithConstant("system", system) + .WithConstant("guild", guild) + .Build(); + return conn.ExecuteAsync(query, pms); + } + + public Task UpsertMemberGuild(IPKConnection conn, MemberId member, ulong guild, + MemberGuildPatch patch) + { + _logger.Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch); + var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild")) + .WithConstant("member", member) + .WithConstant("guild", guild) + .Build(); + return conn.ExecuteAsync(query, pms); + } + + public Task GetGuild(IPKConnection conn, ulong guild) => + conn.QueryFirstAsync("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild}); + + public Task GetSystemGuild(IPKConnection conn, ulong guild, SystemId system) => + conn.QueryFirstAsync( + "insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *", + new {guild, system}); + + public Task GetMemberGuild(IPKConnection conn, ulong guild, MemberId member) => + conn.QueryFirstAsync( + "insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *", + new {guild, member}); + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Member.cs b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs new file mode 100644 index 00000000..e2e25888 --- /dev/null +++ b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs @@ -0,0 +1,47 @@ +#nullable enable +using System.Threading.Tasks; + +using Dapper; + +namespace PluralKit.Core +{ + public partial class ModelRepository + { + public Task GetMember(IPKConnection conn, MemberId id) => + conn.QueryFirstOrDefaultAsync("select * from members where id = @id", new {id}); + + public Task GetMemberByHid(IPKConnection conn, string hid) => + conn.QuerySingleOrDefaultAsync("select * from members where hid = @Hid", new { Hid = hid.ToLower() }); + + public Task GetMemberByName(IPKConnection conn, SystemId system, string name) => + conn.QueryFirstOrDefaultAsync("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system }); + + public Task GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) => + conn.QueryFirstOrDefaultAsync("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system }); + + public async Task CreateMember(IPKConnection conn, SystemId id, string memberName) + { + var member = await conn.QueryFirstAsync( + "insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", + new {SystemId = id, Name = memberName}); + _logger.Information("Created {MemberId} in {SystemId}: {MemberName}", + member.Id, id, memberName); + return member; + } + + public Task UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch) + { + _logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch); + var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id")) + .WithConstant("id", id) + .Build("returning *"); + return conn.QueryFirstAsync(query, pms); + } + + public Task DeleteMember(IPKConnection conn, MemberId id) + { + _logger.Information("Deleted {MemberId}", id); + return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id}); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Message.cs b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs new file mode 100644 index 00000000..9acf8be2 --- /dev/null +++ b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs @@ -0,0 +1,63 @@ +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +using Dapper; + +namespace PluralKit.Core +{ + public partial class ModelRepository + { + public async Task AddMessage(IPKConnection conn, PKMessage msg) { + // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before + await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@Mid, @Guild, @Channel, @Member, @Sender, @OriginalMid) on conflict do nothing", msg); + _logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel); + } + + public async Task GetMessage(IPKConnection conn, ulong id) + { + FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) => + new FullMessage {Message = msg, System = system, Member = member}; + + var result = await conn.QueryAsync( + "select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system", + Mapper, new {Id = id}); + return result.FirstOrDefault(); + } + + public async Task DeleteMessage(IPKConnection conn, ulong id) + { + var rowCount = await conn.ExecuteAsync("delete from messages where mid = @Id", new {Id = id}); + if (rowCount > 0) + _logger.Information("Deleted message {MessageId} from database", id); + } + + public async Task DeleteMessagesBulk(IPKConnection conn, IReadOnlyCollection ids) + { + // Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong + // Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway) + var rowCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", + new {Ids = ids.Select(id => (long) id).ToArray()}); + if (rowCount > 0) + _logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount, + ids); + } + } + + public class PKMessage + { + public ulong Mid { get; set; } + public ulong? Guild { get; set; } // null value means "no data" (ie. from before this field being added) + public ulong Channel { get; set; } + public MemberId Member { get; set; } + public ulong Sender { get; set; } + public ulong? OriginalMid { get; set; } + } + + public class FullMessage + { + public PKMessage Message; + public PKMember Member; + public PKSystem System; + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Switch.cs b/PluralKit.Core/Database/Repository/ModelRepository.Switch.cs new file mode 100644 index 00000000..f313f00a --- /dev/null +++ b/PluralKit.Core/Database/Repository/ModelRepository.Switch.cs @@ -0,0 +1,236 @@ +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +using Dapper; + +using NodaTime; + +using NpgsqlTypes; + +namespace PluralKit.Core +{ + public partial class ModelRepository + { + public async Task AddSwitch(IPKConnection conn, SystemId system, IReadOnlyCollection members) + { + // Use a transaction here since we're doing multiple executed commands in one + await using var tx = await conn.BeginTransactionAsync(); + + // First, we insert the switch itself + var sw = await conn.QuerySingleAsync("insert into switches(system) values (@System) returning *", + new {System = system}); + + // Then we insert each member in the switch in the switch_members table + await using (var w = conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)")) + { + foreach (var member in members) + { + await w.StartRowAsync(); + await w.WriteAsync(sw.Id.Value, NpgsqlDbType.Integer); + await w.WriteAsync(member.Value, NpgsqlDbType.Integer); + } + + await w.CompleteAsync(); + } + + // Finally we commit the tx, since the using block will otherwise rollback it + await tx.CommitAsync(); + + _logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members); + } + + public async Task MoveSwitch(IPKConnection conn, SwitchId id, Instant time) + { + await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id", + new {Time = time, Id = id}); + + _logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", id, time); + } + + public async Task DeleteSwitch(IPKConnection conn, SwitchId id) + { + await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = id}); + _logger.Information("Deleted {Switch}", id); + } + + public async Task DeleteAllSwitches(IPKConnection conn, SystemId system) + { + await conn.ExecuteAsync("delete from switches where system = @Id", new {Id = system}); + _logger.Information("Deleted all switches in {SystemId}", system); + } + + public IAsyncEnumerable GetSwitches(IPKConnection conn, SystemId system) + { + // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list + return conn.QueryStreamAsync( + "select * from switches where system = @System order by timestamp desc", + new {System = system}); + } + + public async Task GetSwitchCount(IPKConnection conn, SystemId system) + { + return await conn.QuerySingleAsync("select count(*) from switches where system = @Id", new { Id = system }); + } + + public async IAsyncEnumerable GetSwitchMembersList(IPKConnection conn, + SystemId system, Instant start, Instant end) + { + // Wrap multiple commands in a single transaction for performance + await using var tx = await conn.BeginTransactionAsync(); + + // Find the time of the last switch outside the range as it overlaps the range + // If no prior switch exists, the lower bound of the range remains the start time + var lastSwitch = await conn.QuerySingleOrDefaultAsync( + @"SELECT COALESCE(MAX(timestamp), @Start) + FROM switches + WHERE switches.system = @System + AND switches.timestamp < @Start", + new {System = system, Start = start}); + + // Then collect the time and members of all switches that overlap the range + var switchMembersEntries = conn.QueryStreamAsync( + @"SELECT switch_members.member, switches.timestamp + FROM switches + LEFT JOIN switch_members + ON switches.id = switch_members.switch + WHERE switches.system = @System + AND ( + switches.timestamp >= @Start + OR switches.timestamp = @LastSwitch + ) + AND switches.timestamp < @End + ORDER BY switches.timestamp DESC", + new {System = system, Start = start, End = end, LastSwitch = lastSwitch}); + + // Yield each value here + await foreach (var entry in switchMembersEntries) + yield return entry; + + // Don't really need to worry about the transaction here, we're not doing any *writes* + } + + public IAsyncEnumerable GetSwitchMembers(IPKConnection conn, SwitchId sw) + { + return conn.QueryStreamAsync( + "select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id", + new {Switch = sw}); + } + + public async Task GetLatestSwitch(IPKConnection conn, SystemId system) => + // TODO: should query directly for perf + await GetSwitches(conn, system).FirstOrDefaultAsync(); + + public async Task> GetPeriodFronters(IPKConnection conn, + SystemId system, Instant periodStart, + Instant periodEnd) + { + // TODO: IAsyncEnumerable-ify this one + // TODO: this doesn't belong in the repo + + // Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order + var switchMembers = await GetSwitchMembersList(conn, system, periodStart, periodEnd).ToListAsync(); + + // query DB for all members involved in any of the switches above and collect into a dictionary for future use + // this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary + // key used in GetPerMemberSwitchDuration below + var membersList = await conn.QueryAsync( + "select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax + new {Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList()}); + var memberObjects = membersList.ToDictionary(m => m.Id); + + // Initialize entries - still need to loop to determine the TimespanEnd below + var entries = + from item in switchMembers + group item by item.Timestamp + into g + select new SwitchListEntry + { + TimespanStart = g.Key, + Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member]) + .ToList() + }; + + // Loop through every switch that overlaps the range and add it to the output list + // end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared + var endTime = periodEnd; + var outList = new List(); + foreach (var e in entries) + { + // Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above) + var switchStartClamped = e.TimespanStart < periodStart + ? periodStart + : e.TimespanStart; + + outList.Add(new SwitchListEntry + { + Members = e.Members, TimespanStart = switchStartClamped, TimespanEnd = endTime + }); + + // next switch's end is this switch's start (we're working backward in time) + endTime = e.TimespanStart; + } + + return outList; + } + + public async Task GetFrontBreakdown(IPKConnection conn, SystemId system, Instant periodStart, + Instant periodEnd) + { + // TODO: this doesn't belong in the repo + var dict = new Dictionary(); + + var noFronterDuration = Duration.Zero; + + // Sum up all switch durations for each member + // switches with multiple members will result in the duration to add up to more than the actual period range + + var actualStart = periodEnd; // will be "pulled" down + var actualEnd = periodStart; // will be "pulled" up + + foreach (var sw in await GetPeriodFronters(conn, system, periodStart, periodEnd)) + { + var span = sw.TimespanEnd - sw.TimespanStart; + foreach (var member in sw.Members) + { + if (!dict.ContainsKey(member)) dict.Add(member, span); + else dict[member] += span; + } + + if (sw.Members.Count == 0) noFronterDuration += span; + + if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart; + if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd; + } + + return new FrontBreakdown + { + MemberSwitchDurations = dict, + NoFronterDuration = noFronterDuration, + RangeStart = actualStart, + RangeEnd = actualEnd + }; + } + } + + public struct SwitchListEntry + { + public ICollection Members; + public Instant TimespanStart; + public Instant TimespanEnd; + } + + public struct FrontBreakdown + { + public Dictionary MemberSwitchDurations; + public Duration NoFronterDuration; + public Instant RangeStart; + public Instant RangeEnd; + } + + public struct SwitchMembersListEntry + { + public MemberId Member; + public Instant Timestamp; + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.System.cs b/PluralKit.Core/Database/Repository/ModelRepository.System.cs new file mode 100644 index 00000000..426c9037 --- /dev/null +++ b/PluralKit.Core/Database/Repository/ModelRepository.System.cs @@ -0,0 +1,78 @@ +#nullable enable +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +using Dapper; + +namespace PluralKit.Core +{ + public partial class ModelRepository + { + public Task GetSystem(IPKConnection conn, SystemId id) => + conn.QueryFirstOrDefaultAsync("select * from systems where id = @id", new {id}); + + public Task GetSystemByAccount(IPKConnection conn, ulong accountId) => + conn.QuerySingleOrDefaultAsync( + "select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", + new {Id = accountId}); + + public Task GetSystemByHid(IPKConnection conn, string hid) => + conn.QuerySingleOrDefaultAsync("select * from systems where systems.hid = @Hid", + new {Hid = hid.ToLower()}); + + public Task> GetSystemAccounts(IPKConnection conn, SystemId system) => + conn.QueryAsync("select uid from accounts where system = @Id", new {Id = system}); + + public IAsyncEnumerable GetSystemMembers(IPKConnection conn, SystemId system) => + conn.QueryStreamAsync("select * from members where system = @SystemID", new {SystemID = system}); + + public Task GetSystemMemberCount(IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null) + { + var query = new StringBuilder("select count(*) from members where system = @Id"); + if (privacyFilter != null) + query.Append($" and member_visibility = {(int) privacyFilter.Value}"); + return conn.QuerySingleAsync(query.ToString(), new {Id = id}); + } + + public async Task CreateSystem(IPKConnection conn, string? systemName = null) + { + var system = await conn.QuerySingleAsync( + "insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *", + new {Name = systemName}); + _logger.Information("Created {SystemId}", system.Id); + return system; + } + + public Task UpdateSystem(IPKConnection conn, SystemId id, SystemPatch patch) + { + _logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch); + var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id")) + .WithConstant("id", id) + .Build("returning *"); + return conn.QueryFirstAsync(query, pms); + } + + public async Task AddAccount(IPKConnection conn, SystemId system, ulong accountId) + { + // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent + // This is used in import/export, although the pk;link command checks for this case beforehand + await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", + new {Id = accountId, SystemId = system}); + _logger.Information("Linked account {UserId} to {SystemId}", accountId, system); + } + + public async Task RemoveAccount(IPKConnection conn, SystemId system, ulong accountId) + { + await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", + new {Id = accountId, SystemId = system}); + _logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system); + } + + public Task DeleteSystem(IPKConnection conn, SystemId id) + { + _logger.Information("Deleted {SystemId}", id); + return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id}); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.cs b/PluralKit.Core/Database/Repository/ModelRepository.cs new file mode 100644 index 00000000..d3814329 --- /dev/null +++ b/PluralKit.Core/Database/Repository/ModelRepository.cs @@ -0,0 +1,15 @@ +using Serilog; + +namespace PluralKit.Core +{ + public partial class ModelRepository + { + private readonly ILogger _logger; + + public ModelRepository(ILogger logger) + { + _logger = logger.ForContext() + .ForContext("Elastic", "yes?"); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Wrappers/PKConnection.cs b/PluralKit.Core/Database/Wrappers/PKConnection.cs index 135d26cc..acf171be 100644 --- a/PluralKit.Core/Database/Wrappers/PKConnection.cs +++ b/PluralKit.Core/Database/Wrappers/PKConnection.cs @@ -64,7 +64,11 @@ namespace PluralKit.Core protected override async ValueTask BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct)); public override void Open() => throw SyncError(nameof(Open)); - public override void Close() => throw SyncError(nameof(Close)); + public override void Close() + { + // Don't throw SyncError here, Dapper calls sync Close() internally so that sucks + Inner.Close(); + } IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction)); IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction)); diff --git a/PluralKit.Core/Models/ModelQueryExt.cs b/PluralKit.Core/Models/ModelQueryExt.cs deleted file mode 100644 index 45ee566f..00000000 --- a/PluralKit.Core/Models/ModelQueryExt.cs +++ /dev/null @@ -1,69 +0,0 @@ -#nullable enable -using System.Collections.Generic; -using System.Text; -using System.Threading.Tasks; - -using Dapper; - -namespace PluralKit.Core -{ - public static class ModelQueryExt - { - public static Task QuerySystem(this IPKConnection conn, SystemId id) => - conn.QueryFirstOrDefaultAsync("select * from systems where id = @id", new {id}); - - public static Task GetSystemMemberCount(this IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null) - { - var query = new StringBuilder("select count(*) from members where system = @Id"); - if (privacyFilter != null) - query.Append($" and member_visibility = {(int) privacyFilter.Value}"); - return conn.QuerySingleAsync(query.ToString(), new {Id = id}); - } - - public static Task> GetLinkedAccounts(this IPKConnection conn, SystemId id) => - conn.QueryAsync("select uid from accounts where system = @Id", new {Id = id}); - - public static Task QueryMember(this IPKConnection conn, MemberId id) => - conn.QueryFirstOrDefaultAsync("select * from members where id = @id", new {id}); - - public static Task QueryMemberByHid(this IPKConnection conn, string hid) => - conn.QueryFirstOrDefaultAsync("select * from members where hid = @hid", new {hid = hid.ToLowerInvariant()}); - - public static Task QueryGroupByName(this IPKConnection conn, SystemId system, string name) => - conn.QueryFirstOrDefaultAsync("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name}); - - public static Task QueryGroupByHid(this IPKConnection conn, string hid) => - conn.QueryFirstOrDefaultAsync("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()}); - - public static Task QueryGroupMemberCount(this IPKConnection conn, GroupId id, - PrivacyLevel? privacyFilter = null) - { - var query = new StringBuilder("select count(*) from group_members"); - if (privacyFilter != null) - query.Append(" inner join members on group_members.member_id = members.id"); - query.Append(" where group_members.group_id = @Id"); - if (privacyFilter != null) - query.Append(" and members.member_visibility = @PrivacyFilter"); - return conn.QuerySingleOrDefaultAsync(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter}); - } - - public static Task> QueryMemberGroups(this IPKConnection conn, MemberId id) => - conn.QueryAsync( - "select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id", - new {Id = id}); - - public static Task QueryOrInsertGuildConfig(this IPKConnection conn, ulong guild) => - conn.QueryFirstAsync("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild}); - - public static Task QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, SystemId system) => - conn.QueryFirstAsync( - "insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *", - new {guild, system}); - - public static Task QueryOrInsertMemberGuildConfig( - this IPKConnection conn, ulong guild, MemberId member) => - conn.QueryFirstAsync( - "insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *", - new {guild, member}); - } -} \ No newline at end of file diff --git a/PluralKit.Core/Models/Patch/ModelPatchExt.cs b/PluralKit.Core/Models/Patch/ModelPatchExt.cs deleted file mode 100644 index 105c47c1..00000000 --- a/PluralKit.Core/Models/Patch/ModelPatchExt.cs +++ /dev/null @@ -1,128 +0,0 @@ -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; - -using Dapper; - -using Serilog; - -namespace PluralKit.Core -{ - public static class ModelPatchExt - { - public static Task UpdateSystem(this IPKConnection conn, SystemId id, SystemPatch patch) - { - Log.ForContext("Elastic", "yes?").Information("Updated {SystemId}: {@SystemPatch}", id, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id")) - .WithConstant("id", id) - .Build("returning *"); - return conn.QueryFirstAsync(query, pms); - } - - public static Task DeleteSystem(this IPKConnection conn, SystemId id) - { - Log.ForContext("Elastic", "yes?").Information("Deleted {SystemId}", id); - return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id}); - } - - public static async Task CreateMember(this IPKConnection conn, SystemId system, string memberName) - { - var member = await conn.QueryFirstAsync( - "insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *", - new {SystemId = system, Name = memberName}); - Log.ForContext("Elastic", "yes?").Information("Created {MemberId} in {SystemId}: {MemberName}", - member.Id, system, memberName); - return member; - } - - public static Task UpdateMember(this IPKConnection conn, MemberId id, MemberPatch patch) - { - Log.ForContext("Elastic", "yes?").Information("Updated {MemberId}: {@MemberPatch}", id, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id")) - .WithConstant("id", id) - .Build("returning *"); - return conn.QueryFirstAsync(query, pms); - } - - public static Task DeleteMember(this IPKConnection conn, MemberId id) - { - Log.ForContext("Elastic", "yes?").Information("Deleted {MemberId}", id); - return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id}); - } - - public static Task UpsertSystemGuild(this IPKConnection conn, SystemId system, ulong guild, - SystemGuildPatch patch) - { - Log.ForContext("Elastic", "yes?").Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild")) - .WithConstant("system", system) - .WithConstant("guild", guild) - .Build(); - return conn.ExecuteAsync(query, pms); - } - - public static Task UpsertMemberGuild(this IPKConnection conn, MemberId member, ulong guild, - MemberGuildPatch patch) - { - Log.ForContext("Elastic", "yes?").Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild")) - .WithConstant("member", member) - .WithConstant("guild", guild) - .Build(); - return conn.ExecuteAsync(query, pms); - } - - public static Task UpsertGuild(this IPKConnection conn, ulong guild, GuildPatch patch) - { - Log.ForContext("Elastic", "yes?").Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id")) - .WithConstant("id", guild) - .Build(); - return conn.ExecuteAsync(query, pms); - } - - public static async Task CreateGroup(this IPKConnection conn, SystemId system, string name) - { - var group = await conn.QueryFirstAsync( - "insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *", - new {System = system, Name = name}); - Log.ForContext("Elastic", "yes?").Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name); - return group; - } - - public static Task UpdateGroup(this IPKConnection conn, GroupId id, GroupPatch patch) - { - Log.ForContext("Elastic", "yes?").Information("Updated {GroupId}: {@GroupPatch}", id, patch); - var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id")) - .WithConstant("id", id) - .Build("returning *"); - return conn.QueryFirstAsync(query, pms); - } - - public static Task DeleteGroup(this IPKConnection conn, GroupId group) - { - Log.ForContext("Elastic", "yes?").Information("Deleted {GroupId}", group); - return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group}); - } - - public static async Task AddMembersToGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection members) - { - await using var w = conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)"); - foreach (var member in members) - { - await w.StartRowAsync(); - await w.WriteAsync(group.Value); - await w.WriteAsync(member.Value); - } - await w.CompleteAsync(); - Log.ForContext("Elastic", "yes?").Information("Added members to {GroupId}: {MemberIds}", group, members); - } - - public static Task RemoveMembersFromGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection members) - { - Log.ForContext("Elastic", "yes?").Information("Removed members from {GroupId}: {MemberIds}", group, members); - return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)", - new {Group = @group, Members = members.ToArray()}); - } - } -} \ No newline at end of file diff --git a/PluralKit.Core/Models/SystemGuildSettings.cs b/PluralKit.Core/Models/SystemGuildSettings.cs index 2da7b4ca..4f9f6732 100644 --- a/PluralKit.Core/Models/SystemGuildSettings.cs +++ b/PluralKit.Core/Models/SystemGuildSettings.cs @@ -1,5 +1,13 @@ namespace PluralKit.Core { + public enum AutoproxyMode + { + Off = 1, + Front = 2, + Latch = 3, + Member = 4 + } + public class SystemGuildSettings { public ulong Guild { get; } diff --git a/PluralKit.Core/Modules.cs b/PluralKit.Core/Modules.cs index ebe3ceb3..22406ebc 100644 --- a/PluralKit.Core/Modules.cs +++ b/PluralKit.Core/Modules.cs @@ -25,7 +25,7 @@ namespace PluralKit.Core { builder.RegisterType().SingleInstance(); builder.RegisterType().As().SingleInstance(); - builder.RegisterType().AsSelf().As(); + builder.RegisterType().AsSelf().SingleInstance(); builder.Populate(new ServiceCollection().AddMemoryCache()); } @@ -33,7 +33,7 @@ namespace PluralKit.Core public class ConfigModule: Module where T: new() { - private string _submodule; + private readonly string _submodule; public ConfigModule(string submodule = null) { diff --git a/PluralKit.Core/Services/DataFileService.cs b/PluralKit.Core/Services/DataFileService.cs index d38a3d7d..a2422be0 100644 --- a/PluralKit.Core/Services/DataFileService.cs +++ b/PluralKit.Core/Services/DataFileService.cs @@ -14,22 +14,24 @@ namespace PluralKit.Core { public class DataFileService { - private IDataStore _data; - private IDatabase _db; - private ILogger _logger; + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly ILogger _logger; - public DataFileService(ILogger logger, IDataStore data, IDatabase db) + public DataFileService(ILogger logger, IDatabase db, ModelRepository repo) { - _data = data; _db = db; + _repo = repo; _logger = logger.ForContext(); } public async Task ExportSystem(PKSystem system) { + await using var conn = await _db.Obtain(); + // Export members var members = new List(); - var pkMembers = _data.GetSystemMembers(system); // Read all members in the system + var pkMembers = _repo.GetSystemMembers(conn, system.Id); // Read all members in the system await foreach (var member in pkMembers.Select(m => new DataFileMember { @@ -49,7 +51,7 @@ namespace PluralKit.Core // Export switches var switches = new List(); - var switchList = await _data.GetPeriodFronters(system, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant()); + var switchList = await _repo.GetPeriodFronters(conn, system.Id, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant()); switches.AddRange(switchList.Select(x => new DataFileSwitch { Timestamp = x.TimespanStart.FormatExport(), @@ -68,7 +70,7 @@ namespace PluralKit.Core Members = members, Switches = switches, Created = system.Created.FormatExport(), - LinkedAccounts = (await _data.GetSystemAccounts(system)).ToList() + LinkedAccounts = (await _repo.GetSystemAccounts(conn, system.Id)).ToList() }; } @@ -102,6 +104,8 @@ namespace PluralKit.Core public async Task ImportSystem(DataFileSystem data, PKSystem system, ulong accountId) { + await using var conn = await _db.Obtain(); + var result = new ImportResult { AddedNames = new List(), ModifiedNames = new List(), @@ -112,26 +116,24 @@ namespace PluralKit.Core // If we don't already have a system to save to, create one if (system == null) { - system = result.System = await _data.CreateSystem(data.Name); - await _data.AddAccount(system, accountId); + system = result.System = await _repo.CreateSystem(conn, data.Name); + await _repo.AddAccount(conn, system.Id, accountId); } - await using var conn = await _db.Obtain(); - // Apply system info var patch = new SystemPatch {Name = data.Name}; if (data.Description != null) patch.Description = data.Description; if (data.Tag != null) patch.Tag = data.Tag; if (data.AvatarUrl != null) patch.AvatarUrl = data.AvatarUrl; if (data.TimeZone != null) patch.UiTz = data.TimeZone ?? "UTC"; - await conn.UpdateSystem(system.Id, patch); + await _repo.UpdateSystem(conn, system.Id, patch); // -- Member/switch import -- await using (var imp = await BulkImporter.Begin(system, conn)) { // Tally up the members that didn't exist before, and check member count on import // If creating the unmatched members would put us over the member limit, abort before creating any members - var memberCountBefore = await conn.GetSystemMemberCount(system.Id); + var memberCountBefore = await _repo.GetSystemMemberCount(conn, system.Id); var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name)); if (memberCountBefore + membersToAdd > Limits.MaxMemberCount) { diff --git a/PluralKit.Core/Services/IDataStore.cs b/PluralKit.Core/Services/IDataStore.cs deleted file mode 100644 index 68a55d46..00000000 --- a/PluralKit.Core/Services/IDataStore.cs +++ /dev/null @@ -1,223 +0,0 @@ -using System.Collections.Generic; -using System.Threading.Tasks; - -using NodaTime; - -namespace PluralKit.Core { - public enum AutoproxyMode - { - Off = 1, - Front = 2, - Latch = 3, - Member = 4 - } - - public class FullMessage - { - public PKMessage Message; - public PKMember Member; - public PKSystem System; - } - - public struct PKMessage - { - public ulong Mid; - public ulong? Guild; // null value means "no data" (ie. from before this field being added) - public ulong Channel; - public ulong Sender; - public ulong? OriginalMid; - } - - public struct SwitchListEntry - { - public ICollection Members; - public Instant TimespanStart; - public Instant TimespanEnd; - } - - public struct FrontBreakdown - { - public Dictionary MemberSwitchDurations; - public Duration NoFronterDuration; - public Instant RangeStart; - public Instant RangeEnd; - } - - public struct SwitchMembersListEntry - { - public MemberId Member; - public Instant Timestamp; - } - - public interface IDataStore - { - /// - /// Gets a system by its user-facing human ID. - /// - /// The with the given human ID, or null if no system was found. - Task GetSystemByHid(string systemHid); - - /// - /// Gets a system by one of its linked Discord account IDs. Multiple IDs can return the same system. - /// - /// The with the given linked account, or null if no system was found. - Task GetSystemByAccount(ulong linkedAccount); - - /// - /// Gets the Discord account IDs linked to a system. - /// - /// An enumerable of Discord account IDs linked to this system. - Task> GetSystemAccounts(PKSystem system); - - /// - /// Creates a system, auto-generating its corresponding IDs. - /// - /// An optional system name to set. If `null`, will not set a system name. - /// The created system model. - Task CreateSystem(string systemName); - // TODO: throw exception if account is present (when adding) or account isn't present (when removing) - - /// - /// Links a Discord account to a system. - /// - /// Throws an exception (TODO: which?) if the given account is already linked to a system. - Task AddAccount(PKSystem system, ulong accountToAdd); - - /// - /// Unlinks a Discord account from a system. - /// - /// Will *not* throw if this results in an orphaned system - this is the caller's responsibility to ensure. - /// - /// Throws an exception (TODO: which?) if the given account is not linked to the given system. - Task RemoveAccount(PKSystem system, ulong accountToRemove); - - /// - /// Gets a member by its user-facing human ID. - /// - /// The with the given human ID, or null if no member was found. - Task GetMemberByHid(string memberHid); - - /// - /// Gets a member by its member name within one system. - /// - /// - /// Member names are *usually* unique within a system (but not always), whereas member names - /// are almost certainly *not* unique globally - therefore only intra-system lookup is - /// allowed. - /// - /// The with the given name, or null if no member was found. - Task GetMemberByName(PKSystem system, string name); - - /// - /// Gets a member by its display name within one system. - /// - /// The with the given name, or null if no member was found. - Task GetMemberByDisplayName(PKSystem system, string name); - - /// - /// Gets all members inside a given system. - /// - /// An enumerable of structs representing each member in the system, in no particular order. - IAsyncEnumerable GetSystemMembers(PKSystem system, bool orderByName = false); - - /// - /// Gets a message and its information by its ID. - /// - /// The message ID to look up. This can be either the ID of the trigger message containing the proxy tags or the resulting proxied webhook message. - /// An extended message object, containing not only the message data itself but the associated system and member structs. - Task GetMessage(ulong id); // id is both original and trigger, also add return type struct - - /// - /// Saves a posted message to the database. - /// - /// The ID of the account that sent the original trigger message. - /// The ID of the guild the message was posted to. - /// The ID of the channel the message was posted to. - /// The ID of the message posted by the webhook. - /// The ID of the original trigger message containing the proxy tags. - /// The member (and by extension system) that was proxied. - /// - Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId); - - /// - /// Deletes a message from the data store. - /// - /// The ID of the webhook message to delete. - Task DeleteMessage(ulong postedMessageId); - - /// - /// Deletes messages from the data store in bulk. - /// - /// The IDs of the webhook messages to delete. - Task DeleteMessagesBulk(IReadOnlyCollection postedMessageIds); - - /// - /// Gets switches from a system. - /// - /// An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested. - IAsyncEnumerable GetSwitches(SystemId system); - - /// - /// Gets the total amount of switches in a given system. - /// - Task GetSwitchCount(PKSystem system); - - /// - /// Gets the latest (temporally; closest to now) switch of a given system. - /// - Task GetLatestSwitch(SystemId system); - - /// - /// Gets the members a given switch consists of. - /// - IAsyncEnumerable GetSwitchMembers(PKSwitch sw); - - /// - /// Gets a list of fronters over a given period of time. - /// - /// - /// This list is returned as an enumerable of "switch members", each containing a timestamp - /// and a member ID. - /// - /// Switches containing multiple members will be returned as multiple switch members each with the same - /// timestamp, and a change in timestamp should be interpreted as the start of a new switch. - /// - /// An enumerable of the aforementioned "switch members". - Task> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd); - - /// - /// Calculates a breakdown of a system's fronters over a given period, including how long each member has - /// been fronting, and how long *no* member has been fronting. - /// - /// - /// Switches containing multiple members will count the full switch duration for all members, meaning - /// the total duration may add up to longer than the breakdown period. - /// - /// - /// - /// - /// - Task GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd); - - /// - /// Registers a switch with the given members in the given system. - /// - /// Throws an exception (TODO: which?) if any of the members are not in the given system. - Task AddSwitch(SystemId system, IEnumerable switchMembers); - - /// - /// Updates the timestamp of a given switch. - /// - Task MoveSwitch(PKSwitch sw, Instant time); - - /// - /// Deletes a given switch from the data store. - /// - Task DeleteSwitch(PKSwitch sw); - - /// - /// Deletes all switches in a given system from the data store. - /// - Task DeleteAllSwitches(PKSystem system); - } -} \ No newline at end of file diff --git a/PluralKit.Core/Services/PostgresDataStore.cs b/PluralKit.Core/Services/PostgresDataStore.cs deleted file mode 100644 index 1281addf..00000000 --- a/PluralKit.Core/Services/PostgresDataStore.cs +++ /dev/null @@ -1,334 +0,0 @@ -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; - -using Dapper; - -using NodaTime; - -using Serilog; - -namespace PluralKit.Core { - public class PostgresDataStore: IDataStore { - private readonly IDatabase _conn; - private readonly ILogger _logger; - - public PostgresDataStore(IDatabase conn, ILogger logger) - { - _conn = conn; - _logger = logger - .ForContext() - .ForContext("Elastic", "yes?"); - } - - public async Task CreateSystem(string systemName = null) { - PKSystem system; - using (var conn = await _conn.Obtain()) - system = await conn.QuerySingleAsync("insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *", new { Name = systemName }); - - _logger.Information("Created {SystemId}", system.Id); - // New system has no accounts, therefore nothing gets cached, therefore no need to invalidate caches right here - return system; - } - - public async Task AddAccount(PKSystem system, ulong accountId) { - // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent - // This is used in import/export, although the pk;link command checks for this case beforehand - using (var conn = await _conn.Obtain()) - await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id }); - - _logger.Information("Linked account {UserId} to {SystemId}", accountId, system.Id); - } - - public async Task RemoveAccount(PKSystem system, ulong accountId) { - using (var conn = await _conn.Obtain()) - await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id }); - - _logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system.Id); - } - - public async Task GetSystemByAccount(ulong accountId) { - using (var conn = await _conn.Obtain()) - return await conn.QuerySingleOrDefaultAsync("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId }); - } - - public async Task GetSystemByHid(string hid) { - using (var conn = await _conn.Obtain()) - return await conn.QuerySingleOrDefaultAsync("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() }); - } - - public async Task> GetSystemAccounts(PKSystem system) - { - using (var conn = await _conn.Obtain()) - return await conn.QueryAsync("select uid from accounts where system = @Id", new { Id = system.Id }); - } - - public async Task DeleteAllSwitches(PKSystem system) - { - using (var conn = await _conn.Obtain()) - await conn.ExecuteAsync("delete from switches where system = @Id", system); - _logger.Information("Deleted all switches in {SystemId}", system.Id); - } - - public async Task GetMemberByHid(string hid) { - using (var conn = await _conn.Obtain()) - return await conn.QuerySingleOrDefaultAsync("select * from members where hid = @Hid", new { Hid = hid.ToLower() }); - } - - public async Task GetMemberByName(PKSystem system, string name) { - // QueryFirst, since members can (in rare cases) share names - using (var conn = await _conn.Obtain()) - return await conn.QueryFirstOrDefaultAsync("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id }); - } - - public async Task GetMemberByDisplayName(PKSystem system, string name) { - // QueryFirst, since members can (in rare cases) share display names - using (var conn = await _conn.Obtain()) - return await conn.QueryFirstOrDefaultAsync("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id }); - } - - public IAsyncEnumerable GetSystemMembers(PKSystem system, bool orderByName) - { - var sql = "select * from members where system = @SystemID"; - if (orderByName) sql += " order by lower(name) asc"; - return _conn.QueryStreamAsync(sql, new { SystemID = system.Id }); - } - - public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId) { - // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before - await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new { - MessageId = postedMessageId, - GuildId = guildId, - ChannelId = channelId, - MemberId = proxiedMemberId, - SenderId = senderId, - OriginalMid = triggerMessageId - }); - - // todo: _logger.Debug("Stored message {Message} in channel {Channel}", postedMessageId, channelId); - } - - public async Task GetMessage(ulong id) - { - using (var conn = await _conn.Obtain()) - return (await conn.QueryAsync("select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system", (msg, member, system) => new FullMessage - { - Message = msg, - System = system, - Member = member - }, new { Id = id })).FirstOrDefault(); - } - - public async Task DeleteMessage(ulong id) { - using (var conn = await _conn.Obtain()) - if (await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }) > 0) - _logger.Information("Deleted message {MessageId} from database", id); - } - - public async Task DeleteMessagesBulk(IReadOnlyCollection ids) - { - using (var conn = await _conn.Obtain()) - { - // Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong - // Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway) - var foundCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", new {Ids = ids.Select(id => (long) id).ToArray()}); - if (foundCount > 0) - _logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", foundCount, ids); - } - } - - public async Task AddSwitch(SystemId system, IEnumerable members) - { - // Use a transaction here since we're doing multiple executed commands in one - await using var conn = await _conn.Obtain(); - await using var tx = await conn.BeginTransactionAsync(); - - // First, we insert the switch itself - var sw = await conn.QuerySingleAsync("insert into switches(system) values (@System) returning *", - new {System = system}); - - // Then we insert each member in the switch in the switch_members table - // TODO: can we parallelize this or send it in bulk somehow? - foreach (var member in members) - { - await conn.ExecuteAsync( - "insert into switch_members(switch, member) values(@Switch, @Member)", - new {Switch = sw.Id, Member = member.Id}); - } - - // Finally we commit the tx, since the using block will otherwise rollback it - await tx.CommitAsync(); - - _logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members.Select(m => m.Id)); - } - - public IAsyncEnumerable GetSwitches(SystemId system) - { - // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list - // (maybe when we get caching in?) - return _conn.QueryStreamAsync( - "select * from switches where system = @System order by timestamp desc", - new {System = system}); - } - - public async Task GetSwitchCount(PKSystem system) - { - using var conn = await _conn.Obtain(); - return await conn.QuerySingleAsync("select count(*) from switches where system = @Id", system); - } - - public async IAsyncEnumerable GetSwitchMembersList(PKSystem system, Instant start, Instant end) - { - // Wrap multiple commands in a single transaction for performance - await using var conn = await _conn.Obtain(); - await using var tx = await conn.BeginTransactionAsync(); - - // Find the time of the last switch outside the range as it overlaps the range - // If no prior switch exists, the lower bound of the range remains the start time - var lastSwitch = await conn.QuerySingleOrDefaultAsync( - @"SELECT COALESCE(MAX(timestamp), @Start) - FROM switches - WHERE switches.system = @System - AND switches.timestamp < @Start", - new { System = system.Id, Start = start }); - - // Then collect the time and members of all switches that overlap the range - var switchMembersEntries = conn.QueryStreamAsync( - @"SELECT switch_members.member, switches.timestamp - FROM switches - LEFT JOIN switch_members - ON switches.id = switch_members.switch - WHERE switches.system = @System - AND ( - switches.timestamp >= @Start - OR switches.timestamp = @LastSwitch - ) - AND switches.timestamp < @End - ORDER BY switches.timestamp DESC", - new { System = system.Id, Start = start, End = end, LastSwitch = lastSwitch }); - - // Yield each value here - await foreach (var entry in switchMembersEntries) - yield return entry; - - // Don't really need to worry about the transaction here, we're not doing any *writes* - } - - public IAsyncEnumerable GetSwitchMembers(PKSwitch sw) - { - return _conn.QueryStreamAsync( - "select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id", - new {Switch = sw.Id}); - } - - public async Task GetLatestSwitch(SystemId system) => - await GetSwitches(system).FirstOrDefaultAsync(); - - public async Task MoveSwitch(PKSwitch sw, Instant time) - { - using (var conn = await _conn.Obtain()) - await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id", - new {Time = time, Id = sw.Id}); - - _logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", sw.Id, time); - } - - public async Task DeleteSwitch(PKSwitch sw) - { - using (var conn = await _conn.Obtain()) - await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id}); - - _logger.Information("Deleted {Switch}", sw.Id); - } - - public async Task> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd) - { - // TODO: IAsyncEnumerable-ify this one - - // Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order - var switchMembers = await GetSwitchMembersList(system, periodStart, periodEnd).ToListAsync(); - - // query DB for all members involved in any of the switches above and collect into a dictionary for future use - // this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary - // key used in GetPerMemberSwitchDuration below - Dictionary memberObjects; - using (var conn = await _conn.Obtain()) - { - memberObjects = ( - await conn.QueryAsync( - "select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax - new { Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList() }) - ).ToDictionary(m => m.Id); - } - - // Initialize entries - still need to loop to determine the TimespanEnd below - var entries = - from item in switchMembers - group item by item.Timestamp into g - select new SwitchListEntry - { - TimespanStart = g.Key, - Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member]).ToList() - }; - - // Loop through every switch that overlaps the range and add it to the output list - // end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared - var endTime = periodEnd; - var outList = new List(); - foreach (var e in entries) - { - // Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above) - var switchStartClamped = e.TimespanStart < periodStart - ? periodStart - : e.TimespanStart; - - outList.Add(new SwitchListEntry - { - Members = e.Members, - TimespanStart = switchStartClamped, - TimespanEnd = endTime - }); - - // next switch's end is this switch's start (we're working backward in time) - endTime = e.TimespanStart; - } - - return outList; - } - public async Task GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd) - { - var dict = new Dictionary(); - - var noFronterDuration = Duration.Zero; - - // Sum up all switch durations for each member - // switches with multiple members will result in the duration to add up to more than the actual period range - - var actualStart = periodEnd; // will be "pulled" down - var actualEnd = periodStart; // will be "pulled" up - - foreach (var sw in await GetPeriodFronters(system, periodStart, periodEnd)) - { - var span = sw.TimespanEnd - sw.TimespanStart; - foreach (var member in sw.Members) - { - if (!dict.ContainsKey(member)) dict.Add(member, span); - else dict[member] += span; - } - - if (sw.Members.Count == 0) noFronterDuration += span; - - if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart; - if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd; - } - - return new FrontBreakdown - { - MemberSwitchDurations = dict, - NoFronterDuration = noFronterDuration, - RangeStart = actualStart, - RangeEnd = actualEnd - }; - } - } -} \ No newline at end of file diff --git a/PluralKit.Core/Utils/ConnectionUtils.cs b/PluralKit.Core/Utils/ConnectionUtils.cs index 743ff505..766295bb 100644 --- a/PluralKit.Core/Utils/ConnectionUtils.cs +++ b/PluralKit.Core/Utils/ConnectionUtils.cs @@ -6,20 +6,11 @@ using Dapper; namespace PluralKit.Core { public static class ConnectionUtils { - public static async IAsyncEnumerable QueryStreamAsync(this IDatabase connFactory, string sql, object param) - { - await using var conn = await connFactory.Obtain(); - - await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param); - var parser = reader.GetRowParser(); - while (await reader.ReadAsync()) - yield return parser(reader); - } - public static async IAsyncEnumerable QueryStreamAsync(this IPKConnection conn, string sql, object param) { await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param); var parser = reader.GetRowParser(); + while (await reader.ReadAsync()) yield return parser(reader); } diff --git a/PluralKit.Core/Utils/QueryBuilder.cs b/PluralKit.Core/Utils/QueryBuilder.cs index 14c4502e..2b423754 100644 --- a/PluralKit.Core/Utils/QueryBuilder.cs +++ b/PluralKit.Core/Utils/QueryBuilder.cs @@ -8,9 +8,9 @@ namespace PluralKit.Core { private readonly string? _conflictField; private readonly string? _condition; - private StringBuilder _insertFragment = new StringBuilder(); - private StringBuilder _valuesFragment = new StringBuilder(); - private StringBuilder _updateFragment = new StringBuilder(); + private readonly StringBuilder _insertFragment = new StringBuilder(); + private readonly StringBuilder _valuesFragment = new StringBuilder(); + private readonly StringBuilder _updateFragment = new StringBuilder(); private bool _firstInsert = true; private bool _firstUpdate = true; public QueryType Type { get; }