Major database refactor (again)

This commit is contained in:
Ske 2020-08-29 13:46:27 +02:00
parent 3996cd48c7
commit c7612df37e
55 changed files with 1014 additions and 1100 deletions

View File

@ -13,18 +13,20 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/a" )]
public class AccountController: ControllerBase
{
private IDataStore _data;
public AccountController(IDataStore data)
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public AccountController(IDatabase db, ModelRepository repo)
{
_data = data;
_db = db;
_repo = repo;
}
[HttpGet("{aid}")]
public async Task<ActionResult<JObject>> GetSystemByAccount(ulong aid)
{
var system = await _data.GetSystemByAccount(aid);
if (system == null) return NotFound("Account not found.");
var system = await _db.Execute(c => _repo.GetSystemByAccount(c, aid));
if (system == null)
return NotFound("Account not found.");
return Ok(system.ToJson(User.ContextFor(system)));
}

View File

@ -16,19 +16,21 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/m" )]
public class MemberController: ControllerBase
{
private IDatabase _db;
private IAuthorizationService _auth;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly IAuthorizationService _auth;
public MemberController(IAuthorizationService auth, IDatabase db)
public MemberController(IAuthorizationService auth, IDatabase db, ModelRepository repo)
{
_auth = auth;
_db = db;
_repo = repo;
}
[HttpGet("{hid}")]
public async Task<ActionResult<JObject>> GetMember(string hid)
{
var member = await _db.Execute(conn => conn.QueryMemberByHid(hid));
var member = await _db.Execute(conn => _repo.GetMemberByHid(conn, hid));
if (member == null) return NotFound("Member not found.");
return Ok(member.ToJson(User.ContextFor(member)));
@ -49,7 +51,7 @@ namespace PluralKit.API
if (memberCount >= Limits.MaxMemberCount)
return BadRequest($"Member limit reached ({memberCount} / {Limits.MaxMemberCount}).");
var member = await conn.CreateMember(system, properties.Value<string>("name"));
var member = await _repo.CreateMember(conn, system, properties.Value<string>("name"));
MemberPatch patch;
try
{
@ -60,7 +62,7 @@ namespace PluralKit.API
return BadRequest(e.Message);
}
member = await conn.UpdateMember(member.Id, patch);
member = await _repo.UpdateMember(conn, member.Id, patch);
return Ok(member.ToJson(User.ContextFor(member)));
}
@ -70,7 +72,7 @@ namespace PluralKit.API
{
await using var conn = await _db.Obtain();
var member = await conn.QueryMemberByHid(hid);
var member = await _repo.GetMemberByHid(conn, hid);
if (member == null) return NotFound("Member not found.");
var res = await _auth.AuthorizeAsync(User, member, "EditMember");
@ -86,7 +88,7 @@ namespace PluralKit.API
return BadRequest(e.Message);
}
var newMember = await conn.UpdateMember(member.Id, patch);
var newMember = await _repo.UpdateMember(conn, member.Id, patch);
return Ok(newMember.ToJson(User.ContextFor(newMember)));
}
@ -96,13 +98,13 @@ namespace PluralKit.API
{
await using var conn = await _db.Obtain();
var member = await conn.QueryMemberByHid(hid);
var member = await _repo.GetMemberByHid(conn, hid);
if (member == null) return NotFound("Member not found.");
var res = await _auth.AuthorizeAsync(User, member, "EditMember");
if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system.");
await conn.DeleteMember(member.Id);
await _repo.DeleteMember(conn, member.Id);
return Ok();
}
}

View File

@ -28,17 +28,19 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/msg" )]
public class MessageController: ControllerBase
{
private IDataStore _data;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MessageController(IDataStore _data)
public MessageController(ModelRepository repo, IDatabase db)
{
this._data = _data;
_repo = repo;
_db = db;
}
[HttpGet("{mid}")]
public async Task<ActionResult<MessageReturn>> GetMessage(ulong mid)
{
var msg = await _data.GetMessage(mid);
var msg = await _db.Execute(c => _repo.GetMessage(c, mid));
if (msg == null) return NotFound("Message not found.");
return new MessageReturn

View File

@ -39,29 +39,29 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/s" )]
public class SystemController : ControllerBase
{
private IDataStore _data;
private IDatabase _db;
private IAuthorizationService _auth;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly IAuthorizationService _auth;
public SystemController(IDataStore data, IDatabase db, IAuthorizationService auth)
public SystemController(IDatabase db, IAuthorizationService auth, ModelRepository repo)
{
_data = data;
_db = db;
_auth = auth;
_repo = repo;
}
[HttpGet]
[Authorize]
public async Task<ActionResult<JObject>> GetOwnSystem()
{
var system = await _db.Execute(c => c.QuerySystem(User.CurrentSystem()));
var system = await _db.Execute(c => _repo.GetSystem(c, User.CurrentSystem()));
return system.ToJson(User.ContextFor(system));
}
[HttpGet("{hid}")]
public async Task<ActionResult<JObject>> GetSystem(string hid)
{
var system = await _data.GetSystemByHid(hid);
var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid));
if (system == null) return NotFound("System not found.");
return Ok(system.ToJson(User.ContextFor(system)));
}
@ -69,13 +69,14 @@ namespace PluralKit.API
[HttpGet("{hid}/members")]
public async Task<ActionResult<IEnumerable<JObject>>> GetMembers(string hid)
{
var system = await _data.GetSystemByHid(hid);
if (system == null) return NotFound("System not found.");
var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid));
if (system == null)
return NotFound("System not found.");
if (!system.MemberListPrivacy.CanAccess(User.ContextFor(system)))
return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list.");
var members = _data.GetSystemMembers(system);
var members = _db.Execute(c => _repo.GetSystemMembers(c, system.Id));
return Ok(await members
.Where(m => m.MemberVisibility.CanAccess(User.ContextFor(system)))
.Select(m => m.ToJson(User.ContextFor(system)))
@ -87,39 +88,40 @@ namespace PluralKit.API
{
if (before == null) before = SystemClock.Instance.GetCurrentInstant();
var system = await _data.GetSystemByHid(hid);
await using var conn = await _db.Obtain();
var system = await _repo.GetSystemByHid(conn, hid);
if (system == null) return NotFound("System not found.");
var auth = await _auth.AuthorizeAsync(User, system, "ViewFrontHistory");
if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history.");
using (var conn = await _db.Obtain())
{
var res = await conn.QueryAsync<SwitchesReturn>(
@"select *, array(
var res = await conn.QueryAsync<SwitchesReturn>(
@"select *, array(
select members.hid from switch_members, members
where switch_members.switch = switches.id and members.id = switch_members.member
) as members from switches
where switches.system = @System and switches.timestamp < @Before
order by switches.timestamp desc
limit 100;", new {System = system.Id, Before = before});
return Ok(res);
}
return Ok(res);
}
[HttpGet("{hid}/fronters")]
public async Task<ActionResult<FrontersReturn>> GetFronters(string hid)
{
var system = await _data.GetSystemByHid(hid);
await using var conn = await _db.Obtain();
var system = await _repo.GetSystemByHid(conn, hid);
if (system == null) return NotFound("System not found.");
var auth = await _auth.AuthorizeAsync(User, system, "ViewFront");
if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter.");
var sw = await _data.GetLatestSwitch(system.Id);
var sw = await _repo.GetLatestSwitch(conn, system.Id);
if (sw == null) return NotFound("System has no registered switches.");
var members = _data.GetSwitchMembers(sw);
var members = _repo.GetSwitchMembers(conn, sw.Id);
return Ok(new FrontersReturn
{
Timestamp = sw.Timestamp,
@ -131,7 +133,8 @@ namespace PluralKit.API
[Authorize]
public async Task<ActionResult<JObject>> EditSystem([FromBody] JObject changes)
{
var system = await _db.Execute(c => c.QuerySystem(User.CurrentSystem()));
await using var conn = await _db.Obtain();
var system = await _repo.GetSystem(conn, User.CurrentSystem());
SystemPatch patch;
try
@ -143,7 +146,7 @@ namespace PluralKit.API
return BadRequest(e.Message);
}
await _db.Execute(conn => conn.UpdateSystem(system.Id, patch));
await _repo.UpdateSystem(conn, system!.Id, patch);
return Ok(system.ToJson(User.ContextFor(system)));
}
@ -154,11 +157,13 @@ namespace PluralKit.API
if (param.Members.Distinct().Count() != param.Members.Count)
return BadRequest("Duplicate members in member list.");
await using var conn = await _db.Obtain();
// We get the current switch, if it exists
var latestSwitch = await _data.GetLatestSwitch(User.CurrentSystem());
var latestSwitch = await _repo.GetLatestSwitch(conn, User.CurrentSystem());
if (latestSwitch != null)
{
var latestSwitchMembers = _data.GetSwitchMembers(latestSwitch);
var latestSwitchMembers = _repo.GetSwitchMembers(conn, latestSwitch.Id);
// Bail if this switch is identical to the latest one
if (await latestSwitchMembers.Select(m => m.Hid).SequenceEqualAsync(param.Members.ToAsyncEnumerable()))
@ -166,9 +171,7 @@ namespace PluralKit.API
}
// Resolve member objects for all given IDs
IEnumerable<PKMember> membersList;
using (var conn = await _db.Obtain())
membersList = (await conn.QueryAsync<PKMember>("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList();
var membersList = (await conn.QueryAsync<PKMember>("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList();
foreach (var member in membersList)
if (member.System != User.CurrentSystem())
@ -182,12 +185,13 @@ namespace PluralKit.API
// We do this without .Select() since we want to have the early return bail if it doesn't find the member
foreach (var givenMemberId in param.Members)
{
if (!membersDict.TryGetValue(givenMemberId, out var member)) return BadRequest($"Member '{givenMemberId}' not found.");
if (!membersDict.TryGetValue(givenMemberId, out var member))
return BadRequest($"Member '{givenMemberId}' not found.");
membersInOrder.Add(member);
}
// Finally, log the switch (yay!)
await _data.AddSwitch(User.CurrentSystem(), membersInOrder);
await _repo.AddSwitch(conn, User.CurrentSystem(), membersInOrder.Select(m => m.Id).ToList());
return NoContent();
}
}

View File

@ -15,7 +15,7 @@ namespace PluralKit.Bot
{
public class Context
{
private ILifetimeScope _provider;
private readonly ILifetimeScope _provider;
private readonly DiscordRestClient _rest;
private readonly DiscordShardedClient _client;
@ -24,8 +24,8 @@ namespace PluralKit.Bot
private readonly Parameters _parameters;
private readonly MessageContext _messageContext;
private readonly IDataStore _data;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly PKSystem _senderSystem;
private readonly IMetrics _metrics;
@ -38,10 +38,10 @@ namespace PluralKit.Bot
_client = provider.Resolve<DiscordShardedClient>();
_message = message;
_shard = shard;
_data = provider.Resolve<IDataStore>();
_senderSystem = senderSystem;
_messageContext = messageContext;
_db = provider.Resolve<IDatabase>();
_repo = provider.Resolve<ModelRepository>();
_metrics = provider.Resolve<IMetrics>();
_provider = provider;
_parameters = new Parameters(message.Content.Substring(commandParseOffset));
@ -61,9 +61,8 @@ namespace PluralKit.Bot
public Parameters Parameters => _parameters;
// TODO: this is just here so the extension methods can access it; should it be public/private/?
internal IDataStore DataStore => _data;
internal IDatabase Database => _db;
internal ModelRepository Repository => _repo;
public Task<DiscordMessage> Reply(string text = null, DiscordEmbed embed = null, IEnumerable<IMention> mentions = null)
{

View File

@ -47,12 +47,14 @@ namespace PluralKit.Bot
// - A @mention of an account connected to the system (<@uid>)
// - A system hid
await using var conn = await ctx.Database.Obtain();
// Direct IDs and mentions are both handled by the below method:
if (input.TryParseMention(out var id))
return await ctx.DataStore.GetSystemByAccount(id);
return await ctx.Repository.GetSystemByAccount(conn, id);
// Finally, try HID parsing
var system = await ctx.DataStore.GetSystemByHid(input);
var system = await ctx.Repository.GetSystemByHid(conn, input);
return system;
}
@ -67,15 +69,16 @@ namespace PluralKit.Bot
// - a textual display name of a member *in your own system*
// First, if we have a system, try finding by member name in system
if (ctx.System != null && await ctx.DataStore.GetMemberByName(ctx.System, input) is PKMember memberByName)
await using var conn = await ctx.Database.Obtain();
if (ctx.System != null && await ctx.Repository.GetMemberByName(conn, ctx.System.Id, input) is PKMember memberByName)
return memberByName;
// Then, try member HID parsing:
if (await ctx.DataStore.GetMemberByHid(input) is PKMember memberByHid)
if (await ctx.Repository.GetMemberByHid(conn, input) is PKMember memberByHid)
return memberByHid;
// And if that again fails, we try finding a member with a display name matching the argument from the system
if (ctx.System != null && await ctx.DataStore.GetMemberByDisplayName(ctx.System, input) is PKMember memberByDisplayName)
if (ctx.System != null && await ctx.Repository.GetMemberByDisplayName(conn, ctx.System.Id, input) is PKMember memberByDisplayName)
return memberByDisplayName;
// We didn't find anything, so we return null.
@ -103,9 +106,9 @@ namespace PluralKit.Bot
var input = ctx.PeekArgument();
await using var conn = await ctx.Database.Obtain();
if (ctx.System != null && await conn.QueryGroupByName(ctx.System.Id, input) is {} byName)
if (ctx.System != null && await ctx.Repository.GetGroupByName(conn, ctx.System.Id, input) is {} byName)
return byName;
if (await conn.QueryGroupByHid(input) is {} byHid)
if (await ctx.Repository.GetGroupByHid(conn, input) is {} byHid)
return byHid;
return null;

View File

@ -36,15 +36,15 @@ namespace PluralKit.Bot
private struct WordPosition
{
// Start of the word
internal int startPos;
internal readonly int startPos;
// End of the word
internal int endPos;
internal readonly int endPos;
// How much to advance word pointer afterwards to point at the start of the *next* word
internal int advanceAfterWord;
internal readonly int advanceAfterWord;
internal bool wasQuoted;
internal readonly bool wasQuoted;
public WordPosition(int startPos, int endPos, int advanceAfterWord, bool wasQuoted)
{

View File

@ -12,10 +12,12 @@ namespace PluralKit.Bot
public class Autoproxy
{
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Autoproxy(IDatabase db)
public Autoproxy(IDatabase db, ModelRepository repo)
{
_db = db;
_repo = repo;
}
public async Task AutoproxyRoot(Context ctx)
@ -87,8 +89,8 @@ namespace PluralKit.Bot
var fronters = ctx.MessageContext.LastSwitchMembers;
var relevantMember = ctx.MessageContext.AutoproxyMode switch
{
AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => c.QueryMember(fronters[0])) : null,
AutoproxyMode.Member => await _db.Execute(c => c.QueryMember(ctx.MessageContext.AutoproxyMember.Value)),
AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => _repo.GetMember(c, fronters[0])) : null,
AutoproxyMode.Member => await _db.Execute(c => _repo.GetMember(c, ctx.MessageContext.AutoproxyMember.Value)),
_ => null
};
@ -126,7 +128,7 @@ namespace PluralKit.Bot
private Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, MemberId? autoproxyMember)
{
var patch = new SystemGuildPatch {AutoproxyMode = autoproxyMode, AutoproxyMember = autoproxyMember};
return _db.Execute(conn => conn.UpsertSystemGuild(ctx.System.Id, ctx.Guild.Id, patch));
return _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch));
}
}
}

View File

@ -17,10 +17,12 @@ namespace PluralKit.Bot
public class Groups
{
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Groups(IDatabase db)
public Groups(IDatabase db, ModelRepository repo)
{
_db = db;
_repo = repo;
}
public async Task CreateGroup(Context ctx)
@ -40,14 +42,14 @@ namespace PluralKit.Bot
throw new PKError($"System has reached the maximum number of groups ({Limits.MaxGroupCount}). Please delete unused groups first in order to create new ones.");
// Warn if there's already a group by this name
var existingGroup = await conn.QueryGroupByName(ctx.System.Id, groupName);
var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, groupName);
if (existingGroup != null) {
var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to create another group with the same name?";
if (!await ctx.PromptYesNo(msg))
throw new PKError("Group creation cancelled.");
}
var newGroup = await conn.CreateGroup(ctx.System.Id, groupName);
var newGroup = await _repo.CreateGroup(conn, ctx.System.Id, groupName);
var eb = new DiscordEmbedBuilder()
.WithDescription($"Your new group, **{groupName}**, has been created, with the group ID **`{newGroup.Hid}`**.\nBelow are a couple of useful commands:")
@ -70,14 +72,14 @@ namespace PluralKit.Bot
await using var conn = await _db.Obtain();
// Warn if there's already a group by this name
var existingGroup = await conn.QueryGroupByName(ctx.System.Id, newName);
var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, newName);
if (existingGroup != null && existingGroup.Id != target.Id) {
var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to rename this member to that name too?";
if (!await ctx.PromptYesNo(msg))
throw new PKError("Group creation cancelled.");
}
await conn.UpdateGroup(target.Id, new GroupPatch {Name = newName});
await _repo.UpdateGroup(conn, target.Id, new GroupPatch {Name = newName});
await ctx.Reply($"{Emojis.Success} Group name changed from **{target.Name}** to **{newName}**.");
}
@ -89,7 +91,7 @@ namespace PluralKit.Bot
ctx.CheckOwnGroup(target);
var patch = new GroupPatch {DisplayName = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch));
await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group display name cleared.");
}
@ -112,7 +114,7 @@ namespace PluralKit.Bot
var newDisplayName = ctx.RemainderOrNull();
var patch = new GroupPatch {DisplayName = Partial<string>.Present(newDisplayName)};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch));
await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group display name changed.");
}
@ -125,7 +127,7 @@ namespace PluralKit.Bot
ctx.CheckOwnGroup(target);
var patch = new GroupPatch {Description = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch));
await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group description cleared.");
}
else if (!ctx.HasNext())
@ -154,7 +156,7 @@ namespace PluralKit.Bot
throw Errors.DescriptionTooLongError(description.Length);
var patch = new GroupPatch {Description = Partial<string>.Present(description)};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch));
await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group description changed.");
}
@ -166,7 +168,7 @@ namespace PluralKit.Bot
{
ctx.CheckOwnGroup(target);
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch {Icon = null}));
await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch {Icon = null}));
await ctx.Reply($"{Emojis.Success} Group icon cleared.");
}
@ -178,7 +180,7 @@ namespace PluralKit.Bot
throw Errors.InvalidUrl(img.Url);
await AvatarUtils.VerifyAvatarOrThrow(img.Url);
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch {Icon = img.Url}));
await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch {Icon = img.Url}));
var msg = img.Source switch
{
@ -282,7 +284,7 @@ namespace PluralKit.Bot
var system = await GetGroupSystem(ctx, target, conn);
var pctx = ctx.LookupContextFor(system);
var memberCount = await conn.QueryGroupMemberCount(target.Id, PrivacyLevel.Public);
var memberCount = await _repo.GetGroupMemberCount(conn, target.Id, PrivacyLevel.Public);
var nameField = target.Name;
if (system.Name != null)
@ -333,7 +335,7 @@ namespace PluralKit.Bot
.Select(m => m.Id)
.Distinct()
.ToList();
await conn.AddMembersToGroup(target.Id, membersNotInGroup);
await _repo.AddMembersToGroup(conn, target.Id, membersNotInGroup);
if (membersNotInGroup.Count == members.Count)
await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersNotInGroup.Count)} added to group.");
@ -347,7 +349,7 @@ namespace PluralKit.Bot
.Select(m => m.Id)
.Distinct()
.ToList();
await conn.RemoveMembersFromGroup(target.Id, membersInGroup);
await _repo.RemoveMembersFromGroup(conn, target.Id, membersInGroup);
if (membersInGroup.Count == members.Count)
await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersInGroup.Count)} removed from group.");
@ -422,7 +424,7 @@ namespace PluralKit.Bot
async Task SetAll(PrivacyLevel level)
{
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch().WithAllPrivacy(level)));
await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithAllPrivacy(level)));
if (level == PrivacyLevel.Private)
await ctx.Reply($"{Emojis.Success} All {target.Name}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the group card.");
@ -432,7 +434,7 @@ namespace PluralKit.Bot
async Task SetLevel(GroupPrivacySubject subject, PrivacyLevel level)
{
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch().WithPrivacy(subject, level)));
await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithPrivacy(subject, level)));
var subjectName = subject switch
{
@ -475,17 +477,17 @@ namespace PluralKit.Bot
if (!await ctx.ConfirmWithReply(target.Hid))
throw new PKError($"Group deletion cancelled. Note that you must reply with your group ID (`{target.Hid}`) *verbatim*.");
await _db.Execute(conn => conn.DeleteGroup(target.Id));
await _db.Execute(conn => _repo.DeleteGroup(conn, target.Id));
await ctx.Reply($"{Emojis.Success} Group deleted.");
}
private static async Task<PKSystem> GetGroupSystem(Context ctx, PKGroup target, IPKConnection conn)
private async Task<PKSystem> GetGroupSystem(Context ctx, PKGroup target, IPKConnection conn)
{
var system = ctx.System;
if (system?.Id == target.System)
return system;
return await conn.QuerySystem(target.System)!;
return await _repo.GetSystem(conn, target.System)!;
}
}
}

View File

@ -15,7 +15,7 @@ namespace PluralKit.Bot
{
public class ImportExport
{
private DataFileService _dataFiles;
private readonly DataFileService _dataFiles;
public ImportExport(DataFileService dataFiles)
{
_dataFiles = dataFiles;

View File

@ -8,15 +8,15 @@ namespace PluralKit.Bot
{
public class Member
{
private IDataStore _data;
private IDatabase _db;
private EmbedService _embeds;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly EmbedService _embeds;
public Member(IDataStore data, EmbedService embeds, IDatabase db)
public Member(EmbedService embeds, IDatabase db, ModelRepository repo)
{
_data = data;
_embeds = embeds;
_db = db;
_repo = repo;
}
public async Task NewMember(Context ctx) {
@ -27,7 +27,7 @@ namespace PluralKit.Bot
if (memberName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(memberName.Length);
// Warn if there's already a member by this name
var existingMember = await _data.GetMemberByName(ctx.System, memberName);
var existingMember = await _db.Execute(c => _repo.GetMemberByName(c, ctx.System.Id, memberName));
if (existingMember != null) {
var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (with ID `{existingMember.Hid}`). Do you want to create another member with the same name?";
if (!await ctx.PromptYesNo(msg)) throw new PKError("Member creation cancelled.");
@ -36,12 +36,12 @@ namespace PluralKit.Bot
await using var conn = await _db.Obtain();
// Enforce per-system member limit
var memberCount = await conn.GetSystemMemberCount(ctx.System.Id);
var memberCount = await _repo.GetSystemMemberCount(conn, ctx.System.Id);
if (memberCount >= Limits.MaxMemberCount)
throw Errors.MemberLimitReachedError;
// Create the member
var member = await conn.CreateMember(ctx.System.Id, memberName);
var member = await _repo.CreateMember(conn, ctx.System.Id, memberName);
memberCount++;
// Send confirmation and space hint
@ -62,10 +62,14 @@ namespace PluralKit.Bot
//Maybe move this somewhere else in the file structure since it doesn't need to get created at every command
// TODO: don't buffer these, find something else to do ig
List<PKMember> members;
if (ctx.MatchFlag("all", "a")) members = await _data.GetSystemMembers(ctx.System).ToListAsync();
else members = await _data.GetSystemMembers(ctx.System).Where(m => m.MemberVisibility == PrivacyLevel.Public).ToListAsync();
var members = await _db.Execute(c =>
{
if (ctx.MatchFlag("all", "a"))
return _repo.GetSystemMembers(c, ctx.System.Id);
return _repo.GetSystemMembers(c, ctx.System.Id)
.Where(m => m.MemberVisibility == PrivacyLevel.Public);
}).ToListAsync();
if (members == null || !members.Any())
throw Errors.NoMembersError;
@ -75,8 +79,7 @@ namespace PluralKit.Bot
public async Task ViewMember(Context ctx, PKMember target)
{
var system = await _db.Execute(c => c.QuerySystem(target.System));
var system = await _db.Execute(c => _repo.GetSystem(c, target.System));
await ctx.Reply(embed: await _embeds.CreateMemberEmbed(system, target, ctx.Guild, ctx.LookupContextFor(system)));
}
}

View File

@ -11,10 +11,12 @@ namespace PluralKit.Bot
public class MemberAvatar
{
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MemberAvatar(IDatabase db)
public MemberAvatar(IDatabase db, ModelRepository repo)
{
_db = db;
_repo = repo;
}
private async Task AvatarClear(AvatarLocation location, Context ctx, PKMember target, MemberGuildSettings? mgs)
@ -67,14 +69,14 @@ namespace PluralKit.Bot
public async Task ServerAvatar(Context ctx, PKMember target)
{
ctx.CheckGuildContext();
var guildData = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id));
var guildData = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
await AvatarCommandTree(AvatarLocation.Server, ctx, target, guildData);
}
public async Task Avatar(Context ctx, PKMember target)
{
var guildData = ctx.Guild != null ?
await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id))
await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id))
: null;
await AvatarCommandTree(AvatarLocation.Member, ctx, target, guildData);
@ -150,10 +152,10 @@ namespace PluralKit.Bot
{
case AvatarLocation.Server:
var serverPatch = new MemberGuildPatch { AvatarUrl = url };
return _db.Execute(c => c.UpsertMemberGuild(target.Id, ctx.Guild.Id, serverPatch));
return _db.Execute(c => _repo.UpsertMemberGuild(c, target.Id, ctx.Guild.Id, serverPatch));
case AvatarLocation.Member:
var memberPatch = new MemberPatch { AvatarUrl = url };
return _db.Execute(c => c.UpdateMember(target.Id, memberPatch));
return _db.Execute(c => _repo.UpdateMember(c, target.Id, memberPatch));
default:
throw new ArgumentOutOfRangeException($"Unknown avatar location {location}");
}

View File

@ -15,13 +15,13 @@ namespace PluralKit.Bot
{
public class MemberEdit
{
private readonly IDataStore _data;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MemberEdit(IDataStore data, IDatabase db)
public MemberEdit(IDatabase db, ModelRepository repo)
{
_data = data;
_db = db;
_repo = repo;
}
public async Task Name(Context ctx, PKMember target) {
@ -35,7 +35,7 @@ namespace PluralKit.Bot
if (newName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(newName.Length);
// Warn if there's already a member by this name
var existingMember = await _data.GetMemberByName(ctx.System, newName);
var existingMember = await _db.Execute(conn => _repo.GetMemberByName(conn, ctx.System.Id, newName));
if (existingMember != null && existingMember.Id != target.Id)
{
var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (`{existingMember.Hid}`). Do you want to rename this member to that name too?";
@ -44,7 +44,7 @@ namespace PluralKit.Bot
// Rename the member
var patch = new MemberPatch {Name = Partial<string>.Present(newName)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member renamed.");
if (newName.Contains(" ")) await ctx.Reply($"{Emojis.Note} Note that this member's name now contains spaces. You will need to surround it with \"double quotes\" when using commands referring to it.");
@ -52,7 +52,7 @@ namespace PluralKit.Bot
if (ctx.Guild != null)
{
var memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id));
var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
if (memberGuildConfig.DisplayName != null)
await ctx.Reply($"{Emojis.Note} Note that this member has a server name set ({memberGuildConfig.DisplayName}) in this server ({ctx.Guild.Name}), and will be proxied using that name here.");
}
@ -69,7 +69,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Description = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member description cleared.");
}
else if (!ctx.HasNext())
@ -100,7 +100,7 @@ namespace PluralKit.Bot
throw Errors.DescriptionTooLongError(description.Length);
var patch = new MemberPatch {Description = Partial<string>.Present(description)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member description changed.");
}
@ -111,7 +111,7 @@ namespace PluralKit.Bot
{
CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Pronouns = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member pronouns cleared.");
}
else if (!ctx.HasNext())
@ -136,7 +136,7 @@ namespace PluralKit.Bot
throw Errors.MemberPronounsTooLongError(pronouns.Length);
var patch = new MemberPatch {Pronouns = Partial<string>.Present(pronouns)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member pronouns changed.");
}
@ -150,7 +150,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Color = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member color cleared.");
}
@ -182,7 +182,7 @@ namespace PluralKit.Bot
if (!Regex.IsMatch(color, "^[0-9a-fA-F]{6}$")) throw Errors.InvalidColorError(color);
var patch = new MemberPatch {Color = Partial<string>.Present(color.ToLowerInvariant())};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply(embed: new DiscordEmbedBuilder()
.WithTitle($"{Emojis.Success} Member color changed.")
@ -198,7 +198,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Birthday = Partial<LocalDate?>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member birthdate cleared.");
}
@ -223,7 +223,7 @@ namespace PluralKit.Bot
if (birthday == null) throw Errors.BirthdayParseError(birthdayStr);
var patch = new MemberPatch {Birthday = Partial<LocalDate?>.Present(birthday)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member birthdate changed.");
}
@ -235,7 +235,7 @@ namespace PluralKit.Bot
MemberGuildSettings memberGuildConfig = null;
if (ctx.Guild != null)
memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id));
memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
var eb = new DiscordEmbedBuilder().WithTitle($"Member names")
.WithFooter($"Member ID: {target.Hid} | Active name in bold. Server name overrides display name, which overrides base name.");
@ -271,7 +271,7 @@ namespace PluralKit.Bot
var successStr = text;
if (ctx.Guild != null)
{
var memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id));
var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
if (memberGuildConfig.DisplayName != null)
successStr += $" However, this member has a server name set in this server ({ctx.Guild.Name}), and will be proxied using that name, \"{memberGuildConfig.DisplayName}\", here.";
}
@ -284,7 +284,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {DisplayName = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await PrintSuccess($"{Emojis.Success} Member display name cleared. This member will now be proxied using their member name \"{target.NameFor(ctx)}\".");
}
@ -303,7 +303,7 @@ namespace PluralKit.Bot
var newDisplayName = ctx.RemainderOrNull();
var patch = new MemberPatch {DisplayName = Partial<string>.Present(newDisplayName)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await PrintSuccess($"{Emojis.Success} Member display name changed. This member will now be proxied using the name \"{newDisplayName}\".");
}
@ -318,7 +318,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target);
var patch = new MemberGuildPatch {DisplayName = null};
await _db.Execute(conn => conn.UpsertMemberGuild(target.Id, ctx.Guild.Id, patch));
await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch));
if (target.DisplayName != null)
await ctx.Reply($"{Emojis.Success} Member server name cleared. This member will now be proxied using their global display name \"{target.DisplayName}\" in this server ({ctx.Guild.Name}).");
@ -340,7 +340,7 @@ namespace PluralKit.Bot
var newServerName = ctx.RemainderOrNull();
var patch = new MemberGuildPatch {DisplayName = newServerName};
await _db.Execute(conn => conn.UpsertMemberGuild(target.Id, ctx.Guild.Id, patch));
await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch));
await ctx.Reply($"{Emojis.Success} Member server name changed. This member will now be proxied using the name \"{newServerName}\" in this server ({ctx.Guild.Name}).");
}
@ -365,7 +365,7 @@ namespace PluralKit.Bot
};
var patch = new MemberPatch {KeepProxy = Partial<bool>.Present(newValue)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
if (newValue)
await ctx.Reply($"{Emojis.Success} Member proxy tags will now be included in the resulting message when proxying.");
@ -398,11 +398,11 @@ namespace PluralKit.Bot
// Get guild settings (mostly for warnings and such)
MemberGuildSettings guildSettings = null;
if (ctx.Guild != null)
guildSettings = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id));
guildSettings = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
async Task SetAll(PrivacyLevel level)
{
await _db.Execute(c => c.UpdateMember(target.Id, new MemberPatch().WithAllPrivacy(level)));
await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithAllPrivacy(level)));
if (level == PrivacyLevel.Private)
await ctx.Reply($"{Emojis.Success} All {target.NameFor(ctx)}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the member card.");
@ -412,7 +412,7 @@ namespace PluralKit.Bot
async Task SetLevel(MemberPrivacySubject subject, PrivacyLevel level)
{
await _db.Execute(c => c.UpdateMember(target.Id, new MemberPatch().WithPrivacy(subject, level)));
await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithPrivacy(subject, level)));
var subjectName = subject switch
{
@ -472,7 +472,7 @@ namespace PluralKit.Bot
await ctx.Reply($"{Emojis.Warn} Are you sure you want to delete \"{target.NameFor(ctx)}\"? If so, reply to this message with the member's ID (`{target.Hid}`). __***This cannot be undone!***__");
if (!await ctx.ConfirmWithReply(target.Hid)) throw Errors.MemberDeleteCancelled;
await _db.Execute(conn => conn.DeleteMember(target.Id));
await _db.Execute(conn => _repo.DeleteMember(conn, target.Id));
await ctx.Reply($"{Emojis.Success} Member deleted.");
}

View File

@ -10,10 +10,12 @@ namespace PluralKit.Bot
public class MemberProxy
{
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MemberProxy(IDatabase db)
public MemberProxy(IDatabase db, ModelRepository repo)
{
_db = db;
_repo = repo;
}
public async Task Proxy(Context ctx, PKMember target)
@ -55,7 +57,7 @@ namespace PluralKit.Bot
}
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(new ProxyTag[0])};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Proxy tags cleared.");
}
@ -83,7 +85,7 @@ namespace PluralKit.Bot
var newTags = target.ProxyTags.ToList();
newTags.Add(tagToAdd);
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags.ToArray())};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Added proxy tags {tagToAdd.ProxyString.AsCode()}.");
}
@ -100,7 +102,7 @@ namespace PluralKit.Bot
var newTags = target.ProxyTags.ToList();
newTags.Remove(tagToRemove);
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags.ToArray())};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Removed proxy tags {tagToRemove.ProxyString.AsCode()}.");
}
@ -124,7 +126,7 @@ namespace PluralKit.Bot
var newTags = new[] {requestedTag};
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch));
await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member proxy tags set to {requestedTag.ProxyString.AsCode()}.");
}

View File

@ -18,21 +18,23 @@ using DSharpPlus.Entities;
namespace PluralKit.Bot {
public class Misc
{
private BotConfig _botConfig;
private IMetrics _metrics;
private CpuStatService _cpu;
private ShardInfoService _shards;
private IDataStore _data;
private EmbedService _embeds;
private readonly BotConfig _botConfig;
private readonly IMetrics _metrics;
private readonly CpuStatService _cpu;
private readonly ShardInfoService _shards;
private readonly EmbedService _embeds;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Misc(BotConfig botConfig, IMetrics metrics, CpuStatService cpu, ShardInfoService shards, IDataStore data, EmbedService embeds)
public Misc(BotConfig botConfig, IMetrics metrics, CpuStatService cpu, ShardInfoService shards, EmbedService embeds, ModelRepository repo, IDatabase db)
{
_botConfig = botConfig;
_metrics = metrics;
_cpu = cpu;
_shards = shards;
_data = data;
_embeds = embeds;
_repo = repo;
_db = db;
}
public async Task Invite(Context ctx)
@ -198,7 +200,7 @@ namespace PluralKit.Bot {
messageId = ulong.Parse(match.Groups[1].Value);
else throw new PKSyntaxError($"Could not parse {word.AsCode()} as a message ID or link.");
var message = await _data.GetMessage(messageId);
var message = await _db.Execute(c => _repo.GetMessage(c, messageId));
if (message == null) throw Errors.MessageNotFound(messageId);
await ctx.Reply(embed: await _embeds.CreateMessageInfoEmbed(ctx.Shard, message));

View File

@ -12,12 +12,14 @@ namespace PluralKit.Bot
{
public class ServerConfig
{
private IDatabase _db;
private LoggerCleanService _cleanService;
public ServerConfig(LoggerCleanService cleanService, IDatabase db)
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly LoggerCleanService _cleanService;
public ServerConfig(LoggerCleanService cleanService, IDatabase db, ModelRepository repo)
{
_cleanService = cleanService;
_db = db;
_repo = repo;
}
public async Task SetLogChannel(Context ctx)
@ -32,7 +34,7 @@ namespace PluralKit.Bot
if (channel == null || channel.GuildId != ctx.Guild.Id) throw Errors.ChannelNotFound(channelString);
var patch = new GuildPatch {LogChannel = channel?.Id};
await _db.Execute(conn => conn.UpsertGuild(ctx.Guild.Id, patch));
await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch));
if (channel != null)
await ctx.Reply($"{Emojis.Success} Proxy logging channel set to #{channel.Name}.");
@ -59,7 +61,7 @@ namespace PluralKit.Bot
ulong? logChannel = null;
await using (var conn = await _db.Obtain())
{
var config = await conn.QueryOrInsertGuildConfig(ctx.Guild.Id);
var config = await _repo.GetGuild(conn, ctx.Guild.Id);
logChannel = config.LogChannel;
var blacklist = config.LogBlacklist.ToHashSet();
if (enable)
@ -68,7 +70,7 @@ namespace PluralKit.Bot
blacklist.UnionWith(affectedChannels.Select(c => c.Id));
var patch = new GuildPatch {LogBlacklist = blacklist.ToArray()};
await conn.UpsertGuild(ctx.Guild.Id, patch);
await _repo.UpsertGuild(conn, ctx.Guild.Id, patch);
}
await ctx.Reply(
@ -80,7 +82,7 @@ namespace PluralKit.Bot
{
ctx.CheckGuildContext().CheckAuthorPermission(Permissions.ManageGuild, "Manage Server");
var blacklist = await _db.Execute(c => c.QueryOrInsertGuildConfig(ctx.Guild.Id));
var blacklist = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id));
// Resolve all channels from the cache and order by position
var channels = blacklist.Blacklist
@ -139,7 +141,7 @@ namespace PluralKit.Bot
await using (var conn = await _db.Obtain())
{
var guild = await conn.QueryOrInsertGuildConfig(ctx.Guild.Id);
var guild = await _repo.GetGuild(conn, ctx.Guild.Id);
var blacklist = guild.Blacklist.ToHashSet();
if (shouldAdd)
blacklist.UnionWith(affectedChannels.Select(c => c.Id));
@ -147,7 +149,7 @@ namespace PluralKit.Bot
blacklist.ExceptWith(affectedChannels.Select(c => c.Id));
var patch = new GuildPatch {Blacklist = blacklist.ToArray()};
await conn.UpsertGuild(ctx.Guild.Id, patch);
await _repo.UpsertGuild(conn, ctx.Guild.Id, patch);
}
await ctx.Reply($"{Emojis.Success} Channels {(shouldAdd ? "added to" : "removed from")} the proxy blacklist.");
@ -170,7 +172,7 @@ namespace PluralKit.Bot
.WithTitle("Log cleanup settings")
.AddField("Supported bots", botList);
var guildCfg = await _db.Execute(c => c.QueryOrInsertGuildConfig(ctx.Guild.Id));
var guildCfg = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id));
if (guildCfg.LogCleanupEnabled)
eb.WithDescription("Log cleanup is currently **on** for this server. To disable it, type `pk;logclean off`.");
else
@ -180,7 +182,7 @@ namespace PluralKit.Bot
}
var patch = new GuildPatch {LogCleanupEnabled = newValue};
await _db.Execute(conn => conn.UpsertGuild(ctx.Guild.Id, patch));
await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch));
if (newValue)
await ctx.Reply($"{Emojis.Success} Log cleanup has been **enabled** for this server. Messages deleted by PluralKit will now be cleaned up from logging channels managed by the following bots:\n- **{botList}**\n\n{Emojis.Note} Make sure PluralKit has the **Manage Messages** permission in the channels in question.\n{Emojis.Note} Also, make sure to blacklist the logging channel itself from the bots in question to prevent conflicts.");

View File

@ -13,11 +13,13 @@ namespace PluralKit.Bot
{
public class Switch
{
private IDataStore _data;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Switch(IDataStore data)
public Switch(IDatabase db, ModelRepository repo)
{
_data = data;
_db = db;
_repo = repo;
}
public async Task SwitchDo(Context ctx)
@ -42,16 +44,17 @@ namespace PluralKit.Bot
if (members.Select(m => m.Id).Distinct().Count() != members.Count) throw Errors.DuplicateSwitchMembers;
// Find the last switch and its members if applicable
var lastSwitch = await _data.GetLatestSwitch(ctx.System.Id);
await using var conn = await _db.Obtain();
var lastSwitch = await _repo.GetLatestSwitch(conn, ctx.System.Id);
if (lastSwitch != null)
{
var lastSwitchMembers = _data.GetSwitchMembers(lastSwitch);
var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastSwitch.Id);
// Make sure the requested switch isn't identical to the last one
if (await lastSwitchMembers.Select(m => m.Id).SequenceEqualAsync(members.Select(m => m.Id).ToAsyncEnumerable()))
throw Errors.SameSwitch(members, ctx.LookupContextFor(ctx.System));
}
await _data.AddSwitch(ctx.System.Id, members);
await _repo.AddSwitch(conn, ctx.System.Id, members.Select(m => m.Id).ToList());
if (members.Count == 0)
await ctx.Reply($"{Emojis.Success} Switch-out registered.");
@ -68,12 +71,14 @@ namespace PluralKit.Bot
var result = DateUtils.ParseDateTime(timeToMove, true, tz);
if (result == null) throw Errors.InvalidDateTime(timeToMove);
await using var conn = await _db.Obtain();
var time = result.Value;
if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture;
// Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync();
var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync();
// If we don't have a switch to move, don't bother
if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
@ -87,7 +92,7 @@ namespace PluralKit.Bot
// Now we can actually do the move, yay!
// But, we do a prompt to confirm.
var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]);
var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id);
var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync());
var lastSwitchTimeStr = lastTwoSwitches[0].Timestamp.FormatZoned(ctx.System);
var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration();
@ -99,7 +104,7 @@ namespace PluralKit.Bot
if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchMoveCancelled;
// aaaand *now* we do the move
await _data.MoveSwitch(lastTwoSwitches[0], time.ToInstant());
await _repo.MoveSwitch(conn, lastTwoSwitches[0].Id, time.ToInstant());
await ctx.Reply($"{Emojis.Success} Switch moved.");
}
@ -113,16 +118,18 @@ namespace PluralKit.Bot
var purgeMsg = $"{Emojis.Warn} This will delete *all registered switches* in your system. Are you sure you want to proceed?";
if (!await ctx.PromptYesNo(purgeMsg))
throw Errors.GenericCancelled();
await _data.DeleteAllSwitches(ctx.System);
await _db.Execute(c => _repo.DeleteAllSwitches(c, ctx.System.Id));
await ctx.Reply($"{Emojis.Success} Cleared system switches!");
return;
}
await using var conn = await _db.Obtain();
// Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync();
var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync();
if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]);
var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id);
var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync());
var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration();
@ -133,14 +140,14 @@ namespace PluralKit.Bot
}
else
{
var secondSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[1]);
var secondSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[1].Id);
var secondSwitchMemberStr = string.Join(", ", await secondSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync());
var secondSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[1].Timestamp).FormatDuration();
msg = $"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr}, {lastSwitchDeltaStr} ago). The next latest switch is {secondSwitchMemberStr} ({secondSwitchDeltaStr} ago). Is this okay?";
}
if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchDeleteCancelled;
await _data.DeleteSwitch(lastTwoSwitches[0]);
await _repo.DeleteSwitch(conn, lastTwoSwitches[0].Id);
await ctx.Reply($"{Emojis.Success} Switch deleted.");
}

View File

@ -6,13 +6,15 @@ namespace PluralKit.Bot
{
public class System
{
private IDataStore _data;
private EmbedService _embeds;
private readonly EmbedService _embeds;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public System(EmbedService embeds, IDataStore data)
public System(EmbedService embeds, IDatabase db, ModelRepository repo)
{
_embeds = embeds;
_data = data;
_db = db;
_repo = repo;
}
public async Task Query(Context ctx, PKSystem system) {
@ -28,9 +30,15 @@ namespace PluralKit.Bot
var systemName = ctx.RemainderOrNull();
if (systemName != null && systemName.Length > Limits.MaxSystemNameLength)
throw Errors.SystemNameTooLongError(systemName.Length);
var system = _db.Execute(async c =>
{
var system = await _repo.CreateSystem(c, systemName);
await _repo.AddAccount(c, system.Id, ctx.Author.Id);
return system;
});
var system = await _data.CreateSystem(systemName);
await _data.AddAccount(system, ctx.Author.Id);
// TODO: better message, perhaps embed like in groups?
await ctx.Reply($"{Emojis.Success} Your system has been created. Type `pk;system` to view it, and type `pk;system help` for more information about commands you can use now. Now that you have that set up, check out the getting started guide on setting up members and proxies: <https://pluralkit.me/start>");
}
}

View File

@ -19,15 +19,13 @@ namespace PluralKit.Bot
{
public class SystemEdit
{
private IDataStore _data;
private IDatabase _db;
private EmbedService _embeds;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public SystemEdit(IDataStore data, EmbedService embeds, IDatabase db)
public SystemEdit(IDatabase db, ModelRepository repo)
{
_data = data;
_embeds = embeds;
_db = db;
_repo = repo;
}
public async Task Name(Context ctx)
@ -37,7 +35,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear())
{
var clearPatch = new SystemPatch {Name = null};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, clearPatch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch));
await ctx.Reply($"{Emojis.Success} System name cleared.");
return;
@ -57,7 +55,7 @@ namespace PluralKit.Bot
throw Errors.SystemNameTooLongError(newSystemName.Length);
var patch = new SystemPatch {Name = newSystemName};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System name changed.");
}
@ -68,7 +66,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear())
{
var patch = new SystemPatch {Description = null};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System description cleared.");
return;
@ -93,7 +91,7 @@ namespace PluralKit.Bot
if (newDescription.Length > Limits.MaxDescriptionLength) throw Errors.DescriptionTooLongError(newDescription.Length);
var patch = new SystemPatch {Description = newDescription};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System description changed.");
}
@ -106,7 +104,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear())
{
var patch = new SystemPatch {Tag = null};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System tag cleared.");
} else if (!ctx.HasNext(skipFlags: false))
@ -124,7 +122,7 @@ namespace PluralKit.Bot
throw Errors.SystemNameTooLongError(newTag.Length);
var patch = new SystemPatch {Tag = newTag};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System tag changed. Member names will now end with {newTag.AsCode()} when proxied.");
}
@ -136,7 +134,7 @@ namespace PluralKit.Bot
async Task ClearIcon()
{
await _db.Execute(c => c.UpdateSystem(ctx.System.Id, new SystemPatch {AvatarUrl = null}));
await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch {AvatarUrl = null}));
await ctx.Reply($"{Emojis.Success} System icon cleared.");
}
@ -146,7 +144,7 @@ namespace PluralKit.Bot
throw Errors.InvalidUrl(img.Url);
await AvatarUtils.VerifyAvatarOrThrow(img.Url);
await _db.Execute(c => c.UpdateSystem(ctx.System.Id, new SystemPatch {AvatarUrl = img.Url}));
await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch {AvatarUrl = img.Url}));
var msg = img.Source switch
{
@ -192,7 +190,7 @@ namespace PluralKit.Bot
if (!await ctx.ConfirmWithReply(ctx.System.Hid))
throw new PKError($"System deletion cancelled. Note that you must reply with your system ID (`{ctx.System.Hid}`) *verbatim*.");
await _db.Execute(conn => conn.DeleteSystem(ctx.System.Id));
await _db.Execute(conn => _repo.DeleteSystem(conn, ctx.System.Id));
await ctx.Reply($"{Emojis.Success} System deleted.");
}
@ -200,7 +198,7 @@ namespace PluralKit.Bot
public async Task SystemProxy(Context ctx)
{
ctx.CheckSystem().CheckGuildContext();
var gs = await _db.Execute(c => c.QueryOrInsertSystemGuildConfig(ctx.Guild.Id, ctx.System.Id));
var gs = await _db.Execute(c => _repo.GetSystemGuild(c, ctx.Guild.Id, ctx.System.Id));
bool newValue;
if (ctx.Match("on", "enabled", "true", "yes")) newValue = true;
@ -216,7 +214,7 @@ namespace PluralKit.Bot
}
var patch = new SystemGuildPatch {ProxyEnabled = newValue};
await _db.Execute(conn => conn.UpsertSystemGuild(ctx.System.Id, ctx.Guild.Id, patch));
await _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch));
if (newValue)
await ctx.Reply($"Message proxying in this server ({ctx.Guild.Name.EscapeMarkdown()}) is now **enabled** for your system.");
@ -231,7 +229,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear())
{
var clearPatch = new SystemPatch {UiTz = "UTC"};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, clearPatch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch));
await ctx.Reply($"{Emojis.Success} System time zone cleared (set to UTC).");
return;
@ -253,7 +251,7 @@ namespace PluralKit.Bot
if (!await ctx.PromptYesNo(msg)) throw Errors.TimezoneChangeCancelled;
var patch = new SystemPatch {UiTz = zone.Id};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"System time zone changed to **{zone.Id}**.");
}
@ -277,7 +275,7 @@ namespace PluralKit.Bot
async Task SetLevel(SystemPrivacySubject subject, PrivacyLevel level)
{
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, new SystemPatch().WithPrivacy(subject, level)));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithPrivacy(subject, level)));
var levelExplanation = level switch
{
@ -302,7 +300,7 @@ namespace PluralKit.Bot
async Task SetAll(PrivacyLevel level)
{
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, new SystemPatch().WithAllPrivacy(level)));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithAllPrivacy(level)));
var msg = level switch
{
@ -334,13 +332,13 @@ namespace PluralKit.Bot
else {
if (ctx.Match("on", "enable")) {
var patch = new SystemPatch {PingsEnabled = true};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply("Reaction pings have now been enabled.");
}
if (ctx.Match("off", "disable")) {
var patch = new SystemPatch {PingsEnabled = false};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch));
await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply("Reaction pings have now been disabled.");
}

View File

@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
@ -10,19 +11,21 @@ namespace PluralKit.Bot
{
public class SystemFront
{
private IDataStore _data;
private EmbedService _embeds;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly EmbedService _embeds;
public SystemFront(IDataStore data, EmbedService embeds)
public SystemFront(EmbedService embeds, IDatabase db, ModelRepository repo)
{
_data = data;
_embeds = embeds;
_db = db;
_repo = repo;
}
struct FrontHistoryEntry
{
public Instant? LastTime;
public PKSwitch ThisSwitch;
public readonly Instant? LastTime;
public readonly PKSwitch ThisSwitch;
public FrontHistoryEntry(Instant? lastTime, PKSwitch thisSwitch)
{
@ -35,8 +38,10 @@ namespace PluralKit.Bot
{
if (system == null) throw Errors.NoSystemError;
ctx.CheckSystemPrivacy(system, system.FrontPrivacy);
await using var conn = await _db.Obtain();
var sw = await _data.GetLatestSwitch(system.Id);
var sw = await _repo.GetLatestSwitch(conn, system.Id);
if (sw == null) throw Errors.NoRegisteredSwitches;
await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone, ctx.LookupContextFor(system)));
@ -47,11 +52,16 @@ namespace PluralKit.Bot
if (system == null) throw Errors.NoSystemError;
ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy);
var sws = _data.GetSwitches(system.Id)
.Scan(new FrontHistoryEntry(null, null), (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch));
var totalSwitches = await _data.GetSwitchCount(system);
if (totalSwitches == 0) throw Errors.NoRegisteredSwitches;
// Gotta be careful here: if we dispose of the connection while the IAE is alive, boom
await using var conn = await _db.Obtain();
var totalSwitches = await _repo.GetSwitchCount(conn, system.Id);
if (totalSwitches == 0) throw Errors.NoRegisteredSwitches;
var sws = _repo.GetSwitches(conn, system.Id)
.Scan(new FrontHistoryEntry(null, null),
(lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch));
var embedTitle = system.Name != null ? $"Front history of {system.Name} (`{system.Hid}`)" : $"Front history of `{system.Hid}`";
await ctx.Paginate(
@ -66,8 +76,11 @@ namespace PluralKit.Bot
var lastSw = entry.LastTime;
var sw = entry.ThisSwitch;
// Fetch member list and format
var members = await _data.GetSwitchMembers(sw).ToListAsync();
await using var conn = await _db.Obtain();
var members = await _db.Execute(c => _repo.GetSwitchMembers(c, sw.Id)).ToListAsync();
var membersStr = members.Any() ? string.Join(", ", members.Select(m => m.NameFor(ctx))) : "no fronter";
var switchSince = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp;
@ -111,8 +124,8 @@ namespace PluralKit.Bot
var rangeStart = DateUtils.ParseDateTime(durationStr, true, system.Zone);
if (rangeStart == null) throw Errors.InvalidDateTime(durationStr);
if (rangeStart.Value.ToInstant() > now) throw Errors.FrontPercentTimeInFuture;
var frontpercent = await _data.GetFrontBreakdown(system, rangeStart.Value.ToInstant(), now);
var frontpercent = await _db.Execute(c => _repo.GetFrontBreakdown(c, system.Id, rangeStart.Value.ToInstant(), now));
await ctx.Reply(embed: await _embeds.CreateFrontPercentEmbed(frontpercent, system.Zone, ctx.LookupContextFor(system)));
}
}

View File

@ -9,28 +9,34 @@ namespace PluralKit.Bot
{
public class SystemLink
{
private IDataStore _data;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
public SystemLink(IDataStore data)
public SystemLink(IDatabase db, ModelRepository repo)
{
_data = data;
_db = db;
_repo = repo;
}
public async Task LinkSystem(Context ctx)
{
ctx.CheckSystem();
await using var conn = await _db.Obtain();
var account = await ctx.MatchUser() ?? throw new PKSyntaxError("You must pass an account to link with (either ID or @mention).");
var accountIds = await _data.GetSystemAccounts(ctx.System);
if (accountIds.Contains(account.Id)) throw Errors.AccountAlreadyLinked;
var accountIds = await _repo.GetSystemAccounts(conn, ctx.System.Id);
if (accountIds.Contains(account.Id))
throw Errors.AccountAlreadyLinked;
var existingAccount = await _data.GetSystemByAccount(account.Id);
if (existingAccount != null) throw Errors.AccountInOtherSystem(existingAccount);
var existingAccount = await _repo.GetSystemByAccount(conn, account.Id);
if (existingAccount != null)
throw Errors.AccountInOtherSystem(existingAccount);
var msg = $"{account.Mention}, please confirm the link by clicking the {Emojis.Success} reaction on this message.";
var mentions = new IMention[] { new UserMention(account) };
if (!await ctx.PromptYesNo(msg, user: account, mentions: mentions)) throw Errors.MemberLinkCancelled;
await _data.AddAccount(ctx.System, account.Id);
await _repo.AddAccount(conn, ctx.System.Id, account.Id);
await ctx.Reply($"{Emojis.Success} Account linked to system.");
}
@ -38,20 +44,22 @@ namespace PluralKit.Bot
{
ctx.CheckSystem();
await using var conn = await _db.Obtain();
ulong id;
if (!ctx.HasNext())
id = ctx.Author.Id;
else if (!ctx.MatchUserRaw(out id))
throw new PKSyntaxError("You must pass an account to link with (either ID or @mention).");
var accountIds = (await _data.GetSystemAccounts(ctx.System)).ToList();
var accountIds = (await _repo.GetSystemAccounts(conn, ctx.System.Id)).ToList();
if (!accountIds.Contains(id)) throw Errors.AccountNotLinked;
if (accountIds.Count == 1) throw Errors.UnlinkingLastAccount;
var msg = $"Are you sure you want to unlink <@{id}> from your system?";
if (!await ctx.PromptYesNo(msg)) throw Errors.MemberUnlinkCancelled;
await _data.RemoveAccount(ctx.System, id);
await _repo.RemoveAccount(conn, ctx.System.Id, id);
await ctx.Reply($"{Emojis.Success} Account unlinked.");
}
}

View File

@ -10,9 +10,11 @@ namespace PluralKit.Bot
public class Token
{
private readonly IDatabase _db;
public Token(IDatabase db)
private readonly ModelRepository _repo;
public Token(IDatabase db, ModelRepository repo)
{
_db = db;
_repo = repo;
}
public async Task GetToken(Context ctx)
@ -45,7 +47,7 @@ namespace PluralKit.Bot
private async Task<string> MakeAndSetNewToken(PKSystem system)
{
var patch = new SystemPatch {Token = StringUtils.GenerateToken()};
system = await _db.Execute(conn => conn.UpdateSystem(system.Id, patch));
system = await _db.Execute(conn => _repo.UpdateSystem(conn, system.Id, patch));
return system.Token;
}

View File

@ -23,11 +23,12 @@ namespace PluralKit.Bot
private readonly ProxyService _proxy;
private readonly ILifetimeScope _services;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly BotConfig _config;
public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean,
IMetrics metrics, ProxyService proxy, DiscordShardedClient client,
CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config)
CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, ModelRepository repo)
{
_lastMessageCache = lastMessageCache;
_loggerClean = loggerClean;
@ -38,6 +39,7 @@ namespace PluralKit.Bot
_services = services;
_db = db;
_config = config;
_repo = repo;
}
public DiscordChannel ErrorChannelFor(MessageCreateEventArgs evt) => evt.Channel;
@ -59,7 +61,7 @@ namespace PluralKit.Bot
MessageContext ctx;
await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await conn.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
// Try each handler until we find one that succeeds
if (await TryHandleLogClean(evt, ctx))
@ -98,7 +100,7 @@ namespace PluralKit.Bot
try
{
var system = ctx.SystemId != null ? await _db.Execute(c => c.QuerySystem(ctx.SystemId.Value)) : null;
var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null;
await _tree.ExecuteCommand(new Context(_services, evt.Client, evt.Message, cmdStart, system, ctx));
}
catch (PKError)

View File

@ -12,28 +12,29 @@ namespace PluralKit.Bot
// Double duty :)
public class MessageDeleted: IEventHandler<MessageDeleteEventArgs>, IEventHandler<MessageBulkDeleteEventArgs>
{
private readonly IDataStore _data;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly ILogger _logger;
public MessageDeleted(IDataStore data, ILogger logger)
public MessageDeleted(ILogger logger, IDatabase db, ModelRepository repo)
{
_data = data;
_db = db;
_repo = repo;
_logger = logger.ForContext<MessageDeleted>();
}
public async Task Handle(MessageDeleteEventArgs evt)
{
// Delete deleted webhook messages from the data store
// (if we don't know whether it's a webhook, delete it just to be safe)
if (!evt.Message.WebhookMessage) return;
await _data.DeleteMessage(evt.Message.Id);
// Most of the data in the given message is wrong/missing, so always delete just to be sure.
await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id));
}
public async Task Handle(MessageBulkDeleteEventArgs evt)
{
// Same as above, but bulk
_logger.Information("Bulk deleting {Count} messages in channel {Channel}", evt.Messages.Count, evt.Channel.Id);
await _data.DeleteMessagesBulk(evt.Messages.Select(m => m.Id).ToList());
await _db.Execute(c => _repo.DeleteMessagesBulk(c, evt.Messages.Select(m => m.Id).ToList()));
}
}
}

View File

@ -14,14 +14,16 @@ namespace PluralKit.Bot
private readonly LastMessageCacheService _lastMessageCache;
private readonly ProxyService _proxy;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly IMetrics _metrics;
public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, IDatabase db, IMetrics metrics)
public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, IDatabase db, IMetrics metrics, ModelRepository repo)
{
_lastMessageCache = lastMessageCache;
_proxy = proxy;
_db = db;
_metrics = metrics;
_repo = repo;
}
public async Task Handle(MessageUpdateEventArgs evt)
@ -36,7 +38,7 @@ namespace PluralKit.Bot
MessageContext ctx;
await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await conn.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: false);
}
}

View File

@ -13,14 +13,16 @@ namespace PluralKit.Bot
{
public class ReactionAdded: IEventHandler<MessageReactionAddEventArgs>
{
private IDataStore _data;
private EmbedService _embeds;
private ILogger _logger;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly EmbedService _embeds;
private readonly ILogger _logger;
public ReactionAdded(IDataStore data, EmbedService embeds, ILogger logger)
public ReactionAdded(EmbedService embeds, ILogger logger, IDatabase db, ModelRepository repo)
{
_data = data;
_embeds = embeds;
_db = db;
_repo = repo;
_logger = logger.ForContext<ReactionAdded>();
}
@ -42,18 +44,21 @@ namespace PluralKit.Bot
// Ignore reactions from bots (we can't DM them anyway)
if (evt.User.IsBot) return;
Task<FullMessage> GetMessage() =>
_db.Execute(c => _repo.GetMessage(c, evt.Message.Id));
FullMessage msg;
switch (evt.Emoji.Name)
{
// Message deletion
case "\u274C": // Red X
if ((msg = await _data.GetMessage(evt.Message.Id)) != null)
if ((msg = await GetMessage()) != null)
await HandleDeleteReaction(evt, msg);
break;
case "\u2753": // Red question mark
case "\u2754": // White question mark
if ((msg = await _data.GetMessage(evt.Message.Id)) != null)
if ((msg = await GetMessage()) != null)
await HandleQueryReaction(evt, msg);
break;
@ -62,7 +67,7 @@ namespace PluralKit.Bot
case "\U0001F3D3": // Ping pong paddle (lol)
case "\u23F0": // Alarm clock
case "\u2757": // Exclamation mark
if ((msg = await _data.GetMessage(evt.Message.Id)) != null)
if ((msg = await GetMessage()) != null)
await HandlePingReaction(evt, msg);
break;
}
@ -84,7 +89,7 @@ namespace PluralKit.Bot
// Message was deleted by something/someone else before we got to it
}
await _data.DeleteMessage(evt.Message.Id);
await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id));
}
private async ValueTask HandleQueryReaction(MessageReactionAddEventArgs evt, FullMessage msg)

View File

@ -21,21 +21,21 @@ namespace PluralKit.Bot
private readonly LogChannelService _logChannel;
private readonly IDatabase _db;
private readonly IDataStore _data;
private readonly ModelRepository _repo;
private readonly ILogger _logger;
private readonly WebhookExecutorService _webhookExecutor;
private readonly ProxyMatcher _matcher;
private readonly IMetrics _metrics;
public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger,
WebhookExecutorService webhookExecutor, IDatabase db, ProxyMatcher matcher, IMetrics metrics)
public ProxyService(LogChannelService logChannel, ILogger logger,
WebhookExecutorService webhookExecutor, IDatabase db, ProxyMatcher matcher, IMetrics metrics, ModelRepository repo)
{
_logChannel = logChannel;
_data = data;
_webhookExecutor = webhookExecutor;
_db = db;
_matcher = matcher;
_metrics = metrics;
_repo = repo;
_logger = logger.ForContext<ProxyService>();
}
@ -48,7 +48,7 @@ namespace PluralKit.Bot
List<ProxyMember> members;
using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime))
members = (await conn.QueryProxyMembers(message.Author.Id, message.Channel.GuildId)).ToList();
members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.Channel.GuildId)).ToList();
if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0,
allowAutoproxy)) return false;
@ -99,8 +99,17 @@ namespace PluralKit.Bot
var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName(ctx),
match.Member.ProxyAvatar(ctx),
content, trigger.Attachments, allowEveryone);
Task SaveMessage() => _repo.AddMessage(conn, new PKMessage
{
Channel = trigger.ChannelId,
Guild = trigger.Channel.GuildId,
Member = match.Member.Id,
Mid = id,
OriginalMid = trigger.Id,
Sender = trigger.Author.Id
});
Task SaveMessage() => _data.AddMessage(conn, trigger.Author.Id, trigger.Channel.GuildId, trigger.ChannelId, id, trigger.Id, match.Member.Id);
Task LogMessage() => _logChannel.LogMessage(ctx, match, trigger, id).AsTask();
async Task DeleteMessage()
{

View File

@ -7,7 +7,7 @@ namespace PluralKit.Bot
{
public class CpuStatService
{
private ILogger _logger;
private readonly ILogger _logger;
public double LastCpuMeasure { get; private set; }

View File

@ -15,40 +15,38 @@ using PluralKit.Core;
namespace PluralKit.Bot {
public class EmbedService
{
private IDataStore _data;
private IDatabase _db;
private DiscordShardedClient _client;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly DiscordShardedClient _client;
public EmbedService(DiscordShardedClient client, IDataStore data, IDatabase db)
public EmbedService(DiscordShardedClient client, IDatabase db, ModelRepository repo)
{
_client = client;
_data = data;
_db = db;
_repo = repo;
}
public async Task<DiscordEmbed> CreateSystemEmbed(DiscordClient client, PKSystem system, LookupContext ctx)
{
await using var conn = await _db.Obtain();
// Fetch/render info for all accounts simultaneously
var accounts = await conn.GetLinkedAccounts(system.Id);
var accounts = await _repo.GetSystemAccounts(conn, system.Id);
var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUser(uid))?.NameAndMention() ?? $"(deleted account {uid})"));
var memberCount = await conn.GetSystemMemberCount(system.Id, PrivacyLevel.Public);
var memberCount = await _repo.GetSystemMemberCount(conn, system.Id, PrivacyLevel.Public);
var eb = new DiscordEmbedBuilder()
.WithColor(DiscordUtils.Gray)
.WithTitle(system.Name ?? null)
.WithThumbnail(system.AvatarUrl)
.WithFooter($"System ID: {system.Hid} | Created on {system.Created.FormatZoned(system)}");
var latestSwitch = await _data.GetLatestSwitch(system.Id);
var latestSwitch = await _repo.GetLatestSwitch(conn, system.Id);
if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx))
{
var switchMembers = await _data.GetSwitchMembers(latestSwitch).ToListAsync();
if (switchMembers.Count > 0)
eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None),
var switchMembers = await _repo.GetSwitchMembers(conn, latestSwitch.Id).ToListAsync();
if (switchMembers.Count > 0)
eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None),
string.Join(", ", switchMembers.Select(m => m.NameFor(ctx))));
}
@ -105,11 +103,13 @@ namespace PluralKit.Bot {
await using var conn = await _db.Obtain();
var guildSettings = guild != null ? await conn.QueryOrInsertMemberGuildConfig(guild.Id, member.Id) : null;
var guildSettings = guild != null ? await _repo.GetMemberGuild(conn, guild.Id, member.Id) : null;
var guildDisplayName = guildSettings?.DisplayName;
var avatar = guildSettings?.AvatarUrl ?? member.AvatarFor(ctx);
var groups = (await conn.QueryMemberGroups(member.Id)).Where(g => g.Visibility.CanAccess(ctx)).ToList();
var groups = await _repo.GetMemberGroups(conn, member.Id)
.Where(g => g.Visibility.CanAccess(ctx))
.ToListAsync();
var eb = new DiscordEmbedBuilder()
// TODO: add URL of website when that's up
@ -157,7 +157,7 @@ namespace PluralKit.Bot {
public async Task<DiscordEmbed> CreateFronterEmbed(PKSwitch sw, DateTimeZone zone, LookupContext ctx)
{
var members = await _data.GetSwitchMembers(sw).ToListAsync();
var members = await _db.Execute(c => _repo.GetSwitchMembers(c, sw.Id).ToListAsync().AsTask());
var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp;
return new DiscordEmbedBuilder()
.WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? DiscordUtils.Gray)

View File

@ -10,7 +10,7 @@ namespace PluralKit.Bot
// TODO: is this still needed after the D#+ migration?
public class LastMessageCacheService
{
private IDictionary<ulong, ulong> _cache = new ConcurrentDictionary<ulong, ulong>();
private readonly IDictionary<ulong, ulong> _cache = new ConcurrentDictionary<ulong, ulong>();
public void AddMessage(ulong channel, ulong message)
{

View File

@ -15,16 +15,16 @@ namespace PluralKit.Bot {
public class LogChannelService {
private readonly EmbedService _embed;
private readonly IDatabase _db;
private readonly IDataStore _data;
private readonly ModelRepository _repo;
private readonly ILogger _logger;
private readonly DiscordRestClient _rest;
public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, IDatabase db, IDataStore data)
public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, IDatabase db, ModelRepository repo)
{
_embed = embed;
_rest = rest;
_db = db;
_data = data;
_repo = repo;
_logger = logger.ForContext<LogChannelService>();
}
@ -47,8 +47,8 @@ namespace PluralKit.Bot {
// Send embed!
await using var conn = await _db.Obtain();
var embed = _embed.CreateLoggedMessageEmbed(await conn.QuerySystem(ctx.SystemId.Value),
await conn.QueryMember(proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content,
var embed = _embed.CreateLoggedMessageEmbed(await _repo.GetSystem(conn, ctx.SystemId.Value),
await _repo.GetMember(conn, proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content,
trigger.Channel);
var url = $"https://discord.com/channels/{trigger.Channel.GuildId}/{trigger.ChannelId}/{hookMessage}";
await logChannel.SendMessageFixedAsync(content: url, embed: embed);

View File

@ -16,19 +16,19 @@ namespace PluralKit.Bot
{
public class LoggerCleanService
{
private static Regex _basicRegex = new Regex("(\\d{17,19})");
private static Regex _dynoRegex = new Regex("Message ID: (\\d{17,19})");
private static Regex _carlRegex = new Regex("ID: (\\d{17,19})");
private static Regex _circleRegex = new Regex("\\(`(\\d{17,19})`\\)");
private static Regex _loggerARegex = new Regex("Message = (\\d{17,19})");
private static Regex _loggerBRegex = new Regex("MessageID:(\\d{17,19})");
private static Regex _auttajaRegex = new Regex("Message (\\d{17,19}) deleted");
private static Regex _mantaroRegex = new Regex("Message \\(?ID:? (\\d{17,19})\\)? created by .* in channel .* was deleted\\.");
private static Regex _pancakeRegex = new Regex("Message from <@(\\d{17,19})> deleted in");
private static Regex _unbelievaboatRegex = new Regex("Message ID: (\\d{17,19})");
private static Regex _vanessaRegex = new Regex("Message sent by <@!?(\\d{17,19})> deleted in");
private static Regex _salRegex = new Regex("\\(ID: (\\d{17,19})\\)");
private static Regex _GearBotRegex = new Regex("\\(``(\\d{17,19})``\\) in <#\\d{17,19}> has been removed.");
private static readonly Regex _basicRegex = new Regex("(\\d{17,19})");
private static readonly Regex _dynoRegex = new Regex("Message ID: (\\d{17,19})");
private static readonly Regex _carlRegex = new Regex("ID: (\\d{17,19})");
private static readonly Regex _circleRegex = new Regex("\\(`(\\d{17,19})`\\)");
private static readonly Regex _loggerARegex = new Regex("Message = (\\d{17,19})");
private static readonly Regex _loggerBRegex = new Regex("MessageID:(\\d{17,19})");
private static readonly Regex _auttajaRegex = new Regex("Message (\\d{17,19}) deleted");
private static readonly Regex _mantaroRegex = new Regex("Message \\(?ID:? (\\d{17,19})\\)? created by .* in channel .* was deleted\\.");
private static readonly Regex _pancakeRegex = new Regex("Message from <@(\\d{17,19})> deleted in");
private static readonly Regex _unbelievaboatRegex = new Regex("Message ID: (\\d{17,19})");
private static readonly Regex _vanessaRegex = new Regex("Message sent by <@!?(\\d{17,19})> deleted in");
private static readonly Regex _salRegex = new Regex("\\(ID: (\\d{17,19})\\)");
private static readonly Regex _GearBotRegex = new Regex("\\(``(\\d{17,19})``\\) in <#\\d{17,19}> has been removed.");
private static readonly Dictionary<ulong, LoggerBot> _bots = new[]
{
@ -55,7 +55,7 @@ namespace PluralKit.Bot
.Where(b => b.WebhookName != null)
.ToDictionary(b => b.WebhookName);
private IDatabase _db;
private readonly IDatabase _db;
private DiscordShardedClient _client;
public LoggerCleanService(IDatabase db, DiscordShardedClient client)

View File

@ -17,17 +17,17 @@ namespace PluralKit.Bot
{
public class PeriodicStatCollector
{
private DiscordShardedClient _client;
private IMetrics _metrics;
private CpuStatService _cpu;
private readonly DiscordShardedClient _client;
private readonly IMetrics _metrics;
private readonly CpuStatService _cpu;
private IDatabase _db;
private readonly IDatabase _db;
private WebhookCacheService _webhookCache;
private readonly WebhookCacheService _webhookCache;
private DbConnectionCountHolder _countHolder;
private readonly DbConnectionCountHolder _countHolder;
private ILogger _logger;
private readonly ILogger _logger;
public PeriodicStatCollector(DiscordShardedClient client, IMetrics metrics, ILogger logger, WebhookCacheService webhookCache, DbConnectionCountHolder countHolder, CpuStatService cpu, IDatabase db)
{

View File

@ -16,7 +16,6 @@ namespace PluralKit.Bot
{
public class ShardInfoService
{
public class ShardInfo
{
public bool HasAttachedListeners;
@ -27,10 +26,10 @@ namespace PluralKit.Bot
public bool Connected;
}
private IMetrics _metrics;
private ILogger _logger;
private DiscordShardedClient _client;
private Dictionary<int, ShardInfo> _shardInfo = new Dictionary<int, ShardInfo>();
private readonly IMetrics _metrics;
private readonly ILogger _logger;
private readonly DiscordShardedClient _client;
private readonly Dictionary<int, ShardInfo> _shardInfo = new Dictionary<int, ShardInfo>();
public ShardInfoService(ILogger logger, DiscordShardedClient client, IMetrics metrics)
{

View File

@ -18,11 +18,11 @@ namespace PluralKit.Bot
{
public static readonly string WebhookName = "PluralKit Proxy Webhook";
private DiscordShardedClient _client;
private ConcurrentDictionary<ulong, Lazy<Task<DiscordWebhook>>> _webhooks;
private readonly DiscordShardedClient _client;
private readonly ConcurrentDictionary<ulong, Lazy<Task<DiscordWebhook>>> _webhooks;
private IMetrics _metrics;
private ILogger _logger;
private readonly IMetrics _metrics;
private readonly ILogger _logger;
public WebhookCacheService(DiscordShardedClient client, ILogger logger, IMetrics metrics)
{

View File

@ -29,10 +29,10 @@ namespace PluralKit.Bot
public class WebhookExecutorService
{
private WebhookCacheService _webhookCache;
private ILogger _logger;
private IMetrics _metrics;
private HttpClient _client;
private readonly WebhookCacheService _webhookCache;
private readonly ILogger _logger;
private readonly IMetrics _metrics;
private readonly HttpClient _client;
public WebhookExecutorService(IMetrics metrics, WebhookCacheService webhookCache, ILogger logger, HttpClient client)
{

View File

@ -2,7 +2,6 @@
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using App.Metrics;
@ -207,11 +206,19 @@ namespace PluralKit.Core
await using var conn = await db.Obtain();
await func(conn);
}
public static async Task<T> Execute<T>(this IDatabase db, Func<IPKConnection, Task<T>> func)
{
await using var conn = await db.Obtain();
return await func(conn);
}
public static async IAsyncEnumerable<T> Execute<T>(this IDatabase db, Func<IPKConnection, IAsyncEnumerable<T>> func)
{
await using var conn = await db.Obtain();
await foreach (var val in func(conn))
yield return val;
}
}
}

View File

@ -6,16 +6,16 @@ using Dapper;
namespace PluralKit.Core
{
public static class DatabaseFunctionsExt
public partial class ModelRepository
{
public static Task<MessageContext> QueryMessageContext(this IPKConnection conn, ulong account, ulong guild, ulong channel)
public Task<MessageContext> GetMessageContext(IPKConnection conn, ulong account, ulong guild, ulong channel)
{
return conn.QueryFirstAsync<MessageContext>("message_context",
new { account_id = account, guild_id = guild, channel_id = channel },
commandType: CommandType.StoredProcedure);
}
public static Task<IEnumerable<ProxyMember>> QueryProxyMembers(this IPKConnection conn, ulong account, ulong guild)
public Task<IEnumerable<ProxyMember>> GetProxyMembers(IPKConnection conn, ulong account, ulong guild)
{
return conn.QueryAsync<ProxyMember>("proxy_members",
new { account_id = account, guild_id = guild },

View File

@ -0,0 +1,83 @@
#nullable enable
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKGroup?> GetGroupByName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name});
public Task<PKGroup?> GetGroupByHid(IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()});
public Task<int> GetGroupMemberCount(IPKConnection conn, GroupId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from group_members");
if (privacyFilter != null)
query.Append(" inner join members on group_members.member_id = members.id");
query.Append(" where group_members.group_id = @Id");
if (privacyFilter != null)
query.Append(" and members.member_visibility = @PrivacyFilter");
return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter});
}
public IAsyncEnumerable<PKGroup> GetMemberGroups(IPKConnection conn, MemberId id) =>
conn.QueryStreamAsync<PKGroup>(
"select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id",
new {Id = id});
public async Task<PKGroup> CreateGroup(IPKConnection conn, SystemId system, string name)
{
var group = await conn.QueryFirstAsync<PKGroup>(
"insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *",
new {System = system, Name = name});
_logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
return group;
}
public Task<PKGroup> UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch)
{
_logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKGroup>(query, pms);
}
public Task DeleteGroup(IPKConnection conn, GroupId group)
{
_logger.Information("Deleted {GroupId}", group);
return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group});
}
public async Task AddMembersToGroup(IPKConnection conn, GroupId group,
IReadOnlyCollection<MemberId> members)
{
await using var w =
conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(group.Value);
await w.WriteAsync(member.Value);
}
await w.CompleteAsync();
_logger.Information("Added members to {GroupId}: {MemberIds}", group, members);
}
public Task RemoveMembersFromGroup(IPKConnection conn, GroupId group,
IReadOnlyCollection<MemberId> members)
{
_logger.Information("Removed members from {GroupId}: {MemberIds}", group, members);
return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)",
new {Group = @group, Members = members.ToArray()});
}
}
}

View File

@ -0,0 +1,53 @@
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task UpsertGuild(IPKConnection conn, ulong guild, GuildPatch patch)
{
_logger.Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id"))
.WithConstant("id", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task UpsertSystemGuild(IPKConnection conn, SystemId system, ulong guild,
SystemGuildPatch patch)
{
_logger.Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild"))
.WithConstant("system", system)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task UpsertMemberGuild(IPKConnection conn, MemberId member, ulong guild,
MemberGuildPatch patch)
{
_logger.Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild"))
.WithConstant("member", member)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task<GuildConfig> GetGuild(IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild});
public Task<SystemGuildSettings> GetSystemGuild(IPKConnection conn, ulong guild, SystemId system) =>
conn.QueryFirstAsync<SystemGuildSettings>(
"insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *",
new {guild, system});
public Task<MemberGuildSettings> GetMemberGuild(IPKConnection conn, ulong guild, MemberId member) =>
conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member});
}
}

View File

@ -0,0 +1,47 @@
#nullable enable
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKMember?> GetMember(IPKConnection conn, MemberId id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public Task<PKMember?> GetMemberByHid(IPKConnection conn, string hid) =>
conn.QuerySingleOrDefaultAsync<PKMember?>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
public Task<PKMember?> GetMemberByName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system });
public Task<PKMember?> GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system });
public async Task<PKMember> CreateMember(IPKConnection conn, SystemId id, string memberName)
{
var member = await conn.QueryFirstAsync<PKMember>(
"insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *",
new {SystemId = id, Name = memberName});
_logger.Information("Created {MemberId} in {SystemId}: {MemberName}",
member.Id, id, memberName);
return member;
}
public Task<PKMember> UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch)
{
_logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKMember>(query, pms);
}
public Task DeleteMember(IPKConnection conn, MemberId id)
{
_logger.Information("Deleted {MemberId}", id);
return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id});
}
}
}

View File

@ -0,0 +1,63 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public async Task AddMessage(IPKConnection conn, PKMessage msg) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@Mid, @Guild, @Channel, @Member, @Sender, @OriginalMid) on conflict do nothing", msg);
_logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
}
public async Task<FullMessage> GetMessage(IPKConnection conn, ulong id)
{
FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) =>
new FullMessage {Message = msg, System = system, Member = member};
var result = await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>(
"select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system",
Mapper, new {Id = id});
return result.FirstOrDefault();
}
public async Task DeleteMessage(IPKConnection conn, ulong id)
{
var rowCount = await conn.ExecuteAsync("delete from messages where mid = @Id", new {Id = id});
if (rowCount > 0)
_logger.Information("Deleted message {MessageId} from database", id);
}
public async Task DeleteMessagesBulk(IPKConnection conn, IReadOnlyCollection<ulong> ids)
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var rowCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)",
new {Ids = ids.Select(id => (long) id).ToArray()});
if (rowCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
ids);
}
}
public class PKMessage
{
public ulong Mid { get; set; }
public ulong? Guild { get; set; } // null value means "no data" (ie. from before this field being added)
public ulong Channel { get; set; }
public MemberId Member { get; set; }
public ulong Sender { get; set; }
public ulong? OriginalMid { get; set; }
}
public class FullMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
}

View File

@ -0,0 +1,236 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using NpgsqlTypes;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public async Task AddSwitch(IPKConnection conn, SystemId system, IReadOnlyCollection<MemberId> members)
{
// Use a transaction here since we're doing multiple executed commands in one
await using var tx = await conn.BeginTransactionAsync();
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system});
// Then we insert each member in the switch in the switch_members table
await using (var w = conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
{
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(sw.Id.Value, NpgsqlDbType.Integer);
await w.WriteAsync(member.Value, NpgsqlDbType.Integer);
}
await w.CompleteAsync();
}
// Finally we commit the tx, since the using block will otherwise rollback it
await tx.CommitAsync();
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members);
}
public async Task MoveSwitch(IPKConnection conn, SwitchId id, Instant time)
{
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = id});
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", id, time);
}
public async Task DeleteSwitch(IPKConnection conn, SwitchId id)
{
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = id});
_logger.Information("Deleted {Switch}", id);
}
public async Task DeleteAllSwitches(IPKConnection conn, SystemId system)
{
await conn.ExecuteAsync("delete from switches where system = @Id", new {Id = system});
_logger.Information("Deleted all switches in {SystemId}", system);
}
public IAsyncEnumerable<PKSwitch> GetSwitches(IPKConnection conn, SystemId system)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
return conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc",
new {System = system});
}
public async Task<int> GetSwitchCount(IPKConnection conn, SystemId system)
{
return await conn.QuerySingleAsync<int>("select count(*) from switches where system = @Id", new { Id = system });
}
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(IPKConnection conn,
SystemId system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
await using var tx = await conn.BeginTransactionAsync();
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new {System = system, Start = start});
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new {System = system, Start = start, End = end, LastSwitch = lastSwitch});
// Yield each value here
await foreach (var entry in switchMembersEntries)
yield return entry;
// Don't really need to worry about the transaction here, we're not doing any *writes*
}
public IAsyncEnumerable<PKMember> GetSwitchMembers(IPKConnection conn, SwitchId sw)
{
return conn.QueryStreamAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw});
}
public async Task<PKSwitch> GetLatestSwitch(IPKConnection conn, SystemId system) =>
// TODO: should query directly for perf
await GetSwitches(conn, system).FirstOrDefaultAsync();
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(IPKConnection conn,
SystemId system, Instant periodStart,
Instant periodEnd)
{
// TODO: IAsyncEnumerable-ify this one
// TODO: this doesn't belong in the repo
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(conn, system, periodStart, periodEnd).ToListAsync();
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
var membersList = await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new {Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList()});
var memberObjects = membersList.ToDictionary(m => m.Id);
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp
into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member])
.ToList()
};
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
outList.Add(new SwitchListEntry
{
Members = e.Members, TimespanStart = switchStartClamped, TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
}
return outList;
}
public async Task<FrontBreakdown> GetFrontBreakdown(IPKConnection conn, SystemId system, Instant periodStart,
Instant periodEnd)
{
// TODO: this doesn't belong in the repo
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetPeriodFronters(conn, system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new FrontBreakdown
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public struct FrontBreakdown
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public struct SwitchMembersListEntry
{
public MemberId Member;
public Instant Timestamp;
}
}

View File

@ -0,0 +1,78 @@
#nullable enable
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKSystem?> GetSystem(IPKConnection conn, SystemId id) =>
conn.QueryFirstOrDefaultAsync<PKSystem?>("select * from systems where id = @id", new {id});
public Task<PKSystem?> GetSystemByAccount(IPKConnection conn, ulong accountId) =>
conn.QuerySingleOrDefaultAsync<PKSystem?>(
"select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id",
new {Id = accountId});
public Task<PKSystem?> GetSystemByHid(IPKConnection conn, string hid) =>
conn.QuerySingleOrDefaultAsync<PKSystem?>("select * from systems where systems.hid = @Hid",
new {Hid = hid.ToLower()});
public Task<IEnumerable<ulong>> GetSystemAccounts(IPKConnection conn, SystemId system) =>
conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new {Id = system});
public IAsyncEnumerable<PKMember> GetSystemMembers(IPKConnection conn, SystemId system) =>
conn.QueryStreamAsync<PKMember>("select * from members where system = @SystemID", new {SystemID = system});
public Task<int> GetSystemMemberCount(IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from members where system = @Id");
if (privacyFilter != null)
query.Append($" and member_visibility = {(int) privacyFilter.Value}");
return conn.QuerySingleAsync<int>(query.ToString(), new {Id = id});
}
public async Task<PKSystem> CreateSystem(IPKConnection conn, string? systemName = null)
{
var system = await conn.QuerySingleAsync<PKSystem>(
"insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *",
new {Name = systemName});
_logger.Information("Created {SystemId}", system.Id);
return system;
}
public Task<PKSystem> UpdateSystem(IPKConnection conn, SystemId id, SystemPatch patch)
{
_logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKSystem>(query, pms);
}
public async Task AddAccount(IPKConnection conn, SystemId system, ulong accountId)
{
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing",
new {Id = accountId, SystemId = system});
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system);
}
public async Task RemoveAccount(IPKConnection conn, SystemId system, ulong accountId)
{
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId",
new {Id = accountId, SystemId = system});
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system);
}
public Task DeleteSystem(IPKConnection conn, SystemId id)
{
_logger.Information("Deleted {SystemId}", id);
return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id});
}
}
}

View File

@ -0,0 +1,15 @@
using Serilog;
namespace PluralKit.Core
{
public partial class ModelRepository
{
private readonly ILogger _logger;
public ModelRepository(ILogger logger)
{
_logger = logger.ForContext<ILogger>()
.ForContext("Elastic", "yes?");
}
}
}

View File

@ -64,7 +64,11 @@ namespace PluralKit.Core
protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
public override void Open() => throw SyncError(nameof(Open));
public override void Close() => throw SyncError(nameof(Close));
public override void Close()
{
// Don't throw SyncError here, Dapper calls sync Close() internally so that sucks
Inner.Close();
}
IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction));
IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction));

View File

@ -1,69 +0,0 @@
#nullable enable
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public static class ModelQueryExt
{
public static Task<PKSystem?> QuerySystem(this IPKConnection conn, SystemId id) =>
conn.QueryFirstOrDefaultAsync<PKSystem?>("select * from systems where id = @id", new {id});
public static Task<int> GetSystemMemberCount(this IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from members where system = @Id");
if (privacyFilter != null)
query.Append($" and member_visibility = {(int) privacyFilter.Value}");
return conn.QuerySingleAsync<int>(query.ToString(), new {Id = id});
}
public static Task<IEnumerable<ulong>> GetLinkedAccounts(this IPKConnection conn, SystemId id) =>
conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new {Id = id});
public static Task<PKMember?> QueryMember(this IPKConnection conn, MemberId id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public static Task<PKMember?> QueryMemberByHid(this IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where hid = @hid", new {hid = hid.ToLowerInvariant()});
public static Task<PKGroup?> QueryGroupByName(this IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name});
public static Task<PKGroup?> QueryGroupByHid(this IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()});
public static Task<int> QueryGroupMemberCount(this IPKConnection conn, GroupId id,
PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from group_members");
if (privacyFilter != null)
query.Append(" inner join members on group_members.member_id = members.id");
query.Append(" where group_members.group_id = @Id");
if (privacyFilter != null)
query.Append(" and members.member_visibility = @PrivacyFilter");
return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter});
}
public static Task<IEnumerable<PKGroup>> QueryMemberGroups(this IPKConnection conn, MemberId id) =>
conn.QueryAsync<PKGroup>(
"select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id",
new {Id = id});
public static Task<GuildConfig> QueryOrInsertGuildConfig(this IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild});
public static Task<SystemGuildSettings> QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, SystemId system) =>
conn.QueryFirstAsync<SystemGuildSettings>(
"insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *",
new {guild, system});
public static Task<MemberGuildSettings> QueryOrInsertMemberGuildConfig(
this IPKConnection conn, ulong guild, MemberId member) =>
conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member});
}
}

View File

@ -1,128 +0,0 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Serilog;
namespace PluralKit.Core
{
public static class ModelPatchExt
{
public static Task<PKSystem> UpdateSystem(this IPKConnection conn, SystemId id, SystemPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {SystemId}: {@SystemPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKSystem>(query, pms);
}
public static Task DeleteSystem(this IPKConnection conn, SystemId id)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {SystemId}", id);
return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id});
}
public static async Task<PKMember> CreateMember(this IPKConnection conn, SystemId system, string memberName)
{
var member = await conn.QueryFirstAsync<PKMember>(
"insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *",
new {SystemId = system, Name = memberName});
Log.ForContext("Elastic", "yes?").Information("Created {MemberId} in {SystemId}: {MemberName}",
member.Id, system, memberName);
return member;
}
public static Task<PKMember> UpdateMember(this IPKConnection conn, MemberId id, MemberPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {MemberId}: {@MemberPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKMember>(query, pms);
}
public static Task DeleteMember(this IPKConnection conn, MemberId id)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {MemberId}", id);
return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id});
}
public static Task UpsertSystemGuild(this IPKConnection conn, SystemId system, ulong guild,
SystemGuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild"))
.WithConstant("system", system)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static Task UpsertMemberGuild(this IPKConnection conn, MemberId member, ulong guild,
MemberGuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild"))
.WithConstant("member", member)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static Task UpsertGuild(this IPKConnection conn, ulong guild, GuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id"))
.WithConstant("id", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static async Task<PKGroup> CreateGroup(this IPKConnection conn, SystemId system, string name)
{
var group = await conn.QueryFirstAsync<PKGroup>(
"insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *",
new {System = system, Name = name});
Log.ForContext("Elastic", "yes?").Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
return group;
}
public static Task<PKGroup> UpdateGroup(this IPKConnection conn, GroupId id, GroupPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {GroupId}: {@GroupPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKGroup>(query, pms);
}
public static Task DeleteGroup(this IPKConnection conn, GroupId group)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {GroupId}", group);
return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group});
}
public static async Task AddMembersToGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection<MemberId> members)
{
await using var w = conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(group.Value);
await w.WriteAsync(member.Value);
}
await w.CompleteAsync();
Log.ForContext("Elastic", "yes?").Information("Added members to {GroupId}: {MemberIds}", group, members);
}
public static Task RemoveMembersFromGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection<MemberId> members)
{
Log.ForContext("Elastic", "yes?").Information("Removed members from {GroupId}: {MemberIds}", group, members);
return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)",
new {Group = @group, Members = members.ToArray()});
}
}
}

View File

@ -1,5 +1,13 @@
namespace PluralKit.Core
{
public enum AutoproxyMode
{
Off = 1,
Front = 2,
Latch = 3,
Member = 4
}
public class SystemGuildSettings
{
public ulong Guild { get; }

View File

@ -25,7 +25,7 @@ namespace PluralKit.Core
{
builder.RegisterType<DbConnectionCountHolder>().SingleInstance();
builder.RegisterType<Database>().As<IDatabase>().SingleInstance();
builder.RegisterType<PostgresDataStore>().AsSelf().As<IDataStore>();
builder.RegisterType<ModelRepository>().AsSelf().SingleInstance();
builder.Populate(new ServiceCollection().AddMemoryCache());
}
@ -33,7 +33,7 @@ namespace PluralKit.Core
public class ConfigModule<T>: Module where T: new()
{
private string _submodule;
private readonly string _submodule;
public ConfigModule(string submodule = null)
{

View File

@ -14,22 +14,24 @@ namespace PluralKit.Core
{
public class DataFileService
{
private IDataStore _data;
private IDatabase _db;
private ILogger _logger;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly ILogger _logger;
public DataFileService(ILogger logger, IDataStore data, IDatabase db)
public DataFileService(ILogger logger, IDatabase db, ModelRepository repo)
{
_data = data;
_db = db;
_repo = repo;
_logger = logger.ForContext<DataFileService>();
}
public async Task<DataFileSystem> ExportSystem(PKSystem system)
{
await using var conn = await _db.Obtain();
// Export members
var members = new List<DataFileMember>();
var pkMembers = _data.GetSystemMembers(system); // Read all members in the system
var pkMembers = _repo.GetSystemMembers(conn, system.Id); // Read all members in the system
await foreach (var member in pkMembers.Select(m => new DataFileMember
{
@ -49,7 +51,7 @@ namespace PluralKit.Core
// Export switches
var switches = new List<DataFileSwitch>();
var switchList = await _data.GetPeriodFronters(system, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
var switchList = await _repo.GetPeriodFronters(conn, system.Id, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
switches.AddRange(switchList.Select(x => new DataFileSwitch
{
Timestamp = x.TimespanStart.FormatExport(),
@ -68,7 +70,7 @@ namespace PluralKit.Core
Members = members,
Switches = switches,
Created = system.Created.FormatExport(),
LinkedAccounts = (await _data.GetSystemAccounts(system)).ToList()
LinkedAccounts = (await _repo.GetSystemAccounts(conn, system.Id)).ToList()
};
}
@ -102,6 +104,8 @@ namespace PluralKit.Core
public async Task<ImportResult> ImportSystem(DataFileSystem data, PKSystem system, ulong accountId)
{
await using var conn = await _db.Obtain();
var result = new ImportResult {
AddedNames = new List<string>(),
ModifiedNames = new List<string>(),
@ -112,26 +116,24 @@ namespace PluralKit.Core
// If we don't already have a system to save to, create one
if (system == null)
{
system = result.System = await _data.CreateSystem(data.Name);
await _data.AddAccount(system, accountId);
system = result.System = await _repo.CreateSystem(conn, data.Name);
await _repo.AddAccount(conn, system.Id, accountId);
}
await using var conn = await _db.Obtain();
// Apply system info
var patch = new SystemPatch {Name = data.Name};
if (data.Description != null) patch.Description = data.Description;
if (data.Tag != null) patch.Tag = data.Tag;
if (data.AvatarUrl != null) patch.AvatarUrl = data.AvatarUrl;
if (data.TimeZone != null) patch.UiTz = data.TimeZone ?? "UTC";
await conn.UpdateSystem(system.Id, patch);
await _repo.UpdateSystem(conn, system.Id, patch);
// -- Member/switch import --
await using (var imp = await BulkImporter.Begin(system, conn))
{
// Tally up the members that didn't exist before, and check member count on import
// If creating the unmatched members would put us over the member limit, abort before creating any members
var memberCountBefore = await conn.GetSystemMemberCount(system.Id);
var memberCountBefore = await _repo.GetSystemMemberCount(conn, system.Id);
var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name));
if (memberCountBefore + membersToAdd > Limits.MaxMemberCount)
{

View File

@ -1,223 +0,0 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using NodaTime;
namespace PluralKit.Core {
public enum AutoproxyMode
{
Off = 1,
Front = 2,
Latch = 3,
Member = 4
}
public class FullMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
public struct PKMessage
{
public ulong Mid;
public ulong? Guild; // null value means "no data" (ie. from before this field being added)
public ulong Channel;
public ulong Sender;
public ulong? OriginalMid;
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public struct FrontBreakdown
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public struct SwitchMembersListEntry
{
public MemberId Member;
public Instant Timestamp;
}
public interface IDataStore
{
/// <summary>
/// Gets a system by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given human ID, or null if no system was found.</returns>
Task<PKSystem> GetSystemByHid(string systemHid);
/// <summary>
/// Gets a system by one of its linked Discord account IDs. Multiple IDs can return the same system.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given linked account, or null if no system was found.</returns>
Task<PKSystem> GetSystemByAccount(ulong linkedAccount);
/// <summary>
/// Gets the Discord account IDs linked to a system.
/// </summary>
/// <returns>An enumerable of Discord account IDs linked to this system.</returns>
Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system);
/// <summary>
/// Creates a system, auto-generating its corresponding IDs.
/// </summary>
/// <param name="systemName">An optional system name to set. If `null`, will not set a system name.</param>
/// <returns>The created system model.</returns>
Task<PKSystem> CreateSystem(string systemName);
// TODO: throw exception if account is present (when adding) or account isn't present (when removing)
/// <summary>
/// Links a Discord account to a system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is already linked to a system.</exception>
Task AddAccount(PKSystem system, ulong accountToAdd);
/// <summary>
/// Unlinks a Discord account from a system.
///
/// Will *not* throw if this results in an orphaned system - this is the caller's responsibility to ensure.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is not linked to the given system.</exception>
Task RemoveAccount(PKSystem system, ulong accountToRemove);
/// <summary>
/// Gets a member by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given human ID, or null if no member was found.</returns>
Task<PKMember> GetMemberByHid(string memberHid);
/// <summary>
/// Gets a member by its member name within one system.
/// </summary>
/// <para>
/// Member names are *usually* unique within a system (but not always), whereas member names
/// are almost certainly *not* unique globally - therefore only intra-system lookup is
/// allowed.
/// </para>
/// <returns>The <see cref="PKMember"/> with the given name, or null if no member was found.</returns>
Task<PKMember> GetMemberByName(PKSystem system, string name);
/// <summary>
/// Gets a member by its display name within one system.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given name, or null if no member was found.</returns>
Task<PKMember> GetMemberByDisplayName(PKSystem system, string name);
/// <summary>
/// Gets all members inside a given system.
/// </summary>
/// <returns>An enumerable of <see cref="PKMember"/> structs representing each member in the system, in no particular order.</returns>
IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName = false);
/// <summary>
/// Gets a message and its information by its ID.
/// </summary>
/// <param name="id">The message ID to look up. This can be either the ID of the trigger message containing the proxy tags or the resulting proxied webhook message.</param>
/// <returns>An extended message object, containing not only the message data itself but the associated system and member structs.</returns>
Task<FullMessage> GetMessage(ulong id); // id is both original and trigger, also add return type struct
/// <summary>
/// Saves a posted message to the database.
/// </summary>
/// <param name="senderAccount">The ID of the account that sent the original trigger message.</param>
/// <param name="guildId">The ID of the guild the message was posted to.</param>
/// <param name="channelId">The ID of the channel the message was posted to.</param>
/// <param name="postedMessageId">The ID of the message posted by the webhook.</param>
/// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param>
/// <param name="proxiedMemberId">The member (and by extension system) that was proxied.</param>
/// <returns></returns>
Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId);
/// <summary>
/// Deletes a message from the data store.
/// </summary>
/// <param name="postedMessageId">The ID of the webhook message to delete.</param>
Task DeleteMessage(ulong postedMessageId);
/// <summary>
/// Deletes messages from the data store in bulk.
/// </summary>
/// <param name="postedMessageIds">The IDs of the webhook messages to delete.</param>
Task DeleteMessagesBulk(IReadOnlyCollection<ulong> postedMessageIds);
/// <summary>
/// Gets switches from a system.
/// </summary>
/// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns>
IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system);
/// <summary>
/// Gets the total amount of switches in a given system.
/// </summary>
Task<int> GetSwitchCount(PKSystem system);
/// <summary>
/// Gets the latest (temporally; closest to now) switch of a given system.
/// </summary>
Task<PKSwitch> GetLatestSwitch(SystemId system);
/// <summary>
/// Gets the members a given switch consists of.
/// </summary>
IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw);
/// <summary>
/// Gets a list of fronters over a given period of time.
/// </summary>
/// <para>
/// This list is returned as an enumerable of "switch members", each containing a timestamp
/// and a member ID. <seealso cref="GetMemberById"/>
///
/// Switches containing multiple members will be returned as multiple switch members each with the same
/// timestamp, and a change in timestamp should be interpreted as the start of a new switch.
/// </para>
/// <returns>An enumerable of the aforementioned "switch members".</returns>
Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Calculates a breakdown of a system's fronters over a given period, including how long each member has
/// been fronting, and how long *no* member has been fronting.
/// </summary>
/// <para>
/// Switches containing multiple members will count the full switch duration for all members, meaning
/// the total duration may add up to longer than the breakdown period.
/// </para>
/// <param name="system"></param>
/// <param name="periodStart"></param>
/// <param name="periodEnd"></param>
/// <returns></returns>
Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Registers a switch with the given members in the given system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if any of the members are not in the given system.</exception>
Task AddSwitch(SystemId system, IEnumerable<PKMember> switchMembers);
/// <summary>
/// Updates the timestamp of a given switch.
/// </summary>
Task MoveSwitch(PKSwitch sw, Instant time);
/// <summary>
/// Deletes a given switch from the data store.
/// </summary>
Task DeleteSwitch(PKSwitch sw);
/// <summary>
/// Deletes all switches in a given system from the data store.
/// </summary>
Task DeleteAllSwitches(PKSystem system);
}
}

View File

@ -1,334 +0,0 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using Serilog;
namespace PluralKit.Core {
public class PostgresDataStore: IDataStore {
private readonly IDatabase _conn;
private readonly ILogger _logger;
public PostgresDataStore(IDatabase conn, ILogger logger)
{
_conn = conn;
_logger = logger
.ForContext<PostgresDataStore>()
.ForContext("Elastic", "yes?");
}
public async Task<PKSystem> CreateSystem(string systemName = null) {
PKSystem system;
using (var conn = await _conn.Obtain())
system = await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *", new { Name = systemName });
_logger.Information("Created {SystemId}", system.Id);
// New system has no accounts, therefore nothing gets cached, therefore no need to invalidate caches right here
return system;
}
public async Task AddAccount(PKSystem system, ulong accountId) {
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id });
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system.Id);
}
public async Task RemoveAccount(PKSystem system, ulong accountId) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id });
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system.Id);
}
public async Task<PKSystem> GetSystemByAccount(ulong accountId) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId });
}
public async Task<PKSystem> GetSystemByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new { Id = system.Id });
}
public async Task DeleteAllSwitches(PKSystem system)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where system = @Id", system);
_logger.Information("Deleted all switches in {SystemId}", system.Id);
}
public async Task<PKMember> GetMemberByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKMember> GetMemberByName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public async Task<PKMember> GetMemberByDisplayName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share display names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName)
{
var sql = "select * from members where system = @SystemID";
if (orderByName) sql += " order by lower(name) asc";
return _conn.QueryStreamAsync<PKMember>(sql, new { SystemID = system.Id });
}
public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new {
MessageId = postedMessageId,
GuildId = guildId,
ChannelId = channelId,
MemberId = proxiedMemberId,
SenderId = senderId,
OriginalMid = triggerMessageId
});
// todo: _logger.Debug("Stored message {Message} in channel {Channel}", postedMessageId, channelId);
}
public async Task<FullMessage> GetMessage(ulong id)
{
using (var conn = await _conn.Obtain())
return (await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>("select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system", (msg, member, system) => new FullMessage
{
Message = msg,
System = system,
Member = member
}, new { Id = id })).FirstOrDefault();
}
public async Task DeleteMessage(ulong id) {
using (var conn = await _conn.Obtain())
if (await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }) > 0)
_logger.Information("Deleted message {MessageId} from database", id);
}
public async Task DeleteMessagesBulk(IReadOnlyCollection<ulong> ids)
{
using (var conn = await _conn.Obtain())
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var foundCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", new {Ids = ids.Select(id => (long) id).ToArray()});
if (foundCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", foundCount, ids);
}
}
public async Task AddSwitch(SystemId system, IEnumerable<PKMember> members)
{
// Use a transaction here since we're doing multiple executed commands in one
await using var conn = await _conn.Obtain();
await using var tx = await conn.BeginTransactionAsync();
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system});
// Then we insert each member in the switch in the switch_members table
// TODO: can we parallelize this or send it in bulk somehow?
foreach (var member in members)
{
await conn.ExecuteAsync(
"insert into switch_members(switch, member) values(@Switch, @Member)",
new {Switch = sw.Id, Member = member.Id});
}
// Finally we commit the tx, since the using block will otherwise rollback it
await tx.CommitAsync();
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members.Select(m => m.Id));
}
public IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
// (maybe when we get caching in?)
return _conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc",
new {System = system});
}
public async Task<int> GetSwitchCount(PKSystem system)
{
using var conn = await _conn.Obtain();
return await conn.QuerySingleAsync<int>("select count(*) from switches where system = @Id", system);
}
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(PKSystem system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
await using var conn = await _conn.Obtain();
await using var tx = await conn.BeginTransactionAsync();
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new { System = system.Id, Start = start });
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new { System = system.Id, Start = start, End = end, LastSwitch = lastSwitch });
// Yield each value here
await foreach (var entry in switchMembersEntries)
yield return entry;
// Don't really need to worry about the transaction here, we're not doing any *writes*
}
public IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw)
{
return _conn.QueryStreamAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw.Id});
}
public async Task<PKSwitch> GetLatestSwitch(SystemId system) =>
await GetSwitches(system).FirstOrDefaultAsync();
public async Task MoveSwitch(PKSwitch sw, Instant time)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = sw.Id});
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", sw.Id, time);
}
public async Task DeleteSwitch(PKSwitch sw)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id});
_logger.Information("Deleted {Switch}", sw.Id);
}
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd)
{
// TODO: IAsyncEnumerable-ify this one
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(system, periodStart, periodEnd).ToListAsync();
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
Dictionary<MemberId, PKMember> memberObjects;
using (var conn = await _conn.Obtain())
{
memberObjects = (
await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new { Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList() })
).ToDictionary(m => m.Id);
}
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member]).ToList()
};
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
outList.Add(new SwitchListEntry
{
Members = e.Members,
TimespanStart = switchStartClamped,
TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
}
return outList;
}
public async Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd)
{
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetPeriodFronters(system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new FrontBreakdown
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
}

View File

@ -6,20 +6,11 @@ using Dapper;
namespace PluralKit.Core {
public static class ConnectionUtils
{
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IDatabase connFactory, string sql, object param)
{
await using var conn = await connFactory.Obtain();
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>();
while (await reader.ReadAsync())
yield return parser(reader);
}
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IPKConnection conn, string sql, object param)
{
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>();
while (await reader.ReadAsync())
yield return parser(reader);
}

View File

@ -8,9 +8,9 @@ namespace PluralKit.Core
{
private readonly string? _conflictField;
private readonly string? _condition;
private StringBuilder _insertFragment = new StringBuilder();
private StringBuilder _valuesFragment = new StringBuilder();
private StringBuilder _updateFragment = new StringBuilder();
private readonly StringBuilder _insertFragment = new StringBuilder();
private readonly StringBuilder _valuesFragment = new StringBuilder();
private readonly StringBuilder _updateFragment = new StringBuilder();
private bool _firstInsert = true;
private bool _firstUpdate = true;
public QueryType Type { get; }