diff --git a/Myriad/Rest/DiscordApiClient.cs b/Myriad/Rest/DiscordApiClient.cs index 4612fd2c..6ab105a5 100644 --- a/Myriad/Rest/DiscordApiClient.cs +++ b/Myriad/Rest/DiscordApiClient.cs @@ -121,6 +121,11 @@ namespace Myriad.Rest _client.PostMultipart($"/webhooks/{webhookId}/{webhookToken}?wait=true", ("ExecuteWebhook", webhookId), request, files)!; + public Task EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId, + WebhookMessageEditRequest request) => + _client.Patch($"/webhooks/{webhookId}/{webhookToken}/messages/{messageId}", + ("EditWebhookMessage", webhookId), request)!; + public Task CreateDm(ulong recipientId) => _client.Post($"/users/@me/channels", ("CreateDM", default), new CreateDmRequest(recipientId))!; diff --git a/Myriad/Rest/Types/Requests/WebhookMessageEditRequest.cs b/Myriad/Rest/Types/Requests/WebhookMessageEditRequest.cs new file mode 100644 index 00000000..039ac625 --- /dev/null +++ b/Myriad/Rest/Types/Requests/WebhookMessageEditRequest.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +using Myriad.Utils; + +namespace Myriad.Rest.Types.Requests +{ + public record WebhookMessageEditRequest + { + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public Optional Content { get; init; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public Optional AllowedMentions { get; init; } + } +} \ No newline at end of file diff --git a/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs b/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs index 3e1b2572..06dfa5fe 100644 --- a/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs +++ b/PluralKit.Bot/CommandSystem/ContextArgumentsExt.cs @@ -1,8 +1,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.RegularExpressions; using System.Threading.Tasks; +using Myriad.Types; + using PluralKit.Core; namespace PluralKit.Bot @@ -68,6 +71,27 @@ namespace PluralKit.Bot return matched; } + public static ulong? MatchMessage(this Context ctx, bool parseRawMessageId) + { + if (ctx.Message.Type == Message.MessageType.Reply && ctx.Message.MessageReference != null) + return ctx.Message.MessageReference.MessageId; + + var word = ctx.PeekArgument(); + if (word == null) + return null; + + if (parseRawMessageId && ulong.TryParse(word, out var mid)) + return mid; + + var match = Regex.Match(word, "https://(?:\\w+.)?discord(?:app)?.com/channels/\\d+/\\d+/(\\d+)"); + if (!match.Success) + return null; + + var messageId = ulong.Parse(match.Groups[1].Value); + ctx.PopArgument(); + return messageId; + } + public static async Task> ParseMemberList(this Context ctx, SystemId? restrictToSystem) { var members = new List(); diff --git a/PluralKit.Bot/Commands/CommandTree.cs b/PluralKit.Bot/Commands/CommandTree.cs index 872bc86b..717a5667 100644 --- a/PluralKit.Bot/Commands/CommandTree.cs +++ b/PluralKit.Bot/Commands/CommandTree.cs @@ -79,6 +79,7 @@ namespace PluralKit.Bot public static Command Help = new Command("help", "help", "Shows help information about PluralKit"); public static Command Explain = new Command("explain", "explain", "Explains the basics of systems and proxying"); public static Command Message = new Command("message", "message [delete|author]", "Looks up a proxied message"); + public static Command MessageEdit = new Command("edit", "edit [link] ", "Edit a previously proxied message"); public static Command LogChannel = new Command("log channel", "log channel ", "Designates a channel to post proxied messages to"); public static Command LogChannelClear = new Command("log channel", "log channel -clear", "Clears the currently set log channel"); public static Command LogEnable = new Command("log enable", "log enable all| [channel 2] [channel 3...]", "Enables message logging in certain channels"); @@ -160,6 +161,8 @@ namespace PluralKit.Bot return ctx.Execute(Explain, m => m.Explain(ctx)); if (ctx.Match("message", "msg")) return ctx.Execute(Message, m => m.GetMessage(ctx)); + if (ctx.Match("edit", "e")) + return ctx.Execute(MessageEdit, m => m.EditMessage(ctx)); if (ctx.Match("log")) if (ctx.Match("channel")) return ctx.Execute(LogChannel, m => m.SetLogChannel(ctx)); diff --git a/PluralKit.Bot/Commands/MessageEdit.cs b/PluralKit.Bot/Commands/MessageEdit.cs new file mode 100644 index 00000000..c979bc59 --- /dev/null +++ b/PluralKit.Bot/Commands/MessageEdit.cs @@ -0,0 +1,91 @@ +#nullable enable +using System.Threading.Tasks; + +using Myriad.Rest; +using Myriad.Rest.Exceptions; +using Myriad.Types; + +using NodaTime; + +using PluralKit.Core; + +namespace PluralKit.Bot +{ + public class MessageEdit + { + private static readonly Duration EditTimeout = Duration.FromMinutes(10); + + private readonly IDatabase _db; + private readonly ModelRepository _repo; + private readonly IClock _clock; + private readonly DiscordApiClient _rest; + private readonly WebhookExecutorService _webhookExecutor; + + public MessageEdit(IDatabase db, ModelRepository repo, IClock clock, DiscordApiClient rest, WebhookExecutorService webhookExecutor) + { + _db = db; + _repo = repo; + _clock = clock; + _rest = rest; + _webhookExecutor = webhookExecutor; + } + + public async Task EditMessage(Context ctx) + { + var msg = await GetMessageToEdit(ctx); + if (!ctx.HasNext()) + throw new PKSyntaxError("You need to include the message to edit in."); + + if (ctx.Author.Id != msg.Sender) + throw new PKError("Can't edit a message sent from a different account."); + + var newContent = ctx.RemainderOrNull(); + + try + { + await _webhookExecutor.EditWebhookMessage(msg.Channel, msg.Mid, newContent); + + if (ctx.BotPermissions.HasFlag(PermissionSet.ManageMessages)) + await _rest.DeleteMessage(ctx.Channel.Id, ctx.Message.Id); + } + catch (NotFoundException) + { + throw new PKError("Could not edit message."); + } + } + + private async Task GetMessageToEdit(Context ctx) + { + var referencedMessage = ctx.MatchMessage(false); + if (referencedMessage != null) + { + await using var conn = await _db.Obtain(); + var msg = await _repo.GetMessage(conn, referencedMessage.Value); + if (msg == null) + throw new PKError("This is not a message proxied by PluralKit."); + + return msg.Message; + } + + var recent = await FindRecentMessage(ctx); + if (recent == null) + throw new PKError("Could not find a recent message to edit."); + + return recent; + } + + private async Task FindRecentMessage(Context ctx) + { + await using var conn = await _db.Obtain(); + var lastMessage = await _repo.GetLastMessage(conn, ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id); + if (lastMessage == null) + return null; + + var timestamp = DiscordUtils.SnowflakeToInstant(lastMessage.Mid); + if (_clock.GetCurrentInstant() - timestamp > EditTimeout) + return null; + + return lastMessage; + } + } +} \ No newline at end of file diff --git a/PluralKit.Bot/Commands/Misc.cs b/PluralKit.Bot/Commands/Misc.cs index fda8e14c..7d208954 100644 --- a/PluralKit.Bot/Commands/Misc.cs +++ b/PluralKit.Bot/Commands/Misc.cs @@ -215,17 +215,16 @@ namespace PluralKit.Bot { public async Task GetMessage(Context ctx) { - var word = ctx.PopArgument() ?? throw new PKSyntaxError("You must pass a message ID or link."); - - ulong messageId; - if (ulong.TryParse(word, out var id)) - messageId = id; - else if (Regex.Match(word, "https://(?:\\w+.)?discord(?:app)?.com/channels/\\d+/\\d+/(\\d+)") is Match match && match.Success) - messageId = ulong.Parse(match.Groups[1].Value); - else throw new PKSyntaxError($"Could not parse {word.AsCode()} as a message ID or link."); - - var message = await _db.Execute(c => _repo.GetMessage(c, messageId)); - if (message == null) throw Errors.MessageNotFound(messageId); + var messageId = ctx.MatchMessage(true); + if (messageId == null) + { + if (!ctx.HasNext()) + throw new PKSyntaxError("You must pass a message ID or link."); + throw new PKSyntaxError($"Could not parse {ctx.PeekArgument().AsCode()} as a message ID or link."); + } + + var message = await _db.Execute(c => _repo.GetMessage(c, messageId.Value)); + if (message == null) throw Errors.MessageNotFound(messageId.Value); if (ctx.Match("delete") || ctx.MatchFlag("delete")) { diff --git a/PluralKit.Bot/Modules.cs b/PluralKit.Bot/Modules.cs index cef29570..f96a64a8 100644 --- a/PluralKit.Bot/Modules.cs +++ b/PluralKit.Bot/Modules.cs @@ -54,6 +54,7 @@ namespace PluralKit.Bot builder.RegisterType().AsSelf(); builder.RegisterType().AsSelf(); builder.RegisterType().AsSelf(); + builder.RegisterType().AsSelf(); builder.RegisterType().AsSelf(); builder.RegisterType().AsSelf(); builder.RegisterType().AsSelf(); diff --git a/PluralKit.Bot/Services/WebhookExecutorService.cs b/PluralKit.Bot/Services/WebhookExecutorService.cs index 39379cdf..bbc4c20a 100644 --- a/PluralKit.Bot/Services/WebhookExecutorService.cs +++ b/PluralKit.Bot/Services/WebhookExecutorService.cs @@ -76,6 +76,18 @@ namespace PluralKit.Bot return webhookMessage; } + public async Task EditWebhookMessage(ulong channelId, ulong messageId, string newContent) + { + var webhook = await _webhookCache.GetWebhook(channelId); + var allowedMentions = newContent.ParseMentions() with { + Roles = Array.Empty(), + Parse = Array.Empty() + }; + + return await _rest.EditWebhookMessage(webhook.Id, webhook.Token, messageId, + new WebhookMessageEditRequest {Content = newContent, AllowedMentions = allowedMentions}); + } + private async Task ExecuteWebhookInner(Webhook webhook, ProxyRequest req, bool hasRetried = false) { var guild = _cache.GetGuild(req.GuildId); diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Message.cs b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs index 9acf8be2..2622ef38 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Message.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs @@ -42,6 +42,18 @@ namespace PluralKit.Core _logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount, ids); } + + public async Task GetLastMessage(IPKConnection conn, ulong guildId, ulong channelId, ulong accountId) + { + // Want to index scan on the (guild, sender, mid) index so need the additional constraint + return await conn.QuerySingleOrDefaultAsync( + "select * from messages where guild = @Guild and channel = @Channel and sender = @Sender order by mid desc limit 1", new + { + Guild = guildId, + Channel = channelId, + Sender = accountId + }); + } } public class PKMessage