Large refactor and project restructuring

This commit is contained in:
Ske
2020-02-12 15:16:19 +01:00
parent c10e197c39
commit 6d5004bf54
71 changed files with 1664 additions and 1607 deletions

View File

@@ -0,0 +1,351 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Newtonsoft.Json;
using NodaTime;
using NodaTime.Text;
using Serilog;
namespace PluralKit.Core
{
public class DataFileService
{
private IDataStore _data;
private ILogger _logger;
public DataFileService(ILogger logger, IDataStore data)
{
_data = data;
_logger = logger.ForContext<DataFileService>();
}
public async Task<DataFileSystem> ExportSystem(PKSystem system)
{
// Export members
var members = new List<DataFileMember>();
var pkMembers = _data.GetSystemMembers(system); // Read all members in the system
var messageCounts = await _data.GetMemberMessageCountBulk(system); // Count messages proxied by all members in the system
await foreach (var member in pkMembers.Select(m => new DataFileMember
{
Id = m.Hid,
Name = m.Name,
DisplayName = m.DisplayName,
Description = m.Description,
Birthday = m.Birthday != null ? DateTimeFormats.DateExportFormat.Format(m.Birthday.Value) : null,
Pronouns = m.Pronouns,
Color = m.Color,
AvatarUrl = m.AvatarUrl,
ProxyTags = m.ProxyTags,
KeepProxy = m.KeepProxy,
Created = DateTimeFormats.TimestampExportFormat.Format(m.Created),
MessageCount = messageCounts.Where(x => x.Member == m.Id).Select(x => x.MessageCount).FirstOrDefault()
})) members.Add(member);
// Export switches
var switches = new List<DataFileSwitch>();
var switchList = await _data.GetPeriodFronters(system, Instant.FromDateTimeUtc(DateTime.MinValue.ToUniversalTime()), SystemClock.Instance.GetCurrentInstant());
switches.AddRange(switchList.Select(x => new DataFileSwitch
{
Timestamp = DateTimeFormats.TimestampExportFormat.Format(x.TimespanStart),
Members = x.Members.Select(m => m.Hid).ToList() // Look up member's HID using the member export from above
}));
return new DataFileSystem
{
Id = system.Hid,
Name = system.Name,
Description = system.Description,
Tag = system.Tag,
AvatarUrl = system.AvatarUrl,
TimeZone = system.UiTz,
Members = members,
Switches = switches,
Created = DateTimeFormats.TimestampExportFormat.Format(system.Created),
LinkedAccounts = (await _data.GetSystemAccounts(system)).ToList()
};
}
public async Task<ImportResult> ImportSystem(DataFileSystem data, PKSystem system, ulong accountId)
{
// TODO: make atomic, somehow - we'd need to obtain one IDbConnection and reuse it
// which probably means refactoring SystemStore.Save and friends etc
var result = new ImportResult {
AddedNames = new List<string>(),
ModifiedNames = new List<string>(),
Success = true // Assume success unless indicated otherwise
};
var dataFileToMemberMapping = new Dictionary<string, PKMember>();
var unmappedMembers = new List<DataFileMember>();
// If we don't already have a system to save to, create one
if (system == null)
system = await _data.CreateSystem(data.Name);
result.System = system;
// Apply system info
system.Name = data.Name;
if (data.Description != null) system.Description = data.Description;
if (data.Tag != null) system.Tag = data.Tag;
if (data.AvatarUrl != null) system.AvatarUrl = data.AvatarUrl;
if (data.TimeZone != null) system.UiTz = data.TimeZone ?? "UTC";
await _data.SaveSystem(system);
// Make sure to link the sender account, too
await _data.AddAccount(system, accountId);
// Determine which members already exist and which ones need to be created
var membersByHid = new Dictionary<string, PKMember>();
var membersByName = new Dictionary<string, PKMember>();
await foreach (var member in _data.GetSystemMembers(system))
{
membersByHid[member.Hid] = member;
membersByName[member.Name] = member;
}
foreach (var d in data.Members)
{
PKMember match = null;
if (membersByHid.TryGetValue(d.Id, out var matchByHid)) match = matchByHid; // Try to look up the member with the given ID
else if (membersByName.TryGetValue(d.Name, out var matchByName)) match = matchByName; // Try with the name instead
if (match != null)
{
dataFileToMemberMapping.Add(d.Id, match); // Relate the data file ID to the PKMember for importing switches
result.ModifiedNames.Add(d.Name);
}
else
{
unmappedMembers.Add(d); // Track members that weren't found so we can create them all
result.AddedNames.Add(d.Name);
}
}
// If creating the unmatched members would put us over the member limit, abort before creating any members
// new total: # in the system + (# in the file - # in the file that already exist)
if (data.Members.Count - dataFileToMemberMapping.Count + membersByHid.Count > Limits.MaxMemberCount)
{
result.Success = false;
result.Message = $"Import would exceed the maximum number of members ({Limits.MaxMemberCount}).";
result.AddedNames.Clear();
result.ModifiedNames.Clear();
return result;
}
// Create all unmapped members in one transaction
// These consist of members from another PluralKit system or another framework (e.g. Tupperbox)
var membersToCreate = new Dictionary<string, string>();
unmappedMembers.ForEach(x => membersToCreate.Add(x.Id, x.Name));
var newMembers = await _data.CreateMembersBulk(system, membersToCreate);
foreach (var member in newMembers)
dataFileToMemberMapping.Add(member.Key, member.Value);
// Update members with data file properties
// TODO: parallelize?
foreach (var dataMember in data.Members)
{
dataFileToMemberMapping.TryGetValue(dataMember.Id, out PKMember member);
if (member == null)
continue;
// Apply member info
member.Name = dataMember.Name;
if (dataMember.DisplayName != null) member.DisplayName = dataMember.DisplayName;
if (dataMember.Description != null) member.Description = dataMember.Description;
if (dataMember.Color != null) member.Color = dataMember.Color;
if (dataMember.AvatarUrl != null) member.AvatarUrl = dataMember.AvatarUrl;
if (dataMember.Prefix != null || dataMember.Suffix != null)
{
member.ProxyTags = new List<ProxyTag> { new ProxyTag(dataMember.Prefix, dataMember.Suffix) };
}
else
{
// Ignore proxy tags where both prefix and suffix are set to null (would be invalid anyway)
member.ProxyTags = (dataMember.ProxyTags ?? new ProxyTag[] { }).Where(tag => !tag.IsEmpty).ToList();
}
member.KeepProxy = dataMember.KeepProxy;
if (dataMember.Birthday != null)
{
var birthdayParse = DateTimeFormats.DateExportFormat.Parse(dataMember.Birthday);
member.Birthday = birthdayParse.Success ? (LocalDate?)birthdayParse.Value : null;
}
await _data.SaveMember(member);
}
// Re-map the switch members in the likely case IDs have changed
var mappedSwitches = new List<ImportedSwitch>();
foreach (var sw in data.Switches)
{
var timestamp = InstantPattern.ExtendedIso.Parse(sw.Timestamp).Value;
var swMembers = new List<PKMember>();
swMembers.AddRange(sw.Members.Select(x =>
dataFileToMemberMapping.FirstOrDefault(y => y.Key.Equals(x)).Value));
mappedSwitches.Add(new ImportedSwitch
{
Timestamp = timestamp,
Members = swMembers
});
}
// Import switches
if (mappedSwitches.Any())
await _data.AddSwitchesBulk(system, mappedSwitches);
_logger.Information("Imported system {System}", system.Hid);
return result;
}
}
public struct ImportResult
{
public ICollection<string> AddedNames;
public ICollection<string> ModifiedNames;
public PKSystem System;
public bool Success;
public string Message;
}
public struct DataFileSystem
{
[JsonProperty("id")] public string Id;
[JsonProperty("name")] public string Name;
[JsonProperty("description")] public string Description;
[JsonProperty("tag")] public string Tag;
[JsonProperty("avatar_url")] public string AvatarUrl;
[JsonProperty("timezone")] public string TimeZone;
[JsonProperty("members")] public ICollection<DataFileMember> Members;
[JsonProperty("switches")] public ICollection<DataFileSwitch> Switches;
[JsonProperty("accounts")] public ICollection<ulong> LinkedAccounts;
[JsonProperty("created")] public string Created;
private bool TimeZoneValid => TimeZone == null || DateTimeZoneProviders.Tzdb.GetZoneOrNull(TimeZone) != null;
[JsonIgnore] public bool Valid => TimeZoneValid && Members != null && Members.All(m => m.Valid);
}
public struct DataFileMember
{
[JsonProperty("id")] public string Id;
[JsonProperty("name")] public string Name;
[JsonProperty("display_name")] public string DisplayName;
[JsonProperty("description")] public string Description;
[JsonProperty("birthday")] public string Birthday;
[JsonProperty("pronouns")] public string Pronouns;
[JsonProperty("color")] public string Color;
[JsonProperty("avatar_url")] public string AvatarUrl;
// For legacy single-tag imports
[JsonProperty("prefix")] [JsonIgnore] public string Prefix;
[JsonProperty("suffix")] [JsonIgnore] public string Suffix;
// ^ is superseded by v
[JsonProperty("proxy_tags")] public ICollection<ProxyTag> ProxyTags;
[JsonProperty("keep_proxy")] public bool KeepProxy;
[JsonProperty("message_count")] public int MessageCount;
[JsonProperty("created")] public string Created;
[JsonIgnore] public bool Valid => Name != null;
}
public struct DataFileSwitch
{
[JsonProperty("timestamp")] public string Timestamp;
[JsonProperty("members")] public ICollection<string> Members;
}
public struct TupperboxConversionResult
{
public bool HadGroups;
public bool HadIndividualTags;
public DataFileSystem System;
}
public struct TupperboxProfile
{
[JsonProperty("tuppers")] public ICollection<TupperboxTupper> Tuppers;
[JsonProperty("groups")] public ICollection<TupperboxGroup> Groups;
[JsonIgnore] public bool Valid => Tuppers != null && Groups != null && Tuppers.All(t => t.Valid) && Groups.All(g => g.Valid);
public TupperboxConversionResult ToPluralKit()
{
// Set by member conversion function
string lastSetTag = null;
TupperboxConversionResult output = default(TupperboxConversionResult);
var members = Tuppers.Select(t => t.ToPluralKit(ref lastSetTag, ref output.HadIndividualTags,
ref output.HadGroups)).ToList();
// Nowadays we set each member's display name to their name + tag, so we don't set a global system tag
output.System = new DataFileSystem
{
Members = members,
Switches = new List<DataFileSwitch>()
};
return output;
}
}
public struct TupperboxTupper
{
[JsonProperty("name")] public string Name;
[JsonProperty("avatar_url")] public string AvatarUrl;
[JsonProperty("brackets")] public IList<string> Brackets;
[JsonProperty("posts")] public int Posts; // Not supported by PK
[JsonProperty("show_brackets")] public bool ShowBrackets;
[JsonProperty("birthday")] public string Birthday;
[JsonProperty("description")] public string Description;
[JsonProperty("tag")] public string Tag;
[JsonProperty("group_id")] public string GroupId; // Not supported by PK
[JsonProperty("group_pos")] public int? GroupPos; // Not supported by PK
[JsonIgnore] public bool Valid => Name != null && Brackets != null && Brackets.Count % 2 == 0;
public DataFileMember ToPluralKit(ref string lastSetTag, ref bool multipleTags, ref bool hasGroup)
{
// If we've set a tag before and it's not the same as this one,
// then we have multiple unique tags and we pass that flag back to the caller
if (Tag != null && lastSetTag != null && lastSetTag != Tag) multipleTags = true;
lastSetTag = Tag;
// If this member is in a group, we have a (used) group and we flag that
if (GroupId != null) hasGroup = true;
// Brackets in Tupperbox format are arranged as a single array
// [prefix1, suffix1, prefix2, suffix2, prefix3... etc]
var tags = new List<ProxyTag>();
for (var i = 0; i < Brackets.Count / 2; i++)
tags.Add(new ProxyTag(Brackets[i * 2], Brackets[i * 2 + 1]));
return new DataFileMember
{
Id = Guid.NewGuid().ToString(), // Note: this is only ever used for lookup purposes
Name = Name,
AvatarUrl = AvatarUrl,
Birthday = Birthday,
Description = Description,
ProxyTags = tags,
KeepProxy = ShowBrackets,
DisplayName = Tag != null ? $"{Name} {Tag}" : null
};
}
}
public struct TupperboxGroup
{
[JsonProperty("id")] public int Id;
[JsonProperty("name")] public string Name;
[JsonProperty("description")] public string Description;
[JsonProperty("tag")] public string Tag;
[JsonIgnore] public bool Valid => true;
}
}

View File

@@ -0,0 +1,422 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using NodaTime;
namespace PluralKit.Core {
public enum AutoproxyMode
{
Off = 1,
Front = 2,
Latch = 3,
Member = 4
}
public class FullMessage
{
public PKMessage Message;
public PKMember Member;
public PKSystem System;
}
public struct PKMessage
{
public ulong Mid;
public ulong? Guild; // null value means "no data" (ie. from before this field being added)
public ulong Channel;
public ulong Sender;
public ulong? OriginalMid;
}
public struct ImportedSwitch
{
public Instant Timestamp;
public IReadOnlyCollection<PKMember> Members;
}
public struct SwitchListEntry
{
public ICollection<PKMember> Members;
public Instant TimespanStart;
public Instant TimespanEnd;
}
public struct MemberMessageCount
{
public int Member;
public int MessageCount;
}
public struct FrontBreakdown
{
public Dictionary<PKMember, Duration> MemberSwitchDurations;
public Duration NoFronterDuration;
public Instant RangeStart;
public Instant RangeEnd;
}
public struct SwitchMembersListEntry
{
public int Member;
public Instant Timestamp;
}
public struct GuildConfig
{
public ulong Id { get; set; }
public ulong? LogChannel { get; set; }
public ISet<ulong> LogBlacklist { get; set; }
public ISet<ulong> Blacklist { get; set; }
}
public class SystemGuildSettings
{
public ulong Guild { get; set; }
public bool ProxyEnabled { get; set; } = true;
public AutoproxyMode AutoproxyMode { get; set; } = AutoproxyMode.Off;
public int? AutoproxyMember { get; set; }
}
public class MemberGuildSettings
{
public int Member { get; set; }
public ulong Guild { get; set; }
public string DisplayName { get; set; }
}
public class AuxillaryProxyInformation
{
public GuildConfig Guild { get; set; }
public SystemGuildSettings SystemGuild { get; set; }
public MemberGuildSettings MemberGuild { get; set; }
}
public interface IDataStore
{
/// <summary>
/// Gets a system by its internal system ID.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given internal ID, or null if no system was found.</returns>
Task<PKSystem> GetSystemById(int systemId);
/// <summary>
/// Gets a system by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given human ID, or null if no system was found.</returns>
Task<PKSystem> GetSystemByHid(string systemHid);
/// <summary>
/// Gets a system by one of its linked Discord account IDs. Multiple IDs can return the same system.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given linked account, or null if no system was found.</returns>
Task<PKSystem> GetSystemByAccount(ulong linkedAccount);
/// <summary>
/// Gets a system by its API token.
/// </summary>
/// <returns>The <see cref="PKSystem"/> with the given API token, or null if no corresponding system was found.</returns>
Task<PKSystem> GetSystemByToken(string apiToken);
/// <summary>
/// Gets the Discord account IDs linked to a system.
/// </summary>
/// <returns>An enumerable of Discord account IDs linked to this system.</returns>
Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system);
/// <summary>
/// Gets the member count of a system.
/// </summary>
/// <param name="includePrivate">Whether the returned count should include private members.</param>
Task<int> GetSystemMemberCount(PKSystem system, bool includePrivate);
/// <summary>
/// Gets a list of members with proxy tags that conflict with the given tags.
///
/// A set of proxy tags A conflict with proxy tags B if both A's prefix and suffix
/// are a "subset" of B's. In other words, if A's prefix *starts* with B's prefix
/// and A's suffix *ends* with B's suffix, the tag pairs are considered conflicting.
/// </summary>
/// <param name="system">The system to check in.</param>
Task<IEnumerable<PKMember>> GetConflictingProxies(PKSystem system, ProxyTag tag);
/// <summary>
/// Gets a specific system's guild-specific settings for a given guild.
/// </summary>
Task<SystemGuildSettings> GetSystemGuildSettings(PKSystem system, ulong guild);
/// <summary>
/// Saves a specific system's guild-specific settings.
/// </summary>
Task SetSystemGuildSettings(PKSystem system, ulong guild, SystemGuildSettings settings);
/// <summary>
/// Creates a system, auto-generating its corresponding IDs.
/// </summary>
/// <param name="systemName">An optional system name to set. If `null`, will not set a system name.</param>
/// <returns>The created system model.</returns>
Task<PKSystem> CreateSystem(string systemName);
// TODO: throw exception if account is present (when adding) or account isn't present (when removing)
/// <summary>
/// Links a Discord account to a system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is already linked to a system.</exception>
Task AddAccount(PKSystem system, ulong accountToAdd);
/// <summary>
/// Unlinks a Discord account from a system.
///
/// Will *not* throw if this results in an orphaned system - this is the caller's responsibility to ensure.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if the given account is not linked to the given system.</exception>
Task RemoveAccount(PKSystem system, ulong accountToRemove);
/// <summary>
/// Saves the information within the given <see cref="PKSystem"/> struct to the data store.
/// </summary>
Task SaveSystem(PKSystem system);
/// <summary>
/// Deletes the given system from the database.
/// </summary>
/// <para>
/// This will also delete all the system's members, all system switches, and every message that has been proxied
/// by members in the system.
/// </para>
Task DeleteSystem(PKSystem system);
/// <summary>
/// Gets a system by its internal member ID.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given internal ID, or null if no member was found.</returns>
Task<PKMember> GetMemberById(int memberId);
/// <summary>
/// Gets a member by its user-facing human ID.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given human ID, or null if no member was found.</returns>
Task<PKMember> GetMemberByHid(string memberHid);
/// <summary>
/// Gets a member by its member name within one system.
/// </summary>
/// <para>
/// Member names are *usually* unique within a system (but not always), whereas member names
/// are almost certainly *not* unique globally - therefore only intra-system lookup is
/// allowed.
/// </para>
/// <returns>The <see cref="PKMember"/> with the given name, or null if no member was found.</returns>
Task<PKMember> GetMemberByName(PKSystem system, string name);
/// <summary>
/// Gets all members inside a given system.
/// </summary>
/// <returns>An enumerable of <see cref="PKMember"/> structs representing each member in the system, in no particular order.</returns>
IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName = false);
/// <summary>
/// Gets the amount of messages proxied by a given member.
/// </summary>
/// <returns>The message count of the given member.</returns>
Task<ulong> GetMemberMessageCount(PKMember member);
/// <summary>
/// Collects a breakdown of each member in a system's message count.
/// </summary>
/// <returns>An enumerable of members along with their message counts.</returns>
Task<IEnumerable<MemberMessageCount>> GetMemberMessageCountBulk(PKSystem system);
/// <summary>
/// Creates a member, auto-generating its corresponding IDs.
/// </summary>
/// <param name="system">The system in which to create the member.</param>
/// <param name="name">The name of the member to create.</param>
/// <returns>The created system model.</returns>
Task<PKMember> CreateMember(PKSystem system, string name);
/// <summary>
/// Creates multiple members, auto-generating each corresponding ID.
/// </summary>
/// <param name="system">The system to create the member in.</param>
/// <param name="memberNames">A dictionary containing a mapping from an arbitrary key to the member's name.</param>
/// <returns>A dictionary containing the resulting member structs, each mapped to the key given in the argument dictionary.</returns>
Task<Dictionary<string, PKMember>> CreateMembersBulk(PKSystem system, Dictionary<string, string> memberNames);
/// <summary>
/// Saves the information within the given <see cref="PKMember"/> struct to the data store.
/// </summary>
Task SaveMember(PKMember member);
/// <summary>
/// Deletes the given member from the database.
/// </summary>
/// <para>
/// This will remove this member from any switches it's involved in, as well as all the messages
/// proxied by this member.
/// </para>
Task DeleteMember(PKMember member);
/// <summary>
/// Gets a specific member's guild-specific settings for a given guild.
/// </summary>
Task<MemberGuildSettings> GetMemberGuildSettings(PKMember member, ulong guild);
/// <summary>
/// Saves a specific member's guild-specific settings.
/// </summary>
Task SetMemberGuildSettings(PKMember member, ulong guild, MemberGuildSettings settings);
/// <summary>
/// Gets a message and its information by its ID.
/// </summary>
/// <param name="id">The message ID to look up. This can be either the ID of the trigger message containing the proxy tags or the resulting proxied webhook message.</param>
/// <returns>An extended message object, containing not only the message data itself but the associated system and member structs.</returns>
Task<FullMessage> GetMessage(ulong id); // id is both original and trigger, also add return type struct
/// <summary>
/// Saves a posted message to the database.
/// </summary>
/// <param name="senderAccount">The ID of the account that sent the original trigger message.</param>
/// <param name="guildId">The ID of the guild the message was posted to.</param>
/// <param name="channelId">The ID of the channel the message was posted to.</param>
/// <param name="postedMessageId">The ID of the message posted by the webhook.</param>
/// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param>
/// <param name="proxiedMember">The member (and by extension system) that was proxied.</param>
/// <returns></returns>
Task AddMessage(ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, PKMember proxiedMember);
/// <summary>
/// Deletes a message from the data store.
/// </summary>
/// <param name="postedMessageId">The ID of the webhook message to delete.</param>
Task DeleteMessage(ulong postedMessageId);
/// <summary>
/// Deletes messages from the data store in bulk.
/// </summary>
/// <param name="postedMessageIds">The IDs of the webhook messages to delete.</param>
Task DeleteMessagesBulk(IEnumerable<ulong> postedMessageIds);
/// <summary>
/// Gets the most recent message sent by a given account in a given guild.
/// </summary>
/// <returns>The full message object, or null if none was found.</returns>
Task<FullMessage> GetLastMessageInGuild(ulong account, ulong guild);
/// <summary>
/// Gets switches from a system.
/// </summary>
/// <returns>An enumerable of the *count* latest switches in the system, in latest-first order. May contain fewer elements than requested.</returns>
IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system);
/// <summary>
/// Gets the total amount of switches in a given system.
/// </summary>
Task<int> GetSwitchCount(PKSystem system);
/// <summary>
/// Gets the latest (temporally; closest to now) switch of a given system.
/// </summary>
Task<PKSwitch> GetLatestSwitch(PKSystem system);
/// <summary>
/// Gets the members a given switch consists of.
/// </summary>
IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw);
/// <summary>
/// Gets a list of fronters over a given period of time.
/// </summary>
/// <para>
/// This list is returned as an enumerable of "switch members", each containing a timestamp
/// and a member ID. <seealso cref="GetMemberById"/>
///
/// Switches containing multiple members will be returned as multiple switch members each with the same
/// timestamp, and a change in timestamp should be interpreted as the start of a new switch.
/// </para>
/// <returns>An enumerable of the aforementioned "switch members".</returns>
Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Calculates a breakdown of a system's fronters over a given period, including how long each member has
/// been fronting, and how long *no* member has been fronting.
/// </summary>
/// <para>
/// Switches containing multiple members will count the full switch duration for all members, meaning
/// the total duration may add up to longer than the breakdown period.
/// </para>
/// <param name="system"></param>
/// <param name="periodStart"></param>
/// <param name="periodEnd"></param>
/// <returns></returns>
Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd);
/// <summary>
/// Gets the first listed fronter in a system.
/// </summary>
/// <returns>The first fronter, or null if none are registered.</returns>
Task<PKMember> GetFirstFronter(PKSystem system);
/// <summary>
/// Registers a switch with the given members in the given system.
/// </summary>
/// <exception>Throws an exception (TODO: which?) if any of the members are not in the given system.</exception>
Task AddSwitch(PKSystem system, IEnumerable<PKMember> switchMembers);
/// <summary>
/// Registers switches in bulk.
/// </summary>
/// <param name="switches">A list of switch structs, each containing a timestamp and a list of members.</param>
/// <exception>Throws an exception (TODO: which?) if any of the given members are not in the given system.</exception>
Task AddSwitchesBulk(PKSystem system, IEnumerable<ImportedSwitch> switches);
/// <summary>
/// Updates the timestamp of a given switch.
/// </summary>
Task MoveSwitch(PKSwitch sw, Instant time);
/// <summary>
/// Deletes a given switch from the data store.
/// </summary>
Task DeleteSwitch(PKSwitch sw);
/// <summary>
/// Deletes all switches in a given system from the data store.
/// </summary>
Task DeleteAllSwitches(PKSystem system);
/// <summary>
/// Gets the total amount of systems in the data store.
/// </summary>
Task<ulong> GetTotalSystems();
/// <summary>
/// Gets the total amount of members in the data store.
/// </summary>
Task<ulong> GetTotalMembers();
/// <summary>
/// Gets the total amount of switches in the data store.
/// </summary>
Task<ulong> GetTotalSwitches();
/// <summary>
/// Gets the total amount of messages in the data store.
/// </summary>
Task<ulong> GetTotalMessages();
/// <summary>
/// Gets the guild configuration struct for a given guild, creating and saving one if none was found.
/// </summary>
/// <returns>The guild's configuration struct.</returns>
Task<GuildConfig> GetOrCreateGuildConfig(ulong guild);
/// <summary>
/// Saves the given guild configuration struct to the data store.
/// </summary>
Task SaveGuildConfig(GuildConfig cfg);
Task<AuxillaryProxyInformation> GetAuxillaryProxyInformation(ulong guild, PKSystem system, PKMember member);
}
}

View File

@@ -0,0 +1,700 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using NodaTime;
using Serilog;
namespace PluralKit.Core {
public class PostgresDataStore: IDataStore {
private DbConnectionFactory _conn;
private ILogger _logger;
private ProxyCache _cache;
public PostgresDataStore(DbConnectionFactory conn, ILogger logger, ProxyCache cache)
{
_conn = conn;
_logger = logger;
_cache = cache;
}
public async Task<IEnumerable<PKMember>> GetConflictingProxies(PKSystem system, ProxyTag tag)
{
using (var conn = await _conn.Obtain())
// return await conn.QueryAsync<PKMember>("select * from (select *, (unnest(proxy_tags)).prefix as prefix, (unnest(proxy_tags)).suffix as suffix from members where system = @System) as _ where prefix ilike @Prefix and suffix ilike @Suffix", new
// {
// System = system.Id,
// Prefix = tag.Prefix.Replace("%", "\\%") + "%",
// Suffix = "%" + tag.Suffix.Replace("%", "\\%")
// });
return await conn.QueryAsync<PKMember>("select * from (select *, (unnest(proxy_tags)).prefix as prefix, (unnest(proxy_tags)).suffix as suffix from members where system = @System) as _ where prefix = @Prefix and suffix = @Suffix", new
{
System = system.Id,
Prefix = tag.Prefix,
Suffix = tag.Suffix
});
}
public async Task<SystemGuildSettings> GetSystemGuildSettings(PKSystem system, ulong guild)
{
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<SystemGuildSettings>(
"select * from system_guild where system = @System and guild = @Guild",
new {System = system.Id, Guild = guild}) ?? new SystemGuildSettings();
}
public async Task SetSystemGuildSettings(PKSystem system, ulong guild, SystemGuildSettings settings)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into system_guild (system, guild, proxy_enabled, autoproxy_mode, autoproxy_member) values (@System, @Guild, @ProxyEnabled, @AutoproxyMode, @AutoproxyMember) on conflict (system, guild) do update set proxy_enabled = @ProxyEnabled, autoproxy_mode = @AutoproxyMode, autoproxy_member = @AutoproxyMember", new
{
System = system.Id,
Guild = guild,
settings.ProxyEnabled,
settings.AutoproxyMode,
settings.AutoproxyMember
});
await _cache.InvalidateSystem(system);
}
public async Task<PKSystem> CreateSystem(string systemName = null) {
string hid;
do
{
hid = StringUtils.GenerateHid();
} while (await GetSystemByHid(hid) != null);
PKSystem system;
using (var conn = await _conn.Obtain())
system = await conn.QuerySingleAsync<PKSystem>("insert into systems (hid, name) values (@Hid, @Name) returning *", new { Hid = hid, Name = systemName });
_logger.Information("Created system {System}", system.Id);
// New system has no accounts, therefore nothing gets cached, therefore no need to invalidate caches right here
return system;
}
public async Task AddAccount(PKSystem system, ulong accountId) {
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id });
_logger.Information("Linked system {System} to account {Account}", system.Id, accountId);
await _cache.InvalidateSystem(system);
}
public async Task RemoveAccount(PKSystem system, ulong accountId) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id });
_logger.Information("Unlinked system {System} from account {Account}", system.Id, accountId);
await _cache.InvalidateSystem(system);
_cache.InvalidateAccounts(new [] { accountId });
}
public async Task<PKSystem> GetSystemByAccount(ulong accountId) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId });
}
public async Task<PKSystem> GetSystemByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKSystem> GetSystemByToken(string token) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where token = @Token", new { Token = token });
}
public async Task<PKSystem> GetSystemById(int id)
{
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKSystem>("select * from systems where id = @Id", new { Id = id });
}
public async Task SaveSystem(PKSystem system) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update systems set name = @Name, description = @Description, tag = @Tag, avatar_url = @AvatarUrl, token = @Token, ui_tz = @UiTz, description_privacy = @DescriptionPrivacy, member_list_privacy = @MemberListPrivacy, front_privacy = @FrontPrivacy, front_history_privacy = @FrontHistoryPrivacy where id = @Id", system);
_logger.Information("Updated system {@System}", system);
await _cache.InvalidateSystem(system);
}
public async Task DeleteSystem(PKSystem system)
{
using var conn = await _conn.Obtain();
// Fetch the list of accounts *before* deletion so we can cache-bust all of those
var accounts = (await conn.QueryAsync<ulong>("select uid from accounts where system = @Id", system)).ToArray();
await conn.ExecuteAsync("delete from systems where id = @Id", system);
_logger.Information("Deleted system {System}", system.Id);
_cache.InvalidateDeletedSystem(system.Id, accounts);
}
public async Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<ulong>("select uid from accounts where system = @Id", new { Id = system.Id });
}
public async Task DeleteAllSwitches(PKSystem system)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where system = @Id", system);
}
public async Task<ulong> GetTotalSystems()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(id) from systems");
}
public async Task<PKMember> CreateMember(PKSystem system, string name) {
string hid;
do
{
hid = StringUtils.GenerateHid();
} while (await GetMemberByHid(hid) != null);
PKMember member;
using (var conn = await _conn.Obtain())
member = await conn.QuerySingleAsync<PKMember>("insert into members (hid, system, name) values (@Hid, @SystemId, @Name) returning *", new {
Hid = hid,
SystemID = system.Id,
Name = name
});
_logger.Information("Created member {Member}", member.Id);
await _cache.InvalidateSystem(system);
return member;
}
public async Task<Dictionary<string,PKMember>> CreateMembersBulk(PKSystem system, Dictionary<string,string> names)
{
using (var conn = await _conn.Obtain())
using (var tx = conn.BeginTransaction())
{
var results = new Dictionary<string, PKMember>();
foreach (var name in names)
{
string hid;
do
{
hid = await conn.QuerySingleOrDefaultAsync<string>("SELECT @Hid WHERE NOT EXISTS (SELECT id FROM members WHERE hid = @Hid LIMIT 1)", new
{
Hid = StringUtils.GenerateHid()
});
} while (hid == null);
var member = await conn.QuerySingleAsync<PKMember>("INSERT INTO members (hid, system, name) VALUES (@Hid, @SystemId, @Name) RETURNING *", new
{
Hid = hid,
SystemID = system.Id,
Name = name.Value
});
results.Add(name.Key, member);
}
tx.Commit();
_logger.Information("Created {MemberCount} members for system {SystemID}", names.Count(), system.Hid);
await _cache.InvalidateSystem(system);
return results;
}
}
public async Task<PKMember> GetMemberById(int id) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where id = @Id", new { Id = id });
}
public async Task<PKMember> GetMemberByHid(string hid) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where hid = @Hid", new { Hid = hid.ToLower() });
}
public async Task<PKMember> GetMemberByName(PKSystem system, string name) {
// QueryFirst, since members can (in rare cases) share names
using (var conn = await _conn.Obtain())
return await conn.QueryFirstOrDefaultAsync<PKMember>("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id });
}
public IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName)
{
var sql = "select * from members where system = @SystemID";
if (orderByName) sql += " order by lower(name) asc";
return _conn.QueryStreamAsync<PKMember>(sql, new { SystemID = system.Id });
}
public async Task SaveMember(PKMember member) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update members set name = @Name, display_name = @DisplayName, description = @Description, color = @Color, avatar_url = @AvatarUrl, birthday = @Birthday, pronouns = @Pronouns, proxy_tags = @ProxyTags, keep_proxy = @KeepProxy, member_privacy = @MemberPrivacy where id = @Id", member);
_logger.Information("Updated member {@Member}", member);
await _cache.InvalidateSystem(member.System);
}
public async Task DeleteMember(PKMember member) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from members where id = @Id", member);
_logger.Information("Deleted member {@Member}", member);
await _cache.InvalidateSystem(member.System);
}
public async Task<MemberGuildSettings> GetMemberGuildSettings(PKMember member, ulong guild)
{
using var conn = await _conn.Obtain();
return await conn.QuerySingleOrDefaultAsync<MemberGuildSettings>(
"select * from member_guild where member = @Member and guild = @Guild", new { Member = member.Id, Guild = guild})
?? new MemberGuildSettings();
}
public async Task SetMemberGuildSettings(PKMember member, ulong guild, MemberGuildSettings settings)
{
using var conn = await _conn.Obtain();
await conn.ExecuteAsync(
"insert into member_guild (member, guild, display_name) values (@Member, @Guild, @DisplayName) on conflict (member, guild) do update set display_name = @Displayname",
new {Member = member.Id, Guild = guild, DisplayName = settings.DisplayName});
await _cache.InvalidateSystem(member.System);
}
public async Task<ulong> GetMemberMessageCount(PKMember member)
{
using (var conn = await _conn.Obtain())
return await conn.QuerySingleAsync<ulong>("select count(*) from messages where member = @Id", member);
}
public async Task<IEnumerable<MemberMessageCount>> GetMemberMessageCountBulk(PKSystem system)
{
using (var conn = await _conn.Obtain())
return await conn.QueryAsync<MemberMessageCount>(
@"SELECT messages.member, COUNT(messages.member) messagecount
FROM members
JOIN messages
ON members.id = messages.member
WHERE members.system = @System
GROUP BY messages.member",
new { System = system.Id });
}
public async Task<int> GetSystemMemberCount(PKSystem system, bool includePrivate)
{
var query = "select count(*) from members where system = @Id";
if (!includePrivate) query += " and member_privacy = 1"; // 1 = public
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<int>(query, system);
}
public async Task<ulong> GetTotalMembers()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(id) from members");
}
public async Task AddMessage(ulong senderId, ulong messageId, ulong guildId, ulong channelId, ulong originalMessage, PKMember member) {
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid)", new {
MessageId = messageId,
GuildId = guildId,
ChannelId = channelId,
MemberId = member.Id,
SenderId = senderId,
OriginalMid = originalMessage
});
_logger.Information("Stored message {Message} in channel {Channel}", messageId, channelId);
}
public async Task<FullMessage> GetMessage(ulong id)
{
using (var conn = await _conn.Obtain())
return (await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>("select messages.*, members.*, systems.* from messages, members, systems where (mid = @Id or original_mid = @Id) and messages.member = members.id and systems.id = members.system", (msg, member, system) => new FullMessage
{
Message = msg,
System = system,
Member = member
}, new { Id = id })).FirstOrDefault();
}
public async Task DeleteMessage(ulong id) {
using (var conn = await _conn.Obtain())
if (await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }) > 0)
_logger.Information("Deleted message {Message}", id);
}
public async Task DeleteMessagesBulk(IEnumerable<ulong> ids)
{
using (var conn = await _conn.Obtain())
{
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var foundCount = await conn.ExecuteAsync("delete from messages where mid = any(@Ids)", new {Ids = ids.Select(id => (long) id).ToArray()});
if (foundCount > 0)
_logger.Information("Bulk deleted messages {Messages}, {FoundCount} found", ids, foundCount);
}
}
public async Task<FullMessage> GetLastMessageInGuild(ulong account, ulong guild)
{
using var conn = await _conn.Obtain();
return (await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>("select messages.*, members.*, systems.* from messages, members, systems where messages.guild = @Guild and messages.sender = @Uid and messages.member = members.id and systems.id = members.system order by mid desc limit 1", (msg, member, system) => new FullMessage
{
Message = msg,
System = system,
Member = member
}, new { Uid = account, Guild = guild })).FirstOrDefault();
}
public async Task<ulong> GetTotalMessages()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(mid) from messages");
}
// Same as GuildConfig, but with ISet<ulong> as long[] instead.
public struct DatabaseCompatibleGuildConfig
{
public ulong Id { get; set; }
public ulong? LogChannel { get; set; }
public long[] LogBlacklist { get; set; }
public long[] Blacklist { get; set; }
public GuildConfig Into() =>
new GuildConfig
{
Id = Id,
LogChannel = LogChannel,
LogBlacklist = new HashSet<ulong>(LogBlacklist?.Select(c => (ulong) c) ?? new ulong[] {}),
Blacklist = new HashSet<ulong>(Blacklist?.Select(c => (ulong) c) ?? new ulong[]{})
};
}
public async Task<GuildConfig> GetOrCreateGuildConfig(ulong guild)
{
// When changing this, also see ProxyCache::GetGuildDataCached
using (var conn = await _conn.Obtain())
{
return (await conn.QuerySingleOrDefaultAsync<DatabaseCompatibleGuildConfig>(
"insert into servers (id) values (@Id) on conflict do nothing; select * from servers where id = @Id",
new {Id = guild})).Into();
}
}
public async Task SaveGuildConfig(GuildConfig cfg)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("insert into servers (id, log_channel, log_blacklist, blacklist) values (@Id, @LogChannel, @LogBlacklist, @Blacklist) on conflict (id) do update set log_channel = @LogChannel, log_blacklist = @LogBlacklist, blacklist = @Blacklist", new
{
cfg.Id,
cfg.LogChannel,
LogBlacklist = cfg.LogBlacklist.Select(c => (long) c).ToList(),
Blacklist = cfg.Blacklist.Select(c => (long) c).ToList()
});
_logger.Information("Updated guild configuration {@GuildCfg}", cfg);
_cache.InvalidateGuild(cfg.Id);
}
public async Task<AuxillaryProxyInformation> GetAuxillaryProxyInformation(ulong guild, PKSystem system, PKMember member)
{
using var conn = await _conn.Obtain();
var args = new {Guild = guild, System = system.Id, Member = member.Id};
var multi = await conn.QueryMultipleAsync(@"
select servers.* from servers where id = @Guild;
select * from system_guild where guild = @Guild and system = @System;
select * from member_guild where guild = @Guild and member = @Member", args);
return new AuxillaryProxyInformation
{
Guild = (await multi.ReadSingleOrDefaultAsync<DatabaseCompatibleGuildConfig>()).Into(),
SystemGuild = await multi.ReadSingleOrDefaultAsync<SystemGuildSettings>() ?? new SystemGuildSettings(),
MemberGuild = await multi.ReadSingleOrDefaultAsync<MemberGuildSettings>() ?? new MemberGuildSettings()
};
}
public async Task<PKMember> GetFirstFronter(PKSystem system)
{
// TODO: move to extension method since it doesn't rely on internals
var lastSwitch = await GetLatestSwitch(system);
if (lastSwitch == null) return null;
return await GetSwitchMembers(lastSwitch).FirstOrDefaultAsync();
}
public async Task AddSwitch(PKSystem system, IEnumerable<PKMember> members)
{
// Use a transaction here since we're doing multiple executed commands in one
using (var conn = await _conn.Obtain())
using (var tx = conn.BeginTransaction())
{
// First, we insert the switch itself
var sw = await conn.QuerySingleAsync<PKSwitch>("insert into switches(system) values (@System) returning *",
new {System = system.Id});
// Then we insert each member in the switch in the switch_members table
// TODO: can we parallelize this or send it in bulk somehow?
foreach (var member in members)
{
await conn.ExecuteAsync(
"insert into switch_members(switch, member) values(@Switch, @Member)",
new {Switch = sw.Id, Member = member.Id});
}
// Finally we commit the tx, since the using block will otherwise rollback it
tx.Commit();
_logger.Information("Registered switch {Switch} in system {System} with members {@Members}", sw.Id, system.Id, members.Select(m => m.Id));
}
}
public async Task AddSwitchesBulk(PKSystem system, IEnumerable<ImportedSwitch> switches)
{
// Read existing switches to enforce unique timestamps
var priorSwitches = new List<PKSwitch>();
await foreach (var sw in GetSwitches(system)) priorSwitches.Add(sw);
var lastSwitchId = priorSwitches.Any()
? priorSwitches.Max(x => x.Id)
: 0;
using (var conn = (PerformanceTrackingConnection) await _conn.Obtain())
{
using (var tx = conn.BeginTransaction())
{
// Import switches in bulk
using (var importer = conn.BeginBinaryImport("COPY switches (system, timestamp) FROM STDIN (FORMAT BINARY)"))
{
foreach (var sw in switches)
{
// If there's already a switch at this time, move on
if (priorSwitches.Any(x => x.Timestamp.Equals(sw.Timestamp)))
continue;
// Otherwise, add it to the importer
importer.StartRow();
importer.Write(system.Id, NpgsqlTypes.NpgsqlDbType.Integer);
importer.Write(sw.Timestamp, NpgsqlTypes.NpgsqlDbType.Timestamp);
}
importer.Complete(); // Commits the copy operation so dispose won't roll it back
}
// Get all switches that were created above and don't have members for ID lookup
var switchesWithoutMembers =
await conn.QueryAsync<PKSwitch>(@"
SELECT switches.*
FROM switches
LEFT JOIN switch_members
ON switch_members.switch = switches.id
WHERE switches.id > @LastSwitchId
AND switches.system = @System
AND switch_members.id IS NULL", new { LastSwitchId = lastSwitchId, System = system.Id });
// Import switch_members in bulk
using (var importer = conn.BeginBinaryImport("COPY switch_members (switch, member) FROM STDIN (FORMAT BINARY)"))
{
// Iterate over the switches we created above and set their members
foreach (var pkSwitch in switchesWithoutMembers)
{
// If this isn't in our import set, move on
var sw = switches.Select(x => (ImportedSwitch?) x).FirstOrDefault(x => x.Value.Timestamp.Equals(pkSwitch.Timestamp));
if (sw == null)
continue;
// Loop through associated members to add each to the switch
foreach (var m in sw.Value.Members)
{
// Skip switch-outs - these don't have switch_members
if (m == null)
continue;
importer.StartRow();
importer.Write(pkSwitch.Id, NpgsqlTypes.NpgsqlDbType.Integer);
importer.Write(m.Id, NpgsqlTypes.NpgsqlDbType.Integer);
}
}
importer.Complete(); // Commits the copy operation so dispose won't roll it back
}
tx.Commit();
}
}
_logger.Information("Completed bulk import of switches for system {0}", system.Hid);
}
public IAsyncEnumerable<PKSwitch> GetSwitches(PKSystem system)
{
// TODO: refactor the PKSwitch data structure to somehow include a hydrated member list
// (maybe when we get caching in?)
return _conn.QueryStreamAsync<PKSwitch>(
"select * from switches where system = @System order by timestamp desc",
new {System = system.Id});
}
public async Task<int> GetSwitchCount(PKSystem system)
{
using var conn = await _conn.Obtain();
return await conn.QuerySingleAsync<int>("select count(*) from switches where system = @Id", system);
}
public async IAsyncEnumerable<SwitchMembersListEntry> GetSwitchMembersList(PKSystem system, Instant start, Instant end)
{
// Wrap multiple commands in a single transaction for performance
using var conn = await _conn.Obtain();
using var tx = conn.BeginTransaction();
// Find the time of the last switch outside the range as it overlaps the range
// If no prior switch exists, the lower bound of the range remains the start time
var lastSwitch = await conn.QuerySingleOrDefaultAsync<Instant>(
@"SELECT COALESCE(MAX(timestamp), @Start)
FROM switches
WHERE switches.system = @System
AND switches.timestamp < @Start",
new { System = system.Id, Start = start });
// Then collect the time and members of all switches that overlap the range
var switchMembersEntries = conn.QueryStreamAsync<SwitchMembersListEntry>(
@"SELECT switch_members.member, switches.timestamp
FROM switches
LEFT JOIN switch_members
ON switches.id = switch_members.switch
WHERE switches.system = @System
AND (
switches.timestamp >= @Start
OR switches.timestamp = @LastSwitch
)
AND switches.timestamp < @End
ORDER BY switches.timestamp DESC",
new { System = system.Id, Start = start, End = end, LastSwitch = lastSwitch });
// Yield each value here
await foreach (var entry in switchMembersEntries)
yield return entry;
// Don't really need to worry about the transaction here, we're not doing any *writes*
}
public IAsyncEnumerable<PKMember> GetSwitchMembers(PKSwitch sw)
{
return _conn.QueryStreamAsync<PKMember>(
"select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch order by switch_members.id",
new {Switch = sw.Id});
}
public async Task<PKSwitch> GetLatestSwitch(PKSystem system) =>
await GetSwitches(system).FirstOrDefaultAsync();
public async Task MoveSwitch(PKSwitch sw, Instant time)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id",
new {Time = time, Id = sw.Id});
_logger.Information("Moved switch {Switch} to {Time}", sw.Id, time);
}
public async Task DeleteSwitch(PKSwitch sw)
{
using (var conn = await _conn.Obtain())
await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id});
_logger.Information("Deleted switch {Switch}");
}
public async Task<ulong> GetTotalSwitches()
{
using (var conn = await _conn.Obtain())
return await conn.ExecuteScalarAsync<ulong>("select count(id) from switches");
}
public async Task<IEnumerable<SwitchListEntry>> GetPeriodFronters(PKSystem system, Instant periodStart, Instant periodEnd)
{
// TODO: IAsyncEnumerable-ify this one
// Returns the timestamps and member IDs of switches overlapping the range, in chronological (newest first) order
var switchMembers = await GetSwitchMembersList(system, periodStart, periodEnd).ToListAsync();
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
Dictionary<int, PKMember> memberObjects;
using (var conn = await _conn.Obtain())
{
memberObjects = (
await conn.QueryAsync<PKMember>(
"select * from members where id = any(@Switches)", // lol postgres specific `= any()` syntax
new { Switches = switchMembers.Select(m => m.Member).Distinct().ToList() })
).ToDictionary(m => m.Id);
}
// Initialize entries - still need to loop to determine the TimespanEnd below
var entries =
from item in switchMembers
group item by item.Timestamp into g
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != 0).Select(x => memberObjects[x.Member]).ToList()
};
// Loop through every switch that overlaps the range and add it to the output list
// end time is the *FOLLOWING* switch's timestamp - we cheat by working backwards from the range end, so no dates need to be compared
var endTime = periodEnd;
var outList = new List<SwitchListEntry>();
foreach (var e in entries)
{
// Override the start time of the switch if it's outside the range (only true for the "out of range" switch we included above)
var switchStartClamped = e.TimespanStart < periodStart
? periodStart
: e.TimespanStart;
outList.Add(new SwitchListEntry
{
Members = e.Members,
TimespanStart = switchStartClamped,
TimespanEnd = endTime
});
// next switch's end is this switch's start (we're working backward in time)
endTime = e.TimespanStart;
}
return outList;
}
public async Task<FrontBreakdown> GetFrontBreakdown(PKSystem system, Instant periodStart, Instant periodEnd)
{
var dict = new Dictionary<PKMember, Duration>();
var noFronterDuration = Duration.Zero;
// Sum up all switch durations for each member
// switches with multiple members will result in the duration to add up to more than the actual period range
var actualStart = periodEnd; // will be "pulled" down
var actualEnd = periodStart; // will be "pulled" up
foreach (var sw in await GetPeriodFronters(system, periodStart, periodEnd))
{
var span = sw.TimespanEnd - sw.TimespanStart;
foreach (var member in sw.Members)
{
if (!dict.ContainsKey(member)) dict.Add(member, span);
else dict[member] += span;
}
if (sw.Members.Count == 0) noFronterDuration += span;
if (sw.TimespanStart < actualStart) actualStart = sw.TimespanStart;
if (sw.TimespanEnd > actualEnd) actualEnd = sw.TimespanEnd;
}
return new FrontBreakdown
{
MemberSwitchDurations = dict,
NoFronterDuration = noFronterDuration,
RangeStart = actualStart,
RangeEnd = actualEnd
};
}
}
}

View File

@@ -0,0 +1,195 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Microsoft.Extensions.Caching.Memory;
using Serilog;
namespace PluralKit.Core
{
public class ProxyCache
{
// We can NOT depend on IDataStore as that creates a cycle, since it needs access to call the invalidation methods
private IMemoryCache _cache;
private DbConnectionFactory _db;
private ILogger _logger;
public ProxyCache(IMemoryCache cache, DbConnectionFactory db, ILogger logger)
{
_cache = cache;
_db = db;
_logger = logger;
}
public Task InvalidateSystem(PKSystem system) => InvalidateSystem(system.Id);
public void InvalidateAccounts(IEnumerable<ulong> accounts)
{
foreach (var account in accounts)
_cache.Remove(KeyForAccount(account));
}
public void InvalidateDeletedSystem(int systemId, IEnumerable<ulong> accounts)
{
// Used when the system's already removed so we can't look up accounts
// We assume the account list is saved already somewhere and can be passed here (which is the case in Store)
_cache.Remove(KeyForSystem(systemId));
InvalidateAccounts(accounts);
}
public async Task InvalidateSystem(int systemId)
{
if (_cache.TryGetValue<CachedAccount>(KeyForSystem(systemId), out var systemCache))
{
// If we have the system cached here, just invalidate for all the accounts we have in the cache
_logger.Debug("Invalidating cache for system {System} and accounts {Accounts}", systemId, systemCache.Accounts);
_cache.Remove(KeyForSystem(systemId));
foreach (var account in systemCache.Accounts)
_cache.Remove(KeyForAccount(account));
return;
}
// If we don't, look up the accounts from the database and invalidate *those*
_cache.Remove(KeyForSystem(systemId));
using var conn = await _db.Obtain();
var accounts = (await conn.QueryAsync<ulong>("select uid from accounts where system = @System", new {System = systemId})).ToArray();
_logger.Debug("Invalidating cache for system {System} and accounts {Accounts}", systemId, accounts);
foreach (var account in accounts)
_cache.Remove(KeyForAccount(account));
}
public void InvalidateGuild(ulong guild)
{
_logger.Debug("Invalidating cache for guild {Guild}", guild);
_cache.Remove(KeyForGuild(guild));
}
public async Task<GuildConfig> GetGuildDataCached(ulong guild)
{
if (_cache.TryGetValue<GuildConfig>(KeyForGuild(guild), out var item))
{
_logger.Verbose("Cache hit for guild {Guild}", guild);
return item;
}
// When changing this, also see PostgresDataStore::GetOrCreateGuildConfig
using var conn = await _db.Obtain();
_logger.Verbose("Cache miss for guild {Guild}", guild);
var guildConfig = (await conn.QuerySingleOrDefaultAsync<PostgresDataStore.DatabaseCompatibleGuildConfig>(
"insert into servers (id) values (@Id) on conflict do nothing; select * from servers where id = @Id",
new {Id = guild})).Into();
_cache.CreateEntry(KeyForGuild(guild))
.SetValue(guildConfig)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromMinutes(30))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
return guildConfig;
}
public async Task<CachedAccount> GetAccountDataCached(ulong account)
{
if (_cache.TryGetValue<CachedAccount>(KeyForAccount(account), out var item))
{
_logger.Verbose("Cache hit for account {Account}", account);
return item;
}
_logger.Verbose("Cache miss for account {Account}", account);
var data = await GetAccountData(account);
if (data == null)
{
_logger.Debug("Cached data for account {Account} (no system)", account);
// If we didn't find any value, set a pretty long expiry and the value to null
_cache.CreateEntry(KeyForAccount(account))
.SetValue(null)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromHours(1))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
return null;
}
// If we *did* find the value, cache it for *every account in the system* with a shorter expiry
_logger.Debug("Cached data for system {System} and accounts {Account}", data.System.Id, data.Accounts);
foreach (var linkedAccount in data.Accounts)
{
_cache.CreateEntry(KeyForAccount(linkedAccount))
.SetValue(data)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromMinutes(20))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
// And also do it for the system itself so we can look up by that
_cache.CreateEntry(KeyForSystem(data.System.Id))
.SetValue(data)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromMinutes(20))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
}
return data;
}
private async Task<CachedAccount> GetAccountData(ulong account)
{
using var conn = await _db.Obtain();
// Doing this as two queries instead of a two-step join to avoid sending duplicate rows for the system over the network for each member
// This *may* be less efficient, haven't done too much stuff about this but having the system ID saved is very useful later on
var system = await conn.QuerySingleOrDefaultAsync<PKSystem>("select systems.* from accounts inner join systems on systems.id = accounts.system where accounts.uid = @Account", new { Account = account });
if (system == null) return null; // No system = no members = no cache value
// Fetches:
// - List of accounts in the system
// - List of members in the system
// - List of guild settings for the system (for every guild)
// - List of guild settings for each member (for every guild)
// I'm slightly worried the volume of guild settings will get too much, but for simplicity reasons I decided
// against caching them individually per-guild, since I can't imagine they'll be edited *that* much
var result = await conn.QueryMultipleAsync(@"
select uid from accounts where system = @System;
select * from members where system = @System;
select * from system_guild where system = @System;
select member_guild.* from members inner join member_guild on member_guild.member = members.id where members.system = @System;
", new {System = system.Id});
return new CachedAccount
{
System = system,
Accounts = (await result.ReadAsync<ulong>()).ToArray(),
Members = (await result.ReadAsync<PKMember>()).ToArray(),
SystemGuild = (await result.ReadAsync<SystemGuildSettings>()).ToArray(),
MemberGuild = (await result.ReadAsync<MemberGuildSettings>()).ToArray()
};
}
private string KeyForAccount(ulong account) => $"_account_cache_{account}";
private string KeyForSystem(int system) => $"_system_cache_{system}";
private string KeyForGuild(ulong guild) => $"_guild_cache_{guild}";
}
public class CachedAccount
{
public PKSystem System;
public PKMember[] Members;
public SystemGuildSettings[] SystemGuild;
public MemberGuildSettings[] MemberGuild;
public ulong[] Accounts;
public SystemGuildSettings SettingsForGuild(ulong guild) =>
SystemGuild.FirstOrDefault(s => s.Guild == guild) ?? new SystemGuildSettings();
public MemberGuildSettings SettingsForMemberGuild(int memberId, ulong guild) =>
MemberGuild.FirstOrDefault(m => m.Member == memberId && m.Guild == guild) ?? new MemberGuildSettings();
}
}

View File

@@ -0,0 +1,74 @@
using System;
using System.IO;
using System.Threading.Tasks;
using Dapper;
using Npgsql;
using Serilog;
namespace PluralKit.Core {
public class SchemaService
{
private const int TargetSchemaVersion = 3;
private DbConnectionFactory _conn;
private ILogger _logger;
public SchemaService(DbConnectionFactory conn, ILogger logger)
{
_conn = conn;
_logger = logger.ForContext<SchemaService>();
}
public static void Initialize()
{
// Without these it'll still *work* but break at the first launch + probably cause other small issues
NpgsqlConnection.GlobalTypeMapper.MapComposite<ProxyTag>("proxy_tag");
NpgsqlConnection.GlobalTypeMapper.MapEnum<PrivacyLevel>("privacy_level");
}
public async Task ApplyMigrations()
{
for (var version = 0; version <= TargetSchemaVersion; version++)
await ApplyMigration(version);
}
private async Task ApplyMigration(int migrationId)
{
// migrationId is the *target* version
using var conn = await _conn.Obtain();
using var tx = conn.BeginTransaction();
// See if we even have the info table... if not, we implicitly define the version as -1
// This means migration 0 will get executed, which ensures we're at a consistent state.
// *Technically* this also means schema version 0 will be identified as -1, but since we're only doing these
// checks in the above for loop, this doesn't matter.
var hasInfoTable = await conn.QuerySingleOrDefaultAsync<int>("select count(*) from information_schema.tables where table_name = 'info'") == 1;
int currentVersion;
if (hasInfoTable)
currentVersion = await conn.QuerySingleOrDefaultAsync<int>("select schema_version from info");
else currentVersion = -1;
if (currentVersion >= migrationId)
return; // Don't execute the migration if we're already at the target version.
using var stream = typeof(SchemaService).Assembly.GetManifestResourceStream($"PluralKit.Core.Migrations.{migrationId}.sql");
if (stream == null) throw new ArgumentException("Invalid migration ID");
using var reader = new StreamReader(stream);
var migrationQuery = await reader.ReadToEndAsync();
_logger.Information("Current schema version is {CurrentVersion}, applying migration {MigrationId}", currentVersion, migrationId);
await conn.ExecuteAsync(migrationQuery, transaction: tx);
tx.Commit();
// If the above migration creates new enum/composite types, we must tell Npgsql to reload the internal type caches
// This will propagate to every other connection as well, since it marks the global type mapper collection dirty.
// TODO: find a way to get around the cast to our internal tracker wrapper... this could break if that ever changes
((PerformanceTrackingConnection) conn)._impl.ReloadTypes();
}
}
}