diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index 34422a9a..a152b384 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -11,10 +11,6 @@ using DSharpPlus.EventArgs; using PluralKit.Core; -using Sentry; - -using Serilog; - namespace PluralKit.Bot { public class MessageCreated: IEventHandler @@ -22,83 +18,84 @@ namespace PluralKit.Bot private readonly CommandTree _tree; private readonly DiscordShardedClient _client; private readonly LastMessageCacheService _lastMessageCache; - private readonly ILogger _logger; private readonly LoggerCleanService _loggerClean; private readonly IMetrics _metrics; private readonly ProxyService _proxy; - private readonly ProxyCache _proxyCache; - private readonly Scope _sentryScope; private readonly ILifetimeScope _services; - - public MessageCreated(LastMessageCacheService lastMessageCache, ILogger logger, LoggerCleanService loggerClean, IMetrics metrics, ProxyService proxy, ProxyCache proxyCache, Scope sentryScope, DiscordShardedClient client, CommandTree tree, ILifetimeScope services) + private readonly DbConnectionFactory _db; + private readonly IDataStore _data; + + public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, + IMetrics metrics, ProxyService proxy, DiscordShardedClient client, + CommandTree tree, ILifetimeScope services, DbConnectionFactory db, IDataStore data) { _lastMessageCache = lastMessageCache; - _logger = logger; _loggerClean = loggerClean; _metrics = metrics; _proxy = proxy; - _proxyCache = proxyCache; - _sentryScope = sentryScope; _client = client; _tree = tree; _services = services; + _db = db; + _data = data; } public DiscordChannel ErrorChannelFor(MessageCreateEventArgs evt) => evt.Channel; + private bool IsDuplicateMessage(DiscordMessage evt) => + // We consider a message duplicate if it has the same ID as the previous message that hit the gateway + _lastMessageCache.GetLastMessage(evt.ChannelId) == evt.Id; + public async Task Handle(MessageCreateEventArgs evt) { - // Drop the message if we've already received it. - // Gateway occasionally resends events for whatever reason and it can break stuff relying on IDs being unique - // Not a perfect fix since reordering could still break things but... it's good enough for now - // (was considering checking the order of the IDs but IDs aren't guaranteed to be *perfectly* in order, so that'd cause false positives) - // LastMessageCache is updated in RegisterMessageMetrics so the ordering here is correct. - if (_lastMessageCache.GetLastMessage(evt.Channel.Id) == evt.Message.Id) return; - RegisterMessageMetrics(evt); + if (evt.Message.MessageType != MessageType.Default) return; + if (IsDuplicateMessage(evt.Message)) return; - // Ignore system messages (member joined, message pinned, etc) - var msg = evt.Message; - if (msg.MessageType != MessageType.Default) return; - - var cachedGuild = await _proxyCache.GetGuildDataCached(msg.Channel.GuildId); - var cachedAccount = await _proxyCache.GetAccountDataCached(msg.Author.Id); - // this ^ may be null, do remember that down the line + // Log metrics and message info + _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived); + _lastMessageCache.AddMessage(evt.Channel.Id, evt.Message.Id); - // Pass guild bot/WH messages onto the logger cleanup service - if (msg.Author.IsBot && msg.Channel.Type == ChannelType.Text) - { - await _loggerClean.HandleLoggerBotCleanup(msg, cachedGuild); - return; - } - - // First try parsing a command, then try proxying - if (await TryHandleCommand(evt, cachedGuild, cachedAccount)) return; - await TryHandleProxy(evt, cachedGuild, cachedAccount); + // Try each handler until we find one that succeeds + var ctx = await _db.Execute(c => c.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id)); + var _ = await TryHandleLogClean(evt, ctx) || + await TryHandleCommand(evt, ctx) || + await TryHandleProxy(evt, ctx); } - private async Task TryHandleCommand(MessageCreateEventArgs evt, GuildConfig cachedGuild, CachedAccount cachedAccount) + private async ValueTask TryHandleLogClean(MessageCreateEventArgs evt, MessageContext ctx) { - var msg = evt.Message; - - int argPos = -1; + if (!evt.Message.Author.IsBot || evt.Message.Channel.Type != ChannelType.Text || + !ctx.LogCleanupEnabled) return false; + + await _loggerClean.HandleLoggerBotCleanup(evt.Message); + return true; + } + + private async ValueTask TryHandleCommand(MessageCreateEventArgs evt, MessageContext ctx) + { + var content = evt.Message.Content; + if (content == null) return false; + + var argPos = -1; // Check if message starts with the command prefix - if (msg.Content.StartsWith("pk;", StringComparison.InvariantCultureIgnoreCase)) argPos = 3; - else if (msg.Content.StartsWith("pk!", StringComparison.InvariantCultureIgnoreCase)) argPos = 3; - else if (msg.Content != null && StringUtils.HasMentionPrefix(msg.Content, ref argPos, out var id)) // Set argPos to the proper value + if (content.StartsWith("pk;", StringComparison.InvariantCultureIgnoreCase)) argPos = 3; + else if (content.StartsWith("pk!", StringComparison.InvariantCultureIgnoreCase)) argPos = 3; + else if (StringUtils.HasMentionPrefix(content, ref argPos, out var id)) // Set argPos to the proper value if (id != _client.CurrentUser.Id) // But undo it if it's someone else's ping argPos = -1; - + // If we didn't find a prefix, give up handling commands if (argPos == -1) return false; - + // Trim leading whitespace from command without actually modifying the wring // This just moves the argPos pointer by however much whitespace is at the start of the post-argPos string - var trimStartLengthDiff = msg.Content.Substring(argPos).Length - msg.Content.Substring(argPos).TrimStart().Length; + var trimStartLengthDiff = content.Substring(argPos).Length - content.Substring(argPos).TrimStart().Length; argPos += trimStartLengthDiff; try { - await _tree.ExecuteCommand(new Context(_services, evt.Client, msg, argPos, cachedAccount?.System)); + var system = ctx.SystemId != null ? await _data.GetSystemById(ctx.SystemId.Value) : null; + await _tree.ExecuteCommand(new Context(_services, evt.Client, evt.Message, argPos, system)); } catch (PKError) { @@ -109,31 +106,21 @@ namespace PluralKit.Bot return true; } - private async Task TryHandleProxy(MessageCreateEventArgs evt, GuildConfig cachedGuild, CachedAccount cachedAccount) + private async ValueTask TryHandleProxy(MessageCreateEventArgs evt, MessageContext ctx) { - var msg = evt.Message; - - // If we don't have any cached account data, this means no member in the account has a proxy tag set - if (cachedAccount == null) return false; - try { - await _proxy.HandleIncomingMessage(evt.Message, allowAutoproxy: true); + return await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: true); } catch (PKError e) { // User-facing errors, print to the channel properly formatted + var msg = evt.Message; if (msg.Channel.Guild == null || msg.Channel.BotHasAllPermissions(Permissions.SendMessages)) await msg.Channel.SendMessageAsync($"{Emojis.Error} {e.Message}"); } - return true; - } - - private void RegisterMessageMetrics(MessageCreateEventArgs evt) - { - _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived); - _lastMessageCache.AddMessage(evt.Channel.Id, evt.Message.Id); + return false; } } } \ No newline at end of file diff --git a/PluralKit.Bot/Handlers/MessageEdited.cs b/PluralKit.Bot/Handlers/MessageEdited.cs index c1cef6ce..8791e554 100644 --- a/PluralKit.Bot/Handlers/MessageEdited.cs +++ b/PluralKit.Bot/Handlers/MessageEdited.cs @@ -1,12 +1,9 @@ -using System.Collections.Generic; using System.Threading.Tasks; using DSharpPlus.EventArgs; using PluralKit.Core; -using Sentry; - namespace PluralKit.Bot { @@ -14,38 +11,26 @@ namespace PluralKit.Bot { private readonly LastMessageCacheService _lastMessageCache; private readonly ProxyService _proxy; - private readonly ProxyCache _proxyCache; - private readonly Scope _sentryScope; + private readonly DbConnectionFactory _db; - public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, ProxyCache proxyCache, Scope sentryScope) + public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, DbConnectionFactory db) { _lastMessageCache = lastMessageCache; _proxy = proxy; - _proxyCache = proxyCache; - _sentryScope = sentryScope; + _db = db; } public async Task Handle(MessageUpdateEventArgs evt) { - // Sometimes edit message events arrive for other reasons (eg. an embed gets updated server-side) - // If this wasn't a *content change* (ie. there's message contents to read), bail - // It'll also sometimes arrive with no *author*, so we'll go ahead and ignore those messages too - if (evt.Message.Content == null) return; - if (evt.Author == null) return; + // Edit message events sometimes arrive with missing data; double-check it's all there + if (evt.Message.Content == null || evt.Author == null || evt.Channel.Guild == null) return; - // Also, if this is in DMs don't bother either - if (evt.Channel.Guild == null) return; - - // If this isn't the last message in the channel, don't do anything + // Only react to the last message in the channel if (_lastMessageCache.GetLastMessage(evt.Channel.Id) != evt.Message.Id) return; - // Fetch account and guild info from cache if there is any - var account = await _proxyCache.GetAccountDataCached(evt.Author.Id); - if (account == null) return; // Again: no cache = no account = no system = no proxy - var guild = await _proxyCache.GetGuildDataCached(evt.Channel.GuildId); - - // Just run the normal message handling stuff, with a flag to disable autoproxying - await _proxy.HandleIncomingMessage(evt.Message, allowAutoproxy: false); + // Just run the normal message handling code, with a flag to disable autoproxying + var ctx = await _db.Execute(c => c.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id)); + await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: false); } } } \ No newline at end of file diff --git a/PluralKit.Bot/Proxy/ProxyMatcher.cs b/PluralKit.Bot/Proxy/ProxyMatcher.cs index f2079913..339e7e8c 100644 --- a/PluralKit.Bot/Proxy/ProxyMatcher.cs +++ b/PluralKit.Bot/Proxy/ProxyMatcher.cs @@ -20,11 +20,11 @@ namespace PluralKit.Bot _clock = clock; } - public bool TryMatch(IReadOnlyCollection members, out ProxyMatch match, string messageContent, + public bool TryMatch(MessageContext ctx, IReadOnlyCollection members, out ProxyMatch match, string messageContent, bool hasAttachments, bool allowAutoproxy) { if (TryMatchTags(members, messageContent, hasAttachments, out match)) return true; - if (allowAutoproxy && TryMatchAutoproxy(members, messageContent, out match)) return true; + if (allowAutoproxy && TryMatchAutoproxy(ctx, members, messageContent, out match)) return true; return false; } @@ -37,33 +37,44 @@ namespace PluralKit.Bot return hasAttachments || match.Content.Length > 0; } - private bool TryMatchAutoproxy(IReadOnlyCollection members, string messageContent, + private bool TryMatchAutoproxy(MessageContext ctx, IReadOnlyCollection members, string messageContent, out ProxyMatch match) { match = default; - // We handle most autoproxy logic in the database function, so we just look for the member that's marked - var info = members.FirstOrDefault(i => i.IsAutoproxyMember); - if (info == null) return false; - - // If we're in latch mode and the latch message is too old, fail the match too - if (info.AutoproxyMode == AutoproxyMode.Latch && info.LatchMessage != null) + // Find the member we should autoproxy (null if none) + var member = ctx.AutoproxyMode switch { - var timestamp = DiscordUtils.SnowflakeToInstant(info.LatchMessage.Value); - if (_clock.GetCurrentInstant() - timestamp > LatchExpiryTime) return false; - } + AutoproxyMode.Member when ctx.AutoproxyMember != null => + members.FirstOrDefault(m => m.Id == ctx.AutoproxyMember), + + AutoproxyMode.Front when ctx.LastSwitchMembers.Count > 0 => + members.FirstOrDefault(m => m.Id == ctx.LastSwitchMembers[0]), + + AutoproxyMode.Latch when ctx.LastMessageMember != null && !IsLatchExpired(ctx.LastMessage) => + members.FirstOrDefault(m => m.Id == ctx.LastMessageMember.Value), + + _ => null + }; - // Match succeeded, build info object and return + if (member == null) return false; match = new ProxyMatch { Content = messageContent, - Member = info, - + Member = member, + // We're autoproxying, so not using any proxy tags here // we just find the first pair of tags (if any), otherwise null - ProxyTags = info.ProxyTags.FirstOrDefault() + ProxyTags = member.ProxyTags.FirstOrDefault() }; return true; } + + private bool IsLatchExpired(ulong? messageId) + { + if (messageId == null) return true; + var timestamp = DiscordUtils.SnowflakeToInstant(messageId.Value); + return _clock.GetCurrentInstant() - timestamp > LatchExpiryTime; + } } } \ No newline at end of file diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index 34b21e03..fe2e5881 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -20,11 +20,11 @@ namespace PluralKit.Bot { public static readonly TimeSpan MessageDeletionDelay = TimeSpan.FromMilliseconds(1000); - private LogChannelService _logChannel; - private DbConnectionFactory _db; - private IDataStore _data; - private ILogger _logger; - private WebhookExecutorService _webhookExecutor; + private readonly LogChannelService _logChannel; + private readonly DbConnectionFactory _db; + private readonly IDataStore _data; + private readonly ILogger _logger; + private readonly WebhookExecutorService _webhookExecutor; private readonly ProxyMatcher _matcher; public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger, @@ -38,35 +38,57 @@ namespace PluralKit.Bot _logger = logger.ForContext(); } - public async Task HandleIncomingMessage(DiscordMessage message, bool allowAutoproxy) + public async Task HandleIncomingMessage(DiscordMessage message, MessageContext ctx, bool allowAutoproxy) { - // Quick context checks to quit early - if (!IsMessageValid(message)) return; + if (!ShouldProxy(message, ctx)) return false; // Fetch members and try to match to a specific member - var members = await FetchProxyMembers(message.Author.Id, message.Channel.GuildId); - if (!_matcher.TryMatch(members, out var match, message.Content, message.Attachments.Count > 0, - allowAutoproxy)) return; + var members = (await _db.Execute(c => c.QueryProxyMembers(message.Author.Id, message.Channel.GuildId))).ToList(); + if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0, + allowAutoproxy)) return false; - // Do some quick permission checks before going through with the proxy - // (do channel checks *after* checking other perms to make sure we don't spam errors when eg. channel is blacklisted) - if (!IsProxyValid(message, match)) return; - if (!await CheckBotPermissionsOrError(message.Channel)) return; - if (!CheckProxyNameBoundsOrError(match)) return; + // Permission check after proxy match so we don't get spammed when not actually proxying + if (!await CheckBotPermissionsOrError(message.Channel)) return false; + if (!CheckProxyNameBoundsOrError(match)) return false; // Everything's in order, we can execute the proxy! - await ExecuteProxy(message, match); + await ExecuteProxy(message, ctx, match); + return true; } - - private async Task ExecuteProxy(DiscordMessage trigger, ProxyMatch match) + + private bool ShouldProxy(DiscordMessage msg, MessageContext ctx) + { + // Make sure author has a system + if (ctx.SystemId == null) return false; + + // Make sure channel is a guild text channel and this is a normal message + if (msg.Channel.Type != ChannelType.Text || msg.MessageType != MessageType.Default) return false; + + // Make sure author is a normal user + if (msg.Author.IsSystem == true || msg.Author.IsBot || msg.WebhookMessage) return false; + + // Make sure proxying is enabled here + if (!ctx.ProxyEnabled || ctx.InBlacklist) return false; + + // Make sure we have either an attachment or message content + var isMessageBlank = msg.Content == null || msg.Content.Trim().Length == 0; + if (isMessageBlank && msg.Attachments.Count == 0) return false; + + // All good! + return true; + } + + private async Task ExecuteProxy(DiscordMessage trigger, MessageContext ctx, ProxyMatch match) { // Send the webhook - var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName, match.Member.ProxyAvatar, + var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName, + match.Member.ProxyAvatar, match.Content, trigger.Attachments); - + // Handle post-proxy actions - await _data.AddMessage(trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id, match.Member.MemberId); - await _logChannel.LogMessage(match, trigger, id); + await _data.AddMessage(trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id, + match.Member.Id); + await _logChannel.LogMessage(ctx, match, trigger, id); // Wait a second or so before deleting the original message await Task.Delay(MessageDeletionDelay); @@ -81,14 +103,6 @@ namespace PluralKit.Bot } } - private async Task> FetchProxyMembers(ulong account, ulong guild) - { - await using var conn = await _db.Obtain(); - var members = await conn.QueryAsync("proxy_info", - new {account_id = account, guild_id = guild}, commandType: CommandType.StoredProcedure); - return members.ToList(); - } - private async Task CheckBotPermissionsOrError(DiscordChannel channel) { var permissions = channel.BotPermissions(); @@ -114,43 +128,15 @@ namespace PluralKit.Bot return true; } - + private bool CheckProxyNameBoundsOrError(ProxyMatch match) { var proxyName = match.Member.ProxyName; if (proxyName.Length < 2) throw Errors.ProxyNameTooShort(proxyName); if (proxyName.Length > Limits.MaxProxyNameLength) throw Errors.ProxyNameTooLong(proxyName); - + // TODO: this never returns false as it throws instead, should this happen? return true; } - - private bool IsMessageValid(DiscordMessage message) - { - return - // Must be a guild text channel - message.Channel.Type == ChannelType.Text && - - // Must not be a system message - message.MessageType == MessageType.Default && - !(message.Author.IsSystem ?? false) && - - // Must not be a bot or webhook message - !message.WebhookMessage && - !message.Author.IsBot && - - // Must have either an attachment or content (or both, but not neither) - (message.Attachments.Count > 0 || (message.Content != null && message.Content.Trim().Length > 0)); - } - - private bool IsProxyValid(DiscordMessage message, ProxyMatch match) - { - return - // System and member must have proxying enabled in this guild - match.Member.ProxyEnabled && - - // Channel must not be blacklisted here - !match.Member.ChannelBlacklist.Contains(message.ChannelId); - } } } \ No newline at end of file diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index cf14124b..12099d33 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -28,12 +28,12 @@ namespace PluralKit.Bot { _logger = logger.ForContext(); } - public async Task LogMessage(ProxyMatch proxy, DiscordMessage trigger, ulong hookMessage) + public async ValueTask LogMessage(MessageContext ctx, ProxyMatch proxy, DiscordMessage trigger, ulong hookMessage) { - if (proxy.Member.LogChannel == null || proxy.Member.LogBlacklist.Contains(trigger.ChannelId)) return; + if (ctx.SystemId == null || ctx.LogChannel == null || ctx.InLogBlacklist) return; // Find log channel and check if valid - var logChannel = await FindLogChannel(trigger.Channel.GuildId, proxy); + var logChannel = await FindLogChannel(trigger.Channel.GuildId, ctx.LogChannel.Value); if (logChannel == null || logChannel.Type != ChannelType.Text) return; // Check bot permissions @@ -41,25 +41,23 @@ namespace PluralKit.Bot { // Send embed! await using var conn = await _db.Obtain(); - var embed = _embed.CreateLoggedMessageEmbed(await _data.GetSystemById(proxy.Member.SystemId), - await _data.GetMemberById(proxy.Member.MemberId), hookMessage, trigger.Id, trigger.Author, proxy.Content, + var embed = _embed.CreateLoggedMessageEmbed(await _data.GetSystemById(ctx.SystemId.Value), + await _data.GetMemberById(proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content, trigger.Channel); var url = $"https://discord.com/channels/{trigger.Channel.GuildId}/{trigger.ChannelId}/{hookMessage}"; await logChannel.SendMessageAsync(content: url, embed: embed); } - private async Task FindLogChannel(ulong guild, ProxyMatch proxy) + private async Task FindLogChannel(ulong guild, ulong channel) { - var logChannel = proxy.Member.LogChannel.Value; - try { - return await _rest.GetChannelAsync(logChannel); + return await _rest.GetChannelAsync(channel); } catch (NotFoundException) { // Channel doesn't exist, let's remove it from the database too - _logger.Warning("Attempted to fetch missing log channel {LogChannel}, removing from database", logChannel); + _logger.Warning("Attempted to fetch missing log channel {LogChannel}, removing from database", channel); await using var conn = await _db.Obtain(); await conn.ExecuteAsync("update servers set log_channel = null where server = @Guild", new {Guild = guild}); diff --git a/PluralKit.Bot/Services/LoggerCleanService.cs b/PluralKit.Bot/Services/LoggerCleanService.cs index 4296c774..cf196a91 100644 --- a/PluralKit.Bot/Services/LoggerCleanService.cs +++ b/PluralKit.Bot/Services/LoggerCleanService.cs @@ -64,10 +64,8 @@ namespace PluralKit.Bot public ICollection Bots => _bots.Values; - public async ValueTask HandleLoggerBotCleanup(DiscordMessage msg, GuildConfig cachedGuild) + public async ValueTask HandleLoggerBotCleanup(DiscordMessage msg) { - // Bail if not enabled, or if we don't have permission here - if (!cachedGuild.LogCleanupEnabled) return; if (msg.Channel.Type != ChannelType.Text) return; if (!msg.Channel.BotHasAllPermissions(Permissions.ManageMessages)) return; diff --git a/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs b/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs new file mode 100644 index 00000000..fd7d7728 --- /dev/null +++ b/PluralKit.Core/Database/Functions/DatabaseFunctionsExt.cs @@ -0,0 +1,25 @@ +using System.Collections.Generic; +using System.Data; +using System.Threading.Tasks; + +using Dapper; + +namespace PluralKit.Core +{ + public static class DatabaseFunctionsExt + { + public static Task QueryMessageContext(this IDbConnection conn, ulong account, ulong guild, ulong channel) + { + return conn.QueryFirstAsync("message_context", + new { account_id = account, guild_id = guild, channel_id = channel }, + commandType: CommandType.StoredProcedure); + } + + public static Task> QueryProxyMembers(this IDbConnection conn, ulong account, ulong guild) + { + return conn.QueryAsync("proxy_members", + new { account_id = account, guild_id = guild }, + commandType: CommandType.StoredProcedure); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Functions/MessageContext.cs b/PluralKit.Core/Database/Functions/MessageContext.cs new file mode 100644 index 00000000..c9f31af0 --- /dev/null +++ b/PluralKit.Core/Database/Functions/MessageContext.cs @@ -0,0 +1,27 @@ +#nullable enable +using System.Collections.Generic; + +using NodaTime; + +namespace PluralKit.Core +{ + /// + /// Model for the `message_context` PL/pgSQL function in `functions.sql` + /// + public class MessageContext + { + public int? SystemId { get; set; } + public ulong? LogChannel { get; set; } + public bool InBlacklist { get; set; } + public bool InLogBlacklist { get; set; } + public bool LogCleanupEnabled { get; set; } + public bool ProxyEnabled { get; set; } + public AutoproxyMode AutoproxyMode { get; set; } + public int? AutoproxyMember { get; set; } + public ulong? LastMessage { get; set; } + public int? LastMessageMember { get; set; } + public int LastSwitch { get; set; } + public IReadOnlyList LastSwitchMembers { get; set; } = new int[0]; + public Instant LastSwitchTimestamp { get; set; } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Functions/ProxyMember.cs b/PluralKit.Core/Database/Functions/ProxyMember.cs new file mode 100644 index 00000000..9466edc8 --- /dev/null +++ b/PluralKit.Core/Database/Functions/ProxyMember.cs @@ -0,0 +1,17 @@ +#nullable enable +using System.Collections.Generic; + +namespace PluralKit.Core +{ + /// + /// Model for the `proxy_members` PL/pgSQL function in `functions.sql` + /// + public class ProxyMember + { + public int Id { get; set; } + public IReadOnlyCollection ProxyTags { get; set; } = new ProxyTag[0]; + public bool KeepProxy { get; set; } + public string ProxyName { get; set; } = ""; + public string? ProxyAvatar { get; set; } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Functions/functions.sql b/PluralKit.Core/Database/Functions/functions.sql new file mode 100644 index 00000000..4b65afd0 --- /dev/null +++ b/PluralKit.Core/Database/Functions/functions.sql @@ -0,0 +1,73 @@ +create function message_context(account_id bigint, guild_id bigint, channel_id bigint) + returns table ( + system_id int, + log_channel bigint, + in_blacklist bool, + in_log_blacklist bool, + log_cleanup_enabled bool, + proxy_enabled bool, + autoproxy_mode int, + autoproxy_member int, + last_message bigint, + last_message_member int, + last_switch int, + last_switch_members int[], + last_switch_timestamp timestamp + ) +as $$ + with + system as (select systems.* from accounts inner join systems on systems.id = accounts.system where accounts.uid = account_id), + guild as (select * from servers where id = guild_id), + last_message as (select * from messages where messages.guild = guild_id and messages.sender = account_id order by mid desc limit 1) + select + system.id as system_id, + guild.log_channel, + (channel_id = any(guild.blacklist)) as in_blacklist, + (channel_id = any(guild.log_blacklist)) as in_log_blacklist, + guild.log_cleanup_enabled, + coalesce(system_guild.proxy_enabled, true) as proxy_enabled, + coalesce(system_guild.autoproxy_mode, 1) as autoproxy_mode, + system_guild.autoproxy_member, + last_message.mid as last_message, + last_message.member as last_message_member, + system_last_switch.switch as last_switch, + system_last_switch.members as last_switch_members, + system_last_switch.timestamp as last_switch_timestamp + from system + full join guild on true + left join last_message on true + + left join system_last_switch on system_last_switch.system = system.id + left join system_guild on system_guild.system = system.id and system_guild.guild = guild_id +$$ language sql stable rows 1; + + +-- Fetches info about proxying related to a given account/guild +-- Returns one row per member in system, should be used in conjuction with `message_context` too +create function proxy_members(account_id bigint, guild_id bigint) + returns table ( + id int, + proxy_tags proxy_tag[], + keep_proxy bool, + proxy_name text, + proxy_avatar text + ) +as $$ + select + -- Basic data + members.id as id, + members.proxy_tags as proxy_tags, + members.keep_proxy as keep_proxy, + + -- Proxy info + case + when systems.tag is not null then (coalesce(member_guild.display_name, members.display_name, members.name) || ' ' || systems.tag) + else coalesce(member_guild.display_name, members.display_name, members.name) + end as proxy_name, + coalesce(member_guild.avatar_url, members.avatar_url, systems.avatar_url) as proxy_avatar + from accounts + inner join systems on systems.id = accounts.system + inner join members on members.system = systems.id + left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id + where accounts.uid = account_id +$$ language sql stable rows 10; \ No newline at end of file diff --git a/PluralKit.Core/Database/ProxyMember.cs b/PluralKit.Core/Database/ProxyMember.cs deleted file mode 100644 index 97361cac..00000000 --- a/PluralKit.Core/Database/ProxyMember.cs +++ /dev/null @@ -1,26 +0,0 @@ -#nullable enable -using System.Collections.Generic; - -namespace PluralKit.Core -{ - /// - /// Model for the `proxy_info` PL/pgSQL function in `functions.sql` - /// - public class ProxyMember - { - public int SystemId { get; set; } - public int MemberId { get; set; } - public bool ProxyEnabled { get; set; } - public AutoproxyMode AutoproxyMode { get; set; } - public bool IsAutoproxyMember { get; set; } - public ulong? LatchMessage { get; set; } - public string ProxyName { get; set; } = ""; - public string? ProxyAvatar { get; set; } - public IReadOnlyCollection ProxyTags { get; set; } = new ProxyTag[0]; - public bool KeepProxy { get; set; } - - public IReadOnlyCollection ChannelBlacklist { get; set; } = new ulong[0]; - public IReadOnlyCollection LogBlacklist { get; set; } = new ulong[0]; - public ulong? LogChannel { get; set; } - } -} \ No newline at end of file diff --git a/PluralKit.Core/Database/Schemas.cs b/PluralKit.Core/Database/Schemas.cs index 2882c2f3..280fb214 100644 --- a/PluralKit.Core/Database/Schemas.cs +++ b/PluralKit.Core/Database/Schemas.cs @@ -46,7 +46,7 @@ namespace PluralKit.Core // Now, reapply views/functions (we deleted them above, no need to worry about conflicts) await ExecuteSqlFile($"{RootPath}.views.sql", conn, tx); - await ExecuteSqlFile($"{RootPath}.functions.sql", conn, tx); + await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx); // Finally, commit tx tx.Commit(); diff --git a/PluralKit.Core/Database/clean.sql b/PluralKit.Core/Database/clean.sql index b03b2ee2..d1a197a4 100644 --- a/PluralKit.Core/Database/clean.sql +++ b/PluralKit.Core/Database/clean.sql @@ -1,3 +1,5 @@ drop view if exists system_last_switch; drop view if exists member_list; -drop function if exists proxy_info; + +drop function if exists message_context; +drop function if exists proxy_members; \ No newline at end of file diff --git a/PluralKit.Core/Database/functions.sql b/PluralKit.Core/Database/functions.sql deleted file mode 100644 index 2f80ad9c..00000000 --- a/PluralKit.Core/Database/functions.sql +++ /dev/null @@ -1,85 +0,0 @@ --- Giant "mega-function" to find all information relevant for message proxying --- Returns one row per member, computes several properties from others -create function proxy_info(account_id bigint, guild_id bigint) - returns table - ( - -- Note: table type gets matched *by index*, not *by name* (make sure order here and in `select` match) - system_id int, -- from: systems.id - member_id int, -- from: members.id - proxy_tags proxy_tag[], -- from: members.proxy_tags - keep_proxy bool, -- from: members.keep_proxy - proxy_enabled bool, -- from: system_guild.proxy_enabled - proxy_name text, -- calculated: name we should proxy under - proxy_avatar text, -- calculated: avatar we should proxy with - autoproxy_mode int, -- from: system_guild.autoproxy_mode - is_autoproxy_member bool, -- calculated: should this member be used for AP? - latch_message bigint, -- calculated: last message from this account in this guild - channel_blacklist bigint[], -- from: servers.blacklist - log_blacklist bigint[], -- from: servers.log_blacklist - log_channel bigint -- from: servers.log_channel - ) -as -$$ -select - -- Basic data - systems.id as system_id, - members.id as member_id, - members.proxy_tags as proxy_tags, - members.keep_proxy as keep_proxy, - - -- Proxy info - coalesce(system_guild.proxy_enabled, true) as proxy_enabled, - case - when systems.tag is not null then (coalesce(member_guild.display_name, members.display_name, members.name) || ' ' || systems.tag) - else coalesce(member_guild.display_name, members.display_name, members.name) - end as proxy_name, - coalesce(member_guild.avatar_url, members.avatar_url, systems.avatar_url) as proxy_avatar, - - -- Autoproxy data - coalesce(system_guild.autoproxy_mode, 1) as autoproxy_mode, - - -- Autoproxy logic is essentially: "is this member the one we should autoproxy?" - case - -- Front mode: check if this is the first fronter - when system_guild.autoproxy_mode = 2 then members.id = (select sls.members[1] - from system_last_switch as sls - where sls.system = systems.id) - - -- Latch mode: check if this is the last proxier - when system_guild.autoproxy_mode = 3 then members.id = last_message_in_guild.member - - -- Member mode: check if this is the selected memebr - when system_guild.autoproxy_mode = 4 then members.id = system_guild.autoproxy_member - - -- no autoproxy: then this member definitely shouldn't be autoproxied :) - else false end as is_autoproxy_member, - - last_message_in_guild.mid as latch_message, - - -- Guild info - coalesce(servers.blacklist, array[]::bigint[]) as channel_blacklist, - coalesce(servers.log_blacklist, array[]::bigint[]) as log_blacklist, - servers.log_channel as log_channel -from accounts - -- Fetch guild info - left join servers on servers.id = guild_id - - -- Fetch the system for this account (w/ guild config) - inner join systems on systems.id = accounts.system - left join system_guild on system_guild.system = accounts.system and system_guild.guild = guild_id - - -- Fetch all members from this system (w/ guild config) - inner join members on members.system = systems.id - left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id - - -- Find ID and member for the last message sent in this guild - left join lateral (select mid, member - from messages - where messages.guild = guild_id - and messages.sender = account_id - and system_guild.autoproxy_mode = 3 - order by mid desc - limit 1) as last_message_in_guild on true -where accounts.uid = account_id; -$$ language sql stable - rows 10; \ No newline at end of file diff --git a/PluralKit.Core/Modules.cs b/PluralKit.Core/Modules.cs index 0a032451..bd4b82db 100644 --- a/PluralKit.Core/Modules.cs +++ b/PluralKit.Core/Modules.cs @@ -28,7 +28,6 @@ namespace PluralKit.Core builder.RegisterType().AsSelf(); builder.Populate(new ServiceCollection().AddMemoryCache()); - builder.RegisterType().AsSelf().SingleInstance(); } } diff --git a/PluralKit.Core/Services/PostgresDataStore.cs b/PluralKit.Core/Services/PostgresDataStore.cs index 0bcdd5fa..c81d3fb5 100644 --- a/PluralKit.Core/Services/PostgresDataStore.cs +++ b/PluralKit.Core/Services/PostgresDataStore.cs @@ -12,13 +12,11 @@ namespace PluralKit.Core { public class PostgresDataStore: IDataStore { private DbConnectionFactory _conn; private ILogger _logger; - private ProxyCache _cache; - public PostgresDataStore(DbConnectionFactory conn, ILogger logger, ProxyCache cache) + public PostgresDataStore(DbConnectionFactory conn, ILogger logger) { _conn = conn; _logger = logger; - _cache = cache; } public async Task> GetConflictingProxies(PKSystem system, ProxyTag tag) @@ -56,7 +54,6 @@ namespace PluralKit.Core { settings.AutoproxyMode, settings.AutoproxyMember }); - await _cache.InvalidateSystem(system); _logger.Information("Updated system guild settings {@SystemGuildSettings}", settings); } @@ -83,7 +80,6 @@ namespace PluralKit.Core { await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id }); _logger.Information("Linked system {System} to account {Account}", system.Id, accountId); - await _cache.InvalidateSystem(system); } public async Task RemoveAccount(PKSystem system, ulong accountId) { @@ -91,8 +87,6 @@ namespace PluralKit.Core { await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id }); _logger.Information("Unlinked system {System} from account {Account}", system.Id, accountId); - await _cache.InvalidateSystem(system); - _cache.InvalidateAccounts(new [] { accountId }); } public async Task GetSystemByAccount(ulong accountId) { @@ -121,7 +115,6 @@ namespace PluralKit.Core { await conn.ExecuteAsync("update systems set name = @Name, description = @Description, tag = @Tag, avatar_url = @AvatarUrl, token = @Token, ui_tz = @UiTz, description_privacy = @DescriptionPrivacy, member_list_privacy = @MemberListPrivacy, front_privacy = @FrontPrivacy, front_history_privacy = @FrontHistoryPrivacy, pings_enabled = @PingsEnabled where id = @Id", system); _logger.Information("Updated system {@System}", system); - await _cache.InvalidateSystem(system); } public async Task DeleteSystem(PKSystem system) @@ -133,7 +126,6 @@ namespace PluralKit.Core { await conn.ExecuteAsync("delete from systems where id = @Id", system); _logger.Information("Deleted system {System}", system.Id); - _cache.InvalidateDeletedSystem(system.Id, accounts); } public async Task> GetSystemAccounts(PKSystem system) @@ -170,7 +162,6 @@ namespace PluralKit.Core { }); _logger.Information("Created member {Member}", member.Id); - await _cache.InvalidateSystem(system); return member; } @@ -202,7 +193,6 @@ namespace PluralKit.Core { await conn.ExecuteAsync("update members set name = @Name, display_name = @DisplayName, description = @Description, color = @Color, avatar_url = @AvatarUrl, birthday = @Birthday, pronouns = @Pronouns, proxy_tags = @ProxyTags, keep_proxy = @KeepProxy, member_privacy = @MemberPrivacy where id = @Id", member); _logger.Information("Updated member {@Member}", member); - await _cache.InvalidateSystem(member.System); } public async Task DeleteMember(PKMember member) { @@ -210,7 +200,6 @@ namespace PluralKit.Core { await conn.ExecuteAsync("delete from members where id = @Id", member); _logger.Information("Deleted member {@Member}", member); - await _cache.InvalidateSystem(member.System); } public async Task GetMemberGuildSettings(PKMember member, ulong guild) @@ -228,7 +217,6 @@ namespace PluralKit.Core { "insert into member_guild (member, guild, display_name, avatar_url) values (@Member, @Guild, @DisplayName, @AvatarUrl) on conflict (member, guild) do update set display_name = @DisplayName, avatar_url = @AvatarUrl", settings); _logger.Information("Updated member guild settings {@MemberGuildSettings}", settings); - await _cache.InvalidateSystem(member.System); } public async Task GetSystemMemberCount(PKSystem system, bool includePrivate) @@ -350,7 +338,6 @@ namespace PluralKit.Core { Blacklist = cfg.Blacklist.Select(c => (long) c).ToList() }); _logger.Information("Updated guild configuration {@GuildCfg}", cfg); - _cache.InvalidateGuild(cfg.Id); } public async Task GetFirstFronter(PKSystem system) diff --git a/PluralKit.Core/Services/ProxyCacheService.cs b/PluralKit.Core/Services/ProxyCacheService.cs deleted file mode 100644 index bfd80f4a..00000000 --- a/PluralKit.Core/Services/ProxyCacheService.cs +++ /dev/null @@ -1,196 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; - -using Dapper; - -using Microsoft.Extensions.Caching.Memory; - -using Serilog; - -namespace PluralKit.Core -{ - public class ProxyCache - { - // We can NOT depend on IDataStore as that creates a cycle, since it needs access to call the invalidation methods - private IMemoryCache _cache; - private DbConnectionFactory _db; - private ILogger _logger; - - public ProxyCache(IMemoryCache cache, DbConnectionFactory db, ILogger logger) - { - _cache = cache; - _db = db; - _logger = logger; - } - - public Task InvalidateSystem(PKSystem system) => InvalidateSystem(system.Id); - - public void InvalidateAccounts(IEnumerable accounts) - { - foreach (var account in accounts) - _cache.Remove(KeyForAccount(account)); - } - - public void InvalidateDeletedSystem(int systemId, IEnumerable accounts) - { - // Used when the system's already removed so we can't look up accounts - // We assume the account list is saved already somewhere and can be passed here (which is the case in Store) - - _cache.Remove(KeyForSystem(systemId)); - InvalidateAccounts(accounts); - } - - public async Task InvalidateSystem(int systemId) - { - if (_cache.TryGetValue(KeyForSystem(systemId), out var systemCache)) - { - // If we have the system cached here, just invalidate for all the accounts we have in the cache - _logger.Debug("Invalidating cache for system {System} and accounts {Accounts}", systemId, systemCache.Accounts); - _cache.Remove(KeyForSystem(systemId)); - foreach (var account in systemCache.Accounts) - _cache.Remove(KeyForAccount(account)); - return; - } - - // If we don't, look up the accounts from the database and invalidate *those* - - _cache.Remove(KeyForSystem(systemId)); - using var conn = await _db.Obtain(); - var accounts = (await conn.QueryAsync("select uid from accounts where system = @System", new {System = systemId})).ToArray(); - _logger.Debug("Invalidating cache for system {System} and accounts {Accounts}", systemId, accounts); - foreach (var account in accounts) - _cache.Remove(KeyForAccount(account)); - } - - public void InvalidateGuild(ulong guild) - { - _logger.Debug("Invalidating cache for guild {Guild}", guild); - _cache.Remove(KeyForGuild(guild)); - } - - public async Task GetGuildDataCached(ulong guild) - { - if (_cache.TryGetValue(KeyForGuild(guild), out var item)) - { - _logger.Verbose("Cache hit for guild {Guild}", guild); - return item; - } - - // When changing this, also see PostgresDataStore::GetOrCreateGuildConfig - using var conn = await _db.Obtain(); - - _logger.Verbose("Cache miss for guild {Guild}", guild); - var guildConfig = (await conn.QuerySingleOrDefaultAsync( - "insert into servers (id) values (@Id) on conflict do nothing; select * from servers where id = @Id", - new {Id = guild})).Into(); - - _cache.CreateEntry(KeyForGuild(guild)) - .SetValue(guildConfig) - .SetSlidingExpiration(TimeSpan.FromMinutes(5)) - .SetAbsoluteExpiration(TimeSpan.FromMinutes(30)) - .Dispose(); // Don't ask, but this *saves* the entry. Somehow. - return guildConfig; - } - - public async Task GetAccountDataCached(ulong account) - { - if (_cache.TryGetValue(KeyForAccount(account), out var item)) - { - _logger.Verbose("Cache hit for account {Account}", account); - return item; - } - - _logger.Verbose("Cache miss for account {Account}", account); - - var data = await GetAccountData(account); - if (data == null) - { - _logger.Debug("Cached data for account {Account} (no system)", account); - - // If we didn't find any value, set a pretty long expiry and the value to null - _cache.CreateEntry(KeyForAccount(account)) - .SetValue(null) - .SetSlidingExpiration(TimeSpan.FromMinutes(5)) - .SetAbsoluteExpiration(TimeSpan.FromHours(1)) - .Dispose(); // Don't ask, but this *saves* the entry. Somehow. - return null; - } - - // If we *did* find the value, cache it for *every account in the system* with a shorter expiry - _logger.Debug("Cached data for system {System} and accounts {Account}", data.System.Id, data.Accounts); - foreach (var linkedAccount in data.Accounts) - { - _cache.CreateEntry(KeyForAccount(linkedAccount)) - .SetValue(data) - .SetSlidingExpiration(TimeSpan.FromMinutes(5)) - .SetAbsoluteExpiration(TimeSpan.FromMinutes(20)) - .Dispose(); // Don't ask, but this *saves* the entry. Somehow. - - // And also do it for the system itself so we can look up by that - _cache.CreateEntry(KeyForSystem(data.System.Id)) - .SetValue(data) - .SetSlidingExpiration(TimeSpan.FromMinutes(5)) - .SetAbsoluteExpiration(TimeSpan.FromMinutes(20)) - .Dispose(); // Don't ask, but this *saves* the entry. Somehow. - } - - return data; - } - - private async Task GetAccountData(ulong account) - { - using var conn = await _db.Obtain(); - - // Doing this as two queries instead of a two-step join to avoid sending duplicate rows for the system over the network for each member - // This *may* be less efficient, haven't done too much stuff about this but having the system ID saved is very useful later on - - var system = await conn.QuerySingleOrDefaultAsync("select systems.* from accounts inner join systems on systems.id = accounts.system where accounts.uid = @Account", new { Account = account }); - if (system == null) return null; // No system = no members = no cache value - - // Fetches: - // - List of accounts in the system - // - List of members in the system - // - List of guild settings for the system (for every guild) - // - List of guild settings for each member (for every guild) - // I'm slightly worried the volume of guild settings will get too much, but for simplicity reasons I decided - // against caching them individually per-guild, since I can't imagine they'll be edited *that* much - var result = await conn.QueryMultipleAsync(@" - select uid from accounts where system = @System; - select * from members where system = @System; - select * from system_guild where system = @System; - select member_guild.* from members inner join member_guild on member_guild.member = members.id where members.system = @System; - ", new {System = system.Id}); - - return new CachedAccount - { - System = system, - Accounts = (await result.ReadAsync()).ToArray(), - Members = (await result.ReadAsync()).ToArray(), - SystemGuild = (await result.ReadAsync()).ToArray(), - MemberGuild = (await result.ReadAsync()).ToArray() - }; - } - - private string KeyForAccount(ulong account) => $"_account_cache_{account}"; - private string KeyForSystem(int system) => $"_system_cache_{system}"; - private string KeyForGuild(ulong guild) => $"_guild_cache_{guild}"; - } - - public class CachedAccount - { - public PKSystem System; - public PKMember[] Members; - public SystemGuildSettings[] SystemGuild; - public MemberGuildSettings[] MemberGuild; - public ulong[] Accounts; - - public SystemGuildSettings SettingsForGuild(ulong guild) => - // O(n) lookup since n is small (max ~100 in prod) and we're more constrained by memory (for a dictionary) here - SystemGuild.FirstOrDefault(s => s.Guild == guild) ?? new SystemGuildSettings(); - - public MemberGuildSettings SettingsForMemberGuild(int memberId, ulong guild) => - MemberGuild.FirstOrDefault(m => m.Member == memberId && m.Guild == guild) ?? new MemberGuildSettings(); - } -} \ No newline at end of file diff --git a/PluralKit.Core/Utils/DatabaseUtils.cs b/PluralKit.Core/Utils/DatabaseUtils.cs index 710cac85..b7d90705 100644 --- a/PluralKit.Core/Utils/DatabaseUtils.cs +++ b/PluralKit.Core/Utils/DatabaseUtils.cs @@ -308,4 +308,13 @@ namespace PluralKit.Core parameter.Value = (long) value; } } + + public static class DatabaseExt + { + public static async Task Execute(this DbConnectionFactory db, Func> func) + { + await using var conn = await db.Obtain(); + return await func(conn); + } + } } \ No newline at end of file