feat: cache Discord DM channels in database
This commit is contained in:
		| @@ -20,7 +20,6 @@ public interface IDiscordCache | ||||
|     public Task<ulong> GetOwnUser(); | ||||
|     public Task<Guild?> TryGetGuild(ulong guildId); | ||||
|     public Task<Channel?> TryGetChannel(ulong channelId); | ||||
|     public Task<Channel?> TryGetDmChannel(ulong userId); | ||||
|     public Task<User?> TryGetUser(ulong userId); | ||||
|     public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId); | ||||
|     public Task<Role?> TryGetRole(ulong roleId); | ||||
|   | ||||
| @@ -7,7 +7,6 @@ namespace Myriad.Cache; | ||||
| public class MemoryDiscordCache: IDiscordCache | ||||
| { | ||||
|     private readonly ConcurrentDictionary<ulong, Channel> _channels = new(); | ||||
|     private readonly ConcurrentDictionary<ulong, ulong> _dmChannels = new(); | ||||
|     private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new(); | ||||
|     private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds = new(); | ||||
|     private readonly ConcurrentDictionary<ulong, Role> _roles = new(); | ||||
| @@ -35,10 +34,7 @@ public class MemoryDiscordCache: IDiscordCache | ||||
|  | ||||
|         if (channel.Recipients != null) | ||||
|             foreach (var recipient in channel.Recipients) | ||||
|             { | ||||
|                 _dmChannels[recipient.Id] = channel.Id; | ||||
|                 await SaveUser(recipient); | ||||
|             } | ||||
|     } | ||||
|  | ||||
|     public ValueTask SaveOwnUser(ulong userId) | ||||
| @@ -140,13 +136,6 @@ public class MemoryDiscordCache: IDiscordCache | ||||
|         return Task.FromResult(channel); | ||||
|     } | ||||
|  | ||||
|     public Task<Channel?> TryGetDmChannel(ulong userId) | ||||
|     { | ||||
|         if (!_dmChannels.TryGetValue(userId, out var channelId)) | ||||
|             return Task.FromResult((Channel?)null); | ||||
|         return TryGetChannel(channelId); | ||||
|     } | ||||
|  | ||||
|     public Task<User?> TryGetUser(ulong userId) | ||||
|     { | ||||
|         _users.TryGetValue(userId, out var user); | ||||
|   | ||||
| @@ -58,17 +58,6 @@ public static class CacheExtensions | ||||
|         return restChannel; | ||||
|     } | ||||
|  | ||||
|     public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest, | ||||
|                                                            ulong recipientId) | ||||
|     { | ||||
|         if (await cache.TryGetDmChannel(recipientId) is { } cacheChannel) | ||||
|             return cacheChannel; | ||||
|  | ||||
|         var restChannel = await rest.CreateDm(recipientId); | ||||
|         await cache.SaveChannel(restChannel); | ||||
|         return restChannel; | ||||
|     } | ||||
|  | ||||
|     public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong channelOrThread) | ||||
|     { | ||||
|         var channel = await cache.GetChannel(channelOrThread); | ||||
|   | ||||
| @@ -17,12 +17,14 @@ public class Api | ||||
|     private readonly BotConfig _botConfig; | ||||
|     private readonly DispatchService _dispatch; | ||||
|     private readonly ModelRepository _repo; | ||||
|     private readonly PrivateChannelService _dmCache; | ||||
|  | ||||
|     public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch) | ||||
|     public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch, PrivateChannelService dmCache) | ||||
|     { | ||||
|         _botConfig = botConfig; | ||||
|         _repo = repo; | ||||
|         _dispatch = dispatch; | ||||
|         _dmCache = dmCache; | ||||
|     } | ||||
|  | ||||
|     public async Task GetToken(Context ctx) | ||||
| @@ -35,17 +37,17 @@ public class Api | ||||
|         try | ||||
|         { | ||||
|             // DM the user a security disclaimer, and then the token in a separate message (for easy copying on mobile) | ||||
|             var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); | ||||
|             await ctx.Rest.CreateMessage(dm.Id, | ||||
|             var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id); | ||||
|             await ctx.Rest.CreateMessage(dm, | ||||
|                 new MessageRequest | ||||
|                 { | ||||
|                     Content = $"{Emojis.Warn} Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure." | ||||
|                             + $" If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`.\n\nYour token is below:" | ||||
|                 }); | ||||
|             await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token }); | ||||
|             await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token }); | ||||
|  | ||||
|             if (_botConfig.IsBetaBot) | ||||
|                 await ctx.Rest.CreateMessage(dm.Id, new MessageRequest | ||||
|                 await ctx.Rest.CreateMessage(dm, new MessageRequest | ||||
|                 { | ||||
|                     Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>." | ||||
|                                                                                     + " You need to use this URL instead of the base URL listed on the documentation website." | ||||
| @@ -84,8 +86,8 @@ public class Api | ||||
|         try | ||||
|         { | ||||
|             // DM the user an invalidation disclaimer, and then the token in a separate message (for easy copying on mobile) | ||||
|             var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); | ||||
|             await ctx.Rest.CreateMessage(dm.Id, | ||||
|             var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id); | ||||
|             await ctx.Rest.CreateMessage(dm, | ||||
|                 new MessageRequest | ||||
|                 { | ||||
|                     Content = $"{Emojis.Warn} Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\n\nYour token is below:" | ||||
| @@ -94,10 +96,10 @@ public class Api | ||||
|             // Make the new token after sending the first DM; this ensures if we can't DM, we also don't end up | ||||
|             // breaking their existing token as a side effect :) | ||||
|             var token = await MakeAndSetNewToken(ctx.System); | ||||
|             await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token }); | ||||
|             await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token }); | ||||
|  | ||||
|             if (_botConfig.IsBetaBot) | ||||
|                 await ctx.Rest.CreateMessage(dm.Id, new MessageRequest | ||||
|                 await ctx.Rest.CreateMessage(dm, new MessageRequest | ||||
|                 { | ||||
|                     Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>." | ||||
|                                                                                    + " You need to use this URL instead of the base URL listed on the documentation website." | ||||
|   | ||||
| @@ -17,6 +17,7 @@ public class ImportExport | ||||
| { | ||||
|     private readonly HttpClient _client; | ||||
|     private readonly DataFileService _dataFiles; | ||||
|     private readonly PrivateChannelService _dmCache; | ||||
|  | ||||
|     private readonly JsonSerializerSettings _settings = new() | ||||
|     { | ||||
| @@ -24,10 +25,11 @@ public class ImportExport | ||||
|         DateParseHandling = DateParseHandling.None | ||||
|     }; | ||||
|  | ||||
|     public ImportExport(DataFileService dataFiles, HttpClient client) | ||||
|     public ImportExport(DataFileService dataFiles, HttpClient client, PrivateChannelService dmCache) | ||||
|     { | ||||
|         _dataFiles = dataFiles; | ||||
|         _client = client; | ||||
|         _dmCache = dmCache; | ||||
|     } | ||||
|  | ||||
|     public async Task Import(Context ctx) | ||||
| @@ -110,12 +112,12 @@ public class ImportExport | ||||
|  | ||||
|         try | ||||
|         { | ||||
|             var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); | ||||
|             var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id); | ||||
|  | ||||
|             var msg = await ctx.Rest.CreateMessage(dm.Id, | ||||
|             var msg = await ctx.Rest.CreateMessage(dm, | ||||
|                 new MessageRequest { Content = $"{Emojis.Success} Here you go!" }, | ||||
|                 new[] { new MultipartFile("system.json", stream, null) }); | ||||
|             await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" }); | ||||
|             await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" }); | ||||
|  | ||||
|             // If the original message wasn't posted in DMs, send a public reminder | ||||
|             if (ctx.Channel.Type != Channel.ChannelType.Dm) | ||||
|   | ||||
| @@ -28,12 +28,13 @@ public class MessageCreated: IEventHandler<MessageCreateEvent> | ||||
|     private readonly DiscordApiClient _rest; | ||||
|     private readonly ILifetimeScope _services; | ||||
|     private readonly CommandTree _tree; | ||||
|     private readonly PrivateChannelService _dmCache; | ||||
|  | ||||
|     public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, | ||||
|                           IMetrics metrics, ProxyService proxy, | ||||
|                           CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, | ||||
|                           ModelRepository repo, IDiscordCache cache, | ||||
|                           Bot bot, Cluster cluster, DiscordApiClient rest) | ||||
|                           Bot bot, Cluster cluster, DiscordApiClient rest, PrivateChannelService dmCache) | ||||
|     { | ||||
|         _lastMessageCache = lastMessageCache; | ||||
|         _loggerClean = loggerClean; | ||||
| @@ -48,6 +49,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent> | ||||
|         _bot = bot; | ||||
|         _cluster = cluster; | ||||
|         _rest = rest; | ||||
|         _dmCache = dmCache; | ||||
|     } | ||||
|  | ||||
|     // for now, only return error messages for explicit commands | ||||
| @@ -66,6 +68,10 @@ public class MessageCreated: IEventHandler<MessageCreateEvent> | ||||
|         if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return; | ||||
|         if (IsDuplicateMessage(evt)) return; | ||||
|  | ||||
|         // spawn off saving the private channel into another thread | ||||
|         // it is not a fatal error if this fails, and it shouldn't block message processing | ||||
|         _ = _dmCache.TrySavePrivateChannel(evt); | ||||
|  | ||||
|         var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null; | ||||
|         var channel = await _cache.GetChannel(evt.ChannelId); | ||||
|         var rootChannel = await _cache.GetRootChannel(evt.ChannelId); | ||||
|   | ||||
| @@ -26,10 +26,11 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent> | ||||
|     private readonly ILogger _logger; | ||||
|     private readonly ModelRepository _repo; | ||||
|     private readonly DiscordApiClient _rest; | ||||
|     private readonly PrivateChannelService _dmCache; | ||||
|  | ||||
|     public ReactionAdded(ILogger logger, IDatabase db, ModelRepository repo, | ||||
|                          CommandMessageService commandMessageService, IDiscordCache cache, Bot bot, Cluster cluster, | ||||
|                          DiscordApiClient rest, EmbedService embeds) | ||||
|                          DiscordApiClient rest, EmbedService embeds, PrivateChannelService dmCache) | ||||
|     { | ||||
|         _db = db; | ||||
|         _repo = repo; | ||||
| @@ -40,6 +41,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent> | ||||
|         _rest = rest; | ||||
|         _embeds = embeds; | ||||
|         _logger = logger.ForContext<ReactionAdded>(); | ||||
|         _dmCache = dmCache; | ||||
|     } | ||||
|  | ||||
|     public async Task Handle(int shardId, MessageReactionAddEvent evt) | ||||
| @@ -168,9 +170,9 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent> | ||||
|         // Try to DM the user info about the message | ||||
|         try | ||||
|         { | ||||
|             var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId); | ||||
|             var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId); | ||||
|             if (msg.Member != null) | ||||
|                 await _rest.CreateMessage(dm.Id, new MessageRequest | ||||
|                 await _rest.CreateMessage(dm, new MessageRequest | ||||
|                 { | ||||
|                     Embed = await _embeds.CreateMemberEmbed( | ||||
|                         msg.System, | ||||
| @@ -182,7 +184,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent> | ||||
|                 }); | ||||
|  | ||||
|             await _rest.CreateMessage( | ||||
|                 dm.Id, | ||||
|                 dm, | ||||
|                 new MessageRequest { Embed = await _embeds.CreateMessageInfoEmbed(msg, true) } | ||||
|             ); | ||||
|         } | ||||
| @@ -234,15 +236,15 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent> | ||||
|             // If not, tell them in DMs (if we can) | ||||
|             try | ||||
|             { | ||||
|                 var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId); | ||||
|                 await _rest.CreateMessage(dm.Id, | ||||
|                 var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId); | ||||
|                 await _rest.CreateMessage(dm, | ||||
|                     new MessageRequest | ||||
|                     { | ||||
|                         Content = | ||||
|                             $"{Emojis.Error} {msg.Member.DisplayName()}'s system has disabled reaction pings. If you want to mention them anyway, you can copy/paste the following message:" | ||||
|                     }); | ||||
|                 await _rest.CreateMessage( | ||||
|                     dm.Id, | ||||
|                     dm, | ||||
|                     new MessageRequest { Content = $"<@{msg.Message.Sender}>".AsCode() } | ||||
|                 ); | ||||
|             } | ||||
|   | ||||
| @@ -43,6 +43,7 @@ public class BotModule: Module | ||||
|         }).AsSelf().SingleInstance(); | ||||
|         builder.RegisterType<Cluster>().AsSelf().SingleInstance(); | ||||
|         builder.Register(c => { return new MemoryDiscordCache(); }).AsSelf().As<IDiscordCache>().SingleInstance(); | ||||
|         builder.RegisterType<PrivateChannelService>().AsSelf().SingleInstance(); | ||||
|  | ||||
|         builder.Register(c => | ||||
|         { | ||||
|   | ||||
							
								
								
									
										63
									
								
								PluralKit.Bot/Services/PrivateChannelService.cs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								PluralKit.Bot/Services/PrivateChannelService.cs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| using Serilog; | ||||
|  | ||||
| using Myriad.Cache; | ||||
| using Myriad.Gateway; | ||||
| using Myriad.Rest; | ||||
|  | ||||
| using PluralKit.Core; | ||||
|  | ||||
| namespace PluralKit.Bot; | ||||
|  | ||||
| public class PrivateChannelService | ||||
| { | ||||
|     private readonly ILogger _logger; | ||||
|     private readonly ModelRepository _repo; | ||||
|     private readonly DiscordApiClient _rest; | ||||
|  | ||||
|     private static Dictionary<ulong, ulong> _channelsCache = new(); | ||||
|     public PrivateChannelService(ILogger logger, ModelRepository repo, DiscordApiClient rest) | ||||
|     { | ||||
|         _logger = logger; | ||||
|         _repo = repo; | ||||
|         _rest = rest; | ||||
|     } | ||||
|  | ||||
|     public async Task TrySavePrivateChannel(MessageCreateEvent evt) | ||||
|     { | ||||
|         if (evt.GuildId != null) return; | ||||
|         if (_channelsCache.TryGetValue(evt.Author.Id, out _)) return; | ||||
|  | ||||
|         await SaveDmChannel(evt.Author.Id, evt.ChannelId); | ||||
|     } | ||||
|  | ||||
|     public async Task<ulong> GetOrCreateDmChannel(ulong userId) | ||||
|     { | ||||
|         if (_channelsCache.TryGetValue(userId, out var cachedChannelId)) | ||||
|             return cachedChannelId; | ||||
|  | ||||
|         var channelId = await _repo.GetDmChannel(userId); | ||||
|         if (channelId == null) | ||||
|         { | ||||
|             var channel = await _rest.CreateDm(userId); | ||||
|             channelId = channel.Id; | ||||
|         } | ||||
|  | ||||
|         // spawn off saving the channel as to not block the current thread | ||||
|         _ = SaveDmChannel(userId, channelId.Value); | ||||
|  | ||||
|         return channelId.Value; | ||||
|     } | ||||
|  | ||||
|     private async Task SaveDmChannel(ulong userId, ulong channelId) | ||||
|     { | ||||
|         try | ||||
|         { | ||||
|             _channelsCache.Add(userId, channelId); | ||||
|             await _repo.UpdateAccount(userId, new() { DmChannel = channelId }); | ||||
|         } | ||||
|         catch (Exception e) | ||||
|         { | ||||
|             _logger.Error(e, "Failed to save DM channel {ChannelId} for user {UserId}", channelId, userId); | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										12
									
								
								PluralKit.Core/Database/Migrations/26.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								PluralKit.Core/Database/Migrations/26.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| -- schema version 26 | ||||
| -- cache Discord DM channels in the database | ||||
|  | ||||
| alter table accounts alter column system drop not null; | ||||
| alter table accounts drop constraint accounts_system_fkey; | ||||
| alter table accounts | ||||
|     add constraint accounts_system_fkey | ||||
|     foreign key (system) references systems(id) on delete set null; | ||||
|  | ||||
| alter table accounts add column dm_channel bigint; | ||||
|  | ||||
| update info set schema_version = 26; | ||||
| @@ -1,9 +1,14 @@ | ||||
| using Dapper; | ||||
|  | ||||
| using SqlKata; | ||||
|  | ||||
| namespace PluralKit.Core; | ||||
|  | ||||
| public partial class ModelRepository | ||||
| { | ||||
|     public async Task<ulong?> GetDmChannel(ulong id) | ||||
|         => await _db.Execute(c => c.QueryFirstOrDefaultAsync<ulong?>("select dm_channel from accounts where uid = @id", new { id = id })); | ||||
|  | ||||
|     public async Task UpdateAccount(ulong id, AccountPatch patch) | ||||
|     { | ||||
|         _logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch); | ||||
|   | ||||
| @@ -24,7 +24,8 @@ public partial class ModelRepository | ||||
|         var query = new Query("accounts") | ||||
|             .Select("systems.*") | ||||
|             .LeftJoin("systems", "systems.id", "accounts.system") | ||||
|             .Where("uid", accountId); | ||||
|             .Where("uid", accountId) | ||||
|             .WhereNotNull("system"); | ||||
|         return _db.QueryFirst<PKSystem?>(query); | ||||
|     } | ||||
|  | ||||
| @@ -111,10 +112,13 @@ public partial class ModelRepository | ||||
|         // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent | ||||
|         // This is used in import/export, although the pk;link command checks for this case beforehand | ||||
|  | ||||
|         // update 2022-01: the accounts table is now independent of systems | ||||
|         // we MUST check for the presence of a system before inserting, or it will move the new account to the current system | ||||
|  | ||||
|         var query = new Query("accounts").AsInsert(new { system, uid = accountId }); | ||||
|         await _db.ExecuteQuery(conn, query, "on conflict (uid) do update set system = @p0"); | ||||
|  | ||||
|         _logger.Information("Linked account {UserId} to {SystemId}", accountId, system); | ||||
|         await _db.ExecuteQuery(conn, query, "on conflict do nothing"); | ||||
|  | ||||
|         _ = _dispatch.Dispatch(system, new UpdateDispatchData | ||||
|         { | ||||
| @@ -125,7 +129,10 @@ public partial class ModelRepository | ||||
|  | ||||
|     public async Task RemoveAccount(SystemId system, ulong accountId) | ||||
|     { | ||||
|         var query = new Query("accounts").AsDelete().Where("uid", accountId).Where("system", system); | ||||
|         var query = new Query("accounts").AsUpdate(new | ||||
|         { | ||||
|             system = (ulong?)null | ||||
|         }).Where("uid", accountId).Where("system", system); | ||||
|         await _db.ExecuteQuery(query); | ||||
|         _logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system); | ||||
|         _ = _dispatch.Dispatch(system, new UpdateDispatchData | ||||
|   | ||||
| @@ -9,7 +9,7 @@ namespace PluralKit.Core; | ||||
| internal class DatabaseMigrator | ||||
| { | ||||
|     private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files | ||||
|     private const int TargetSchemaVersion = 25; | ||||
|     private const int TargetSchemaVersion = 26; | ||||
|     private readonly ILogger _logger; | ||||
|  | ||||
|     public DatabaseMigrator(ILogger logger) | ||||
|   | ||||
| @@ -6,9 +6,11 @@ namespace PluralKit.Core; | ||||
|  | ||||
| public class AccountPatch: PatchObject | ||||
| { | ||||
|     public Partial<ulong> DmChannel { get; set; } | ||||
|     public Partial<bool> AllowAutoproxy { get; set; } | ||||
|  | ||||
|     public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper | ||||
|         .With("dm_channel", DmChannel) | ||||
|         .With("allow_autoproxy", AllowAutoproxy) | ||||
|     ); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user