fix(bot): make proxy/log blacklists work properly with threads

Handling of both blacklists was inconsistent when dealing with threads:
- proxy blacklist of root channel blacklists all threads
- log blacklist of root channel _did not apply_ to threads
- couldn't proxy blacklist threads while leaving root channel proxyable

This change fixes the inconsistencies:
- proxy _and_ log blacklist of root channel affects all threads
- now able to individually proxy/log blacklist threads, with root
  channel unaffected
This commit is contained in:
Iris System 2023-05-10 13:16:16 +12:00
parent 24f0fcd563
commit 66544b9d40
7 changed files with 34 additions and 26 deletions

View File

@ -17,6 +17,7 @@ public class Checks
private readonly ProxyMatcher _matcher; private readonly ProxyMatcher _matcher;
private readonly ProxyService _proxy; private readonly ProxyService _proxy;
private readonly DiscordApiClient _rest; private readonly DiscordApiClient _rest;
private readonly IDiscordCache _cache;
private readonly PermissionSet[] requiredPermissions = private readonly PermissionSet[] requiredPermissions =
{ {
@ -26,12 +27,13 @@ public class Checks
}; };
// todo: make sure everything uses the minimum amount of REST calls necessary // todo: make sure everything uses the minimum amount of REST calls necessary
public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher) public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher, IDiscordCache cache)
{ {
_rest = rest; _rest = rest;
_botConfig = botConfig; _botConfig = botConfig;
_proxy = proxy; _proxy = proxy;
_matcher = matcher; _matcher = matcher;
_cache = cache;
} }
public async Task PermCheckGuild(Context ctx) public async Task PermCheckGuild(Context ctx)
@ -230,11 +232,12 @@ public class Checks
if (channel == null) if (channel == null)
throw new PKError("Unable to get the channel associated with this message."); throw new PKError("Unable to get the channel associated with this message.");
var rootChannel = await _cache.GetRootChannel(channel.Id);
if (channel.GuildId == null) if (channel.GuildId == null)
throw new PKError("PluralKit is not able to proxy messages in DMs."); throw new PKError("PluralKit is not able to proxy messages in DMs.");
// using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId // using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, msg.ChannelId); var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId);
var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList(); var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList();
// for now this is just server // for now this is just server

View File

@ -159,7 +159,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
// Get message context from DB (tracking w/ metrics) // Get message context from DB (tracking w/ metrics)
MessageContext ctx; MessageContext ctx;
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel); ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel, channel.Id != rootChannel ? channel.Id : default);
try try
{ {

View File

@ -55,6 +55,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
var channel = await _cache.GetChannel(evt.ChannelId); var channel = await _cache.GetChannel(evt.ChannelId);
if (!DiscordUtils.IsValidGuildChannel(channel)) if (!DiscordUtils.IsValidGuildChannel(channel))
return; return;
var rootChannel = await _cache.GetRootChannel(channel.Id);
var guild = await _cache.GetGuild(channel.GuildId!.Value); var guild = await _cache.GetGuild(channel.GuildId!.Value);
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current; var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
@ -65,7 +66,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
// Just run the normal message handling code, with a flag to disable autoproxying // Just run the normal message handling code, with a flag to disable autoproxying
MessageContext ctx; MessageContext ctx;
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId); ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId);
var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel); var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel);
var botPermissions = await _cache.PermissionsIn(channel.Id); var botPermissions = await _cache.PermissionsIn(channel.Id);

View File

@ -229,9 +229,12 @@ public class ProxyService
if (originalMsg == null) if (originalMsg == null)
throw new PKError("Could not reproxy message."); throw new PKError("Could not reproxy message.");
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
// Get a MessageContext for the original message // Get a MessageContext for the original message
MessageContext ctx = MessageContext ctx =
await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, msg.Channel); await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, rootChannel.Id, msg.Channel);
// Make sure proxying is enabled here // Make sure proxying is enabled here
if (ctx.InBlacklist) if (ctx.InBlacklist)
@ -250,8 +253,6 @@ public class ProxyService
ProxyTags = member.ProxyTags.FirstOrDefault(), ProxyTags = member.ProxyTags.FirstOrDefault(),
}; };
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null; var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
var guild = await _rest.GetGuildOrNull(msg.Guild!.Value); var guild = await _rest.GetGuildOrNull(msg.Guild!.Value);
var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id); var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id);

View File

@ -63,11 +63,12 @@ public class LogChannelService
return null; return null;
var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value; var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
var rootChannel = await _cache.GetRootChannel(trigger.ChannelId);
// get log channel info from the database // get log channel info from the database
var guild = await _repo.GetGuild(guildId); var guild = await _repo.GetGuild(guildId);
var logChannelId = guild.LogChannel; var logChannelId = guild.LogChannel;
var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId); var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId || x == rootChannel.Id);
// if (ctx.SystemId == null || // if (ctx.SystemId == null ||
// removed the above, there shouldn't be a way to get to this code path if you don't have a system registered // removed the above, there shouldn't be a way to get to this code path if you don't have a system registered

View File

@ -1,4 +1,4 @@
create function message_context(account_id bigint, guild_id bigint, channel_id bigint) create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint)
returns table ( returns table (
system_id int, system_id int,
log_channel bigint, log_channel bigint,
@ -31,8 +31,10 @@ as $$
select select
system.id as system_id, system.id as system_id,
guild.log_channel, guild.log_channel,
(channel_id = any (guild.blacklist)) as in_blacklist, ((channel_id = any (guild.blacklist))
(channel_id = any (guild.log_blacklist)) as in_log_blacklist, or (thread_id = any (guild.blacklist))) as in_blacklist,
((channel_id = any (guild.log_blacklist))
or (thread_id = any (guild.log_blacklist))) as in_log_blacklist,
coalesce(guild.log_cleanup_enabled, false), coalesce(guild.log_cleanup_enabled, false),
coalesce(system_guild.proxy_enabled, true) as proxy_enabled, coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
system_last_switch.switch as last_switch, system_last_switch.switch as last_switch,

View File

@ -2,9 +2,9 @@ namespace PluralKit.Core;
public partial class ModelRepository public partial class ModelRepository
{ {
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel) public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel, ulong thread)
=> _db.QuerySingleProcedure<MessageContext>("message_context", => _db.QuerySingleProcedure<MessageContext>("message_context",
new { account_id = account, guild_id = guild, channel_id = channel }); new { account_id = account, guild_id = guild, channel_id = channel, thread_id = thread });
public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild) public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
=> _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild }); => _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild });