PluralKit/PluralKit.Core/Stores.cs

546 lines
24 KiB
C#
Raw Normal View History

using System;
2019-08-14 05:16:48 +00:00
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using App.Metrics.Logging;
using Dapper;
using NodaTime;
using PluralKit.Core;
using Serilog;
namespace PluralKit {
public class SystemStore {
private DbConnectionFactory _conn;
private ILogger _logger;
public SystemStore(DbConnectionFactory conn, ILogger logger)
{
this._conn = conn;
_logger = logger.ForContext<SystemStore>();
}
public async Task<PKSystem> Create(string systemName = null) {
string hid;
do
{
hid = Utils.GenerateHid();
} while (await GetByHid(hid) != null);
PKSystem system;
using (var conn = await _conn.Obtain())
system = await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (@Hid, @Name) returning *", new { Hid = hid, Name = systemName });
_logger.Information("Created system {System}", system.Id);
return system;
}
public async Task Link(PKSystem system, ulong accountId) {
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id });
_logger.Information("Linked system {System} to account {Account}", system.Id, accountId);
}
public async Task Unlink(PKSystem system, ulong accountId) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id });
_logger.Information("Unlinked system {System} from account {Account}", system.Id, accountId);
}
public async Task<PKSystem> GetByAccount(ulong accountId) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId });
}
public async Task<PKSystem> GetByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKSystem> GetByToken(string token) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where token = @Token", new { Token = token });
}
public async Task<PKSystem> GetById(int id)
{
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where id = @Id", new { Id = id });
}
public async Task Save(PKSystem system) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update systems set name = @Name, description = @Description, tag = @Tag, avatar_url = @AvatarUrl, token = @Token, ui_tz = @UiTz where id = @Id", system);
_logger.Information("Updated system {@System}", system);
}
public async Task Delete(PKSystem system) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from systems where id = @Id", system);
_logger.Information("Deleted system {System}", system.Id);
}
public async Task<IEnumerable<ulong>> GetLinkedAccountIds(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new { Id = system.Id });
}
public async Task<ulong> Count()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(id) from systems");
}
}
public class MemberStore {
private DbConnectionFactory _conn;
private ILogger _logger;
public MemberStore(DbConnectionFactory conn, ILogger logger)
{
this._conn = conn;
_logger = logger.ForContext<MemberStore>();
}
public async Task<PKMember> Create(PKSystem system, string name) {
string hid;
do
{
hid = Utils.GenerateHid();
} while (await GetByHid(hid) != null);
PKMember member;
using (var conn = await _conn.Obtain())
member = await conn.QuerySingleAsync<PKMember>("insert into members (hid, system, name) values (@Hid, @SystemId, @Name) returning *", new {
Hid = hid,
SystemID = system.Id,
Name = name
});
_logger.Information("Created member {Member}", member.Id);
return member;
}
public async Task<PKMember> GetByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKMember> GetByName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public async Task<ICollection<PKMember>> GetUnproxyableMembers(PKSystem system) {
return (await GetBySystem(system))
.Where((m) => {
var proxiedName = $"{m.Name} {system.Tag}";
return proxiedName.Length > Limits.MaxProxyNameLength || proxiedName.Length < 2;
}).ToList();
}
public async Task<IEnumerable<PKMember>> GetBySystem(PKSystem system) {
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<PKMember>("select * from members where system = @SystemID", new { SystemID = system.Id });
}
public async Task Save(PKMember member) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update members set name = @Name, display_name = @DisplayName, description = @Description, color = @Color, avatar_url = @AvatarUrl, birthday = @Birthday, pronouns = @Pronouns, prefix = @Prefix, suffix = @Suffix where id = @Id", member);
_logger.Information("Updated member {@Member}", member);
}
public async Task Delete(PKMember member) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from members where id = @Id", member);
_logger.Information("Deleted member {@Member}", member);
}
public async Task<int> MessageCount(PKMember member)
{
using (var conn = await _conn.Obtain())
return await conn.QuerySingleAsync<int>("select count(*) from messages where member = @Id", member);
}
public struct MessageBreakdownListEntry
{
public int Member;
public int MessageCount;
}
public async Task<IEnumerable<MessageBreakdownListEntry>> MessageCountsPerMember(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<MessageBreakdownListEntry>(
@"SELECT messages.member, COUNT(messages.member) messagecount
FROM members
JOIN messages
ON members.id = messages.member
WHERE members.system = @System
GROUP BY messages.member",
new { System = system.Id });
}
2019-08-14 05:16:48 +00:00
public async Task<int> MemberCount(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<int>("select count(*) from members where system = @Id", system);
}
public async Task<ulong> Count()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(id) from members");
}
}
public class MessageStore {
public struct PKMessage
{
public ulong Mid;
public ulong Channel;
public ulong Sender;
public ulong? OriginalMid;
}
public class StoredMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
private DbConnectionFactory _conn;
private ILogger _logger;
public MessageStore(DbConnectionFactory conn, ILogger logger)
{
this._conn = conn;
_logger = logger.ForContext<MessageStore>();
}
public async Task Store(ulong senderId, ulong messageId, ulong channelId, ulong originalMessage, PKMember member) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into messages(mid, channel, member, sender, original_mid) values(@MessageId, @ChannelId, @MemberId, @SenderId, @OriginalMid)", new {
MessageId = messageId,
ChannelId = channelId,
MemberId = member.Id,
SenderId = senderId,
OriginalMid = originalMessage
});
_logger.Information("Stored message {Message} in channel {Channel}", messageId, channelId);
}
public async Task<StoredMessage> Get(ulong id)
{
using (var conn = await _conn.Obtain())
return (await conn.QueryAsync<PKMessage, PKMember, PKSystem, StoredMessage>("select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system", (msg, member, system) => new StoredMessage
{
Message = msg,
System = system,
Member = member
}, new { Id = id })).FirstOrDefault();
}
public async Task Delete(ulong id) {
using (var conn = await _conn.Obtain())
if (await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }) > 0)
_logger.Information("Deleted message {Message}", id);
}
public async Task BulkDelete(IReadOnlyCollection<ulong> ids)
{
using (var conn = await _conn.Obtain())
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var foundCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", new {Ids = ids.Select(id => (long) id).ToArray()});
if (foundCount > 0)
_logger.Information("Bulk deleted messages {Messages}, {FoundCount} found", ids, foundCount);
}
}
public async Task<ulong> Count()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(mid) from messages");
}
}
public class SwitchStore
{
private DbConnectionFactory _conn;
private ILogger _logger;
public SwitchStore(DbConnectionFactory conn, ILogger logger)
{
_conn = conn;
_logger = logger.ForContext<SwitchStore>();
}
public async Task RegisterSwitch(PKSystem system, ICollection<PKMember> members)
{
// Use a transaction here since we're doing multiple executed commands in one
using (var conn = await _conn.Obtain())
using (var tx = conn.BeginTransaction())
{
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system.Id});
// Then we insert each member in the switch in the switch_members table
// TODO: can we parallelize this or send it in bulk somehow?
foreach (var member in members)
{
await conn.ExecuteAsync(
"insert into switch_members(switch, member) values(@Switch, @Member)",
new {Switch = sw.Id, Member = member.Id});
}
// Finally we commit the tx, since the using block will otherwise rollback it
tx.Commit();
_logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id));
}
}
public async Task RegisterSwitches(PKSystem system, ICollection<Tuple<Instant, ICollection<PKMember>>> switches)
{
// Use a transaction here since we're doing multiple executed commands in one
using (var conn = await _conn.Obtain())
using (var tx = conn.BeginTransaction())
{
foreach (var s in switches)
{
// First, we insert the switch itself
var sw = await conn.QueryFirstOrDefaultAsync<PKSwitch>(
@"insert into switches(system, timestamp)
select @System, @Timestamp
where not exists (
select * from switches
where system = @System and timestamp::timestamp(0) = @Timestamp
limit 1
)
returning *",
new { System = system.Id, Timestamp = s.Item1 });
// If we inserted a switch, also insert each member in the switch in the switch_members table
if (sw != null && s.Item2.Any())
await conn.ExecuteAsync(
"insert into switch_members(switch, member) select @Switch, * FROM unnest(@Members)",
new { Switch = sw.Id, Members = s.Item2.Select(x => x.Id).ToArray() });
}
// Finally we commit the tx, since the using block will otherwise rollback it
tx.Commit();
_logger.Information("Registered {SwitchCount} switches in system {System}", switches.Count, system.Id);
}
}
2019-08-14 05:16:48 +00:00
public async Task<IEnumerable<PKSwitch>> GetSwitches(PKSystem system, int count = 9999999)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
// (maybe when we get caching in?)
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<PKSwitch>("select * from switches where system = @System order by timestamp desc limit @Count", new {System = system.Id, Count = count});
}
public struct SwitchMembersListEntry
{
public int Member;
public Instant Timestamp;
}
public async Task<IEnumerable<SwitchMembersListEntry>> GetSwitchMembersList(PKSystem system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
using (var conn = await _conn.Obtain())
using (var tx = conn.BeginTransaction())
{
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new { System = system.Id, Start = start });
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = await conn.QueryAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new { System = system.Id, Start = start, End = end, LastSwitch = lastSwitch });
// Commit and return the list
tx.Commit();
return switchMembersEntries;
}
}
2019-08-14 05:16:48 +00:00
public async Task<IEnumerable<int>> GetSwitchMemberIds(PKSwitch sw)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<int>("select member from switch_members where switch = @Switch order by switch_members.id",
new {Switch = sw.Id});
}
public async Task<IEnumerable<PKMember>> GetSwitchMembers(PKSwitch sw)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw.Id});
}
public async Task<PKSwitch> GetLatestSwitch(PKSystem system) => (await GetSwitches(system, 1)).FirstOrDefault();
public async Task MoveSwitch(PKSwitch sw, Instant time)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = sw.Id});
_logger.Information("Moved switch {Switch} to {Time}", sw.Id, time);
}
public async Task DeleteSwitch(PKSwitch sw)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id});
_logger.Information("Deleted switch {Switch}");
}
public async Task<ulong> Count()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(id) from switches");
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public async Task<IEnumerable<SwitchListEntry>> GetTruncatedSwitchList(PKSystem system, Instant periodStart, Instant periodEnd)
{
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(system, periodStart, periodEnd);
2019-08-14 05:16:48 +00:00
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
Dictionary<int, PKMember> memberObjects;
using (var conn = await _conn.Obtain())
{
memberObjects = (
await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new { Switches = switchMembers.Select(m => m.Member).Distinct().ToList() })
).ToDictionary(m => m.Id);
2019-08-14 05:16:48 +00:00
}
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != 0).Select(x => memberObjects[x.Member]).ToList()
};
2019-08-14 05:16:48 +00:00
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
2019-08-14 05:16:48 +00:00
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
2019-08-14 05:16:48 +00:00
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
2019-08-14 05:16:48 +00:00
outList.Add(new SwitchListEntry
{
Members = e.Members,
2019-08-14 05:16:48 +00:00
TimespanStart = switchStartClamped,
TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
2019-08-14 05:16:48 +00:00
}
return outList;
}
public struct PerMemberSwitchDuration
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public async Task<PerMemberSwitchDuration> GetPerMemberSwitchDuration(PKSystem system, Instant periodStart,
Instant periodEnd)
{
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetTruncatedSwitchList(system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new PerMemberSwitchDuration
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
2019-04-19 18:48:37 +00:00
}