From 1f2b9f998de836e759e54c9ddf17434c6968ccc7 Mon Sep 17 00:00:00 2001 From: Ske Date: Thu, 15 Jul 2021 12:41:19 +0200 Subject: [PATCH] Preliminary support for APIv9 and threads --- Myriad/Cache/DiscordCacheExtensions.cs | 17 +++++++++++++++++ Myriad/Extensions/CacheExtensions.cs | 10 ++++++++++ Myriad/Extensions/ChannelExtensions.cs | 7 +++++++ Myriad/Extensions/PermissionExtensions.cs | 11 ++++++++--- Myriad/Gateway/Events/GuildCreateEvent.cs | 1 + Myriad/Gateway/Events/IGatewayEvent.cs | 4 ++++ Myriad/Gateway/Events/ThreadCreateEvent.cs | 6 ++++++ Myriad/Gateway/Events/ThreadDeleteEvent.cs | 12 ++++++++++++ Myriad/Gateway/Events/ThreadListSyncEvent.cs | 11 +++++++++++ Myriad/Gateway/Events/ThreadUpdateEvent.cs | 6 ++++++ Myriad/Gateway/ShardConnection.cs | 2 +- Myriad/Rest/BaseRestClient.cs | 2 +- Myriad/Rest/DiscordApiClient.cs | 10 ++++++++-- PluralKit.Bot/Bot.cs | 2 +- PluralKit.Bot/Handlers/MessageCreated.cs | 3 ++- PluralKit.Bot/Proxy/ProxyService.cs | 13 ++++++++++--- .../Services/WebhookExecutorService.cs | 16 +++++++++------- PluralKit.Bot/Utils/DiscordUtils.cs | 7 ++++++- .../Utils/SerilogGatewayEnricherFactory.cs | 11 +++++++---- 19 files changed, 127 insertions(+), 24 deletions(-) create mode 100644 Myriad/Gateway/Events/ThreadCreateEvent.cs create mode 100644 Myriad/Gateway/Events/ThreadDeleteEvent.cs create mode 100644 Myriad/Gateway/Events/ThreadListSyncEvent.cs create mode 100644 Myriad/Gateway/Events/ThreadUpdateEvent.cs diff --git a/Myriad/Cache/DiscordCacheExtensions.cs b/Myriad/Cache/DiscordCacheExtensions.cs index 03fbc9b4..91c9f268 100644 --- a/Myriad/Cache/DiscordCacheExtensions.cs +++ b/Myriad/Cache/DiscordCacheExtensions.cs @@ -38,6 +38,14 @@ namespace Myriad.Cache return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId); case MessageDeleteBulkEvent md: return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId); + case ThreadCreateEvent tc: + return cache.SaveChannel(tc); + case ThreadUpdateEvent tu: + return cache.SaveChannel(tu); + case ThreadDeleteEvent td: + return cache.RemoveChannel(td.Id); + case ThreadListSyncEvent tls: + return cache.SaveThreadListSync(tls); } return default; @@ -53,6 +61,9 @@ namespace Myriad.Cache foreach (var member in guildCreate.Members) await cache.SaveUser(member.User); + + foreach (var thread in guildCreate.Threads) + await cache.SaveChannel(thread); } private static async ValueTask SaveMessageCreate(this IDiscordCache cache, MessageCreateEvent evt) @@ -70,5 +81,11 @@ namespace Myriad.Cache // some kind of stub channel object until we get the real one return guildId != null ? default : cache.SaveDmChannelStub(channelId); } + + private static async ValueTask SaveThreadListSync(this IDiscordCache cache, ThreadListSyncEvent evt) + { + foreach (var thread in evt.Threads) + await cache.SaveChannel(thread); + } } } \ No newline at end of file diff --git a/Myriad/Extensions/CacheExtensions.cs b/Myriad/Extensions/CacheExtensions.cs index d331e9e5..5be4df99 100644 --- a/Myriad/Extensions/CacheExtensions.cs +++ b/Myriad/Extensions/CacheExtensions.cs @@ -75,5 +75,15 @@ namespace Myriad.Extensions await cache.SaveChannel(restChannel); return restChannel; } + + public static Channel GetRootChannel(this IDiscordCache cache, ulong channelOrThread) + { + var channel = cache.GetChannel(channelOrThread); + if (!channel.IsThread()) + return channel; + + var parent = cache.GetChannel(channel.ParentId!.Value); + return parent; + } } } \ No newline at end of file diff --git a/Myriad/Extensions/ChannelExtensions.cs b/Myriad/Extensions/ChannelExtensions.cs index 0f04cb03..b511b970 100644 --- a/Myriad/Extensions/ChannelExtensions.cs +++ b/Myriad/Extensions/ChannelExtensions.cs @@ -5,5 +5,12 @@ namespace Myriad.Extensions public static class ChannelExtensions { public static string Mention(this Channel channel) => $"<#{channel.Id}>"; + + public static bool IsThread(this Channel channel) => channel.Type.IsThread(); + + public static bool IsThread(this Channel.ChannelType type) => + type is Channel.ChannelType.GuildPublicThread + or Channel.ChannelType.GuildPrivateThread + or Channel.ChannelType.GuildNewsThread; } } \ No newline at end of file diff --git a/Myriad/Extensions/PermissionExtensions.cs b/Myriad/Extensions/PermissionExtensions.cs index 9c403a3f..099c2d3d 100644 --- a/Myriad/Extensions/PermissionExtensions.cs +++ b/Myriad/Extensions/PermissionExtensions.cs @@ -11,7 +11,7 @@ namespace Myriad.Extensions public static class PermissionExtensions { public static PermissionSet PermissionsFor(this IDiscordCache cache, MessageCreateEvent message) => - PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member?.Roles, isWebhook: message.Author.Discriminator == "0000"); + PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member?.Roles, isWebhook: message.WebhookId != null); public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, GuildMember member) => PermissionsFor(cache, channelId, member.User.Id, member.Roles); @@ -21,16 +21,21 @@ namespace Myriad.Extensions public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, ICollection? userRoles, bool isWebhook = false) { - var channel = cache.GetChannel(channelId); + if (!cache.TryGetChannel(channelId, out var channel)) + // todo: handle channel not found better + return PermissionSet.Dm; + if (channel.GuildId == null) return PermissionSet.Dm; + + var rootChannel = cache.GetRootChannel(channelId); var guild = cache.GetGuild(channel.GuildId.Value); if (isWebhook) return EveryonePermissions(guild); - return PermissionsFor(guild, channel, userId, userRoles); + return PermissionsFor(guild, rootChannel, userId, userRoles); } public static PermissionSet EveryonePermissions(this Guild guild) => diff --git a/Myriad/Gateway/Events/GuildCreateEvent.cs b/Myriad/Gateway/Events/GuildCreateEvent.cs index 41f6220c..6b3cfb5e 100644 --- a/Myriad/Gateway/Events/GuildCreateEvent.cs +++ b/Myriad/Gateway/Events/GuildCreateEvent.cs @@ -6,5 +6,6 @@ namespace Myriad.Gateway { public Channel[] Channels { get; init; } public GuildMember[] Members { get; init; } + public Channel[] Threads { get; init; } } } \ No newline at end of file diff --git a/Myriad/Gateway/Events/IGatewayEvent.cs b/Myriad/Gateway/Events/IGatewayEvent.cs index 17c5068b..679ff2e4 100644 --- a/Myriad/Gateway/Events/IGatewayEvent.cs +++ b/Myriad/Gateway/Events/IGatewayEvent.cs @@ -21,6 +21,10 @@ namespace Myriad.Gateway {"CHANNEL_CREATE", typeof(ChannelCreateEvent)}, {"CHANNEL_UPDATE", typeof(ChannelUpdateEvent)}, {"CHANNEL_DELETE", typeof(ChannelDeleteEvent)}, + {"THREAD_CREATE", typeof(ThreadCreateEvent)}, + {"THREAD_UPDATE", typeof(ThreadUpdateEvent)}, + {"THREAD_DELETE", typeof(ThreadDeleteEvent)}, + {"THREAD_LIST_SYNC", typeof(ThreadListSyncEvent)}, {"MESSAGE_CREATE", typeof(MessageCreateEvent)}, {"MESSAGE_UPDATE", typeof(MessageUpdateEvent)}, {"MESSAGE_DELETE", typeof(MessageDeleteEvent)}, diff --git a/Myriad/Gateway/Events/ThreadCreateEvent.cs b/Myriad/Gateway/Events/ThreadCreateEvent.cs new file mode 100644 index 00000000..ee204188 --- /dev/null +++ b/Myriad/Gateway/Events/ThreadCreateEvent.cs @@ -0,0 +1,6 @@ +using Myriad.Types; + +namespace Myriad.Gateway +{ + public record ThreadCreateEvent: Channel, IGatewayEvent; +} \ No newline at end of file diff --git a/Myriad/Gateway/Events/ThreadDeleteEvent.cs b/Myriad/Gateway/Events/ThreadDeleteEvent.cs new file mode 100644 index 00000000..694ae56b --- /dev/null +++ b/Myriad/Gateway/Events/ThreadDeleteEvent.cs @@ -0,0 +1,12 @@ +using Myriad.Types; + +namespace Myriad.Gateway +{ + public record ThreadDeleteEvent: IGatewayEvent + { + public ulong Id { get; init; } + public ulong? GuildId { get; init; } + public ulong? ParentId { get; init; } + public Channel.ChannelType Type { get; init; } + } +} \ No newline at end of file diff --git a/Myriad/Gateway/Events/ThreadListSyncEvent.cs b/Myriad/Gateway/Events/ThreadListSyncEvent.cs new file mode 100644 index 00000000..56ff7426 --- /dev/null +++ b/Myriad/Gateway/Events/ThreadListSyncEvent.cs @@ -0,0 +1,11 @@ +using Myriad.Types; + +namespace Myriad.Gateway +{ + public record ThreadListSyncEvent: IGatewayEvent + { + public ulong GuildId { get; init; } + public ulong[]? ChannelIds { get; init; } + public Channel[] Threads { get; init; } + } +} \ No newline at end of file diff --git a/Myriad/Gateway/Events/ThreadUpdateEvent.cs b/Myriad/Gateway/Events/ThreadUpdateEvent.cs new file mode 100644 index 00000000..68cc3afb --- /dev/null +++ b/Myriad/Gateway/Events/ThreadUpdateEvent.cs @@ -0,0 +1,6 @@ +using Myriad.Types; + +namespace Myriad.Gateway +{ + public record ThreadUpdateEvent: Channel, IGatewayEvent; +} \ No newline at end of file diff --git a/Myriad/Gateway/ShardConnection.cs b/Myriad/Gateway/ShardConnection.cs index 3e20615d..03e84792 100644 --- a/Myriad/Gateway/ShardConnection.cs +++ b/Myriad/Gateway/ShardConnection.cs @@ -82,7 +82,7 @@ namespace Myriad.Gateway private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri) { - Query = "v=8&encoding=json" + Query = "v=9&encoding=json" }.Uri; private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description) diff --git a/Myriad/Rest/BaseRestClient.cs b/Myriad/Rest/BaseRestClient.cs index 8f327508..923a6f5a 100644 --- a/Myriad/Rest/BaseRestClient.cs +++ b/Myriad/Rest/BaseRestClient.cs @@ -22,7 +22,7 @@ namespace Myriad.Rest { public class BaseRestClient: IAsyncDisposable { - private const string ApiBaseUrl = "https://discord.com/api/v8"; + private const string ApiBaseUrl = "https://discord.com/api/v9"; private readonly Version _httpVersion = new(2, 0); private readonly JsonSerializerOptions _jsonSerializerOptions; diff --git a/Myriad/Rest/DiscordApiClient.cs b/Myriad/Rest/DiscordApiClient.cs index 0e68e931..77b73a81 100644 --- a/Myriad/Rest/DiscordApiClient.cs +++ b/Myriad/Rest/DiscordApiClient.cs @@ -116,9 +116,15 @@ namespace Myriad.Rest _client.Get($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!; public Task ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request, - MultipartFile[]? files = null) => - _client.PostMultipart($"/webhooks/{webhookId}/{webhookToken}?wait=true", + MultipartFile[]? files = null, ulong? threadId = null) + { + var url = $"/webhooks/{webhookId}/{webhookToken}?wait=true"; + if (threadId != null) + url += $"&thread_id={threadId}"; + + return _client.PostMultipart(url, ("ExecuteWebhook", webhookId), request, files)!; + } public Task EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId, WebhookMessageEditRequest request) => diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index 087e5d38..37d73600 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -78,7 +78,7 @@ namespace PluralKit.Bot public PermissionSet PermissionsIn(ulong channelId) { - var channel = _cache.GetChannel(channelId); + var channel = _cache.GetRootChannel(channelId); if (channel.GuildId != null) { diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index 9206835a..43a93d63 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -63,6 +63,7 @@ namespace PluralKit.Bot var guild = evt.GuildId != null ? _cache.GetGuild(evt.GuildId.Value) : null; var channel = _cache.GetChannel(evt.ChannelId); + var rootChannel = _cache.GetRootChannel(evt.ChannelId); // Log metrics and message info _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived); @@ -72,7 +73,7 @@ namespace PluralKit.Bot MessageContext ctx; await using (var conn = await _db.Obtain()) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) - ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, evt.ChannelId); + ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, rootChannel.Id); // Try each handler until we find one that succeeds if (await TryHandleLogClean(evt, ctx)) diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index e1267f92..1439cddb 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -57,6 +57,8 @@ namespace PluralKit.Bot // Fetch members and try to match to a specific member await using var conn = await _db.Obtain(); + var rootChannel = _cache.GetRootChannel(message.ChannelId); + List members; using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime)) members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.GuildId!.Value)).ToList(); @@ -68,7 +70,7 @@ namespace PluralKit.Bot if (message.Content != null && message.Content.Length > 2000) throw new PKError("PluralKit cannot proxy messages over 2000 characters in length."); // Permission check after proxy match so we don't get spammed when not actually proxying - if (!await CheckBotPermissionsOrError(botPermissions, message.ChannelId)) + if (!await CheckBotPermissionsOrError(botPermissions, rootChannel.Id)) return false; // this method throws, so no need to wrap it in an if statement @@ -76,7 +78,7 @@ namespace PluralKit.Bot // Check if the sender account can mention everyone/here + embed links // we need to "mirror" these permissions when proxying to prevent exploits - var senderPermissions = PermissionExtensions.PermissionsFor(guild, channel, message); + var senderPermissions = PermissionExtensions.PermissionsFor(guild, rootChannel, message); var allowEveryone = senderPermissions.HasFlag(PermissionSet.MentionEveryone); var allowEmbeds = senderPermissions.HasFlag(PermissionSet.EmbedLinks); @@ -131,10 +133,15 @@ namespace PluralKit.Bot var content = match.ProxyContent; if (!allowEmbeds) content = content.BreakLinkEmbeds(); + var messageChannel = _cache.GetChannel(trigger.ChannelId); + var rootChannel = _cache.GetRootChannel(trigger.ChannelId); + var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null; + var proxyMessage = await _webhookExecutor.ExecuteWebhook(new ProxyRequest { GuildId = trigger.GuildId!.Value, - ChannelId = trigger.ChannelId, + ChannelId = rootChannel.Id, + ThreadId = threadId, Name = match.Member.ProxyName(ctx), AvatarUrl = match.Member.ProxyAvatar(ctx), Content = content, diff --git a/PluralKit.Bot/Services/WebhookExecutorService.cs b/PluralKit.Bot/Services/WebhookExecutorService.cs index bbc4c20a..398318ec 100644 --- a/PluralKit.Bot/Services/WebhookExecutorService.cs +++ b/PluralKit.Bot/Services/WebhookExecutorService.cs @@ -33,6 +33,7 @@ namespace PluralKit.Bot { public ulong GuildId { get; init; } public ulong ChannelId { get; init; } + public ulong? ThreadId { get; init; } public string Name { get; init; } public string? AvatarUrl { get; init; } public string? Content { get; init; } @@ -70,8 +71,8 @@ namespace PluralKit.Bot // Log the relevant metrics _metrics.Measure.Meter.Mark(BotMetrics.MessagesProxied); - _logger.Information("Invoked webhook {Webhook} in channel {Channel}", webhook.Id, - req.ChannelId); + _logger.Information("Invoked webhook {Webhook} in channel {Channel} (thread {ThreadId}", webhook.Id, + req.ChannelId, req.ThreadId); return webhookMessage; } @@ -122,7 +123,7 @@ namespace PluralKit.Bot using (_metrics.Measure.Timer.Time(BotMetrics.WebhookResponseTime)) { try { - webhookMessage = await _rest.ExecuteWebhook(webhook.Id, webhook.Token, webhookReq, files); + webhookMessage = await _rest.ExecuteWebhook(webhook.Id, webhook.Token, webhookReq, files, req.ThreadId); } catch (JsonReaderException) { @@ -136,7 +137,8 @@ namespace PluralKit.Bot { // Error 10015 = "Unknown Webhook" - this likely means the webhook was deleted // but is still in our cache. Invalidate, refresh, try again - _logger.Warning("Error invoking webhook {Webhook} in channel {Channel}", webhook.Id, webhook.ChannelId); + _logger.Warning("Error invoking webhook {Webhook} in channel {Channel} (thread {ThreadId}", + webhook.Id, webhook.ChannelId, req.ThreadId); var newWebhook = await _webhookCache.InvalidateAndRefreshWebhook(req.ChannelId, webhook); return await ExecuteWebhookInner(newWebhook, req, hasRetried: true); @@ -147,12 +149,12 @@ namespace PluralKit.Bot } // We don't care about whether the sending succeeds, and we don't want to *wait* for it, so we just fork it off - var _ = TrySendRemainingAttachments(webhook, req.Name, req.AvatarUrl, attachmentChunks); + var _ = TrySendRemainingAttachments(webhook, req.Name, req.AvatarUrl, attachmentChunks, req.ThreadId); return webhookMessage; } - private async Task TrySendRemainingAttachments(Webhook webhook, string name, string avatarUrl, IReadOnlyList> attachmentChunks) + private async Task TrySendRemainingAttachments(Webhook webhook, string name, string avatarUrl, IReadOnlyList> attachmentChunks, ulong? threadId) { if (attachmentChunks.Count <= 1) return; @@ -160,7 +162,7 @@ namespace PluralKit.Bot { var files = await GetAttachmentFiles(attachmentChunks[i]); var req = new ExecuteWebhookRequest {Username = name, AvatarUrl = avatarUrl}; - await _rest.ExecuteWebhook(webhook.Id, webhook.Token!, req, files); + await _rest.ExecuteWebhook(webhook.Id, webhook.Token!, req, files, threadId); } } diff --git a/PluralKit.Bot/Utils/DiscordUtils.cs b/PluralKit.Bot/Utils/DiscordUtils.cs index 59f3711d..62ce0169 100644 --- a/PluralKit.Bot/Utils/DiscordUtils.cs +++ b/PluralKit.Bot/Utils/DiscordUtils.cs @@ -194,6 +194,11 @@ namespace PluralKit.Bot } public static bool IsValidGuildChannel(Channel channel) => - channel.Type == Channel.ChannelType.GuildText || channel.Type == Channel.ChannelType.GuildNews; + channel.Type is + Channel.ChannelType.GuildText or + Channel.ChannelType.GuildNews or + Channel.ChannelType.GuildPublicThread or + Channel.ChannelType.GuildPrivateThread or + Channel.ChannelType.GuildNewsThread; } } diff --git a/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs b/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs index 79577d3a..3ff85a12 100644 --- a/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs +++ b/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs @@ -36,10 +36,13 @@ namespace PluralKit.Bot if (channel != null) { - props.Add(new("GuildId", new ScalarValue(channel.Value))); - - var botPermissions = _bot.PermissionsIn(channel.Value); - props.Add(new("BotPermissions", new ScalarValue(botPermissions))); + props.Add(new("ChannelId", new ScalarValue(channel.Value))); + + if (_cache.TryGetChannel(channel.Value, out _)) + { + var botPermissions = _bot.PermissionsIn(channel.Value); + props.Add(new("BotPermissions", new ScalarValue(botPermissions))); + } } if (message != null)