From 66544b9d40ef963428256b6b6921edd0f44db2ea Mon Sep 17 00:00:00 2001 From: Iris System Date: Wed, 10 May 2023 13:16:16 +1200 Subject: [PATCH] fix(bot): make proxy/log blacklists work properly with threads Handling of both blacklists was inconsistent when dealing with threads: - proxy blacklist of root channel blacklists all threads - log blacklist of root channel _did not apply_ to threads - couldn't proxy blacklist threads while leaving root channel proxyable This change fixes the inconsistencies: - proxy _and_ log blacklist of root channel affects all threads - now able to individually proxy/log blacklist threads, with root channel unaffected --- PluralKit.Bot/Commands/Checks.cs | 7 ++-- PluralKit.Bot/Handlers/MessageCreated.cs | 2 +- PluralKit.Bot/Handlers/MessageEdited.cs | 3 +- PluralKit.Bot/Proxy/ProxyService.cs | 7 ++-- PluralKit.Bot/Services/LogChannelService.cs | 3 +- .../Database/Functions/functions.sql | 34 ++++++++++--------- .../Repository/ModelRepository.Context.cs | 4 +-- 7 files changed, 34 insertions(+), 26 deletions(-) diff --git a/PluralKit.Bot/Commands/Checks.cs b/PluralKit.Bot/Commands/Checks.cs index a2be32cf..43f2c8a0 100644 --- a/PluralKit.Bot/Commands/Checks.cs +++ b/PluralKit.Bot/Commands/Checks.cs @@ -17,6 +17,7 @@ public class Checks private readonly ProxyMatcher _matcher; private readonly ProxyService _proxy; private readonly DiscordApiClient _rest; + private readonly IDiscordCache _cache; private readonly PermissionSet[] requiredPermissions = { @@ -26,12 +27,13 @@ public class Checks }; // todo: make sure everything uses the minimum amount of REST calls necessary - public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher) + public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher, IDiscordCache cache) { _rest = rest; _botConfig = botConfig; _proxy = proxy; _matcher = matcher; + _cache = cache; } public async Task PermCheckGuild(Context ctx) @@ -230,11 +232,12 @@ public class Checks if (channel == null) throw new PKError("Unable to get the channel associated with this message."); + var rootChannel = await _cache.GetRootChannel(channel.Id); if (channel.GuildId == null) throw new PKError("PluralKit is not able to proxy messages in DMs."); // using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId - var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, msg.ChannelId); + var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId); var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList(); // for now this is just server diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index 3625ea58..f4d84988 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -159,7 +159,7 @@ public class MessageCreated: IEventHandler // Get message context from DB (tracking w/ metrics) MessageContext ctx; using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) - ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel); + ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel, channel.Id != rootChannel ? channel.Id : default); try { diff --git a/PluralKit.Bot/Handlers/MessageEdited.cs b/PluralKit.Bot/Handlers/MessageEdited.cs index c4d20468..2c869c7e 100644 --- a/PluralKit.Bot/Handlers/MessageEdited.cs +++ b/PluralKit.Bot/Handlers/MessageEdited.cs @@ -55,6 +55,7 @@ public class MessageEdited: IEventHandler var channel = await _cache.GetChannel(evt.ChannelId); if (!DiscordUtils.IsValidGuildChannel(channel)) return; + var rootChannel = await _cache.GetRootChannel(channel.Id); var guild = await _cache.GetGuild(channel.GuildId!.Value); var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current; @@ -65,7 +66,7 @@ public class MessageEdited: IEventHandler // Just run the normal message handling code, with a flag to disable autoproxying MessageContext ctx; using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) - ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId); + ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId); var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel); var botPermissions = await _cache.PermissionsIn(channel.Id); diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index b2f82a38..ca85c396 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -229,9 +229,12 @@ public class ProxyService if (originalMsg == null) throw new PKError("Could not reproxy message."); + var messageChannel = await _rest.GetChannelOrNull(msg.Channel!); + var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel; + // Get a MessageContext for the original message MessageContext ctx = - await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, msg.Channel); + await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, rootChannel.Id, msg.Channel); // Make sure proxying is enabled here if (ctx.InBlacklist) @@ -250,8 +253,6 @@ public class ProxyService ProxyTags = member.ProxyTags.FirstOrDefault(), }; - var messageChannel = await _rest.GetChannelOrNull(msg.Channel!); - var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel; var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null; var guild = await _rest.GetGuildOrNull(msg.Guild!.Value); var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id); diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index 337d513a..c8b21282 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -63,11 +63,12 @@ public class LogChannelService return null; var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value; + var rootChannel = await _cache.GetRootChannel(trigger.ChannelId); // get log channel info from the database var guild = await _repo.GetGuild(guildId); var logChannelId = guild.LogChannel; - var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId); + var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId || x == rootChannel.Id); // if (ctx.SystemId == null || // removed the above, there shouldn't be a way to get to this code path if you don't have a system registered diff --git a/PluralKit.Core/Database/Functions/functions.sql b/PluralKit.Core/Database/Functions/functions.sql index b3fe6df6..6ce9e360 100644 --- a/PluralKit.Core/Database/Functions/functions.sql +++ b/PluralKit.Core/Database/Functions/functions.sql @@ -1,4 +1,4 @@ -create function message_context(account_id bigint, guild_id bigint, channel_id bigint) +create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint) returns table ( system_id int, log_channel bigint, @@ -29,23 +29,25 @@ as $$ where accounts.uid = account_id), guild as (select * from servers where id = guild_id) select - system.id as system_id, + system.id as system_id, guild.log_channel, - (channel_id = any (guild.blacklist)) as in_blacklist, - (channel_id = any (guild.log_blacklist)) as in_log_blacklist, + ((channel_id = any (guild.blacklist)) + or (thread_id = any (guild.blacklist))) as in_blacklist, + ((channel_id = any (guild.log_blacklist)) + or (thread_id = any (guild.log_blacklist))) as in_log_blacklist, coalesce(guild.log_cleanup_enabled, false), - coalesce(system_guild.proxy_enabled, true) as proxy_enabled, - system_last_switch.switch as last_switch, - system_last_switch.members as last_switch_members, - system_last_switch.timestamp as last_switch_timestamp, - system.tag as system_tag, - system.guild_tag as system_guild_tag, - coalesce(system.tag_enabled, true) as tag_enabled, - system.avatar_url as system_avatar, - system.account_autoproxy as allow_autoproxy, - system.latch_timeout as latch_timeout, - system.case_sensitive_proxy_tags as case_sensitive_proxy_tags, - system.proxy_error_message_enabled as proxy_error_message_enabled + coalesce(system_guild.proxy_enabled, true) as proxy_enabled, + system_last_switch.switch as last_switch, + system_last_switch.members as last_switch_members, + system_last_switch.timestamp as last_switch_timestamp, + system.tag as system_tag, + system.guild_tag as system_guild_tag, + coalesce(system.tag_enabled, true) as tag_enabled, + system.avatar_url as system_avatar, + system.account_autoproxy as allow_autoproxy, + system.latch_timeout as latch_timeout, + system.case_sensitive_proxy_tags as case_sensitive_proxy_tags, + system.proxy_error_message_enabled as proxy_error_message_enabled -- We need a "from" clause, so we just use some bogus data that's always present -- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data from (select 1) as _placeholder diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Context.cs b/PluralKit.Core/Database/Repository/ModelRepository.Context.cs index c839ad1e..8ea6ab9b 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Context.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Context.cs @@ -2,9 +2,9 @@ namespace PluralKit.Core; public partial class ModelRepository { - public Task GetMessageContext(ulong account, ulong guild, ulong channel) + public Task GetMessageContext(ulong account, ulong guild, ulong channel, ulong thread) => _db.QuerySingleProcedure("message_context", - new { account_id = account, guild_id = guild, channel_id = channel }); + new { account_id = account, guild_id = guild, channel_id = channel, thread_id = thread }); public Task> GetProxyMembers(ulong account, ulong guild) => _db.QueryProcedure("proxy_members", new { account_id = account, guild_id = guild });