Preliminary support for APIv9 and threads

This commit is contained in:
Ske 2021-07-15 12:41:19 +02:00
parent 0e7bcb993e
commit 1f2b9f998d
19 changed files with 127 additions and 24 deletions

View File

@ -38,6 +38,14 @@ namespace Myriad.Cache
return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId); return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId);
case MessageDeleteBulkEvent md: case MessageDeleteBulkEvent md:
return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId); return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId);
case ThreadCreateEvent tc:
return cache.SaveChannel(tc);
case ThreadUpdateEvent tu:
return cache.SaveChannel(tu);
case ThreadDeleteEvent td:
return cache.RemoveChannel(td.Id);
case ThreadListSyncEvent tls:
return cache.SaveThreadListSync(tls);
} }
return default; return default;
@ -53,6 +61,9 @@ namespace Myriad.Cache
foreach (var member in guildCreate.Members) foreach (var member in guildCreate.Members)
await cache.SaveUser(member.User); await cache.SaveUser(member.User);
foreach (var thread in guildCreate.Threads)
await cache.SaveChannel(thread);
} }
private static async ValueTask SaveMessageCreate(this IDiscordCache cache, MessageCreateEvent evt) private static async ValueTask SaveMessageCreate(this IDiscordCache cache, MessageCreateEvent evt)
@ -70,5 +81,11 @@ namespace Myriad.Cache
// some kind of stub channel object until we get the real one // some kind of stub channel object until we get the real one
return guildId != null ? default : cache.SaveDmChannelStub(channelId); return guildId != null ? default : cache.SaveDmChannelStub(channelId);
} }
private static async ValueTask SaveThreadListSync(this IDiscordCache cache, ThreadListSyncEvent evt)
{
foreach (var thread in evt.Threads)
await cache.SaveChannel(thread);
}
} }
} }

View File

@ -75,5 +75,15 @@ namespace Myriad.Extensions
await cache.SaveChannel(restChannel); await cache.SaveChannel(restChannel);
return restChannel; return restChannel;
} }
public static Channel GetRootChannel(this IDiscordCache cache, ulong channelOrThread)
{
var channel = cache.GetChannel(channelOrThread);
if (!channel.IsThread())
return channel;
var parent = cache.GetChannel(channel.ParentId!.Value);
return parent;
}
} }
} }

View File

@ -5,5 +5,12 @@ namespace Myriad.Extensions
public static class ChannelExtensions public static class ChannelExtensions
{ {
public static string Mention(this Channel channel) => $"<#{channel.Id}>"; public static string Mention(this Channel channel) => $"<#{channel.Id}>";
public static bool IsThread(this Channel channel) => channel.Type.IsThread();
public static bool IsThread(this Channel.ChannelType type) =>
type is Channel.ChannelType.GuildPublicThread
or Channel.ChannelType.GuildPrivateThread
or Channel.ChannelType.GuildNewsThread;
} }
} }

View File

@ -11,7 +11,7 @@ namespace Myriad.Extensions
public static class PermissionExtensions public static class PermissionExtensions
{ {
public static PermissionSet PermissionsFor(this IDiscordCache cache, MessageCreateEvent message) => public static PermissionSet PermissionsFor(this IDiscordCache cache, MessageCreateEvent message) =>
PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member?.Roles, isWebhook: message.Author.Discriminator == "0000"); PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member?.Roles, isWebhook: message.WebhookId != null);
public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, GuildMember member) => public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, GuildMember member) =>
PermissionsFor(cache, channelId, member.User.Id, member.Roles); PermissionsFor(cache, channelId, member.User.Id, member.Roles);
@ -21,16 +21,21 @@ namespace Myriad.Extensions
public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, ICollection<ulong>? userRoles, bool isWebhook = false) public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, ICollection<ulong>? userRoles, bool isWebhook = false)
{ {
var channel = cache.GetChannel(channelId); if (!cache.TryGetChannel(channelId, out var channel))
// todo: handle channel not found better
return PermissionSet.Dm;
if (channel.GuildId == null) if (channel.GuildId == null)
return PermissionSet.Dm; return PermissionSet.Dm;
var rootChannel = cache.GetRootChannel(channelId);
var guild = cache.GetGuild(channel.GuildId.Value); var guild = cache.GetGuild(channel.GuildId.Value);
if (isWebhook) if (isWebhook)
return EveryonePermissions(guild); return EveryonePermissions(guild);
return PermissionsFor(guild, channel, userId, userRoles); return PermissionsFor(guild, rootChannel, userId, userRoles);
} }
public static PermissionSet EveryonePermissions(this Guild guild) => public static PermissionSet EveryonePermissions(this Guild guild) =>

View File

@ -6,5 +6,6 @@ namespace Myriad.Gateway
{ {
public Channel[] Channels { get; init; } public Channel[] Channels { get; init; }
public GuildMember[] Members { get; init; } public GuildMember[] Members { get; init; }
public Channel[] Threads { get; init; }
} }
} }

View File

@ -21,6 +21,10 @@ namespace Myriad.Gateway
{"CHANNEL_CREATE", typeof(ChannelCreateEvent)}, {"CHANNEL_CREATE", typeof(ChannelCreateEvent)},
{"CHANNEL_UPDATE", typeof(ChannelUpdateEvent)}, {"CHANNEL_UPDATE", typeof(ChannelUpdateEvent)},
{"CHANNEL_DELETE", typeof(ChannelDeleteEvent)}, {"CHANNEL_DELETE", typeof(ChannelDeleteEvent)},
{"THREAD_CREATE", typeof(ThreadCreateEvent)},
{"THREAD_UPDATE", typeof(ThreadUpdateEvent)},
{"THREAD_DELETE", typeof(ThreadDeleteEvent)},
{"THREAD_LIST_SYNC", typeof(ThreadListSyncEvent)},
{"MESSAGE_CREATE", typeof(MessageCreateEvent)}, {"MESSAGE_CREATE", typeof(MessageCreateEvent)},
{"MESSAGE_UPDATE", typeof(MessageUpdateEvent)}, {"MESSAGE_UPDATE", typeof(MessageUpdateEvent)},
{"MESSAGE_DELETE", typeof(MessageDeleteEvent)}, {"MESSAGE_DELETE", typeof(MessageDeleteEvent)},

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadCreateEvent: Channel, IGatewayEvent;
}

View File

@ -0,0 +1,12 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadDeleteEvent: IGatewayEvent
{
public ulong Id { get; init; }
public ulong? GuildId { get; init; }
public ulong? ParentId { get; init; }
public Channel.ChannelType Type { get; init; }
}
}

View File

@ -0,0 +1,11 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadListSyncEvent: IGatewayEvent
{
public ulong GuildId { get; init; }
public ulong[]? ChannelIds { get; init; }
public Channel[] Threads { get; init; }
}
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ThreadUpdateEvent: Channel, IGatewayEvent;
}

View File

@ -82,7 +82,7 @@ namespace Myriad.Gateway
private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri) private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri)
{ {
Query = "v=8&encoding=json" Query = "v=9&encoding=json"
}.Uri; }.Uri;
private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description) private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description)

View File

@ -22,7 +22,7 @@ namespace Myriad.Rest
{ {
public class BaseRestClient: IAsyncDisposable public class BaseRestClient: IAsyncDisposable
{ {
private const string ApiBaseUrl = "https://discord.com/api/v8"; private const string ApiBaseUrl = "https://discord.com/api/v9";
private readonly Version _httpVersion = new(2, 0); private readonly Version _httpVersion = new(2, 0);
private readonly JsonSerializerOptions _jsonSerializerOptions; private readonly JsonSerializerOptions _jsonSerializerOptions;

View File

@ -116,9 +116,15 @@ namespace Myriad.Rest
_client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!; _client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!;
public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request, public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request,
MultipartFile[]? files = null) => MultipartFile[]? files = null, ulong? threadId = null)
_client.PostMultipart<Message>($"/webhooks/{webhookId}/{webhookToken}?wait=true", {
var url = $"/webhooks/{webhookId}/{webhookToken}?wait=true";
if (threadId != null)
url += $"&thread_id={threadId}";
return _client.PostMultipart<Message>(url,
("ExecuteWebhook", webhookId), request, files)!; ("ExecuteWebhook", webhookId), request, files)!;
}
public Task<Message> EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId, public Task<Message> EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId,
WebhookMessageEditRequest request) => WebhookMessageEditRequest request) =>

View File

@ -78,7 +78,7 @@ namespace PluralKit.Bot
public PermissionSet PermissionsIn(ulong channelId) public PermissionSet PermissionsIn(ulong channelId)
{ {
var channel = _cache.GetChannel(channelId); var channel = _cache.GetRootChannel(channelId);
if (channel.GuildId != null) if (channel.GuildId != null)
{ {

View File

@ -63,6 +63,7 @@ namespace PluralKit.Bot
var guild = evt.GuildId != null ? _cache.GetGuild(evt.GuildId.Value) : null; var guild = evt.GuildId != null ? _cache.GetGuild(evt.GuildId.Value) : null;
var channel = _cache.GetChannel(evt.ChannelId); var channel = _cache.GetChannel(evt.ChannelId);
var rootChannel = _cache.GetRootChannel(evt.ChannelId);
// Log metrics and message info // Log metrics and message info
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived); _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
@ -72,7 +73,7 @@ namespace PluralKit.Bot
MessageContext ctx; MessageContext ctx;
await using (var conn = await _db.Obtain()) await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, evt.ChannelId); ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, rootChannel.Id);
// Try each handler until we find one that succeeds // Try each handler until we find one that succeeds
if (await TryHandleLogClean(evt, ctx)) if (await TryHandleLogClean(evt, ctx))

View File

@ -57,6 +57,8 @@ namespace PluralKit.Bot
// Fetch members and try to match to a specific member // Fetch members and try to match to a specific member
await using var conn = await _db.Obtain(); await using var conn = await _db.Obtain();
var rootChannel = _cache.GetRootChannel(message.ChannelId);
List<ProxyMember> members; List<ProxyMember> members;
using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime)) using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime))
members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.GuildId!.Value)).ToList(); members = (await _repo.GetProxyMembers(conn, message.Author.Id, message.GuildId!.Value)).ToList();
@ -68,7 +70,7 @@ namespace PluralKit.Bot
if (message.Content != null && message.Content.Length > 2000) throw new PKError("PluralKit cannot proxy messages over 2000 characters in length."); if (message.Content != null && message.Content.Length > 2000) throw new PKError("PluralKit cannot proxy messages over 2000 characters in length.");
// Permission check after proxy match so we don't get spammed when not actually proxying // Permission check after proxy match so we don't get spammed when not actually proxying
if (!await CheckBotPermissionsOrError(botPermissions, message.ChannelId)) if (!await CheckBotPermissionsOrError(botPermissions, rootChannel.Id))
return false; return false;
// this method throws, so no need to wrap it in an if statement // this method throws, so no need to wrap it in an if statement
@ -76,7 +78,7 @@ namespace PluralKit.Bot
// Check if the sender account can mention everyone/here + embed links // Check if the sender account can mention everyone/here + embed links
// we need to "mirror" these permissions when proxying to prevent exploits // we need to "mirror" these permissions when proxying to prevent exploits
var senderPermissions = PermissionExtensions.PermissionsFor(guild, channel, message); var senderPermissions = PermissionExtensions.PermissionsFor(guild, rootChannel, message);
var allowEveryone = senderPermissions.HasFlag(PermissionSet.MentionEveryone); var allowEveryone = senderPermissions.HasFlag(PermissionSet.MentionEveryone);
var allowEmbeds = senderPermissions.HasFlag(PermissionSet.EmbedLinks); var allowEmbeds = senderPermissions.HasFlag(PermissionSet.EmbedLinks);
@ -131,10 +133,15 @@ namespace PluralKit.Bot
var content = match.ProxyContent; var content = match.ProxyContent;
if (!allowEmbeds) content = content.BreakLinkEmbeds(); if (!allowEmbeds) content = content.BreakLinkEmbeds();
var messageChannel = _cache.GetChannel(trigger.ChannelId);
var rootChannel = _cache.GetRootChannel(trigger.ChannelId);
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
var proxyMessage = await _webhookExecutor.ExecuteWebhook(new ProxyRequest var proxyMessage = await _webhookExecutor.ExecuteWebhook(new ProxyRequest
{ {
GuildId = trigger.GuildId!.Value, GuildId = trigger.GuildId!.Value,
ChannelId = trigger.ChannelId, ChannelId = rootChannel.Id,
ThreadId = threadId,
Name = match.Member.ProxyName(ctx), Name = match.Member.ProxyName(ctx),
AvatarUrl = match.Member.ProxyAvatar(ctx), AvatarUrl = match.Member.ProxyAvatar(ctx),
Content = content, Content = content,

View File

@ -33,6 +33,7 @@ namespace PluralKit.Bot
{ {
public ulong GuildId { get; init; } public ulong GuildId { get; init; }
public ulong ChannelId { get; init; } public ulong ChannelId { get; init; }
public ulong? ThreadId { get; init; }
public string Name { get; init; } public string Name { get; init; }
public string? AvatarUrl { get; init; } public string? AvatarUrl { get; init; }
public string? Content { get; init; } public string? Content { get; init; }
@ -70,8 +71,8 @@ namespace PluralKit.Bot
// Log the relevant metrics // Log the relevant metrics
_metrics.Measure.Meter.Mark(BotMetrics.MessagesProxied); _metrics.Measure.Meter.Mark(BotMetrics.MessagesProxied);
_logger.Information("Invoked webhook {Webhook} in channel {Channel}", webhook.Id, _logger.Information("Invoked webhook {Webhook} in channel {Channel} (thread {ThreadId}", webhook.Id,
req.ChannelId); req.ChannelId, req.ThreadId);
return webhookMessage; return webhookMessage;
} }
@ -122,7 +123,7 @@ namespace PluralKit.Bot
using (_metrics.Measure.Timer.Time(BotMetrics.WebhookResponseTime)) { using (_metrics.Measure.Timer.Time(BotMetrics.WebhookResponseTime)) {
try try
{ {
webhookMessage = await _rest.ExecuteWebhook(webhook.Id, webhook.Token, webhookReq, files); webhookMessage = await _rest.ExecuteWebhook(webhook.Id, webhook.Token, webhookReq, files, req.ThreadId);
} }
catch (JsonReaderException) catch (JsonReaderException)
{ {
@ -136,7 +137,8 @@ namespace PluralKit.Bot
{ {
// Error 10015 = "Unknown Webhook" - this likely means the webhook was deleted // Error 10015 = "Unknown Webhook" - this likely means the webhook was deleted
// but is still in our cache. Invalidate, refresh, try again // but is still in our cache. Invalidate, refresh, try again
_logger.Warning("Error invoking webhook {Webhook} in channel {Channel}", webhook.Id, webhook.ChannelId); _logger.Warning("Error invoking webhook {Webhook} in channel {Channel} (thread {ThreadId}",
webhook.Id, webhook.ChannelId, req.ThreadId);
var newWebhook = await _webhookCache.InvalidateAndRefreshWebhook(req.ChannelId, webhook); var newWebhook = await _webhookCache.InvalidateAndRefreshWebhook(req.ChannelId, webhook);
return await ExecuteWebhookInner(newWebhook, req, hasRetried: true); return await ExecuteWebhookInner(newWebhook, req, hasRetried: true);
@ -147,12 +149,12 @@ namespace PluralKit.Bot
} }
// We don't care about whether the sending succeeds, and we don't want to *wait* for it, so we just fork it off // We don't care about whether the sending succeeds, and we don't want to *wait* for it, so we just fork it off
var _ = TrySendRemainingAttachments(webhook, req.Name, req.AvatarUrl, attachmentChunks); var _ = TrySendRemainingAttachments(webhook, req.Name, req.AvatarUrl, attachmentChunks, req.ThreadId);
return webhookMessage; return webhookMessage;
} }
private async Task TrySendRemainingAttachments(Webhook webhook, string name, string avatarUrl, IReadOnlyList<IReadOnlyCollection<Message.Attachment>> attachmentChunks) private async Task TrySendRemainingAttachments(Webhook webhook, string name, string avatarUrl, IReadOnlyList<IReadOnlyCollection<Message.Attachment>> attachmentChunks, ulong? threadId)
{ {
if (attachmentChunks.Count <= 1) return; if (attachmentChunks.Count <= 1) return;
@ -160,7 +162,7 @@ namespace PluralKit.Bot
{ {
var files = await GetAttachmentFiles(attachmentChunks[i]); var files = await GetAttachmentFiles(attachmentChunks[i]);
var req = new ExecuteWebhookRequest {Username = name, AvatarUrl = avatarUrl}; var req = new ExecuteWebhookRequest {Username = name, AvatarUrl = avatarUrl};
await _rest.ExecuteWebhook(webhook.Id, webhook.Token!, req, files); await _rest.ExecuteWebhook(webhook.Id, webhook.Token!, req, files, threadId);
} }
} }

View File

@ -194,6 +194,11 @@ namespace PluralKit.Bot
} }
public static bool IsValidGuildChannel(Channel channel) => public static bool IsValidGuildChannel(Channel channel) =>
channel.Type == Channel.ChannelType.GuildText || channel.Type == Channel.ChannelType.GuildNews; channel.Type is
Channel.ChannelType.GuildText or
Channel.ChannelType.GuildNews or
Channel.ChannelType.GuildPublicThread or
Channel.ChannelType.GuildPrivateThread or
Channel.ChannelType.GuildNewsThread;
} }
} }

View File

@ -36,11 +36,14 @@ namespace PluralKit.Bot
if (channel != null) if (channel != null)
{ {
props.Add(new("GuildId", new ScalarValue(channel.Value))); props.Add(new("ChannelId", new ScalarValue(channel.Value)));
if (_cache.TryGetChannel(channel.Value, out _))
{
var botPermissions = _bot.PermissionsIn(channel.Value); var botPermissions = _bot.PermissionsIn(channel.Value);
props.Add(new("BotPermissions", new ScalarValue(botPermissions))); props.Add(new("BotPermissions", new ScalarValue(botPermissions)));
} }
}
if (message != null) if (message != null)
props.Add(new("MessageId", new ScalarValue(message.Value))); props.Add(new("MessageId", new ScalarValue(message.Value)));