refactor: don't use 'out' in IDiscordCache methods

this change is required for async cache (.NET doesn't support async ref/out params)
This commit is contained in:
spiral 2021-11-18 13:11:02 -05:00
parent 3ba46ff456
commit 0efaab6d95
No known key found for this signature in database
GPG Key ID: A6059F0CA0E1BD31
11 changed files with 48 additions and 52 deletions

View File

@ -18,11 +18,11 @@ namespace Myriad.Cache
public ValueTask RemoveUser(ulong userId); public ValueTask RemoveUser(ulong userId);
public ValueTask RemoveRole(ulong guildId, ulong roleId); public ValueTask RemoveRole(ulong guildId, ulong roleId);
public Task<bool> TryGetGuild(ulong guildId, out Guild guild); public Task<Guild?> TryGetGuild(ulong guildId);
public Task<bool> TryGetChannel(ulong channelId, out Channel channel); public Task<Channel?> TryGetChannel(ulong channelId);
public Task<bool> TryGetDmChannel(ulong userId, out Channel channel); public Task<Channel?> TryGetDmChannel(ulong userId);
public Task<bool> TryGetUser(ulong userId, out User user); public Task<User?> TryGetUser(ulong userId);
public Task<bool> TryGetRole(ulong roleId, out Role role); public Task<Role?> TryGetRole(ulong roleId);
public IAsyncEnumerable<Guild> GetAllGuilds(); public IAsyncEnumerable<Guild> GetAllGuilds();
public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId); public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId);

View File

@ -124,34 +124,36 @@ namespace Myriad.Cache
return default; return default;
} }
public Task<bool> TryGetGuild(ulong guildId, out Guild guild) public Task<Guild?> TryGetGuild(ulong guildId)
{ {
if (_guilds.TryGetValue(guildId, out var cg)) _guilds.TryGetValue(guildId, out var cg);
{ return Task.FromResult(cg?.Guild);
guild = cg.Guild;
return Task.FromResult(true);
}
guild = null!;
return Task.FromResult(false);
} }
public Task<bool> TryGetChannel(ulong channelId, out Channel channel) => public Task<Channel?> TryGetChannel(ulong channelId)
Task.FromResult(_channels.TryGetValue(channelId, out channel!)); {
_channels.TryGetValue(channelId, out var channel);
public Task<bool> TryGetDmChannel(ulong userId, out Channel channel) return Task.FromResult(channel);
}
public Task<Channel?> TryGetDmChannel(ulong userId)
{ {
channel = default!;
if (!_dmChannels.TryGetValue(userId, out var channelId)) if (!_dmChannels.TryGetValue(userId, out var channelId))
return Task.FromResult(false); return Task.FromResult((Channel?) null);
return TryGetChannel(channelId, out channel); return TryGetChannel(channelId);
} }
public Task<bool> TryGetUser(ulong userId, out User user) => public Task<User?> TryGetUser(ulong userId)
Task.FromResult(_users.TryGetValue(userId, out user!)); {
_users.TryGetValue(userId, out var user);
return Task.FromResult(user);
}
public Task<bool> TryGetRole(ulong roleId, out Role role) => public Task<Role?> TryGetRole(ulong roleId)
Task.FromResult(_roles.TryGetValue(roleId, out role!)); {
_roles.TryGetValue(roleId, out var role);
return Task.FromResult(role);
}
public IAsyncEnumerable<Guild> GetAllGuilds() public IAsyncEnumerable<Guild> GetAllGuilds()
{ {

View File

@ -11,42 +11,35 @@ namespace Myriad.Extensions
{ {
public static async Task<Guild> GetGuild(this IDiscordCache cache, ulong guildId) public static async Task<Guild> GetGuild(this IDiscordCache cache, ulong guildId)
{ {
if (!await cache.TryGetGuild(guildId, out var guild)) if (!(await cache.TryGetGuild(guildId) is Guild guild))
throw new KeyNotFoundException($"Guild {guildId} not found in cache"); throw new KeyNotFoundException($"Guild {guildId} not found in cache");
return guild; return guild;
} }
public static async Task<Channel> GetChannel(this IDiscordCache cache, ulong channelId) public static async Task<Channel> GetChannel(this IDiscordCache cache, ulong channelId)
{ {
if (!await cache.TryGetChannel(channelId, out var channel)) if (!(await cache.TryGetChannel(channelId) is Channel channel))
throw new KeyNotFoundException($"Channel {channelId} not found in cache"); throw new KeyNotFoundException($"Channel {channelId} not found in cache");
return channel; return channel;
} }
public static async Task<Channel?> GetChannelOrNull(this IDiscordCache cache, ulong channelId)
{
if (await cache.TryGetChannel(channelId, out var channel))
return channel;
return null;
}
public static async Task<User> GetUser(this IDiscordCache cache, ulong userId) public static async Task<User> GetUser(this IDiscordCache cache, ulong userId)
{ {
if (!await cache.TryGetUser(userId, out var user)) if (!(await cache.TryGetUser(userId) is User user))
throw new KeyNotFoundException($"User {userId} not found in cache"); throw new KeyNotFoundException($"User {userId} not found in cache");
return user; return user;
} }
public static async Task<Role> GetRole(this IDiscordCache cache, ulong roleId) public static async Task<Role> GetRole(this IDiscordCache cache, ulong roleId)
{ {
if (!await cache.TryGetRole(roleId, out var role)) if (!(await cache.TryGetRole(roleId) is Role role))
throw new KeyNotFoundException($"Role {roleId} not found in cache"); throw new KeyNotFoundException($"Role {roleId} not found in cache");
return role; return role;
} }
public static async ValueTask<User?> GetOrFetchUser(this IDiscordCache cache, DiscordApiClient rest, ulong userId) public static async ValueTask<User?> GetOrFetchUser(this IDiscordCache cache, DiscordApiClient rest, ulong userId)
{ {
if (await cache.TryGetUser(userId, out var cacheUser)) if (await cache.TryGetUser(userId) is User cacheUser)
return cacheUser; return cacheUser;
var restUser = await rest.GetUser(userId); var restUser = await rest.GetUser(userId);
@ -57,7 +50,7 @@ namespace Myriad.Extensions
public static async ValueTask<Channel?> GetOrFetchChannel(this IDiscordCache cache, DiscordApiClient rest, ulong channelId) public static async ValueTask<Channel?> GetOrFetchChannel(this IDiscordCache cache, DiscordApiClient rest, ulong channelId)
{ {
if (await cache.TryGetChannel(channelId, out var cacheChannel)) if (await cache.TryGetChannel(channelId) is {} cacheChannel)
return cacheChannel; return cacheChannel;
var restChannel = await rest.GetChannel(channelId); var restChannel = await rest.GetChannel(channelId);
@ -68,7 +61,7 @@ namespace Myriad.Extensions
public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest, ulong recipientId) public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest, ulong recipientId)
{ {
if (await cache.TryGetDmChannel(recipientId, out var cacheChannel)) if (await cache.TryGetDmChannel(recipientId) is {} cacheChannel)
return cacheChannel; return cacheChannel;
var restChannel = await rest.CreateDm(recipientId); var restChannel = await rest.CreateDm(recipientId);

View File

@ -19,7 +19,7 @@ namespace Myriad.Extensions
public static async Task<PermissionSet> PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, GuildMemberPartial? member, bool isWebhook = false) public static async Task<PermissionSet> PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, GuildMemberPartial? member, bool isWebhook = false)
{ {
if (!await cache.TryGetChannel(channelId, out var channel)) if (!(await cache.TryGetChannel(channelId) is Channel channel))
// todo: handle channel not found better // todo: handle channel not found better
return PermissionSet.Dm; return PermissionSet.Dm;

View File

@ -157,7 +157,7 @@ namespace PluralKit.Bot
if (!MentionUtils.TryParseChannel(ctx.PeekArgument(), out var id)) if (!MentionUtils.TryParseChannel(ctx.PeekArgument(), out var id))
return null; return null;
if (!await ctx.Cache.TryGetChannel(id, out var channel)) if (!(await ctx.Cache.TryGetChannel(id) is Channel channel))
return null; return null;
if (!DiscordUtils.IsValidGuildChannel(channel)) if (!DiscordUtils.IsValidGuildChannel(channel))
@ -167,12 +167,12 @@ namespace PluralKit.Bot
return channel; return channel;
} }
public static Guild MatchGuild(this Context ctx) public static async Task<Guild> MatchGuild(this Context ctx)
{ {
try try
{ {
var id = ulong.Parse(ctx.PeekArgument()); var id = ulong.Parse(ctx.PeekArgument());
ctx.Cache.TryGetGuild(id, out var guild); var guild = await ctx.Cache.TryGetGuild(id);
if (guild != null) if (guild != null)
ctx.PopArgument(); ctx.PopArgument();

View File

@ -110,7 +110,7 @@ namespace PluralKit.Bot
// Resolve all channels from the cache and order by position // Resolve all channels from the cache and order by position
var channels = (await Task.WhenAll(blacklist.Blacklist var channels = (await Task.WhenAll(blacklist.Blacklist
.Select(id => _cache.GetChannelOrNull(id)))) .Select(id => _cache.TryGetChannel(id))))
.Where(c => c != null) .Where(c => c != null)
.OrderBy(c => c.Position) .OrderBy(c => c.Position)
.ToList(); .ToList();

View File

@ -436,7 +436,7 @@ namespace PluralKit.Bot
{ {
ctx.CheckSystem(); ctx.CheckSystem();
var guild = ctx.MatchGuild() ?? ctx.Guild ?? var guild = await ctx.MatchGuild() ?? ctx.Guild ??
throw new PKError("You must run this command in a server or pass a server ID."); throw new PKError("You must run this command in a server or pass a server ID.");
var gs = await _repo.GetSystemGuild(guild.Id, ctx.System.Id); var gs = await _repo.GetSystemGuild(guild.Id, ctx.System.Id);

View File

@ -50,7 +50,7 @@ namespace PluralKit.Bot
{ {
// Sometimes we get events from users that aren't in the user cache // Sometimes we get events from users that aren't in the user cache
// We just ignore all of those for now, should be quite rare... // We just ignore all of those for now, should be quite rare...
if (!await _cache.TryGetUser(evt.UserId, out var user)) if (!(await _cache.TryGetUser(evt.UserId) is User user))
return; return;
// ignore any reactions added by *us* // ignore any reactions added by *us*

View File

@ -335,10 +335,10 @@ namespace PluralKit.Bot
var roles = memberInfo?.Roles?.ToList(); var roles = memberInfo?.Roles?.ToList();
if (roles != null && roles.Count > 0 && showContent) if (roles != null && roles.Count > 0 && showContent)
{ {
var rolesString = string.Join(", ", roles var rolesString = string.Join(", ", (await Task.WhenAll(roles
.Select(id => .Select(async id =>
{ {
_cache.TryGetRole(id, out var role); var role = await _cache.TryGetRole(id);
if (role != null) if (role != null)
return role; return role;
return new Role() return new Role()
@ -346,7 +346,7 @@ namespace PluralKit.Bot
Name = "*(unknown role)*", Name = "*(unknown role)*",
Position = 0, Position = 0,
}; };
}) })))
.OrderByDescending(role => role.Position) .OrderByDescending(role => role.Position)
.Select(role => role.Name)); .Select(role => role.Name));
eb.Field(new($"Account roles ({roles.Count})", rolesString.Truncate(1024))); eb.Field(new($"Account roles ({roles.Count})", rolesString.Truncate(1024)));

View File

@ -93,7 +93,7 @@ namespace PluralKit.Bot
private async Task<Channel?> FindLogChannel(ulong guildId, ulong channelId) private async Task<Channel?> FindLogChannel(ulong guildId, ulong channelId)
{ {
// TODO: fetch it directly on cache miss? // TODO: fetch it directly on cache miss?
if (await _cache.TryGetChannel(channelId, out var channel)) if (await _cache.TryGetChannel(channelId) is Channel channel)
return channel; return channel;
// Channel doesn't exist or we don't have permission to access it, let's remove it from the database too // Channel doesn't exist or we don't have permission to access it, let's remove it from the database too

View File

@ -4,6 +4,7 @@ using System.Threading.Tasks;
using Myriad.Cache; using Myriad.Cache;
using Myriad.Extensions; using Myriad.Extensions;
using Myriad.Gateway; using Myriad.Gateway;
using Myriad.Types;
using Serilog.Core; using Serilog.Core;
using Serilog.Events; using Serilog.Events;
@ -39,7 +40,7 @@ namespace PluralKit.Bot
{ {
props.Add(new("ChannelId", new ScalarValue(channel.Value))); props.Add(new("ChannelId", new ScalarValue(channel.Value)));
if (await _cache.TryGetChannel(channel.Value, out _)) if (await _cache.TryGetChannel(channel.Value) != null)
{ {
var botPermissions = await _bot.PermissionsIn(channel.Value); var botPermissions = await _bot.PermissionsIn(channel.Value);
props.Add(new("BotPermissions", new ScalarValue(botPermissions))); props.Add(new("BotPermissions", new ScalarValue(botPermissions)));