feat: fetch from REST instead of cache for cross-cluster lookups

This commit is contained in:
spiral 2022-03-09 20:06:53 -05:00
parent d0ad7abb03
commit ae543b9c18
No known key found for this signature in database
GPG Key ID: 244A11E4B0BCF40E
7 changed files with 80 additions and 19 deletions

View File

@ -59,14 +59,14 @@ public static class PermissionExtensions
} }
public static PermissionSet EveryonePermissions(this Guild guild) => public static PermissionSet EveryonePermissions(this Guild guild) =>
guild.Roles.FirstOrDefault(r => r.Id == guild.Id)?.Permissions ?? PermissionSet.Dm; guild.Roles.FirstOrDefault(r => r.Id == guild.Id)!.Permissions;
public static async Task<PermissionSet> EveryonePermissions(this IDiscordCache cache, Channel channel) public static PermissionSet EveryonePermissions(Guild guild, Channel channel)
{ {
if (channel.Type == Channel.ChannelType.Dm) if (channel.Type == Channel.ChannelType.Dm)
return PermissionSet.Dm; return PermissionSet.Dm;
var defaultPermissions = (await cache.GetGuild(channel.GuildId!.Value)).EveryonePermissions(); var defaultPermissions = guild.EveryonePermissions();
var overwrite = channel.PermissionOverwrites?.FirstOrDefault(r => r.Id == channel.GuildId); var overwrite = channel.PermissionOverwrites?.FirstOrDefault(r => r.Id == channel.GuildId);
if (overwrite == null) if (overwrite == null)
return defaultPermissions; return defaultPermissions;

View File

@ -81,7 +81,10 @@ public static class ContextChecksExt
public static async Task<bool> CheckPermissionsInGuildChannel(this Context ctx, Channel channel, public static async Task<bool> CheckPermissionsInGuildChannel(this Context ctx, Channel channel,
PermissionSet neededPerms) PermissionSet neededPerms)
{ {
var guild = await ctx.Cache.GetGuild(channel.GuildId.Value); // this is a quick hack, should probably do it properly eventually
var guild = await ctx.Cache.TryGetGuild(channel.GuildId.Value);
if (guild == null)
await ctx.Rest.GetGuild(channel.GuildId.Value);
if (guild == null) if (guild == null)
return false; return false;

View File

@ -151,7 +151,10 @@ public static class ContextEntityArgumentsExt
if (!MentionUtils.TryParseChannel(ctx.PeekArgument(), out var id)) if (!MentionUtils.TryParseChannel(ctx.PeekArgument(), out var id))
return null; return null;
if (!(await ctx.Cache.TryGetChannel(id) is Channel channel)) var channel = await ctx.Cache.TryGetChannel(id);
if (channel == null)
channel = await ctx.Rest.GetChannelOrNull(id);
if (channel == null)
return null; return null;
if (!DiscordUtils.IsValidGuildChannel(channel)) if (!DiscordUtils.IsValidGuildChannel(channel))

View File

@ -14,6 +14,7 @@ namespace PluralKit.Bot;
public class Checks public class Checks
{ {
private readonly BotConfig _botConfig; private readonly BotConfig _botConfig;
// this must ONLY be used to get the bot's user ID
private readonly IDiscordCache _cache; private readonly IDiscordCache _cache;
private readonly ProxyMatcher _matcher; private readonly ProxyMatcher _matcher;
private readonly ProxyService _proxy; private readonly ProxyService _proxy;
@ -26,6 +27,7 @@ public class Checks
PermissionSet.ManageWebhooks, PermissionSet.ReadMessageHistory PermissionSet.ManageWebhooks, PermissionSet.ReadMessageHistory
}; };
// todo: make sure everything uses the minimum amount of REST calls necessary
public Checks(DiscordApiClient rest, IDiscordCache cache, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher) public Checks(DiscordApiClient rest, IDiscordCache cache, BotConfig botConfig, ProxyService proxy, ProxyMatcher matcher)
{ {
_rest = rest; _rest = rest;
@ -67,16 +69,17 @@ public class Checks
throw Errors.GuildNotFound(guildId); throw Errors.GuildNotFound(guildId);
} }
var guildMember = await _rest.GetGuildMember(guild.Id, await _cache.GetOwnUser());
// Loop through every channel and group them by sets of permissions missing // Loop through every channel and group them by sets of permissions missing
var permissionsMissing = new Dictionary<ulong, List<Channel>>(); var permissionsMissing = new Dictionary<ulong, List<Channel>>();
var hiddenChannels = false; var hiddenChannels = false;
var missingEmojiPermissions = false; var missingEmojiPermissions = false;
foreach (var channel in await _rest.GetGuildChannels(guild.Id)) foreach (var channel in await _rest.GetGuildChannels(guild.Id))
{ {
var botPermissions = await _cache.PermissionsIn(channel.Id); var botPermissions = PermissionExtensions.PermissionsFor(guild, channel, await _cache.GetOwnUser(), guildMember);
var webhookPermissions = await _cache.EveryonePermissions(channel); var webhookPermissions = PermissionExtensions.EveryonePermissions(guild, channel);
var userPermissions = var userPermissions = PermissionExtensions.PermissionsFor(guild, channel, ctx.Author.Id, senderGuildUser);
PermissionExtensions.PermissionsFor(guild, channel, ctx.Author.Id, senderGuildUser);
if ((userPermissions & PermissionSet.ViewChannel) == 0) if ((userPermissions & PermissionSet.ViewChannel) == 0)
{ {
@ -153,15 +156,23 @@ public class Checks
throw new PKSyntaxError("You need to specify a channel."); throw new PKSyntaxError("You need to specify a channel.");
var error = "Channel not found or you do not have permissions to access it."; var error = "Channel not found or you do not have permissions to access it.";
// todo: this breaks if channel is not in cache and bot does not have View Channel permissions
var channel = await ctx.MatchChannel(); var channel = await ctx.MatchChannel();
if (channel == null || channel.GuildId == null) if (channel == null || channel.GuildId == null)
throw new PKError(error); throw new PKError(error);
var guild = await _rest.GetGuildOrNull(channel.GuildId.Value);
if (guild == null)
throw new PKError(error);
var guildMember = await _rest.GetGuildMember(channel.GuildId.Value, await _cache.GetOwnUser());
if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel)) if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel))
throw new PKError(error); throw new PKError(error);
var botPermissions = await _cache.PermissionsIn(channel.Id); var botPermissions = PermissionExtensions.PermissionsFor(guild, channel, await _cache.GetOwnUser(), guildMember);
var webhookPermissions = await _cache.EveryonePermissions(channel); var webhookPermissions = PermissionExtensions.EveryonePermissions(guild, channel);
// We use a bitfield so we can set individual permission bits // We use a bitfield so we can set individual permission bits
ulong missingPermissions = 0; ulong missingPermissions = 0;
@ -240,7 +251,7 @@ public class Checks
throw new PKError("You can only check your own messages."); throw new PKError("You can only check your own messages.");
// get the channel info // get the channel info
var channel = await _cache.GetChannel(channelId.Value); var channel = await _rest.GetChannelOrNull(channelId.Value);
if (channel == null) if (channel == null)
throw new PKError("Unable to get the channel associated with this message."); throw new PKError("Unable to get the channel associated with this message.");

View File

@ -20,7 +20,7 @@ namespace PluralKit.Bot;
public class ProxiedMessage public class ProxiedMessage
{ {
private static readonly Duration EditTimeout = Duration.FromMinutes(10); private static readonly Duration EditTimeout = Duration.FromMinutes(10);
private readonly IDiscordCache _cache; // private readonly IDiscordCache _cache;
private readonly IClock _clock; private readonly IClock _clock;
private readonly EmbedService _embeds; private readonly EmbedService _embeds;
@ -37,7 +37,7 @@ public class ProxiedMessage
_rest = rest; _rest = rest;
_webhookExecutor = webhookExecutor; _webhookExecutor = webhookExecutor;
_logChannel = logChannel; _logChannel = logChannel;
_cache = cache; // _cache = cache;
} }
public async Task EditMessage(Context ctx) public async Task EditMessage(Context ctx)
@ -112,7 +112,7 @@ public class ProxiedMessage
var error = var error =
"The channel where the message was sent does not exist anymore, or you are missing permissions to access it."; "The channel where the message was sent does not exist anymore, or you are missing permissions to access it.";
var channel = await _cache.GetChannel(msg.Message.Channel); var channel = await _rest.GetChannelOrNull(msg.Message.Channel);
if (channel == null) if (channel == null)
throw new PKError(error); throw new PKError(error);
@ -165,7 +165,7 @@ public class ProxiedMessage
var showContent = true; var showContent = true;
var noShowContentError = "Message deleted or inaccessible."; var noShowContentError = "Message deleted or inaccessible.";
var channel = await _cache.GetChannel(message.Message.Channel); var channel = await _rest.GetChannelOrNull(message.Message.Channel);
if (channel == null) if (channel == null)
showContent = false; showContent = false;
else if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel)) else if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel))
@ -222,7 +222,7 @@ public class ProxiedMessage
if (ctx.Match("author") || ctx.MatchFlag("author")) if (ctx.Match("author") || ctx.MatchFlag("author"))
{ {
var user = await _cache.GetOrFetchUser(_rest, message.Message.Sender); var user = await _rest.GetUser(message.Message.Sender);
var eb = new EmbedBuilder() var eb = new EmbedBuilder()
.Author(new Embed.EmbedAuthor( .Author(new Embed.EmbedAuthor(
user != null user != null

View File

@ -81,24 +81,42 @@ public class LogChannelService
if (logChannel == null || logChannel.Type != Channel.ChannelType.GuildText) return null; if (logChannel == null || logChannel.Type != Channel.ChannelType.GuildText) return null;
// Check bot permissions // Check bot permissions
var perms = await _cache.PermissionsIn(logChannel.Id); var perms = await GetPermissionsInLogChannel(logChannel);
if (!perms.HasFlag(PermissionSet.SendMessages | PermissionSet.EmbedLinks)) if (!perms.HasFlag(PermissionSet.SendMessages | PermissionSet.EmbedLinks))
{ {
_logger.Information( _logger.Information(
"Does not have permission to log proxy, ignoring (channel: {ChannelId}, guild: {GuildId}, bot permissions: {BotPermissions})", "Does not have permission to log proxy, ignoring (channel: {ChannelId}, guild: {GuildId}, bot permissions: {BotPermissions})",
ctx.LogChannel.Value, trigger.GuildId!.Value, perms); logChannel.Id, trigger.GuildId!.Value, perms);
return null; return null;
} }
return logChannel.Id; return logChannel.Id;
} }
// todo: move this somewhere else
private async Task<PermissionSet> GetPermissionsInLogChannel(Channel channel)
{
var guild = await _cache.TryGetGuild(channel.GuildId.Value);
if (guild == null)
guild = await _rest.GetGuild(channel.GuildId.Value);
var guildMember = await _cache.TryGetSelfMember(channel.GuildId.Value);
if (guildMember == null)
guildMember = await _rest.GetGuildMember(channel.GuildId.Value, await _cache.GetOwnUser());
var perms = PermissionExtensions.PermissionsFor(guild, channel, await _cache.GetOwnUser(), guildMember);
return perms;
}
private async Task<Channel?> FindLogChannel(ulong guildId, ulong channelId) private async Task<Channel?> FindLogChannel(ulong guildId, ulong channelId)
{ {
// TODO: fetch it directly on cache miss? // TODO: fetch it directly on cache miss?
if (await _cache.TryGetChannel(channelId) is Channel channel) if (await _cache.TryGetChannel(channelId) is Channel channel)
return channel; return channel;
if (await _rest.GetChannelOrNull(channelId) is Channel restChannel)
return restChannel;
// Channel doesn't exist or we don't have permission to access it, let's remove it from the database too // Channel doesn't exist or we don't have permission to access it, let's remove it from the database too
_logger.Warning( _logger.Warning(
"Attempted to fetch missing log channel {LogChannel} for guild {Guild}, removing from database", "Attempted to fetch missing log channel {LogChannel} for guild {Guild}, removing from database",

View File

@ -49,6 +49,32 @@ public static class DiscordUtils
await rest.CreateReaction(msg.ChannelId, msg.Id, new Emoji { Name = reaction }); await rest.CreateReaction(msg.ChannelId, msg.Id, new Emoji { Name = reaction });
} }
public static async Task<Guild?> GetGuildOrNull(this DiscordApiClient rest, ulong guildId)
{
try
{
return await rest.GetGuild(guildId);
}
catch (ForbiddenException)
{
// no permission, couldn't fetch, oh well
return null;
}
}
public static async Task<Channel?> GetChannelOrNull(this DiscordApiClient rest, ulong channelId)
{
try
{
return await rest.GetChannel(channelId);
}
catch (ForbiddenException)
{
// no permission, couldn't fetch, oh well
return null;
}
}
public static async Task<Message?> GetMessageOrNull(this DiscordApiClient rest, ulong channelId, public static async Task<Message?> GetMessageOrNull(this DiscordApiClient rest, ulong channelId,
ulong messageId) ulong messageId)
{ {