fix(bot): make proxy/log blacklists work properly with threads
Handling of both blacklists was inconsistent when dealing with threads: - proxy blacklist of root channel blacklists all threads - log blacklist of root channel _did not apply_ to threads - couldn't proxy blacklist threads while leaving root channel proxyable This change fixes the inconsistencies: - proxy _and_ log blacklist of root channel affects all threads - now able to individually proxy/log blacklist threads, with root channel unaffected
This commit is contained in:
parent
24f0fcd563
commit
66544b9d40
@ -17,6 +17,7 @@ public class Checks
|
|||||||
private readonly ProxyMatcher _matcher;
|
private readonly ProxyMatcher _matcher;
|
||||||
private readonly ProxyService _proxy;
|
private readonly ProxyService _proxy;
|
||||||
private readonly DiscordApiClient _rest;
|
private readonly DiscordApiClient _rest;
|
||||||
|
private readonly IDiscordCache _cache;
|
||||||
|
|
||||||
private readonly PermissionSet[] requiredPermissions =
|
private readonly PermissionSet[] requiredPermissions =
|
||||||
{
|
{
|
||||||
@ -26,12 +27,13 @@ public class Checks
|
|||||||
};
|
};
|
||||||
|
|
||||||
// todo: make sure everything uses the minimum amount of REST calls necessary
|
// todo: make sure everything uses the minimum amount of REST calls necessary
|
||||||
public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher)
|
public Checks(DiscordApiClient rest, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher, IDiscordCache cache)
|
||||||
{
|
{
|
||||||
_rest = rest;
|
_rest = rest;
|
||||||
_botConfig = botConfig;
|
_botConfig = botConfig;
|
||||||
_proxy = proxy;
|
_proxy = proxy;
|
||||||
_matcher = matcher;
|
_matcher = matcher;
|
||||||
|
_cache = cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task PermCheckGuild(Context ctx)
|
public async Task PermCheckGuild(Context ctx)
|
||||||
@ -230,11 +232,12 @@ public class Checks
|
|||||||
if (channel == null)
|
if (channel == null)
|
||||||
throw new PKError("Unable to get the channel associated with this message.");
|
throw new PKError("Unable to get the channel associated with this message.");
|
||||||
|
|
||||||
|
var rootChannel = await _cache.GetRootChannel(channel.Id);
|
||||||
if (channel.GuildId == null)
|
if (channel.GuildId == null)
|
||||||
throw new PKError("PluralKit is not able to proxy messages in DMs.");
|
throw new PKError("PluralKit is not able to proxy messages in DMs.");
|
||||||
|
|
||||||
// using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId
|
// using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId
|
||||||
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, msg.ChannelId);
|
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId);
|
||||||
var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList();
|
var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList();
|
||||||
|
|
||||||
// for now this is just server
|
// for now this is just server
|
||||||
|
@ -159,7 +159,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
|
|||||||
// Get message context from DB (tracking w/ metrics)
|
// Get message context from DB (tracking w/ metrics)
|
||||||
MessageContext ctx;
|
MessageContext ctx;
|
||||||
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
|
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
|
||||||
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel);
|
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel, channel.Id != rootChannel ? channel.Id : default);
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
@ -55,6 +55,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
|
|||||||
var channel = await _cache.GetChannel(evt.ChannelId);
|
var channel = await _cache.GetChannel(evt.ChannelId);
|
||||||
if (!DiscordUtils.IsValidGuildChannel(channel))
|
if (!DiscordUtils.IsValidGuildChannel(channel))
|
||||||
return;
|
return;
|
||||||
|
var rootChannel = await _cache.GetRootChannel(channel.Id);
|
||||||
var guild = await _cache.GetGuild(channel.GuildId!.Value);
|
var guild = await _cache.GetGuild(channel.GuildId!.Value);
|
||||||
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
|
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
|
|||||||
// Just run the normal message handling code, with a flag to disable autoproxying
|
// Just run the normal message handling code, with a flag to disable autoproxying
|
||||||
MessageContext ctx;
|
MessageContext ctx;
|
||||||
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
|
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
|
||||||
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, evt.ChannelId);
|
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId);
|
||||||
|
|
||||||
var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel);
|
var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel);
|
||||||
var botPermissions = await _cache.PermissionsIn(channel.Id);
|
var botPermissions = await _cache.PermissionsIn(channel.Id);
|
||||||
|
@ -229,9 +229,12 @@ public class ProxyService
|
|||||||
if (originalMsg == null)
|
if (originalMsg == null)
|
||||||
throw new PKError("Could not reproxy message.");
|
throw new PKError("Could not reproxy message.");
|
||||||
|
|
||||||
|
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
|
||||||
|
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
|
||||||
|
|
||||||
// Get a MessageContext for the original message
|
// Get a MessageContext for the original message
|
||||||
MessageContext ctx =
|
MessageContext ctx =
|
||||||
await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, msg.Channel);
|
await _repo.GetMessageContext(msg.Sender, msg.Guild!.Value, rootChannel.Id, msg.Channel);
|
||||||
|
|
||||||
// Make sure proxying is enabled here
|
// Make sure proxying is enabled here
|
||||||
if (ctx.InBlacklist)
|
if (ctx.InBlacklist)
|
||||||
@ -250,8 +253,6 @@ public class ProxyService
|
|||||||
ProxyTags = member.ProxyTags.FirstOrDefault(),
|
ProxyTags = member.ProxyTags.FirstOrDefault(),
|
||||||
};
|
};
|
||||||
|
|
||||||
var messageChannel = await _rest.GetChannelOrNull(msg.Channel!);
|
|
||||||
var rootChannel = messageChannel.IsThread() ? await _rest.GetChannelOrNull(messageChannel.ParentId!.Value) : messageChannel;
|
|
||||||
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
|
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
|
||||||
var guild = await _rest.GetGuildOrNull(msg.Guild!.Value);
|
var guild = await _rest.GetGuildOrNull(msg.Guild!.Value);
|
||||||
var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id);
|
var guildMember = await _rest.GetGuildMember(msg.Guild!.Value, trigger.Author.Id);
|
||||||
|
@ -63,11 +63,12 @@ public class LogChannelService
|
|||||||
return null;
|
return null;
|
||||||
|
|
||||||
var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
|
var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
|
||||||
|
var rootChannel = await _cache.GetRootChannel(trigger.ChannelId);
|
||||||
|
|
||||||
// get log channel info from the database
|
// get log channel info from the database
|
||||||
var guild = await _repo.GetGuild(guildId);
|
var guild = await _repo.GetGuild(guildId);
|
||||||
var logChannelId = guild.LogChannel;
|
var logChannelId = guild.LogChannel;
|
||||||
var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
|
var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId || x == rootChannel.Id);
|
||||||
|
|
||||||
// if (ctx.SystemId == null ||
|
// if (ctx.SystemId == null ||
|
||||||
// removed the above, there shouldn't be a way to get to this code path if you don't have a system registered
|
// removed the above, there shouldn't be a way to get to this code path if you don't have a system registered
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
create function message_context(account_id bigint, guild_id bigint, channel_id bigint)
|
create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint)
|
||||||
returns table (
|
returns table (
|
||||||
system_id int,
|
system_id int,
|
||||||
log_channel bigint,
|
log_channel bigint,
|
||||||
@ -29,23 +29,25 @@ as $$
|
|||||||
where accounts.uid = account_id),
|
where accounts.uid = account_id),
|
||||||
guild as (select * from servers where id = guild_id)
|
guild as (select * from servers where id = guild_id)
|
||||||
select
|
select
|
||||||
system.id as system_id,
|
system.id as system_id,
|
||||||
guild.log_channel,
|
guild.log_channel,
|
||||||
(channel_id = any (guild.blacklist)) as in_blacklist,
|
((channel_id = any (guild.blacklist))
|
||||||
(channel_id = any (guild.log_blacklist)) as in_log_blacklist,
|
or (thread_id = any (guild.blacklist))) as in_blacklist,
|
||||||
|
((channel_id = any (guild.log_blacklist))
|
||||||
|
or (thread_id = any (guild.log_blacklist))) as in_log_blacklist,
|
||||||
coalesce(guild.log_cleanup_enabled, false),
|
coalesce(guild.log_cleanup_enabled, false),
|
||||||
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
|
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
|
||||||
system_last_switch.switch as last_switch,
|
system_last_switch.switch as last_switch,
|
||||||
system_last_switch.members as last_switch_members,
|
system_last_switch.members as last_switch_members,
|
||||||
system_last_switch.timestamp as last_switch_timestamp,
|
system_last_switch.timestamp as last_switch_timestamp,
|
||||||
system.tag as system_tag,
|
system.tag as system_tag,
|
||||||
system.guild_tag as system_guild_tag,
|
system.guild_tag as system_guild_tag,
|
||||||
coalesce(system.tag_enabled, true) as tag_enabled,
|
coalesce(system.tag_enabled, true) as tag_enabled,
|
||||||
system.avatar_url as system_avatar,
|
system.avatar_url as system_avatar,
|
||||||
system.account_autoproxy as allow_autoproxy,
|
system.account_autoproxy as allow_autoproxy,
|
||||||
system.latch_timeout as latch_timeout,
|
system.latch_timeout as latch_timeout,
|
||||||
system.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
|
system.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
|
||||||
system.proxy_error_message_enabled as proxy_error_message_enabled
|
system.proxy_error_message_enabled as proxy_error_message_enabled
|
||||||
-- We need a "from" clause, so we just use some bogus data that's always present
|
-- We need a "from" clause, so we just use some bogus data that's always present
|
||||||
-- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data
|
-- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data
|
||||||
from (select 1) as _placeholder
|
from (select 1) as _placeholder
|
||||||
|
@ -2,9 +2,9 @@ namespace PluralKit.Core;
|
|||||||
|
|
||||||
public partial class ModelRepository
|
public partial class ModelRepository
|
||||||
{
|
{
|
||||||
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel)
|
public Task<MessageContext> GetMessageContext(ulong account, ulong guild, ulong channel, ulong thread)
|
||||||
=> _db.QuerySingleProcedure<MessageContext>("message_context",
|
=> _db.QuerySingleProcedure<MessageContext>("message_context",
|
||||||
new { account_id = account, guild_id = guild, channel_id = channel });
|
new { account_id = account, guild_id = guild, channel_id = channel, thread_id = thread });
|
||||||
|
|
||||||
public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
|
public Task<IEnumerable<ProxyMember>> GetProxyMembers(ulong account, ulong guild)
|
||||||
=> _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild });
|
=> _db.QueryProcedure<ProxyMember>("proxy_members", new { account_id = account, guild_id = guild });
|
||||||
|
Loading…
Reference in New Issue
Block a user