Major database refactor (again)

This commit is contained in:
Ske 2020-08-29 13:46:27 +02:00
parent 3996cd48c7
commit c7612df37e
55 changed files with 1014 additions and 1100 deletions

View File

@ -13,18 +13,20 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/a" )] [Route( "v{version:apiVersion}/a" )]
public class AccountController: ControllerBase public class AccountController: ControllerBase
{ {
private IDataStore _data; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public AccountController(IDataStore data) public AccountController(IDatabase db, ModelRepository repo)
{ {
_data = data; _db = db;
_repo = repo;
} }
[HttpGet("{aid}")] [HttpGet("{aid}")]
public async Task<ActionResult<JObject>> GetSystemByAccount(ulong aid) public async Task<ActionResult<JObject>> GetSystemByAccount(ulong aid)
{ {
var system = await _data.GetSystemByAccount(aid); var system = await _db.Execute(c => _repo.GetSystemByAccount(c, aid));
if (system == null) return NotFound("Account not found."); if (system == null)
return NotFound("Account not found.");
return Ok(system.ToJson(User.ContextFor(system))); return Ok(system.ToJson(User.ContextFor(system)));
} }

View File

@ -16,19 +16,21 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/m" )] [Route( "v{version:apiVersion}/m" )]
public class MemberController: ControllerBase public class MemberController: ControllerBase
{ {
private IDatabase _db; private readonly IDatabase _db;
private IAuthorizationService _auth; private readonly ModelRepository _repo;
private readonly IAuthorizationService _auth;
public MemberController(IAuthorizationService auth, IDatabase db) public MemberController(IAuthorizationService auth, IDatabase db, ModelRepository repo)
{ {
_auth = auth; _auth = auth;
_db = db; _db = db;
_repo = repo;
} }
[HttpGet("{hid}")] [HttpGet("{hid}")]
public async Task<ActionResult<JObject>> GetMember(string hid) public async Task<ActionResult<JObject>> GetMember(string hid)
{ {
var member = await _db.Execute(conn => conn.QueryMemberByHid(hid)); var member = await _db.Execute(conn => _repo.GetMemberByHid(conn, hid));
if (member == null) return NotFound("Member not found."); if (member == null) return NotFound("Member not found.");
return Ok(member.ToJson(User.ContextFor(member))); return Ok(member.ToJson(User.ContextFor(member)));
@ -49,7 +51,7 @@ namespace PluralKit.API
if (memberCount >= Limits.MaxMemberCount) if (memberCount >= Limits.MaxMemberCount)
return BadRequest($"Member limit reached ({memberCount} / {Limits.MaxMemberCount})."); return BadRequest($"Member limit reached ({memberCount} / {Limits.MaxMemberCount}).");
var member = await conn.CreateMember(system, properties.Value<string>("name")); var member = await _repo.CreateMember(conn, system, properties.Value<string>("name"));
MemberPatch patch; MemberPatch patch;
try try
{ {
@ -60,7 +62,7 @@ namespace PluralKit.API
return BadRequest(e.Message); return BadRequest(e.Message);
} }
member = await conn.UpdateMember(member.Id, patch); member = await _repo.UpdateMember(conn, member.Id, patch);
return Ok(member.ToJson(User.ContextFor(member))); return Ok(member.ToJson(User.ContextFor(member)));
} }
@ -70,7 +72,7 @@ namespace PluralKit.API
{ {
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
var member = await conn.QueryMemberByHid(hid); var member = await _repo.GetMemberByHid(conn, hid);
if (member == null) return NotFound("Member not found."); if (member == null) return NotFound("Member not found.");
var res = await _auth.AuthorizeAsync(User, member, "EditMember"); var res = await _auth.AuthorizeAsync(User, member, "EditMember");
@ -86,7 +88,7 @@ namespace PluralKit.API
return BadRequest(e.Message); return BadRequest(e.Message);
} }
var newMember = await conn.UpdateMember(member.Id, patch); var newMember = await _repo.UpdateMember(conn, member.Id, patch);
return Ok(newMember.ToJson(User.ContextFor(newMember))); return Ok(newMember.ToJson(User.ContextFor(newMember)));
} }
@ -96,13 +98,13 @@ namespace PluralKit.API
{ {
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
var member = await conn.QueryMemberByHid(hid); var member = await _repo.GetMemberByHid(conn, hid);
if (member == null) return NotFound("Member not found."); if (member == null) return NotFound("Member not found.");
var res = await _auth.AuthorizeAsync(User, member, "EditMember"); var res = await _auth.AuthorizeAsync(User, member, "EditMember");
if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system."); if (!res.Succeeded) return Unauthorized($"Member '{hid}' is not part of your system.");
await conn.DeleteMember(member.Id); await _repo.DeleteMember(conn, member.Id);
return Ok(); return Ok();
} }
} }

View File

@ -28,17 +28,19 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/msg" )] [Route( "v{version:apiVersion}/msg" )]
public class MessageController: ControllerBase public class MessageController: ControllerBase
{ {
private IDataStore _data; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MessageController(IDataStore _data) public MessageController(ModelRepository repo, IDatabase db)
{ {
this._data = _data; _repo = repo;
_db = db;
} }
[HttpGet("{mid}")] [HttpGet("{mid}")]
public async Task<ActionResult<MessageReturn>> GetMessage(ulong mid) public async Task<ActionResult<MessageReturn>> GetMessage(ulong mid)
{ {
var msg = await _data.GetMessage(mid); var msg = await _db.Execute(c => _repo.GetMessage(c, mid));
if (msg == null) return NotFound("Message not found."); if (msg == null) return NotFound("Message not found.");
return new MessageReturn return new MessageReturn

View File

@ -39,29 +39,29 @@ namespace PluralKit.API
[Route( "v{version:apiVersion}/s" )] [Route( "v{version:apiVersion}/s" )]
public class SystemController : ControllerBase public class SystemController : ControllerBase
{ {
private IDataStore _data; private readonly IDatabase _db;
private IDatabase _db; private readonly ModelRepository _repo;
private IAuthorizationService _auth; private readonly IAuthorizationService _auth;
public SystemController(IDataStore data, IDatabase db, IAuthorizationService auth) public SystemController(IDatabase db, IAuthorizationService auth, ModelRepository repo)
{ {
_data = data;
_db = db; _db = db;
_auth = auth; _auth = auth;
_repo = repo;
} }
[HttpGet] [HttpGet]
[Authorize] [Authorize]
public async Task<ActionResult<JObject>> GetOwnSystem() public async Task<ActionResult<JObject>> GetOwnSystem()
{ {
var system = await _db.Execute(c => c.QuerySystem(User.CurrentSystem())); var system = await _db.Execute(c => _repo.GetSystem(c, User.CurrentSystem()));
return system.ToJson(User.ContextFor(system)); return system.ToJson(User.ContextFor(system));
} }
[HttpGet("{hid}")] [HttpGet("{hid}")]
public async Task<ActionResult<JObject>> GetSystem(string hid) public async Task<ActionResult<JObject>> GetSystem(string hid)
{ {
var system = await _data.GetSystemByHid(hid); var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid));
if (system == null) return NotFound("System not found."); if (system == null) return NotFound("System not found.");
return Ok(system.ToJson(User.ContextFor(system))); return Ok(system.ToJson(User.ContextFor(system)));
} }
@ -69,13 +69,14 @@ namespace PluralKit.API
[HttpGet("{hid}/members")] [HttpGet("{hid}/members")]
public async Task<ActionResult<IEnumerable<JObject>>> GetMembers(string hid) public async Task<ActionResult<IEnumerable<JObject>>> GetMembers(string hid)
{ {
var system = await _data.GetSystemByHid(hid); var system = await _db.Execute(c => _repo.GetSystemByHid(c, hid));
if (system == null) return NotFound("System not found."); if (system == null)
return NotFound("System not found.");
if (!system.MemberListPrivacy.CanAccess(User.ContextFor(system))) if (!system.MemberListPrivacy.CanAccess(User.ContextFor(system)))
return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list."); return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view member list.");
var members = _data.GetSystemMembers(system); var members = _db.Execute(c => _repo.GetSystemMembers(c, system.Id));
return Ok(await members return Ok(await members
.Where(m => m.MemberVisibility.CanAccess(User.ContextFor(system))) .Where(m => m.MemberVisibility.CanAccess(User.ContextFor(system)))
.Select(m => m.ToJson(User.ContextFor(system))) .Select(m => m.ToJson(User.ContextFor(system)))
@ -87,14 +88,14 @@ namespace PluralKit.API
{ {
if (before == null) before = SystemClock.Instance.GetCurrentInstant(); if (before == null) before = SystemClock.Instance.GetCurrentInstant();
var system = await _data.GetSystemByHid(hid); await using var conn = await _db.Obtain();
var system = await _repo.GetSystemByHid(conn, hid);
if (system == null) return NotFound("System not found."); if (system == null) return NotFound("System not found.");
var auth = await _auth.AuthorizeAsync(User, system, "ViewFrontHistory"); var auth = await _auth.AuthorizeAsync(User, system, "ViewFrontHistory");
if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history."); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view front history.");
using (var conn = await _db.Obtain())
{
var res = await conn.QueryAsync<SwitchesReturn>( var res = await conn.QueryAsync<SwitchesReturn>(
@"select *, array( @"select *, array(
select members.hid from switch_members, members select members.hid from switch_members, members
@ -105,21 +106,22 @@ namespace PluralKit.API
limit 100;", new {System = system.Id, Before = before}); limit 100;", new {System = system.Id, Before = before});
return Ok(res); return Ok(res);
} }
}
[HttpGet("{hid}/fronters")] [HttpGet("{hid}/fronters")]
public async Task<ActionResult<FrontersReturn>> GetFronters(string hid) public async Task<ActionResult<FrontersReturn>> GetFronters(string hid)
{ {
var system = await _data.GetSystemByHid(hid); await using var conn = await _db.Obtain();
var system = await _repo.GetSystemByHid(conn, hid);
if (system == null) return NotFound("System not found."); if (system == null) return NotFound("System not found.");
var auth = await _auth.AuthorizeAsync(User, system, "ViewFront"); var auth = await _auth.AuthorizeAsync(User, system, "ViewFront");
if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter."); if (!auth.Succeeded) return StatusCode(StatusCodes.Status403Forbidden, "Unauthorized to view fronter.");
var sw = await _data.GetLatestSwitch(system.Id); var sw = await _repo.GetLatestSwitch(conn, system.Id);
if (sw == null) return NotFound("System has no registered switches."); if (sw == null) return NotFound("System has no registered switches.");
var members = _data.GetSwitchMembers(sw); var members = _repo.GetSwitchMembers(conn, sw.Id);
return Ok(new FrontersReturn return Ok(new FrontersReturn
{ {
Timestamp = sw.Timestamp, Timestamp = sw.Timestamp,
@ -131,7 +133,8 @@ namespace PluralKit.API
[Authorize] [Authorize]
public async Task<ActionResult<JObject>> EditSystem([FromBody] JObject changes) public async Task<ActionResult<JObject>> EditSystem([FromBody] JObject changes)
{ {
var system = await _db.Execute(c => c.QuerySystem(User.CurrentSystem())); await using var conn = await _db.Obtain();
var system = await _repo.GetSystem(conn, User.CurrentSystem());
SystemPatch patch; SystemPatch patch;
try try
@ -143,7 +146,7 @@ namespace PluralKit.API
return BadRequest(e.Message); return BadRequest(e.Message);
} }
await _db.Execute(conn => conn.UpdateSystem(system.Id, patch)); await _repo.UpdateSystem(conn, system!.Id, patch);
return Ok(system.ToJson(User.ContextFor(system))); return Ok(system.ToJson(User.ContextFor(system)));
} }
@ -154,11 +157,13 @@ namespace PluralKit.API
if (param.Members.Distinct().Count() != param.Members.Count) if (param.Members.Distinct().Count() != param.Members.Count)
return BadRequest("Duplicate members in member list."); return BadRequest("Duplicate members in member list.");
await using var conn = await _db.Obtain();
// We get the current switch, if it exists // We get the current switch, if it exists
var latestSwitch = await _data.GetLatestSwitch(User.CurrentSystem()); var latestSwitch = await _repo.GetLatestSwitch(conn, User.CurrentSystem());
if (latestSwitch != null) if (latestSwitch != null)
{ {
var latestSwitchMembers = _data.GetSwitchMembers(latestSwitch); var latestSwitchMembers = _repo.GetSwitchMembers(conn, latestSwitch.Id);
// Bail if this switch is identical to the latest one // Bail if this switch is identical to the latest one
if (await latestSwitchMembers.Select(m => m.Hid).SequenceEqualAsync(param.Members.ToAsyncEnumerable())) if (await latestSwitchMembers.Select(m => m.Hid).SequenceEqualAsync(param.Members.ToAsyncEnumerable()))
@ -166,9 +171,7 @@ namespace PluralKit.API
} }
// Resolve member objects for all given IDs // Resolve member objects for all given IDs
IEnumerable<PKMember> membersList; var membersList = (await conn.QueryAsync<PKMember>("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList();
using (var conn = await _db.Obtain())
membersList = (await conn.QueryAsync<PKMember>("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList();
foreach (var member in membersList) foreach (var member in membersList)
if (member.System != User.CurrentSystem()) if (member.System != User.CurrentSystem())
@ -182,12 +185,13 @@ namespace PluralKit.API
// We do this without .Select() since we want to have the early return bail if it doesn't find the member // We do this without .Select() since we want to have the early return bail if it doesn't find the member
foreach (var givenMemberId in param.Members) foreach (var givenMemberId in param.Members)
{ {
if (!membersDict.TryGetValue(givenMemberId, out var member)) return BadRequest($"Member '{givenMemberId}' not found."); if (!membersDict.TryGetValue(givenMemberId, out var member))
return BadRequest($"Member '{givenMemberId}' not found.");
membersInOrder.Add(member); membersInOrder.Add(member);
} }
// Finally, log the switch (yay!) // Finally, log the switch (yay!)
await _data.AddSwitch(User.CurrentSystem(), membersInOrder); await _repo.AddSwitch(conn, User.CurrentSystem(), membersInOrder.Select(m => m.Id).ToList());
return NoContent(); return NoContent();
} }
} }

View File

@ -15,7 +15,7 @@ namespace PluralKit.Bot
{ {
public class Context public class Context
{ {
private ILifetimeScope _provider; private readonly ILifetimeScope _provider;
private readonly DiscordRestClient _rest; private readonly DiscordRestClient _rest;
private readonly DiscordShardedClient _client; private readonly DiscordShardedClient _client;
@ -24,8 +24,8 @@ namespace PluralKit.Bot
private readonly Parameters _parameters; private readonly Parameters _parameters;
private readonly MessageContext _messageContext; private readonly MessageContext _messageContext;
private readonly IDataStore _data;
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly PKSystem _senderSystem; private readonly PKSystem _senderSystem;
private readonly IMetrics _metrics; private readonly IMetrics _metrics;
@ -38,10 +38,10 @@ namespace PluralKit.Bot
_client = provider.Resolve<DiscordShardedClient>(); _client = provider.Resolve<DiscordShardedClient>();
_message = message; _message = message;
_shard = shard; _shard = shard;
_data = provider.Resolve<IDataStore>();
_senderSystem = senderSystem; _senderSystem = senderSystem;
_messageContext = messageContext; _messageContext = messageContext;
_db = provider.Resolve<IDatabase>(); _db = provider.Resolve<IDatabase>();
_repo = provider.Resolve<ModelRepository>();
_metrics = provider.Resolve<IMetrics>(); _metrics = provider.Resolve<IMetrics>();
_provider = provider; _provider = provider;
_parameters = new Parameters(message.Content.Substring(commandParseOffset)); _parameters = new Parameters(message.Content.Substring(commandParseOffset));
@ -61,9 +61,8 @@ namespace PluralKit.Bot
public Parameters Parameters => _parameters; public Parameters Parameters => _parameters;
// TODO: this is just here so the extension methods can access it; should it be public/private/?
internal IDataStore DataStore => _data;
internal IDatabase Database => _db; internal IDatabase Database => _db;
internal ModelRepository Repository => _repo;
public Task<DiscordMessage> Reply(string text = null, DiscordEmbed embed = null, IEnumerable<IMention> mentions = null) public Task<DiscordMessage> Reply(string text = null, DiscordEmbed embed = null, IEnumerable<IMention> mentions = null)
{ {

View File

@ -47,12 +47,14 @@ namespace PluralKit.Bot
// - A @mention of an account connected to the system (<@uid>) // - A @mention of an account connected to the system (<@uid>)
// - A system hid // - A system hid
await using var conn = await ctx.Database.Obtain();
// Direct IDs and mentions are both handled by the below method: // Direct IDs and mentions are both handled by the below method:
if (input.TryParseMention(out var id)) if (input.TryParseMention(out var id))
return await ctx.DataStore.GetSystemByAccount(id); return await ctx.Repository.GetSystemByAccount(conn, id);
// Finally, try HID parsing // Finally, try HID parsing
var system = await ctx.DataStore.GetSystemByHid(input); var system = await ctx.Repository.GetSystemByHid(conn, input);
return system; return system;
} }
@ -67,15 +69,16 @@ namespace PluralKit.Bot
// - a textual display name of a member *in your own system* // - a textual display name of a member *in your own system*
// First, if we have a system, try finding by member name in system // First, if we have a system, try finding by member name in system
if (ctx.System != null && await ctx.DataStore.GetMemberByName(ctx.System, input) is PKMember memberByName) await using var conn = await ctx.Database.Obtain();
if (ctx.System != null && await ctx.Repository.GetMemberByName(conn, ctx.System.Id, input) is PKMember memberByName)
return memberByName; return memberByName;
// Then, try member HID parsing: // Then, try member HID parsing:
if (await ctx.DataStore.GetMemberByHid(input) is PKMember memberByHid) if (await ctx.Repository.GetMemberByHid(conn, input) is PKMember memberByHid)
return memberByHid; return memberByHid;
// And if that again fails, we try finding a member with a display name matching the argument from the system // And if that again fails, we try finding a member with a display name matching the argument from the system
if (ctx.System != null && await ctx.DataStore.GetMemberByDisplayName(ctx.System, input) is PKMember memberByDisplayName) if (ctx.System != null && await ctx.Repository.GetMemberByDisplayName(conn, ctx.System.Id, input) is PKMember memberByDisplayName)
return memberByDisplayName; return memberByDisplayName;
// We didn't find anything, so we return null. // We didn't find anything, so we return null.
@ -103,9 +106,9 @@ namespace PluralKit.Bot
var input = ctx.PeekArgument(); var input = ctx.PeekArgument();
await using var conn = await ctx.Database.Obtain(); await using var conn = await ctx.Database.Obtain();
if (ctx.System != null && await conn.QueryGroupByName(ctx.System.Id, input) is {} byName) if (ctx.System != null && await ctx.Repository.GetGroupByName(conn, ctx.System.Id, input) is {} byName)
return byName; return byName;
if (await conn.QueryGroupByHid(input) is {} byHid) if (await ctx.Repository.GetGroupByHid(conn, input) is {} byHid)
return byHid; return byHid;
return null; return null;

View File

@ -36,15 +36,15 @@ namespace PluralKit.Bot
private struct WordPosition private struct WordPosition
{ {
// Start of the word // Start of the word
internal int startPos; internal readonly int startPos;
// End of the word // End of the word
internal int endPos; internal readonly int endPos;
// How much to advance word pointer afterwards to point at the start of the *next* word // How much to advance word pointer afterwards to point at the start of the *next* word
internal int advanceAfterWord; internal readonly int advanceAfterWord;
internal bool wasQuoted; internal readonly bool wasQuoted;
public WordPosition(int startPos, int endPos, int advanceAfterWord, bool wasQuoted) public WordPosition(int startPos, int endPos, int advanceAfterWord, bool wasQuoted)
{ {

View File

@ -12,10 +12,12 @@ namespace PluralKit.Bot
public class Autoproxy public class Autoproxy
{ {
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Autoproxy(IDatabase db) public Autoproxy(IDatabase db, ModelRepository repo)
{ {
_db = db; _db = db;
_repo = repo;
} }
public async Task AutoproxyRoot(Context ctx) public async Task AutoproxyRoot(Context ctx)
@ -87,8 +89,8 @@ namespace PluralKit.Bot
var fronters = ctx.MessageContext.LastSwitchMembers; var fronters = ctx.MessageContext.LastSwitchMembers;
var relevantMember = ctx.MessageContext.AutoproxyMode switch var relevantMember = ctx.MessageContext.AutoproxyMode switch
{ {
AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => c.QueryMember(fronters[0])) : null, AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => _repo.GetMember(c, fronters[0])) : null,
AutoproxyMode.Member => await _db.Execute(c => c.QueryMember(ctx.MessageContext.AutoproxyMember.Value)), AutoproxyMode.Member => await _db.Execute(c => _repo.GetMember(c, ctx.MessageContext.AutoproxyMember.Value)),
_ => null _ => null
}; };
@ -126,7 +128,7 @@ namespace PluralKit.Bot
private Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, MemberId? autoproxyMember) private Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, MemberId? autoproxyMember)
{ {
var patch = new SystemGuildPatch {AutoproxyMode = autoproxyMode, AutoproxyMember = autoproxyMember}; var patch = new SystemGuildPatch {AutoproxyMode = autoproxyMode, AutoproxyMember = autoproxyMember};
return _db.Execute(conn => conn.UpsertSystemGuild(ctx.System.Id, ctx.Guild.Id, patch)); return _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch));
} }
} }
} }

View File

@ -17,10 +17,12 @@ namespace PluralKit.Bot
public class Groups public class Groups
{ {
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Groups(IDatabase db) public Groups(IDatabase db, ModelRepository repo)
{ {
_db = db; _db = db;
_repo = repo;
} }
public async Task CreateGroup(Context ctx) public async Task CreateGroup(Context ctx)
@ -40,14 +42,14 @@ namespace PluralKit.Bot
throw new PKError($"System has reached the maximum number of groups ({Limits.MaxGroupCount}). Please delete unused groups first in order to create new ones."); throw new PKError($"System has reached the maximum number of groups ({Limits.MaxGroupCount}). Please delete unused groups first in order to create new ones.");
// Warn if there's already a group by this name // Warn if there's already a group by this name
var existingGroup = await conn.QueryGroupByName(ctx.System.Id, groupName); var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, groupName);
if (existingGroup != null) { if (existingGroup != null) {
var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to create another group with the same name?"; var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to create another group with the same name?";
if (!await ctx.PromptYesNo(msg)) if (!await ctx.PromptYesNo(msg))
throw new PKError("Group creation cancelled."); throw new PKError("Group creation cancelled.");
} }
var newGroup = await conn.CreateGroup(ctx.System.Id, groupName); var newGroup = await _repo.CreateGroup(conn, ctx.System.Id, groupName);
var eb = new DiscordEmbedBuilder() var eb = new DiscordEmbedBuilder()
.WithDescription($"Your new group, **{groupName}**, has been created, with the group ID **`{newGroup.Hid}`**.\nBelow are a couple of useful commands:") .WithDescription($"Your new group, **{groupName}**, has been created, with the group ID **`{newGroup.Hid}`**.\nBelow are a couple of useful commands:")
@ -70,14 +72,14 @@ namespace PluralKit.Bot
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
// Warn if there's already a group by this name // Warn if there's already a group by this name
var existingGroup = await conn.QueryGroupByName(ctx.System.Id, newName); var existingGroup = await _repo.GetGroupByName(conn, ctx.System.Id, newName);
if (existingGroup != null && existingGroup.Id != target.Id) { if (existingGroup != null && existingGroup.Id != target.Id) {
var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to rename this member to that name too?"; var msg = $"{Emojis.Warn} You already have a group in your system with the name \"{existingGroup.Name}\" (with ID `{existingGroup.Hid}`). Do you want to rename this member to that name too?";
if (!await ctx.PromptYesNo(msg)) if (!await ctx.PromptYesNo(msg))
throw new PKError("Group creation cancelled."); throw new PKError("Group creation cancelled.");
} }
await conn.UpdateGroup(target.Id, new GroupPatch {Name = newName}); await _repo.UpdateGroup(conn, target.Id, new GroupPatch {Name = newName});
await ctx.Reply($"{Emojis.Success} Group name changed from **{target.Name}** to **{newName}**."); await ctx.Reply($"{Emojis.Success} Group name changed from **{target.Name}** to **{newName}**.");
} }
@ -89,7 +91,7 @@ namespace PluralKit.Bot
ctx.CheckOwnGroup(target); ctx.CheckOwnGroup(target);
var patch = new GroupPatch {DisplayName = Partial<string>.Null()}; var patch = new GroupPatch {DisplayName = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group display name cleared."); await ctx.Reply($"{Emojis.Success} Group display name cleared.");
} }
@ -112,7 +114,7 @@ namespace PluralKit.Bot
var newDisplayName = ctx.RemainderOrNull(); var newDisplayName = ctx.RemainderOrNull();
var patch = new GroupPatch {DisplayName = Partial<string>.Present(newDisplayName)}; var patch = new GroupPatch {DisplayName = Partial<string>.Present(newDisplayName)};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group display name changed."); await ctx.Reply($"{Emojis.Success} Group display name changed.");
} }
@ -125,7 +127,7 @@ namespace PluralKit.Bot
ctx.CheckOwnGroup(target); ctx.CheckOwnGroup(target);
var patch = new GroupPatch {Description = Partial<string>.Null()}; var patch = new GroupPatch {Description = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group description cleared."); await ctx.Reply($"{Emojis.Success} Group description cleared.");
} }
else if (!ctx.HasNext()) else if (!ctx.HasNext())
@ -154,7 +156,7 @@ namespace PluralKit.Bot
throw Errors.DescriptionTooLongError(description.Length); throw Errors.DescriptionTooLongError(description.Length);
var patch = new GroupPatch {Description = Partial<string>.Present(description)}; var patch = new GroupPatch {Description = Partial<string>.Present(description)};
await _db.Execute(conn => conn.UpdateGroup(target.Id, patch)); await _db.Execute(conn => _repo.UpdateGroup(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Group description changed."); await ctx.Reply($"{Emojis.Success} Group description changed.");
} }
@ -166,7 +168,7 @@ namespace PluralKit.Bot
{ {
ctx.CheckOwnGroup(target); ctx.CheckOwnGroup(target);
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch {Icon = null})); await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch {Icon = null}));
await ctx.Reply($"{Emojis.Success} Group icon cleared."); await ctx.Reply($"{Emojis.Success} Group icon cleared.");
} }
@ -178,7 +180,7 @@ namespace PluralKit.Bot
throw Errors.InvalidUrl(img.Url); throw Errors.InvalidUrl(img.Url);
await AvatarUtils.VerifyAvatarOrThrow(img.Url); await AvatarUtils.VerifyAvatarOrThrow(img.Url);
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch {Icon = img.Url})); await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch {Icon = img.Url}));
var msg = img.Source switch var msg = img.Source switch
{ {
@ -282,7 +284,7 @@ namespace PluralKit.Bot
var system = await GetGroupSystem(ctx, target, conn); var system = await GetGroupSystem(ctx, target, conn);
var pctx = ctx.LookupContextFor(system); var pctx = ctx.LookupContextFor(system);
var memberCount = await conn.QueryGroupMemberCount(target.Id, PrivacyLevel.Public); var memberCount = await _repo.GetGroupMemberCount(conn, target.Id, PrivacyLevel.Public);
var nameField = target.Name; var nameField = target.Name;
if (system.Name != null) if (system.Name != null)
@ -333,7 +335,7 @@ namespace PluralKit.Bot
.Select(m => m.Id) .Select(m => m.Id)
.Distinct() .Distinct()
.ToList(); .ToList();
await conn.AddMembersToGroup(target.Id, membersNotInGroup); await _repo.AddMembersToGroup(conn, target.Id, membersNotInGroup);
if (membersNotInGroup.Count == members.Count) if (membersNotInGroup.Count == members.Count)
await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersNotInGroup.Count)} added to group."); await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersNotInGroup.Count)} added to group.");
@ -347,7 +349,7 @@ namespace PluralKit.Bot
.Select(m => m.Id) .Select(m => m.Id)
.Distinct() .Distinct()
.ToList(); .ToList();
await conn.RemoveMembersFromGroup(target.Id, membersInGroup); await _repo.RemoveMembersFromGroup(conn, target.Id, membersInGroup);
if (membersInGroup.Count == members.Count) if (membersInGroup.Count == members.Count)
await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersInGroup.Count)} removed from group."); await ctx.Reply($"{Emojis.Success} {"members".ToQuantity(membersInGroup.Count)} removed from group.");
@ -422,7 +424,7 @@ namespace PluralKit.Bot
async Task SetAll(PrivacyLevel level) async Task SetAll(PrivacyLevel level)
{ {
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch().WithAllPrivacy(level))); await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithAllPrivacy(level)));
if (level == PrivacyLevel.Private) if (level == PrivacyLevel.Private)
await ctx.Reply($"{Emojis.Success} All {target.Name}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the group card."); await ctx.Reply($"{Emojis.Success} All {target.Name}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the group card.");
@ -432,7 +434,7 @@ namespace PluralKit.Bot
async Task SetLevel(GroupPrivacySubject subject, PrivacyLevel level) async Task SetLevel(GroupPrivacySubject subject, PrivacyLevel level)
{ {
await _db.Execute(c => c.UpdateGroup(target.Id, new GroupPatch().WithPrivacy(subject, level))); await _db.Execute(c => _repo.UpdateGroup(c, target.Id, new GroupPatch().WithPrivacy(subject, level)));
var subjectName = subject switch var subjectName = subject switch
{ {
@ -475,17 +477,17 @@ namespace PluralKit.Bot
if (!await ctx.ConfirmWithReply(target.Hid)) if (!await ctx.ConfirmWithReply(target.Hid))
throw new PKError($"Group deletion cancelled. Note that you must reply with your group ID (`{target.Hid}`) *verbatim*."); throw new PKError($"Group deletion cancelled. Note that you must reply with your group ID (`{target.Hid}`) *verbatim*.");
await _db.Execute(conn => conn.DeleteGroup(target.Id)); await _db.Execute(conn => _repo.DeleteGroup(conn, target.Id));
await ctx.Reply($"{Emojis.Success} Group deleted."); await ctx.Reply($"{Emojis.Success} Group deleted.");
} }
private static async Task<PKSystem> GetGroupSystem(Context ctx, PKGroup target, IPKConnection conn) private async Task<PKSystem> GetGroupSystem(Context ctx, PKGroup target, IPKConnection conn)
{ {
var system = ctx.System; var system = ctx.System;
if (system?.Id == target.System) if (system?.Id == target.System)
return system; return system;
return await conn.QuerySystem(target.System)!; return await _repo.GetSystem(conn, target.System)!;
} }
} }
} }

View File

@ -15,7 +15,7 @@ namespace PluralKit.Bot
{ {
public class ImportExport public class ImportExport
{ {
private DataFileService _dataFiles; private readonly DataFileService _dataFiles;
public ImportExport(DataFileService dataFiles) public ImportExport(DataFileService dataFiles)
{ {
_dataFiles = dataFiles; _dataFiles = dataFiles;

View File

@ -8,15 +8,15 @@ namespace PluralKit.Bot
{ {
public class Member public class Member
{ {
private IDataStore _data; private readonly IDatabase _db;
private IDatabase _db; private readonly ModelRepository _repo;
private EmbedService _embeds; private readonly EmbedService _embeds;
public Member(IDataStore data, EmbedService embeds, IDatabase db) public Member(EmbedService embeds, IDatabase db, ModelRepository repo)
{ {
_data = data;
_embeds = embeds; _embeds = embeds;
_db = db; _db = db;
_repo = repo;
} }
public async Task NewMember(Context ctx) { public async Task NewMember(Context ctx) {
@ -27,7 +27,7 @@ namespace PluralKit.Bot
if (memberName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(memberName.Length); if (memberName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(memberName.Length);
// Warn if there's already a member by this name // Warn if there's already a member by this name
var existingMember = await _data.GetMemberByName(ctx.System, memberName); var existingMember = await _db.Execute(c => _repo.GetMemberByName(c, ctx.System.Id, memberName));
if (existingMember != null) { if (existingMember != null) {
var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (with ID `{existingMember.Hid}`). Do you want to create another member with the same name?"; var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (with ID `{existingMember.Hid}`). Do you want to create another member with the same name?";
if (!await ctx.PromptYesNo(msg)) throw new PKError("Member creation cancelled."); if (!await ctx.PromptYesNo(msg)) throw new PKError("Member creation cancelled.");
@ -36,12 +36,12 @@ namespace PluralKit.Bot
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
// Enforce per-system member limit // Enforce per-system member limit
var memberCount = await conn.GetSystemMemberCount(ctx.System.Id); var memberCount = await _repo.GetSystemMemberCount(conn, ctx.System.Id);
if (memberCount >= Limits.MaxMemberCount) if (memberCount >= Limits.MaxMemberCount)
throw Errors.MemberLimitReachedError; throw Errors.MemberLimitReachedError;
// Create the member // Create the member
var member = await conn.CreateMember(ctx.System.Id, memberName); var member = await _repo.CreateMember(conn, ctx.System.Id, memberName);
memberCount++; memberCount++;
// Send confirmation and space hint // Send confirmation and space hint
@ -63,9 +63,13 @@ namespace PluralKit.Bot
// TODO: don't buffer these, find something else to do ig // TODO: don't buffer these, find something else to do ig
List<PKMember> members; var members = await _db.Execute(c =>
if (ctx.MatchFlag("all", "a")) members = await _data.GetSystemMembers(ctx.System).ToListAsync(); {
else members = await _data.GetSystemMembers(ctx.System).Where(m => m.MemberVisibility == PrivacyLevel.Public).ToListAsync(); if (ctx.MatchFlag("all", "a"))
return _repo.GetSystemMembers(c, ctx.System.Id);
return _repo.GetSystemMembers(c, ctx.System.Id)
.Where(m => m.MemberVisibility == PrivacyLevel.Public);
}).ToListAsync();
if (members == null || !members.Any()) if (members == null || !members.Any())
throw Errors.NoMembersError; throw Errors.NoMembersError;
@ -75,8 +79,7 @@ namespace PluralKit.Bot
public async Task ViewMember(Context ctx, PKMember target) public async Task ViewMember(Context ctx, PKMember target)
{ {
var system = await _db.Execute(c => _repo.GetSystem(c, target.System));
var system = await _db.Execute(c => c.QuerySystem(target.System));
await ctx.Reply(embed: await _embeds.CreateMemberEmbed(system, target, ctx.Guild, ctx.LookupContextFor(system))); await ctx.Reply(embed: await _embeds.CreateMemberEmbed(system, target, ctx.Guild, ctx.LookupContextFor(system)));
} }
} }

View File

@ -11,10 +11,12 @@ namespace PluralKit.Bot
public class MemberAvatar public class MemberAvatar
{ {
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MemberAvatar(IDatabase db) public MemberAvatar(IDatabase db, ModelRepository repo)
{ {
_db = db; _db = db;
_repo = repo;
} }
private async Task AvatarClear(AvatarLocation location, Context ctx, PKMember target, MemberGuildSettings? mgs) private async Task AvatarClear(AvatarLocation location, Context ctx, PKMember target, MemberGuildSettings? mgs)
@ -67,14 +69,14 @@ namespace PluralKit.Bot
public async Task ServerAvatar(Context ctx, PKMember target) public async Task ServerAvatar(Context ctx, PKMember target)
{ {
ctx.CheckGuildContext(); ctx.CheckGuildContext();
var guildData = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); var guildData = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
await AvatarCommandTree(AvatarLocation.Server, ctx, target, guildData); await AvatarCommandTree(AvatarLocation.Server, ctx, target, guildData);
} }
public async Task Avatar(Context ctx, PKMember target) public async Task Avatar(Context ctx, PKMember target)
{ {
var guildData = ctx.Guild != null ? var guildData = ctx.Guild != null ?
await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)) await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id))
: null; : null;
await AvatarCommandTree(AvatarLocation.Member, ctx, target, guildData); await AvatarCommandTree(AvatarLocation.Member, ctx, target, guildData);
@ -150,10 +152,10 @@ namespace PluralKit.Bot
{ {
case AvatarLocation.Server: case AvatarLocation.Server:
var serverPatch = new MemberGuildPatch { AvatarUrl = url }; var serverPatch = new MemberGuildPatch { AvatarUrl = url };
return _db.Execute(c => c.UpsertMemberGuild(target.Id, ctx.Guild.Id, serverPatch)); return _db.Execute(c => _repo.UpsertMemberGuild(c, target.Id, ctx.Guild.Id, serverPatch));
case AvatarLocation.Member: case AvatarLocation.Member:
var memberPatch = new MemberPatch { AvatarUrl = url }; var memberPatch = new MemberPatch { AvatarUrl = url };
return _db.Execute(c => c.UpdateMember(target.Id, memberPatch)); return _db.Execute(c => _repo.UpdateMember(c, target.Id, memberPatch));
default: default:
throw new ArgumentOutOfRangeException($"Unknown avatar location {location}"); throw new ArgumentOutOfRangeException($"Unknown avatar location {location}");
} }

View File

@ -15,13 +15,13 @@ namespace PluralKit.Bot
{ {
public class MemberEdit public class MemberEdit
{ {
private readonly IDataStore _data;
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MemberEdit(IDataStore data, IDatabase db) public MemberEdit(IDatabase db, ModelRepository repo)
{ {
_data = data;
_db = db; _db = db;
_repo = repo;
} }
public async Task Name(Context ctx, PKMember target) { public async Task Name(Context ctx, PKMember target) {
@ -35,7 +35,7 @@ namespace PluralKit.Bot
if (newName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(newName.Length); if (newName.Length > Limits.MaxMemberNameLength) throw Errors.MemberNameTooLongError(newName.Length);
// Warn if there's already a member by this name // Warn if there's already a member by this name
var existingMember = await _data.GetMemberByName(ctx.System, newName); var existingMember = await _db.Execute(conn => _repo.GetMemberByName(conn, ctx.System.Id, newName));
if (existingMember != null && existingMember.Id != target.Id) if (existingMember != null && existingMember.Id != target.Id)
{ {
var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (`{existingMember.Hid}`). Do you want to rename this member to that name too?"; var msg = $"{Emojis.Warn} You already have a member in your system with the name \"{existingMember.NameFor(ctx)}\" (`{existingMember.Hid}`). Do you want to rename this member to that name too?";
@ -44,7 +44,7 @@ namespace PluralKit.Bot
// Rename the member // Rename the member
var patch = new MemberPatch {Name = Partial<string>.Present(newName)}; var patch = new MemberPatch {Name = Partial<string>.Present(newName)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member renamed."); await ctx.Reply($"{Emojis.Success} Member renamed.");
if (newName.Contains(" ")) await ctx.Reply($"{Emojis.Note} Note that this member's name now contains spaces. You will need to surround it with \"double quotes\" when using commands referring to it."); if (newName.Contains(" ")) await ctx.Reply($"{Emojis.Note} Note that this member's name now contains spaces. You will need to surround it with \"double quotes\" when using commands referring to it.");
@ -52,7 +52,7 @@ namespace PluralKit.Bot
if (ctx.Guild != null) if (ctx.Guild != null)
{ {
var memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
if (memberGuildConfig.DisplayName != null) if (memberGuildConfig.DisplayName != null)
await ctx.Reply($"{Emojis.Note} Note that this member has a server name set ({memberGuildConfig.DisplayName}) in this server ({ctx.Guild.Name}), and will be proxied using that name here."); await ctx.Reply($"{Emojis.Note} Note that this member has a server name set ({memberGuildConfig.DisplayName}) in this server ({ctx.Guild.Name}), and will be proxied using that name here.");
} }
@ -69,7 +69,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target); CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Description = Partial<string>.Null()}; var patch = new MemberPatch {Description = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member description cleared."); await ctx.Reply($"{Emojis.Success} Member description cleared.");
} }
else if (!ctx.HasNext()) else if (!ctx.HasNext())
@ -100,7 +100,7 @@ namespace PluralKit.Bot
throw Errors.DescriptionTooLongError(description.Length); throw Errors.DescriptionTooLongError(description.Length);
var patch = new MemberPatch {Description = Partial<string>.Present(description)}; var patch = new MemberPatch {Description = Partial<string>.Present(description)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member description changed."); await ctx.Reply($"{Emojis.Success} Member description changed.");
} }
@ -111,7 +111,7 @@ namespace PluralKit.Bot
{ {
CheckEditMemberPermission(ctx, target); CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Pronouns = Partial<string>.Null()}; var patch = new MemberPatch {Pronouns = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member pronouns cleared."); await ctx.Reply($"{Emojis.Success} Member pronouns cleared.");
} }
else if (!ctx.HasNext()) else if (!ctx.HasNext())
@ -136,7 +136,7 @@ namespace PluralKit.Bot
throw Errors.MemberPronounsTooLongError(pronouns.Length); throw Errors.MemberPronounsTooLongError(pronouns.Length);
var patch = new MemberPatch {Pronouns = Partial<string>.Present(pronouns)}; var patch = new MemberPatch {Pronouns = Partial<string>.Present(pronouns)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member pronouns changed."); await ctx.Reply($"{Emojis.Success} Member pronouns changed.");
} }
@ -150,7 +150,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target); CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Color = Partial<string>.Null()}; var patch = new MemberPatch {Color = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member color cleared."); await ctx.Reply($"{Emojis.Success} Member color cleared.");
} }
@ -182,7 +182,7 @@ namespace PluralKit.Bot
if (!Regex.IsMatch(color, "^[0-9a-fA-F]{6}$")) throw Errors.InvalidColorError(color); if (!Regex.IsMatch(color, "^[0-9a-fA-F]{6}$")) throw Errors.InvalidColorError(color);
var patch = new MemberPatch {Color = Partial<string>.Present(color.ToLowerInvariant())}; var patch = new MemberPatch {Color = Partial<string>.Present(color.ToLowerInvariant())};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply(embed: new DiscordEmbedBuilder() await ctx.Reply(embed: new DiscordEmbedBuilder()
.WithTitle($"{Emojis.Success} Member color changed.") .WithTitle($"{Emojis.Success} Member color changed.")
@ -198,7 +198,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target); CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {Birthday = Partial<LocalDate?>.Null()}; var patch = new MemberPatch {Birthday = Partial<LocalDate?>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member birthdate cleared."); await ctx.Reply($"{Emojis.Success} Member birthdate cleared.");
} }
@ -223,7 +223,7 @@ namespace PluralKit.Bot
if (birthday == null) throw Errors.BirthdayParseError(birthdayStr); if (birthday == null) throw Errors.BirthdayParseError(birthdayStr);
var patch = new MemberPatch {Birthday = Partial<LocalDate?>.Present(birthday)}; var patch = new MemberPatch {Birthday = Partial<LocalDate?>.Present(birthday)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member birthdate changed."); await ctx.Reply($"{Emojis.Success} Member birthdate changed.");
} }
@ -235,7 +235,7 @@ namespace PluralKit.Bot
MemberGuildSettings memberGuildConfig = null; MemberGuildSettings memberGuildConfig = null;
if (ctx.Guild != null) if (ctx.Guild != null)
memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
var eb = new DiscordEmbedBuilder().WithTitle($"Member names") var eb = new DiscordEmbedBuilder().WithTitle($"Member names")
.WithFooter($"Member ID: {target.Hid} | Active name in bold. Server name overrides display name, which overrides base name."); .WithFooter($"Member ID: {target.Hid} | Active name in bold. Server name overrides display name, which overrides base name.");
@ -271,7 +271,7 @@ namespace PluralKit.Bot
var successStr = text; var successStr = text;
if (ctx.Guild != null) if (ctx.Guild != null)
{ {
var memberGuildConfig = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); var memberGuildConfig = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
if (memberGuildConfig.DisplayName != null) if (memberGuildConfig.DisplayName != null)
successStr += $" However, this member has a server name set in this server ({ctx.Guild.Name}), and will be proxied using that name, \"{memberGuildConfig.DisplayName}\", here."; successStr += $" However, this member has a server name set in this server ({ctx.Guild.Name}), and will be proxied using that name, \"{memberGuildConfig.DisplayName}\", here.";
} }
@ -284,7 +284,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target); CheckEditMemberPermission(ctx, target);
var patch = new MemberPatch {DisplayName = Partial<string>.Null()}; var patch = new MemberPatch {DisplayName = Partial<string>.Null()};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await PrintSuccess($"{Emojis.Success} Member display name cleared. This member will now be proxied using their member name \"{target.NameFor(ctx)}\"."); await PrintSuccess($"{Emojis.Success} Member display name cleared. This member will now be proxied using their member name \"{target.NameFor(ctx)}\".");
} }
@ -303,7 +303,7 @@ namespace PluralKit.Bot
var newDisplayName = ctx.RemainderOrNull(); var newDisplayName = ctx.RemainderOrNull();
var patch = new MemberPatch {DisplayName = Partial<string>.Present(newDisplayName)}; var patch = new MemberPatch {DisplayName = Partial<string>.Present(newDisplayName)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await PrintSuccess($"{Emojis.Success} Member display name changed. This member will now be proxied using the name \"{newDisplayName}\"."); await PrintSuccess($"{Emojis.Success} Member display name changed. This member will now be proxied using the name \"{newDisplayName}\".");
} }
@ -318,7 +318,7 @@ namespace PluralKit.Bot
CheckEditMemberPermission(ctx, target); CheckEditMemberPermission(ctx, target);
var patch = new MemberGuildPatch {DisplayName = null}; var patch = new MemberGuildPatch {DisplayName = null};
await _db.Execute(conn => conn.UpsertMemberGuild(target.Id, ctx.Guild.Id, patch)); await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch));
if (target.DisplayName != null) if (target.DisplayName != null)
await ctx.Reply($"{Emojis.Success} Member server name cleared. This member will now be proxied using their global display name \"{target.DisplayName}\" in this server ({ctx.Guild.Name})."); await ctx.Reply($"{Emojis.Success} Member server name cleared. This member will now be proxied using their global display name \"{target.DisplayName}\" in this server ({ctx.Guild.Name}).");
@ -340,7 +340,7 @@ namespace PluralKit.Bot
var newServerName = ctx.RemainderOrNull(); var newServerName = ctx.RemainderOrNull();
var patch = new MemberGuildPatch {DisplayName = newServerName}; var patch = new MemberGuildPatch {DisplayName = newServerName};
await _db.Execute(conn => conn.UpsertMemberGuild(target.Id, ctx.Guild.Id, patch)); await _db.Execute(conn => _repo.UpsertMemberGuild(conn, target.Id, ctx.Guild.Id, patch));
await ctx.Reply($"{Emojis.Success} Member server name changed. This member will now be proxied using the name \"{newServerName}\" in this server ({ctx.Guild.Name})."); await ctx.Reply($"{Emojis.Success} Member server name changed. This member will now be proxied using the name \"{newServerName}\" in this server ({ctx.Guild.Name}).");
} }
@ -365,7 +365,7 @@ namespace PluralKit.Bot
}; };
var patch = new MemberPatch {KeepProxy = Partial<bool>.Present(newValue)}; var patch = new MemberPatch {KeepProxy = Partial<bool>.Present(newValue)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
if (newValue) if (newValue)
await ctx.Reply($"{Emojis.Success} Member proxy tags will now be included in the resulting message when proxying."); await ctx.Reply($"{Emojis.Success} Member proxy tags will now be included in the resulting message when proxying.");
@ -398,11 +398,11 @@ namespace PluralKit.Bot
// Get guild settings (mostly for warnings and such) // Get guild settings (mostly for warnings and such)
MemberGuildSettings guildSettings = null; MemberGuildSettings guildSettings = null;
if (ctx.Guild != null) if (ctx.Guild != null)
guildSettings = await _db.Execute(c => c.QueryOrInsertMemberGuildConfig(ctx.Guild.Id, target.Id)); guildSettings = await _db.Execute(c => _repo.GetMemberGuild(c, ctx.Guild.Id, target.Id));
async Task SetAll(PrivacyLevel level) async Task SetAll(PrivacyLevel level)
{ {
await _db.Execute(c => c.UpdateMember(target.Id, new MemberPatch().WithAllPrivacy(level))); await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithAllPrivacy(level)));
if (level == PrivacyLevel.Private) if (level == PrivacyLevel.Private)
await ctx.Reply($"{Emojis.Success} All {target.NameFor(ctx)}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the member card."); await ctx.Reply($"{Emojis.Success} All {target.NameFor(ctx)}'s privacy settings have been set to **{level.LevelName()}**. Other accounts will now see nothing on the member card.");
@ -412,7 +412,7 @@ namespace PluralKit.Bot
async Task SetLevel(MemberPrivacySubject subject, PrivacyLevel level) async Task SetLevel(MemberPrivacySubject subject, PrivacyLevel level)
{ {
await _db.Execute(c => c.UpdateMember(target.Id, new MemberPatch().WithPrivacy(subject, level))); await _db.Execute(c => _repo.UpdateMember(c, target.Id, new MemberPatch().WithPrivacy(subject, level)));
var subjectName = subject switch var subjectName = subject switch
{ {
@ -472,7 +472,7 @@ namespace PluralKit.Bot
await ctx.Reply($"{Emojis.Warn} Are you sure you want to delete \"{target.NameFor(ctx)}\"? If so, reply to this message with the member's ID (`{target.Hid}`). __***This cannot be undone!***__"); await ctx.Reply($"{Emojis.Warn} Are you sure you want to delete \"{target.NameFor(ctx)}\"? If so, reply to this message with the member's ID (`{target.Hid}`). __***This cannot be undone!***__");
if (!await ctx.ConfirmWithReply(target.Hid)) throw Errors.MemberDeleteCancelled; if (!await ctx.ConfirmWithReply(target.Hid)) throw Errors.MemberDeleteCancelled;
await _db.Execute(conn => conn.DeleteMember(target.Id)); await _db.Execute(conn => _repo.DeleteMember(conn, target.Id));
await ctx.Reply($"{Emojis.Success} Member deleted."); await ctx.Reply($"{Emojis.Success} Member deleted.");
} }

View File

@ -10,10 +10,12 @@ namespace PluralKit.Bot
public class MemberProxy public class MemberProxy
{ {
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public MemberProxy(IDatabase db) public MemberProxy(IDatabase db, ModelRepository repo)
{ {
_db = db; _db = db;
_repo = repo;
} }
public async Task Proxy(Context ctx, PKMember target) public async Task Proxy(Context ctx, PKMember target)
@ -55,7 +57,7 @@ namespace PluralKit.Bot
} }
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(new ProxyTag[0])}; var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(new ProxyTag[0])};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Proxy tags cleared."); await ctx.Reply($"{Emojis.Success} Proxy tags cleared.");
} }
@ -83,7 +85,7 @@ namespace PluralKit.Bot
var newTags = target.ProxyTags.ToList(); var newTags = target.ProxyTags.ToList();
newTags.Add(tagToAdd); newTags.Add(tagToAdd);
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags.ToArray())}; var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags.ToArray())};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Added proxy tags {tagToAdd.ProxyString.AsCode()}."); await ctx.Reply($"{Emojis.Success} Added proxy tags {tagToAdd.ProxyString.AsCode()}.");
} }
@ -100,7 +102,7 @@ namespace PluralKit.Bot
var newTags = target.ProxyTags.ToList(); var newTags = target.ProxyTags.ToList();
newTags.Remove(tagToRemove); newTags.Remove(tagToRemove);
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags.ToArray())}; var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags.ToArray())};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Removed proxy tags {tagToRemove.ProxyString.AsCode()}."); await ctx.Reply($"{Emojis.Success} Removed proxy tags {tagToRemove.ProxyString.AsCode()}.");
} }
@ -124,7 +126,7 @@ namespace PluralKit.Bot
var newTags = new[] {requestedTag}; var newTags = new[] {requestedTag};
var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags)}; var patch = new MemberPatch {ProxyTags = Partial<ProxyTag[]>.Present(newTags)};
await _db.Execute(conn => conn.UpdateMember(target.Id, patch)); await _db.Execute(conn => _repo.UpdateMember(conn, target.Id, patch));
await ctx.Reply($"{Emojis.Success} Member proxy tags set to {requestedTag.ProxyString.AsCode()}."); await ctx.Reply($"{Emojis.Success} Member proxy tags set to {requestedTag.ProxyString.AsCode()}.");
} }

View File

@ -18,21 +18,23 @@ using DSharpPlus.Entities;
namespace PluralKit.Bot { namespace PluralKit.Bot {
public class Misc public class Misc
{ {
private BotConfig _botConfig; private readonly BotConfig _botConfig;
private IMetrics _metrics; private readonly IMetrics _metrics;
private CpuStatService _cpu; private readonly CpuStatService _cpu;
private ShardInfoService _shards; private readonly ShardInfoService _shards;
private IDataStore _data; private readonly EmbedService _embeds;
private EmbedService _embeds; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Misc(BotConfig botConfig, IMetrics metrics, CpuStatService cpu, ShardInfoService shards, IDataStore data, EmbedService embeds) public Misc(BotConfig botConfig, IMetrics metrics, CpuStatService cpu, ShardInfoService shards, EmbedService embeds, ModelRepository repo, IDatabase db)
{ {
_botConfig = botConfig; _botConfig = botConfig;
_metrics = metrics; _metrics = metrics;
_cpu = cpu; _cpu = cpu;
_shards = shards; _shards = shards;
_data = data;
_embeds = embeds; _embeds = embeds;
_repo = repo;
_db = db;
} }
public async Task Invite(Context ctx) public async Task Invite(Context ctx)
@ -198,7 +200,7 @@ namespace PluralKit.Bot {
messageId = ulong.Parse(match.Groups[1].Value); messageId = ulong.Parse(match.Groups[1].Value);
else throw new PKSyntaxError($"Could not parse {word.AsCode()} as a message ID or link."); else throw new PKSyntaxError($"Could not parse {word.AsCode()} as a message ID or link.");
var message = await _data.GetMessage(messageId); var message = await _db.Execute(c => _repo.GetMessage(c, messageId));
if (message == null) throw Errors.MessageNotFound(messageId); if (message == null) throw Errors.MessageNotFound(messageId);
await ctx.Reply(embed: await _embeds.CreateMessageInfoEmbed(ctx.Shard, message)); await ctx.Reply(embed: await _embeds.CreateMessageInfoEmbed(ctx.Shard, message));

View File

@ -12,12 +12,14 @@ namespace PluralKit.Bot
{ {
public class ServerConfig public class ServerConfig
{ {
private IDatabase _db; private readonly IDatabase _db;
private LoggerCleanService _cleanService; private readonly ModelRepository _repo;
public ServerConfig(LoggerCleanService cleanService, IDatabase db) private readonly LoggerCleanService _cleanService;
public ServerConfig(LoggerCleanService cleanService, IDatabase db, ModelRepository repo)
{ {
_cleanService = cleanService; _cleanService = cleanService;
_db = db; _db = db;
_repo = repo;
} }
public async Task SetLogChannel(Context ctx) public async Task SetLogChannel(Context ctx)
@ -32,7 +34,7 @@ namespace PluralKit.Bot
if (channel == null || channel.GuildId != ctx.Guild.Id) throw Errors.ChannelNotFound(channelString); if (channel == null || channel.GuildId != ctx.Guild.Id) throw Errors.ChannelNotFound(channelString);
var patch = new GuildPatch {LogChannel = channel?.Id}; var patch = new GuildPatch {LogChannel = channel?.Id};
await _db.Execute(conn => conn.UpsertGuild(ctx.Guild.Id, patch)); await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch));
if (channel != null) if (channel != null)
await ctx.Reply($"{Emojis.Success} Proxy logging channel set to #{channel.Name}."); await ctx.Reply($"{Emojis.Success} Proxy logging channel set to #{channel.Name}.");
@ -59,7 +61,7 @@ namespace PluralKit.Bot
ulong? logChannel = null; ulong? logChannel = null;
await using (var conn = await _db.Obtain()) await using (var conn = await _db.Obtain())
{ {
var config = await conn.QueryOrInsertGuildConfig(ctx.Guild.Id); var config = await _repo.GetGuild(conn, ctx.Guild.Id);
logChannel = config.LogChannel; logChannel = config.LogChannel;
var blacklist = config.LogBlacklist.ToHashSet(); var blacklist = config.LogBlacklist.ToHashSet();
if (enable) if (enable)
@ -68,7 +70,7 @@ namespace PluralKit.Bot
blacklist.UnionWith(affectedChannels.Select(c => c.Id)); blacklist.UnionWith(affectedChannels.Select(c => c.Id));
var patch = new GuildPatch {LogBlacklist = blacklist.ToArray()}; var patch = new GuildPatch {LogBlacklist = blacklist.ToArray()};
await conn.UpsertGuild(ctx.Guild.Id, patch); await _repo.UpsertGuild(conn, ctx.Guild.Id, patch);
} }
await ctx.Reply( await ctx.Reply(
@ -80,7 +82,7 @@ namespace PluralKit.Bot
{ {
ctx.CheckGuildContext().CheckAuthorPermission(Permissions.ManageGuild, "Manage Server"); ctx.CheckGuildContext().CheckAuthorPermission(Permissions.ManageGuild, "Manage Server");
var blacklist = await _db.Execute(c => c.QueryOrInsertGuildConfig(ctx.Guild.Id)); var blacklist = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id));
// Resolve all channels from the cache and order by position // Resolve all channels from the cache and order by position
var channels = blacklist.Blacklist var channels = blacklist.Blacklist
@ -139,7 +141,7 @@ namespace PluralKit.Bot
await using (var conn = await _db.Obtain()) await using (var conn = await _db.Obtain())
{ {
var guild = await conn.QueryOrInsertGuildConfig(ctx.Guild.Id); var guild = await _repo.GetGuild(conn, ctx.Guild.Id);
var blacklist = guild.Blacklist.ToHashSet(); var blacklist = guild.Blacklist.ToHashSet();
if (shouldAdd) if (shouldAdd)
blacklist.UnionWith(affectedChannels.Select(c => c.Id)); blacklist.UnionWith(affectedChannels.Select(c => c.Id));
@ -147,7 +149,7 @@ namespace PluralKit.Bot
blacklist.ExceptWith(affectedChannels.Select(c => c.Id)); blacklist.ExceptWith(affectedChannels.Select(c => c.Id));
var patch = new GuildPatch {Blacklist = blacklist.ToArray()}; var patch = new GuildPatch {Blacklist = blacklist.ToArray()};
await conn.UpsertGuild(ctx.Guild.Id, patch); await _repo.UpsertGuild(conn, ctx.Guild.Id, patch);
} }
await ctx.Reply($"{Emojis.Success} Channels {(shouldAdd ? "added to" : "removed from")} the proxy blacklist."); await ctx.Reply($"{Emojis.Success} Channels {(shouldAdd ? "added to" : "removed from")} the proxy blacklist.");
@ -170,7 +172,7 @@ namespace PluralKit.Bot
.WithTitle("Log cleanup settings") .WithTitle("Log cleanup settings")
.AddField("Supported bots", botList); .AddField("Supported bots", botList);
var guildCfg = await _db.Execute(c => c.QueryOrInsertGuildConfig(ctx.Guild.Id)); var guildCfg = await _db.Execute(c => _repo.GetGuild(c, ctx.Guild.Id));
if (guildCfg.LogCleanupEnabled) if (guildCfg.LogCleanupEnabled)
eb.WithDescription("Log cleanup is currently **on** for this server. To disable it, type `pk;logclean off`."); eb.WithDescription("Log cleanup is currently **on** for this server. To disable it, type `pk;logclean off`.");
else else
@ -180,7 +182,7 @@ namespace PluralKit.Bot
} }
var patch = new GuildPatch {LogCleanupEnabled = newValue}; var patch = new GuildPatch {LogCleanupEnabled = newValue};
await _db.Execute(conn => conn.UpsertGuild(ctx.Guild.Id, patch)); await _db.Execute(conn => _repo.UpsertGuild(conn, ctx.Guild.Id, patch));
if (newValue) if (newValue)
await ctx.Reply($"{Emojis.Success} Log cleanup has been **enabled** for this server. Messages deleted by PluralKit will now be cleaned up from logging channels managed by the following bots:\n- **{botList}**\n\n{Emojis.Note} Make sure PluralKit has the **Manage Messages** permission in the channels in question.\n{Emojis.Note} Also, make sure to blacklist the logging channel itself from the bots in question to prevent conflicts."); await ctx.Reply($"{Emojis.Success} Log cleanup has been **enabled** for this server. Messages deleted by PluralKit will now be cleaned up from logging channels managed by the following bots:\n- **{botList}**\n\n{Emojis.Note} Make sure PluralKit has the **Manage Messages** permission in the channels in question.\n{Emojis.Note} Also, make sure to blacklist the logging channel itself from the bots in question to prevent conflicts.");

View File

@ -13,11 +13,13 @@ namespace PluralKit.Bot
{ {
public class Switch public class Switch
{ {
private IDataStore _data; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public Switch(IDataStore data) public Switch(IDatabase db, ModelRepository repo)
{ {
_data = data; _db = db;
_repo = repo;
} }
public async Task SwitchDo(Context ctx) public async Task SwitchDo(Context ctx)
@ -42,16 +44,17 @@ namespace PluralKit.Bot
if (members.Select(m => m.Id).Distinct().Count() != members.Count) throw Errors.DuplicateSwitchMembers; if (members.Select(m => m.Id).Distinct().Count() != members.Count) throw Errors.DuplicateSwitchMembers;
// Find the last switch and its members if applicable // Find the last switch and its members if applicable
var lastSwitch = await _data.GetLatestSwitch(ctx.System.Id); await using var conn = await _db.Obtain();
var lastSwitch = await _repo.GetLatestSwitch(conn, ctx.System.Id);
if (lastSwitch != null) if (lastSwitch != null)
{ {
var lastSwitchMembers = _data.GetSwitchMembers(lastSwitch); var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastSwitch.Id);
// Make sure the requested switch isn't identical to the last one // Make sure the requested switch isn't identical to the last one
if (await lastSwitchMembers.Select(m => m.Id).SequenceEqualAsync(members.Select(m => m.Id).ToAsyncEnumerable())) if (await lastSwitchMembers.Select(m => m.Id).SequenceEqualAsync(members.Select(m => m.Id).ToAsyncEnumerable()))
throw Errors.SameSwitch(members, ctx.LookupContextFor(ctx.System)); throw Errors.SameSwitch(members, ctx.LookupContextFor(ctx.System));
} }
await _data.AddSwitch(ctx.System.Id, members); await _repo.AddSwitch(conn, ctx.System.Id, members.Select(m => m.Id).ToList());
if (members.Count == 0) if (members.Count == 0)
await ctx.Reply($"{Emojis.Success} Switch-out registered."); await ctx.Reply($"{Emojis.Success} Switch-out registered.");
@ -69,11 +72,13 @@ namespace PluralKit.Bot
var result = DateUtils.ParseDateTime(timeToMove, true, tz); var result = DateUtils.ParseDateTime(timeToMove, true, tz);
if (result == null) throw Errors.InvalidDateTime(timeToMove); if (result == null) throw Errors.InvalidDateTime(timeToMove);
await using var conn = await _db.Obtain();
var time = result.Value; var time = result.Value;
if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture; if (time.ToInstant() > SystemClock.Instance.GetCurrentInstant()) throw Errors.SwitchTimeInFuture;
// Fetch the last two switches for the system to do bounds checking on // Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync();
// If we don't have a switch to move, don't bother // If we don't have a switch to move, don't bother
if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
@ -87,7 +92,7 @@ namespace PluralKit.Bot
// Now we can actually do the move, yay! // Now we can actually do the move, yay!
// But, we do a prompt to confirm. // But, we do a prompt to confirm.
var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]); var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id);
var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync());
var lastSwitchTimeStr = lastTwoSwitches[0].Timestamp.FormatZoned(ctx.System); var lastSwitchTimeStr = lastTwoSwitches[0].Timestamp.FormatZoned(ctx.System);
var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration(); var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration();
@ -99,7 +104,7 @@ namespace PluralKit.Bot
if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchMoveCancelled; if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchMoveCancelled;
// aaaand *now* we do the move // aaaand *now* we do the move
await _data.MoveSwitch(lastTwoSwitches[0], time.ToInstant()); await _repo.MoveSwitch(conn, lastTwoSwitches[0].Id, time.ToInstant());
await ctx.Reply($"{Emojis.Success} Switch moved."); await ctx.Reply($"{Emojis.Success} Switch moved.");
} }
@ -113,16 +118,18 @@ namespace PluralKit.Bot
var purgeMsg = $"{Emojis.Warn} This will delete *all registered switches* in your system. Are you sure you want to proceed?"; var purgeMsg = $"{Emojis.Warn} This will delete *all registered switches* in your system. Are you sure you want to proceed?";
if (!await ctx.PromptYesNo(purgeMsg)) if (!await ctx.PromptYesNo(purgeMsg))
throw Errors.GenericCancelled(); throw Errors.GenericCancelled();
await _data.DeleteAllSwitches(ctx.System); await _db.Execute(c => _repo.DeleteAllSwitches(c, ctx.System.Id));
await ctx.Reply($"{Emojis.Success} Cleared system switches!"); await ctx.Reply($"{Emojis.Success} Cleared system switches!");
return; return;
} }
await using var conn = await _db.Obtain();
// Fetch the last two switches for the system to do bounds checking on // Fetch the last two switches for the system to do bounds checking on
var lastTwoSwitches = await _data.GetSwitches(ctx.System.Id).Take(2).ToListAsync(); var lastTwoSwitches = await _repo.GetSwitches(conn, ctx.System.Id).Take(2).ToListAsync();
if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches; if (lastTwoSwitches.Count == 0) throw Errors.NoRegisteredSwitches;
var lastSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[0]); var lastSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[0].Id);
var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var lastSwitchMemberStr = string.Join(", ", await lastSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync());
var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration(); var lastSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[0].Timestamp).FormatDuration();
@ -133,14 +140,14 @@ namespace PluralKit.Bot
} }
else else
{ {
var secondSwitchMembers = _data.GetSwitchMembers(lastTwoSwitches[1]); var secondSwitchMembers = _repo.GetSwitchMembers(conn, lastTwoSwitches[1].Id);
var secondSwitchMemberStr = string.Join(", ", await secondSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync()); var secondSwitchMemberStr = string.Join(", ", await secondSwitchMembers.Select(m => m.NameFor(ctx)).ToListAsync());
var secondSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[1].Timestamp).FormatDuration(); var secondSwitchDeltaStr = (SystemClock.Instance.GetCurrentInstant() - lastTwoSwitches[1].Timestamp).FormatDuration();
msg = $"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr}, {lastSwitchDeltaStr} ago). The next latest switch is {secondSwitchMemberStr} ({secondSwitchDeltaStr} ago). Is this okay?"; msg = $"{Emojis.Warn} This will delete the latest switch ({lastSwitchMemberStr}, {lastSwitchDeltaStr} ago). The next latest switch is {secondSwitchMemberStr} ({secondSwitchDeltaStr} ago). Is this okay?";
} }
if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchDeleteCancelled; if (!await ctx.PromptYesNo(msg)) throw Errors.SwitchDeleteCancelled;
await _data.DeleteSwitch(lastTwoSwitches[0]); await _repo.DeleteSwitch(conn, lastTwoSwitches[0].Id);
await ctx.Reply($"{Emojis.Success} Switch deleted."); await ctx.Reply($"{Emojis.Success} Switch deleted.");
} }

View File

@ -6,13 +6,15 @@ namespace PluralKit.Bot
{ {
public class System public class System
{ {
private IDataStore _data; private readonly EmbedService _embeds;
private EmbedService _embeds; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public System(EmbedService embeds, IDataStore data) public System(EmbedService embeds, IDatabase db, ModelRepository repo)
{ {
_embeds = embeds; _embeds = embeds;
_data = data; _db = db;
_repo = repo;
} }
public async Task Query(Context ctx, PKSystem system) { public async Task Query(Context ctx, PKSystem system) {
@ -29,8 +31,14 @@ namespace PluralKit.Bot
if (systemName != null && systemName.Length > Limits.MaxSystemNameLength) if (systemName != null && systemName.Length > Limits.MaxSystemNameLength)
throw Errors.SystemNameTooLongError(systemName.Length); throw Errors.SystemNameTooLongError(systemName.Length);
var system = await _data.CreateSystem(systemName); var system = _db.Execute(async c =>
await _data.AddAccount(system, ctx.Author.Id); {
var system = await _repo.CreateSystem(c, systemName);
await _repo.AddAccount(c, system.Id, ctx.Author.Id);
return system;
});
// TODO: better message, perhaps embed like in groups?
await ctx.Reply($"{Emojis.Success} Your system has been created. Type `pk;system` to view it, and type `pk;system help` for more information about commands you can use now. Now that you have that set up, check out the getting started guide on setting up members and proxies: <https://pluralkit.me/start>"); await ctx.Reply($"{Emojis.Success} Your system has been created. Type `pk;system` to view it, and type `pk;system help` for more information about commands you can use now. Now that you have that set up, check out the getting started guide on setting up members and proxies: <https://pluralkit.me/start>");
} }
} }

View File

@ -19,15 +19,13 @@ namespace PluralKit.Bot
{ {
public class SystemEdit public class SystemEdit
{ {
private IDataStore _data; private readonly IDatabase _db;
private IDatabase _db; private readonly ModelRepository _repo;
private EmbedService _embeds;
public SystemEdit(IDataStore data, EmbedService embeds, IDatabase db) public SystemEdit(IDatabase db, ModelRepository repo)
{ {
_data = data;
_embeds = embeds;
_db = db; _db = db;
_repo = repo;
} }
public async Task Name(Context ctx) public async Task Name(Context ctx)
@ -37,7 +35,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear()) if (ctx.MatchClear())
{ {
var clearPatch = new SystemPatch {Name = null}; var clearPatch = new SystemPatch {Name = null};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, clearPatch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch));
await ctx.Reply($"{Emojis.Success} System name cleared."); await ctx.Reply($"{Emojis.Success} System name cleared.");
return; return;
@ -57,7 +55,7 @@ namespace PluralKit.Bot
throw Errors.SystemNameTooLongError(newSystemName.Length); throw Errors.SystemNameTooLongError(newSystemName.Length);
var patch = new SystemPatch {Name = newSystemName}; var patch = new SystemPatch {Name = newSystemName};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System name changed."); await ctx.Reply($"{Emojis.Success} System name changed.");
} }
@ -68,7 +66,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear()) if (ctx.MatchClear())
{ {
var patch = new SystemPatch {Description = null}; var patch = new SystemPatch {Description = null};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System description cleared."); await ctx.Reply($"{Emojis.Success} System description cleared.");
return; return;
@ -93,7 +91,7 @@ namespace PluralKit.Bot
if (newDescription.Length > Limits.MaxDescriptionLength) throw Errors.DescriptionTooLongError(newDescription.Length); if (newDescription.Length > Limits.MaxDescriptionLength) throw Errors.DescriptionTooLongError(newDescription.Length);
var patch = new SystemPatch {Description = newDescription}; var patch = new SystemPatch {Description = newDescription};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System description changed."); await ctx.Reply($"{Emojis.Success} System description changed.");
} }
@ -106,7 +104,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear()) if (ctx.MatchClear())
{ {
var patch = new SystemPatch {Tag = null}; var patch = new SystemPatch {Tag = null};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System tag cleared."); await ctx.Reply($"{Emojis.Success} System tag cleared.");
} else if (!ctx.HasNext(skipFlags: false)) } else if (!ctx.HasNext(skipFlags: false))
@ -124,7 +122,7 @@ namespace PluralKit.Bot
throw Errors.SystemNameTooLongError(newTag.Length); throw Errors.SystemNameTooLongError(newTag.Length);
var patch = new SystemPatch {Tag = newTag}; var patch = new SystemPatch {Tag = newTag};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"{Emojis.Success} System tag changed. Member names will now end with {newTag.AsCode()} when proxied."); await ctx.Reply($"{Emojis.Success} System tag changed. Member names will now end with {newTag.AsCode()} when proxied.");
} }
@ -136,7 +134,7 @@ namespace PluralKit.Bot
async Task ClearIcon() async Task ClearIcon()
{ {
await _db.Execute(c => c.UpdateSystem(ctx.System.Id, new SystemPatch {AvatarUrl = null})); await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch {AvatarUrl = null}));
await ctx.Reply($"{Emojis.Success} System icon cleared."); await ctx.Reply($"{Emojis.Success} System icon cleared.");
} }
@ -146,7 +144,7 @@ namespace PluralKit.Bot
throw Errors.InvalidUrl(img.Url); throw Errors.InvalidUrl(img.Url);
await AvatarUtils.VerifyAvatarOrThrow(img.Url); await AvatarUtils.VerifyAvatarOrThrow(img.Url);
await _db.Execute(c => c.UpdateSystem(ctx.System.Id, new SystemPatch {AvatarUrl = img.Url})); await _db.Execute(c => _repo.UpdateSystem(c, ctx.System.Id, new SystemPatch {AvatarUrl = img.Url}));
var msg = img.Source switch var msg = img.Source switch
{ {
@ -192,7 +190,7 @@ namespace PluralKit.Bot
if (!await ctx.ConfirmWithReply(ctx.System.Hid)) if (!await ctx.ConfirmWithReply(ctx.System.Hid))
throw new PKError($"System deletion cancelled. Note that you must reply with your system ID (`{ctx.System.Hid}`) *verbatim*."); throw new PKError($"System deletion cancelled. Note that you must reply with your system ID (`{ctx.System.Hid}`) *verbatim*.");
await _db.Execute(conn => conn.DeleteSystem(ctx.System.Id)); await _db.Execute(conn => _repo.DeleteSystem(conn, ctx.System.Id));
await ctx.Reply($"{Emojis.Success} System deleted."); await ctx.Reply($"{Emojis.Success} System deleted.");
} }
@ -200,7 +198,7 @@ namespace PluralKit.Bot
public async Task SystemProxy(Context ctx) public async Task SystemProxy(Context ctx)
{ {
ctx.CheckSystem().CheckGuildContext(); ctx.CheckSystem().CheckGuildContext();
var gs = await _db.Execute(c => c.QueryOrInsertSystemGuildConfig(ctx.Guild.Id, ctx.System.Id)); var gs = await _db.Execute(c => _repo.GetSystemGuild(c, ctx.Guild.Id, ctx.System.Id));
bool newValue; bool newValue;
if (ctx.Match("on", "enabled", "true", "yes")) newValue = true; if (ctx.Match("on", "enabled", "true", "yes")) newValue = true;
@ -216,7 +214,7 @@ namespace PluralKit.Bot
} }
var patch = new SystemGuildPatch {ProxyEnabled = newValue}; var patch = new SystemGuildPatch {ProxyEnabled = newValue};
await _db.Execute(conn => conn.UpsertSystemGuild(ctx.System.Id, ctx.Guild.Id, patch)); await _db.Execute(conn => _repo.UpsertSystemGuild(conn, ctx.System.Id, ctx.Guild.Id, patch));
if (newValue) if (newValue)
await ctx.Reply($"Message proxying in this server ({ctx.Guild.Name.EscapeMarkdown()}) is now **enabled** for your system."); await ctx.Reply($"Message proxying in this server ({ctx.Guild.Name.EscapeMarkdown()}) is now **enabled** for your system.");
@ -231,7 +229,7 @@ namespace PluralKit.Bot
if (ctx.MatchClear()) if (ctx.MatchClear())
{ {
var clearPatch = new SystemPatch {UiTz = "UTC"}; var clearPatch = new SystemPatch {UiTz = "UTC"};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, clearPatch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, clearPatch));
await ctx.Reply($"{Emojis.Success} System time zone cleared (set to UTC)."); await ctx.Reply($"{Emojis.Success} System time zone cleared (set to UTC).");
return; return;
@ -253,7 +251,7 @@ namespace PluralKit.Bot
if (!await ctx.PromptYesNo(msg)) throw Errors.TimezoneChangeCancelled; if (!await ctx.PromptYesNo(msg)) throw Errors.TimezoneChangeCancelled;
var patch = new SystemPatch {UiTz = zone.Id}; var patch = new SystemPatch {UiTz = zone.Id};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply($"System time zone changed to **{zone.Id}**."); await ctx.Reply($"System time zone changed to **{zone.Id}**.");
} }
@ -277,7 +275,7 @@ namespace PluralKit.Bot
async Task SetLevel(SystemPrivacySubject subject, PrivacyLevel level) async Task SetLevel(SystemPrivacySubject subject, PrivacyLevel level)
{ {
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, new SystemPatch().WithPrivacy(subject, level))); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithPrivacy(subject, level)));
var levelExplanation = level switch var levelExplanation = level switch
{ {
@ -302,7 +300,7 @@ namespace PluralKit.Bot
async Task SetAll(PrivacyLevel level) async Task SetAll(PrivacyLevel level)
{ {
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, new SystemPatch().WithAllPrivacy(level))); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, new SystemPatch().WithAllPrivacy(level)));
var msg = level switch var msg = level switch
{ {
@ -334,13 +332,13 @@ namespace PluralKit.Bot
else { else {
if (ctx.Match("on", "enable")) { if (ctx.Match("on", "enable")) {
var patch = new SystemPatch {PingsEnabled = true}; var patch = new SystemPatch {PingsEnabled = true};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply("Reaction pings have now been enabled."); await ctx.Reply("Reaction pings have now been enabled.");
} }
if (ctx.Match("off", "disable")) { if (ctx.Match("off", "disable")) {
var patch = new SystemPatch {PingsEnabled = false}; var patch = new SystemPatch {PingsEnabled = false};
await _db.Execute(conn => conn.UpdateSystem(ctx.System.Id, patch)); await _db.Execute(conn => _repo.UpdateSystem(conn, ctx.System.Id, patch));
await ctx.Reply("Reaction pings have now been disabled."); await ctx.Reply("Reaction pings have now been disabled.");
} }

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -10,19 +11,21 @@ namespace PluralKit.Bot
{ {
public class SystemFront public class SystemFront
{ {
private IDataStore _data; private readonly IDatabase _db;
private EmbedService _embeds; private readonly ModelRepository _repo;
private readonly EmbedService _embeds;
public SystemFront(IDataStore data, EmbedService embeds) public SystemFront(EmbedService embeds, IDatabase db, ModelRepository repo)
{ {
_data = data;
_embeds = embeds; _embeds = embeds;
_db = db;
_repo = repo;
} }
struct FrontHistoryEntry struct FrontHistoryEntry
{ {
public Instant? LastTime; public readonly Instant? LastTime;
public PKSwitch ThisSwitch; public readonly PKSwitch ThisSwitch;
public FrontHistoryEntry(Instant? lastTime, PKSwitch thisSwitch) public FrontHistoryEntry(Instant? lastTime, PKSwitch thisSwitch)
{ {
@ -36,7 +39,9 @@ namespace PluralKit.Bot
if (system == null) throw Errors.NoSystemError; if (system == null) throw Errors.NoSystemError;
ctx.CheckSystemPrivacy(system, system.FrontPrivacy); ctx.CheckSystemPrivacy(system, system.FrontPrivacy);
var sw = await _data.GetLatestSwitch(system.Id); await using var conn = await _db.Obtain();
var sw = await _repo.GetLatestSwitch(conn, system.Id);
if (sw == null) throw Errors.NoRegisteredSwitches; if (sw == null) throw Errors.NoRegisteredSwitches;
await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone, ctx.LookupContextFor(system))); await ctx.Reply(embed: await _embeds.CreateFronterEmbed(sw, system.Zone, ctx.LookupContextFor(system)));
@ -47,11 +52,16 @@ namespace PluralKit.Bot
if (system == null) throw Errors.NoSystemError; if (system == null) throw Errors.NoSystemError;
ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy); ctx.CheckSystemPrivacy(system, system.FrontHistoryPrivacy);
var sws = _data.GetSwitches(system.Id) // Gotta be careful here: if we dispose of the connection while the IAE is alive, boom
.Scan(new FrontHistoryEntry(null, null), (lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch)); await using var conn = await _db.Obtain();
var totalSwitches = await _data.GetSwitchCount(system);
var totalSwitches = await _repo.GetSwitchCount(conn, system.Id);
if (totalSwitches == 0) throw Errors.NoRegisteredSwitches; if (totalSwitches == 0) throw Errors.NoRegisteredSwitches;
var sws = _repo.GetSwitches(conn, system.Id)
.Scan(new FrontHistoryEntry(null, null),
(lastEntry, newSwitch) => new FrontHistoryEntry(lastEntry.ThisSwitch?.Timestamp, newSwitch));
var embedTitle = system.Name != null ? $"Front history of {system.Name} (`{system.Hid}`)" : $"Front history of `{system.Hid}`"; var embedTitle = system.Name != null ? $"Front history of {system.Name} (`{system.Hid}`)" : $"Front history of `{system.Hid}`";
await ctx.Paginate( await ctx.Paginate(
@ -66,8 +76,11 @@ namespace PluralKit.Bot
var lastSw = entry.LastTime; var lastSw = entry.LastTime;
var sw = entry.ThisSwitch; var sw = entry.ThisSwitch;
// Fetch member list and format // Fetch member list and format
var members = await _data.GetSwitchMembers(sw).ToListAsync(); await using var conn = await _db.Obtain();
var members = await _db.Execute(c => _repo.GetSwitchMembers(c, sw.Id)).ToListAsync();
var membersStr = members.Any() ? string.Join(", ", members.Select(m => m.NameFor(ctx))) : "no fronter"; var membersStr = members.Any() ? string.Join(", ", members.Select(m => m.NameFor(ctx))) : "no fronter";
var switchSince = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp; var switchSince = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp;
@ -112,7 +125,7 @@ namespace PluralKit.Bot
if (rangeStart == null) throw Errors.InvalidDateTime(durationStr); if (rangeStart == null) throw Errors.InvalidDateTime(durationStr);
if (rangeStart.Value.ToInstant() > now) throw Errors.FrontPercentTimeInFuture; if (rangeStart.Value.ToInstant() > now) throw Errors.FrontPercentTimeInFuture;
var frontpercent = await _data.GetFrontBreakdown(system, rangeStart.Value.ToInstant(), now); var frontpercent = await _db.Execute(c => _repo.GetFrontBreakdown(c, system.Id, rangeStart.Value.ToInstant(), now));
await ctx.Reply(embed: await _embeds.CreateFrontPercentEmbed(frontpercent, system.Zone, ctx.LookupContextFor(system))); await ctx.Reply(embed: await _embeds.CreateFrontPercentEmbed(frontpercent, system.Zone, ctx.LookupContextFor(system)));
} }
} }

View File

@ -9,28 +9,34 @@ namespace PluralKit.Bot
{ {
public class SystemLink public class SystemLink
{ {
private IDataStore _data; private readonly IDatabase _db;
private readonly ModelRepository _repo;
public SystemLink(IDataStore data) public SystemLink(IDatabase db, ModelRepository repo)
{ {
_data = data; _db = db;
_repo = repo;
} }
public async Task LinkSystem(Context ctx) public async Task LinkSystem(Context ctx)
{ {
ctx.CheckSystem(); ctx.CheckSystem();
var account = await ctx.MatchUser() ?? throw new PKSyntaxError("You must pass an account to link with (either ID or @mention)."); await using var conn = await _db.Obtain();
var accountIds = await _data.GetSystemAccounts(ctx.System);
if (accountIds.Contains(account.Id)) throw Errors.AccountAlreadyLinked;
var existingAccount = await _data.GetSystemByAccount(account.Id); var account = await ctx.MatchUser() ?? throw new PKSyntaxError("You must pass an account to link with (either ID or @mention).");
if (existingAccount != null) throw Errors.AccountInOtherSystem(existingAccount); var accountIds = await _repo.GetSystemAccounts(conn, ctx.System.Id);
if (accountIds.Contains(account.Id))
throw Errors.AccountAlreadyLinked;
var existingAccount = await _repo.GetSystemByAccount(conn, account.Id);
if (existingAccount != null)
throw Errors.AccountInOtherSystem(existingAccount);
var msg = $"{account.Mention}, please confirm the link by clicking the {Emojis.Success} reaction on this message."; var msg = $"{account.Mention}, please confirm the link by clicking the {Emojis.Success} reaction on this message.";
var mentions = new IMention[] { new UserMention(account) }; var mentions = new IMention[] { new UserMention(account) };
if (!await ctx.PromptYesNo(msg, user: account, mentions: mentions)) throw Errors.MemberLinkCancelled; if (!await ctx.PromptYesNo(msg, user: account, mentions: mentions)) throw Errors.MemberLinkCancelled;
await _data.AddAccount(ctx.System, account.Id); await _repo.AddAccount(conn, ctx.System.Id, account.Id);
await ctx.Reply($"{Emojis.Success} Account linked to system."); await ctx.Reply($"{Emojis.Success} Account linked to system.");
} }
@ -38,20 +44,22 @@ namespace PluralKit.Bot
{ {
ctx.CheckSystem(); ctx.CheckSystem();
await using var conn = await _db.Obtain();
ulong id; ulong id;
if (!ctx.HasNext()) if (!ctx.HasNext())
id = ctx.Author.Id; id = ctx.Author.Id;
else if (!ctx.MatchUserRaw(out id)) else if (!ctx.MatchUserRaw(out id))
throw new PKSyntaxError("You must pass an account to link with (either ID or @mention)."); throw new PKSyntaxError("You must pass an account to link with (either ID or @mention).");
var accountIds = (await _data.GetSystemAccounts(ctx.System)).ToList(); var accountIds = (await _repo.GetSystemAccounts(conn, ctx.System.Id)).ToList();
if (!accountIds.Contains(id)) throw Errors.AccountNotLinked; if (!accountIds.Contains(id)) throw Errors.AccountNotLinked;
if (accountIds.Count == 1) throw Errors.UnlinkingLastAccount; if (accountIds.Count == 1) throw Errors.UnlinkingLastAccount;
var msg = $"Are you sure you want to unlink <@{id}> from your system?"; var msg = $"Are you sure you want to unlink <@{id}> from your system?";
if (!await ctx.PromptYesNo(msg)) throw Errors.MemberUnlinkCancelled; if (!await ctx.PromptYesNo(msg)) throw Errors.MemberUnlinkCancelled;
await _data.RemoveAccount(ctx.System, id); await _repo.RemoveAccount(conn, ctx.System.Id, id);
await ctx.Reply($"{Emojis.Success} Account unlinked."); await ctx.Reply($"{Emojis.Success} Account unlinked.");
} }
} }

View File

@ -10,9 +10,11 @@ namespace PluralKit.Bot
public class Token public class Token
{ {
private readonly IDatabase _db; private readonly IDatabase _db;
public Token(IDatabase db) private readonly ModelRepository _repo;
public Token(IDatabase db, ModelRepository repo)
{ {
_db = db; _db = db;
_repo = repo;
} }
public async Task GetToken(Context ctx) public async Task GetToken(Context ctx)
@ -45,7 +47,7 @@ namespace PluralKit.Bot
private async Task<string> MakeAndSetNewToken(PKSystem system) private async Task<string> MakeAndSetNewToken(PKSystem system)
{ {
var patch = new SystemPatch {Token = StringUtils.GenerateToken()}; var patch = new SystemPatch {Token = StringUtils.GenerateToken()};
system = await _db.Execute(conn => conn.UpdateSystem(system.Id, patch)); system = await _db.Execute(conn => _repo.UpdateSystem(conn, system.Id, patch));
return system.Token; return system.Token;
} }

View File

@ -23,11 +23,12 @@ namespace PluralKit.Bot
private readonly ProxyService _proxy; private readonly ProxyService _proxy;
private readonly ILifetimeScope _services; private readonly ILifetimeScope _services;
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly BotConfig _config; private readonly BotConfig _config;
public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean,
IMetrics metrics, ProxyService proxy, DiscordShardedClient client, IMetrics metrics, ProxyService proxy, DiscordShardedClient client,
CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config) CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, ModelRepository repo)
{ {
_lastMessageCache = lastMessageCache; _lastMessageCache = lastMessageCache;
_loggerClean = loggerClean; _loggerClean = loggerClean;
@ -38,6 +39,7 @@ namespace PluralKit.Bot
_services = services; _services = services;
_db = db; _db = db;
_config = config; _config = config;
_repo = repo;
} }
public DiscordChannel ErrorChannelFor(MessageCreateEventArgs evt) => evt.Channel; public DiscordChannel ErrorChannelFor(MessageCreateEventArgs evt) => evt.Channel;
@ -59,7 +61,7 @@ namespace PluralKit.Bot
MessageContext ctx; MessageContext ctx;
await using (var conn = await _db.Obtain()) await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await conn.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id); ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
// Try each handler until we find one that succeeds // Try each handler until we find one that succeeds
if (await TryHandleLogClean(evt, ctx)) if (await TryHandleLogClean(evt, ctx))
@ -98,7 +100,7 @@ namespace PluralKit.Bot
try try
{ {
var system = ctx.SystemId != null ? await _db.Execute(c => c.QuerySystem(ctx.SystemId.Value)) : null; var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null;
await _tree.ExecuteCommand(new Context(_services, evt.Client, evt.Message, cmdStart, system, ctx)); await _tree.ExecuteCommand(new Context(_services, evt.Client, evt.Message, cmdStart, system, ctx));
} }
catch (PKError) catch (PKError)

View File

@ -12,28 +12,29 @@ namespace PluralKit.Bot
// Double duty :) // Double duty :)
public class MessageDeleted: IEventHandler<MessageDeleteEventArgs>, IEventHandler<MessageBulkDeleteEventArgs> public class MessageDeleted: IEventHandler<MessageDeleteEventArgs>, IEventHandler<MessageBulkDeleteEventArgs>
{ {
private readonly IDataStore _data; private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly ILogger _logger; private readonly ILogger _logger;
public MessageDeleted(IDataStore data, ILogger logger) public MessageDeleted(ILogger logger, IDatabase db, ModelRepository repo)
{ {
_data = data; _db = db;
_repo = repo;
_logger = logger.ForContext<MessageDeleted>(); _logger = logger.ForContext<MessageDeleted>();
} }
public async Task Handle(MessageDeleteEventArgs evt) public async Task Handle(MessageDeleteEventArgs evt)
{ {
// Delete deleted webhook messages from the data store // Delete deleted webhook messages from the data store
// (if we don't know whether it's a webhook, delete it just to be safe) // Most of the data in the given message is wrong/missing, so always delete just to be sure.
if (!evt.Message.WebhookMessage) return; await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id));
await _data.DeleteMessage(evt.Message.Id);
} }
public async Task Handle(MessageBulkDeleteEventArgs evt) public async Task Handle(MessageBulkDeleteEventArgs evt)
{ {
// Same as above, but bulk // Same as above, but bulk
_logger.Information("Bulk deleting {Count} messages in channel {Channel}", evt.Messages.Count, evt.Channel.Id); _logger.Information("Bulk deleting {Count} messages in channel {Channel}", evt.Messages.Count, evt.Channel.Id);
await _data.DeleteMessagesBulk(evt.Messages.Select(m => m.Id).ToList()); await _db.Execute(c => _repo.DeleteMessagesBulk(c, evt.Messages.Select(m => m.Id).ToList()));
} }
} }
} }

View File

@ -14,14 +14,16 @@ namespace PluralKit.Bot
private readonly LastMessageCacheService _lastMessageCache; private readonly LastMessageCacheService _lastMessageCache;
private readonly ProxyService _proxy; private readonly ProxyService _proxy;
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly IMetrics _metrics; private readonly IMetrics _metrics;
public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, IDatabase db, IMetrics metrics) public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, IDatabase db, IMetrics metrics, ModelRepository repo)
{ {
_lastMessageCache = lastMessageCache; _lastMessageCache = lastMessageCache;
_proxy = proxy; _proxy = proxy;
_db = db; _db = db;
_metrics = metrics; _metrics = metrics;
_repo = repo;
} }
public async Task Handle(MessageUpdateEventArgs evt) public async Task Handle(MessageUpdateEventArgs evt)
@ -36,7 +38,7 @@ namespace PluralKit.Bot
MessageContext ctx; MessageContext ctx;
await using (var conn = await _db.Obtain()) await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await conn.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id); ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: false); await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: false);
} }
} }

View File

@ -13,14 +13,16 @@ namespace PluralKit.Bot
{ {
public class ReactionAdded: IEventHandler<MessageReactionAddEventArgs> public class ReactionAdded: IEventHandler<MessageReactionAddEventArgs>
{ {
private IDataStore _data; private readonly IDatabase _db;
private EmbedService _embeds; private readonly ModelRepository _repo;
private ILogger _logger; private readonly EmbedService _embeds;
private readonly ILogger _logger;
public ReactionAdded(IDataStore data, EmbedService embeds, ILogger logger) public ReactionAdded(EmbedService embeds, ILogger logger, IDatabase db, ModelRepository repo)
{ {
_data = data;
_embeds = embeds; _embeds = embeds;
_db = db;
_repo = repo;
_logger = logger.ForContext<ReactionAdded>(); _logger = logger.ForContext<ReactionAdded>();
} }
@ -42,18 +44,21 @@ namespace PluralKit.Bot
// Ignore reactions from bots (we can't DM them anyway) // Ignore reactions from bots (we can't DM them anyway)
if (evt.User.IsBot) return; if (evt.User.IsBot) return;
Task<FullMessage> GetMessage() =>
_db.Execute(c => _repo.GetMessage(c, evt.Message.Id));
FullMessage msg; FullMessage msg;
switch (evt.Emoji.Name) switch (evt.Emoji.Name)
{ {
// Message deletion // Message deletion
case "\u274C": // Red X case "\u274C": // Red X
if ((msg = await _data.GetMessage(evt.Message.Id)) != null) if ((msg = await GetMessage()) != null)
await HandleDeleteReaction(evt, msg); await HandleDeleteReaction(evt, msg);
break; break;
case "\u2753": // Red question mark case "\u2753": // Red question mark
case "\u2754": // White question mark case "\u2754": // White question mark
if ((msg = await _data.GetMessage(evt.Message.Id)) != null) if ((msg = await GetMessage()) != null)
await HandleQueryReaction(evt, msg); await HandleQueryReaction(evt, msg);
break; break;
@ -62,7 +67,7 @@ namespace PluralKit.Bot
case "\U0001F3D3": // Ping pong paddle (lol) case "\U0001F3D3": // Ping pong paddle (lol)
case "\u23F0": // Alarm clock case "\u23F0": // Alarm clock
case "\u2757": // Exclamation mark case "\u2757": // Exclamation mark
if ((msg = await _data.GetMessage(evt.Message.Id)) != null) if ((msg = await GetMessage()) != null)
await HandlePingReaction(evt, msg); await HandlePingReaction(evt, msg);
break; break;
} }
@ -84,7 +89,7 @@ namespace PluralKit.Bot
// Message was deleted by something/someone else before we got to it // Message was deleted by something/someone else before we got to it
} }
await _data.DeleteMessage(evt.Message.Id); await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id));
} }
private async ValueTask HandleQueryReaction(MessageReactionAddEventArgs evt, FullMessage msg) private async ValueTask HandleQueryReaction(MessageReactionAddEventArgs evt, FullMessage msg)

View File

@ -21,21 +21,21 @@ namespace PluralKit.Bot
private readonly LogChannelService _logChannel; private readonly LogChannelService _logChannel;
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly IDataStore _data; private readonly ModelRepository _repo;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly WebhookExecutorService _webhookExecutor; private readonly WebhookExecutorService _webhookExecutor;
private readonly ProxyMatcher _matcher; private readonly ProxyMatcher _matcher;
private readonly IMetrics _metrics; private readonly IMetrics _metrics;
public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger, public ProxyService(LogChannelService logChannel, ILogger logger,
WebhookExecutorService webhookExecutor, IDatabase db, ProxyMatcher matcher, IMetrics metrics) WebhookExecutorService webhookExecutor, IDatabase db, ProxyMatcher matcher, IMetrics metrics, ModelRepository repo)
{ {
_logChannel = logChannel; _logChannel = logChannel;
_data = data;
_webhookExecutor = webhookExecutor; _webhookExecutor = webhookExecutor;
_db = db; _db = db;
_matcher = matcher; _matcher = matcher;
_metrics = metrics; _metrics = metrics;
_repo = repo;
_logger = logger.ForContext<ProxyService>(); _logger = logger.ForContext<ProxyService>();
} }
@ -48,7 +48,7 @@ namespace PluralKit.Bot
List<ProxyMember> members; List<ProxyMember> members;
using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime)) using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime))
members = (await conn.QueryProxyMembers(message.Author.Id, message.Channel.GuildId)).ToList(); members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.Channel.GuildId)).ToList();
if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0, if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0,
allowAutoproxy)) return false; allowAutoproxy)) return false;
@ -100,7 +100,16 @@ namespace PluralKit.Bot
match.Member.ProxyAvatar(ctx), match.Member.ProxyAvatar(ctx),
content, trigger.Attachments, allowEveryone); content, trigger.Attachments, allowEveryone);
Task SaveMessage() => _data.AddMessage(conn, trigger.Author.Id, trigger.Channel.GuildId, trigger.ChannelId, id, trigger.Id, match.Member.Id); Task SaveMessage() => _repo.AddMessage(conn, new PKMessage
{
Channel = trigger.ChannelId,
Guild = trigger.Channel.GuildId,
Member = match.Member.Id,
Mid = id,
OriginalMid = trigger.Id,
Sender = trigger.Author.Id
});
Task LogMessage() => _logChannel.LogMessage(ctx, match, trigger, id).AsTask(); Task LogMessage() => _logChannel.LogMessage(ctx, match, trigger, id).AsTask();
async Task DeleteMessage() async Task DeleteMessage()
{ {

View File

@ -7,7 +7,7 @@ namespace PluralKit.Bot
{ {
public class CpuStatService public class CpuStatService
{ {
private ILogger _logger; private readonly ILogger _logger;
public double LastCpuMeasure { get; private set; } public double LastCpuMeasure { get; private set; }

View File

@ -15,38 +15,36 @@ using PluralKit.Core;
namespace PluralKit.Bot { namespace PluralKit.Bot {
public class EmbedService public class EmbedService
{ {
private IDataStore _data; private readonly IDatabase _db;
private IDatabase _db; private readonly ModelRepository _repo;
private DiscordShardedClient _client; private readonly DiscordShardedClient _client;
public EmbedService(DiscordShardedClient client, IDataStore data, IDatabase db) public EmbedService(DiscordShardedClient client, IDatabase db, ModelRepository repo)
{ {
_client = client; _client = client;
_data = data;
_db = db; _db = db;
_repo = repo;
} }
public async Task<DiscordEmbed> CreateSystemEmbed(DiscordClient client, PKSystem system, LookupContext ctx) public async Task<DiscordEmbed> CreateSystemEmbed(DiscordClient client, PKSystem system, LookupContext ctx)
{ {
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
// Fetch/render info for all accounts simultaneously // Fetch/render info for all accounts simultaneously
var accounts = await conn.GetLinkedAccounts(system.Id); var accounts = await _repo.GetSystemAccounts(conn, system.Id);
var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUser(uid))?.NameAndMention() ?? $"(deleted account {uid})")); var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUser(uid))?.NameAndMention() ?? $"(deleted account {uid})"));
var memberCount = await conn.GetSystemMemberCount(system.Id, PrivacyLevel.Public); var memberCount = await _repo.GetSystemMemberCount(conn, system.Id, PrivacyLevel.Public);
var eb = new DiscordEmbedBuilder() var eb = new DiscordEmbedBuilder()
.WithColor(DiscordUtils.Gray) .WithColor(DiscordUtils.Gray)
.WithTitle(system.Name ?? null) .WithTitle(system.Name ?? null)
.WithThumbnail(system.AvatarUrl) .WithThumbnail(system.AvatarUrl)
.WithFooter($"System ID: {system.Hid} | Created on {system.Created.FormatZoned(system)}"); .WithFooter($"System ID: {system.Hid} | Created on {system.Created.FormatZoned(system)}");
var latestSwitch = await _data.GetLatestSwitch(system.Id); var latestSwitch = await _repo.GetLatestSwitch(conn, system.Id);
if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx)) if (latestSwitch != null && system.FrontPrivacy.CanAccess(ctx))
{ {
var switchMembers = await _data.GetSwitchMembers(latestSwitch).ToListAsync(); var switchMembers = await _repo.GetSwitchMembers(conn, latestSwitch.Id).ToListAsync();
if (switchMembers.Count > 0) if (switchMembers.Count > 0)
eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None), eb.AddField("Fronter".ToQuantity(switchMembers.Count(), ShowQuantityAs.None),
string.Join(", ", switchMembers.Select(m => m.NameFor(ctx)))); string.Join(", ", switchMembers.Select(m => m.NameFor(ctx))));
@ -105,11 +103,13 @@ namespace PluralKit.Bot {
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
var guildSettings = guild != null ? await conn.QueryOrInsertMemberGuildConfig(guild.Id, member.Id) : null; var guildSettings = guild != null ? await _repo.GetMemberGuild(conn, guild.Id, member.Id) : null;
var guildDisplayName = guildSettings?.DisplayName; var guildDisplayName = guildSettings?.DisplayName;
var avatar = guildSettings?.AvatarUrl ?? member.AvatarFor(ctx); var avatar = guildSettings?.AvatarUrl ?? member.AvatarFor(ctx);
var groups = (await conn.QueryMemberGroups(member.Id)).Where(g => g.Visibility.CanAccess(ctx)).ToList(); var groups = await _repo.GetMemberGroups(conn, member.Id)
.Where(g => g.Visibility.CanAccess(ctx))
.ToListAsync();
var eb = new DiscordEmbedBuilder() var eb = new DiscordEmbedBuilder()
// TODO: add URL of website when that's up // TODO: add URL of website when that's up
@ -157,7 +157,7 @@ namespace PluralKit.Bot {
public async Task<DiscordEmbed> CreateFronterEmbed(PKSwitch sw, DateTimeZone zone, LookupContext ctx) public async Task<DiscordEmbed> CreateFronterEmbed(PKSwitch sw, DateTimeZone zone, LookupContext ctx)
{ {
var members = await _data.GetSwitchMembers(sw).ToListAsync(); var members = await _db.Execute(c => _repo.GetSwitchMembers(c, sw.Id).ToListAsync().AsTask());
var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp; var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp;
return new DiscordEmbedBuilder() return new DiscordEmbedBuilder()
.WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? DiscordUtils.Gray) .WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? DiscordUtils.Gray)

View File

@ -10,7 +10,7 @@ namespace PluralKit.Bot
// TODO: is this still needed after the D#+ migration? // TODO: is this still needed after the D#+ migration?
public class LastMessageCacheService public class LastMessageCacheService
{ {
private IDictionary<ulong, ulong> _cache = new ConcurrentDictionary<ulong, ulong>(); private readonly IDictionary<ulong, ulong> _cache = new ConcurrentDictionary<ulong, ulong>();
public void AddMessage(ulong channel, ulong message) public void AddMessage(ulong channel, ulong message)
{ {

View File

@ -15,16 +15,16 @@ namespace PluralKit.Bot {
public class LogChannelService { public class LogChannelService {
private readonly EmbedService _embed; private readonly EmbedService _embed;
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly IDataStore _data; private readonly ModelRepository _repo;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly DiscordRestClient _rest; private readonly DiscordRestClient _rest;
public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, IDatabase db, IDataStore data) public LogChannelService(EmbedService embed, ILogger logger, DiscordRestClient rest, IDatabase db, ModelRepository repo)
{ {
_embed = embed; _embed = embed;
_rest = rest; _rest = rest;
_db = db; _db = db;
_data = data; _repo = repo;
_logger = logger.ForContext<LogChannelService>(); _logger = logger.ForContext<LogChannelService>();
} }
@ -47,8 +47,8 @@ namespace PluralKit.Bot {
// Send embed! // Send embed!
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
var embed = _embed.CreateLoggedMessageEmbed(await conn.QuerySystem(ctx.SystemId.Value), var embed = _embed.CreateLoggedMessageEmbed(await _repo.GetSystem(conn, ctx.SystemId.Value),
await conn.QueryMember(proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content, await _repo.GetMember(conn, proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content,
trigger.Channel); trigger.Channel);
var url = $"https://discord.com/channels/{trigger.Channel.GuildId}/{trigger.ChannelId}/{hookMessage}"; var url = $"https://discord.com/channels/{trigger.Channel.GuildId}/{trigger.ChannelId}/{hookMessage}";
await logChannel.SendMessageFixedAsync(content: url, embed: embed); await logChannel.SendMessageFixedAsync(content: url, embed: embed);

View File

@ -16,19 +16,19 @@ namespace PluralKit.Bot
{ {
public class LoggerCleanService public class LoggerCleanService
{ {
private static Regex _basicRegex = new Regex("(\\d{17,19})"); private static readonly Regex _basicRegex = new Regex("(\\d{17,19})");
private static Regex _dynoRegex = new Regex("Message ID: (\\d{17,19})"); private static readonly Regex _dynoRegex = new Regex("Message ID: (\\d{17,19})");
private static Regex _carlRegex = new Regex("ID: (\\d{17,19})"); private static readonly Regex _carlRegex = new Regex("ID: (\\d{17,19})");
private static Regex _circleRegex = new Regex("\\(`(\\d{17,19})`\\)"); private static readonly Regex _circleRegex = new Regex("\\(`(\\d{17,19})`\\)");
private static Regex _loggerARegex = new Regex("Message = (\\d{17,19})"); private static readonly Regex _loggerARegex = new Regex("Message = (\\d{17,19})");
private static Regex _loggerBRegex = new Regex("MessageID:(\\d{17,19})"); private static readonly Regex _loggerBRegex = new Regex("MessageID:(\\d{17,19})");
private static Regex _auttajaRegex = new Regex("Message (\\d{17,19}) deleted"); private static readonly Regex _auttajaRegex = new Regex("Message (\\d{17,19}) deleted");
private static Regex _mantaroRegex = new Regex("Message \\(?ID:? (\\d{17,19})\\)? created by .* in channel .* was deleted\\."); private static readonly Regex _mantaroRegex = new Regex("Message \\(?ID:? (\\d{17,19})\\)? created by .* in channel .* was deleted\\.");
private static Regex _pancakeRegex = new Regex("Message from <@(\\d{17,19})> deleted in"); private static readonly Regex _pancakeRegex = new Regex("Message from <@(\\d{17,19})> deleted in");
private static Regex _unbelievaboatRegex = new Regex("Message ID: (\\d{17,19})"); private static readonly Regex _unbelievaboatRegex = new Regex("Message ID: (\\d{17,19})");
private static Regex _vanessaRegex = new Regex("Message sent by <@!?(\\d{17,19})> deleted in"); private static readonly Regex _vanessaRegex = new Regex("Message sent by <@!?(\\d{17,19})> deleted in");
private static Regex _salRegex = new Regex("\\(ID: (\\d{17,19})\\)"); private static readonly Regex _salRegex = new Regex("\\(ID: (\\d{17,19})\\)");
private static Regex _GearBotRegex = new Regex("\\(``(\\d{17,19})``\\) in <#\\d{17,19}> has been removed."); private static readonly Regex _GearBotRegex = new Regex("\\(``(\\d{17,19})``\\) in <#\\d{17,19}> has been removed.");
private static readonly Dictionary<ulong, LoggerBot> _bots = new[] private static readonly Dictionary<ulong, LoggerBot> _bots = new[]
{ {
@ -55,7 +55,7 @@ namespace PluralKit.Bot
.Where(b => b.WebhookName != null) .Where(b => b.WebhookName != null)
.ToDictionary(b => b.WebhookName); .ToDictionary(b => b.WebhookName);
private IDatabase _db; private readonly IDatabase _db;
private DiscordShardedClient _client; private DiscordShardedClient _client;
public LoggerCleanService(IDatabase db, DiscordShardedClient client) public LoggerCleanService(IDatabase db, DiscordShardedClient client)

View File

@ -17,17 +17,17 @@ namespace PluralKit.Bot
{ {
public class PeriodicStatCollector public class PeriodicStatCollector
{ {
private DiscordShardedClient _client; private readonly DiscordShardedClient _client;
private IMetrics _metrics; private readonly IMetrics _metrics;
private CpuStatService _cpu; private readonly CpuStatService _cpu;
private IDatabase _db; private readonly IDatabase _db;
private WebhookCacheService _webhookCache; private readonly WebhookCacheService _webhookCache;
private DbConnectionCountHolder _countHolder; private readonly DbConnectionCountHolder _countHolder;
private ILogger _logger; private readonly ILogger _logger;
public PeriodicStatCollector(DiscordShardedClient client, IMetrics metrics, ILogger logger, WebhookCacheService webhookCache, DbConnectionCountHolder countHolder, CpuStatService cpu, IDatabase db) public PeriodicStatCollector(DiscordShardedClient client, IMetrics metrics, ILogger logger, WebhookCacheService webhookCache, DbConnectionCountHolder countHolder, CpuStatService cpu, IDatabase db)
{ {

View File

@ -16,7 +16,6 @@ namespace PluralKit.Bot
{ {
public class ShardInfoService public class ShardInfoService
{ {
public class ShardInfo public class ShardInfo
{ {
public bool HasAttachedListeners; public bool HasAttachedListeners;
@ -27,10 +26,10 @@ namespace PluralKit.Bot
public bool Connected; public bool Connected;
} }
private IMetrics _metrics; private readonly IMetrics _metrics;
private ILogger _logger; private readonly ILogger _logger;
private DiscordShardedClient _client; private readonly DiscordShardedClient _client;
private Dictionary<int, ShardInfo> _shardInfo = new Dictionary<int, ShardInfo>(); private readonly Dictionary<int, ShardInfo> _shardInfo = new Dictionary<int, ShardInfo>();
public ShardInfoService(ILogger logger, DiscordShardedClient client, IMetrics metrics) public ShardInfoService(ILogger logger, DiscordShardedClient client, IMetrics metrics)
{ {

View File

@ -18,11 +18,11 @@ namespace PluralKit.Bot
{ {
public static readonly string WebhookName = "PluralKit Proxy Webhook"; public static readonly string WebhookName = "PluralKit Proxy Webhook";
private DiscordShardedClient _client; private readonly DiscordShardedClient _client;
private ConcurrentDictionary<ulong, Lazy<Task<DiscordWebhook>>> _webhooks; private readonly ConcurrentDictionary<ulong, Lazy<Task<DiscordWebhook>>> _webhooks;
private IMetrics _metrics; private readonly IMetrics _metrics;
private ILogger _logger; private readonly ILogger _logger;
public WebhookCacheService(DiscordShardedClient client, ILogger logger, IMetrics metrics) public WebhookCacheService(DiscordShardedClient client, ILogger logger, IMetrics metrics)
{ {

View File

@ -29,10 +29,10 @@ namespace PluralKit.Bot
public class WebhookExecutorService public class WebhookExecutorService
{ {
private WebhookCacheService _webhookCache; private readonly WebhookCacheService _webhookCache;
private ILogger _logger; private readonly ILogger _logger;
private IMetrics _metrics; private readonly IMetrics _metrics;
private HttpClient _client; private readonly HttpClient _client;
public WebhookExecutorService(IMetrics metrics, WebhookCacheService webhookCache, ILogger logger, HttpClient client) public WebhookExecutorService(IMetrics metrics, WebhookCacheService webhookCache, ILogger logger, HttpClient client)
{ {

View File

@ -2,7 +2,6 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Data; using System.Data;
using System.IO; using System.IO;
using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using App.Metrics; using App.Metrics;
@ -213,5 +212,13 @@ namespace PluralKit.Core
await using var conn = await db.Obtain(); await using var conn = await db.Obtain();
return await func(conn); return await func(conn);
} }
public static async IAsyncEnumerable<T> Execute<T>(this IDatabase db, Func<IPKConnection, IAsyncEnumerable<T>> func)
{
await using var conn = await db.Obtain();
await foreach (var val in func(conn))
yield return val;
}
} }
} }

View File

@ -6,16 +6,16 @@ using Dapper;
namespace PluralKit.Core namespace PluralKit.Core
{ {
public static class DatabaseFunctionsExt public partial class ModelRepository
{ {
public static Task<MessageContext> QueryMessageContext(this IPKConnection conn, ulong account, ulong guild, ulong channel) public Task<MessageContext> GetMessageContext(IPKConnection conn, ulong account, ulong guild, ulong channel)
{ {
return conn.QueryFirstAsync<MessageContext>("message_context", return conn.QueryFirstAsync<MessageContext>("message_context",
new { account_id = account, guild_id = guild, channel_id = channel }, new { account_id = account, guild_id = guild, channel_id = channel },
commandType: CommandType.StoredProcedure); commandType: CommandType.StoredProcedure);
} }
public static Task<IEnumerable<ProxyMember>> QueryProxyMembers(this IPKConnection conn, ulong account, ulong guild) public Task<IEnumerable<ProxyMember>> GetProxyMembers(IPKConnection conn, ulong account, ulong guild)
{ {
return conn.QueryAsync<ProxyMember>("proxy_members", return conn.QueryAsync<ProxyMember>("proxy_members",
new { account_id = account, guild_id = guild }, new { account_id = account, guild_id = guild },

View File

@ -0,0 +1,83 @@
#nullable enable
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKGroup?> GetGroupByName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name});
public Task<PKGroup?> GetGroupByHid(IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()});
public Task<int> GetGroupMemberCount(IPKConnection conn, GroupId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from group_members");
if (privacyFilter != null)
query.Append(" inner join members on group_members.member_id = members.id");
query.Append(" where group_members.group_id = @Id");
if (privacyFilter != null)
query.Append(" and members.member_visibility = @PrivacyFilter");
return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter});
}
public IAsyncEnumerable<PKGroup> GetMemberGroups(IPKConnection conn, MemberId id) =>
conn.QueryStreamAsync<PKGroup>(
"select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id",
new {Id = id});
public async Task<PKGroup> CreateGroup(IPKConnection conn, SystemId system, string name)
{
var group = await conn.QueryFirstAsync<PKGroup>(
"insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *",
new {System = system, Name = name});
_logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
return group;
}
public Task<PKGroup> UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch)
{
_logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKGroup>(query, pms);
}
public Task DeleteGroup(IPKConnection conn, GroupId group)
{
_logger.Information("Deleted {GroupId}", group);
return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group});
}
public async Task AddMembersToGroup(IPKConnection conn, GroupId group,
IReadOnlyCollection<MemberId> members)
{
await using var w =
conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(group.Value);
await w.WriteAsync(member.Value);
}
await w.CompleteAsync();
_logger.Information("Added members to {GroupId}: {MemberIds}", group, members);
}
public Task RemoveMembersFromGroup(IPKConnection conn, GroupId group,
IReadOnlyCollection<MemberId> members)
{
_logger.Information("Removed members from {GroupId}: {MemberIds}", group, members);
return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)",
new {Group = @group, Members = members.ToArray()});
}
}
}

View File

@ -0,0 +1,53 @@
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task UpsertGuild(IPKConnection conn, ulong guild, GuildPatch patch)
{
_logger.Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id"))
.WithConstant("id", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task UpsertSystemGuild(IPKConnection conn, SystemId system, ulong guild,
SystemGuildPatch patch)
{
_logger.Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild"))
.WithConstant("system", system)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task UpsertMemberGuild(IPKConnection conn, MemberId member, ulong guild,
MemberGuildPatch patch)
{
_logger.Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild"))
.WithConstant("member", member)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task<GuildConfig> GetGuild(IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild});
public Task<SystemGuildSettings> GetSystemGuild(IPKConnection conn, ulong guild, SystemId system) =>
conn.QueryFirstAsync<SystemGuildSettings>(
"insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *",
new {guild, system});
public Task<MemberGuildSettings> GetMemberGuild(IPKConnection conn, ulong guild, MemberId member) =>
conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member});
}
}

View File

@ -0,0 +1,47 @@
#nullable enable
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKMember?> GetMember(IPKConnection conn, MemberId id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public Task<PKMember?> GetMemberByHid(IPKConnection conn, string hid) =>
conn.QuerySingleOrDefaultAsync<PKMember?>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
public Task<PKMember?> GetMemberByName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system });
public Task<PKMember?> GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system });
public async Task<PKMember> CreateMember(IPKConnection conn, SystemId id, string memberName)
{
var member = await conn.QueryFirstAsync<PKMember>(
"insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *",
new {SystemId = id, Name = memberName});
_logger.Information("Created {MemberId} in {SystemId}: {MemberName}",
member.Id, id, memberName);
return member;
}
public Task<PKMember> UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch)
{
_logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKMember>(query, pms);
}
public Task DeleteMember(IPKConnection conn, MemberId id)
{
_logger.Information("Deleted {MemberId}", id);
return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id});
}
}
}

View File

@ -0,0 +1,63 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public async Task AddMessage(IPKConnection conn, PKMessage msg) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@Mid, @Guild, @Channel, @Member, @Sender, @OriginalMid) on conflict do nothing", msg);
_logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
}
public async Task<FullMessage> GetMessage(IPKConnection conn, ulong id)
{
FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) =>
new FullMessage {Message = msg, System = system, Member = member};
var result = await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>(
"select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system",
Mapper, new {Id = id});
return result.FirstOrDefault();
}
public async Task DeleteMessage(IPKConnection conn, ulong id)
{
var rowCount = await conn.ExecuteAsync("delete from messages where mid = @Id", new {Id = id});
if (rowCount > 0)
_logger.Information("Deleted message {MessageId} from database", id);
}
public async Task DeleteMessagesBulk(IPKConnection conn, IReadOnlyCollection<ulong> ids)
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var rowCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)",
new {Ids = ids.Select(id => (long) id).ToArray()});
if (rowCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
ids);
}
}
public class PKMessage
{
public ulong Mid { get; set; }
public ulong? Guild { get; set; } // null value means "no data" (ie. from before this field being added)
public ulong Channel { get; set; }
public MemberId Member { get; set; }
public ulong Sender { get; set; }
public ulong? OriginalMid { get; set; }
}
public class FullMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
}

View File

@ -0,0 +1,236 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using NpgsqlTypes;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public async Task AddSwitch(IPKConnection conn, SystemId system, IReadOnlyCollection<MemberId> members)
{
// Use a transaction here since we're doing multiple executed commands in one
await using var tx = await conn.BeginTransactionAsync();
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system});
// Then we insert each member in the switch in the switch_members table
await using (var w = conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
{
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(sw.Id.Value, NpgsqlDbType.Integer);
await w.WriteAsync(member.Value, NpgsqlDbType.Integer);
}
await w.CompleteAsync();
}
// Finally we commit the tx, since the using block will otherwise rollback it
await tx.CommitAsync();
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members);
}
public async Task MoveSwitch(IPKConnection conn, SwitchId id, Instant time)
{
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = id});
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", id, time);
}
public async Task DeleteSwitch(IPKConnection conn, SwitchId id)
{
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = id});
_logger.Information("Deleted {Switch}", id);
}
public async Task DeleteAllSwitches(IPKConnection conn, SystemId system)
{
await conn.ExecuteAsync("delete from switches where system = @Id", new {Id = system});
_logger.Information("Deleted all switches in {SystemId}", system);
}
public IAsyncEnumerable<PKSwitch> GetSwitches(IPKConnection conn, SystemId system)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
return conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc",
new {System = system});
}
public async Task<int> GetSwitchCount(IPKConnection conn, SystemId system)
{
return await conn.QuerySingleAsync<int>("select count(*) from switches where system = @Id", new { Id = system });
}
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(IPKConnection conn,
SystemId system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
await using var tx = await conn.BeginTransactionAsync();
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new {System = system, Start = start});
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new {System = system, Start = start, End = end, LastSwitch = lastSwitch});
// Yield each value here
await foreach (var entry in switchMembersEntries)
yield return entry;
// Don't really need to worry about the transaction here, we're not doing any *writes*
}
public IAsyncEnumerable<PKMember> GetSwitchMembers(IPKConnection conn, SwitchId sw)
{
return conn.QueryStreamAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw});
}
public async Task<PKSwitch> GetLatestSwitch(IPKConnection conn, SystemId system) =>
// TODO: should query directly for perf
await GetSwitches(conn, system).FirstOrDefaultAsync();
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(IPKConnection conn,
SystemId system, Instant periodStart,
Instant periodEnd)
{
// TODO: IAsyncEnumerable-ify this one
// TODO: this doesn't belong in the repo
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(conn, system, periodStart, periodEnd).ToListAsync();
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
var membersList = await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new {Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList()});
var memberObjects = membersList.ToDictionary(m => m.Id);
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp
into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member])
.ToList()
};
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
outList.Add(new SwitchListEntry
{
Members = e.Members, TimespanStart = switchStartClamped, TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
}
return outList;
}
public async Task<FrontBreakdown> GetFrontBreakdown(IPKConnection conn, SystemId system, Instant periodStart,
Instant periodEnd)
{
// TODO: this doesn't belong in the repo
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetPeriodFronters(conn, system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new FrontBreakdown
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public struct FrontBreakdown
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public struct SwitchMembersListEntry
{
public MemberId Member;
public Instant Timestamp;
}
}

View File

@ -0,0 +1,78 @@
#nullable enable
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKSystem?> GetSystem(IPKConnection conn, SystemId id) =>
conn.QueryFirstOrDefaultAsync<PKSystem?>("select * from systems where id = @id", new {id});
public Task<PKSystem?> GetSystemByAccount(IPKConnection conn, ulong accountId) =>
conn.QuerySingleOrDefaultAsync<PKSystem?>(
"select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id",
new {Id = accountId});
public Task<PKSystem?> GetSystemByHid(IPKConnection conn, string hid) =>
conn.QuerySingleOrDefaultAsync<PKSystem?>("select * from systems where systems.hid = @Hid",
new {Hid = hid.ToLower()});
public Task<IEnumerable<ulong>> GetSystemAccounts(IPKConnection conn, SystemId system) =>
conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new {Id = system});
public IAsyncEnumerable<PKMember> GetSystemMembers(IPKConnection conn, SystemId system) =>
conn.QueryStreamAsync<PKMember>("select * from members where system = @SystemID", new {SystemID = system});
public Task<int> GetSystemMemberCount(IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from members where system = @Id");
if (privacyFilter != null)
query.Append($" and member_visibility = {(int) privacyFilter.Value}");
return conn.QuerySingleAsync<int>(query.ToString(), new {Id = id});
}
public async Task<PKSystem> CreateSystem(IPKConnection conn, string? systemName = null)
{
var system = await conn.QuerySingleAsync<PKSystem>(
"insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *",
new {Name = systemName});
_logger.Information("Created {SystemId}", system.Id);
return system;
}
public Task<PKSystem> UpdateSystem(IPKConnection conn, SystemId id, SystemPatch patch)
{
_logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKSystem>(query, pms);
}
public async Task AddAccount(IPKConnection conn, SystemId system, ulong accountId)
{
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing",
new {Id = accountId, SystemId = system});
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system);
}
public async Task RemoveAccount(IPKConnection conn, SystemId system, ulong accountId)
{
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId",
new {Id = accountId, SystemId = system});
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system);
}
public Task DeleteSystem(IPKConnection conn, SystemId id)
{
_logger.Information("Deleted {SystemId}", id);
return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id});
}
}
}

View File

@ -0,0 +1,15 @@
using Serilog;
namespace PluralKit.Core
{
public partial class ModelRepository
{
private readonly ILogger _logger;
public ModelRepository(ILogger logger)
{
_logger = logger.ForContext<ILogger>()
.ForContext("Elastic", "yes?");
}
}
}

View File

@ -64,7 +64,11 @@ namespace PluralKit.Core
protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct)); protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
public override void Open() => throw SyncError(nameof(Open)); public override void Open() => throw SyncError(nameof(Open));
public override void Close() => throw SyncError(nameof(Close)); public override void Close()
{
// Don't throw SyncError here, Dapper calls sync Close() internally so that sucks
Inner.Close();
}
IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction)); IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction));
IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction)); IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction));

View File

@ -1,69 +0,0 @@
#nullable enable
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public static class ModelQueryExt
{
public static Task<PKSystem?> QuerySystem(this IPKConnection conn, SystemId id) =>
conn.QueryFirstOrDefaultAsync<PKSystem?>("select * from systems where id = @id", new {id});
public static Task<int> GetSystemMemberCount(this IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from members where system = @Id");
if (privacyFilter != null)
query.Append($" and member_visibility = {(int) privacyFilter.Value}");
return conn.QuerySingleAsync<int>(query.ToString(), new {Id = id});
}
public static Task<IEnumerable<ulong>> GetLinkedAccounts(this IPKConnection conn, SystemId id) =>
conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new {Id = id});
public static Task<PKMember?> QueryMember(this IPKConnection conn, MemberId id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public static Task<PKMember?> QueryMemberByHid(this IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where hid = @hid", new {hid = hid.ToLowerInvariant()});
public static Task<PKGroup?> QueryGroupByName(this IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name});
public static Task<PKGroup?> QueryGroupByHid(this IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()});
public static Task<int> QueryGroupMemberCount(this IPKConnection conn, GroupId id,
PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from group_members");
if (privacyFilter != null)
query.Append(" inner join members on group_members.member_id = members.id");
query.Append(" where group_members.group_id = @Id");
if (privacyFilter != null)
query.Append(" and members.member_visibility = @PrivacyFilter");
return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter});
}
public static Task<IEnumerable<PKGroup>> QueryMemberGroups(this IPKConnection conn, MemberId id) =>
conn.QueryAsync<PKGroup>(
"select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id",
new {Id = id});
public static Task<GuildConfig> QueryOrInsertGuildConfig(this IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild});
public static Task<SystemGuildSettings> QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, SystemId system) =>
conn.QueryFirstAsync<SystemGuildSettings>(
"insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *",
new {guild, system});
public static Task<MemberGuildSettings> QueryOrInsertMemberGuildConfig(
this IPKConnection conn, ulong guild, MemberId member) =>
conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member});
}
}

View File

@ -1,128 +0,0 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Serilog;
namespace PluralKit.Core
{
public static class ModelPatchExt
{
public static Task<PKSystem> UpdateSystem(this IPKConnection conn, SystemId id, SystemPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {SystemId}: {@SystemPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKSystem>(query, pms);
}
public static Task DeleteSystem(this IPKConnection conn, SystemId id)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {SystemId}", id);
return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id});
}
public static async Task<PKMember> CreateMember(this IPKConnection conn, SystemId system, string memberName)
{
var member = await conn.QueryFirstAsync<PKMember>(
"insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *",
new {SystemId = system, Name = memberName});
Log.ForContext("Elastic", "yes?").Information("Created {MemberId} in {SystemId}: {MemberName}",
member.Id, system, memberName);
return member;
}
public static Task<PKMember> UpdateMember(this IPKConnection conn, MemberId id, MemberPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {MemberId}: {@MemberPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKMember>(query, pms);
}
public static Task DeleteMember(this IPKConnection conn, MemberId id)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {MemberId}", id);
return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id});
}
public static Task UpsertSystemGuild(this IPKConnection conn, SystemId system, ulong guild,
SystemGuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild"))
.WithConstant("system", system)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static Task UpsertMemberGuild(this IPKConnection conn, MemberId member, ulong guild,
MemberGuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild"))
.WithConstant("member", member)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static Task UpsertGuild(this IPKConnection conn, ulong guild, GuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id"))
.WithConstant("id", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static async Task<PKGroup> CreateGroup(this IPKConnection conn, SystemId system, string name)
{
var group = await conn.QueryFirstAsync<PKGroup>(
"insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *",
new {System = system, Name = name});
Log.ForContext("Elastic", "yes?").Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
return group;
}
public static Task<PKGroup> UpdateGroup(this IPKConnection conn, GroupId id, GroupPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {GroupId}: {@GroupPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKGroup>(query, pms);
}
public static Task DeleteGroup(this IPKConnection conn, GroupId group)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {GroupId}", group);
return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group});
}
public static async Task AddMembersToGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection<MemberId> members)
{
await using var w = conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(group.Value);
await w.WriteAsync(member.Value);
}
await w.CompleteAsync();
Log.ForContext("Elastic", "yes?").Information("Added members to {GroupId}: {MemberIds}", group, members);
}
public static Task RemoveMembersFromGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection<MemberId> members)
{
Log.ForContext("Elastic", "yes?").Information("Removed members from {GroupId}: {MemberIds}", group, members);
return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)",
new {Group = @group, Members = members.ToArray()});
}
}
}

View File

@ -1,5 +1,13 @@
namespace PluralKit.Core namespace PluralKit.Core
{ {
public enum AutoproxyMode
{
Off = 1,
Front = 2,
Latch = 3,
Member = 4
}
public class SystemGuildSettings public class SystemGuildSettings
{ {
public ulong Guild { get; } public ulong Guild { get; }

View File

@ -25,7 +25,7 @@ namespace PluralKit.Core
{ {
builder.RegisterType<DbConnectionCountHolder>().SingleInstance(); builder.RegisterType<DbConnectionCountHolder>().SingleInstance();
builder.RegisterType<Database>().As<IDatabase>().SingleInstance(); builder.RegisterType<Database>().As<IDatabase>().SingleInstance();
builder.RegisterType<PostgresDataStore>().AsSelf().As<IDataStore>(); builder.RegisterType<ModelRepository>().AsSelf().SingleInstance();
builder.Populate(new ServiceCollection().AddMemoryCache()); builder.Populate(new ServiceCollection().AddMemoryCache());
} }
@ -33,7 +33,7 @@ namespace PluralKit.Core
public class ConfigModule<T>: Module where T: new() public class ConfigModule<T>: Module where T: new()
{ {
private string _submodule; private readonly string _submodule;
public ConfigModule(string submodule = null) public ConfigModule(string submodule = null)
{ {

View File

@ -14,22 +14,24 @@ namespace PluralKit.Core
{ {
public class DataFileService public class DataFileService
{ {
private IDataStore _data; private readonly IDatabase _db;
private IDatabase _db; private readonly ModelRepository _repo;
private ILogger _logger; private readonly ILogger _logger;
public DataFileService(ILogger logger, IDataStore data, IDatabase db) public DataFileService(ILogger logger, IDatabase db, ModelRepository repo)
{ {
_data = data;
_db = db; _db = db;
_repo = repo;
_logger = logger.ForContext<DataFileService>(); _logger = logger.ForContext<DataFileService>();
} }
public async Task<DataFileSystem> ExportSystem(PKSystem system) public async Task<DataFileSystem> ExportSystem(PKSystem system)
{ {
await using var conn = await _db.Obtain();
// Export members // Export members
var members = new List<DataFileMember>(); var members = new List<DataFileMember>();
var pkMembers = _data.GetSystemMembers(system); // Read all members in the system var pkMembers = _repo.GetSystemMembers(conn, system.Id); // Read all members in the system
await foreach (var member in pkMembers.Select(m => new DataFileMember await foreach (var member in pkMembers.Select(m => new DataFileMember
{ {
@ -49,7 +51,7 @@ namespace PluralKit.Core
// Export switches // Export switches
var switches = new List<DataFileSwitch>(); var switches = new List<DataFileSwitch>();
var switchList = await _data.GetPeriodFronters(system, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant()); var switchList = await _repo.GetPeriodFronters(conn, system.Id, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
switches.AddRange(switchList.Select(x => new DataFileSwitch switches.AddRange(switchList.Select(x => new DataFileSwitch
{ {
Timestamp = x.TimespanStart.FormatExport(), Timestamp = x.TimespanStart.FormatExport(),
@ -68,7 +70,7 @@ namespace PluralKit.Core
Members = members, Members = members,
Switches = switches, Switches = switches,
Created = system.Created.FormatExport(), Created = system.Created.FormatExport(),
LinkedAccounts = (await _data.GetSystemAccounts(system)).ToList() LinkedAccounts = (await _repo.GetSystemAccounts(conn, system.Id)).ToList()
}; };
} }
@ -102,6 +104,8 @@ namespace PluralKit.Core
public async Task<ImportResult> ImportSystem(DataFileSystem data, PKSystem system, ulong accountId) public async Task<ImportResult> ImportSystem(DataFileSystem data, PKSystem system, ulong accountId)
{ {
await using var conn = await _db.Obtain();
var result = new ImportResult { var result = new ImportResult {
AddedNames = new List<string>(), AddedNames = new List<string>(),
ModifiedNames = new List<string>(), ModifiedNames = new List<string>(),
@ -112,26 +116,24 @@ namespace PluralKit.Core
// If we don't already have a system to save to, create one // If we don't already have a system to save to, create one
if (system == null) if (system == null)
{ {
system = result.System = await _data.CreateSystem(data.Name); system = result.System = await _repo.CreateSystem(conn, data.Name);
await _data.AddAccount(system, accountId); await _repo.AddAccount(conn, system.Id, accountId);
} }
await using var conn = await _db.Obtain();
// Apply system info // Apply system info
var patch = new SystemPatch {Name = data.Name}; var patch = new SystemPatch {Name = data.Name};
if (data.Description != null) patch.Description = data.Description; if (data.Description != null) patch.Description = data.Description;
if (data.Tag != null) patch.Tag = data.Tag; if (data.Tag != null) patch.Tag = data.Tag;
if (data.AvatarUrl != null) patch.AvatarUrl = data.AvatarUrl; if (data.AvatarUrl != null) patch.AvatarUrl = data.AvatarUrl;
if (data.TimeZone != null) patch.UiTz = data.TimeZone ?? "UTC"; if (data.TimeZone != null) patch.UiTz = data.TimeZone ?? "UTC";
await conn.UpdateSystem(system.Id, patch); await _repo.UpdateSystem(conn, system.Id, patch);
// -- Member/switch import -- // -- Member/switch import --
await using (var imp = await BulkImporter.Begin(system, conn)) await using (var imp = await BulkImporter.Begin(system, conn))
{ {
// Tally up the members that didn't exist before, and check member count on import // Tally up the members that didn't exist before, and check member count on import
// If creating the unmatched members would put us over the member limit, abort before creating any members // If creating the unmatched members would put us over the member limit, abort before creating any members
var memberCountBefore = await conn.GetSystemMemberCount(system.Id); var memberCountBefore = await _repo.GetSystemMemberCount(conn, system.Id);
var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name)); var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name));
if (memberCountBefore + membersToAdd > Limits.MaxMemberCount) if (memberCountBefore + membersToAdd > Limits.MaxMemberCount)
{ {

View File

@ -1,223 +0,0 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using NodaTime;
namespace PluralKit.Core {
public enum AutoproxyMode
{
Off = 1,
Front = 2,
Latch = 3,
Member = 4
}
public class FullMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
public struct PKMessage
{
public ulong Mid;
public ulong? Guild; // null value means "no data" (ie. from before this field being added)
public ulong Channel;
public ulong Sender;
public ulong? OriginalMid;
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public struct FrontBreakdown
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public struct SwitchMembersListEntry
{
public MemberId Member;
public Instant Timestamp;
}
public interface IDataStore
{
/// <summary>
/// Gets a system by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given human ID, or null if no system was found.</returns>
Task<PKSystem> GetSystemByHid(string systemHid);
/// <summary>
/// Gets a system by one of its linked Discord account IDs. Multiple IDs can return the same system.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given linked account, or null if no system was found.</returns>
Task<PKSystem> GetSystemByAccount(ulong linkedAccount);
/// <summary>
/// Gets the Discord account IDs linked to a system.
/// </summary>
/// <returns>An enumerable of Discord account IDs linked to this system.</returns>
Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system);
/// <summary>
/// Creates a system, auto-generating its corresponding IDs.
/// </summary>
/// <param name="systemName">An optional system name to set. If `null`, will not set a system name.</param>
/// <returns>The created system model.</returns>
Task<PKSystem> CreateSystem(string systemName);
// TODO: throw exception if account is present (when adding) or account isn't present (when removing)
/// <summary>
/// Links a Discord account to a system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is already linked to a system.</exception>
Task AddAccount(PKSystem system, ulong accountToAdd);
/// <summary>
/// Unlinks a Discord account from a system.
///
/// Will *not* throw if this results in an orphaned system - this is the caller's responsibility to ensure.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is not linked to the given system.</exception>
Task RemoveAccount(PKSystem system, ulong accountToRemove);
/// <summary>
/// Gets a member by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given human ID, or null if no member was found.</returns>
Task<PKMember> GetMemberByHid(string memberHid);
/// <summary>
/// Gets a member by its member name within one system.
/// </summary>
/// <para>
/// Member names are *usually* unique within a system (but not always), whereas member names
/// are almost certainly *not* unique globally - therefore only intra-system lookup is
/// allowed.
/// </para>
/// <returns>The <see cref="PKMember"/> with the given name, or null if no member was found.</returns>
Task<PKMember> GetMemberByName(PKSystem system, string name);
/// <summary>
/// Gets a member by its display name within one system.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given name, or null if no member was found.</returns>
Task<PKMember> GetMemberByDisplayName(PKSystem system, string name);
/// <summary>
/// Gets all members inside a given system.
/// </summary>
/// <returns>An enumerable of <see cref="PKMember"/> structs representing each member in the system, in no particular order.</returns>
IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName = false);
/// <summary>
/// Gets a message and its information by its ID.
/// </summary>
/// <param name="id">The message ID to look up. This can be either the ID of the trigger message containing the proxy tags or the resulting proxied webhook message.</param>
/// <returns>An extended message object, containing not only the message data itself but the associated system and member structs.</returns>
Task<FullMessage> GetMessage(ulong id); // id is both original and trigger, also add return type struct
/// <summary>
/// Saves a posted message to the database.
/// </summary>
/// <param name="senderAccount">The ID of the account that sent the original trigger message.</param>
/// <param name="guildId">The ID of the guild the message was posted to.</param>
/// <param name="channelId">The ID of the channel the message was posted to.</param>
/// <param name="postedMessageId">The ID of the message posted by the webhook.</param>
/// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param>
/// <param name="proxiedMemberId">The member (and by extension system) that was proxied.</param>
/// <returns></returns>
Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId);
/// <summary>
/// Deletes a message from the data store.
/// </summary>
/// <param name="postedMessageId">The ID of the webhook message to delete.</param>
Task DeleteMessage(ulong postedMessageId);
/// <summary>
/// Deletes messages from the data store in bulk.
/// </summary>
/// <param name="postedMessageIds">The IDs of the webhook messages to delete.</param>
Task DeleteMessagesBulk(IReadOnlyCollection<ulong> postedMessageIds);
/// <summary>
/// Gets switches from a system.
/// </summary>
/// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns>
IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system);
/// <summary>
/// Gets the total amount of switches in a given system.
/// </summary>
Task<int> GetSwitchCount(PKSystem system);
/// <summary>
/// Gets the latest (temporally; closest to now) switch of a given system.
/// </summary>
Task<PKSwitch> GetLatestSwitch(SystemId system);
/// <summary>
/// Gets the members a given switch consists of.
/// </summary>
IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw);
/// <summary>
/// Gets a list of fronters over a given period of time.
/// </summary>
/// <para>
/// This list is returned as an enumerable of "switch members", each containing a timestamp
/// and a member ID. <seealso cref="GetMemberById"/>
///
/// Switches containing multiple members will be returned as multiple switch members each with the same
/// timestamp, and a change in timestamp should be interpreted as the start of a new switch.
/// </para>
/// <returns>An enumerable of the aforementioned "switch members".</returns>
Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Calculates a breakdown of a system's fronters over a given period, including how long each member has
/// been fronting, and how long *no* member has been fronting.
/// </summary>
/// <para>
/// Switches containing multiple members will count the full switch duration for all members, meaning
/// the total duration may add up to longer than the breakdown period.
/// </para>
/// <param name="system"></param>
/// <param name="periodStart"></param>
/// <param name="periodEnd"></param>
/// <returns></returns>
Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Registers a switch with the given members in the given system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if any of the members are not in the given system.</exception>
Task AddSwitch(SystemId system, IEnumerable<PKMember> switchMembers);
/// <summary>
/// Updates the timestamp of a given switch.
/// </summary>
Task MoveSwitch(PKSwitch sw, Instant time);
/// <summary>
/// Deletes a given switch from the data store.
/// </summary>
Task DeleteSwitch(PKSwitch sw);
/// <summary>
/// Deletes all switches in a given system from the data store.
/// </summary>
Task DeleteAllSwitches(PKSystem system);
}
}

View File

@ -1,334 +0,0 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using Serilog;
namespace PluralKit.Core {
public class PostgresDataStore: IDataStore {
private readonly IDatabase _conn;
private readonly ILogger _logger;
public PostgresDataStore(IDatabase conn, ILogger logger)
{
_conn = conn;
_logger = logger
.ForContext<PostgresDataStore>()
.ForContext("Elastic", "yes?");
}
public async Task<PKSystem> CreateSystem(string systemName = null) {
PKSystem system;
using (var conn = await _conn.Obtain())
system = await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *", new { Name = systemName });
_logger.Information("Created {SystemId}", system.Id);
// New system has no accounts, therefore nothing gets cached, therefore no need to invalidate caches right here
return system;
}
public async Task AddAccount(PKSystem system, ulong accountId) {
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id });
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system.Id);
}
public async Task RemoveAccount(PKSystem system, ulong accountId) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id });
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system.Id);
}
public async Task<PKSystem> GetSystemByAccount(ulong accountId) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId });
}
public async Task<PKSystem> GetSystemByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new { Id = system.Id });
}
public async Task DeleteAllSwitches(PKSystem system)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where system = @Id", system);
_logger.Information("Deleted all switches in {SystemId}", system.Id);
}
public async Task<PKMember> GetMemberByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKMember> GetMemberByName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public async Task<PKMember> GetMemberByDisplayName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share display names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName)
{
var sql = "select * from members where system = @SystemID";
if (orderByName) sql += " order by lower(name) asc";
return _conn.QueryStreamAsync<PKMember>(sql, new { SystemID = system.Id });
}
public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new {
MessageId = postedMessageId,
GuildId = guildId,
ChannelId = channelId,
MemberId = proxiedMemberId,
SenderId = senderId,
OriginalMid = triggerMessageId
});
// todo: _logger.Debug("Stored message {Message} in channel {Channel}", postedMessageId, channelId);
}
public async Task<FullMessage> GetMessage(ulong id)
{
using (var conn = await _conn.Obtain())
return (await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>("select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system", (msg, member, system) => new FullMessage
{
Message = msg,
System = system,
Member = member
}, new { Id = id })).FirstOrDefault();
}
public async Task DeleteMessage(ulong id) {
using (var conn = await _conn.Obtain())
if (await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }) > 0)
_logger.Information("Deleted message {MessageId} from database", id);
}
public async Task DeleteMessagesBulk(IReadOnlyCollection<ulong> ids)
{
using (var conn = await _conn.Obtain())
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var foundCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", new {Ids = ids.Select(id => (long) id).ToArray()});
if (foundCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", foundCount, ids);
}
}
public async Task AddSwitch(SystemId system, IEnumerable<PKMember> members)
{
// Use a transaction here since we're doing multiple executed commands in one
await using var conn = await _conn.Obtain();
await using var tx = await conn.BeginTransactionAsync();
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system});
// Then we insert each member in the switch in the switch_members table
// TODO: can we parallelize this or send it in bulk somehow?
foreach (var member in members)
{
await conn.ExecuteAsync(
"insert into switch_members(switch, member) values(@Switch, @Member)",
new {Switch = sw.Id, Member = member.Id});
}
// Finally we commit the tx, since the using block will otherwise rollback it
await tx.CommitAsync();
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members.Select(m => m.Id));
}
public IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
// (maybe when we get caching in?)
return _conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc",
new {System = system});
}
public async Task<int> GetSwitchCount(PKSystem system)
{
using var conn = await _conn.Obtain();
return await conn.QuerySingleAsync<int>("select count(*) from switches where system = @Id", system);
}
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(PKSystem system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
await using var conn = await _conn.Obtain();
await using var tx = await conn.BeginTransactionAsync();
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new { System = system.Id, Start = start });
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new { System = system.Id, Start = start, End = end, LastSwitch = lastSwitch });
// Yield each value here
await foreach (var entry in switchMembersEntries)
yield return entry;
// Don't really need to worry about the transaction here, we're not doing any *writes*
}
public IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw)
{
return _conn.QueryStreamAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw.Id});
}
public async Task<PKSwitch> GetLatestSwitch(SystemId system) =>
await GetSwitches(system).FirstOrDefaultAsync();
public async Task MoveSwitch(PKSwitch sw, Instant time)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = sw.Id});
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", sw.Id, time);
}
public async Task DeleteSwitch(PKSwitch sw)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id});
_logger.Information("Deleted {Switch}", sw.Id);
}
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd)
{
// TODO: IAsyncEnumerable-ify this one
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(system, periodStart, periodEnd).ToListAsync();
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
Dictionary<MemberId, PKMember> memberObjects;
using (var conn = await _conn.Obtain())
{
memberObjects = (
await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new { Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList() })
).ToDictionary(m => m.Id);
}
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member]).ToList()
};
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
outList.Add(new SwitchListEntry
{
Members = e.Members,
TimespanStart = switchStartClamped,
TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
}
return outList;
}
public async Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd)
{
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetPeriodFronters(system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new FrontBreakdown
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
}

View File

@ -6,20 +6,11 @@ using Dapper;
namespace PluralKit.Core { namespace PluralKit.Core {
public static class ConnectionUtils public static class ConnectionUtils
{ {
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IDatabase connFactory, string sql, object param)
{
await using var conn = await connFactory.Obtain();
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>();
while (await reader.ReadAsync())
yield return parser(reader);
}
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IPKConnection conn, string sql, object param) public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IPKConnection conn, string sql, object param)
{ {
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param); await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>(); var parser = reader.GetRowParser<T>();
while (await reader.ReadAsync()) while (await reader.ReadAsync())
yield return parser(reader); yield return parser(reader);
} }

View File

@ -8,9 +8,9 @@ namespace PluralKit.Core
{ {
private readonly string? _conflictField; private readonly string? _conflictField;
private readonly string? _condition; private readonly string? _condition;
private StringBuilder _insertFragment = new StringBuilder(); private readonly StringBuilder _insertFragment = new StringBuilder();
private StringBuilder _valuesFragment = new StringBuilder(); private readonly StringBuilder _valuesFragment = new StringBuilder();
private StringBuilder _updateFragment = new StringBuilder(); private readonly StringBuilder _updateFragment = new StringBuilder();
private bool _firstInsert = true; private bool _firstInsert = true;
private bool _firstUpdate = true; private bool _firstUpdate = true;
public QueryType Type { get; } public QueryType Type { get; }