fix(bot): make proxy/log blacklists work properly with threads
Handling of both blacklists was inconsistent when dealing with threads: - proxy blacklist of root channel blacklists all threads - log blacklist of root channel _did not apply_ to threads - couldn't proxy blacklist threads while leaving root channel proxyable This change fixes the inconsistencies: - proxy _and_ log blacklist of root channel affects all threads - now able to individually proxy/log blacklist threads, with root channel unaffected
This commit is contained in:
parent
24f0fcd563
commit
66544b9d40
@ -17,6 +17,7 @@ public class Checks
|
||||
private readonly ProxyMatcher _matcher;
|
||||
private readonly ProxyService _proxy;
|
||||
private readonly DiscordApiClient _rest;
|
||||
private readonly IDiscordCache _cache;
|
||||
|
||||
private readonly PermissionSet[] requiredPermissions =
|
||||
{
|
||||
@ -26,12 +27,13 @@ public class Checks
|
||||
};
|
||||
|
||||
// todo: make sure everything uses the minimum amount of REST calls necessary
|
||||
public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher)
|
||||
public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher, IDiscordCache cache)
|
||||
{
|
||||
_rest = rest;
|
||||
_botConfig = botConfig;
|
||||
_proxy = proxy;
|
||||
_matcher = matcher;
|
||||
_cache = cache;
|
||||
}
|
||||
|
||||
public async Task PermCheckGuild(Context ctx)
|
||||
@ -230,11 +232,12 @@ public class Checks
|
||||
if (channel == null)
|
||||
throw new PKError("Unable to get the channel associated with this message.");
|
||||
|
||||
var rootChannel = await _cache.GetRootChannel(channel.Id);
|
||||
if (channel.GuildId == null)
|
||||
throw new PKError("PluralKit is not able to proxy messages in DMs.");
|
||||
|
||||
// using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId
|
||||
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, msg.ChannelId);
|
||||
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId);
|
||||
var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList();
|
||||
|
||||
// for now this is just server
|
||||
|
@ -159,7 +159,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
|
||||
// Get message context from DB (tracking w/ metrics)
|
||||
MessageContext ctx;
|
||||
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
|
||||
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel);
|
||||
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel, channel.Id != rootChannel ? channel.Id : default);
|
||||
|
||||
try
|
||||
{
|
||||
|
@ -55,6 +55,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
|
||||
var channel = await _cache.GetChannel(evt.ChannelId);
|
||||
if (!DiscordUtils.IsValidGuildChannel(channel))
|
||||
return;
|
||||
var rootChannel = await _cache.GetRootChannel(channel.Id);
|
||||
var guild = await _cache.GetGuild(channel.GuildId!.Value);
|
||||
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
|
||||
|
||||
@ -65,7 +66,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
|
||||
// Just run the normal message handling code, with a flag to disable autoproxying
|
||||
MessageContext ctx;
|
||||
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
|
||||
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId);
|
||||
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId);
|
||||
|
||||
var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel);
|
||||
var botPermissions = await _cache.PermissionsIn(channel.Id);
|
||||
|
@ -229,9 +229,12 @@ public class ProxyService
|
||||
if (originalMsg == null)
|
||||
throw new PKError("Could not reproxy message.");
|
||||
|
||||
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
|
||||
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
|
||||
|
||||
// Get a MessageContext for the original message
|
||||
MessageContext ctx =
|
||||
await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, msg.Channel);
|
||||
await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, rootChannel.Id, msg.Channel);
|
||||
|
||||
// Make sure proxying is enabled here
|
||||
if (ctx.InBlacklist)
|
||||
@ -250,8 +253,6 @@ public class ProxyService
|
||||
ProxyTags = member.ProxyTags.FirstOrDefault(),
|
||||
};
|
||||
|
||||
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
|
||||
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
|
||||
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
|
||||
var guild = await _rest.GetGuildOrNull(msg.Guild!.Value);
|
||||
var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id);
|
||||
|
@ -63,11 +63,12 @@ public class LogChannelService
|
||||
return null;
|
||||
|
||||
var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
|
||||
var rootChannel = await _cache.GetRootChannel(trigger.ChannelId);
|
||||
|
||||
// get log channel info from the database
|
||||
var guild = await _repo.GetGuild(guildId);
|
||||
var logChannelId = guild.LogChannel;
|
||||
var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
|
||||
var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId || x == rootChannel.Id);
|
||||
|
||||
// if (ctx.SystemId == null ||
|
||||
// removed the above, there shouldn't be a way to get to this code path if you don't have a system registered
|
||||
|
@ -1,4 +1,4 @@
|
||||
create function message_context(account_id bigint, guild_id bigint, channel_id bigint)
|
||||
create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint)
|
||||
returns table (
|
||||
system_id int,
|
||||
log_channel bigint,
|
||||
@ -29,23 +29,25 @@ as $$
|
||||
where accounts.uid = account_id),
|
||||
guild as (select * from servers where id = guild_id)
|
||||
select
|
||||
system.id as system_id,
|
||||
system.id as system_id,
|
||||
guild.log_channel,
|
||||
(channel_id = any (guild.blacklist)) as in_blacklist,
|
||||
(channel_id = any (guild.log_blacklist)) as in_log_blacklist,
|
||||
((channel_id = any (guild.blacklist))
|
||||
or (thread_id = any (guild.blacklist))) as in_blacklist,
|
||||
((channel_id = any (guild.log_blacklist))
|
||||
or (thread_id = any (guild.log_blacklist))) as in_log_blacklist,
|
||||
coalesce(guild.log_cleanup_enabled, false),
|
||||
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
|
||||
system_last_switch.switch as last_switch,
|
||||
system_last_switch.members as last_switch_members,
|
||||
system_last_switch.timestamp as last_switch_timestamp,
|
||||
system.tag as system_tag,
|
||||
system.guild_tag as system_guild_tag,
|
||||
coalesce(system.tag_enabled, true) as tag_enabled,
|
||||
system.avatar_url as system_avatar,
|
||||
system.account_autoproxy as allow_autoproxy,
|
||||
system.latch_timeout as latch_timeout,
|
||||
system.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
|
||||
system.proxy_error_message_enabled as proxy_error_message_enabled
|
||||
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
|
||||
system_last_switch.switch as last_switch,
|
||||
system_last_switch.members as last_switch_members,
|
||||
system_last_switch.timestamp as last_switch_timestamp,
|
||||
system.tag as system_tag,
|
||||
system.guild_tag as system_guild_tag,
|
||||
coalesce(system.tag_enabled, true) as tag_enabled,
|
||||
system.avatar_url as system_avatar,
|
||||
system.account_autoproxy as allow_autoproxy,
|
||||
system.latch_timeout as latch_timeout,
|
||||
system.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
|
||||
system.proxy_error_message_enabled as proxy_error_message_enabled
|
||||
-- We need a "from" clause, so we just use some bogus data that's always present
|
||||
-- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data
|
||||
from (select 1) as _placeholder
|
||||
|
@ -2,9 +2,9 @@ namespace PluralKit.Core;
|
||||
|
||||
public partial class ModelRepository
|
||||
{
|
||||
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel)
|
||||
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel, ulong thread)
|
||||
=> _db.QuerySingleProcedure<MessageContext>("message_context",
|
||||
new { account_id = account, guild_id = guild, channel_id = channel });
|
||||
new { account_id = account, guild_id = guild, channel_id = channel, thread_id = thread });
|
||||
|
||||
public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
|
||||
=> _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild });
|
||||
|
Loading…
Reference in New Issue
Block a user