From 2579683da90f58f30094e68028c19912771d0160 Mon Sep 17 00:00:00 2001 From: Ske Date: Tue, 5 May 2020 16:03:46 +0200 Subject: [PATCH] Refactor interactive event handlers --- PluralKit.Bot/Bot.cs | 12 ++-- PluralKit.Bot/CommandSystem/Context.cs | 2 + PluralKit.Bot/Commands/SystemEdit.cs | 6 +- PluralKit.Bot/Modules.cs | 4 ++ PluralKit.Bot/Utils/ContextUtils.cs | 87 +++++++++++++++++--------- PluralKit.Bot/Utils/HandlerQueue.cs | 79 +++++++++++++++++++++++ 6 files changed, 155 insertions(+), 35 deletions(-) create mode 100644 PluralKit.Bot/Utils/HandlerQueue.cs diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index ec4a1d32..3035908a 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -72,16 +72,20 @@ namespace PluralKit.Bot { var serviceScope = _services.BeginLifetimeScope(); - // Find an event handler that can handle the type of event () we're given - var handler = serviceScope.Resolve>(); - // Also, find a Sentry enricher for the event type (if one is present), and ask it to put some event data in the Sentry scope var sentryEnricher = serviceScope.ResolveOptional>(); sentryEnricher?.Enrich(serviceScope.Resolve(), evt); + // Find an event handler that can handle the type of event () we're given + var handler = serviceScope.Resolve>(); + var queue = serviceScope.ResolveOptional>(); try { - await handler.Handle(evt); + // Delegate to the queue to see if it wants to handle this event + // the TryHandle call returns true if it's handled the event + // Usually it won't, so just pass it on to the main handler + if (queue == null || !await queue.TryHandle(evt)) + await handler.Handle(evt); } catch (Exception exc) { diff --git a/PluralKit.Bot/CommandSystem/Context.cs b/PluralKit.Bot/CommandSystem/Context.cs index 54c3e7c0..eeabda93 100644 --- a/PluralKit.Bot/CommandSystem/Context.cs +++ b/PluralKit.Bot/CommandSystem/Context.cs @@ -304,5 +304,7 @@ namespace PluralKit.Bot return null; } } + + public IComponentContext Services => _provider; } } \ No newline at end of file diff --git a/PluralKit.Bot/Commands/SystemEdit.cs b/PluralKit.Bot/Commands/SystemEdit.cs index d5cf1645..a4bd18b0 100644 --- a/PluralKit.Bot/Commands/SystemEdit.cs +++ b/PluralKit.Bot/Commands/SystemEdit.cs @@ -169,9 +169,9 @@ namespace PluralKit.Bot public async Task Delete(Context ctx) { ctx.CheckSystem(); - var msg = await ctx.Reply($"{Emojis.Warn} Are you sure you want to delete your system? If so, reply to this message with your system's ID (`{ctx.System.Hid}`).\n**Note: this action is permanent.**"); - var reply = await ctx.AwaitMessage(ctx.Channel, ctx.Author, timeout: TimeSpan.FromMinutes(1)); - if (reply.Content != ctx.System.Hid) throw new PKError($"System deletion cancelled. Note that you must reply with your system ID (`{ctx.System.Hid}`) *verbatim*."); + await ctx.Reply($"{Emojis.Warn} Are you sure you want to delete your system? If so, reply to this message with your system's ID (`{ctx.System.Hid}`).\n**Note: this action is permanent.**"); + if (!await ctx.ConfirmWithReply(ctx.System.Hid)) + throw new PKError($"System deletion cancelled. Note that you must reply with your system ID (`{ctx.System.Hid}`) *verbatim*."); await _data.DeleteSystem(ctx.System); await ctx.Reply($"{Emojis.Success} System deleted."); diff --git a/PluralKit.Bot/Modules.cs b/PluralKit.Bot/Modules.cs index 4de422cc..01dc12cf 100644 --- a/PluralKit.Bot/Modules.cs +++ b/PluralKit.Bot/Modules.cs @@ -54,6 +54,10 @@ namespace PluralKit.Bot builder.RegisterType().As>(); builder.RegisterType().As>(); + // Event handler queue + builder.RegisterType>().AsSelf().SingleInstance(); + builder.RegisterType>().AsSelf().SingleInstance(); + // Bot services builder.RegisterType().AsSelf().SingleInstance(); builder.RegisterType().AsSelf().SingleInstance(); diff --git a/PluralKit.Bot/Utils/ContextUtils.cs b/PluralKit.Bot/Utils/ContextUtils.cs index 71123bd2..9c0067e1 100644 --- a/PluralKit.Bot/Utils/ContextUtils.cs +++ b/PluralKit.Bot/Utils/ContextUtils.cs @@ -1,22 +1,71 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; +using Autofac; + using DSharpPlus; using DSharpPlus.Entities; using DSharpPlus.EventArgs; using DSharpPlus.Exceptions; +using NodaTime; + using PluralKit.Core; namespace PluralKit.Bot { public static class ContextUtils { - public static async Task PromptYesNo(this Context ctx, DiscordMessage message, DiscordUser user = null, TimeSpan? timeout = null) { + public static async Task PromptYesNo(this Context ctx, DiscordMessage message, DiscordUser user = null, Duration? timeout = null) + { + var cts = new CancellationTokenSource(); + if (user == null) user = ctx.Author; + if (timeout == null) timeout = Duration.FromMinutes(5); + // "Fork" the task adding the reactions off so we don't have to wait for them to be finished to start listening for presses var _ = message.CreateReactionsBulk(new[] {Emojis.Success, Emojis.Error}); - var reaction = await ctx.AwaitReaction(message, user ?? ctx.Author, r => r.Emoji.Name == Emojis.Success || r.Emoji.Name == Emojis.Error, timeout ?? TimeSpan.FromMinutes(1)); - return reaction.Emoji.Name == Emojis.Success; + + bool ReactionPredicate(MessageReactionAddEventArgs e) + { + if (e.Channel.Id != message.ChannelId || e.Message.Id != message.Id) return false; + if (e.User.Id != user.Id) return false; + return true; + } + + bool MessagePredicate(MessageCreateEventArgs e) + { + if (e.Channel.Id != message.ChannelId) return false; + if (e.Author.Id != user.Id) return false; + + var strings = new [] {"y", "yes", "n", "no"}; + foreach (var str in strings) + if (e.Message.Content.Equals(str, StringComparison.InvariantCultureIgnoreCase)) + return true; + + return false; + } + + var messageTask = ctx.Services.Resolve>().WaitFor(MessagePredicate, timeout, cts.Token); + var reactionTask = ctx.Services.Resolve>().WaitFor(ReactionPredicate, timeout, cts.Token); + + var theTask = await Task.WhenAny(messageTask, reactionTask); + cts.Cancel(); + + if (theTask == messageTask) + { + var responseMsg = (await messageTask).Message; + var positives = new[] {"y", "yes"}; + foreach (var p in positives) + if (responseMsg.Content.Equals(p, StringComparison.InvariantCultureIgnoreCase)) + return true; + return false; + } + + if (theTask == reactionTask) + return (await reactionTask).Emoji.Name == Emojis.Success; + + return false; } public static async Task AwaitReaction(this Context ctx, DiscordMessage message, DiscordUser user = null, Func predicate = null, TimeSpan? timeout = null) { @@ -37,33 +86,15 @@ namespace PluralKit.Bot { } } - public static async Task AwaitMessage(this Context ctx, DiscordChannel channel, DiscordUser user = null, Func predicate = null, TimeSpan? timeout = null) { - var tcs = new TaskCompletionSource(); - Task Inner(MessageCreateEventArgs args) - { - var msg = args.Message; - if (channel != msg.Channel) return Task.CompletedTask; // Ignore messages in a different channel - if (user != null && user != msg.Author) return Task.CompletedTask; // Ignore messages from other users - if (predicate != null && !predicate.Invoke(msg)) return Task.CompletedTask; // Check predicate - tcs.SetResult(msg); - return Task.CompletedTask; - } - - ctx.Shard.MessageCreated += Inner; - try - { - return await (tcs.Task.TimeoutAfter(timeout)); - } - finally - { - ctx.Shard.MessageCreated -= Inner; - } - } - public static async Task ConfirmWithReply(this Context ctx, string expectedReply) { - var msg = await ctx.AwaitMessage(ctx.Channel, ctx.Author, timeout: TimeSpan.FromMinutes(1)); - return string.Equals(msg.Content, expectedReply, StringComparison.InvariantCultureIgnoreCase); + bool Predicate(MessageCreateEventArgs e) => + e.Author == ctx.Author && e.Channel.Id == ctx.Channel.Id; + + var msg = await ctx.Services.Resolve>() + .WaitFor(Predicate, Duration.FromMinutes(1)); + + return string.Equals(msg.Message.Content, expectedReply, StringComparison.InvariantCultureIgnoreCase); } public static async Task Paginate(this Context ctx, IAsyncEnumerable items, int totalCount, int itemsPerPage, string title, Func, Task> renderer) { diff --git a/PluralKit.Bot/Utils/HandlerQueue.cs b/PluralKit.Bot/Utils/HandlerQueue.cs new file mode 100644 index 00000000..732194c5 --- /dev/null +++ b/PluralKit.Bot/Utils/HandlerQueue.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +using NodaTime; + +namespace PluralKit.Bot +{ + public class HandlerQueue + { + private readonly List _handlers = new List(); + + public HandlerEntry Add(Func> handler) + { + var entry = new HandlerEntry {Handler = handler}; + _handlers.Add(entry); + return entry; + } + + public async Task WaitFor(Func predicate, Duration? timeout = null, CancellationToken ct = default) + { + var timeoutTask = Task.Delay(timeout?.ToTimeSpan() ?? TimeSpan.FromMilliseconds(-1), ct); + var tcs = new TaskCompletionSource(); + + Task Handler(T e) + { + var matches = predicate(e); + if (matches) tcs.SetResult(e); + return Task.FromResult(matches); + } + + var entry = new HandlerEntry {Handler = Handler}; + _handlers.Add(entry); + + // Wait for either the event task or the timeout task + // If the timeout task finishes first, raise, otherwise pass event through + try + { + var theTask = await Task.WhenAny(timeoutTask, tcs.Task); + if (theTask == timeoutTask) + throw new TimeoutException(); + } + finally + { + entry.Remove(); + } + + return await tcs.Task; + } + + public async Task TryHandle(T evt) + { + _handlers.RemoveAll(he => !he.Alive); + + var now = SystemClock.Instance.GetCurrentInstant(); + foreach (var entry in _handlers) + { + if (entry.Expiry < now) entry.Alive = false; + else if (entry.Alive && await entry.Handler(evt)) + { + entry.Alive = false; + return true; + } + } + + return false; + } + + public class HandlerEntry + { + internal Func> Handler; + internal bool Alive = true; + internal Instant Expiry = SystemClock.Instance.GetCurrentInstant() + Duration.FromMinutes(30); + + public void Remove() => Alive = false; + } + } +} \ No newline at end of file