Refactor command message deletion
This commit is contained in:
		| @@ -36,13 +36,13 @@ namespace PluralKit.Bot | ||||
|         private readonly PeriodicStatCollector _collector; | ||||
|         private readonly IMetrics _metrics; | ||||
|         private readonly ErrorMessageService _errorMessageService; | ||||
|         private readonly IDatabase _db; | ||||
|         private readonly CommandMessageService _commandMessageService; | ||||
|  | ||||
|         private bool _hasReceivedReady = false; | ||||
|         private Timer _periodicTask; // Never read, just kept here for GC reasons | ||||
|  | ||||
|         public Bot(DiscordShardedClient client, ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics,  | ||||
|             ErrorMessageService errorMessageService, IDatabase db) | ||||
|             ErrorMessageService errorMessageService, CommandMessageService commandMessageService) | ||||
|         { | ||||
|             _client = client; | ||||
|             _logger = logger.ForContext<Bot>(); | ||||
| @@ -50,7 +50,7 @@ namespace PluralKit.Bot | ||||
|             _collector = collector; | ||||
|             _metrics = metrics; | ||||
|             _errorMessageService = errorMessageService; | ||||
|             _db = db; | ||||
|             _commandMessageService = commandMessageService; | ||||
|         } | ||||
|  | ||||
|         public void Init() | ||||
| @@ -183,7 +183,7 @@ namespace PluralKit.Bot | ||||
|             await UpdateBotStatus(); | ||||
|  | ||||
|             // Clean up message cache in postgres | ||||
|             await _db.Execute(conn => conn.QueryAsync("select from cleanup_command_message()")); | ||||
|             await _commandMessageService.CleanupOldMessages(); | ||||
|  | ||||
|             // Collect some stats, submit them to the metrics backend | ||||
|             await _collector.CollectStats(); | ||||
|   | ||||
| @@ -28,6 +28,7 @@ namespace PluralKit.Bot | ||||
|         private readonly ModelRepository _repo; | ||||
|         private readonly PKSystem _senderSystem; | ||||
|         private readonly IMetrics _metrics; | ||||
|         private readonly CommandMessageService _commandMessageService; | ||||
|  | ||||
|         private Command _currentCommand; | ||||
|  | ||||
| @@ -44,6 +45,7 @@ namespace PluralKit.Bot | ||||
|             _repo = provider.Resolve<ModelRepository>(); | ||||
|             _metrics = provider.Resolve<IMetrics>(); | ||||
|             _provider = provider; | ||||
|             _commandMessageService = provider.Resolve<CommandMessageService>(); | ||||
|             _parameters = new Parameters(message.Content.Substring(commandParseOffset)); | ||||
|         } | ||||
|  | ||||
| @@ -73,10 +75,14 @@ namespace PluralKit.Bot | ||||
|             if (embed != null && !this.BotHasAllPermissions(Permissions.EmbedLinks)) | ||||
|                 throw new PKError("PluralKit does not have permission to send embeds in this channel. Please ensure I have the **Embed Links** permission enabled."); | ||||
|             var msg = await Channel.SendMessageFixedAsync(text, embed: embed, mentions: mentions); | ||||
|             if (embed != null)  | ||||
|  | ||||
|             if (embed != null) | ||||
|             { | ||||
|                 // Sensitive information that might want to be deleted by :x: reaction is typically in an embed format (member cards, for example) | ||||
|                 // This may need to be changed at some point but works well enough for now | ||||
|                 await _db.Execute(conn => _repo.SaveCommandMessage(conn, msg.Id, Author.Id)); | ||||
|                 await _commandMessageService.RegisterMessage(msg.Id, Author.Id); | ||||
|             } | ||||
|  | ||||
|             return msg; | ||||
|         } | ||||
|          | ||||
|   | ||||
| @@ -15,14 +15,16 @@ namespace PluralKit.Bot | ||||
|     { | ||||
|         private readonly IDatabase _db; | ||||
|         private readonly ModelRepository _repo; | ||||
|         private readonly CommandMessageService _commandMessageService; | ||||
|         private readonly EmbedService _embeds; | ||||
|         private readonly ILogger _logger; | ||||
|  | ||||
|         public ReactionAdded(EmbedService embeds, ILogger logger, IDatabase db, ModelRepository repo) | ||||
|         public ReactionAdded(EmbedService embeds, ILogger logger, IDatabase db, ModelRepository repo, CommandMessageService commandMessageService) | ||||
|         { | ||||
|             _embeds = embeds; | ||||
|             _db = db; | ||||
|             _repo = repo; | ||||
|             _commandMessageService = commandMessageService; | ||||
|             _logger = logger.ForContext<ReactionAdded>(); | ||||
|         } | ||||
|  | ||||
| @@ -43,40 +45,54 @@ namespace PluralKit.Bot | ||||
|              | ||||
|             // Ignore reactions from bots (we can't DM them anyway) | ||||
|             if (evt.User.IsBot) return; | ||||
|  | ||||
|             Task<FullMessage> GetMessage() =>  | ||||
|                 _db.Execute(c => _repo.GetMessage(c, evt.Message.Id)); | ||||
|  | ||||
|             FullMessage msg; | ||||
|             CommandMessage cmdmsg; | ||||
|              | ||||
|             switch (evt.Emoji.Name) | ||||
|             { | ||||
|                 // Message deletion | ||||
|                 case "\u274C": // Red X | ||||
|                     if ((msg = await GetMessage()) != null) | ||||
|                         await HandleDeleteReaction(evt, msg); | ||||
|                     else if ((cmdmsg = await _db.Execute(conn => _repo.GetCommandMessage(conn, evt.Message.Id))) != null) | ||||
|                         await HandleCommandDeleteReaction(evt, cmdmsg); | ||||
|                     break;                 | ||||
|                  | ||||
|                 { | ||||
|                     await using var conn = await _db.Obtain(); | ||||
|                     var msg = await _repo.GetMessage(conn, evt.Message.Id); | ||||
|                     if (msg != null) | ||||
|                     { | ||||
|                         await HandleProxyDeleteReaction(evt, msg); | ||||
|                         break; | ||||
|                     } | ||||
|  | ||||
|                     var commandMsg = await _commandMessageService.GetCommandMessage(conn, evt.Message.Id); | ||||
|                     if (commandMsg != null)  | ||||
|                         await HandleCommandDeleteReaction(evt, commandMsg); | ||||
|  | ||||
|                     break; | ||||
|                 } | ||||
|  | ||||
|                 case "\u2753": // Red question mark | ||||
|                 case "\u2754": // White question mark | ||||
|                     if ((msg = await GetMessage()) != null)  | ||||
|                 { | ||||
|                     await using var conn = await _db.Obtain(); | ||||
|                     var msg = await _repo.GetMessage(conn, evt.Message.Id); | ||||
|                     if (msg != null) | ||||
|                         await HandleQueryReaction(evt, msg); | ||||
|                      | ||||
|                     break; | ||||
|                  | ||||
|                 } | ||||
|  | ||||
|                 case "\U0001F514": // Bell | ||||
|                 case "\U0001F6CE": // Bellhop bell | ||||
|                 case "\U0001F3D3": // Ping pong paddle (lol) | ||||
|                 case "\u23F0": // Alarm clock | ||||
|                 case "\u2757": // Exclamation mark | ||||
|                     if ((msg = await GetMessage()) != null) | ||||
|                 { | ||||
|                     await using var conn = await _db.Obtain(); | ||||
|                     var msg = await _repo.GetMessage(conn, evt.Message.Id); | ||||
|                     if (msg != null) | ||||
|                         await HandlePingReaction(evt, msg); | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private async ValueTask HandleDeleteReaction(MessageReactionAddEventArgs evt, FullMessage msg) | ||||
|         private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEventArgs evt, FullMessage msg) | ||||
|         { | ||||
|             if (!evt.Channel.BotHasAllPermissions(Permissions.ManageMessages)) return; | ||||
|              | ||||
| @@ -97,10 +113,12 @@ namespace PluralKit.Bot | ||||
|  | ||||
|         private async ValueTask HandleCommandDeleteReaction(MessageReactionAddEventArgs evt, CommandMessage msg) | ||||
|         { | ||||
|             if (!evt.Channel.BotHasAllPermissions(Permissions.ManageMessages)) return; | ||||
|             if (!evt.Channel.BotHasAllPermissions(Permissions.ManageMessages))  | ||||
|                 return; | ||||
|  | ||||
|             // Can only delete your own message | ||||
|             if (msg.author_id != evt.User.Id) return; | ||||
|             if (msg.AuthorId != evt.User.Id)  | ||||
|                 return; | ||||
|  | ||||
|             try | ||||
|             { | ||||
|   | ||||
| @@ -74,6 +74,7 @@ namespace PluralKit.Bot | ||||
|             builder.RegisterType<LastMessageCacheService>().AsSelf().SingleInstance(); | ||||
|             builder.RegisterType<LoggerCleanService>().AsSelf().SingleInstance(); | ||||
|             builder.RegisterType<ErrorMessageService>().AsSelf().SingleInstance(); | ||||
|             builder.RegisterType<CommandMessageService>().AsSelf().SingleInstance(); | ||||
|              | ||||
|             // Sentry stuff | ||||
|             builder.Register(_ => new Scope(null)).AsSelf().InstancePerLifetimeScope(); | ||||
|   | ||||
							
								
								
									
										50
									
								
								PluralKit.Bot/Services/CommandMessageService.cs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								PluralKit.Bot/Services/CommandMessageService.cs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| using System.Threading.Tasks; | ||||
|  | ||||
| using NodaTime; | ||||
|  | ||||
| using PluralKit.Core; | ||||
|  | ||||
| using Serilog; | ||||
|  | ||||
| namespace PluralKit.Bot | ||||
| { | ||||
|     public class CommandMessageService | ||||
|     { | ||||
|         private static readonly Duration CommandMessageRetention = Duration.FromSeconds(2); | ||||
|          | ||||
|         private readonly IDatabase _db; | ||||
|         private readonly ModelRepository _repo; | ||||
|         private readonly IClock _clock; | ||||
|         private readonly ILogger _logger; | ||||
|          | ||||
|         public CommandMessageService(IDatabase db, ModelRepository repo, IClock clock, ILogger logger) | ||||
|         { | ||||
|             _db = db; | ||||
|             _repo = repo; | ||||
|             _clock = clock; | ||||
|             _logger = logger; | ||||
|         } | ||||
|  | ||||
|         public async Task RegisterMessage(ulong messageId, ulong authorId) | ||||
|         { | ||||
|             _logger.Debug("Registering command response {MessageId} from author {AuthorId}", messageId, authorId); | ||||
|             await _db.Execute(conn => _repo.SaveCommandMessage(conn, messageId, authorId)); | ||||
|         } | ||||
|  | ||||
|         public async Task<CommandMessage> GetCommandMessage(IPKConnection conn, ulong messageId) | ||||
|         { | ||||
|             return await _repo.GetCommandMessage(conn, messageId); | ||||
|         } | ||||
|  | ||||
|         public async Task CleanupOldMessages() | ||||
|         { | ||||
|             var deleteThresholdInstant = _clock.GetCurrentInstant() - CommandMessageRetention; | ||||
|             var deleteThresholdSnowflake = DiscordUtils.InstantToSnowflake(deleteThresholdInstant); | ||||
|  | ||||
|             var deletedRows = await _db.Execute(conn => _repo.DeleteCommandMessagesBefore(conn, deleteThresholdSnowflake)); | ||||
|              | ||||
|             _logger.Information("Pruned {DeletedRows} command messages older than retention {Retention} (older than {DeleteThresholdInstant} / {DeleteThresholdSnowflake})", | ||||
|                 deletedRows, CommandMessageRetention, deleteThresholdInstant, deleteThresholdSnowflake); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -1,17 +1,10 @@ | ||||
| -- SCHEMA VERSION 11: (insert date) -- | ||||
| -- SCHEMA VERSION 11: 2020-10-23  -- | ||||
| -- Create command message table -- | ||||
|  | ||||
| create table command_message | ||||
| create table command_messages | ||||
| ( | ||||
| 	message_id bigint primary key, | ||||
| 	author_id bigint not null, | ||||
| 	timestamp timestamp not null default now() | ||||
| 	message_id bigint primary key not null, | ||||
| 	author_id bigint not null | ||||
| ); | ||||
|  | ||||
| create function cleanup_command_message() returns void as $$ | ||||
| begin | ||||
|     delete from command_message where timestamp < now() - interval '2 hours'; | ||||
| end; | ||||
| $$ language plpgsql; | ||||
|  | ||||
| update info set schema_version = 11; | ||||
|   | ||||
| @@ -1,5 +1,3 @@ | ||||
| using System.Collections.Generic; | ||||
| using System.Data; | ||||
| using System.Threading.Tasks; | ||||
|  | ||||
| using Dapper; | ||||
| @@ -8,17 +6,22 @@ namespace PluralKit.Core | ||||
| { | ||||
|     public partial class ModelRepository | ||||
| 	{ | ||||
| 		public Task SaveCommandMessage(IPKConnection conn, ulong message_id, ulong author_id) => | ||||
| 			conn.QueryAsync("insert into command_message (message_id, author_id) values (@Message, @Author)", | ||||
| 				new {Message = message_id, Author = author_id }); | ||||
| 		public Task SaveCommandMessage(IPKConnection conn, ulong messageId, ulong authorId) => | ||||
| 			conn.QueryAsync("insert into command_messages (message_id, author_id) values (@Message, @Author)", | ||||
| 				new {Message = messageId, Author = authorId }); | ||||
|  | ||||
| 		public Task<CommandMessage> GetCommandMessage(IPKConnection conn, ulong message_id) => | ||||
| 			conn.QuerySingleOrDefaultAsync<CommandMessage>("select message_id, author_id from command_message where message_id = @Message",  | ||||
| 				new {Message = message_id}); | ||||
| 	} | ||||
| 		public Task<CommandMessage> GetCommandMessage(IPKConnection conn, ulong messageId) => | ||||
| 			conn.QuerySingleOrDefaultAsync<CommandMessage>("select * from command_messages where message_id = @Message",  | ||||
| 				new {Message = messageId}); | ||||
|  | ||||
|         public Task<int> DeleteCommandMessagesBefore(IPKConnection conn, ulong messageIdThreshold) => | ||||
|             conn.ExecuteAsync("delete from command_messages where message_id < @Threshold", | ||||
|                 new {Threshold = messageIdThreshold}); | ||||
|     } | ||||
|  | ||||
| 	public class CommandMessage | ||||
| 	{ | ||||
| 		public ulong author_id { get; set; } | ||||
|         public ulong AuthorId { get; set; } | ||||
|         public ulong MessageId { get; set; } | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user