Preliminary support for APIv9 and threads

This commit is contained in:
Ske 2021-07-15 12:41:19 +02:00
parent 0e7bcb993e
commit 1f2b9f998d
19 changed files with 127 additions and 24 deletions

@ -38,6 +38,14 @@ namespace Myriad.Cache
return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId);
case MessageDeleteBulkEvent md:
return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId);
case ThreadCreateEvent tc:
return cache.SaveChannel(tc);
case ThreadUpdateEvent tu:
return cache.SaveChannel(tu);
case ThreadDeleteEvent td:
return cache.RemoveChannel(td.Id);
case ThreadListSyncEvent tls:
return cache.SaveThreadListSync(tls);
}
return default;
@ -53,6 +61,9 @@ namespace Myriad.Cache
foreach (var member in guildCreate.Members)
await cache.SaveUser(member.User);
foreach (var thread in guildCreate.Threads)
await cache.SaveChannel(thread);
}
private static async ValueTask SaveMessageCreate(this IDiscordCache cache, MessageCreateEvent evt)
@ -70,5 +81,11 @@ namespace Myriad.Cache
// some kind of stub channel object until we get the real one
return guildId != null ? default : cache.SaveDmChannelStub(channelId);
}
private static async ValueTask SaveThreadListSync(this IDiscordCache cache, ThreadListSyncEvent evt)
{
foreach (var thread in evt.Threads)
await cache.SaveChannel(thread);
}
}
}

@ -75,5 +75,15 @@ namespace Myriad.Extensions
await cache.SaveChannel(restChannel);
return restChannel;
}
public static Channel GetRootChannel(this IDiscordCache cache, ulong channelOrThread)
{
var channel = cache.GetChannel(channelOrThread);
if (!channel.IsThread())
return channel;
var parent = cache.GetChannel(channel.ParentId!.Value);
return parent;
}
}
}

@ -5,5 +5,12 @@ namespace Myriad.Extensions
public static class ChannelExtensions
{
public static string Mention(this Channel channel) => $"<#{channel.Id}>";
public static bool IsThread(this Channel channel) => channel.Type.IsThread();
public static bool IsThread(this Channel.ChannelType type) =>
type is Channel.ChannelType.GuildPublicThread
or Channel.ChannelType.GuildPrivateThread
or Channel.ChannelType.GuildNewsThread;
}
}

@ -11,7 +11,7 @@ namespace Myriad.Extensions
public static class PermissionExtensions
{
public static PermissionSet PermissionsFor(this IDiscordCache cache, MessageCreateEvent message) =>
PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member?.Roles, isWebhook: message.Author.Discriminator == "0000");
PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member?.Roles, isWebhook: message.WebhookId != null);
public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, GuildMember member) =>
PermissionsFor(cache, channelId, member.User.Id, member.Roles);
@ -21,16 +21,21 @@ namespace Myriad.Extensions
public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, ICollection<ulong>? userRoles, bool isWebhook = false)
{
var channel = cache.GetChannel(channelId);
if (!cache.TryGetChannel(channelId, out var channel))
// todo: handle channel not found better
return PermissionSet.Dm;
if (channel.GuildId == null)
return PermissionSet.Dm;
var rootChannel = cache.GetRootChannel(channelId);
var guild = cache.GetGuild(channel.GuildId.Value);
if (isWebhook)
return EveryonePermissions(guild);
return PermissionsFor(guild, channel, userId, userRoles);
return PermissionsFor(guild, rootChannel, userId, userRoles);
}
public static PermissionSet EveryonePermissions(this Guild guild) =>

@ -6,5 +6,6 @@ namespace Myriad.Gateway
{
public Channel[] Channels { get; init; }
public GuildMember[] Members { get; init; }
public Channel[] Threads { get; init; }
}
}

@ -21,6 +21,10 @@ namespace Myriad.Gateway
{"CHANNEL_CREATE", typeof(ChannelCreateEvent)},
{"CHANNEL_UPDATE", typeof(ChannelUpdateEvent)},
{"CHANNEL_DELETE", typeof(ChannelDeleteEvent)},
{"THREAD_CREATE", typeof(ThreadCreateEvent)},
{"THREAD_UPDATE", typeof(ThreadUpdateEvent)},
{"THREAD_DELETE", typeof(ThreadDeleteEvent)},
{"THREAD_LIST_SYNC", typeof(ThreadListSyncEvent)},
{"MESSAGE_CREATE", typeof(MessageCreateEvent)},
{"MESSAGE_UPDATE", typeof(MessageUpdateEvent)},
{"MESSAGE_DELETE", typeof(MessageDeleteEvent)},

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadCreateEvent: Channel, IGatewayEvent;
}

@ -0,0 +1,12 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadDeleteEvent: IGatewayEvent
{
public ulong Id { get; init; }
public ulong? GuildId { get; init; }
public ulong? ParentId { get; init; }
public Channel.ChannelType Type { get; init; }
}
}

@ -0,0 +1,11 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadListSyncEvent: IGatewayEvent
{
public ulong GuildId { get; init; }
public ulong[]? ChannelIds { get; init; }
public Channel[] Threads { get; init; }
}
}

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadUpdateEvent: Channel, IGatewayEvent;
}

@ -82,7 +82,7 @@ namespace Myriad.Gateway
private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri)
{
Query = "v=8&encoding=json"
Query = "v=9&encoding=json"
}.Uri;
private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description)

@ -22,7 +22,7 @@ namespace Myriad.Rest
{
public class BaseRestClient: IAsyncDisposable
{
private const string ApiBaseUrl = "https://discord.com/api/v8";
private const string ApiBaseUrl = "https://discord.com/api/v9";
private readonly Version _httpVersion = new(2, 0);
private readonly JsonSerializerOptions _jsonSerializerOptions;

@ -116,9 +116,15 @@ namespace Myriad.Rest
_client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!;
public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request,
MultipartFile[]? files = null) =>
_client.PostMultipart<Message>($"/webhooks/{webhookId}/{webhookToken}?wait=true",
MultipartFile[]? files = null, ulong? threadId = null)
{
var url = $"/webhooks/{webhookId}/{webhookToken}?wait=true";
if (threadId != null)
url += $"&thread_id={threadId}";
return _client.PostMultipart<Message>(url,
("ExecuteWebhook", webhookId), request, files)!;
}
public Task<Message> EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId,
WebhookMessageEditRequest request) =>

@ -78,7 +78,7 @@ namespace PluralKit.Bot
public PermissionSet PermissionsIn(ulong channelId)
{
var channel = _cache.GetChannel(channelId);
var channel = _cache.GetRootChannel(channelId);
if (channel.GuildId != null)
{

@ -63,6 +63,7 @@ namespace PluralKit.Bot
var guild = evt.GuildId != null ? _cache.GetGuild(evt.GuildId.Value) : null;
var channel = _cache.GetChannel(evt.ChannelId);
var rootChannel = _cache.GetRootChannel(evt.ChannelId);
// Log metrics and message info
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
@ -72,7 +73,7 @@ namespace PluralKit.Bot
MessageContext ctx;
await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, evt.ChannelId);
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, rootChannel.Id);
// Try each handler until we find one that succeeds
if (await TryHandleLogClean(evt, ctx))

@ -57,6 +57,8 @@ namespace PluralKit.Bot
// Fetch members and try to match to a specific member
await using var conn = await _db.Obtain();
var rootChannel = _cache.GetRootChannel(message.ChannelId);
List<ProxyMember> members;
using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime))
members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.GuildId!.Value)).ToList();
@ -68,7 +70,7 @@ namespace PluralKit.Bot
if (message.Content != null && message.Content.Length > 2000) throw new PKError("PluralKit cannot proxy messages over 2000 characters in length.");
// Permission check after proxy match so we don't get spammed when not actually proxying
if (!await CheckBotPermissionsOrError(botPermissions, message.ChannelId))
if (!await CheckBotPermissionsOrError(botPermissions, rootChannel.Id))
return false;
// this method throws, so no need to wrap it in an if statement
@ -76,7 +78,7 @@ namespace PluralKit.Bot
// Check if the sender account can mention everyone/here + embed links
// we need to "mirror" these permissions when proxying to prevent exploits
var senderPermissions = PermissionExtensions.PermissionsFor(guild, channel, message);
var senderPermissions = PermissionExtensions.PermissionsFor(guild, rootChannel, message);
var allowEveryone = senderPermissions.HasFlag(PermissionSet.MentionEveryone);
var allowEmbeds = senderPermissions.HasFlag(PermissionSet.EmbedLinks);
@ -131,10 +133,15 @@ namespace PluralKit.Bot
var content = match.ProxyContent;
if (!allowEmbeds) content = content.BreakLinkEmbeds();
var messageChannel = _cache.GetChannel(trigger.ChannelId);
var rootChannel = _cache.GetRootChannel(trigger.ChannelId);
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
var proxyMessage = await _webhookExecutor.ExecuteWebhook(new ProxyRequest
{
GuildId = trigger.GuildId!.Value,
ChannelId = trigger.ChannelId,
ChannelId = rootChannel.Id,
ThreadId = threadId,
Name = match.Member.ProxyName(ctx),
AvatarUrl = match.Member.ProxyAvatar(ctx),
Content = content,

@ -33,6 +33,7 @@ namespace PluralKit.Bot
{
public ulong GuildId { get; init; }
public ulong ChannelId { get; init; }
public ulong? ThreadId { get; init; }
public string Name { get; init; }
public string? AvatarUrl { get; init; }
public string? Content { get; init; }
@ -70,8 +71,8 @@ namespace PluralKit.Bot
// Log the relevant metrics
_metrics.Measure.Meter.Mark(BotMetrics.MessagesProxied);
_logger.Information("Invoked webhook {Webhook} in channel {Channel}", webhook.Id,
req.ChannelId);
_logger.Information("Invoked webhook {Webhook} in channel {Channel} (thread {ThreadId}", webhook.Id,
req.ChannelId, req.ThreadId);
return webhookMessage;
}
@ -122,7 +123,7 @@ namespace PluralKit.Bot
using (_metrics.Measure.Timer.Time(BotMetrics.WebhookResponseTime)) {
try
{
webhookMessage = await _rest.ExecuteWebhook(webhook.Id, webhook.Token, webhookReq, files);
webhookMessage = await _rest.ExecuteWebhook(webhook.Id, webhook.Token, webhookReq, files, req.ThreadId);
}
catch (JsonReaderException)
{
@ -136,7 +137,8 @@ namespace PluralKit.Bot
{
// Error 10015 = "Unknown Webhook" - this likely means the webhook was deleted
// but is still in our cache. Invalidate, refresh, try again
_logger.Warning("Error invoking webhook {Webhook} in channel {Channel}", webhook.Id, webhook.ChannelId);
_logger.Warning("Error invoking webhook {Webhook} in channel {Channel} (thread {ThreadId}",
webhook.Id, webhook.ChannelId, req.ThreadId);
var newWebhook = await _webhookCache.InvalidateAndRefreshWebhook(req.ChannelId, webhook);
return await ExecuteWebhookInner(newWebhook, req, hasRetried: true);
@ -147,12 +149,12 @@ namespace PluralKit.Bot
}
// We don't care about whether the sending succeeds, and we don't want to *wait* for it, so we just fork it off
var _ = TrySendRemainingAttachments(webhook, req.Name, req.AvatarUrl, attachmentChunks);
var _ = TrySendRemainingAttachments(webhook, req.Name, req.AvatarUrl, attachmentChunks, req.ThreadId);
return webhookMessage;
}
private async Task TrySendRemainingAttachments(Webhook webhook, string name, string avatarUrl, IReadOnlyList<IReadOnlyCollection<Message.Attachment>> attachmentChunks)
private async Task TrySendRemainingAttachments(Webhook webhook, string name, string avatarUrl, IReadOnlyList<IReadOnlyCollection<Message.Attachment>> attachmentChunks, ulong? threadId)
{
if (attachmentChunks.Count <= 1) return;
@ -160,7 +162,7 @@ namespace PluralKit.Bot
{
var files = await GetAttachmentFiles(attachmentChunks[i]);
var req = new ExecuteWebhookRequest {Username = name, AvatarUrl = avatarUrl};
await _rest.ExecuteWebhook(webhook.Id, webhook.Token!, req, files);
await _rest.ExecuteWebhook(webhook.Id, webhook.Token!, req, files, threadId);
}
}

@ -194,6 +194,11 @@ namespace PluralKit.Bot
}
public static bool IsValidGuildChannel(Channel channel) =>
channel.Type == Channel.ChannelType.GuildText || channel.Type == Channel.ChannelType.GuildNews;
channel.Type is
Channel.ChannelType.GuildText or
Channel.ChannelType.GuildNews or
Channel.ChannelType.GuildPublicThread or
Channel.ChannelType.GuildPrivateThread or
Channel.ChannelType.GuildNewsThread;
}
}

@ -36,10 +36,13 @@ namespace PluralKit.Bot
if (channel != null)
{
props.Add(new("GuildId", new ScalarValue(channel.Value)));
var botPermissions = _bot.PermissionsIn(channel.Value);
props.Add(new("BotPermissions", new ScalarValue(botPermissions)));
props.Add(new("ChannelId", new ScalarValue(channel.Value)));
if (_cache.TryGetChannel(channel.Value, out _))
{
var botPermissions = _bot.PermissionsIn(channel.Value);
props.Add(new("BotPermissions", new ScalarValue(botPermissions)));
}
}
if (message != null)