Migrate to type-safe model ID structs

This commit is contained in:
Ske 2020-06-14 21:37:04 +02:00
parent e5ac5edc35
commit b9cbd241de
21 changed files with 167 additions and 41 deletions

View File

@ -276,7 +276,7 @@ namespace PluralKit.Bot
public LookupContext LookupContextFor(PKSystem target) =>
System?.Id == target.Id ? LookupContext.ByOwner : LookupContext.ByNonOwner;
public LookupContext LookupContextFor(int systemId) =>
public LookupContext LookupContextFor(SystemId systemId) =>
System?.Id == systemId ? LookupContext.ByOwner : LookupContext.ByNonOwner;
public Context CheckSystemPrivacy(PKSystem target, PrivacyLevel level)

View File

@ -87,7 +87,7 @@ namespace PluralKit.Bot
var fronters = ctx.MessageContext.LastSwitchMembers;
var relevantMember = ctx.MessageContext.AutoproxyMode switch
{
AutoproxyMode.Front => fronters.Count > 0 ? await _db.Execute(c => c.QueryMember(fronters[0])) : null,
AutoproxyMode.Front => fronters.Length > 0 ? await _db.Execute(c => c.QueryMember(fronters[0])) : null,
AutoproxyMode.Member => await _db.Execute(c => c.QueryMember(ctx.MessageContext.AutoproxyMember.Value)),
_ => null
};
@ -97,7 +97,7 @@ namespace PluralKit.Bot
break;
case AutoproxyMode.Front:
{
if (fronters.Count == 0)
if (fronters.Length == 0)
eb.WithDescription("Autoproxy is currently set to **front mode** in this server, but there are currently no fronters registered. Use the `pk;switch` command to log a switch.");
else
{
@ -123,7 +123,7 @@ namespace PluralKit.Bot
return eb.Build();
}
private Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, int? autoproxyMember) =>
private Task UpdateAutoproxy(Context ctx, AutoproxyMode autoproxyMode, MemberId? autoproxyMember) =>
_db.Execute(c =>
c.ExecuteAsync(
"update system_guild set autoproxy_mode = @autoproxyMode, autoproxy_member = @autoproxyMember where guild = @guild and system = @system",

View File

@ -48,7 +48,7 @@ namespace PluralKit.Bot
AutoproxyMode.Member when ctx.AutoproxyMember != null =>
members.FirstOrDefault(m => m.Id == ctx.AutoproxyMember),
AutoproxyMode.Front when ctx.LastSwitchMembers.Count > 0 =>
AutoproxyMode.Front when ctx.LastSwitchMembers.Length > 0 =>
members.FirstOrDefault(m => m.Id == ctx.LastSwitchMembers[0]),
AutoproxyMode.Latch when ctx.LastMessageMember != null && !IsLatchExpired(ctx.LastMessage) =>

View File

@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using App.Metrics;
@ -36,13 +38,14 @@ namespace PluralKit.Core
public static void InitStatic()
{
DefaultTypeMap.MatchNamesWithUnderscores = true;
// Dapper by default tries to pass ulongs to Npgsql, which rejects them since PostgreSQL technically
// doesn't support unsigned types on its own.
// Instead we add a custom mapper to encode them as signed integers instead, converting them back and forth.
SqlMapper.RemoveTypeMap(typeof(ulong));
SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler());
SqlMapper.AddTypeHandler(new UlongArrayHandler());
DefaultTypeMap.MatchNamesWithUnderscores = true;
NpgsqlConnection.GlobalTypeMapper.UseNodaTime();
// With the thing we add above, Npgsql already handles NodaTime integration
@ -51,6 +54,14 @@ namespace PluralKit.Core
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<Instant>());
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<LocalDate>());
// Add ID types to Dapper
SqlMapper.AddTypeHandler(new NumericIdHandler<SystemId, int>(i => new SystemId(i)));
SqlMapper.AddTypeHandler(new NumericIdHandler<MemberId, int>(i => new MemberId(i)));
SqlMapper.AddTypeHandler(new NumericIdHandler<SwitchId, int>(i => new SwitchId(i)));
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<SystemId, int>(i => new SystemId(i)));
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<MemberId, int>(i => new MemberId(i)));
SqlMapper.AddTypeHandler(new NumericIdArrayHandler<SwitchId, int>(i => new SwitchId(i)));
// Register our custom types to Npgsql
// Without these it'll still *work* but break at the first launch + probably cause other small issues
NpgsqlConnection.GlobalTypeMapper.MapComposite<ProxyTag>("proxy_tag");
@ -153,5 +164,37 @@ namespace PluralKit.Core
public override ulong[] Parse(object value) => Array.ConvertAll((long[]) value, i => (ulong) i);
}
private class NumericIdHandler<T, TInner>: SqlMapper.TypeHandler<T>
where T: INumericId<T, TInner>
where TInner: IEquatable<TInner>, IComparable<TInner>
{
private readonly Func<TInner, T> _factory;
public NumericIdHandler(Func<TInner, T> factory)
{
_factory = factory;
}
public override void SetValue(IDbDataParameter parameter, T value) => parameter.Value = value.Value;
public override T Parse(object value) => _factory((TInner) value);
}
private class NumericIdArrayHandler<T, TInner>: SqlMapper.TypeHandler<T[]>
where T: INumericId<T, TInner>
where TInner: IEquatable<TInner>, IComparable<TInner>
{
private readonly Func<TInner, T> _factory;
public NumericIdArrayHandler(Func<TInner, T> factory)
{
_factory = factory;
}
public override void SetValue(IDbDataParameter parameter, T[] value) => parameter.Value = Array.ConvertAll(value, v => v.Value);
public override T[] Parse(object value) => Array.ConvertAll((TInner[]) value, v => _factory(v));
}
}
}

View File

@ -10,18 +10,18 @@ namespace PluralKit.Core
/// </summary>
public class MessageContext
{
public int? SystemId { get; }
public SystemId? SystemId { get; }
public ulong? LogChannel { get; }
public bool InBlacklist { get; }
public bool InLogBlacklist { get; }
public bool LogCleanupEnabled { get; }
public bool ProxyEnabled { get; }
public AutoproxyMode AutoproxyMode { get; }
public int? AutoproxyMember { get; }
public MemberId? AutoproxyMember { get; }
public ulong? LastMessage { get; }
public int? LastMessageMember { get; }
public int LastSwitch { get; }
public IReadOnlyList<int> LastSwitchMembers { get; } = new int[0];
public MemberId? LastMessageMember { get; }
public SwitchId LastSwitch { get; }
public MemberId[] LastSwitchMembers { get; } = new MemberId[0];
public Instant LastSwitchTimestamp { get; }
public string? SystemTag { get; }
public string? SystemAvatar { get; }

View File

@ -8,7 +8,7 @@ namespace PluralKit.Core
/// </summary>
public class ProxyMember
{
public int Id { get; }
public MemberId Id { get; }
public IReadOnlyCollection<ProxyTag> ProxyTags { get; } = new ProxyTag[0];
public bool KeepProxy { get; }

View File

@ -9,10 +9,10 @@ namespace PluralKit.Core
{
public static class DatabaseViewsExt
{
public static Task<IEnumerable<SystemFronter>> QueryCurrentFronters(this IPKConnection conn, int system) =>
public static Task<IEnumerable<SystemFronter>> QueryCurrentFronters(this IPKConnection conn, SystemId system) =>
conn.QueryAsync<SystemFronter>("select * from system_fronters where system = @system", new {system});
public static Task<IEnumerable<ListedMember>> QueryMemberList(this IPKConnection conn, int system, PrivacyLevel? privacyFilter = null, string? filter = null, bool includeDescriptionInNameFilter = false)
public static Task<IEnumerable<ListedMember>> QueryMemberList(this IPKConnection conn, SystemId system, PrivacyLevel? privacyFilter = null, string? filter = null, bool includeDescriptionInNameFilter = false)
{
StringBuilder query = new StringBuilder("select * from member_list where system = @system");

View File

@ -4,10 +4,10 @@ namespace PluralKit.Core
{
public class SystemFronter
{
public int SystemId { get; }
public int SwitchId { get; }
public SystemId SystemId { get; }
public SwitchId SwitchId { get; }
public Instant SwitchTimestamp { get; }
public int MemberId { get; }
public MemberId MemberId { get; }
public string MemberHid { get; }
public string MemberName { get; }
}

View File

@ -0,0 +1,11 @@
using System;
namespace PluralKit.Core
{
public interface INumericId<T, out TInner>: IEquatable<T>, IComparable<T>
where T: INumericId<T, TInner>
where TInner: IEquatable<TInner>, IComparable<TInner>
{
public TInner Value { get; }
}
}

View File

@ -3,7 +3,7 @@ namespace PluralKit.Core
{
public class MemberGuildSettings
{
public int Member { get; }
public MemberId Member { get; }
public ulong Guild { get; }
public string? DisplayName { get; }
public string? AvatarUrl { get; }

View File

@ -0,0 +1,24 @@
namespace PluralKit.Core
{
public readonly struct MemberId: INumericId<MemberId, int>
{
public int Value { get; }
public MemberId(int value)
{
Value = value;
}
public bool Equals(MemberId other) => Value == other.Value;
public override bool Equals(object obj) => obj is MemberId other && Equals(other);
public override int GetHashCode() => Value;
public static bool operator ==(MemberId left, MemberId right) => left.Equals(right);
public static bool operator !=(MemberId left, MemberId right) => !left.Equals(right);
public int CompareTo(MemberId other) => Value.CompareTo(other.Value);
}
}

View File

@ -7,22 +7,22 @@ namespace PluralKit.Core
{
public static class ModelQueryExt
{
public static Task<PKSystem?> QuerySystem(this IPKConnection conn, int id) =>
public static Task<PKSystem?> QuerySystem(this IPKConnection conn, SystemId id) =>
conn.QueryFirstOrDefaultAsync<PKSystem?>("select * from systems where id = @id", new {id});
public static Task<PKMember?> QueryMember(this IPKConnection conn, int id) =>
public static Task<PKMember?> QueryMember(this IPKConnection conn, MemberId id) =>
conn.QueryFirstOrDefaultAsync<PKMember?>("select * from members where id = @id", new {id});
public static Task<GuildConfig> QueryOrInsertGuildConfig(this IPKConnection conn, ulong guild) =>
conn.QueryFirstAsync<GuildConfig>("insert into servers (id) values (@guild) on conflict (id) do update set id = @guild returning *", new {guild});
public static Task<SystemGuildSettings> QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, int system) =>
public static Task<SystemGuildSettings> QueryOrInsertSystemGuildConfig(this IPKConnection conn, ulong guild, SystemId system) =>
conn.QueryFirstAsync<SystemGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, system});
public static Task<MemberGuildSettings> QueryOrInsertMemberGuildConfig(
this IPKConnection conn, ulong guild, int member) =>
this IPKConnection conn, ulong guild, MemberId member) =>
conn.QueryFirstAsync<MemberGuildSettings>(
"insert into member_guild (guild, member) values (@guild, @member) on conflict (guild, member) do update set guild = @guild, member = @member returning *",
new {guild, member});

View File

@ -8,9 +8,9 @@ using NodaTime.Text;
namespace PluralKit.Core {
public class PKMember
{
public int Id { get; }
public MemberId Id { get; }
public string Hid { get; set; }
public int System { get; set; }
public SystemId System { get; set; }
public string Color { get; set; }
public string AvatarUrl { get; set; }
public string Name { get; set; }

View File

@ -3,8 +3,8 @@
namespace PluralKit.Core {
public class PKSwitch
{
public int Id { get; }
public int System { get; set; }
public SwitchId Id { get; }
public SystemId System { get; set; }
public Instant Timestamp { get; }
}
}

View File

@ -8,7 +8,7 @@ namespace PluralKit.Core {
public class PKSystem
{
// Additions here should be mirrored in SystemStore::Save
[Key] public int Id { get; }
[Key] public SystemId Id { get; }
public string Hid { get; }
public string Name { get; set; }
public string Description { get; set; }

View File

@ -0,0 +1,24 @@
namespace PluralKit.Core
{
public readonly struct SwitchId: INumericId<SwitchId, int>
{
public int Value { get; }
public SwitchId(int value)
{
Value = value;
}
public bool Equals(SwitchId other) => Value == other.Value;
public override bool Equals(object obj) => obj is SwitchId other && Equals(other);
public override int GetHashCode() => Value;
public static bool operator ==(SwitchId left, SwitchId right) => left.Equals(right);
public static bool operator !=(SwitchId left, SwitchId right) => !left.Equals(right);
public int CompareTo(SwitchId other) => Value.CompareTo(other.Value);
}
}

View File

@ -2,10 +2,10 @@
{
public class SystemGuildSettings
{
public ulong Guild { get; }
public SystemId Guild { get; }
public bool ProxyEnabled { get; } = true;
public AutoproxyMode AutoproxyMode { get; } = AutoproxyMode.Off;
public int? AutoproxyMember { get; }
public MemberId? AutoproxyMember { get; }
}
}

View File

@ -0,0 +1,24 @@
namespace PluralKit.Core
{
public readonly struct SystemId: INumericId<SystemId, int>
{
public int Value { get; }
public SystemId(int value)
{
Value = value;
}
public bool Equals(SystemId other) => Value == other.Value;
public override bool Equals(object obj) => obj is SystemId other && Equals(other);
public override int GetHashCode() => Value;
public static bool operator ==(SystemId left, SystemId right) => left.Equals(right);
public static bool operator !=(SystemId left, SystemId right) => !left.Equals(right);
public int CompareTo(SystemId other) => Value.CompareTo(other.Value);
}
}

View File

@ -45,7 +45,7 @@ namespace PluralKit.Core {
public struct SwitchMembersListEntry
{
public int Member;
public MemberId Member;
public Instant Timestamp;
}
@ -131,7 +131,7 @@ namespace PluralKit.Core {
/// Gets a system by its internal member ID.
/// </summary>
/// <returns>The <see cref="PKMember"/> with the given internal ID, or null if no member was found.</returns>
Task<PKMember> GetMemberById(int memberId);
Task<PKMember> GetMemberById(MemberId memberId);
/// <summary>
/// Gets a member by its user-facing human ID.
@ -195,7 +195,7 @@ namespace PluralKit.Core {
/// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param>
/// <param name="proxiedMemberId">The member (and by extension system) that was proxied.</param>
/// <returns></returns>
Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId);
Task AddMessage(IPKConnection conn, ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId);
/// <summary>
/// Deletes a message from the data store.

View File

@ -125,7 +125,7 @@ namespace PluralKit.Core {
return member;
}
public async Task<PKMember> GetMemberById(int id) {
public async Task<PKMember> GetMemberById(MemberId id) {
using (var conn = await _conn.Obtain())
return await conn.QuerySingleOrDefaultAsync<PKMember>("select * from members where id = @Id", new { Id = id });
}
@ -177,7 +177,7 @@ namespace PluralKit.Core {
return await conn.ExecuteScalarAsync<ulong>("select count(id) from members");
}
public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId) {
public async Task AddMessage(IPKConnection conn, ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, MemberId proxiedMemberId) {
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new {
MessageId = postedMessageId,
@ -334,7 +334,7 @@ namespace PluralKit.Core {
// query DB for all members involved in any of the switches above and collect into a dictionary for future use
// this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary
// key used in GetPerMemberSwitchDuration below
Dictionary<int, PKMember> memberObjects;
Dictionary<MemberId, PKMember> memberObjects;
using (var conn = await _conn.Obtain())
{
memberObjects = (
@ -351,7 +351,7 @@ namespace PluralKit.Core {
select new SwitchListEntry
{
TimespanStart = g.Key,
Members = g.Where(x => x.Member != 0).Select(x => memberObjects[x.Member]).ToList()
Members = g.Where(x => x.Member != default(MemberId)).Select(x => memberObjects[x.Member]).ToList()
};
// Loop through every switch that overlaps the range and add it to the output list

View File

@ -16,14 +16,14 @@ namespace PluralKit.Core
{
public class BulkImporter: IAsyncDisposable
{
private readonly int _systemId;
private readonly SystemId _systemId;
private readonly IPKConnection _conn;
private readonly IPKTransaction _tx;
private readonly Dictionary<string, int> _knownMembers = new Dictionary<string, int>();
private readonly Dictionary<string, MemberId> _knownMembers = new Dictionary<string, MemberId>();
private readonly Dictionary<string, PKMember> _existingMembersByHid = new Dictionary<string, PKMember>();
private readonly Dictionary<string, PKMember> _existingMembersByName = new Dictionary<string, PKMember>();
private BulkImporter(int systemId, IPKConnection conn, IPKTransaction tx)
private BulkImporter(SystemId systemId, IPKConnection conn, IPKTransaction tx)
{
_systemId = systemId;
_conn = conn;
@ -124,7 +124,7 @@ namespace PluralKit.Core
// Fetch the existing switches in the database so we can avoid duplicates
var existingSwitches = (await _conn.QueryAsync<PKSwitch>("select * from switches where system = @System", new {System = _systemId})).ToList();
var existingTimestamps = existingSwitches.Select(sw => sw.Timestamp).ToImmutableHashSet();
var lastSwitchId = existingSwitches.Count != 0 ? existingSwitches.Select(sw => sw.Id).Max() : -1;
var lastSwitchId = existingSwitches.Count != 0 ? existingSwitches.Select(sw => sw.Id).Max() : (SwitchId?) null;
// Import switch definitions
var importedSwitches = new Dictionary<Instant, SwitchInfo>();
@ -152,7 +152,7 @@ namespace PluralKit.Core
// IDs are sequential, so any ID in this system, with a switch ID > the last max, will be one we just added
var justAddedSwitches = await _conn.QueryAsync<PKSwitch>(
"select * from switches where system = @System and id > @LastSwitchId",
new {System = _systemId, LastSwitchId = lastSwitchId});
new {System = _systemId, LastSwitchId = lastSwitchId?.Value ?? -1});
// Lastly, import the switch members
await using (var importer = _conn.BeginBinaryImport("copy switch_members (switch, member) from stdin (format binary)"))