feat: cache Discord DM channels in database
This commit is contained in:
parent
ddbf0e8691
commit
89c44a3482
@ -20,7 +20,6 @@ public interface IDiscordCache
|
||||
public Task<ulong> GetOwnUser();
|
||||
public Task<Guild?> TryGetGuild(ulong guildId);
|
||||
public Task<Channel?> TryGetChannel(ulong channelId);
|
||||
public Task<Channel?> TryGetDmChannel(ulong userId);
|
||||
public Task<User?> TryGetUser(ulong userId);
|
||||
public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId);
|
||||
public Task<Role?> TryGetRole(ulong roleId);
|
||||
|
@ -7,7 +7,6 @@ namespace Myriad.Cache;
|
||||
public class MemoryDiscordCache: IDiscordCache
|
||||
{
|
||||
private readonly ConcurrentDictionary<ulong, Channel> _channels = new();
|
||||
private readonly ConcurrentDictionary<ulong, ulong> _dmChannels = new();
|
||||
private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new();
|
||||
private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds = new();
|
||||
private readonly ConcurrentDictionary<ulong, Role> _roles = new();
|
||||
@ -35,10 +34,7 @@ public class MemoryDiscordCache: IDiscordCache
|
||||
|
||||
if (channel.Recipients != null)
|
||||
foreach (var recipient in channel.Recipients)
|
||||
{
|
||||
_dmChannels[recipient.Id] = channel.Id;
|
||||
await SaveUser(recipient);
|
||||
}
|
||||
}
|
||||
|
||||
public ValueTask SaveOwnUser(ulong userId)
|
||||
@ -140,13 +136,6 @@ public class MemoryDiscordCache: IDiscordCache
|
||||
return Task.FromResult(channel);
|
||||
}
|
||||
|
||||
public Task<Channel?> TryGetDmChannel(ulong userId)
|
||||
{
|
||||
if (!_dmChannels.TryGetValue(userId, out var channelId))
|
||||
return Task.FromResult((Channel?)null);
|
||||
return TryGetChannel(channelId);
|
||||
}
|
||||
|
||||
public Task<User?> TryGetUser(ulong userId)
|
||||
{
|
||||
_users.TryGetValue(userId, out var user);
|
||||
|
@ -58,17 +58,6 @@ public static class CacheExtensions
|
||||
return restChannel;
|
||||
}
|
||||
|
||||
public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest,
|
||||
ulong recipientId)
|
||||
{
|
||||
if (await cache.TryGetDmChannel(recipientId) is { } cacheChannel)
|
||||
return cacheChannel;
|
||||
|
||||
var restChannel = await rest.CreateDm(recipientId);
|
||||
await cache.SaveChannel(restChannel);
|
||||
return restChannel;
|
||||
}
|
||||
|
||||
public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong channelOrThread)
|
||||
{
|
||||
var channel = await cache.GetChannel(channelOrThread);
|
||||
|
@ -17,12 +17,14 @@ public class Api
|
||||
private readonly BotConfig _botConfig;
|
||||
private readonly DispatchService _dispatch;
|
||||
private readonly ModelRepository _repo;
|
||||
private readonly PrivateChannelService _dmCache;
|
||||
|
||||
public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch)
|
||||
public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch, PrivateChannelService dmCache)
|
||||
{
|
||||
_botConfig = botConfig;
|
||||
_repo = repo;
|
||||
_dispatch = dispatch;
|
||||
_dmCache = dmCache;
|
||||
}
|
||||
|
||||
public async Task GetToken(Context ctx)
|
||||
@ -35,17 +37,17 @@ public class Api
|
||||
try
|
||||
{
|
||||
// DM the user a security disclaimer, and then the token in a separate message (for easy copying on mobile)
|
||||
var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id);
|
||||
await ctx.Rest.CreateMessage(dm.Id,
|
||||
var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id);
|
||||
await ctx.Rest.CreateMessage(dm,
|
||||
new MessageRequest
|
||||
{
|
||||
Content = $"{Emojis.Warn} Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure."
|
||||
+ $" If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`.\n\nYour token is below:"
|
||||
});
|
||||
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token });
|
||||
await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token });
|
||||
|
||||
if (_botConfig.IsBetaBot)
|
||||
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest
|
||||
await ctx.Rest.CreateMessage(dm, new MessageRequest
|
||||
{
|
||||
Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>."
|
||||
+ " You need to use this URL instead of the base URL listed on the documentation website."
|
||||
@ -84,8 +86,8 @@ public class Api
|
||||
try
|
||||
{
|
||||
// DM the user an invalidation disclaimer, and then the token in a separate message (for easy copying on mobile)
|
||||
var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id);
|
||||
await ctx.Rest.CreateMessage(dm.Id,
|
||||
var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id);
|
||||
await ctx.Rest.CreateMessage(dm,
|
||||
new MessageRequest
|
||||
{
|
||||
Content = $"{Emojis.Warn} Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\n\nYour token is below:"
|
||||
@ -94,10 +96,10 @@ public class Api
|
||||
// Make the new token after sending the first DM; this ensures if we can't DM, we also don't end up
|
||||
// breaking their existing token as a side effect :)
|
||||
var token = await MakeAndSetNewToken(ctx.System);
|
||||
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token });
|
||||
await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token });
|
||||
|
||||
if (_botConfig.IsBetaBot)
|
||||
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest
|
||||
await ctx.Rest.CreateMessage(dm, new MessageRequest
|
||||
{
|
||||
Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>."
|
||||
+ " You need to use this URL instead of the base URL listed on the documentation website."
|
||||
|
@ -17,6 +17,7 @@ public class ImportExport
|
||||
{
|
||||
private readonly HttpClient _client;
|
||||
private readonly DataFileService _dataFiles;
|
||||
private readonly PrivateChannelService _dmCache;
|
||||
|
||||
private readonly JsonSerializerSettings _settings = new()
|
||||
{
|
||||
@ -24,10 +25,11 @@ public class ImportExport
|
||||
DateParseHandling = DateParseHandling.None
|
||||
};
|
||||
|
||||
public ImportExport(DataFileService dataFiles, HttpClient client)
|
||||
public ImportExport(DataFileService dataFiles, HttpClient client, PrivateChannelService dmCache)
|
||||
{
|
||||
_dataFiles = dataFiles;
|
||||
_client = client;
|
||||
_dmCache = dmCache;
|
||||
}
|
||||
|
||||
public async Task Import(Context ctx)
|
||||
@ -110,12 +112,12 @@ public class ImportExport
|
||||
|
||||
try
|
||||
{
|
||||
var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id);
|
||||
var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id);
|
||||
|
||||
var msg = await ctx.Rest.CreateMessage(dm.Id,
|
||||
var msg = await ctx.Rest.CreateMessage(dm,
|
||||
new MessageRequest { Content = $"{Emojis.Success} Here you go!" },
|
||||
new[] { new MultipartFile("system.json", stream, null) });
|
||||
await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" });
|
||||
await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" });
|
||||
|
||||
// If the original message wasn't posted in DMs, send a public reminder
|
||||
if (ctx.Channel.Type != Channel.ChannelType.Dm)
|
||||
|
@ -28,12 +28,13 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
|
||||
private readonly DiscordApiClient _rest;
|
||||
private readonly ILifetimeScope _services;
|
||||
private readonly CommandTree _tree;
|
||||
private readonly PrivateChannelService _dmCache;
|
||||
|
||||
public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean,
|
||||
IMetrics metrics, ProxyService proxy,
|
||||
CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config,
|
||||
ModelRepository repo, IDiscordCache cache,
|
||||
Bot bot, Cluster cluster, DiscordApiClient rest)
|
||||
Bot bot, Cluster cluster, DiscordApiClient rest, PrivateChannelService dmCache)
|
||||
{
|
||||
_lastMessageCache = lastMessageCache;
|
||||
_loggerClean = loggerClean;
|
||||
@ -48,6 +49,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
|
||||
_bot = bot;
|
||||
_cluster = cluster;
|
||||
_rest = rest;
|
||||
_dmCache = dmCache;
|
||||
}
|
||||
|
||||
// for now, only return error messages for explicit commands
|
||||
@ -66,6 +68,10 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
|
||||
if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return;
|
||||
if (IsDuplicateMessage(evt)) return;
|
||||
|
||||
// spawn off saving the private channel into another thread
|
||||
// it is not a fatal error if this fails, and it shouldn't block message processing
|
||||
_ = _dmCache.TrySavePrivateChannel(evt);
|
||||
|
||||
var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null;
|
||||
var channel = await _cache.GetChannel(evt.ChannelId);
|
||||
var rootChannel = await _cache.GetRootChannel(evt.ChannelId);
|
||||
|
@ -26,10 +26,11 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
|
||||
private readonly ILogger _logger;
|
||||
private readonly ModelRepository _repo;
|
||||
private readonly DiscordApiClient _rest;
|
||||
private readonly PrivateChannelService _dmCache;
|
||||
|
||||
public ReactionAdded(ILogger logger, IDatabase db, ModelRepository repo,
|
||||
CommandMessageService commandMessageService, IDiscordCache cache, Bot bot, Cluster cluster,
|
||||
DiscordApiClient rest, EmbedService embeds)
|
||||
DiscordApiClient rest, EmbedService embeds, PrivateChannelService dmCache)
|
||||
{
|
||||
_db = db;
|
||||
_repo = repo;
|
||||
@ -40,6 +41,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
|
||||
_rest = rest;
|
||||
_embeds = embeds;
|
||||
_logger = logger.ForContext<ReactionAdded>();
|
||||
_dmCache = dmCache;
|
||||
}
|
||||
|
||||
public async Task Handle(int shardId, MessageReactionAddEvent evt)
|
||||
@ -168,9 +170,9 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
|
||||
// Try to DM the user info about the message
|
||||
try
|
||||
{
|
||||
var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId);
|
||||
var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId);
|
||||
if (msg.Member != null)
|
||||
await _rest.CreateMessage(dm.Id, new MessageRequest
|
||||
await _rest.CreateMessage(dm, new MessageRequest
|
||||
{
|
||||
Embed = await _embeds.CreateMemberEmbed(
|
||||
msg.System,
|
||||
@ -182,7 +184,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
|
||||
});
|
||||
|
||||
await _rest.CreateMessage(
|
||||
dm.Id,
|
||||
dm,
|
||||
new MessageRequest { Embed = await _embeds.CreateMessageInfoEmbed(msg, true) }
|
||||
);
|
||||
}
|
||||
@ -234,15 +236,15 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
|
||||
// If not, tell them in DMs (if we can)
|
||||
try
|
||||
{
|
||||
var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId);
|
||||
await _rest.CreateMessage(dm.Id,
|
||||
var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId);
|
||||
await _rest.CreateMessage(dm,
|
||||
new MessageRequest
|
||||
{
|
||||
Content =
|
||||
$"{Emojis.Error} {msg.Member.DisplayName()}'s system has disabled reaction pings. If you want to mention them anyway, you can copy/paste the following message:"
|
||||
});
|
||||
await _rest.CreateMessage(
|
||||
dm.Id,
|
||||
dm,
|
||||
new MessageRequest { Content = $"<@{msg.Message.Sender}>".AsCode() }
|
||||
);
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ public class BotModule: Module
|
||||
}).AsSelf().SingleInstance();
|
||||
builder.RegisterType<Cluster>().AsSelf().SingleInstance();
|
||||
builder.Register(c => { return new MemoryDiscordCache(); }).AsSelf().As<IDiscordCache>().SingleInstance();
|
||||
builder.RegisterType<PrivateChannelService>().AsSelf().SingleInstance();
|
||||
|
||||
builder.Register(c =>
|
||||
{
|
||||
|
63
PluralKit.Bot/Services/PrivateChannelService.cs
Normal file
63
PluralKit.Bot/Services/PrivateChannelService.cs
Normal file
@ -0,0 +1,63 @@
|
||||
using Serilog;
|
||||
|
||||
using Myriad.Cache;
|
||||
using Myriad.Gateway;
|
||||
using Myriad.Rest;
|
||||
|
||||
using PluralKit.Core;
|
||||
|
||||
namespace PluralKit.Bot;
|
||||
|
||||
public class PrivateChannelService
|
||||
{
|
||||
private readonly ILogger _logger;
|
||||
private readonly ModelRepository _repo;
|
||||
private readonly DiscordApiClient _rest;
|
||||
|
||||
private static Dictionary<ulong, ulong> _channelsCache = new();
|
||||
public PrivateChannelService(ILogger logger, ModelRepository repo, DiscordApiClient rest)
|
||||
{
|
||||
_logger = logger;
|
||||
_repo = repo;
|
||||
_rest = rest;
|
||||
}
|
||||
|
||||
public async Task TrySavePrivateChannel(MessageCreateEvent evt)
|
||||
{
|
||||
if (evt.GuildId != null) return;
|
||||
if (_channelsCache.TryGetValue(evt.Author.Id, out _)) return;
|
||||
|
||||
await SaveDmChannel(evt.Author.Id, evt.ChannelId);
|
||||
}
|
||||
|
||||
public async Task<ulong> GetOrCreateDmChannel(ulong userId)
|
||||
{
|
||||
if (_channelsCache.TryGetValue(userId, out var cachedChannelId))
|
||||
return cachedChannelId;
|
||||
|
||||
var channelId = await _repo.GetDmChannel(userId);
|
||||
if (channelId == null)
|
||||
{
|
||||
var channel = await _rest.CreateDm(userId);
|
||||
channelId = channel.Id;
|
||||
}
|
||||
|
||||
// spawn off saving the channel as to not block the current thread
|
||||
_ = SaveDmChannel(userId, channelId.Value);
|
||||
|
||||
return channelId.Value;
|
||||
}
|
||||
|
||||
private async Task SaveDmChannel(ulong userId, ulong channelId)
|
||||
{
|
||||
try
|
||||
{
|
||||
_channelsCache.Add(userId, channelId);
|
||||
await _repo.UpdateAccount(userId, new() { DmChannel = channelId });
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
_logger.Error(e, "Failed to save DM channel {ChannelId} for user {UserId}", channelId, userId);
|
||||
}
|
||||
}
|
||||
}
|
12
PluralKit.Core/Database/Migrations/26.sql
Normal file
12
PluralKit.Core/Database/Migrations/26.sql
Normal file
@ -0,0 +1,12 @@
|
||||
-- schema version 26
|
||||
-- cache Discord DM channels in the database
|
||||
|
||||
alter table accounts alter column system drop not null;
|
||||
alter table accounts drop constraint accounts_system_fkey;
|
||||
alter table accounts
|
||||
add constraint accounts_system_fkey
|
||||
foreign key (system) references systems(id) on delete set null;
|
||||
|
||||
alter table accounts add column dm_channel bigint;
|
||||
|
||||
update info set schema_version = 26;
|
@ -1,9 +1,14 @@
|
||||
using Dapper;
|
||||
|
||||
using SqlKata;
|
||||
|
||||
namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public async Task<ulong?> GetDmChannel(ulong id)
|
||||
=> await _db.Execute(c => c.QueryFirstOrDefaultAsync<ulong?>("select dm_channel from accounts where uid = @id", new { id = id }));
|
||||
|
||||
public async Task UpdateAccount(ulong id, AccountPatch patch)
|
||||
{
|
||||
_logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch);
|
||||
|
@ -24,7 +24,8 @@ public partial class ModelRepository
|
||||
var query = new Query("accounts")
|
||||
.Select("systems.*")
|
||||
.LeftJoin("systems", "systems.id", "accounts.system")
|
||||
.Where("uid", accountId);
|
||||
.Where("uid", accountId)
|
||||
.WhereNotNull("system");
|
||||
return _db.QueryFirst<PKSystem?>(query);
|
||||
}
|
||||
|
||||
@ -111,10 +112,13 @@ public partial class ModelRepository
|
||||
// We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent
|
||||
// This is used in import/export, although the pk;link command checks for this case beforehand
|
||||
|
||||
// update 2022-01: the accounts table is now independent of systems
|
||||
// we MUST check for the presence of a system before inserting, or it will move the new account to the current system
|
||||
|
||||
var query = new Query("accounts").AsInsert(new { system, uid = accountId });
|
||||
await _db.ExecuteQuery(conn, query, "on conflict (uid) do update set system = @p0");
|
||||
|
||||
_logger.Information("Linked account {UserId} to {SystemId}", accountId, system);
|
||||
await _db.ExecuteQuery(conn, query, "on conflict do nothing");
|
||||
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData
|
||||
{
|
||||
@ -125,7 +129,10 @@ public partial class ModelRepository
|
||||
|
||||
public async Task RemoveAccount(SystemId system, ulong accountId)
|
||||
{
|
||||
var query = new Query("accounts").AsDelete().Where("uid", accountId).Where("system", system);
|
||||
var query = new Query("accounts").AsUpdate(new
|
||||
{
|
||||
system = (ulong?)null
|
||||
}).Where("uid", accountId).Where("system", system);
|
||||
await _db.ExecuteQuery(query);
|
||||
_logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system);
|
||||
_ = _dispatch.Dispatch(system, new UpdateDispatchData
|
||||
|
@ -9,7 +9,7 @@ namespace PluralKit.Core;
|
||||
internal class DatabaseMigrator
|
||||
{
|
||||
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
|
||||
private const int TargetSchemaVersion = 25;
|
||||
private const int TargetSchemaVersion = 26;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public DatabaseMigrator(ILogger logger)
|
||||
|
@ -6,9 +6,11 @@ namespace PluralKit.Core;
|
||||
|
||||
public class AccountPatch: PatchObject
|
||||
{
|
||||
public Partial<ulong> DmChannel { get; set; }
|
||||
public Partial<bool> AllowAutoproxy { get; set; }
|
||||
|
||||
public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper
|
||||
.With("dm_channel", DmChannel)
|
||||
.With("allow_autoproxy", AllowAutoproxy)
|
||||
);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user