feat(bot): don't query db message context when running commands
This commit is contained in:
		@@ -28,8 +28,8 @@ public class Context
 | 
			
		||||
 | 
			
		||||
    private Command? _currentCommand;
 | 
			
		||||
 | 
			
		||||
    public Context(ILifetimeScope provider, int shardId, Guild? guild, Channel channel, MessageCreateEvent message, int commandParseOffset,
 | 
			
		||||
                    PKSystem senderSystem, SystemConfig config, MessageContext messageContext)
 | 
			
		||||
    public Context(ILifetimeScope provider, int shardId, Guild? guild, Channel channel, MessageCreateEvent message,
 | 
			
		||||
                                                    int commandParseOffset, PKSystem senderSystem, SystemConfig config)
 | 
			
		||||
    {
 | 
			
		||||
        Message = (Message)message;
 | 
			
		||||
        ShardId = shardId;
 | 
			
		||||
@@ -37,7 +37,6 @@ public class Context
 | 
			
		||||
        Channel = channel;
 | 
			
		||||
        System = senderSystem;
 | 
			
		||||
        Config = config;
 | 
			
		||||
        MessageContext = messageContext;
 | 
			
		||||
        Cache = provider.Resolve<IDiscordCache>();
 | 
			
		||||
        Database = provider.Resolve<IDatabase>();
 | 
			
		||||
        Repository = provider.Resolve<ModelRepository>();
 | 
			
		||||
@@ -61,7 +60,6 @@ public class Context
 | 
			
		||||
    public readonly Guild Guild;
 | 
			
		||||
    public readonly int ShardId;
 | 
			
		||||
    public readonly Cluster Cluster;
 | 
			
		||||
    public readonly MessageContext MessageContext;
 | 
			
		||||
 | 
			
		||||
    public Task<PermissionSet> BotPermissions => Cache.PermissionsIn(Channel.Id);
 | 
			
		||||
    public Task<PermissionSet> UserPermissions => Cache.PermissionsFor((MessageCreateEvent)Message);
 | 
			
		||||
 
 | 
			
		||||
@@ -89,10 +89,12 @@ public class Autoproxy
 | 
			
		||||
        var eb = new EmbedBuilder()
 | 
			
		||||
            .Title($"Current autoproxy status (for {ctx.Guild.Name.EscapeMarkdown()})");
 | 
			
		||||
 | 
			
		||||
        var fronters = ctx.MessageContext.LastSwitchMembers;
 | 
			
		||||
        var sw = await ctx.Repository.GetLatestSwitch(ctx.System.Id);
 | 
			
		||||
        var fronters = await ctx.Database.Execute(c => ctx.Repository.GetSwitchMembers(c, sw.Id)).ToListAsync();
 | 
			
		||||
 | 
			
		||||
        var relevantMember = settings.AutoproxyMode switch
 | 
			
		||||
        {
 | 
			
		||||
            AutoproxyMode.Front => fronters.Length > 0 ? await ctx.Repository.GetMember(fronters[0]) : null,
 | 
			
		||||
            AutoproxyMode.Front => fronters.Count > 0 ? fronters[0] : null,
 | 
			
		||||
            AutoproxyMode.Member when settings.AutoproxyMember.HasValue => await ctx.Repository.GetMember(settings.AutoproxyMember.Value),
 | 
			
		||||
            _ => null
 | 
			
		||||
        };
 | 
			
		||||
@@ -104,7 +106,7 @@ public class Autoproxy
 | 
			
		||||
                break;
 | 
			
		||||
            case AutoproxyMode.Front:
 | 
			
		||||
                {
 | 
			
		||||
                    if (fronters.Length == 0)
 | 
			
		||||
                    if (fronters.Count == 0)
 | 
			
		||||
                    {
 | 
			
		||||
                        eb.Description("Autoproxy is currently set to **front mode** in this server, but there are currently no fronters registered. Use the `pk;switch` command to log a switch.");
 | 
			
		||||
                    }
 | 
			
		||||
@@ -135,7 +137,8 @@ public class Autoproxy
 | 
			
		||||
            default: throw new ArgumentOutOfRangeException();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (!ctx.MessageContext.AllowAutoproxy)
 | 
			
		||||
        var allowAutoproxy = await ctx.Repository.GetAutoproxyEnabled(ctx.Author.Id);
 | 
			
		||||
        if (!allowAutoproxy)
 | 
			
		||||
            eb.Field(new Embed.Field("\u200b", $"{Emojis.Note} Autoproxy is currently **disabled** for your account (<@{ctx.Author.Id}>). To enable it, use `pk;autoproxy account enable`."));
 | 
			
		||||
 | 
			
		||||
        return eb.Build();
 | 
			
		||||
 
 | 
			
		||||
@@ -17,10 +17,12 @@ public class Config
 | 
			
		||||
    {
 | 
			
		||||
        var items = new List<PaginatedConfigItem>();
 | 
			
		||||
 | 
			
		||||
        var allowAutoproxy = await ctx.Repository.GetAutoproxyEnabled(ctx.Author.Id);
 | 
			
		||||
 | 
			
		||||
        items.Add(new(
 | 
			
		||||
            "autoproxy account",
 | 
			
		||||
            "Whether autoproxy is enabled for the current account",
 | 
			
		||||
            EnabledDisabled(ctx.MessageContext.AllowAutoproxy),
 | 
			
		||||
            EnabledDisabled(allowAutoproxy),
 | 
			
		||||
            "enabled"
 | 
			
		||||
        ));
 | 
			
		||||
 | 
			
		||||
@@ -122,16 +124,18 @@ public class Config
 | 
			
		||||
 | 
			
		||||
    public async Task AutoproxyAccount(Context ctx)
 | 
			
		||||
    {
 | 
			
		||||
        var allowAutoproxy = await ctx.Repository.GetAutoproxyEnabled(ctx.Author.Id);
 | 
			
		||||
 | 
			
		||||
        if (!ctx.HasNext())
 | 
			
		||||
        {
 | 
			
		||||
            await ctx.Reply($"Autoproxy is currently **{EnabledDisabled(ctx.MessageContext.AllowAutoproxy)}** for account <@{ctx.Author.Id}>.");
 | 
			
		||||
            await ctx.Reply($"Autoproxy is currently **{EnabledDisabled(allowAutoproxy)}** for account <@{ctx.Author.Id}>.");
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        var allow = ctx.MatchToggle(true);
 | 
			
		||||
 | 
			
		||||
        var statusString = EnabledDisabled(allow);
 | 
			
		||||
        if (ctx.MessageContext.AllowAutoproxy == allow)
 | 
			
		||||
        if (allowAutoproxy == allow)
 | 
			
		||||
        {
 | 
			
		||||
            await ctx.Reply($"{Emojis.Note} Autoproxy is already {statusString} for account <@{ctx.Author.Id}>.");
 | 
			
		||||
            return;
 | 
			
		||||
 
 | 
			
		||||
@@ -126,8 +126,7 @@ public class ProxiedMessage
 | 
			
		||||
            if ((await ctx.BotPermissions).HasFlag(PermissionSet.ManageMessages))
 | 
			
		||||
                await _rest.DeleteMessage(ctx.Channel.Id, ctx.Message.Id);
 | 
			
		||||
 | 
			
		||||
            await _logChannel.LogMessage(ctx.MessageContext, msg.Message, ctx.Message, editedMsg,
 | 
			
		||||
                originalMsg!.Content!);
 | 
			
		||||
            await _logChannel.LogMessage(msg.Message, ctx.Message, editedMsg, originalMsg!.Content!);
 | 
			
		||||
        }
 | 
			
		||||
        catch (NotFoundException)
 | 
			
		||||
        {
 | 
			
		||||
 
 | 
			
		||||
@@ -277,7 +277,7 @@ public class SystemEdit
 | 
			
		||||
            await ctx.Reply(
 | 
			
		||||
                $"{Emojis.Success} System server tag changed. Member names will now end with {newTag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'.");
 | 
			
		||||
 | 
			
		||||
            if (!ctx.MessageContext.TagEnabled)
 | 
			
		||||
            if (!settings.TagEnabled)
 | 
			
		||||
                await ctx.Reply(setDisabledWarning);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -288,7 +288,7 @@ public class SystemEdit
 | 
			
		||||
            await ctx.Reply(
 | 
			
		||||
                $"{Emojis.Success} System server tag cleared. Member names will now end with the global system tag, if there is one set.");
 | 
			
		||||
 | 
			
		||||
            if (!ctx.MessageContext.TagEnabled)
 | 
			
		||||
            if (!settings.TagEnabled)
 | 
			
		||||
                await ctx.Reply(setDisabledWarning);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -297,7 +297,7 @@ public class SystemEdit
 | 
			
		||||
            await ctx.Repository.UpdateSystemGuild(target.Id, ctx.Guild.Id,
 | 
			
		||||
                new SystemGuildPatch { TagEnabled = newValue });
 | 
			
		||||
 | 
			
		||||
            await ctx.Reply(PrintEnableDisableResult(newValue, newValue != ctx.MessageContext.TagEnabled));
 | 
			
		||||
            await ctx.Reply(PrintEnableDisableResult(newValue, newValue != settings.TagEnabled));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        string PrintEnableDisableResult(bool newValue, bool changedValue)
 | 
			
		||||
@@ -312,20 +312,20 @@ public class SystemEdit
 | 
			
		||||
 | 
			
		||||
            if (newValue)
 | 
			
		||||
            {
 | 
			
		||||
                if (ctx.MessageContext.TagEnabled)
 | 
			
		||||
                if (settings.TagEnabled)
 | 
			
		||||
                {
 | 
			
		||||
                    if (ctx.MessageContext.SystemGuildTag == null)
 | 
			
		||||
                    if (settings.Tag == null)
 | 
			
		||||
                        str +=
 | 
			
		||||
                            " However, you do not have a system tag specific to this server. Messages will be proxied using your global system tag, if there is one set.";
 | 
			
		||||
                    else
 | 
			
		||||
                        str +=
 | 
			
		||||
                            $" Your current system tag in '{ctx.Guild.Name}' is {ctx.MessageContext.SystemGuildTag.AsCode()}.";
 | 
			
		||||
                            $" Your current system tag in '{ctx.Guild.Name}' is {settings.Tag.AsCode()}.";
 | 
			
		||||
                }
 | 
			
		||||
                else
 | 
			
		||||
                {
 | 
			
		||||
                    if (ctx.MessageContext.SystemGuildTag != null)
 | 
			
		||||
                    if (settings.Tag != null)
 | 
			
		||||
                        str +=
 | 
			
		||||
                            $" Member names will now end with the server-specific tag {ctx.MessageContext.SystemGuildTag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'.";
 | 
			
		||||
                            $" Member names will now end with the server-specific tag {settings.Tag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'.";
 | 
			
		||||
                    else
 | 
			
		||||
                        str +=
 | 
			
		||||
                            " Member names will now end with the global system tag when proxied in the current server, if there is one set.";
 | 
			
		||||
 
 | 
			
		||||
@@ -63,7 +63,8 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
 | 
			
		||||
        if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return;
 | 
			
		||||
        if (IsDuplicateMessage(evt)) return;
 | 
			
		||||
 | 
			
		||||
        if (!(await _cache.PermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.SendMessages)) return;
 | 
			
		||||
        var botPermissions = await _cache.PermissionsIn(evt.ChannelId);
 | 
			
		||||
        if (!botPermissions.HasFlag(PermissionSet.SendMessages)) return;
 | 
			
		||||
 | 
			
		||||
        // spawn off saving the private channel into another thread
 | 
			
		||||
        // it is not a fatal error if this fails, and it shouldn't block message processing
 | 
			
		||||
@@ -77,36 +78,33 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
 | 
			
		||||
        _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
 | 
			
		||||
        _lastMessageCache.AddMessage(evt);
 | 
			
		||||
 | 
			
		||||
        // Get message context from DB (tracking w/ metrics)
 | 
			
		||||
        MessageContext ctx;
 | 
			
		||||
        using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
 | 
			
		||||
            ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel.Id);
 | 
			
		||||
        // if the message was not sent by an user account, only try running log cleanup
 | 
			
		||||
        if (evt.Author.Bot || evt.WebhookId != null || evt.Author.System == true)
 | 
			
		||||
        {
 | 
			
		||||
            await TryHandleLogClean(channel, evt);
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Try each handler until we find one that succeeds
 | 
			
		||||
        if (await TryHandleLogClean(evt, ctx))
 | 
			
		||||
 | 
			
		||||
        if (await TryHandleCommand(shardId, evt, guild, channel))
 | 
			
		||||
            return;
 | 
			
		||||
 | 
			
		||||
        // Only do command/proxy handling if it's a user account
 | 
			
		||||
        if (evt.Author.Bot || evt.WebhookId != null || evt.Author.System == true)
 | 
			
		||||
            return;
 | 
			
		||||
 | 
			
		||||
        if (await TryHandleCommand(shardId, evt, guild, channel, ctx))
 | 
			
		||||
            return;
 | 
			
		||||
        await TryHandleProxy(evt, guild, channel, ctx);
 | 
			
		||||
        await TryHandleProxy(evt, guild, channel, rootChannel.Id, botPermissions);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private async ValueTask<bool> TryHandleLogClean(MessageCreateEvent evt, MessageContext ctx)
 | 
			
		||||
    private async Task TryHandleLogClean(Channel channel, MessageCreateEvent evt)
 | 
			
		||||
    {
 | 
			
		||||
        var channel = await _cache.GetChannel(evt.ChannelId);
 | 
			
		||||
        if (!evt.Author.Bot || channel.Type != Channel.ChannelType.GuildText ||
 | 
			
		||||
            !ctx.LogCleanupEnabled) return false;
 | 
			
		||||
        if (evt.GuildId != null) return;
 | 
			
		||||
        if (channel.Type != Channel.ChannelType.GuildText) return;
 | 
			
		||||
 | 
			
		||||
        var guildSettings = await _repo.GetGuild(evt.GuildId!.Value);
 | 
			
		||||
 | 
			
		||||
        if (guildSettings.LogCleanupEnabled)
 | 
			
		||||
            await _loggerClean.HandleLoggerBotCleanup(evt);
 | 
			
		||||
        return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private async ValueTask<bool> TryHandleCommand(int shardId, MessageCreateEvent evt, Guild? guild,
 | 
			
		||||
                                                   Channel channel, MessageContext ctx)
 | 
			
		||||
    private async ValueTask<bool> TryHandleCommand(int shardId, MessageCreateEvent evt, Guild? guild, Channel channel)
 | 
			
		||||
    {
 | 
			
		||||
        var content = evt.Content;
 | 
			
		||||
        if (content == null) return false;
 | 
			
		||||
@@ -125,9 +123,9 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
 | 
			
		||||
 | 
			
		||||
        try
 | 
			
		||||
        {
 | 
			
		||||
            var system = ctx.SystemId != null ? await _repo.GetSystem(ctx.SystemId.Value) : null;
 | 
			
		||||
            var config = ctx.SystemId != null ? await _repo.GetSystemConfig(ctx.SystemId.Value) : null;
 | 
			
		||||
            await _tree.ExecuteCommand(new Context(_services, shardId, guild, channel, evt, cmdStart, system, config, ctx));
 | 
			
		||||
            var system = await _repo.GetSystemByAccount(evt.Author.Id);
 | 
			
		||||
            var config = system != null ? await _repo.GetSystemConfig(system.Id) : null;
 | 
			
		||||
            await _tree.ExecuteCommand(new Context(_services, shardId, guild, channel, evt, cmdStart, system, config));
 | 
			
		||||
        }
 | 
			
		||||
        catch (PKError)
 | 
			
		||||
        {
 | 
			
		||||
@@ -158,10 +156,12 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private async ValueTask<bool> TryHandleProxy(MessageCreateEvent evt, Guild guild, Channel channel,
 | 
			
		||||
                                                 MessageContext ctx)
 | 
			
		||||
    private async ValueTask<bool> TryHandleProxy(MessageCreateEvent evt, Guild guild, Channel channel, ulong rootChannel, PermissionSet botPermissions)
 | 
			
		||||
    {
 | 
			
		||||
        var botPermissions = await _cache.PermissionsIn(channel.Id);
 | 
			
		||||
        // Get message context from DB (tracking w/ metrics)
 | 
			
		||||
        MessageContext ctx;
 | 
			
		||||
        using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
 | 
			
		||||
            ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel);
 | 
			
		||||
 | 
			
		||||
        try
 | 
			
		||||
        {
 | 
			
		||||
 
 | 
			
		||||
@@ -386,7 +386,7 @@ public class ProxyService
 | 
			
		||||
            => _repo.AddMessage(sentMessage);
 | 
			
		||||
 | 
			
		||||
        Task LogMessageToChannel() =>
 | 
			
		||||
            _logChannel.LogMessage(ctx, sentMessage, triggerMessage, proxyMessage).AsTask();
 | 
			
		||||
            _logChannel.LogMessage(sentMessage, triggerMessage, proxyMessage).AsTask();
 | 
			
		||||
 | 
			
		||||
        Task SaveLatchAutoproxy() => autoproxySettings.AutoproxyMode == AutoproxyMode.Latch
 | 
			
		||||
            ? _repo.UpdateAutoproxy(ctx.SystemId.Value, triggerMessage.GuildId, null, new()
 | 
			
		||||
 
 | 
			
		||||
@@ -34,17 +34,16 @@ public class LogChannelService
 | 
			
		||||
        _logger = logger.ForContext<LogChannelService>();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public async ValueTask LogMessage(MessageContext ctx, PKMessage proxiedMessage, Message trigger,
 | 
			
		||||
                                      Message hookMessage, string oldContent = null)
 | 
			
		||||
    public async ValueTask LogMessage(PKMessage proxiedMessage, Message trigger, Message hookMessage, string oldContent = null)
 | 
			
		||||
    {
 | 
			
		||||
        var logChannelId = await GetAndCheckLogChannel(ctx, trigger, proxiedMessage);
 | 
			
		||||
        var logChannelId = await GetAndCheckLogChannel(trigger, proxiedMessage);
 | 
			
		||||
        if (logChannelId == null)
 | 
			
		||||
            return;
 | 
			
		||||
 | 
			
		||||
        var triggerChannel = await _cache.GetChannel(proxiedMessage.Channel);
 | 
			
		||||
 | 
			
		||||
        var system = await _repo.GetSystem(ctx.SystemId.Value);
 | 
			
		||||
        var member = await _repo.GetMember(proxiedMessage.Member!.Value);
 | 
			
		||||
        var system = await _repo.GetSystem(member.System);
 | 
			
		||||
 | 
			
		||||
        // Send embed!
 | 
			
		||||
        var embed = _embed.CreateLoggedMessageEmbed(trigger, hookMessage, system.Hid, member, triggerChannel.Name,
 | 
			
		||||
@@ -54,8 +53,7 @@ public class LogChannelService
 | 
			
		||||
        await _rest.CreateMessage(logChannelId.Value, new MessageRequest { Content = url, Embeds = new[] { embed } });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private async Task<ulong?> GetAndCheckLogChannel(MessageContext ctx, Message trigger,
 | 
			
		||||
                                                       PKMessage proxiedMessage)
 | 
			
		||||
    private async Task<ulong?> GetAndCheckLogChannel(Message trigger, PKMessage proxiedMessage)
 | 
			
		||||
    {
 | 
			
		||||
        if (proxiedMessage.Guild == null && proxiedMessage.Channel != trigger.ChannelId)
 | 
			
		||||
            // a very old message is being edited outside of its original channel
 | 
			
		||||
@@ -63,18 +61,15 @@ public class LogChannelService
 | 
			
		||||
            return null;
 | 
			
		||||
 | 
			
		||||
        var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
 | 
			
		||||
        var logChannelId = ctx.LogChannel;
 | 
			
		||||
        var isBlacklisted = ctx.InLogBlacklist;
 | 
			
		||||
 | 
			
		||||
        if (proxiedMessage.Guild != trigger.GuildId)
 | 
			
		||||
        {
 | 
			
		||||
            // we're editing a message from a different server, get log channel info from the database
 | 
			
		||||
            var guild = await _repo.GetGuild(proxiedMessage.Guild.Value);
 | 
			
		||||
            logChannelId = guild.LogChannel;
 | 
			
		||||
            isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
 | 
			
		||||
        }
 | 
			
		||||
        // get log channel info from the database
 | 
			
		||||
        var guild = await _repo.GetGuild(guildId);
 | 
			
		||||
        var logChannelId = guild.LogChannel;
 | 
			
		||||
        var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
 | 
			
		||||
 | 
			
		||||
        if (ctx.SystemId == null || logChannelId == null || isBlacklisted) return null;
 | 
			
		||||
        // if (ctx.SystemId == null ||
 | 
			
		||||
        // removed the above, there shouldn't be a way to get to this code path if you don't have a system registered
 | 
			
		||||
        if (logChannelId == null || isBlacklisted) return null;
 | 
			
		||||
 | 
			
		||||
        // Find log channel and check if valid
 | 
			
		||||
        var logChannel = await FindLogChannel(guildId, logChannelId.Value);
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,9 @@ public partial class ModelRepository
 | 
			
		||||
    public async Task<ulong?> GetDmChannel(ulong id)
 | 
			
		||||
        => await _db.Execute(c => c.QueryFirstOrDefaultAsync<ulong?>("select dm_channel from accounts where uid = @id", new { id = id }));
 | 
			
		||||
 | 
			
		||||
    public async Task<bool> GetAutoproxyEnabled(ulong id)
 | 
			
		||||
        => await _db.QueryFirst<bool>(new Query("accounts").Select("allow_autoproxy").Where("uid", id));
 | 
			
		||||
 | 
			
		||||
    public async Task UpdateAccount(ulong id, AccountPatch patch)
 | 
			
		||||
    {
 | 
			
		||||
        _logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user