Major database refactor (again)

This commit is contained in:
Ske
2020-08-29 13:46:27 +02:00
parent 3996cd48c7
commit c7612df37e
55 changed files with 1014 additions and 1100 deletions

View File

@@ -2,7 +2,6 @@
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using App.Metrics;
@@ -207,11 +206,19 @@ namespace PluralKit.Core
await using var conn = await db.Obtain();
await func(conn);
}
public static async Task<T> Execute<T>(this IDatabase db, Func<IPKConnection, Task<T>> func)
{
await using var conn = await db.Obtain();
return await func(conn);
}
public static async IAsyncEnumerable<T> Execute<T>(this IDatabase db, Func<IPKConnection, IAsyncEnumerable<T>> func)
{
await using var conn = await db.Obtain();
await foreach (var val in func(conn))
yield return val;
}
}
}

View File

@@ -6,16 +6,16 @@ using Dapper;
namespace PluralKit.Core
{
public static class DatabaseFunctionsExt
public partial class ModelRepository
{
public static Task<MessageContext> QueryMessageContext(this IPKConnection conn, ulong account, ulong guild, ulong channel)
public Task<MessageContext> GetMessageContext(IPKConnection conn, ulong account, ulong guild, ulong channel)
{
return conn.QueryFirstAsync<MessageContext>("message_context",
new { account_id = account, guild_id = guild, channel_id = channel },
commandType: CommandType.StoredProcedure);
}
public static Task<IEnumerable<ProxyMember>> QueryProxyMembers(this IPKConnection conn, ulong account, ulong guild)
public Task<IEnumerable<ProxyMember>> GetProxyMembers(IPKConnection conn, ulong account, ulong guild)
{
return conn.QueryAsync<ProxyMember>("proxy_members",
new { account_id = account, guild_id = guild },

View File

@@ -0,0 +1,83 @@
#nullable enable
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKGroup?> GetGroupByName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name});
public Task<PKGroup?> GetGroupByHid(IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()});
public Task<int> GetGroupMemberCount(IPKConnection conn, GroupId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from group_members");
if (privacyFilter != null)
query.Append(" inner join members on group_members.member_id = members.id");
query.Append(" where group_members.group_id = @Id");
if (privacyFilter != null)
query.Append(" and members.member_visibility = @PrivacyFilter");
return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter});
}
public IAsyncEnumerable<PKGroup> GetMemberGroups(IPKConnection conn, MemberId id) =>
conn.QueryStreamAsync<PKGroup>(
"select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id",
new {Id = id});
public async Task<PKGroup> CreateGroup(IPKConnection conn, SystemId system, string name)
{
var group = await conn.QueryFirstAsync<PKGroup>(
"insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *",
new {System = system, Name = name});
_logger.Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
return group;
}
public Task<PKGroup> UpdateGroup(IPKConnection conn, GroupId id, GroupPatch patch)
{
_logger.Information("Updated {GroupId}: {@GroupPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKGroup>(query, pms);
}
public Task DeleteGroup(IPKConnection conn, GroupId group)
{
_logger.Information("Deleted {GroupId}", group);
return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group});
}
public async Task AddMembersToGroup(IPKConnection conn, GroupId group,
IReadOnlyCollection<MemberId> members)
{
await using var w =
conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(group.Value);
await w.WriteAsync(member.Value);
}
await w.CompleteAsync();
_logger.Information("Added members to {GroupId}: {MemberIds}", group, members);
}
public Task RemoveMembersFromGroup(IPKConnection conn, GroupId group,
IReadOnlyCollection<MemberId> members)
{
_logger.Information("Removed members from {GroupId}: {MemberIds}", group, members);
return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)",
new {Group = @group, Members = members.ToArray()});
}
}
}

View File

@@ -0,0 +1,53 @@
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task UpsertGuild(IPKConnection conn, ulong guild, GuildPatch patch)
{
_logger.Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id"))
.WithConstant("id", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task UpsertSystemGuild(IPKConnection conn, SystemId system, ulong guild,
SystemGuildPatch patch)
{
_logger.Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild"))
.WithConstant("system", system)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task UpsertMemberGuild(IPKConnection conn, MemberId member, ulong guild,
MemberGuildPatch patch)
{
_logger.Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild"))
.WithConstant("member", member)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public Task<GuildConfig> GetGuild(IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild});
public Task<SystemGuildSettings> GetSystemGuild(IPKConnection conn, ulong guild, SystemId system) =>
conn.QueryFirstAsync<SystemGuildSettings>(
"insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *",
new {guild, system});
public Task<MemberGuildSettings> GetMemberGuild(IPKConnection conn, ulong guild, MemberId member) =>
conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member});
}
}

View File

@@ -0,0 +1,47 @@
#nullable enable
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKMember?> GetMember(IPKConnection conn, MemberId id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public Task<PKMember?> GetMemberByHid(IPKConnection conn, string hid) =>
conn.QuerySingleOrDefaultAsync<PKMember?>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
public Task<PKMember?> GetMemberByName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system });
public Task<PKMember?> GetMemberByDisplayName(IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system });
public async Task<PKMember> CreateMember(IPKConnection conn, SystemId id, string memberName)
{
var member = await conn.QueryFirstAsync<PKMember>(
"insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *",
new {SystemId = id, Name = memberName});
_logger.Information("Created {MemberId} in {SystemId}: {MemberName}",
member.Id, id, memberName);
return member;
}
public Task<PKMember> UpdateMember(IPKConnection conn, MemberId id, MemberPatch patch)
{
_logger.Information("Updated {MemberId}: {@MemberPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKMember>(query, pms);
}
public Task DeleteMember(IPKConnection conn, MemberId id)
{
_logger.Information("Deleted {MemberId}", id);
return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id});
}
}
}

View File

@@ -0,0 +1,63 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public async Task AddMessage(IPKConnection conn, PKMessage msg) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@Mid, @Guild, @Channel, @Member, @Sender, @OriginalMid) on conflict do nothing", msg);
_logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
}
public async Task<FullMessage> GetMessage(IPKConnection conn, ulong id)
{
FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) =>
new FullMessage {Message = msg, System = system, Member = member};
var result = await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>(
"select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system",
Mapper, new {Id = id});
return result.FirstOrDefault();
}
public async Task DeleteMessage(IPKConnection conn, ulong id)
{
var rowCount = await conn.ExecuteAsync("delete from messages where mid = @Id", new {Id = id});
if (rowCount > 0)
_logger.Information("Deleted message {MessageId} from database", id);
}
public async Task DeleteMessagesBulk(IPKConnection conn, IReadOnlyCollection<ulong> ids)
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var rowCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)",
new {Ids = ids.Select(id => (long) id).ToArray()});
if (rowCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
ids);
}
}
public class PKMessage
{
public ulong Mid { get; set; }
public ulong? Guild { get; set; } // null value means "no data" (ie. from before this field being added)
public ulong Channel { get; set; }
public MemberId Member { get; set; }
public ulong Sender { get; set; }
public ulong? OriginalMid { get; set; }
}
public class FullMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
}

View File

@@ -0,0 +1,236 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using NpgsqlTypes;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public async Task AddSwitch(IPKConnection conn, SystemId system, IReadOnlyCollection<MemberId> members)
{
// Use a transaction here since we're doing multiple executed commands in one
await using var tx = await conn.BeginTransactionAsync();
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system});
// Then we insert each member in the switch in the switch_members table
await using (var w = conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))
{
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(sw.Id.Value, NpgsqlDbType.Integer);
await w.WriteAsync(member.Value, NpgsqlDbType.Integer);
}
await w.CompleteAsync();
}
// Finally we commit the tx, since the using block will otherwise rollback it
await tx.CommitAsync();
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members);
}
public async Task MoveSwitch(IPKConnection conn, SwitchId id, Instant time)
{
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = id});
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", id, time);
}
public async Task DeleteSwitch(IPKConnection conn, SwitchId id)
{
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = id});
_logger.Information("Deleted {Switch}", id);
}
public async Task DeleteAllSwitches(IPKConnection conn, SystemId system)
{
await conn.ExecuteAsync("delete from switches where system = @Id", new {Id = system});
_logger.Information("Deleted all switches in {SystemId}", system);
}
public IAsyncEnumerable<PKSwitch> GetSwitches(IPKConnection conn, SystemId system)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
return conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc",
new {System = system});
}
public async Task<int> GetSwitchCount(IPKConnection conn, SystemId system)
{
return await conn.QuerySingleAsync<int>("select count(*) from switches where system = @Id", new { Id = system });
}
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(IPKConnection conn,
SystemId system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
await using var tx = await conn.BeginTransactionAsync();
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new {System = system, Start = start});
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new {System = system, Start = start, End = end, LastSwitch = lastSwitch});
// Yield each value here
await foreach (var entry in switchMembersEntries)
yield return entry;
// Don't really need to worry about the transaction here, we're not doing any *writes*
}
public IAsyncEnumerable<PKMember> GetSwitchMembers(IPKConnection conn, SwitchId sw)
{
return conn.QueryStreamAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw});
}
public async Task<PKSwitch> GetLatestSwitch(IPKConnection conn, SystemId system) =>
// TODO: should query directly for perf
await GetSwitches(conn, system).FirstOrDefaultAsync();
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(IPKConnection conn,
SystemId system, Instant periodStart,
Instant periodEnd)
{
// TODO: IAsyncEnumerable-ify this one
// TODO: this doesn't belong in the repo
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(conn, system, periodStart, periodEnd).ToListAsync();
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
var membersList = await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new {Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList()});
var memberObjects = membersList.ToDictionary(m => m.Id);
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp
into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member])
.ToList()
};
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
outList.Add(new SwitchListEntry
{
Members = e.Members, TimespanStart = switchStartClamped, TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
}
return outList;
}
public async Task<FrontBreakdown> GetFrontBreakdown(IPKConnection conn, SystemId system, Instant periodStart,
Instant periodEnd)
{
// TODO: this doesn't belong in the repo
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetPeriodFronters(conn, system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new FrontBreakdown
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public struct FrontBreakdown
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public struct SwitchMembersListEntry
{
public MemberId Member;
public Instant Timestamp;
}
}

View File

@@ -0,0 +1,78 @@
#nullable enable
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public partial class ModelRepository
{
public Task<PKSystem?> GetSystem(IPKConnection conn, SystemId id) =>
conn.QueryFirstOrDefaultAsync<PKSystem?>("select * from systems where id = @id", new {id});
public Task<PKSystem?> GetSystemByAccount(IPKConnection conn, ulong accountId) =>
conn.QuerySingleOrDefaultAsync<PKSystem?>(
"select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id",
new {Id = accountId});
public Task<PKSystem?> GetSystemByHid(IPKConnection conn, string hid) =>
conn.QuerySingleOrDefaultAsync<PKSystem?>("select * from systems where systems.hid = @Hid",
new {Hid = hid.ToLower()});
public Task<IEnumerable<ulong>> GetSystemAccounts(IPKConnection conn, SystemId system) =>
conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new {Id = system});
public IAsyncEnumerable<PKMember> GetSystemMembers(IPKConnection conn, SystemId system) =>
conn.QueryStreamAsync<PKMember>("select * from members where system = @SystemID", new {SystemID = system});
public Task<int> GetSystemMemberCount(IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from members where system = @Id");
if (privacyFilter != null)
query.Append($" and member_visibility = {(int) privacyFilter.Value}");
return conn.QuerySingleAsync<int>(query.ToString(), new {Id = id});
}
public async Task<PKSystem> CreateSystem(IPKConnection conn, string? systemName = null)
{
var system = await conn.QuerySingleAsync<PKSystem>(
"insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *",
new {Name = systemName});
_logger.Information("Created {SystemId}", system.Id);
return system;
}
public Task<PKSystem> UpdateSystem(IPKConnection conn, SystemId id, SystemPatch patch)
{
_logger.Information("Updated {SystemId}: {@SystemPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKSystem>(query, pms);
}
public async Task AddAccount(IPKConnection conn, SystemId system, ulong accountId)
{
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing",
new {Id = accountId, SystemId = system});
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system);
}
public async Task RemoveAccount(IPKConnection conn, SystemId system, ulong accountId)
{
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId",
new {Id = accountId, SystemId = system});
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system);
}
public Task DeleteSystem(IPKConnection conn, SystemId id)
{
_logger.Information("Deleted {SystemId}", id);
return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id});
}
}
}

View File

@@ -0,0 +1,15 @@
using Serilog;
namespace PluralKit.Core
{
public partial class ModelRepository
{
private readonly ILogger _logger;
public ModelRepository(ILogger logger)
{
_logger = logger.ForContext<ILogger>()
.ForContext("Elastic", "yes?");
}
}
}

View File

@@ -64,7 +64,11 @@ namespace PluralKit.Core
protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(IsolationLevel level, CancellationToken ct) => new PKTransaction(await Inner.BeginTransactionAsync(level, ct));
public override void Open() => throw SyncError(nameof(Open));
public override void Close() => throw SyncError(nameof(Close));
public override void Close()
{
// Don't throw SyncError here, Dapper calls sync Close() internally so that sucks
Inner.Close();
}
IDbTransaction IPKConnection.BeginTransaction() => throw SyncError(nameof(BeginTransaction));
IDbTransaction IPKConnection.BeginTransaction(IsolationLevel level) => throw SyncError(nameof(BeginTransaction));

View File

@@ -1,69 +0,0 @@
#nullable enable
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public static class ModelQueryExt
{
public static Task<PKSystem?> QuerySystem(this IPKConnection conn, SystemId id) =>
conn.QueryFirstOrDefaultAsync<PKSystem?>("select * from systems where id = @id", new {id});
public static Task<int> GetSystemMemberCount(this IPKConnection conn, SystemId id, PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from members where system = @Id");
if (privacyFilter != null)
query.Append($" and member_visibility = {(int) privacyFilter.Value}");
return conn.QuerySingleAsync<int>(query.ToString(), new {Id = id});
}
public static Task<IEnumerable<ulong>> GetLinkedAccounts(this IPKConnection conn, SystemId id) =>
conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new {Id = id});
public static Task<PKMember?> QueryMember(this IPKConnection conn, MemberId id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public static Task<PKMember?> QueryMemberByHid(this IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where hid = @hid", new {hid = hid.ToLowerInvariant()});
public static Task<PKGroup?> QueryGroupByName(this IPKConnection conn, SystemId system, string name) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where system = @System and lower(Name) = lower(@Name)", new {System = system, Name = name});
public static Task<PKGroup?> QueryGroupByHid(this IPKConnection conn, string hid) =>
conn.QueryFirstOrDefaultAsync<PKGroup?>("select * from groups where hid = @hid", new {hid = hid.ToLowerInvariant()});
public static Task<int> QueryGroupMemberCount(this IPKConnection conn, GroupId id,
PrivacyLevel? privacyFilter = null)
{
var query = new StringBuilder("select count(*) from group_members");
if (privacyFilter != null)
query.Append(" inner join members on group_members.member_id = members.id");
query.Append(" where group_members.group_id = @Id");
if (privacyFilter != null)
query.Append(" and members.member_visibility = @PrivacyFilter");
return conn.QuerySingleOrDefaultAsync<int>(query.ToString(), new {Id = id, PrivacyFilter = privacyFilter});
}
public static Task<IEnumerable<PKGroup>> QueryMemberGroups(this IPKConnection conn, MemberId id) =>
conn.QueryAsync<PKGroup>(
"select groups.* from group_members inner join groups on group_members.group_id = groups.id where group_members.member_id = @Id",
new {Id = id});
public static Task<GuildConfig> QueryOrInsertGuildConfig(this IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild});
public static Task<SystemGuildSettings> QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, SystemId system) =>
conn.QueryFirstAsync<SystemGuildSettings>(
"insert into system_guild (guild, system) values (@guild, @system) on conflict (guild, system) do update set guild = @guild, system = @system returning *",
new {guild, system});
public static Task<MemberGuildSettings> QueryOrInsertMemberGuildConfig(
this IPKConnection conn, ulong guild, MemberId member) =>
conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member});
}
}

View File

@@ -1,128 +0,0 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Serilog;
namespace PluralKit.Core
{
public static class ModelPatchExt
{
public static Task<PKSystem> UpdateSystem(this IPKConnection conn, SystemId id, SystemPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {SystemId}: {@SystemPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("systems", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKSystem>(query, pms);
}
public static Task DeleteSystem(this IPKConnection conn, SystemId id)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {SystemId}", id);
return conn.ExecuteAsync("delete from systems where id = @Id", new {Id = id});
}
public static async Task<PKMember> CreateMember(this IPKConnection conn, SystemId system, string memberName)
{
var member = await conn.QueryFirstAsync<PKMember>(
"insert into members (hid, system, name) values (find_free_member_hid(), @SystemId, @Name) returning *",
new {SystemId = system, Name = memberName});
Log.ForContext("Elastic", "yes?").Information("Created {MemberId} in {SystemId}: {MemberName}",
member.Id, system, memberName);
return member;
}
public static Task<PKMember> UpdateMember(this IPKConnection conn, MemberId id, MemberPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {MemberId}: {@MemberPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("members", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKMember>(query, pms);
}
public static Task DeleteMember(this IPKConnection conn, MemberId id)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {MemberId}", id);
return conn.ExecuteAsync("delete from members where id = @Id", new {Id = id});
}
public static Task UpsertSystemGuild(this IPKConnection conn, SystemId system, ulong guild,
SystemGuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {SystemId} in guild {GuildId}: {@SystemGuildPatch}", system, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("system_guild", "system, guild"))
.WithConstant("system", system)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static Task UpsertMemberGuild(this IPKConnection conn, MemberId member, ulong guild,
MemberGuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {MemberId} in guild {GuildId}: {@MemberGuildPatch}", member, guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("member_guild", "member, guild"))
.WithConstant("member", member)
.WithConstant("guild", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static Task UpsertGuild(this IPKConnection conn, ulong guild, GuildPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated guild {GuildId}: {@GuildPatch}", guild, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Upsert("servers", "id"))
.WithConstant("id", guild)
.Build();
return conn.ExecuteAsync(query, pms);
}
public static async Task<PKGroup> CreateGroup(this IPKConnection conn, SystemId system, string name)
{
var group = await conn.QueryFirstAsync<PKGroup>(
"insert into groups (hid, system, name) values (find_free_group_hid(), @System, @Name) returning *",
new {System = system, Name = name});
Log.ForContext("Elastic", "yes?").Information("Created group {GroupId} in system {SystemId}: {GroupName}", group.Id, system, name);
return group;
}
public static Task<PKGroup> UpdateGroup(this IPKConnection conn, GroupId id, GroupPatch patch)
{
Log.ForContext("Elastic", "yes?").Information("Updated {GroupId}: {@GroupPatch}", id, patch);
var (query, pms) = patch.Apply(UpdateQueryBuilder.Update("groups", "id = @id"))
.WithConstant("id", id)
.Build("returning *");
return conn.QueryFirstAsync<PKGroup>(query, pms);
}
public static Task DeleteGroup(this IPKConnection conn, GroupId group)
{
Log.ForContext("Elastic", "yes?").Information("Deleted {GroupId}", group);
return conn.ExecuteAsync("delete from groups where id = @Id", new {Id = @group});
}
public static async Task AddMembersToGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection<MemberId> members)
{
await using var w = conn.BeginBinaryImport("copy group_members (group_id, member_id) from stdin (format binary)");
foreach (var member in members)
{
await w.StartRowAsync();
await w.WriteAsync(group.Value);
await w.WriteAsync(member.Value);
}
await w.CompleteAsync();
Log.ForContext("Elastic", "yes?").Information("Added members to {GroupId}: {MemberIds}", group, members);
}
public static Task RemoveMembersFromGroup(this IPKConnection conn, GroupId group, IReadOnlyCollection<MemberId> members)
{
Log.ForContext("Elastic", "yes?").Information("Removed members from {GroupId}: {MemberIds}", group, members);
return conn.ExecuteAsync("delete from group_members where group_id = @Group and member_id = any(@Members)",
new {Group = @group, Members = members.ToArray()});
}
}
}

View File

@@ -1,5 +1,13 @@
namespace PluralKit.Core
{
public enum AutoproxyMode
{
Off = 1,
Front = 2,
Latch = 3,
Member = 4
}
public class SystemGuildSettings
{
public ulong Guild { get; }

View File

@@ -25,7 +25,7 @@ namespace PluralKit.Core
{
builder.RegisterType<DbConnectionCountHolder>().SingleInstance();
builder.RegisterType<Database>().As<IDatabase>().SingleInstance();
builder.RegisterType<PostgresDataStore>().AsSelf().As<IDataStore>();
builder.RegisterType<ModelRepository>().AsSelf().SingleInstance();
builder.Populate(new ServiceCollection().AddMemoryCache());
}
@@ -33,7 +33,7 @@ namespace PluralKit.Core
public class ConfigModule<T>: Module where T: new()
{
private string _submodule;
private readonly string _submodule;
public ConfigModule(string submodule = null)
{

View File

@@ -14,22 +14,24 @@ namespace PluralKit.Core
{
public class DataFileService
{
private IDataStore _data;
private IDatabase _db;
private ILogger _logger;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly ILogger _logger;
public DataFileService(ILogger logger, IDataStore data, IDatabase db)
public DataFileService(ILogger logger, IDatabase db, ModelRepository repo)
{
_data = data;
_db = db;
_repo = repo;
_logger = logger.ForContext<DataFileService>();
}
public async Task<DataFileSystem> ExportSystem(PKSystem system)
{
await using var conn = await _db.Obtain();
// Export members
var members = new List<DataFileMember>();
var pkMembers = _data.GetSystemMembers(system); // Read all members in the system
var pkMembers = _repo.GetSystemMembers(conn, system.Id); // Read all members in the system
await foreach (var member in pkMembers.Select(m => new DataFileMember
{
@@ -49,7 +51,7 @@ namespace PluralKit.Core
// Export switches
var switches = new List<DataFileSwitch>();
var switchList = await _data.GetPeriodFronters(system, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
var switchList = await _repo.GetPeriodFronters(conn, system.Id, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
switches.AddRange(switchList.Select(x => new DataFileSwitch
{
Timestamp = x.TimespanStart.FormatExport(),
@@ -68,7 +70,7 @@ namespace PluralKit.Core
Members = members,
Switches = switches,
Created = system.Created.FormatExport(),
LinkedAccounts = (await _data.GetSystemAccounts(system)).ToList()
LinkedAccounts = (await _repo.GetSystemAccounts(conn, system.Id)).ToList()
};
}
@@ -102,6 +104,8 @@ namespace PluralKit.Core
public async Task<ImportResult> ImportSystem(DataFileSystem data, PKSystem system, ulong accountId)
{
await using var conn = await _db.Obtain();
var result = new ImportResult {
AddedNames = new List<string>(),
ModifiedNames = new List<string>(),
@@ -112,26 +116,24 @@ namespace PluralKit.Core
// If we don't already have a system to save to, create one
if (system == null)
{
system = result.System = await _data.CreateSystem(data.Name);
await _data.AddAccount(system, accountId);
system = result.System = await _repo.CreateSystem(conn, data.Name);
await _repo.AddAccount(conn, system.Id, accountId);
}
await using var conn = await _db.Obtain();
// Apply system info
var patch = new SystemPatch {Name = data.Name};
if (data.Description != null) patch.Description = data.Description;
if (data.Tag != null) patch.Tag = data.Tag;
if (data.AvatarUrl != null) patch.AvatarUrl = data.AvatarUrl;
if (data.TimeZone != null) patch.UiTz = data.TimeZone ?? "UTC";
await conn.UpdateSystem(system.Id, patch);
await _repo.UpdateSystem(conn, system.Id, patch);
// -- Member/switch import --
await using (var imp = await BulkImporter.Begin(system, conn))
{
// Tally up the members that didn't exist before, and check member count on import
// If creating the unmatched members would put us over the member limit, abort before creating any members
var memberCountBefore = await conn.GetSystemMemberCount(system.Id);
var memberCountBefore = await _repo.GetSystemMemberCount(conn, system.Id);
var membersToAdd = data.Members.Count(m => imp.IsNewMember(m.Id, m.Name));
if (memberCountBefore + membersToAdd > Limits.MaxMemberCount)
{

View File

@@ -1,223 +0,0 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using NodaTime;
namespace PluralKit.Core {
public enum AutoproxyMode
{
Off = 1,
Front = 2,
Latch = 3,
Member = 4
}
public class FullMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
public struct PKMessage
{
public ulong Mid;
public ulong? Guild; // null value means "no data" (ie. from before this field being added)
public ulong Channel;
public ulong Sender;
public ulong? OriginalMid;
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public struct FrontBreakdown
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public struct SwitchMembersListEntry
{
public MemberId Member;
public Instant Timestamp;
}
public interface IDataStore
{
/// <summary>
/// Gets a system by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given human ID, or null if no system was found.</returns>
Task<PKSystem> GetSystemByHid(string systemHid);
/// <summary>
/// Gets a system by one of its linked Discord account IDs. Multiple IDs can return the same system.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given linked account, or null if no system was found.</returns>
Task<PKSystem> GetSystemByAccount(ulong linkedAccount);
/// <summary>
/// Gets the Discord account IDs linked to a system.
/// </summary>
/// <returns>An enumerable of Discord account IDs linked to this system.</returns>
Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system);
/// <summary>
/// Creates a system, auto-generating its corresponding IDs.
/// </summary>
/// <param name="systemName">An optional system name to set. If `null`, will not set a system name.</param>
/// <returns>The created system model.</returns>
Task<PKSystem> CreateSystem(string systemName);
// TODO: throw exception if account is present (when adding) or account isn't present (when removing)
/// <summary>
/// Links a Discord account to a system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is already linked to a system.</exception>
Task AddAccount(PKSystem system, ulong accountToAdd);
/// <summary>
/// Unlinks a Discord account from a system.
///
/// Will *not* throw if this results in an orphaned system - this is the caller's responsibility to ensure.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is not linked to the given system.</exception>
Task RemoveAccount(PKSystem system, ulong accountToRemove);
/// <summary>
/// Gets a member by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given human ID, or null if no member was found.</returns>
Task<PKMember> GetMemberByHid(string memberHid);
/// <summary>
/// Gets a member by its member name within one system.
/// </summary>
/// <para>
/// Member names are *usually* unique within a system (but not always), whereas member names
/// are almost certainly *not* unique globally - therefore only intra-system lookup is
/// allowed.
/// </para>
/// <returns>The <see cref="PKMember"/> with the given name, or null if no member was found.</returns>
Task<PKMember> GetMemberByName(PKSystem system, string name);
/// <summary>
/// Gets a member by its display name within one system.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given name, or null if no member was found.</returns>
Task<PKMember> GetMemberByDisplayName(PKSystem system, string name);
/// <summary>
/// Gets all members inside a given system.
/// </summary>
/// <returns>An enumerable of <see cref="PKMember"/> structs representing each member in the system, in no particular order.</returns>
IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName = false);
/// <summary>
/// Gets a message and its information by its ID.
/// </summary>
/// <param name="id">The message ID to look up. This can be either the ID of the trigger message containing the proxy tags or the resulting proxied webhook message.</param>
/// <returns>An extended message object, containing not only the message data itself but the associated system and member structs.</returns>
Task<FullMessage> GetMessage(ulong id); // id is both original and trigger, also add return type struct
/// <summary>
/// Saves a posted message to the database.
/// </summary>
/// <param name="senderAccount">The ID of the account that sent the original trigger message.</param>
/// <param name="guildId">The ID of the guild the message was posted to.</param>
/// <param name="channelId">The ID of the channel the message was posted to.</param>
/// <param name="postedMessageId">The ID of the message posted by the webhook.</param>
/// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param>
/// <param name="proxiedMemberId">The member (and by extension system) that was proxied.</param>
/// <returns></returns>
Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId);
/// <summary>
/// Deletes a message from the data store.
/// </summary>
/// <param name="postedMessageId">The ID of the webhook message to delete.</param>
Task DeleteMessage(ulong postedMessageId);
/// <summary>
/// Deletes messages from the data store in bulk.
/// </summary>
/// <param name="postedMessageIds">The IDs of the webhook messages to delete.</param>
Task DeleteMessagesBulk(IReadOnlyCollection<ulong> postedMessageIds);
/// <summary>
/// Gets switches from a system.
/// </summary>
/// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns>
IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system);
/// <summary>
/// Gets the total amount of switches in a given system.
/// </summary>
Task<int> GetSwitchCount(PKSystem system);
/// <summary>
/// Gets the latest (temporally; closest to now) switch of a given system.
/// </summary>
Task<PKSwitch> GetLatestSwitch(SystemId system);
/// <summary>
/// Gets the members a given switch consists of.
/// </summary>
IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw);
/// <summary>
/// Gets a list of fronters over a given period of time.
/// </summary>
/// <para>
/// This list is returned as an enumerable of "switch members", each containing a timestamp
/// and a member ID. <seealso cref="GetMemberById"/>
///
/// Switches containing multiple members will be returned as multiple switch members each with the same
/// timestamp, and a change in timestamp should be interpreted as the start of a new switch.
/// </para>
/// <returns>An enumerable of the aforementioned "switch members".</returns>
Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Calculates a breakdown of a system's fronters over a given period, including how long each member has
/// been fronting, and how long *no* member has been fronting.
/// </summary>
/// <para>
/// Switches containing multiple members will count the full switch duration for all members, meaning
/// the total duration may add up to longer than the breakdown period.
/// </para>
/// <param name="system"></param>
/// <param name="periodStart"></param>
/// <param name="periodEnd"></param>
/// <returns></returns>
Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Registers a switch with the given members in the given system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if any of the members are not in the given system.</exception>
Task AddSwitch(SystemId system, IEnumerable<PKMember> switchMembers);
/// <summary>
/// Updates the timestamp of a given switch.
/// </summary>
Task MoveSwitch(PKSwitch sw, Instant time);
/// <summary>
/// Deletes a given switch from the data store.
/// </summary>
Task DeleteSwitch(PKSwitch sw);
/// <summary>
/// Deletes all switches in a given system from the data store.
/// </summary>
Task DeleteAllSwitches(PKSystem system);
}
}

View File

@@ -1,334 +0,0 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using Serilog;
namespace PluralKit.Core {
public class PostgresDataStore: IDataStore {
private readonly IDatabase _conn;
private readonly ILogger _logger;
public PostgresDataStore(IDatabase conn, ILogger logger)
{
_conn = conn;
_logger = logger
.ForContext<PostgresDataStore>()
.ForContext("Elastic", "yes?");
}
public async Task<PKSystem> CreateSystem(string systemName = null) {
PKSystem system;
using (var conn = await _conn.Obtain())
system = await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (find_free_system_hid(), @Name) returning *", new { Name = systemName });
_logger.Information("Created {SystemId}", system.Id);
// New system has no accounts, therefore nothing gets cached, therefore no need to invalidate caches right here
return system;
}
public async Task AddAccount(PKSystem system, ulong accountId) {
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id });
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system.Id);
}
public async Task RemoveAccount(PKSystem system, ulong accountId) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id });
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system.Id);
}
public async Task<PKSystem> GetSystemByAccount(ulong accountId) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId });
}
public async Task<PKSystem> GetSystemByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new { Id = system.Id });
}
public async Task DeleteAllSwitches(PKSystem system)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where system = @Id", system);
_logger.Information("Deleted all switches in {SystemId}", system.Id);
}
public async Task<PKMember> GetMemberByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKMember> GetMemberByName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public async Task<PKMember> GetMemberByDisplayName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share display names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(display_name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName)
{
var sql = "select * from members where system = @SystemID";
if (orderByName) sql += " order by lower(name) asc";
return _conn.QueryStreamAsync<PKMember>(sql, new { SystemID = system.Id });
}
public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new {
MessageId = postedMessageId,
GuildId = guildId,
ChannelId = channelId,
MemberId = proxiedMemberId,
SenderId = senderId,
OriginalMid = triggerMessageId
});
// todo: _logger.Debug("Stored message {Message} in channel {Channel}", postedMessageId, channelId);
}
public async Task<FullMessage> GetMessage(ulong id)
{
using (var conn = await _conn.Obtain())
return (await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>("select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system", (msg, member, system) => new FullMessage
{
Message = msg,
System = system,
Member = member
}, new { Id = id })).FirstOrDefault();
}
public async Task DeleteMessage(ulong id) {
using (var conn = await _conn.Obtain())
if (await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }) > 0)
_logger.Information("Deleted message {MessageId} from database", id);
}
public async Task DeleteMessagesBulk(IReadOnlyCollection<ulong> ids)
{
using (var conn = await _conn.Obtain())
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var foundCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", new {Ids = ids.Select(id => (long) id).ToArray()});
if (foundCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", foundCount, ids);
}
}
public async Task AddSwitch(SystemId system, IEnumerable<PKMember> members)
{
// Use a transaction here since we're doing multiple executed commands in one
await using var conn = await _conn.Obtain();
await using var tx = await conn.BeginTransactionAsync();
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system});
// Then we insert each member in the switch in the switch_members table
// TODO: can we parallelize this or send it in bulk somehow?
foreach (var member in members)
{
await conn.ExecuteAsync(
"insert into switch_members(switch, member) values(@Switch, @Member)",
new {Switch = sw.Id, Member = member.Id});
}
// Finally we commit the tx, since the using block will otherwise rollback it
await tx.CommitAsync();
_logger.Information("Created {SwitchId} in {SystemId}: {Members}", sw.Id, system, members.Select(m => m.Id));
}
public IAsyncEnumerable<PKSwitch> GetSwitches(SystemId system)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
// (maybe when we get caching in?)
return _conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc",
new {System = system});
}
public async Task<int> GetSwitchCount(PKSystem system)
{
using var conn = await _conn.Obtain();
return await conn.QuerySingleAsync<int>("select count(*) from switches where system = @Id", system);
}
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(PKSystem system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
await using var conn = await _conn.Obtain();
await using var tx = await conn.BeginTransactionAsync();
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new { System = system.Id, Start = start });
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new { System = system.Id, Start = start, End = end, LastSwitch = lastSwitch });
// Yield each value here
await foreach (var entry in switchMembersEntries)
yield return entry;
// Don't really need to worry about the transaction here, we're not doing any *writes*
}
public IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw)
{
return _conn.QueryStreamAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw.Id});
}
public async Task<PKSwitch> GetLatestSwitch(SystemId system) =>
await GetSwitches(system).FirstOrDefaultAsync();
public async Task MoveSwitch(PKSwitch sw, Instant time)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = sw.Id});
_logger.Information("Updated {SwitchId} timestamp: {SwitchTimestamp}", sw.Id, time);
}
public async Task DeleteSwitch(PKSwitch sw)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id});
_logger.Information("Deleted {Switch}", sw.Id);
}
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd)
{
// TODO: IAsyncEnumerable-ify this one
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(system, periodStart, periodEnd).ToListAsync();
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
Dictionary<MemberId, PKMember> memberObjects;
using (var conn = await _conn.Obtain())
{
memberObjects = (
await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new { Switches = switchMembers.Select(m => m.Member.Value).Distinct().ToList() })
).ToDictionary(m => m.Id);
}
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member]).ToList()
};
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
outList.Add(new SwitchListEntry
{
Members = e.Members,
TimespanStart = switchStartClamped,
TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
}
return outList;
}
public async Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd)
{
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetPeriodFronters(system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new FrontBreakdown
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
}

View File

@@ -6,20 +6,11 @@ using Dapper;
namespace PluralKit.Core {
public static class ConnectionUtils
{
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IDatabase connFactory, string sql, object param)
{
await using var conn = await connFactory.Obtain();
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>();
while (await reader.ReadAsync())
yield return parser(reader);
}
public static async IAsyncEnumerable<T> QueryStreamAsync<T>(this IPKConnection conn, string sql, object param)
{
await using var reader = (DbDataReader) await conn.ExecuteReaderAsync(sql, param);
var parser = reader.GetRowParser<T>();
while (await reader.ReadAsync())
yield return parser(reader);
}

View File

@@ -8,9 +8,9 @@ namespace PluralKit.Core
{
private readonly string? _conflictField;
private readonly string? _condition;
private StringBuilder _insertFragment = new StringBuilder();
private StringBuilder _valuesFragment = new StringBuilder();
private StringBuilder _updateFragment = new StringBuilder();
private readonly StringBuilder _insertFragment = new StringBuilder();
private readonly StringBuilder _valuesFragment = new StringBuilder();
private readonly StringBuilder _updateFragment = new StringBuilder();
private bool _firstInsert = true;
private bool _firstUpdate = true;
public QueryType Type { get; }