fix(bot): make proxy/log blacklists work properly with threads

Handling of both blacklists was inconsistent when dealing with threads:
- proxy blacklist of root channel blacklists all threads
- log blacklist of root channel _did not apply_ to threads
- couldn't proxy blacklist threads while leaving root channel proxyable

This change fixes the inconsistencies:
- proxy _and_ log blacklist of root channel affects all threads
- now able to individually proxy/log blacklist threads, with root
  channel unaffected
This commit is contained in:
Iris System 2023-05-10 13:16:16 +12:00
parent 24f0fcd563
commit 66544b9d40
7 changed files with 34 additions and 26 deletions

View File

@ -17,6 +17,7 @@ public class Checks
private readonly ProxyMatcher _matcher;
private readonly ProxyService _proxy;
private readonly DiscordApiClient _rest;
private readonly IDiscordCache _cache;
private readonly PermissionSet[] requiredPermissions =
{
@ -26,12 +27,13 @@ public class Checks
};
// todo: make sure everything uses the minimum amount of REST calls necessary
public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher)
public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher, IDiscordCache cache)
{
_rest = rest;
_botConfig = botConfig;
_proxy = proxy;
_matcher = matcher;
_cache = cache;
}
public async Task PermCheckGuild(Context ctx)
@ -230,11 +232,12 @@ public class Checks
if (channel == null)
throw new PKError("Unable to get the channel associated with this message.");
var rootChannel = await _cache.GetRootChannel(channel.Id);
if (channel.GuildId == null)
throw new PKError("PluralKit is not able to proxy messages in DMs.");
// using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, msg.ChannelId);
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId);
var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList();
// for now this is just server

View File

@ -159,7 +159,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
// Get message context from DB (tracking w/ metrics)
MessageContext ctx;
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel);
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel, channel.Id != rootChannel ? channel.Id : default);
try
{

View File

@ -55,6 +55,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
var channel = await _cache.GetChannel(evt.ChannelId);
if (!DiscordUtils.IsValidGuildChannel(channel))
return;
var rootChannel = await _cache.GetRootChannel(channel.Id);
var guild = await _cache.GetGuild(channel.GuildId!.Value);
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
@ -65,7 +66,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
// Just run the normal message handling code, with a flag to disable autoproxying
MessageContext ctx;
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId);
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId);
var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel);
var botPermissions = await _cache.PermissionsIn(channel.Id);

View File

@ -229,9 +229,12 @@ public class ProxyService
if (originalMsg == null)
throw new PKError("Could not reproxy message.");
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
// Get a MessageContext for the original message
MessageContext ctx =
await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, msg.Channel);
await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, rootChannel.Id, msg.Channel);
// Make sure proxying is enabled here
if (ctx.InBlacklist)
@ -250,8 +253,6 @@ public class ProxyService
ProxyTags = member.ProxyTags.FirstOrDefault(),
};
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
var guild = await _rest.GetGuildOrNull(msg.Guild!.Value);
var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id);

View File

@ -63,11 +63,12 @@ public class LogChannelService
return null;
var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
var rootChannel = await _cache.GetRootChannel(trigger.ChannelId);
// get log channel info from the database
var guild = await _repo.GetGuild(guildId);
var logChannelId = guild.LogChannel;
var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId || x == rootChannel.Id);
// if (ctx.SystemId == null ||
// removed the above, there shouldn't be a way to get to this code path if you don't have a system registered

View File

@ -1,4 +1,4 @@
create function message_context(account_id bigint, guild_id bigint, channel_id bigint)
create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint)
returns table (
system_id int,
log_channel bigint,
@ -29,23 +29,25 @@ as $$
where accounts.uid = account_id),
guild as (select * from servers where id = guild_id)
select
system.id as system_id,
system.id as system_id,
guild.log_channel,
(channel_id = any (guild.blacklist)) as in_blacklist,
(channel_id = any (guild.log_blacklist)) as in_log_blacklist,
((channel_id = any (guild.blacklist))
or (thread_id = any (guild.blacklist))) as in_blacklist,
((channel_id = any (guild.log_blacklist))
or (thread_id = any (guild.log_blacklist))) as in_log_blacklist,
coalesce(guild.log_cleanup_enabled, false),
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
system_last_switch.switch as last_switch,
system_last_switch.members as last_switch_members,
system_last_switch.timestamp as last_switch_timestamp,
system.tag as system_tag,
system.guild_tag as system_guild_tag,
coalesce(system.tag_enabled, true) as tag_enabled,
system.avatar_url as system_avatar,
system.account_autoproxy as allow_autoproxy,
system.latch_timeout as latch_timeout,
system.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
system.proxy_error_message_enabled as proxy_error_message_enabled
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
system_last_switch.switch as last_switch,
system_last_switch.members as last_switch_members,
system_last_switch.timestamp as last_switch_timestamp,
system.tag as system_tag,
system.guild_tag as system_guild_tag,
coalesce(system.tag_enabled, true) as tag_enabled,
system.avatar_url as system_avatar,
system.account_autoproxy as allow_autoproxy,
system.latch_timeout as latch_timeout,
system.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
system.proxy_error_message_enabled as proxy_error_message_enabled
-- We need a "from" clause, so we just use some bogus data that's always present
-- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data
from (select 1) as _placeholder

View File

@ -2,9 +2,9 @@ namespace PluralKit.Core;
public partial class ModelRepository
{
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel)
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel, ulong thread)
=> _db.QuerySingleProcedure<MessageContext>("message_context",
new { account_id = account, guild_id = guild, channel_id = channel });
new { account_id = account, guild_id = guild, channel_id = channel, thread_id = thread });
public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
=> _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild });