feat: cache Discord DM channels in database

This commit is contained in:
spiral 2022-01-22 02:47:47 -05:00
parent ddbf0e8691
commit 89c44a3482
No known key found for this signature in database
GPG Key ID: A6059F0CA0E1BD31
14 changed files with 127 additions and 48 deletions

View File

@ -20,7 +20,6 @@ public interface IDiscordCache
public Task<ulong> GetOwnUser(); public Task<ulong> GetOwnUser();
public Task<Guild?> TryGetGuild(ulong guildId); public Task<Guild?> TryGetGuild(ulong guildId);
public Task<Channel?> TryGetChannel(ulong channelId); public Task<Channel?> TryGetChannel(ulong channelId);
public Task<Channel?> TryGetDmChannel(ulong userId);
public Task<User?> TryGetUser(ulong userId); public Task<User?> TryGetUser(ulong userId);
public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId); public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId);
public Task<Role?> TryGetRole(ulong roleId); public Task<Role?> TryGetRole(ulong roleId);

View File

@ -7,7 +7,6 @@ namespace Myriad.Cache;
public class MemoryDiscordCache: IDiscordCache public class MemoryDiscordCache: IDiscordCache
{ {
private readonly ConcurrentDictionary<ulong, Channel> _channels = new(); private readonly ConcurrentDictionary<ulong, Channel> _channels = new();
private readonly ConcurrentDictionary<ulong, ulong> _dmChannels = new();
private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new(); private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new();
private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds = new(); private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds = new();
private readonly ConcurrentDictionary<ulong, Role> _roles = new(); private readonly ConcurrentDictionary<ulong, Role> _roles = new();
@ -35,10 +34,7 @@ public class MemoryDiscordCache: IDiscordCache
if (channel.Recipients != null) if (channel.Recipients != null)
foreach (var recipient in channel.Recipients) foreach (var recipient in channel.Recipients)
{
_dmChannels[recipient.Id] = channel.Id;
await SaveUser(recipient); await SaveUser(recipient);
}
} }
public ValueTask SaveOwnUser(ulong userId) public ValueTask SaveOwnUser(ulong userId)
@ -140,13 +136,6 @@ public class MemoryDiscordCache: IDiscordCache
return Task.FromResult(channel); return Task.FromResult(channel);
} }
public Task<Channel?> TryGetDmChannel(ulong userId)
{
if (!_dmChannels.TryGetValue(userId, out var channelId))
return Task.FromResult((Channel?)null);
return TryGetChannel(channelId);
}
public Task<User?> TryGetUser(ulong userId) public Task<User?> TryGetUser(ulong userId)
{ {
_users.TryGetValue(userId, out var user); _users.TryGetValue(userId, out var user);

View File

@ -58,17 +58,6 @@ public static class CacheExtensions
return restChannel; return restChannel;
} }
public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest,
ulong recipientId)
{
if (await cache.TryGetDmChannel(recipientId) is { } cacheChannel)
return cacheChannel;
var restChannel = await rest.CreateDm(recipientId);
await cache.SaveChannel(restChannel);
return restChannel;
}
public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong channelOrThread) public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong channelOrThread)
{ {
var channel = await cache.GetChannel(channelOrThread); var channel = await cache.GetChannel(channelOrThread);

View File

@ -17,12 +17,14 @@ public class Api
private readonly BotConfig _botConfig; private readonly BotConfig _botConfig;
private readonly DispatchService _dispatch; private readonly DispatchService _dispatch;
private readonly ModelRepository _repo; private readonly ModelRepository _repo;
private readonly PrivateChannelService _dmCache;
public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch) public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch, PrivateChannelService dmCache)
{ {
_botConfig = botConfig; _botConfig = botConfig;
_repo = repo; _repo = repo;
_dispatch = dispatch; _dispatch = dispatch;
_dmCache = dmCache;
} }
public async Task GetToken(Context ctx) public async Task GetToken(Context ctx)
@ -35,17 +37,17 @@ public class Api
try try
{ {
// DM the user a security disclaimer, and then the token in a separate message (for easy copying on mobile) // DM the user a security disclaimer, and then the token in a separate message (for easy copying on mobile)
var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id);
await ctx.Rest.CreateMessage(dm.Id, await ctx.Rest.CreateMessage(dm,
new MessageRequest new MessageRequest
{ {
Content = $"{Emojis.Warn} Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure." Content = $"{Emojis.Warn} Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure."
+ $" If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`.\n\nYour token is below:" + $" If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`.\n\nYour token is below:"
}); });
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token }); await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token });
if (_botConfig.IsBetaBot) if (_botConfig.IsBetaBot)
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest await ctx.Rest.CreateMessage(dm, new MessageRequest
{ {
Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>." Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>."
+ " You need to use this URL instead of the base URL listed on the documentation website." + " You need to use this URL instead of the base URL listed on the documentation website."
@ -84,8 +86,8 @@ public class Api
try try
{ {
// DM the user an invalidation disclaimer, and then the token in a separate message (for easy copying on mobile) // DM the user an invalidation disclaimer, and then the token in a separate message (for easy copying on mobile)
var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id);
await ctx.Rest.CreateMessage(dm.Id, await ctx.Rest.CreateMessage(dm,
new MessageRequest new MessageRequest
{ {
Content = $"{Emojis.Warn} Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\n\nYour token is below:" Content = $"{Emojis.Warn} Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\n\nYour token is below:"
@ -94,10 +96,10 @@ public class Api
// Make the new token after sending the first DM; this ensures if we can't DM, we also don't end up // Make the new token after sending the first DM; this ensures if we can't DM, we also don't end up
// breaking their existing token as a side effect :) // breaking their existing token as a side effect :)
var token = await MakeAndSetNewToken(ctx.System); var token = await MakeAndSetNewToken(ctx.System);
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token }); await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token });
if (_botConfig.IsBetaBot) if (_botConfig.IsBetaBot)
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest await ctx.Rest.CreateMessage(dm, new MessageRequest
{ {
Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>." Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>."
+ " You need to use this URL instead of the base URL listed on the documentation website." + " You need to use this URL instead of the base URL listed on the documentation website."

View File

@ -17,6 +17,7 @@ public class ImportExport
{ {
private readonly HttpClient _client; private readonly HttpClient _client;
private readonly DataFileService _dataFiles; private readonly DataFileService _dataFiles;
private readonly PrivateChannelService _dmCache;
private readonly JsonSerializerSettings _settings = new() private readonly JsonSerializerSettings _settings = new()
{ {
@ -24,10 +25,11 @@ public class ImportExport
DateParseHandling = DateParseHandling.None DateParseHandling = DateParseHandling.None
}; };
public ImportExport(DataFileService dataFiles, HttpClient client) public ImportExport(DataFileService dataFiles, HttpClient client, PrivateChannelService dmCache)
{ {
_dataFiles = dataFiles; _dataFiles = dataFiles;
_client = client; _client = client;
_dmCache = dmCache;
} }
public async Task Import(Context ctx) public async Task Import(Context ctx)
@ -110,12 +112,12 @@ public class ImportExport
try try
{ {
var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id);
var msg = await ctx.Rest.CreateMessage(dm.Id, var msg = await ctx.Rest.CreateMessage(dm,
new MessageRequest { Content = $"{Emojis.Success} Here you go!" }, new MessageRequest { Content = $"{Emojis.Success} Here you go!" },
new[] { new MultipartFile("system.json", stream, null) }); new[] { new MultipartFile("system.json", stream, null) });
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" }); await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" });
// If the original message wasn't posted in DMs, send a public reminder // If the original message wasn't posted in DMs, send a public reminder
if (ctx.Channel.Type != Channel.ChannelType.Dm) if (ctx.Channel.Type != Channel.ChannelType.Dm)

View File

@ -28,12 +28,13 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
private readonly DiscordApiClient _rest; private readonly DiscordApiClient _rest;
private readonly ILifetimeScope _services; private readonly ILifetimeScope _services;
private readonly CommandTree _tree; private readonly CommandTree _tree;
private readonly PrivateChannelService _dmCache;
public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean,
IMetrics metrics, ProxyService proxy, IMetrics metrics, ProxyService proxy,
CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config,
ModelRepository repo, IDiscordCache cache, ModelRepository repo, IDiscordCache cache,
Bot bot, Cluster cluster, DiscordApiClient rest) Bot bot, Cluster cluster, DiscordApiClient rest, PrivateChannelService dmCache)
{ {
_lastMessageCache = lastMessageCache; _lastMessageCache = lastMessageCache;
_loggerClean = loggerClean; _loggerClean = loggerClean;
@ -48,6 +49,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
_bot = bot; _bot = bot;
_cluster = cluster; _cluster = cluster;
_rest = rest; _rest = rest;
_dmCache = dmCache;
} }
// for now, only return error messages for explicit commands // for now, only return error messages for explicit commands
@ -66,6 +68,10 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return; if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return;
if (IsDuplicateMessage(evt)) return; if (IsDuplicateMessage(evt)) return;
// spawn off saving the private channel into another thread
// it is not a fatal error if this fails, and it shouldn't block message processing
_ = _dmCache.TrySavePrivateChannel(evt);
var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null; var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null;
var channel = await _cache.GetChannel(evt.ChannelId); var channel = await _cache.GetChannel(evt.ChannelId);
var rootChannel = await _cache.GetRootChannel(evt.ChannelId); var rootChannel = await _cache.GetRootChannel(evt.ChannelId);

View File

@ -26,10 +26,11 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly ModelRepository _repo; private readonly ModelRepository _repo;
private readonly DiscordApiClient _rest; private readonly DiscordApiClient _rest;
private readonly PrivateChannelService _dmCache;
public ReactionAdded(ILogger logger, IDatabase db, ModelRepository repo, public ReactionAdded(ILogger logger, IDatabase db, ModelRepository repo,
CommandMessageService commandMessageService, IDiscordCache cache, Bot bot, Cluster cluster, CommandMessageService commandMessageService, IDiscordCache cache, Bot bot, Cluster cluster,
DiscordApiClient rest, EmbedService embeds) DiscordApiClient rest, EmbedService embeds, PrivateChannelService dmCache)
{ {
_db = db; _db = db;
_repo = repo; _repo = repo;
@ -40,6 +41,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
_rest = rest; _rest = rest;
_embeds = embeds; _embeds = embeds;
_logger = logger.ForContext<ReactionAdded>(); _logger = logger.ForContext<ReactionAdded>();
_dmCache = dmCache;
} }
public async Task Handle(int shardId, MessageReactionAddEvent evt) public async Task Handle(int shardId, MessageReactionAddEvent evt)
@ -168,9 +170,9 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
// Try to DM the user info about the message // Try to DM the user info about the message
try try
{ {
var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId); var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId);
if (msg.Member != null) if (msg.Member != null)
await _rest.CreateMessage(dm.Id, new MessageRequest await _rest.CreateMessage(dm, new MessageRequest
{ {
Embed = await _embeds.CreateMemberEmbed( Embed = await _embeds.CreateMemberEmbed(
msg.System, msg.System,
@ -182,7 +184,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
}); });
await _rest.CreateMessage( await _rest.CreateMessage(
dm.Id, dm,
new MessageRequest { Embed = await _embeds.CreateMessageInfoEmbed(msg, true) } new MessageRequest { Embed = await _embeds.CreateMessageInfoEmbed(msg, true) }
); );
} }
@ -234,15 +236,15 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
// If not, tell them in DMs (if we can) // If not, tell them in DMs (if we can)
try try
{ {
var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId); var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId);
await _rest.CreateMessage(dm.Id, await _rest.CreateMessage(dm,
new MessageRequest new MessageRequest
{ {
Content = Content =
$"{Emojis.Error} {msg.Member.DisplayName()}'s system has disabled reaction pings. If you want to mention them anyway, you can copy/paste the following message:" $"{Emojis.Error} {msg.Member.DisplayName()}'s system has disabled reaction pings. If you want to mention them anyway, you can copy/paste the following message:"
}); });
await _rest.CreateMessage( await _rest.CreateMessage(
dm.Id, dm,
new MessageRequest { Content = $"<@{msg.Message.Sender}>".AsCode() } new MessageRequest { Content = $"<@{msg.Message.Sender}>".AsCode() }
); );
} }

View File

@ -43,6 +43,7 @@ public class BotModule: Module
}).AsSelf().SingleInstance(); }).AsSelf().SingleInstance();
builder.RegisterType<Cluster>().AsSelf().SingleInstance(); builder.RegisterType<Cluster>().AsSelf().SingleInstance();
builder.Register(c => { return new MemoryDiscordCache(); }).AsSelf().As<IDiscordCache>().SingleInstance(); builder.Register(c => { return new MemoryDiscordCache(); }).AsSelf().As<IDiscordCache>().SingleInstance();
builder.RegisterType<PrivateChannelService>().AsSelf().SingleInstance();
builder.Register(c => builder.Register(c =>
{ {

View File

@ -0,0 +1,63 @@
using Serilog;
using Myriad.Cache;
using Myriad.Gateway;
using Myriad.Rest;
using PluralKit.Core;
namespace PluralKit.Bot;
public class PrivateChannelService
{
private readonly ILogger _logger;
private readonly ModelRepository _repo;
private readonly DiscordApiClient _rest;
private static Dictionary<ulong, ulong> _channelsCache = new();
public PrivateChannelService(ILogger logger, ModelRepository repo, DiscordApiClient rest)
{
_logger = logger;
_repo = repo;
_rest = rest;
}
public async Task TrySavePrivateChannel(MessageCreateEvent evt)
{
if (evt.GuildId != null) return;
if (_channelsCache.TryGetValue(evt.Author.Id, out _)) return;
await SaveDmChannel(evt.Author.Id, evt.ChannelId);
}
public async Task<ulong> GetOrCreateDmChannel(ulong userId)
{
if (_channelsCache.TryGetValue(userId, out var cachedChannelId))
return cachedChannelId;
var channelId = await _repo.GetDmChannel(userId);
if (channelId == null)
{
var channel = await _rest.CreateDm(userId);
channelId = channel.Id;
}
// spawn off saving the channel as to not block the current thread
_ = SaveDmChannel(userId, channelId.Value);
return channelId.Value;
}
private async Task SaveDmChannel(ulong userId, ulong channelId)
{
try
{
_channelsCache.Add(userId, channelId);
await _repo.UpdateAccount(userId, new() { DmChannel = channelId });
}
catch (Exception e)
{
_logger.Error(e, "Failed to save DM channel {ChannelId} for user {UserId}", channelId, userId);
}
}
}

View File

@ -0,0 +1,12 @@
-- schema version 26
-- cache Discord DM channels in the database
alter table accounts alter column system drop not null;
alter table accounts drop constraint accounts_system_fkey;
alter table accounts
add constraint accounts_system_fkey
foreign key (system) references systems(id) on delete set null;
alter table accounts add column dm_channel bigint;
update info set schema_version = 26;

View File

@ -1,9 +1,14 @@
using Dapper;
using SqlKata; using SqlKata;
namespace PluralKit.Core; namespace PluralKit.Core;
public partial class ModelRepository public partial class ModelRepository
{ {
public async Task<ulong?> GetDmChannel(ulong id)
=> await _db.Execute(c => c.QueryFirstOrDefaultAsync<ulong?>("select dm_channel from accounts where uid = @id", new { id = id }));
public async Task UpdateAccount(ulong id, AccountPatch patch) public async Task UpdateAccount(ulong id, AccountPatch patch)
{ {
_logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch); _logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch);

View File

@ -24,7 +24,8 @@ public partial class ModelRepository
var query = new Query("accounts") var query = new Query("accounts")
.Select("systems.*") .Select("systems.*")
.LeftJoin("systems", "systems.id", "accounts.system") .LeftJoin("systems", "systems.id", "accounts.system")
.Where("uid", accountId); .Where("uid", accountId)
.WhereNotNull("system");
return _db.QueryFirst<PKSystem?>(query); return _db.QueryFirst<PKSystem?>(query);
} }
@ -111,10 +112,13 @@ public partial class ModelRepository
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
// This is used in import/export, although the pk;link command checks for this case beforehand // This is used in import/export, although the pk;link command checks for this case beforehand
// update 2022-01: the accounts table is now independent of systems
// we MUST check for the presence of a system before inserting, or it will move the new account to the current system
var query = new Query("accounts").AsInsert(new { system, uid = accountId }); var query = new Query("accounts").AsInsert(new { system, uid = accountId });
await _db.ExecuteQuery(conn, query, "on conflict (uid) do update set system = @p0");
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system); _logger.Information("Linked account {UserId} to {SystemId}", accountId, system);
await _db.ExecuteQuery(conn, query, "on conflict do nothing");
_ = _dispatch.Dispatch(system, new UpdateDispatchData _ = _dispatch.Dispatch(system, new UpdateDispatchData
{ {
@ -125,7 +129,10 @@ public partial class ModelRepository
public async Task RemoveAccount(SystemId system, ulong accountId) public async Task RemoveAccount(SystemId system, ulong accountId)
{ {
var query = new Query("accounts").AsDelete().Where("uid", accountId).Where("system", system); var query = new Query("accounts").AsUpdate(new
{
system = (ulong?)null
}).Where("uid", accountId).Where("system", system);
await _db.ExecuteQuery(query); await _db.ExecuteQuery(query);
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system); _logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system);
_ = _dispatch.Dispatch(system, new UpdateDispatchData _ = _dispatch.Dispatch(system, new UpdateDispatchData

View File

@ -9,7 +9,7 @@ namespace PluralKit.Core;
internal class DatabaseMigrator internal class DatabaseMigrator
{ {
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
private const int TargetSchemaVersion = 25; private const int TargetSchemaVersion = 26;
private readonly ILogger _logger; private readonly ILogger _logger;
public DatabaseMigrator(ILogger logger) public DatabaseMigrator(ILogger logger)

View File

@ -6,9 +6,11 @@ namespace PluralKit.Core;
public class AccountPatch: PatchObject public class AccountPatch: PatchObject
{ {
public Partial<ulong> DmChannel { get; set; }
public Partial<bool> AllowAutoproxy { get; set; } public Partial<bool> AllowAutoproxy { get; set; }
public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper
.With("dm_channel", DmChannel)
.With("allow_autoproxy", AllowAutoproxy) .With("allow_autoproxy", AllowAutoproxy)
); );