Split message/proxy data up in MessageContext and ProxyMember

This commit is contained in:
Ske 2020-06-12 23:13:21 +02:00
parent ba441a15cc
commit 3d62a0d33c
18 changed files with 296 additions and 499 deletions

View File

@ -11,10 +11,6 @@ using DSharpPlus.EventArgs;
using PluralKit.Core;
using Sentry;
using Serilog;
namespace PluralKit.Bot
{
public class MessageCreated: IEventHandler<MessageCreateEventArgs>
@ -22,69 +18,69 @@ namespace PluralKit.Bot
private readonly CommandTree _tree;
private readonly DiscordShardedClient _client;
private readonly LastMessageCacheService _lastMessageCache;
private readonly ILogger _logger;
private readonly LoggerCleanService _loggerClean;
private readonly IMetrics _metrics;
private readonly ProxyService _proxy;
private readonly ProxyCache _proxyCache;
private readonly Scope _sentryScope;
private readonly ILifetimeScope _services;
private readonly DbConnectionFactory _db;
private readonly IDataStore _data;
public MessageCreated(LastMessageCacheService lastMessageCache, ILogger logger, LoggerCleanService loggerClean, IMetrics metrics, ProxyService proxy, ProxyCache proxyCache, Scope sentryScope, DiscordShardedClient client, CommandTree tree, ILifetimeScope services)
public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean,
IMetrics metrics, ProxyService proxy, DiscordShardedClient client,
CommandTree tree, ILifetimeScope services, DbConnectionFactory db, IDataStore data)
{
_lastMessageCache = lastMessageCache;
_logger = logger;
_loggerClean = loggerClean;
_metrics = metrics;
_proxy = proxy;
_proxyCache = proxyCache;
_sentryScope = sentryScope;
_client = client;
_tree = tree;
_services = services;
_db = db;
_data = data;
}
public DiscordChannel ErrorChannelFor(MessageCreateEventArgs evt) => evt.Channel;
private bool IsDuplicateMessage(DiscordMessage evt) =>
// We consider a message duplicate if it has the same ID as the previous message that hit the gateway
_lastMessageCache.GetLastMessage(evt.ChannelId) == evt.Id;
public async Task Handle(MessageCreateEventArgs evt)
{
// Drop the message if we've already received it.
// Gateway occasionally resends events for whatever reason and it can break stuff relying on IDs being unique
// Not a perfect fix since reordering could still break things but... it's good enough for now
// (was considering checking the order of the IDs but IDs aren't guaranteed to be *perfectly* in order, so that'd cause false positives)
// LastMessageCache is updated in RegisterMessageMetrics so the ordering here is correct.
if (_lastMessageCache.GetLastMessage(evt.Channel.Id) == evt.Message.Id) return;
RegisterMessageMetrics(evt);
if (evt.Message.MessageType != MessageType.Default) return;
if (IsDuplicateMessage(evt.Message)) return;
// Ignore system messages (member joined, message pinned, etc)
var msg = evt.Message;
if (msg.MessageType != MessageType.Default) return;
// Log metrics and message info
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
_lastMessageCache.AddMessage(evt.Channel.Id, evt.Message.Id);
var cachedGuild = await _proxyCache.GetGuildDataCached(msg.Channel.GuildId);
var cachedAccount = await _proxyCache.GetAccountDataCached(msg.Author.Id);
// this ^ may be null, do remember that down the line
// Pass guild bot/WH messages onto the logger cleanup service
if (msg.Author.IsBot && msg.Channel.Type == ChannelType.Text)
{
await _loggerClean.HandleLoggerBotCleanup(msg, cachedGuild);
return;
// Try each handler until we find one that succeeds
var ctx = await _db.Execute(c => c.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id));
var _ = await TryHandleLogClean(evt, ctx) ||
await TryHandleCommand(evt, ctx) ||
await TryHandleProxy(evt, ctx);
}
// First try parsing a command, then try proxying
if (await TryHandleCommand(evt, cachedGuild, cachedAccount)) return;
await TryHandleProxy(evt, cachedGuild, cachedAccount);
private async ValueTask<bool> TryHandleLogClean(MessageCreateEventArgs evt, MessageContext ctx)
{
if (!evt.Message.Author.IsBot || evt.Message.Channel.Type != ChannelType.Text ||
!ctx.LogCleanupEnabled) return false;
await _loggerClean.HandleLoggerBotCleanup(evt.Message);
return true;
}
private async Task<bool> TryHandleCommand(MessageCreateEventArgs evt, GuildConfig cachedGuild, CachedAccount cachedAccount)
private async ValueTask<bool> TryHandleCommand(MessageCreateEventArgs evt, MessageContext ctx)
{
var msg = evt.Message;
var content = evt.Message.Content;
if (content == null) return false;
int argPos = -1;
var argPos = -1;
// Check if message starts with the command prefix
if (msg.Content.StartsWith("pk;", StringComparison.InvariantCultureIgnoreCase)) argPos = 3;
else if (msg.Content.StartsWith("pk!", StringComparison.InvariantCultureIgnoreCase)) argPos = 3;
else if (msg.Content != null && StringUtils.HasMentionPrefix(msg.Content, ref argPos, out var id)) // Set argPos to the proper value
if (content.StartsWith("pk;", StringComparison.InvariantCultureIgnoreCase)) argPos = 3;
else if (content.StartsWith("pk!", StringComparison.InvariantCultureIgnoreCase)) argPos = 3;
else if (StringUtils.HasMentionPrefix(content, ref argPos, out var id)) // Set argPos to the proper value
if (id != _client.CurrentUser.Id) // But undo it if it's someone else's ping
argPos = -1;
@ -93,12 +89,13 @@ namespace PluralKit.Bot
// Trim leading whitespace from command without actually modifying the wring
// This just moves the argPos pointer by however much whitespace is at the start of the post-argPos string
var trimStartLengthDiff = msg.Content.Substring(argPos).Length - msg.Content.Substring(argPos).TrimStart().Length;
var trimStartLengthDiff = content.Substring(argPos).Length - content.Substring(argPos).TrimStart().Length;
argPos += trimStartLengthDiff;
try
{
await _tree.ExecuteCommand(new Context(_services, evt.Client, msg, argPos, cachedAccount?.System));
var system = ctx.SystemId != null ? await _data.GetSystemById(ctx.SystemId.Value) : null;
await _tree.ExecuteCommand(new Context(_services, evt.Client, evt.Message, argPos, system));
}
catch (PKError)
{
@ -109,31 +106,21 @@ namespace PluralKit.Bot
return true;
}
private async Task<bool> TryHandleProxy(MessageCreateEventArgs evt, GuildConfig cachedGuild, CachedAccount cachedAccount)
private async ValueTask<bool> TryHandleProxy(MessageCreateEventArgs evt, MessageContext ctx)
{
var msg = evt.Message;
// If we don't have any cached account data, this means no member in the account has a proxy tag set
if (cachedAccount == null) return false;
try
{
await _proxy.HandleIncomingMessage(evt.Message, allowAutoproxy: true);
return await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: true);
}
catch (PKError e)
{
// User-facing errors, print to the channel properly formatted
var msg = evt.Message;
if (msg.Channel.Guild == null || msg.Channel.BotHasAllPermissions(Permissions.SendMessages))
await msg.Channel.SendMessageAsync($"{Emojis.Error} {e.Message}");
}
return true;
}
private void RegisterMessageMetrics(MessageCreateEventArgs evt)
{
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
_lastMessageCache.AddMessage(evt.Channel.Id, evt.Message.Id);
return false;
}
}
}

View File

@ -1,12 +1,9 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DSharpPlus.EventArgs;
using PluralKit.Core;
using Sentry;
namespace PluralKit.Bot
{
@ -14,38 +11,26 @@ namespace PluralKit.Bot
{
private readonly LastMessageCacheService _lastMessageCache;
private readonly ProxyService _proxy;
private readonly ProxyCache _proxyCache;
private readonly Scope _sentryScope;
private readonly DbConnectionFactory _db;
public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, ProxyCache proxyCache, Scope sentryScope)
public MessageEdited(LastMessageCacheService lastMessageCache, ProxyService proxy, DbConnectionFactory db)
{
_lastMessageCache = lastMessageCache;
_proxy = proxy;
_proxyCache = proxyCache;
_sentryScope = sentryScope;
_db = db;
}
public async Task Handle(MessageUpdateEventArgs evt)
{
// Sometimes edit message events arrive for other reasons (eg. an embed gets updated server-side)
// If this wasn't a *content change* (ie. there's message contents to read), bail
// It'll also sometimes arrive with no *author*, so we'll go ahead and ignore those messages too
if (evt.Message.Content == null) return;
if (evt.Author == null) return;
// Edit message events sometimes arrive with missing data; double-check it's all there
if (evt.Message.Content == null || evt.Author == null || evt.Channel.Guild == null) return;
// Also, if this is in DMs don't bother either
if (evt.Channel.Guild == null) return;
// If this isn't the last message in the channel, don't do anything
// Only react to the last message in the channel
if (_lastMessageCache.GetLastMessage(evt.Channel.Id) != evt.Message.Id) return;
// Fetch account and guild info from cache if there is any
var account = await _proxyCache.GetAccountDataCached(evt.Author.Id);
if (account == null) return; // Again: no cache = no account = no system = no proxy
var guild = await _proxyCache.GetGuildDataCached(evt.Channel.GuildId);
// Just run the normal message handling stuff, with a flag to disable autoproxying
await _proxy.HandleIncomingMessage(evt.Message, allowAutoproxy: false);
// Just run the normal message handling code, with a flag to disable autoproxying
var ctx = await _db.Execute(c => c.QueryMessageContext(evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id));
await _proxy.HandleIncomingMessage(evt.Message, ctx, allowAutoproxy: false);
}
}
}

View File

@ -20,11 +20,11 @@ namespace PluralKit.Bot
_clock = clock;
}
public bool TryMatch(IReadOnlyCollection<ProxyMember> members, out ProxyMatch match, string messageContent,
public bool TryMatch(MessageContext ctx, IReadOnlyCollection<ProxyMember> members, out ProxyMatch match, string messageContent,
bool hasAttachments, bool allowAutoproxy)
{
if (TryMatchTags(members, messageContent, hasAttachments, out match)) return true;
if (allowAutoproxy && TryMatchAutoproxy(members, messageContent, out match)) return true;
if (allowAutoproxy && TryMatchAutoproxy(ctx, members, messageContent, out match)) return true;
return false;
}
@ -37,33 +37,44 @@ namespace PluralKit.Bot
return hasAttachments || match.Content.Length > 0;
}
private bool TryMatchAutoproxy(IReadOnlyCollection<ProxyMember> members, string messageContent,
private bool TryMatchAutoproxy(MessageContext ctx, IReadOnlyCollection<ProxyMember> members, string messageContent,
out ProxyMatch match)
{
match = default;
// We handle most autoproxy logic in the database function, so we just look for the member that's marked
var info = members.FirstOrDefault(i => i.IsAutoproxyMember);
if (info == null) return false;
// If we're in latch mode and the latch message is too old, fail the match too
if (info.AutoproxyMode == AutoproxyMode.Latch && info.LatchMessage != null)
// Find the member we should autoproxy (null if none)
var member = ctx.AutoproxyMode switch
{
var timestamp = DiscordUtils.SnowflakeToInstant(info.LatchMessage.Value);
if (_clock.GetCurrentInstant() - timestamp > LatchExpiryTime) return false;
}
AutoproxyMode.Member when ctx.AutoproxyMember != null =>
members.FirstOrDefault(m => m.Id == ctx.AutoproxyMember),
// Match succeeded, build info object and return
AutoproxyMode.Front when ctx.LastSwitchMembers.Count > 0 =>
members.FirstOrDefault(m => m.Id == ctx.LastSwitchMembers[0]),
AutoproxyMode.Latch when ctx.LastMessageMember != null && !IsLatchExpired(ctx.LastMessage) =>
members.FirstOrDefault(m => m.Id == ctx.LastMessageMember.Value),
_ => null
};
if (member == null) return false;
match = new ProxyMatch
{
Content = messageContent,
Member = info,
Member = member,
// We're autoproxying, so not using any proxy tags here
// we just find the first pair of tags (if any), otherwise null
ProxyTags = info.ProxyTags.FirstOrDefault()
ProxyTags = member.ProxyTags.FirstOrDefault()
};
return true;
}
private bool IsLatchExpired(ulong? messageId)
{
if (messageId == null) return true;
var timestamp = DiscordUtils.SnowflakeToInstant(messageId.Value);
return _clock.GetCurrentInstant() - timestamp > LatchExpiryTime;
}
}
}

View File

@ -20,11 +20,11 @@ namespace PluralKit.Bot
{
public static readonly TimeSpan MessageDeletionDelay = TimeSpan.FromMilliseconds(1000);
private LogChannelService _logChannel;
private DbConnectionFactory _db;
private IDataStore _data;
private ILogger _logger;
private WebhookExecutorService _webhookExecutor;
private readonly LogChannelService _logChannel;
private readonly DbConnectionFactory _db;
private readonly IDataStore _data;
private readonly ILogger _logger;
private readonly WebhookExecutorService _webhookExecutor;
private readonly ProxyMatcher _matcher;
public ProxyService(LogChannelService logChannel, IDataStore data, ILogger logger,
@ -38,35 +38,57 @@ namespace PluralKit.Bot
_logger = logger.ForContext<ProxyService>();
}
public async Task HandleIncomingMessage(DiscordMessage message, bool allowAutoproxy)
public async Task<bool> HandleIncomingMessage(DiscordMessage message, MessageContext ctx, bool allowAutoproxy)
{
// Quick context checks to quit early
if (!IsMessageValid(message)) return;
if (!ShouldProxy(message, ctx)) return false;
// Fetch members and try to match to a specific member
var members = await FetchProxyMembers(message.Author.Id, message.Channel.GuildId);
if (!_matcher.TryMatch(members, out var match, message.Content, message.Attachments.Count > 0,
allowAutoproxy)) return;
var members = (await _db.Execute(c => c.QueryProxyMembers(message.Author.Id, message.Channel.GuildId))).ToList();
if (!_matcher.TryMatch(ctx, members, out var match, message.Content, message.Attachments.Count > 0,
allowAutoproxy)) return false;
// Do some quick permission checks before going through with the proxy
// (do channel checks *after* checking other perms to make sure we don't spam errors when eg. channel is blacklisted)
if (!IsProxyValid(message, match)) return;
if (!await CheckBotPermissionsOrError(message.Channel)) return;
if (!CheckProxyNameBoundsOrError(match)) return;
// Permission check after proxy match so we don't get spammed when not actually proxying
if (!await CheckBotPermissionsOrError(message.Channel)) return false;
if (!CheckProxyNameBoundsOrError(match)) return false;
// Everything's in order, we can execute the proxy!
await ExecuteProxy(message, match);
await ExecuteProxy(message, ctx, match);
return true;
}
private async Task ExecuteProxy(DiscordMessage trigger, ProxyMatch match)
private bool ShouldProxy(DiscordMessage msg, MessageContext ctx)
{
// Make sure author has a system
if (ctx.SystemId == null) return false;
// Make sure channel is a guild text channel and this is a normal message
if (msg.Channel.Type != ChannelType.Text || msg.MessageType != MessageType.Default) return false;
// Make sure author is a normal user
if (msg.Author.IsSystem == true || msg.Author.IsBot || msg.WebhookMessage) return false;
// Make sure proxying is enabled here
if (!ctx.ProxyEnabled || ctx.InBlacklist) return false;
// Make sure we have either an attachment or message content
var isMessageBlank = msg.Content == null || msg.Content.Trim().Length == 0;
if (isMessageBlank && msg.Attachments.Count == 0) return false;
// All good!
return true;
}
private async Task ExecuteProxy(DiscordMessage trigger, MessageContext ctx, ProxyMatch match)
{
// Send the webhook
var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName, match.Member.ProxyAvatar,
var id = await _webhookExecutor.ExecuteWebhook(trigger.Channel, match.Member.ProxyName,
match.Member.ProxyAvatar,
match.Content, trigger.Attachments);
// Handle post-proxy actions
await _data.AddMessage(trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id, match.Member.MemberId);
await _logChannel.LogMessage(match, trigger, id);
await _data.AddMessage(trigger.Author.Id, trigger.Channel.GuildId, trigger.Channel.Id, id, trigger.Id,
match.Member.Id);
await _logChannel.LogMessage(ctx, match, trigger, id);
// Wait a second or so before deleting the original message
await Task.Delay(MessageDeletionDelay);
@ -81,14 +103,6 @@ namespace PluralKit.Bot
}
}
private async Task<IReadOnlyCollection<ProxyMember>> FetchProxyMembers(ulong account, ulong guild)
{
await using var conn = await _db.Obtain();
var members = await conn.QueryAsync<ProxyMember>("proxy_info",
new {account_id = account, guild_id = guild}, commandType: CommandType.StoredProcedure);
return members.ToList();
}
private async Task<bool> CheckBotPermissionsOrError(DiscordChannel channel)
{
var permissions = channel.BotPermissions();
@ -124,33 +138,5 @@ namespace PluralKit.Bot
// TODO: this never returns false as it throws instead, should this happen?
return true;
}
private bool IsMessageValid(DiscordMessage message)
{
return
// Must be a guild text channel
message.Channel.Type == ChannelType.Text &&
// Must not be a system message
message.MessageType == MessageType.Default &&
!(message.Author.IsSystem ?? false) &&
// Must not be a bot or webhook message
!message.WebhookMessage &&
!message.Author.IsBot &&
// Must have either an attachment or content (or both, but not neither)
(message.Attachments.Count > 0 || (message.Content != null && message.Content.Trim().Length > 0));
}
private bool IsProxyValid(DiscordMessage message, ProxyMatch match)
{
return
// System and member must have proxying enabled in this guild
match.Member.ProxyEnabled &&
// Channel must not be blacklisted here
!match.Member.ChannelBlacklist.Contains(message.ChannelId);
}
}
}

View File

@ -28,12 +28,12 @@ namespace PluralKit.Bot {
_logger = logger.ForContext<LogChannelService>();
}
public async Task LogMessage(ProxyMatch proxy, DiscordMessage trigger, ulong hookMessage)
public async ValueTask LogMessage(MessageContext ctx, ProxyMatch proxy, DiscordMessage trigger, ulong hookMessage)
{
if (proxy.Member.LogChannel == null || proxy.Member.LogBlacklist.Contains(trigger.ChannelId)) return;
if (ctx.SystemId == null || ctx.LogChannel == null || ctx.InLogBlacklist) return;
// Find log channel and check if valid
var logChannel = await FindLogChannel(trigger.Channel.GuildId, proxy);
var logChannel = await FindLogChannel(trigger.Channel.GuildId, ctx.LogChannel.Value);
if (logChannel == null || logChannel.Type != ChannelType.Text) return;
// Check bot permissions
@ -41,25 +41,23 @@ namespace PluralKit.Bot {
// Send embed!
await using var conn = await _db.Obtain();
var embed = _embed.CreateLoggedMessageEmbed(await _data.GetSystemById(proxy.Member.SystemId),
await _data.GetMemberById(proxy.Member.MemberId), hookMessage, trigger.Id, trigger.Author, proxy.Content,
var embed = _embed.CreateLoggedMessageEmbed(await _data.GetSystemById(ctx.SystemId.Value),
await _data.GetMemberById(proxy.Member.Id), hookMessage, trigger.Id, trigger.Author, proxy.Content,
trigger.Channel);
var url = $"https://discord.com/channels/{trigger.Channel.GuildId}/{trigger.ChannelId}/{hookMessage}";
await logChannel.SendMessageAsync(content: url, embed: embed);
}
private async Task<DiscordChannel> FindLogChannel(ulong guild, ProxyMatch proxy)
private async Task<DiscordChannel> FindLogChannel(ulong guild, ulong channel)
{
var logChannel = proxy.Member.LogChannel.Value;
try
{
return await _rest.GetChannelAsync(logChannel);
return await _rest.GetChannelAsync(channel);
}
catch (NotFoundException)
{
// Channel doesn't exist, let's remove it from the database too
_logger.Warning("Attempted to fetch missing log channel {LogChannel}, removing from database", logChannel);
_logger.Warning("Attempted to fetch missing log channel {LogChannel}, removing from database", channel);
await using var conn = await _db.Obtain();
await conn.ExecuteAsync("update servers set log_channel = null where server = @Guild",
new {Guild = guild});

View File

@ -64,10 +64,8 @@ namespace PluralKit.Bot
public ICollection<LoggerBot> Bots => _bots.Values;
public async ValueTask HandleLoggerBotCleanup(DiscordMessage msg, GuildConfig cachedGuild)
public async ValueTask HandleLoggerBotCleanup(DiscordMessage msg)
{
// Bail if not enabled, or if we don't have permission here
if (!cachedGuild.LogCleanupEnabled) return;
if (msg.Channel.Type != ChannelType.Text) return;
if (!msg.Channel.BotHasAllPermissions(Permissions.ManageMessages)) return;

View File

@ -0,0 +1,25 @@
using System.Collections.Generic;
using System.Data;
using System.Threading.Tasks;
using Dapper;
namespace PluralKit.Core
{
public static class DatabaseFunctionsExt
{
public static Task<MessageContext> QueryMessageContext(this IDbConnection conn, ulong account, ulong guild, ulong channel)
{
return conn.QueryFirstAsync<MessageContext>("message_context",
new { account_id = account, guild_id = guild, channel_id = channel },
commandType: CommandType.StoredProcedure);
}
public static Task<IEnumerable<ProxyMember>> QueryProxyMembers(this IDbConnection conn, ulong account, ulong guild)
{
return conn.QueryAsync<ProxyMember>("proxy_members",
new { account_id = account, guild_id = guild },
commandType: CommandType.StoredProcedure);
}
}
}

View File

@ -0,0 +1,27 @@
#nullable enable
using System.Collections.Generic;
using NodaTime;
namespace PluralKit.Core
{
/// <summary>
/// Model for the `message_context` PL/pgSQL function in `functions.sql`
/// </summary>
public class MessageContext
{
public int? SystemId { get; set; }
public ulong? LogChannel { get; set; }
public bool InBlacklist { get; set; }
public bool InLogBlacklist { get; set; }
public bool LogCleanupEnabled { get; set; }
public bool ProxyEnabled { get; set; }
public AutoproxyMode AutoproxyMode { get; set; }
public int? AutoproxyMember { get; set; }
public ulong? LastMessage { get; set; }
public int? LastMessageMember { get; set; }
public int LastSwitch { get; set; }
public IReadOnlyList<int> LastSwitchMembers { get; set; } = new int[0];
public Instant LastSwitchTimestamp { get; set; }
}
}

View File

@ -0,0 +1,17 @@
#nullable enable
using System.Collections.Generic;
namespace PluralKit.Core
{
/// <summary>
/// Model for the `proxy_members` PL/pgSQL function in `functions.sql`
/// </summary>
public class ProxyMember
{
public int Id { get; set; }
public IReadOnlyCollection<ProxyTag> ProxyTags { get; set; } = new ProxyTag[0];
public bool KeepProxy { get; set; }
public string ProxyName { get; set; } = "";
public string? ProxyAvatar { get; set; }
}
}

View File

@ -0,0 +1,73 @@
create function message_context(account_id bigint, guild_id bigint, channel_id bigint)
returns table (
system_id int,
log_channel bigint,
in_blacklist bool,
in_log_blacklist bool,
log_cleanup_enabled bool,
proxy_enabled bool,
autoproxy_mode int,
autoproxy_member int,
last_message bigint,
last_message_member int,
last_switch int,
last_switch_members int[],
last_switch_timestamp timestamp
)
as $$
with
system as (select systems.* from accounts inner join systems on systems.id = accounts.system where accounts.uid = account_id),
guild as (select * from servers where id = guild_id),
last_message as (select * from messages where messages.guild = guild_id and messages.sender = account_id order by mid desc limit 1)
select
system.id as system_id,
guild.log_channel,
(channel_id = any(guild.blacklist)) as in_blacklist,
(channel_id = any(guild.log_blacklist)) as in_log_blacklist,
guild.log_cleanup_enabled,
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
coalesce(system_guild.autoproxy_mode, 1) as autoproxy_mode,
system_guild.autoproxy_member,
last_message.mid as last_message,
last_message.member as last_message_member,
system_last_switch.switch as last_switch,
system_last_switch.members as last_switch_members,
system_last_switch.timestamp as last_switch_timestamp
from system
full join guild on true
left join last_message on true
left join system_last_switch on system_last_switch.system = system.id
left join system_guild on system_guild.system = system.id and system_guild.guild = guild_id
$$ language sql stable rows 1;
-- Fetches info about proxying related to a given account/guild
-- Returns one row per member in system, should be used in conjuction with `message_context` too
create function proxy_members(account_id bigint, guild_id bigint)
returns table (
id int,
proxy_tags proxy_tag[],
keep_proxy bool,
proxy_name text,
proxy_avatar text
)
as $$
select
-- Basic data
members.id as id,
members.proxy_tags as proxy_tags,
members.keep_proxy as keep_proxy,
-- Proxy info
case
when systems.tag is not null then (coalesce(member_guild.display_name, members.display_name, members.name) || ' ' || systems.tag)
else coalesce(member_guild.display_name, members.display_name, members.name)
end as proxy_name,
coalesce(member_guild.avatar_url, members.avatar_url, systems.avatar_url) as proxy_avatar
from accounts
inner join systems on systems.id = accounts.system
inner join members on members.system = systems.id
left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id
where accounts.uid = account_id
$$ language sql stable rows 10;

View File

@ -1,26 +0,0 @@
#nullable enable
using System.Collections.Generic;
namespace PluralKit.Core
{
/// <summary>
/// Model for the `proxy_info` PL/pgSQL function in `functions.sql`
/// </summary>
public class ProxyMember
{
public int SystemId { get; set; }
public int MemberId { get; set; }
public bool ProxyEnabled { get; set; }
public AutoproxyMode AutoproxyMode { get; set; }
public bool IsAutoproxyMember { get; set; }
public ulong? LatchMessage { get; set; }
public string ProxyName { get; set; } = "";
public string? ProxyAvatar { get; set; }
public IReadOnlyCollection<ProxyTag> ProxyTags { get; set; } = new ProxyTag[0];
public bool KeepProxy { get; set; }
public IReadOnlyCollection<ulong> ChannelBlacklist { get; set; } = new ulong[0];
public IReadOnlyCollection<ulong> LogBlacklist { get; set; } = new ulong[0];
public ulong? LogChannel { get; set; }
}
}

View File

@ -46,7 +46,7 @@ namespace PluralKit.Core
// Now, reapply views/functions (we deleted them above, no need to worry about conflicts)
await ExecuteSqlFile($"{RootPath}.views.sql", conn, tx);
await ExecuteSqlFile($"{RootPath}.functions.sql", conn, tx);
await ExecuteSqlFile($"{RootPath}.Functions.functions.sql", conn, tx);
// Finally, commit tx
tx.Commit();

View File

@ -1,3 +1,5 @@
drop view if exists system_last_switch;
drop view if exists member_list;
drop function if exists proxy_info;
drop function if exists message_context;
drop function if exists proxy_members;

View File

@ -1,85 +0,0 @@
-- Giant "mega-function" to find all information relevant for message proxying
-- Returns one row per member, computes several properties from others
create function proxy_info(account_id bigint, guild_id bigint)
returns table
(
-- Note: table type gets matched *by index*, not *by name* (make sure order here and in `select` match)
system_id int, -- from: systems.id
member_id int, -- from: members.id
proxy_tags proxy_tag[], -- from: members.proxy_tags
keep_proxy bool, -- from: members.keep_proxy
proxy_enabled bool, -- from: system_guild.proxy_enabled
proxy_name text, -- calculated: name we should proxy under
proxy_avatar text, -- calculated: avatar we should proxy with
autoproxy_mode int, -- from: system_guild.autoproxy_mode
is_autoproxy_member bool, -- calculated: should this member be used for AP?
latch_message bigint, -- calculated: last message from this account in this guild
channel_blacklist bigint[], -- from: servers.blacklist
log_blacklist bigint[], -- from: servers.log_blacklist
log_channel bigint -- from: servers.log_channel
)
as
$$
select
-- Basic data
systems.id as system_id,
members.id as member_id,
members.proxy_tags as proxy_tags,
members.keep_proxy as keep_proxy,
-- Proxy info
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
case
when systems.tag is not null then (coalesce(member_guild.display_name, members.display_name, members.name) || ' ' || systems.tag)
else coalesce(member_guild.display_name, members.display_name, members.name)
end as proxy_name,
coalesce(member_guild.avatar_url, members.avatar_url, systems.avatar_url) as proxy_avatar,
-- Autoproxy data
coalesce(system_guild.autoproxy_mode, 1) as autoproxy_mode,
-- Autoproxy logic is essentially: "is this member the one we should autoproxy?"
case
-- Front mode: check if this is the first fronter
when system_guild.autoproxy_mode = 2 then members.id = (select sls.members[1]
from system_last_switch as sls
where sls.system = systems.id)
-- Latch mode: check if this is the last proxier
when system_guild.autoproxy_mode = 3 then members.id = last_message_in_guild.member
-- Member mode: check if this is the selected memebr
when system_guild.autoproxy_mode = 4 then members.id = system_guild.autoproxy_member
-- no autoproxy: then this member definitely shouldn't be autoproxied :)
else false end as is_autoproxy_member,
last_message_in_guild.mid as latch_message,
-- Guild info
coalesce(servers.blacklist, array[]::bigint[]) as channel_blacklist,
coalesce(servers.log_blacklist, array[]::bigint[]) as log_blacklist,
servers.log_channel as log_channel
from accounts
-- Fetch guild info
left join servers on servers.id = guild_id
-- Fetch the system for this account (w/ guild config)
inner join systems on systems.id = accounts.system
left join system_guild on system_guild.system = accounts.system and system_guild.guild = guild_id
-- Fetch all members from this system (w/ guild config)
inner join members on members.system = systems.id
left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id
-- Find ID and member for the last message sent in this guild
left join lateral (select mid, member
from messages
where messages.guild = guild_id
and messages.sender = account_id
and system_guild.autoproxy_mode = 3
order by mid desc
limit 1) as last_message_in_guild on true
where accounts.uid = account_id;
$$ language sql stable
rows 10;

View File

@ -28,7 +28,6 @@ namespace PluralKit.Core
builder.RegisterType<Schemas>().AsSelf();
builder.Populate(new ServiceCollection().AddMemoryCache());
builder.RegisterType<ProxyCache>().AsSelf().SingleInstance();
}
}

View File

@ -12,13 +12,11 @@ namespace PluralKit.Core {
public class PostgresDataStore: IDataStore {
private DbConnectionFactory _conn;
private ILogger _logger;
private ProxyCache _cache;
public PostgresDataStore(DbConnectionFactory conn, ILogger logger, ProxyCache cache)
public PostgresDataStore(DbConnectionFactory conn, ILogger logger)
{
_conn = conn;
_logger = logger;
_cache = cache;
}
public async Task<IEnumerable<PKMember>> GetConflictingProxies(PKSystem system, ProxyTag tag)
@ -56,7 +54,6 @@ namespace PluralKit.Core {
settings.AutoproxyMode,
settings.AutoproxyMember
});
await _cache.InvalidateSystem(system);
_logger.Information("Updated system guild settings {@SystemGuildSettings}", settings);
}
@ -83,7 +80,6 @@ namespace PluralKit.Core {
await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id });
_logger.Information("Linked system {System} to account {Account}", system.Id, accountId);
await _cache.InvalidateSystem(system);
}
public async Task RemoveAccount(PKSystem system, ulong accountId) {
@ -91,8 +87,6 @@ namespace PluralKit.Core {
await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id });
_logger.Information("Unlinked system {System} from account {Account}", system.Id, accountId);
await _cache.InvalidateSystem(system);
_cache.InvalidateAccounts(new [] { accountId });
}
public async Task<PKSystem> GetSystemByAccount(ulong accountId) {
@ -121,7 +115,6 @@ namespace PluralKit.Core {
await conn.ExecuteAsync("update systems set name = @Name, description = @Description, tag = @Tag, avatar_url = @AvatarUrl, token = @Token, ui_tz = @UiTz, description_privacy = @DescriptionPrivacy, member_list_privacy = @MemberListPrivacy, front_privacy = @FrontPrivacy, front_history_privacy = @FrontHistoryPrivacy, pings_enabled = @PingsEnabled where id = @Id", system);
_logger.Information("Updated system {@System}", system);
await _cache.InvalidateSystem(system);
}
public async Task DeleteSystem(PKSystem system)
@ -133,7 +126,6 @@ namespace PluralKit.Core {
await conn.ExecuteAsync("delete from systems where id = @Id", system);
_logger.Information("Deleted system {System}", system.Id);
_cache.InvalidateDeletedSystem(system.Id, accounts);
}
public async Task<IEnumerable<ulong>> GetSystemAccounts(PKSystem system)
@ -170,7 +162,6 @@ namespace PluralKit.Core {
});
_logger.Information("Created member {Member}", member.Id);
await _cache.InvalidateSystem(system);
return member;
}
@ -202,7 +193,6 @@ namespace PluralKit.Core {
await conn.ExecuteAsync("update members set name = @Name, display_name = @DisplayName, description = @Description, color = @Color, avatar_url = @AvatarUrl, birthday = @Birthday, pronouns = @Pronouns, proxy_tags = @ProxyTags, keep_proxy = @KeepProxy, member_privacy = @MemberPrivacy where id = @Id", member);
_logger.Information("Updated member {@Member}", member);
await _cache.InvalidateSystem(member.System);
}
public async Task DeleteMember(PKMember member) {
@ -210,7 +200,6 @@ namespace PluralKit.Core {
await conn.ExecuteAsync("delete from members where id = @Id", member);
_logger.Information("Deleted member {@Member}", member);
await _cache.InvalidateSystem(member.System);
}
public async Task<MemberGuildSettings> GetMemberGuildSettings(PKMember member, ulong guild)
@ -228,7 +217,6 @@ namespace PluralKit.Core {
"insert into member_guild (member, guild, display_name, avatar_url) values (@Member, @Guild, @DisplayName, @AvatarUrl) on conflict (member, guild) do update set display_name = @DisplayName, avatar_url = @AvatarUrl",
settings);
_logger.Information("Updated member guild settings {@MemberGuildSettings}", settings);
await _cache.InvalidateSystem(member.System);
}
public async Task<int> GetSystemMemberCount(PKSystem system, bool includePrivate)
@ -350,7 +338,6 @@ namespace PluralKit.Core {
Blacklist = cfg.Blacklist.Select(c => (long) c).ToList()
});
_logger.Information("Updated guild configuration {@GuildCfg}", cfg);
_cache.InvalidateGuild(cfg.Id);
}
public async Task<PKMember> GetFirstFronter(PKSystem system)

View File

@ -1,196 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Dapper;
using Microsoft.Extensions.Caching.Memory;
using Serilog;
namespace PluralKit.Core
{
public class ProxyCache
{
// We can NOT depend on IDataStore as that creates a cycle, since it needs access to call the invalidation methods
private IMemoryCache _cache;
private DbConnectionFactory _db;
private ILogger _logger;
public ProxyCache(IMemoryCache cache, DbConnectionFactory db, ILogger logger)
{
_cache = cache;
_db = db;
_logger = logger;
}
public Task InvalidateSystem(PKSystem system) => InvalidateSystem(system.Id);
public void InvalidateAccounts(IEnumerable<ulong> accounts)
{
foreach (var account in accounts)
_cache.Remove(KeyForAccount(account));
}
public void InvalidateDeletedSystem(int systemId, IEnumerable<ulong> accounts)
{
// Used when the system's already removed so we can't look up accounts
// We assume the account list is saved already somewhere and can be passed here (which is the case in Store)
_cache.Remove(KeyForSystem(systemId));
InvalidateAccounts(accounts);
}
public async Task InvalidateSystem(int systemId)
{
if (_cache.TryGetValue<CachedAccount>(KeyForSystem(systemId), out var systemCache))
{
// If we have the system cached here, just invalidate for all the accounts we have in the cache
_logger.Debug("Invalidating cache for system {System} and accounts {Accounts}", systemId, systemCache.Accounts);
_cache.Remove(KeyForSystem(systemId));
foreach (var account in systemCache.Accounts)
_cache.Remove(KeyForAccount(account));
return;
}
// If we don't, look up the accounts from the database and invalidate *those*
_cache.Remove(KeyForSystem(systemId));
using var conn = await _db.Obtain();
var accounts = (await conn.QueryAsync<ulong>("select uid from accounts where system = @System", new {System = systemId})).ToArray();
_logger.Debug("Invalidating cache for system {System} and accounts {Accounts}", systemId, accounts);
foreach (var account in accounts)
_cache.Remove(KeyForAccount(account));
}
public void InvalidateGuild(ulong guild)
{
_logger.Debug("Invalidating cache for guild {Guild}", guild);
_cache.Remove(KeyForGuild(guild));
}
public async Task<GuildConfig> GetGuildDataCached(ulong guild)
{
if (_cache.TryGetValue<GuildConfig>(KeyForGuild(guild), out var item))
{
_logger.Verbose("Cache hit for guild {Guild}", guild);
return item;
}
// When changing this, also see PostgresDataStore::GetOrCreateGuildConfig
using var conn = await _db.Obtain();
_logger.Verbose("Cache miss for guild {Guild}", guild);
var guildConfig = (await conn.QuerySingleOrDefaultAsync<PostgresDataStore.DatabaseCompatibleGuildConfig>(
"insert into servers (id) values (@Id) on conflict do nothing; select * from servers where id = @Id",
new {Id = guild})).Into();
_cache.CreateEntry(KeyForGuild(guild))
.SetValue(guildConfig)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromMinutes(30))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
return guildConfig;
}
public async Task<CachedAccount> GetAccountDataCached(ulong account)
{
if (_cache.TryGetValue<CachedAccount>(KeyForAccount(account), out var item))
{
_logger.Verbose("Cache hit for account {Account}", account);
return item;
}
_logger.Verbose("Cache miss for account {Account}", account);
var data = await GetAccountData(account);
if (data == null)
{
_logger.Debug("Cached data for account {Account} (no system)", account);
// If we didn't find any value, set a pretty long expiry and the value to null
_cache.CreateEntry(KeyForAccount(account))
.SetValue(null)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromHours(1))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
return null;
}
// If we *did* find the value, cache it for *every account in the system* with a shorter expiry
_logger.Debug("Cached data for system {System} and accounts {Account}", data.System.Id, data.Accounts);
foreach (var linkedAccount in data.Accounts)
{
_cache.CreateEntry(KeyForAccount(linkedAccount))
.SetValue(data)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromMinutes(20))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
// And also do it for the system itself so we can look up by that
_cache.CreateEntry(KeyForSystem(data.System.Id))
.SetValue(data)
.SetSlidingExpiration(TimeSpan.FromMinutes(5))
.SetAbsoluteExpiration(TimeSpan.FromMinutes(20))
.Dispose(); // Don't ask, but this *saves* the entry. Somehow.
}
return data;
}
private async Task<CachedAccount> GetAccountData(ulong account)
{
using var conn = await _db.Obtain();
// Doing this as two queries instead of a two-step join to avoid sending duplicate rows for the system over the network for each member
// This *may* be less efficient, haven't done too much stuff about this but having the system ID saved is very useful later on
var system = await conn.QuerySingleOrDefaultAsync<PKSystem>("select systems.* from accounts inner join systems on systems.id = accounts.system where accounts.uid = @Account", new { Account = account });
if (system == null) return null; // No system = no members = no cache value
// Fetches:
// - List of accounts in the system
// - List of members in the system
// - List of guild settings for the system (for every guild)
// - List of guild settings for each member (for every guild)
// I'm slightly worried the volume of guild settings will get too much, but for simplicity reasons I decided
// against caching them individually per-guild, since I can't imagine they'll be edited *that* much
var result = await conn.QueryMultipleAsync(@"
select uid from accounts where system = @System;
select * from members where system = @System;
select * from system_guild where system = @System;
select member_guild.* from members inner join member_guild on member_guild.member = members.id where members.system = @System;
", new {System = system.Id});
return new CachedAccount
{
System = system,
Accounts = (await result.ReadAsync<ulong>()).ToArray(),
Members = (await result.ReadAsync<PKMember>()).ToArray(),
SystemGuild = (await result.ReadAsync<SystemGuildSettings>()).ToArray(),
MemberGuild = (await result.ReadAsync<MemberGuildSettings>()).ToArray()
};
}
private string KeyForAccount(ulong account) => $"_account_cache_{account}";
private string KeyForSystem(int system) => $"_system_cache_{system}";
private string KeyForGuild(ulong guild) => $"_guild_cache_{guild}";
}
public class CachedAccount
{
public PKSystem System;
public PKMember[] Members;
public SystemGuildSettings[] SystemGuild;
public MemberGuildSettings[] MemberGuild;
public ulong[] Accounts;
public SystemGuildSettings SettingsForGuild(ulong guild) =>
// O(n) lookup since n is small (max ~100 in prod) and we're more constrained by memory (for a dictionary) here
SystemGuild.FirstOrDefault(s => s.Guild == guild) ?? new SystemGuildSettings();
public MemberGuildSettings SettingsForMemberGuild(int memberId, ulong guild) =>
MemberGuild.FirstOrDefault(m => m.Member == memberId && m.Guild == guild) ?? new MemberGuildSettings();
}
}

View File

@ -308,4 +308,13 @@ namespace PluralKit.Core
parameter.Value = (long) value;
}
}
public static class DatabaseExt
{
public static async Task<T> Execute<T>(this DbConnectionFactory db, Func<IDbConnection, Task<T>> func)
{
await using var conn = await db.Obtain();
return await func(conn);
}
}
}