From 81cd5496d5e7582d52a3f5e5ed0a62d030945bc8 Mon Sep 17 00:00:00 2001 From: Ske Date: Fri, 23 Oct 2020 12:18:28 +0200 Subject: [PATCH] Refactor command message deletion --- PluralKit.Bot/Bot.cs | 8 +-- PluralKit.Bot/CommandSystem/Context.cs | 10 +++- PluralKit.Bot/Handlers/ReactionAdded.cs | 56 ++++++++++++------- PluralKit.Bot/Modules.cs | 1 + .../Services/CommandMessageService.cs | 50 +++++++++++++++++ PluralKit.Core/Database/Migrations/11.sql | 15 ++--- .../ModelRepository.CommandMessage.cs | 23 ++++---- 7 files changed, 117 insertions(+), 46 deletions(-) create mode 100644 PluralKit.Bot/Services/CommandMessageService.cs diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index 5e87f3da..73a61949 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -36,13 +36,13 @@ namespace PluralKit.Bot private readonly PeriodicStatCollector _collector; private readonly IMetrics _metrics; private readonly ErrorMessageService _errorMessageService; - private readonly IDatabase _db; + private readonly CommandMessageService _commandMessageService; private bool _hasReceivedReady = false; private Timer _periodicTask; // Never read, just kept here for GC reasons public Bot(DiscordShardedClient client, ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics, - ErrorMessageService errorMessageService, IDatabase db) + ErrorMessageService errorMessageService, CommandMessageService commandMessageService) { _client = client; _logger = logger.ForContext(); @@ -50,7 +50,7 @@ namespace PluralKit.Bot _collector = collector; _metrics = metrics; _errorMessageService = errorMessageService; - _db = db; + _commandMessageService = commandMessageService; } public void Init() @@ -183,7 +183,7 @@ namespace PluralKit.Bot await UpdateBotStatus(); // Clean up message cache in postgres - await _db.Execute(conn => conn.QueryAsync("select from cleanup_command_message()")); + await _commandMessageService.CleanupOldMessages(); // Collect some stats, submit them to the metrics backend await _collector.CollectStats(); diff --git a/PluralKit.Bot/CommandSystem/Context.cs b/PluralKit.Bot/CommandSystem/Context.cs index 1097cc18..effdfe46 100644 --- a/PluralKit.Bot/CommandSystem/Context.cs +++ b/PluralKit.Bot/CommandSystem/Context.cs @@ -28,6 +28,7 @@ namespace PluralKit.Bot private readonly ModelRepository _repo; private readonly PKSystem _senderSystem; private readonly IMetrics _metrics; + private readonly CommandMessageService _commandMessageService; private Command _currentCommand; @@ -44,6 +45,7 @@ namespace PluralKit.Bot _repo = provider.Resolve(); _metrics = provider.Resolve(); _provider = provider; + _commandMessageService = provider.Resolve(); _parameters = new Parameters(message.Content.Substring(commandParseOffset)); } @@ -73,10 +75,14 @@ namespace PluralKit.Bot if (embed != null && !this.BotHasAllPermissions(Permissions.EmbedLinks)) throw new PKError("PluralKit does not have permission to send embeds in this channel. Please ensure I have the **Embed Links** permission enabled."); var msg = await Channel.SendMessageFixedAsync(text, embed: embed, mentions: mentions); - if (embed != null) + + if (embed != null) + { // Sensitive information that might want to be deleted by :x: reaction is typically in an embed format (member cards, for example) // This may need to be changed at some point but works well enough for now - await _db.Execute(conn => _repo.SaveCommandMessage(conn, msg.Id, Author.Id)); + await _commandMessageService.RegisterMessage(msg.Id, Author.Id); + } + return msg; } diff --git a/PluralKit.Bot/Handlers/ReactionAdded.cs b/PluralKit.Bot/Handlers/ReactionAdded.cs index 57b08dcb..6210386c 100644 --- a/PluralKit.Bot/Handlers/ReactionAdded.cs +++ b/PluralKit.Bot/Handlers/ReactionAdded.cs @@ -15,14 +15,16 @@ namespace PluralKit.Bot { private readonly IDatabase _db; private readonly ModelRepository _repo; + private readonly CommandMessageService _commandMessageService; private readonly EmbedService _embeds; private readonly ILogger _logger; - public ReactionAdded(EmbedService embeds, ILogger logger, IDatabase db, ModelRepository repo) + public ReactionAdded(EmbedService embeds, ILogger logger, IDatabase db, ModelRepository repo, CommandMessageService commandMessageService) { _embeds = embeds; _db = db; _repo = repo; + _commandMessageService = commandMessageService; _logger = logger.ForContext(); } @@ -43,40 +45,54 @@ namespace PluralKit.Bot // Ignore reactions from bots (we can't DM them anyway) if (evt.User.IsBot) return; - - Task GetMessage() => - _db.Execute(c => _repo.GetMessage(c, evt.Message.Id)); - - FullMessage msg; - CommandMessage cmdmsg; + switch (evt.Emoji.Name) { // Message deletion case "\u274C": // Red X - if ((msg = await GetMessage()) != null) - await HandleDeleteReaction(evt, msg); - else if ((cmdmsg = await _db.Execute(conn => _repo.GetCommandMessage(conn, evt.Message.Id))) != null) - await HandleCommandDeleteReaction(evt, cmdmsg); - break; - + { + await using var conn = await _db.Obtain(); + var msg = await _repo.GetMessage(conn, evt.Message.Id); + if (msg != null) + { + await HandleProxyDeleteReaction(evt, msg); + break; + } + + var commandMsg = await _commandMessageService.GetCommandMessage(conn, evt.Message.Id); + if (commandMsg != null) + await HandleCommandDeleteReaction(evt, commandMsg); + + break; + } + case "\u2753": // Red question mark case "\u2754": // White question mark - if ((msg = await GetMessage()) != null) + { + await using var conn = await _db.Obtain(); + var msg = await _repo.GetMessage(conn, evt.Message.Id); + if (msg != null) await HandleQueryReaction(evt, msg); + break; - + } + case "\U0001F514": // Bell case "\U0001F6CE": // Bellhop bell case "\U0001F3D3": // Ping pong paddle (lol) case "\u23F0": // Alarm clock case "\u2757": // Exclamation mark - if ((msg = await GetMessage()) != null) + { + await using var conn = await _db.Obtain(); + var msg = await _repo.GetMessage(conn, evt.Message.Id); + if (msg != null) await HandlePingReaction(evt, msg); break; + } } } - private async ValueTask HandleDeleteReaction(MessageReactionAddEventArgs evt, FullMessage msg) + private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEventArgs evt, FullMessage msg) { if (!evt.Channel.BotHasAllPermissions(Permissions.ManageMessages)) return; @@ -97,10 +113,12 @@ namespace PluralKit.Bot private async ValueTask HandleCommandDeleteReaction(MessageReactionAddEventArgs evt, CommandMessage msg) { - if (!evt.Channel.BotHasAllPermissions(Permissions.ManageMessages)) return; + if (!evt.Channel.BotHasAllPermissions(Permissions.ManageMessages)) + return; // Can only delete your own message - if (msg.author_id != evt.User.Id) return; + if (msg.AuthorId != evt.User.Id) + return; try { diff --git a/PluralKit.Bot/Modules.cs b/PluralKit.Bot/Modules.cs index ba3e15b9..a6752cbc 100644 --- a/PluralKit.Bot/Modules.cs +++ b/PluralKit.Bot/Modules.cs @@ -74,6 +74,7 @@ namespace PluralKit.Bot builder.RegisterType().AsSelf().SingleInstance(); builder.RegisterType().AsSelf().SingleInstance(); builder.RegisterType().AsSelf().SingleInstance(); + builder.RegisterType().AsSelf().SingleInstance(); // Sentry stuff builder.Register(_ => new Scope(null)).AsSelf().InstancePerLifetimeScope(); diff --git a/PluralKit.Bot/Services/CommandMessageService.cs b/PluralKit.Bot/Services/CommandMessageService.cs new file mode 100644 index 00000000..5ed33c4c --- /dev/null +++ b/PluralKit.Bot/Services/CommandMessageService.cs @@ -0,0 +1,50 @@ +using System.Threading.Tasks; + +using NodaTime; + +using PluralKit.Core; + +using Serilog; + +namespace PluralKit.Bot +{ + public class CommandMessageService + { + private static readonly Duration CommandMessageRetention = Duration.FromSeconds(2); + + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly IClock _clock; + private readonly ILogger _logger; + + public CommandMessageService(IDatabase db, ModelRepository repo, IClock clock, ILogger logger) + { + _db = db; + _repo = repo; + _clock = clock; + _logger = logger; + } + + public async Task RegisterMessage(ulong messageId, ulong authorId) + { + _logger.Debug("Registering command response {MessageId} from author {AuthorId}", messageId, authorId); + await _db.Execute(conn => _repo.SaveCommandMessage(conn, messageId, authorId)); + } + + public async Task GetCommandMessage(IPKConnection conn, ulong messageId) + { + return await _repo.GetCommandMessage(conn, messageId); + } + + public async Task CleanupOldMessages() + { + var deleteThresholdInstant = _clock.GetCurrentInstant() - CommandMessageRetention; + var deleteThresholdSnowflake = DiscordUtils.InstantToSnowflake(deleteThresholdInstant); + + var deletedRows = await _db.Execute(conn => _repo.DeleteCommandMessagesBefore(conn, deleteThresholdSnowflake)); + + _logger.Information("Pruned {DeletedRows} command messages older than retention {Retention} (older than {DeleteThresholdInstant} / {DeleteThresholdSnowflake})", + deletedRows, CommandMessageRetention, deleteThresholdInstant, deleteThresholdSnowflake); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Migrations/11.sql b/PluralKit.Core/Database/Migrations/11.sql index bfed37e1..a16bf097 100644 --- a/PluralKit.Core/Database/Migrations/11.sql +++ b/PluralKit.Core/Database/Migrations/11.sql @@ -1,17 +1,10 @@ --- SCHEMA VERSION 11: (insert date) -- +-- SCHEMA VERSION 11: 2020-10-23 -- -- Create command message table -- -create table command_message +create table command_messages ( - message_id bigint primary key, - author_id bigint not null, - timestamp timestamp not null default now() + message_id bigint primary key not null, + author_id bigint not null ); -create function cleanup_command_message() returns void as $$ -begin - delete from command_message where timestamp < now() - interval '2 hours'; -end; -$$ language plpgsql; - update info set schema_version = 11; diff --git a/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs b/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs index 1e38c447..69222313 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.CommandMessage.cs @@ -1,5 +1,3 @@ -using System.Collections.Generic; -using System.Data; using System.Threading.Tasks; using Dapper; @@ -8,17 +6,22 @@ namespace PluralKit.Core { public partial class ModelRepository { - public Task SaveCommandMessage(IPKConnection conn, ulong message_id, ulong author_id) => - conn.QueryAsync("insert into command_message (message_id, author_id) values (@Message, @Author)", - new {Message = message_id, Author = author_id }); + public Task SaveCommandMessage(IPKConnection conn, ulong messageId, ulong authorId) => + conn.QueryAsync("insert into command_messages (message_id, author_id) values (@Message, @Author)", + new {Message = messageId, Author = authorId }); - public Task GetCommandMessage(IPKConnection conn, ulong message_id) => - conn.QuerySingleOrDefaultAsync("select message_id, author_id from command_message where message_id = @Message", - new {Message = message_id}); - } + public Task GetCommandMessage(IPKConnection conn, ulong messageId) => + conn.QuerySingleOrDefaultAsync("select * from command_messages where message_id = @Message", + new {Message = messageId}); + + public Task DeleteCommandMessagesBefore(IPKConnection conn, ulong messageIdThreshold) => + conn.ExecuteAsync("delete from command_messages where message_id < @Threshold", + new {Threshold = messageIdThreshold}); + } public class CommandMessage { - public ulong author_id { get; set; } + public ulong AuthorId { get; set; } + public ulong MessageId { get; set; } } } \ No newline at end of file