Add message editing command

Signed-off-by: Ske <voltasalt@gmail.com>
This commit is contained in:
Ske 2021-05-03 12:33:30 +02:00
parent 33cabff359
commit 3d624b39e4
9 changed files with 173 additions and 11 deletions

View File

@ -121,6 +121,11 @@ namespace Myriad.Rest
_client.PostMultipart<Message>($"/webhooks/{webhookId}/{webhookToken}?wait=true", _client.PostMultipart<Message>($"/webhooks/{webhookId}/{webhookToken}?wait=true",
("ExecuteWebhook", webhookId), request, files)!; ("ExecuteWebhook", webhookId), request, files)!;
public Task<Message> EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId,
WebhookMessageEditRequest request) =>
_client.Patch<Message>($"/webhooks/{webhookId}/{webhookToken}/messages/{messageId}",
("EditWebhookMessage", webhookId), request)!;
public Task<Channel> CreateDm(ulong recipientId) => public Task<Channel> CreateDm(ulong recipientId) =>
_client.Post<Channel>($"/users/@me/channels", ("CreateDM", default), new CreateDmRequest(recipientId))!; _client.Post<Channel>($"/users/@me/channels", ("CreateDM", default), new CreateDmRequest(recipientId))!;

View File

@ -0,0 +1,15 @@
using System.Text.Json.Serialization;
using Myriad.Utils;
namespace Myriad.Rest.Types.Requests
{
public record WebhookMessageEditRequest
{
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<string?> Content { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<AllowedMentions> AllowedMentions { get; init; }
}
}

View File

@ -1,8 +1,11 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text.RegularExpressions;
using System.Threading.Tasks; using System.Threading.Tasks;
using Myriad.Types;
using PluralKit.Core; using PluralKit.Core;
namespace PluralKit.Bot namespace PluralKit.Bot
@ -68,6 +71,27 @@ namespace PluralKit.Bot
return matched; return matched;
} }
public static ulong? MatchMessage(this Context ctx, bool parseRawMessageId)
{
if (ctx.Message.Type == Message.MessageType.Reply && ctx.Message.MessageReference != null)
return ctx.Message.MessageReference.MessageId;
var word = ctx.PeekArgument();
if (word == null)
return null;
if (parseRawMessageId && ulong.TryParse(word, out var mid))
return mid;
var match = Regex.Match(word, "https://(?:\\w+.)?discord(?:app)?.com/channels/\\d+/\\d+/(\\d+)");
if (!match.Success)
return null;
var messageId = ulong.Parse(match.Groups[1].Value);
ctx.PopArgument();
return messageId;
}
public static async Task<List<PKMember>> ParseMemberList(this Context ctx, SystemId? restrictToSystem) public static async Task<List<PKMember>> ParseMemberList(this Context ctx, SystemId? restrictToSystem)
{ {
var members = new List<PKMember>(); var members = new List<PKMember>();

View File

@ -79,6 +79,7 @@ namespace PluralKit.Bot
public static Command Help = new Command("help", "help", "Shows help information about PluralKit"); public static Command Help = new Command("help", "help", "Shows help information about PluralKit");
public static Command Explain = new Command("explain", "explain", "Explains the basics of systems and proxying"); public static Command Explain = new Command("explain", "explain", "Explains the basics of systems and proxying");
public static Command Message = new Command("message", "message <id|link> [delete|author]", "Looks up a proxied message"); public static Command Message = new Command("message", "message <id|link> [delete|author]", "Looks up a proxied message");
public static Command MessageEdit = new Command("edit", "edit [link] <text>", "Edit a previously proxied message");
public static Command LogChannel = new Command("log channel", "log channel <channel>", "Designates a channel to post proxied messages to"); public static Command LogChannel = new Command("log channel", "log channel <channel>", "Designates a channel to post proxied messages to");
public static Command LogChannelClear = new Command("log channel", "log channel -clear", "Clears the currently set log channel"); public static Command LogChannelClear = new Command("log channel", "log channel -clear", "Clears the currently set log channel");
public static Command LogEnable = new Command("log enable", "log enable all|<channel> [channel 2] [channel 3...]", "Enables message logging in certain channels"); public static Command LogEnable = new Command("log enable", "log enable all|<channel> [channel 2] [channel 3...]", "Enables message logging in certain channels");
@ -160,6 +161,8 @@ namespace PluralKit.Bot
return ctx.Execute<Help>(Explain, m => m.Explain(ctx)); return ctx.Execute<Help>(Explain, m => m.Explain(ctx));
if (ctx.Match("message", "msg")) if (ctx.Match("message", "msg"))
return ctx.Execute<Misc>(Message, m => m.GetMessage(ctx)); return ctx.Execute<Misc>(Message, m => m.GetMessage(ctx));
if (ctx.Match("edit", "e"))
return ctx.Execute<MessageEdit>(MessageEdit, m => m.EditMessage(ctx));
if (ctx.Match("log")) if (ctx.Match("log"))
if (ctx.Match("channel")) if (ctx.Match("channel"))
return ctx.Execute<ServerConfig>(LogChannel, m => m.SetLogChannel(ctx)); return ctx.Execute<ServerConfig>(LogChannel, m => m.SetLogChannel(ctx));

View File

@ -0,0 +1,91 @@
#nullable enable
using System.Threading.Tasks;
using Myriad.Rest;
using Myriad.Rest.Exceptions;
using Myriad.Types;
using NodaTime;
using PluralKit.Core;
namespace PluralKit.Bot
{
public class MessageEdit
{
private static readonly Duration EditTimeout = Duration.FromMinutes(10);
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly IClock _clock;
private readonly DiscordApiClient _rest;
private readonly WebhookExecutorService _webhookExecutor;
public MessageEdit(IDatabase db, ModelRepository repo, IClock clock, DiscordApiClient rest, WebhookExecutorService webhookExecutor)
{
_db = db;
_repo = repo;
_clock = clock;
_rest = rest;
_webhookExecutor = webhookExecutor;
}
public async Task EditMessage(Context ctx)
{
var msg = await GetMessageToEdit(ctx);
if (!ctx.HasNext())
throw new PKSyntaxError("You need to include the message to edit in.");
if (ctx.Author.Id != msg.Sender)
throw new PKError("Can't edit a message sent from a different account.");
var newContent = ctx.RemainderOrNull();
try
{
await _webhookExecutor.EditWebhookMessage(msg.Channel, msg.Mid, newContent);
if (ctx.BotPermissions.HasFlag(PermissionSet.ManageMessages))
await _rest.DeleteMessage(ctx.Channel.Id, ctx.Message.Id);
}
catch (NotFoundException)
{
throw new PKError("Could not edit message.");
}
}
private async Task<PKMessage> GetMessageToEdit(Context ctx)
{
var referencedMessage = ctx.MatchMessage(false);
if (referencedMessage != null)
{
await using var conn = await _db.Obtain();
var msg = await _repo.GetMessage(conn, referencedMessage.Value);
if (msg == null)
throw new PKError("This is not a message proxied by PluralKit.");
return msg.Message;
}
var recent = await FindRecentMessage(ctx);
if (recent == null)
throw new PKError("Could not find a recent message to edit.");
return recent;
}
private async Task<PKMessage?> FindRecentMessage(Context ctx)
{
await using var conn = await _db.Obtain();
var lastMessage = await _repo.GetLastMessage(conn, ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id);
if (lastMessage == null)
return null;
var timestamp = DiscordUtils.SnowflakeToInstant(lastMessage.Mid);
if (_clock.GetCurrentInstant() - timestamp > EditTimeout)
return null;
return lastMessage;
}
}
}

View File

@ -215,17 +215,16 @@ namespace PluralKit.Bot {
public async Task GetMessage(Context ctx) public async Task GetMessage(Context ctx)
{ {
var word = ctx.PopArgument() ?? throw new PKSyntaxError("You must pass a message ID or link."); var messageId = ctx.MatchMessage(true);
if (messageId == null)
{
if (!ctx.HasNext())
throw new PKSyntaxError("You must pass a message ID or link.");
throw new PKSyntaxError($"Could not parse {ctx.PeekArgument().AsCode()} as a message ID or link.");
}
ulong messageId; var message = await _db.Execute(c => _repo.GetMessage(c, messageId.Value));
if (ulong.TryParse(word, out var id)) if (message == null) throw Errors.MessageNotFound(messageId.Value);
messageId = id;
else if (Regex.Match(word, "https://(?:\\w+.)?discord(?:app)?.com/channels/\\d+/\\d+/(\\d+)") is Match match && match.Success)
messageId = ulong.Parse(match.Groups[1].Value);
else throw new PKSyntaxError($"Could not parse {word.AsCode()} as a message ID or link.");
var message = await _db.Execute(c => _repo.GetMessage(c, messageId));
if (message == null) throw Errors.MessageNotFound(messageId);
if (ctx.Match("delete") || ctx.MatchFlag("delete")) if (ctx.Match("delete") || ctx.MatchFlag("delete"))
{ {

View File

@ -54,6 +54,7 @@ namespace PluralKit.Bot
builder.RegisterType<MemberEdit>().AsSelf(); builder.RegisterType<MemberEdit>().AsSelf();
builder.RegisterType<MemberGroup>().AsSelf(); builder.RegisterType<MemberGroup>().AsSelf();
builder.RegisterType<MemberProxy>().AsSelf(); builder.RegisterType<MemberProxy>().AsSelf();
builder.RegisterType<MessageEdit>().AsSelf();
builder.RegisterType<Misc>().AsSelf(); builder.RegisterType<Misc>().AsSelf();
builder.RegisterType<Random>().AsSelf(); builder.RegisterType<Random>().AsSelf();
builder.RegisterType<ServerConfig>().AsSelf(); builder.RegisterType<ServerConfig>().AsSelf();

View File

@ -76,6 +76,18 @@ namespace PluralKit.Bot
return webhookMessage; return webhookMessage;
} }
public async Task<Message> EditWebhookMessage(ulong channelId, ulong messageId, string newContent)
{
var webhook = await _webhookCache.GetWebhook(channelId);
var allowedMentions = newContent.ParseMentions() with {
Roles = Array.Empty<ulong>(),
Parse = Array.Empty<AllowedMentions.ParseType>()
};
return await _rest.EditWebhookMessage(webhook.Id, webhook.Token, messageId,
new WebhookMessageEditRequest {Content = newContent, AllowedMentions = allowedMentions});
}
private async Task<Message> ExecuteWebhookInner(Webhook webhook, ProxyRequest req, bool hasRetried = false) private async Task<Message> ExecuteWebhookInner(Webhook webhook, ProxyRequest req, bool hasRetried = false)
{ {
var guild = _cache.GetGuild(req.GuildId); var guild = _cache.GetGuild(req.GuildId);

View File

@ -42,6 +42,18 @@ namespace PluralKit.Core
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount, _logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
ids); ids);
} }
public async Task<PKMessage?> GetLastMessage(IPKConnection conn, ulong guildId, ulong channelId, ulong accountId)
{
// Want to index scan on the (guild, sender, mid) index so need the additional constraint
return await conn.QuerySingleOrDefaultAsync<PKMessage>(
"select * from messages where guild = @Guild and channel = @Channel and sender = @Sender order by mid desc limit 1", new
{
Guild = guildId,
Channel = channelId,
Sender = accountId
});
}
} }
public class PKMessage public class PKMessage