feat(bot): don't query db message context when running commands

This commit is contained in:
spiral 2022-06-15 19:28:34 -04:00
parent 7cb3a3ea0f
commit 9848b88d5b
No known key found for this signature in database
GPG Key ID: 244A11E4B0BCF40E
9 changed files with 67 additions and 65 deletions

View File

@ -28,8 +28,8 @@ public class Context
private Command? _currentCommand; private Command? _currentCommand;
public Context(ILifetimeScope provider, int shardId, Guild? guild, Channel channel, MessageCreateEvent message, int commandParseOffset, public Context(ILifetimeScope provider, int shardId, Guild? guild, Channel channel, MessageCreateEvent message,
PKSystem senderSystem, SystemConfig config, MessageContext messageContext) int commandParseOffset, PKSystem senderSystem, SystemConfig config)
{ {
Message = (Message)message; Message = (Message)message;
ShardId = shardId; ShardId = shardId;
@ -37,7 +37,6 @@ public class Context
Channel = channel; Channel = channel;
System = senderSystem; System = senderSystem;
Config = config; Config = config;
MessageContext = messageContext;
Cache = provider.Resolve<IDiscordCache>(); Cache = provider.Resolve<IDiscordCache>();
Database = provider.Resolve<IDatabase>(); Database = provider.Resolve<IDatabase>();
Repository = provider.Resolve<ModelRepository>(); Repository = provider.Resolve<ModelRepository>();
@ -61,7 +60,6 @@ public class Context
public readonly Guild Guild; public readonly Guild Guild;
public readonly int ShardId; public readonly int ShardId;
public readonly Cluster Cluster; public readonly Cluster Cluster;
public readonly MessageContext MessageContext;
public Task<PermissionSet> BotPermissions => Cache.PermissionsIn(Channel.Id); public Task<PermissionSet> BotPermissions => Cache.PermissionsIn(Channel.Id);
public Task<PermissionSet> UserPermissions => Cache.PermissionsFor((MessageCreateEvent)Message); public Task<PermissionSet> UserPermissions => Cache.PermissionsFor((MessageCreateEvent)Message);

View File

@ -89,10 +89,12 @@ public class Autoproxy
var eb = new EmbedBuilder() var eb = new EmbedBuilder()
.Title($"Current autoproxy status (for {ctx.Guild.Name.EscapeMarkdown()})"); .Title($"Current autoproxy status (for {ctx.Guild.Name.EscapeMarkdown()})");
var fronters = ctx.MessageContext.LastSwitchMembers; var sw = await ctx.Repository.GetLatestSwitch(ctx.System.Id);
var fronters = await ctx.Database.Execute(c => ctx.Repository.GetSwitchMembers(c, sw.Id)).ToListAsync();
var relevantMember = settings.AutoproxyMode switch var relevantMember = settings.AutoproxyMode switch
{ {
AutoproxyMode.Front => fronters.Length > 0 ? await ctx.Repository.GetMember(fronters[0]) : null, AutoproxyMode.Front => fronters.Count > 0 ? fronters[0] : null,
AutoproxyMode.Member when settings.AutoproxyMember.HasValue => await ctx.Repository.GetMember(settings.AutoproxyMember.Value), AutoproxyMode.Member when settings.AutoproxyMember.HasValue => await ctx.Repository.GetMember(settings.AutoproxyMember.Value),
_ => null _ => null
}; };
@ -104,7 +106,7 @@ public class Autoproxy
break; break;
case AutoproxyMode.Front: case AutoproxyMode.Front:
{ {
if (fronters.Length == 0) if (fronters.Count == 0)
{ {
eb.Description("Autoproxy is currently set to **front mode** in this server, but there are currently no fronters registered. Use the `pk;switch` command to log a switch."); eb.Description("Autoproxy is currently set to **front mode** in this server, but there are currently no fronters registered. Use the `pk;switch` command to log a switch.");
} }
@ -135,7 +137,8 @@ public class Autoproxy
default: throw new ArgumentOutOfRangeException(); default: throw new ArgumentOutOfRangeException();
} }
if (!ctx.MessageContext.AllowAutoproxy) var allowAutoproxy = await ctx.Repository.GetAutoproxyEnabled(ctx.Author.Id);
if (!allowAutoproxy)
eb.Field(new Embed.Field("\u200b", $"{Emojis.Note} Autoproxy is currently **disabled** for your account (<@{ctx.Author.Id}>). To enable it, use `pk;autoproxy account enable`.")); eb.Field(new Embed.Field("\u200b", $"{Emojis.Note} Autoproxy is currently **disabled** for your account (<@{ctx.Author.Id}>). To enable it, use `pk;autoproxy account enable`."));
return eb.Build(); return eb.Build();

View File

@ -17,10 +17,12 @@ public class Config
{ {
var items = new List<PaginatedConfigItem>(); var items = new List<PaginatedConfigItem>();
var allowAutoproxy = await ctx.Repository.GetAutoproxyEnabled(ctx.Author.Id);
items.Add(new( items.Add(new(
"autoproxy account", "autoproxy account",
"Whether autoproxy is enabled for the current account", "Whether autoproxy is enabled for the current account",
EnabledDisabled(ctx.MessageContext.AllowAutoproxy), EnabledDisabled(allowAutoproxy),
"enabled" "enabled"
)); ));
@ -122,16 +124,18 @@ public class Config
public async Task AutoproxyAccount(Context ctx) public async Task AutoproxyAccount(Context ctx)
{ {
var allowAutoproxy = await ctx.Repository.GetAutoproxyEnabled(ctx.Author.Id);
if (!ctx.HasNext()) if (!ctx.HasNext())
{ {
await ctx.Reply($"Autoproxy is currently **{EnabledDisabled(ctx.MessageContext.AllowAutoproxy)}** for account <@{ctx.Author.Id}>."); await ctx.Reply($"Autoproxy is currently **{EnabledDisabled(allowAutoproxy)}** for account <@{ctx.Author.Id}>.");
return; return;
} }
var allow = ctx.MatchToggle(true); var allow = ctx.MatchToggle(true);
var statusString = EnabledDisabled(allow); var statusString = EnabledDisabled(allow);
if (ctx.MessageContext.AllowAutoproxy == allow) if (allowAutoproxy == allow)
{ {
await ctx.Reply($"{Emojis.Note} Autoproxy is already {statusString} for account <@{ctx.Author.Id}>."); await ctx.Reply($"{Emojis.Note} Autoproxy is already {statusString} for account <@{ctx.Author.Id}>.");
return; return;

View File

@ -126,8 +126,7 @@ public class ProxiedMessage
if ((await ctx.BotPermissions).HasFlag(PermissionSet.ManageMessages)) if ((await ctx.BotPermissions).HasFlag(PermissionSet.ManageMessages))
await _rest.DeleteMessage(ctx.Channel.Id, ctx.Message.Id); await _rest.DeleteMessage(ctx.Channel.Id, ctx.Message.Id);
await _logChannel.LogMessage(ctx.MessageContext, msg.Message, ctx.Message, editedMsg, await _logChannel.LogMessage(msg.Message, ctx.Message, editedMsg, originalMsg!.Content!);
originalMsg!.Content!);
} }
catch (NotFoundException) catch (NotFoundException)
{ {

View File

@ -277,7 +277,7 @@ public class SystemEdit
await ctx.Reply( await ctx.Reply(
$"{Emojis.Success} System server tag changed. Member names will now end with {newTag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'."); $"{Emojis.Success} System server tag changed. Member names will now end with {newTag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'.");
if (!ctx.MessageContext.TagEnabled) if (!settings.TagEnabled)
await ctx.Reply(setDisabledWarning); await ctx.Reply(setDisabledWarning);
} }
@ -288,7 +288,7 @@ public class SystemEdit
await ctx.Reply( await ctx.Reply(
$"{Emojis.Success} System server tag cleared. Member names will now end with the global system tag, if there is one set."); $"{Emojis.Success} System server tag cleared. Member names will now end with the global system tag, if there is one set.");
if (!ctx.MessageContext.TagEnabled) if (!settings.TagEnabled)
await ctx.Reply(setDisabledWarning); await ctx.Reply(setDisabledWarning);
} }
@ -297,7 +297,7 @@ public class SystemEdit
await ctx.Repository.UpdateSystemGuild(target.Id, ctx.Guild.Id, await ctx.Repository.UpdateSystemGuild(target.Id, ctx.Guild.Id,
new SystemGuildPatch { TagEnabled = newValue }); new SystemGuildPatch { TagEnabled = newValue });
await ctx.Reply(PrintEnableDisableResult(newValue, newValue != ctx.MessageContext.TagEnabled)); await ctx.Reply(PrintEnableDisableResult(newValue, newValue != settings.TagEnabled));
} }
string PrintEnableDisableResult(bool newValue, bool changedValue) string PrintEnableDisableResult(bool newValue, bool changedValue)
@ -312,20 +312,20 @@ public class SystemEdit
if (newValue) if (newValue)
{ {
if (ctx.MessageContext.TagEnabled) if (settings.TagEnabled)
{ {
if (ctx.MessageContext.SystemGuildTag == null) if (settings.Tag == null)
str += str +=
" However, you do not have a system tag specific to this server. Messages will be proxied using your global system tag, if there is one set."; " However, you do not have a system tag specific to this server. Messages will be proxied using your global system tag, if there is one set.";
else else
str += str +=
$" Your current system tag in '{ctx.Guild.Name}' is {ctx.MessageContext.SystemGuildTag.AsCode()}."; $" Your current system tag in '{ctx.Guild.Name}' is {settings.Tag.AsCode()}.";
} }
else else
{ {
if (ctx.MessageContext.SystemGuildTag != null) if (settings.Tag != null)
str += str +=
$" Member names will now end with the server-specific tag {ctx.MessageContext.SystemGuildTag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'."; $" Member names will now end with the server-specific tag {settings.Tag.AsCode()} when proxied in the current server '{ctx.Guild.Name}'.";
else else
str += str +=
" Member names will now end with the global system tag when proxied in the current server, if there is one set."; " Member names will now end with the global system tag when proxied in the current server, if there is one set.";

View File

@ -63,7 +63,8 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return; if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return;
if (IsDuplicateMessage(evt)) return; if (IsDuplicateMessage(evt)) return;
if (!(await _cache.PermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.SendMessages)) return; var botPermissions = await _cache.PermissionsIn(evt.ChannelId);
if (!botPermissions.HasFlag(PermissionSet.SendMessages)) return;
// spawn off saving the private channel into another thread // spawn off saving the private channel into another thread
// it is not a fatal error if this fails, and it shouldn't block message processing // it is not a fatal error if this fails, and it shouldn't block message processing
@ -77,36 +78,33 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived); _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
_lastMessageCache.AddMessage(evt); _lastMessageCache.AddMessage(evt);
// Get message context from DB (tracking w/ metrics) // if the message was not sent by an user account, only try running log cleanup
MessageContext ctx; if (evt.Author.Bot || evt.WebhookId != null || evt.Author.System == true)
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime)) {
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel.Id); await TryHandleLogClean(channel, evt);
return;
}
// Try each handler until we find one that succeeds // Try each handler until we find one that succeeds
if (await TryHandleLogClean(evt, ctx))
if (await TryHandleCommand(shardId, evt, guild, channel))
return; return;
// Only do command/proxy handling if it's a user account await TryHandleProxy(evt, guild, channel, rootChannel.Id, botPermissions);
if (evt.Author.Bot || evt.WebhookId != null || evt.Author.System == true)
return;
if (await TryHandleCommand(shardId, evt, guild, channel, ctx))
return;
await TryHandleProxy(evt, guild, channel, ctx);
} }
private async ValueTask<bool> TryHandleLogClean(MessageCreateEvent evt, MessageContext ctx) private async Task TryHandleLogClean(Channel channel, MessageCreateEvent evt)
{ {
var channel = await _cache.GetChannel(evt.ChannelId); if (evt.GuildId != null) return;
if (!evt.Author.Bot || channel.Type != Channel.ChannelType.GuildText || if (channel.Type != Channel.ChannelType.GuildText) return;
!ctx.LogCleanupEnabled) return false;
await _loggerClean.HandleLoggerBotCleanup(evt); var guildSettings = await _repo.GetGuild(evt.GuildId!.Value);
return true;
if (guildSettings.LogCleanupEnabled)
await _loggerClean.HandleLoggerBotCleanup(evt);
} }
private async ValueTask<bool> TryHandleCommand(int shardId, MessageCreateEvent evt, Guild? guild, private async ValueTask<bool> TryHandleCommand(int shardId, MessageCreateEvent evt, Guild? guild, Channel channel)
Channel channel, MessageContext ctx)
{ {
var content = evt.Content; var content = evt.Content;
if (content == null) return false; if (content == null) return false;
@ -125,9 +123,9 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
try try
{ {
var system = ctx.SystemId != null ? await _repo.GetSystem(ctx.SystemId.Value) : null; var system = await _repo.GetSystemByAccount(evt.Author.Id);
var config = ctx.SystemId != null ? await _repo.GetSystemConfig(ctx.SystemId.Value) : null; var config = system != null ? await _repo.GetSystemConfig(system.Id) : null;
await _tree.ExecuteCommand(new Context(_services, shardId, guild, channel, evt, cmdStart, system, config, ctx)); await _tree.ExecuteCommand(new Context(_services, shardId, guild, channel, evt, cmdStart, system, config));
} }
catch (PKError) catch (PKError)
{ {
@ -158,10 +156,12 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
return false; return false;
} }
private async ValueTask<bool> TryHandleProxy(MessageCreateEvent evt, Guild guild, Channel channel, private async ValueTask<bool> TryHandleProxy(MessageCreateEvent evt, Guild guild, Channel channel, ulong rootChannel, PermissionSet botPermissions)
MessageContext ctx)
{ {
var botPermissions = await _cache.PermissionsIn(channel.Id); // Get message context from DB (tracking w/ metrics)
MessageContext ctx;
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(evt.Author.Id, evt.GuildId ?? default, rootChannel);
try try
{ {

View File

@ -386,7 +386,7 @@ public class ProxyService
=> _repo.AddMessage(sentMessage); => _repo.AddMessage(sentMessage);
Task LogMessageToChannel() => Task LogMessageToChannel() =>
_logChannel.LogMessage(ctx, sentMessage, triggerMessage, proxyMessage).AsTask(); _logChannel.LogMessage(sentMessage, triggerMessage, proxyMessage).AsTask();
Task SaveLatchAutoproxy() => autoproxySettings.AutoproxyMode == AutoproxyMode.Latch Task SaveLatchAutoproxy() => autoproxySettings.AutoproxyMode == AutoproxyMode.Latch
? _repo.UpdateAutoproxy(ctx.SystemId.Value, triggerMessage.GuildId, null, new() ? _repo.UpdateAutoproxy(ctx.SystemId.Value, triggerMessage.GuildId, null, new()

View File

@ -34,17 +34,16 @@ public class LogChannelService
_logger = logger.ForContext<LogChannelService>(); _logger = logger.ForContext<LogChannelService>();
} }
public async ValueTask LogMessage(MessageContext ctx, PKMessage proxiedMessage, Message trigger, public async ValueTask LogMessage(PKMessage proxiedMessage, Message trigger, Message hookMessage, string oldContent = null)
Message hookMessage, string oldContent = null)
{ {
var logChannelId = await GetAndCheckLogChannel(ctx, trigger, proxiedMessage); var logChannelId = await GetAndCheckLogChannel(trigger, proxiedMessage);
if (logChannelId == null) if (logChannelId == null)
return; return;
var triggerChannel = await _cache.GetChannel(proxiedMessage.Channel); var triggerChannel = await _cache.GetChannel(proxiedMessage.Channel);
var system = await _repo.GetSystem(ctx.SystemId.Value);
var member = await _repo.GetMember(proxiedMessage.Member!.Value); var member = await _repo.GetMember(proxiedMessage.Member!.Value);
var system = await _repo.GetSystem(member.System);
// Send embed! // Send embed!
var embed = _embed.CreateLoggedMessageEmbed(trigger, hookMessage, system.Hid, member, triggerChannel.Name, var embed = _embed.CreateLoggedMessageEmbed(trigger, hookMessage, system.Hid, member, triggerChannel.Name,
@ -54,8 +53,7 @@ public class LogChannelService
await _rest.CreateMessage(logChannelId.Value, new MessageRequest { Content = url, Embeds = new[] { embed } }); await _rest.CreateMessage(logChannelId.Value, new MessageRequest { Content = url, Embeds = new[] { embed } });
} }
private async Task<ulong?> GetAndCheckLogChannel(MessageContext ctx, Message trigger, private async Task<ulong?> GetAndCheckLogChannel(Message trigger, PKMessage proxiedMessage)
PKMessage proxiedMessage)
{ {
if (proxiedMessage.Guild == null && proxiedMessage.Channel != trigger.ChannelId) if (proxiedMessage.Guild == null && proxiedMessage.Channel != trigger.ChannelId)
// a very old message is being edited outside of its original channel // a very old message is being edited outside of its original channel
@ -63,18 +61,15 @@ public class LogChannelService
return null; return null;
var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value; var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
var logChannelId = ctx.LogChannel;
var isBlacklisted = ctx.InLogBlacklist;
if (proxiedMessage.Guild != trigger.GuildId) // get log channel info from the database
{ var guild = await _repo.GetGuild(guildId);
// we're editing a message from a different server, get log channel info from the database var logChannelId = guild.LogChannel;
var guild = await _repo.GetGuild(proxiedMessage.Guild.Value); var isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
logChannelId = guild.LogChannel;
isBlacklisted = guild.LogBlacklist.Any(x => x == trigger.ChannelId);
}
if (ctx.SystemId == null || logChannelId == null || isBlacklisted) return null; // if (ctx.SystemId == null ||
// removed the above, there shouldn't be a way to get to this code path if you don't have a system registered
if (logChannelId == null || isBlacklisted) return null;
// Find log channel and check if valid // Find log channel and check if valid
var logChannel = await FindLogChannel(guildId, logChannelId.Value); var logChannel = await FindLogChannel(guildId, logChannelId.Value);

View File

@ -9,6 +9,9 @@ public partial class ModelRepository
public async Task<ulong?> GetDmChannel(ulong id) public async Task<ulong?> GetDmChannel(ulong id)
=> await _db.Execute(c => c.QueryFirstOrDefaultAsync<ulong?>("select dm_channel from accounts where uid = @id", new { id = id })); => await _db.Execute(c => c.QueryFirstOrDefaultAsync<ulong?>("select dm_channel from accounts where uid = @id", new { id = id }));
public async Task<bool> GetAutoproxyEnabled(ulong id)
=> await _db.QueryFirst<bool>(new Query("accounts").Select("allow_autoproxy").Where("uid", id));
public async Task UpdateAccount(ulong id, AccountPatch patch) public async Task UpdateAccount(ulong id, AccountPatch patch)
{ {
_logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch); _logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch);