feat: use redis cache for non-id message lookups
This commit is contained in:
		| @@ -18,6 +18,7 @@ public class PKControllerBase: ControllerBase | ||||
|     protected readonly ApiConfig _config; | ||||
|     protected readonly IDatabase _db; | ||||
|     protected readonly ModelRepository _repo; | ||||
|     protected readonly RedisService _redis; | ||||
|     protected readonly DispatchService _dispatch; | ||||
|  | ||||
|     public PKControllerBase(IServiceProvider svc) | ||||
| @@ -25,6 +26,7 @@ public class PKControllerBase: ControllerBase | ||||
|         _config = svc.GetRequiredService<ApiConfig>(); | ||||
|         _db = svc.GetRequiredService<IDatabase>(); | ||||
|         _repo = svc.GetRequiredService<ModelRepository>(); | ||||
|         _redis = svc.GetRequiredService<RedisService>(); | ||||
|         _dispatch = svc.GetRequiredService<DispatchService>(); | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -20,12 +20,7 @@ namespace PluralKit.API; | ||||
| [Route("private")] | ||||
| public class PrivateController: PKControllerBase | ||||
| { | ||||
|     private readonly RedisService _redis; | ||||
|  | ||||
|     public PrivateController(IServiceProvider svc) : base(svc) | ||||
|     { | ||||
|         _redis = svc.GetRequiredService<RedisService>(); | ||||
|     } | ||||
|     public PrivateController(IServiceProvider svc) : base(svc) { } | ||||
|  | ||||
|     [HttpGet("meta")] | ||||
|     public async Task<ActionResult<JObject>> Meta() | ||||
|   | ||||
| @@ -92,7 +92,9 @@ public class DiscordControllerV2: PKControllerBase | ||||
|     [HttpGet("messages/{messageId}")] | ||||
|     public async Task<ActionResult<JObject>> MessageGet(ulong messageId) | ||||
|     { | ||||
|         var msg = await _repo.GetFullMessage(messageId); | ||||
|         var messageByOriginal = await _redis.GetOriginalMid(messageId); | ||||
|  | ||||
|         var msg = await _repo.GetFullMessage(messageByOriginal ?? messageId); | ||||
|         if (msg == null) | ||||
|             throw Errors.MessageNotFound; | ||||
|  | ||||
|   | ||||
| @@ -40,6 +40,7 @@ public class Context | ||||
|         Cache = provider.Resolve<IDiscordCache>(); | ||||
|         Database = provider.Resolve<IDatabase>(); | ||||
|         Repository = provider.Resolve<ModelRepository>(); | ||||
|         Redis = provider.Resolve<RedisService>(); | ||||
|         _metrics = provider.Resolve<IMetrics>(); | ||||
|         _provider = provider; | ||||
|         _commandMessageService = provider.Resolve<CommandMessageService>(); | ||||
| @@ -74,6 +75,7 @@ public class Context | ||||
|  | ||||
|     internal readonly IDatabase Database; | ||||
|     internal readonly ModelRepository Repository; | ||||
|     internal readonly RedisService Redis; | ||||
|  | ||||
|     public async Task<Message> Reply(string text = null, Embed embed = null, AllowedMentions? mentions = null) | ||||
|     { | ||||
|   | ||||
| @@ -164,7 +164,7 @@ public class ProxiedMessage | ||||
|             ulong? recent = null; | ||||
|  | ||||
|             if (isReproxy) | ||||
|                 recent = await ctx.Repository.GetLastMessage(ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id); | ||||
|                 recent = await ctx.Redis.GetLastMessage(ctx.Author.Id, ctx.Channel.Id); | ||||
|             else | ||||
|                 recent = await FindRecentMessage(ctx, timeout); | ||||
|  | ||||
| @@ -210,13 +210,13 @@ public class ProxiedMessage | ||||
|         return (msg, member.System); | ||||
|     } | ||||
|  | ||||
|     private async Task<PKMessage?> FindRecentMessage(Context ctx, Duration timeout) | ||||
|     private async Task<ulong?> FindRecentMessage(Context ctx, Duration timeout) | ||||
|     { | ||||
|         var lastMessage = await ctx.Repository.GetLastMessage(ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id); | ||||
|         var lastMessage = await ctx.Redis.GetLastMessage(ctx.Author.Id, ctx.Channel.Id); | ||||
|         if (lastMessage == null) | ||||
|             return null; | ||||
|  | ||||
|         var timestamp = DiscordUtils.SnowflakeToInstant(lastMessage.Mid); | ||||
|         var timestamp = DiscordUtils.SnowflakeToInstant(lastMessage.Value); | ||||
|         if (SystemClock.Instance.GetCurrentInstant() - timestamp > timeout) | ||||
|             return null; | ||||
|  | ||||
|   | ||||
| @@ -22,6 +22,7 @@ public class ProxyService | ||||
|     private static readonly TimeSpan MessageDeletionDelay = TimeSpan.FromMilliseconds(1000); | ||||
|     private readonly IDiscordCache _cache; | ||||
|     private readonly IDatabase _db; | ||||
|     private readonly RedisService _redis; | ||||
|     private readonly DispatchService _dispatch; | ||||
|     private readonly LastMessageCacheService _lastMessage; | ||||
|  | ||||
| @@ -35,13 +36,14 @@ public class ProxyService | ||||
|     private readonly NodaTime.IClock _clock; | ||||
|  | ||||
|     public ProxyService(LogChannelService logChannel, ILogger logger, WebhookExecutorService webhookExecutor, | ||||
|             DispatchService dispatch, IDatabase db, ProxyMatcher matcher, IMetrics metrics, ModelRepository repo, | ||||
|             DispatchService dispatch, IDatabase db, RedisService redis, ProxyMatcher matcher, IMetrics metrics, ModelRepository repo, | ||||
|                       NodaTime.IClock clock, IDiscordCache cache, DiscordApiClient rest, LastMessageCacheService lastMessage) | ||||
|     { | ||||
|         _logChannel = logChannel; | ||||
|         _webhookExecutor = webhookExecutor; | ||||
|         _dispatch = dispatch; | ||||
|         _db = db; | ||||
|         _redis = redis; | ||||
|         _matcher = matcher; | ||||
|         _metrics = metrics; | ||||
|         _repo = repo; | ||||
| @@ -420,6 +422,18 @@ public class ProxyService | ||||
|         Task SaveMessageInDatabase() | ||||
|             => _repo.AddMessage(sentMessage); | ||||
|  | ||||
|         async Task SaveMessageInRedis() | ||||
|         { | ||||
|             // logclean info | ||||
|             await _redis.SetLogCleanup(triggerMessage.Author.Id, triggerMessage.GuildId.Value); | ||||
|  | ||||
|             // last message info (edit/reproxy) | ||||
|             await _redis.SetLastMessage(triggerMessage.Author.Id, triggerMessage.ChannelId, sentMessage.Mid); | ||||
|  | ||||
|             // "by original mid" lookup | ||||
|             await _redis.SetOriginalMid(triggerMessage.Id, proxyMessage.Id); | ||||
|         } | ||||
|  | ||||
|         Task LogMessageToChannel() => | ||||
|             _logChannel.LogMessage(sentMessage, triggerMessage, proxyMessage).AsTask(); | ||||
|  | ||||
| @@ -458,6 +472,7 @@ public class ProxyService | ||||
|         await Task.WhenAll( | ||||
|             DeleteProxyTriggerMessage(), | ||||
|             SaveMessageInDatabase(), | ||||
|             SaveMessageInRedis(), | ||||
|             LogMessageToChannel(), | ||||
|             SaveLatchAutoproxy(), | ||||
|             DispatchWebhook() | ||||
|   | ||||
| @@ -79,12 +79,12 @@ public class LoggerCleanService | ||||
|     private readonly IDiscordCache _cache; | ||||
|     private readonly DiscordApiClient _client; | ||||
|  | ||||
|     private readonly IDatabase _db; | ||||
|     private readonly RedisService _redis; | ||||
|     private readonly ILogger _logger; | ||||
|  | ||||
|     public LoggerCleanService(IDatabase db, DiscordApiClient client, IDiscordCache cache, ILogger logger) | ||||
|     public LoggerCleanService(RedisService redis, DiscordApiClient client, IDiscordCache cache, ILogger logger) | ||||
|     { | ||||
|         _db = db; | ||||
|         _redis = redis; | ||||
|         _client = client; | ||||
|         _cache = cache; | ||||
|         _logger = logger.ForContext<LoggerCleanService>(); | ||||
| @@ -124,20 +124,10 @@ public class LoggerCleanService | ||||
|                 _logger.Debug("Fuzzy logclean for {BotName} on {MessageId}: {@FuzzyExtractResult}", | ||||
|                     bot.Name, msg.Id, fuzzy); | ||||
|  | ||||
|                 var mid = await _db.Execute(conn => | ||||
|                     conn.QuerySingleOrDefaultAsync<ulong?>( | ||||
|                         "select mid from messages where sender = @User and mid > @ApproxID and guild = @Guild limit 1", | ||||
|                         new | ||||
|                         { | ||||
|                             fuzzy.Value.User, | ||||
|                             Guild = msg.GuildId, | ||||
|                             ApproxId = DiscordUtils.InstantToSnowflake( | ||||
|                                 fuzzy.Value.ApproxTimestamp - Duration.FromSeconds(3)) | ||||
|                         })); | ||||
|                 var exists = await _redis.HasLogCleanup(fuzzy.Value.User, msg.GuildId.Value); | ||||
|  | ||||
|                 // If we didn't find a corresponding message, bail | ||||
|                 if (mid == null) | ||||
|                     return; | ||||
|                 if (!exists) return; | ||||
|  | ||||
|                 // Otherwise, we can *reasonably assume* that this is a logged deletion, so delete the log message. | ||||
|                 await _client.DeleteMessage(msg.ChannelId, msg.Id); | ||||
| @@ -151,8 +141,7 @@ public class LoggerCleanService | ||||
|                 _logger.Debug("Pure logclean for {BotName} on {MessageId}: {@FuzzyExtractResult}", | ||||
|                     bot.Name, msg.Id, extractedId); | ||||
|  | ||||
|                 var mid = await _db.Execute(conn => conn.QuerySingleOrDefaultAsync<ulong?>( | ||||
|                     "select mid from messages where original_mid = @Mid", new { Mid = extractedId.Value })); | ||||
|                 var mid = await _redis.GetOriginalMid(extractedId.Value); | ||||
|                 if (mid == null) return; | ||||
|  | ||||
|                 // If we've gotten this far, we found a logged deletion of a trigger message. Just yeet it! | ||||
|   | ||||
| @@ -11,4 +11,41 @@ public class RedisService | ||||
|         if (config.RedisAddr != null) | ||||
|             Connection = await ConnectionMultiplexer.ConnectAsync(config.RedisAddr); | ||||
|     } | ||||
|  | ||||
|     private string LastMessageKey(ulong userId, ulong channelId) => $"user_last_message:{userId}:{channelId}"; | ||||
|     public Task SetLastMessage(ulong userId, ulong channelId, ulong mid) | ||||
|         => Connection.GetDatabase().UlongSetAsync(LastMessageKey(userId, channelId), mid, expiry: TimeSpan.FromMinutes(10)); | ||||
|     public Task<ulong?> GetLastMessage(ulong userId, ulong channelId) | ||||
|         => Connection.GetDatabase().UlongGetAsync(LastMessageKey(userId, channelId)); | ||||
|  | ||||
|     private string LoggerCleanKey(ulong userId, ulong guildId) => $"log_cleanup:{userId}:{guildId}"; | ||||
|     public Task SetLogCleanup(ulong userId, ulong guildId) | ||||
|         => Connection.GetDatabase().StringSetAsync(LoggerCleanKey(userId, guildId), 1, expiry: TimeSpan.FromSeconds(3)); | ||||
|     public Task<bool> HasLogCleanup(ulong userId, ulong guildId) | ||||
|         => Connection.GetDatabase().KeyExistsAsync(LoggerCleanKey(userId, guildId)); | ||||
|  | ||||
|     // note: these methods are named weird - they actually get the proxied mid from the original mid | ||||
|     // but anything else would've been more confusing | ||||
|     private string OriginalMidKey(ulong original_mid) => $"original_mid:{original_mid}"; | ||||
|     public Task SetOriginalMid(ulong original_mid, ulong proxied_mid) | ||||
|         => Connection.GetDatabase().UlongSetAsync(OriginalMidKey(original_mid), proxied_mid, expiry: TimeSpan.FromMinutes(30)); | ||||
|     public Task<ulong?> GetOriginalMid(ulong original_mid) | ||||
|         => Connection.GetDatabase().UlongGetAsync(OriginalMidKey(original_mid)); | ||||
| } | ||||
|  | ||||
| public static class RedisExt | ||||
| { | ||||
|     public static async Task<ulong?> UlongGetAsync(this StackExchange.Redis.IDatabase database, string key) | ||||
|     { | ||||
|         var data = await database.StringGetAsync(key); | ||||
|         if (data == RedisValue.Null) return null; | ||||
|  | ||||
|         if (ulong.TryParse(data, out var value)) | ||||
|             return value; | ||||
|  | ||||
|         return null; | ||||
|     } | ||||
|  | ||||
|     public static Task UlongSetAsync(this StackExchange.Redis.IDatabase database, string key, ulong value, TimeSpan? expiry = null) | ||||
|         => database.StringSetAsync(key, value.ToString(), expiry); | ||||
| } | ||||
| @@ -333,4 +333,8 @@ GET `/messages/{message}` | ||||
|  | ||||
| Message can be the ID of a proxied message, or the ID of the message that sent the proxy. | ||||
|  | ||||
| ::: warning | ||||
| Looking up messages by the original message ID only works **up to 30 minutes** after the message was sent. | ||||
| ::: | ||||
|  | ||||
| Returns a [message object](/api/models#message-object). | ||||
|   | ||||
		Reference in New Issue
	
	Block a user