fix(bot): make proxy/log blacklists work properly with threads
Handling of both blacklists was inconsistent when dealing with threads: - proxy blacklist of root channel blacklists all threads - log blacklist of root channel _did not apply_ to threads - couldn't proxy blacklist threads while leaving root channel proxyable This change fixes the inconsistencies: - proxy _and_ log blacklist of root channel affects all threads - now able to individually proxy/log blacklist threads, with root channel unaffected
This commit is contained in:
		@@ -17,6 +17,7 @@ public class Checks
 | 
			
		||||
    private readonly ProxyMatcher _matcher;
 | 
			
		||||
    private readonly ProxyService _proxy;
 | 
			
		||||
    private readonly DiscordApiClient _rest;
 | 
			
		||||
    private readonly IDiscordCache _cache;
 | 
			
		||||
 | 
			
		||||
    private readonly PermissionSet[] requiredPermissions =
 | 
			
		||||
    {
 | 
			
		||||
@@ -26,12 +27,13 @@ public class Checks
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // todo: make sure everything uses the minimum amount of REST calls necessary
 | 
			
		||||
    public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher)
 | 
			
		||||
    public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher, IDiscordCache cache)
 | 
			
		||||
    {
 | 
			
		||||
        _rest = rest;
 | 
			
		||||
        _botConfig = botConfig;
 | 
			
		||||
        _proxy = proxy;
 | 
			
		||||
        _matcher = matcher;
 | 
			
		||||
        _cache = cache;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public async Task PermCheckGuild(Context ctx)
 | 
			
		||||
@@ -230,11 +232,12 @@ public class Checks
 | 
			
		||||
        if (channel == null)
 | 
			
		||||
            throw new PKError("Unable to get the channel associated with this message.");
 | 
			
		||||
 | 
			
		||||
        var rootChannel = await _cache.GetRootChannel(channel.Id);
 | 
			
		||||
        if (channel.GuildId == null)
 | 
			
		||||
            throw new PKError("PluralKit is not able to proxy messages in DMs.");
 | 
			
		||||
 | 
			
		||||
        // using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId
 | 
			
		||||
        var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, msg.ChannelId);
 | 
			
		||||
        var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId);
 | 
			
		||||
        var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList();
 | 
			
		||||
 | 
			
		||||
        // for now this is just server
 | 
			
		||||
 
 | 
			
		||||
@@ -159,7 +159,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
 | 
			
		||||
        // Get message context from DB (tracking w/ metrics)
 | 
			
		||||
        MessageContext ctx;
 | 
			
		||||
        using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
 | 
			
		||||
            ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel);
 | 
			
		||||
            ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel, channel.Id != rootChannel ? channel.Id : default);
 | 
			
		||||
 | 
			
		||||
        try
 | 
			
		||||
        {
 | 
			
		||||
 
 | 
			
		||||
@@ -55,6 +55,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
 | 
			
		||||
        var channel = await _cache.GetChannel(evt.ChannelId);
 | 
			
		||||
        if (!DiscordUtils.IsValidGuildChannel(channel))
 | 
			
		||||
            return;
 | 
			
		||||
        var rootChannel = await _cache.GetRootChannel(channel.Id);
 | 
			
		||||
        var guild = await _cache.GetGuild(channel.GuildId!.Value);
 | 
			
		||||
        var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
 | 
			
		||||
 | 
			
		||||
@@ -65,7 +66,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
 | 
			
		||||
        // Just run the normal message handling code, with a flag to disable autoproxying
 | 
			
		||||
        MessageContext ctx;
 | 
			
		||||
        using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
 | 
			
		||||
            ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId);
 | 
			
		||||
            ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId);
 | 
			
		||||
 | 
			
		||||
        var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel);
 | 
			
		||||
        var botPermissions = await _cache.PermissionsIn(channel.Id);
 | 
			
		||||
 
 | 
			
		||||
@@ -229,9 +229,12 @@ public class ProxyService
 | 
			
		||||
        if (originalMsg == null)
 | 
			
		||||
            throw new PKError("Could not reproxy message.");
 | 
			
		||||
 | 
			
		||||
        var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
 | 
			
		||||
        var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
 | 
			
		||||
 | 
			
		||||
        // Get a MessageContext for the original message
 | 
			
		||||
        MessageContext ctx =
 | 
			
		||||
            await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, msg.Channel);
 | 
			
		||||
            await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, rootChannel.Id, msg.Channel);
 | 
			
		||||
 | 
			
		||||
        // Make sure proxying is enabled here
 | 
			
		||||
        if (ctx.InBlacklist)
 | 
			
		||||
@@ -250,8 +253,6 @@ public class ProxyService
 | 
			
		||||
            ProxyTags = member.ProxyTags.FirstOrDefault(),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
 | 
			
		||||
        var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
 | 
			
		||||
        var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
 | 
			
		||||
        var guild = await _rest.GetGuildOrNull(msg.Guild!.Value);
 | 
			
		||||
        var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id);
 | 
			
		||||
 
 | 
			
		||||
@@ -63,11 +63,12 @@ public class LogChannelService
 | 
			
		||||
            return null;
 | 
			
		||||
 | 
			
		||||
        var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
 | 
			
		||||
        var rootChannel = await _cache.GetRootChannel(trigger.ChannelId);
 | 
			
		||||
 | 
			
		||||
        // get log channel info from the database
 | 
			
		||||
        var guild = await _repo.GetGuild(guildId);
 | 
			
		||||
        var logChannelId = guild.LogChannel;
 | 
			
		||||
        var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
 | 
			
		||||
        var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId || x == rootChannel.Id);
 | 
			
		||||
 | 
			
		||||
        // if (ctx.SystemId == null ||
 | 
			
		||||
        // removed the above, there shouldn't be a way to get to this code path if you don't have a system registered
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
create function message_context(account_id bigint, guild_id bigint, channel_id bigint)
 | 
			
		||||
create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint)
 | 
			
		||||
    returns table (
 | 
			
		||||
        system_id int,
 | 
			
		||||
        log_channel bigint,
 | 
			
		||||
@@ -29,23 +29,25 @@ as $$
 | 
			
		||||
            where accounts.uid = account_id),
 | 
			
		||||
        guild as (select * from servers where id = guild_id)
 | 
			
		||||
    select
 | 
			
		||||
        system.id                                  as system_id,
 | 
			
		||||
        system.id                                    as system_id,
 | 
			
		||||
        guild.log_channel,
 | 
			
		||||
        (channel_id = any (guild.blacklist))       as in_blacklist,
 | 
			
		||||
        (channel_id = any (guild.log_blacklist))   as in_log_blacklist,
 | 
			
		||||
        ((channel_id = any (guild.blacklist))
 | 
			
		||||
         or (thread_id = any (guild.blacklist)))     as in_blacklist,
 | 
			
		||||
        ((channel_id = any (guild.log_blacklist))
 | 
			
		||||
         or (thread_id = any (guild.log_blacklist))) as in_log_blacklist,
 | 
			
		||||
        coalesce(guild.log_cleanup_enabled, false),
 | 
			
		||||
        coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
 | 
			
		||||
        system_last_switch.switch                  as last_switch,
 | 
			
		||||
        system_last_switch.members                 as last_switch_members,
 | 
			
		||||
        system_last_switch.timestamp               as last_switch_timestamp,
 | 
			
		||||
        system.tag                                 as system_tag,
 | 
			
		||||
        system.guild_tag                           as system_guild_tag,
 | 
			
		||||
        coalesce(system.tag_enabled, true)         as tag_enabled,
 | 
			
		||||
        system.avatar_url                          as system_avatar,
 | 
			
		||||
        system.account_autoproxy                   as allow_autoproxy,
 | 
			
		||||
        system.latch_timeout                       as latch_timeout,
 | 
			
		||||
        system.case_sensitive_proxy_tags           as case_sensitive_proxy_tags,
 | 
			
		||||
        system.proxy_error_message_enabled         as proxy_error_message_enabled
 | 
			
		||||
        coalesce(system_guild.proxy_enabled, true)   as proxy_enabled,
 | 
			
		||||
        system_last_switch.switch                    as last_switch,
 | 
			
		||||
        system_last_switch.members                   as last_switch_members,
 | 
			
		||||
        system_last_switch.timestamp                 as last_switch_timestamp,
 | 
			
		||||
        system.tag                                   as system_tag,
 | 
			
		||||
        system.guild_tag                             as system_guild_tag,
 | 
			
		||||
        coalesce(system.tag_enabled, true)           as tag_enabled,
 | 
			
		||||
        system.avatar_url                            as system_avatar,
 | 
			
		||||
        system.account_autoproxy                     as allow_autoproxy,
 | 
			
		||||
        system.latch_timeout                         as latch_timeout,
 | 
			
		||||
        system.case_sensitive_proxy_tags             as case_sensitive_proxy_tags,
 | 
			
		||||
        system.proxy_error_message_enabled           as proxy_error_message_enabled
 | 
			
		||||
    -- We need a "from" clause, so we just use some bogus data that's always present
 | 
			
		||||
    -- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data
 | 
			
		||||
    from (select 1) as _placeholder
 | 
			
		||||
 
 | 
			
		||||
@@ -2,9 +2,9 @@ namespace PluralKit.Core;
 | 
			
		||||
 | 
			
		||||
public partial class ModelRepository
 | 
			
		||||
{
 | 
			
		||||
    public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel)
 | 
			
		||||
    public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel, ulong thread)
 | 
			
		||||
        => _db.QuerySingleProcedure<MessageContext>("message_context",
 | 
			
		||||
            new { account_id = account, guild_id = guild, channel_id = channel });
 | 
			
		||||
            new { account_id = account, guild_id = guild, channel_id = channel, thread_id = thread });
 | 
			
		||||
 | 
			
		||||
    public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
 | 
			
		||||
        => _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild });
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user