diff --git a/PluralKit.API/Controllers/v2/DiscordControllerV2.cs b/PluralKit.API/Controllers/v2/DiscordControllerV2.cs index 477ad8ca..ec624236 100644 --- a/PluralKit.API/Controllers/v2/DiscordControllerV2.cs +++ b/PluralKit.API/Controllers/v2/DiscordControllerV2.cs @@ -92,7 +92,7 @@ public class DiscordControllerV2: PKControllerBase [HttpGet("messages/{messageId}")] public async Task> MessageGet(ulong messageId) { - var msg = await _db.Execute(c => _repo.GetMessage(c, messageId)); + var msg = await _repo.GetFullMessage(messageId); if (msg == null) throw Errors.MessageNotFound; diff --git a/PluralKit.Bot/Commands/Checks.cs b/PluralKit.Bot/Commands/Checks.cs index f4d37a27..8d00faab 100644 --- a/PluralKit.Bot/Commands/Checks.cs +++ b/PluralKit.Bot/Commands/Checks.cs @@ -197,7 +197,7 @@ public class Checks if (messageId == null || channelId == null) throw new PKError(failedToGetMessage); - var proxiedMsg = await ctx.Database.Execute(conn => ctx.Repository.GetMessage(conn, messageId.Value)); + var proxiedMsg = await ctx.Repository.GetMessage(messageId.Value); if (proxiedMsg != null) { await ctx.Reply($"{Emojis.Success} This message was proxied successfully."); diff --git a/PluralKit.Bot/Commands/Message.cs b/PluralKit.Bot/Commands/Message.cs index ef42ecbb..1bf1bf92 100644 --- a/PluralKit.Bot/Commands/Message.cs +++ b/PluralKit.Bot/Commands/Message.cs @@ -55,9 +55,9 @@ public class ProxiedMessage public async Task ReproxyMessage(Context ctx) { - var msg = await GetMessageToEdit(ctx, ReproxyTimeout, true); + var (msg, systemId) = await GetMessageToEdit(ctx, ReproxyTimeout, true); - if (ctx.System.Id != msg.System?.Id) + if (ctx.System.Id != systemId) throw new PKError("Can't reproxy a message sent by a different system."); // Get target member ID @@ -68,14 +68,14 @@ public class ProxiedMessage // Fetch members and get the ProxyMember for `target` List members; using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime)) - members = (await _repo.GetProxyMembers(ctx.Author.Id, msg.Message.Guild!.Value)).ToList(); + members = (await _repo.GetProxyMembers(ctx.Author.Id, msg.Guild!.Value)).ToList(); var match = members.Find(x => x.Id == target.Id); if (match == null) throw new PKError("Could not find a member to reproxy the message with."); try { - await _proxy.ExecuteReproxy(ctx.Message, msg.Message, members, match); + await _proxy.ExecuteReproxy(ctx.Message, msg, members, match); if (ctx.Guild == null) await _rest.CreateReaction(ctx.Channel.Id, ctx.Message.Id, new Emoji { Name = Emojis.Success }); @@ -90,15 +90,15 @@ public class ProxiedMessage public async Task EditMessage(Context ctx) { - var msg = await GetMessageToEdit(ctx, EditTimeout, false); + var (msg, systemId) = await GetMessageToEdit(ctx, EditTimeout, false); - if (ctx.System.Id != msg.System?.Id) + if (ctx.System.Id != systemId) throw new PKError("Can't edit a message sent by a different system."); if (!ctx.HasNext()) throw new PKSyntaxError("You need to include the message to edit in."); - var originalMsg = await _rest.GetMessageOrNull(msg.Message.Channel, msg.Message.Mid); + var originalMsg = await _rest.GetMessageOrNull(msg.Channel, msg.Mid); if (originalMsg == null) throw new PKError("Could not edit message."); @@ -124,7 +124,7 @@ public class ProxiedMessage try { var editedMsg = - await _webhookExecutor.EditWebhookMessage(msg.Message.Channel, msg.Message.Mid, newContent); + await _webhookExecutor.EditWebhookMessage(msg.Channel, msg.Mid, newContent); if (ctx.Guild == null) await _rest.CreateReaction(ctx.Channel.Id, ctx.Message.Id, new Emoji { Name = Emojis.Success }); @@ -132,7 +132,7 @@ public class ProxiedMessage if ((await ctx.BotPermissions).HasFlag(PermissionSet.ManageMessages)) await _rest.DeleteMessage(ctx.Channel.Id, ctx.Message.Id); - await _logChannel.LogMessage(msg.Message, ctx.Message, editedMsg, originalMsg!.Content!); + await _logChannel.LogMessage(msg, ctx.Message, editedMsg, originalMsg!.Content!); } catch (NotFoundException) { @@ -140,18 +140,18 @@ public class ProxiedMessage } } - private async Task GetMessageToEdit(Context ctx, Duration timeout, bool isReproxy) + private async Task<(PKMessage, SystemId)> GetMessageToEdit(Context ctx, Duration timeout, bool isReproxy) { var editType = isReproxy ? "reproxy" : "edit"; var editTypeAction = isReproxy ? "reproxied" : "edited"; - FullMessage? msg = null; + PKMessage? msg = null; var (referencedMessage, _) = ctx.MatchMessage(false); if (referencedMessage != null) { await using var conn = await ctx.Database.Obtain(); - msg = await ctx.Repository.GetMessage(conn, referencedMessage.Value); + msg = await ctx.Repository.GetMessage(referencedMessage.Value); if (msg == null) throw new PKError("This is not a message proxied by PluralKit."); } @@ -161,7 +161,7 @@ public class ProxiedMessage if (ctx.Guild == null) throw new PKSyntaxError($"You must use a message link to {editType} messages in DMs."); - PKMessage? recent; + ulong? recent = null; if (isReproxy) recent = await ctx.Repository.GetLastMessage(ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id); @@ -172,17 +172,21 @@ public class ProxiedMessage throw new PKSyntaxError($"Could not find a recent message to {editType}."); await using var conn = await ctx.Database.Obtain(); - msg = await ctx.Repository.GetMessage(conn, recent.Mid); + msg = await ctx.Repository.GetMessage(recent.Value); if (msg == null) throw new PKSyntaxError($"Could not find a recent message to {editType}."); } - if (msg.Message.Channel != ctx.Channel.Id) + var member = await ctx.Repository.GetMember(msg.Member!.Value); + if (member == null) + throw new PKSyntaxError($"Could not find a recent message to {editType}."); + + if (msg.Channel != ctx.Channel.Id) { var error = "The channel where the message was sent does not exist anymore, or you are missing permissions to access it."; - var channel = await _rest.GetChannelOrNull(msg.Message.Channel); + var channel = await _rest.GetChannelOrNull(msg.Channel); if (channel == null) throw new PKError(error); @@ -192,16 +196,18 @@ public class ProxiedMessage throw new PKError(error); } - var isLatestMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId)?.Current.Id == ctx.Message.Id - ? _lastMessageCache.GetLastMessage(ctx.Message.ChannelId)?.Previous?.Id == msg.Message.Mid - : _lastMessageCache.GetLastMessage(ctx.Message.ChannelId)?.Current.Id == msg.Message.Mid; + var lastMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId); - var msgTimestamp = DiscordUtils.SnowflakeToInstant(msg.Message.Mid); + var isLatestMessage = lastMessage?.Current.Id == ctx.Message.Id + ? lastMessage?.Previous?.Id == msg.Mid + : lastMessage?.Current.Id == msg.Mid; + + var msgTimestamp = DiscordUtils.SnowflakeToInstant(msg.Mid); if (isReproxy && !isLatestMessage) if (SystemClock.Instance.GetCurrentInstant() - msgTimestamp > timeout) throw new PKError($"The message is too old to be {editTypeAction}."); - return msg; + return (msg, member.System); } private async Task FindRecentMessage(Context ctx, Duration timeout) @@ -229,7 +235,7 @@ public class ProxiedMessage var isDelete = ctx.Match("delete") || ctx.MatchFlag("delete"); - var message = await ctx.Database.Execute(c => ctx.Repository.GetMessage(c, messageId.Value)); + var message = await ctx.Repository.GetFullMessage(messageId.Value); if (message == null) { if (isDelete) diff --git a/PluralKit.Bot/Handlers/ReactionAdded.cs b/PluralKit.Bot/Handlers/ReactionAdded.cs index 50952e08..0331e218 100644 --- a/PluralKit.Bot/Handlers/ReactionAdded.cs +++ b/PluralKit.Bot/Handlers/ReactionAdded.cs @@ -91,7 +91,7 @@ public class ReactionAdded: IEventHandler // Message deletion case "\u274C": // Red X { - var msg = await _db.Execute(c => _repo.GetMessage(c, evt.MessageId)); + var msg = await _repo.GetMessage(evt.MessageId); if (msg != null) await HandleProxyDeleteReaction(evt, msg); @@ -100,7 +100,7 @@ public class ReactionAdded: IEventHandler case "\u2753": // Red question mark case "\u2754": // White question mark { - var msg = await _db.Execute(c => _repo.GetMessage(c, evt.MessageId)); + var msg = await _repo.GetFullMessage(evt.MessageId); if (msg != null) await HandleQueryReaction(evt, msg); @@ -113,7 +113,7 @@ public class ReactionAdded: IEventHandler case "\u23F0": // Alarm clock case "\u2757": // Exclamation mark { - var msg = await _db.Execute(c => _repo.GetMessage(c, evt.MessageId)); + var msg = await _repo.GetFullMessage(evt.MessageId); if (msg != null) await HandlePingReaction(evt, msg); break; @@ -121,15 +121,15 @@ public class ReactionAdded: IEventHandler } } - private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEvent evt, FullMessage msg) + private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEvent evt, PKMessage msg) { if (!(await _cache.PermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages)) return; - var system = await _repo.GetSystemByAccount(evt.UserId); + var isSameSystem = msg.Member != null && await _repo.IsMemberOwnedByAccount(msg.Member.Value, evt.UserId); - // Can only delete your own message - if (msg.System?.Id != system?.Id && msg.Message.Sender != evt.UserId) return; + // Can only delete your own message (same system or same Discord account) + if (!isSameSystem && msg.Sender != evt.UserId) return; try { diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index abdd4180..52e12b2b 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -375,7 +375,7 @@ public class ProxyService // cache is out of date or channel is empty. return proxyName; - var pkMessage = await _db.Execute(conn => _repo.GetMessage(conn, lastMessage.Id)); + var pkMessage = await _repo.GetMessage(lastMessage.Id); if (lastMessage.AuthorUsername == proxyName) { @@ -385,12 +385,12 @@ public class ProxyService return FixSameNameInner(proxyName); // last message was proxied by a different member - if (pkMessage.Member?.Id != member.Id) + if (pkMessage.Member != member.Id) return FixSameNameInner(proxyName); } // if we fixed the name last message and it's the same member proxying, we want to fix it again - if (lastMessage.AuthorUsername == FixSameNameInner(proxyName) && pkMessage?.Member?.Id == member.Id) + if (lastMessage.AuthorUsername == FixSameNameInner(proxyName) && pkMessage?.Member == member.Id) return FixSameNameInner(proxyName); // No issues found, current proxy name is fine. diff --git a/PluralKit.Core/CoreConfig.cs b/PluralKit.Core/CoreConfig.cs index a4f06c7a..c70f9791 100644 --- a/PluralKit.Core/CoreConfig.cs +++ b/PluralKit.Core/CoreConfig.cs @@ -5,6 +5,7 @@ namespace PluralKit.Core; public class CoreConfig { public string Database { get; set; } + public string? MessagesDatabase { get; set; } public string? DatabasePassword { get; set; } public string RedisAddr { get; set; } public bool UseRedisMetrics { get; set; } = false; diff --git a/PluralKit.Core/Database/Database.cs b/PluralKit.Core/Database/Database.cs index ca11d683..d7625dd9 100644 --- a/PluralKit.Core/Database/Database.cs +++ b/PluralKit.Core/Database/Database.cs @@ -25,6 +25,7 @@ internal partial class Database: IDatabase private readonly DbConnectionCountHolder _countHolder; private readonly DatabaseMigrator _migrator; private readonly string _connectionString; + private readonly string _messagesConnectionString; public Database(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger, IMetrics metrics, DatabaseMigrator migrator) @@ -35,20 +36,26 @@ internal partial class Database: IDatabase _migrator = migrator; _logger = logger.ForContext(); - var connectionString = new NpgsqlConnectionStringBuilder(_config.Database) + string connectionString(string src) { - Pooling = true, - Enlist = false, - NoResetOnClose = true, + var builder = new NpgsqlConnectionStringBuilder(src) + { + Pooling = true, + Enlist = false, + NoResetOnClose = true, - // Lower timeout than default (15s -> 2s), should ideally fail-fast instead of hanging - Timeout = 2 - }; + // Lower timeout than default (15s -> 2s), should ideally fail-fast instead of hanging + Timeout = 2 + }; - if (_config.DatabasePassword != null) - connectionString.Password = _config.DatabasePassword; + if (_config.DatabasePassword != null) + builder.Password = _config.DatabasePassword; - _connectionString = connectionString.ConnectionString; + return builder.ConnectionString; + } + + _connectionString = connectionString(_config.Database); + _messagesConnectionString = connectionString(_config.MessagesDatabase ?? _config.Database); } private static readonly PostgresCompiler _compiler = new(); @@ -88,14 +95,14 @@ internal partial class Database: IDatabase } // TODO: make sure every SQL query is behind a logged query method - public async Task Obtain() + public async Task Obtain(bool messages = false) { // Mark the request (for a handle, I guess) in the metrics _metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests); // Create a connection and open it // We wrap it in PKConnection for tracing purposes - var conn = new PKConnection(new NpgsqlConnection(_connectionString), _countHolder, _logger, _metrics); + var conn = new PKConnection(new NpgsqlConnection(messages ? _messagesConnectionString : _connectionString), _countHolder, _logger, _metrics); await conn.OpenAsync(); return conn; } diff --git a/PluralKit.Core/Database/DatabaseQueries.cs b/PluralKit.Core/Database/DatabaseQueries.cs index a41e101b..073c99dd 100644 --- a/PluralKit.Core/Database/DatabaseQueries.cs +++ b/PluralKit.Core/Database/DatabaseQueries.cs @@ -31,10 +31,17 @@ internal partial class Database: IDatabase yield return val; } - public async Task ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "") + public async Task QueryFirst(string q, object param = null, [CallerMemberName] string queryName = "", bool messages = false) + { + using var conn = await Obtain(messages); + using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) + return await conn.QueryFirstOrDefaultAsync(q, param); + } + + public async Task ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "", bool messages = false) { var query = _compiler.Compile(q); - using var conn = await Obtain(); + using var conn = await Obtain(messages); using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName))) return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings); } diff --git a/PluralKit.Core/Database/IDatabase.cs b/PluralKit.Core/Database/IDatabase.cs index f49153ba..8de9870f 100644 --- a/PluralKit.Core/Database/IDatabase.cs +++ b/PluralKit.Core/Database/IDatabase.cs @@ -7,15 +7,16 @@ namespace PluralKit.Core; public interface IDatabase { Task ApplyMigrations(); - Task Obtain(); + Task Obtain(bool messages = false); Task Execute(Func func); Task Execute(Func> func); IAsyncEnumerable Execute(Func> func); - Task ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = ""); + Task ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "", bool messages = false); Task ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = ""); + Task QueryFirst(string q, object param = null, [CallerMemberName] string queryName = "", bool messages = false); Task QueryFirst(Query q, string extraSql = "", [CallerMemberName] string queryName = ""); Task QueryFirst(IPKConnection? conn, Query q, string extraSql = "", diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Member.cs b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs index ae548602..d676b11a 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Member.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Member.cs @@ -89,4 +89,12 @@ public partial class ModelRepository if (oldMember != null) _ = _dispatch.Dispatch(oldMember.System, oldMember.Uuid, DispatchEvent.DELETE_MEMBER); } + + public async Task IsMemberOwnedByAccount(MemberId id, ulong userId) + { + return await _db.QueryFirst( + "select true from accounts, members where members.id = @member and accounts.uid = @account and members.system = accounts.system", + new { member = id, account = userId } + ); + } } \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Message.cs b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs index 9406b6ff..c1a77609 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Message.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Message.cs @@ -20,29 +20,38 @@ public partial class ModelRepository _logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel); // "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before - return _db.ExecuteQuery(query, "on conflict do nothing"); + return _db.ExecuteQuery(query, "on conflict do nothing", messages: true); } - // todo: add a Mapper to QuerySingle and move this to SqlKata - public async Task GetMessage(IPKConnection conn, ulong id) + public async Task GetMessage(ulong id) { - FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) => - new() { Message = msg, System = system, Member = member }; + return await _db.QueryFirst( + "select * from messages where mid = @Id", + new { Id = id }, + messages: true + ); + } - var query = "select * from messages" - + " left join members on messages.member = members.id" - + " left join systems on members.system = systems.id" - + " where (mid = @Id or original_mid = @Id)"; + public async Task GetFullMessage(ulong id) + { + var rawMessage = await GetMessage(id); + if (rawMessage == null) return null; - var result = await conn.QueryAsync( - query, Mapper, new { Id = id }); - return result.FirstOrDefault(); + var member = rawMessage.Member == null ? null : await GetMember(rawMessage.Member.Value); + var system = member == null ? null : await GetSystem(member.System); + + return new FullMessage + { + Message = rawMessage, + Member = member, + System = system, + }; } public async Task DeleteMessage(ulong id) { var query = new Query("messages").AsDelete().Where("mid", id); - var rowCount = await _db.ExecuteQuery(query); + var rowCount = await _db.ExecuteQuery(query, messages: true); if (rowCount > 0) _logger.Information("Deleted message {MessageId} from database", id); } @@ -52,22 +61,9 @@ public partial class ModelRepository // Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong // Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway) var query = new Query("messages").AsDelete().WhereIn("mid", ids.Select(id => (long)id).ToArray()); - var rowCount = await _db.ExecuteQuery(query); + var rowCount = await _db.ExecuteQuery(query, messages: true); if (rowCount > 0) _logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount, ids); } - - public Task GetLastMessage(ulong guildId, ulong channelId, ulong accountId) - { - // Want to index scan on the (guild, sender, mid) index so need the additional constraint - var query = new Query("messages") - .Where("guild", guildId) - .Where("channel", channelId) - .Where("sender", accountId) - .OrderByDesc("mid") - .Limit(1); - - return _db.QueryFirst(query); - } } \ No newline at end of file diff --git a/services/scheduled_tasks/db.go b/services/scheduled_tasks/db.go index 62e7cbab..e16076a2 100644 --- a/services/scheduled_tasks/db.go +++ b/services/scheduled_tasks/db.go @@ -8,6 +8,7 @@ import ( ) var data_db *pgx.Conn +var messages_db *pgx.Conn var stats_db *pgx.Conn var rdb *redis.Client @@ -20,6 +21,7 @@ func run_simple_pg_query(c *pgx.Conn, sql string) { func connect_dbs() { data_db = pg_connect(get_env_var("DATA_DB_URI")) + messages_db = pg_connect(get_env_var("MESSAGES_DB_URI")) stats_db = pg_connect(get_env_var("STATS_DB_URI")) rdb = redis_connect(get_env_var("REDIS_ADDR")) } diff --git a/services/scheduled_tasks/repo.go b/services/scheduled_tasks/repo.go index f12ad2d7..7813bb5c 100644 --- a/services/scheduled_tasks/repo.go +++ b/services/scheduled_tasks/repo.go @@ -37,6 +37,15 @@ func run_redis_query() []rstatval { return values } +func get_message_count() int { + var count int + row := messages_db.QueryRow(context.Background(), "select count(*) as count from systems") + if err := row.Scan(&count); err != nil { + panic(err) + } + return count +} + func run_data_stats_query() map[string]interface{} { s := map[string]interface{}{} diff --git a/services/scheduled_tasks/tasks.go b/services/scheduled_tasks/tasks.go index 6ea87225..374cd90f 100644 --- a/services/scheduled_tasks/tasks.go +++ b/services/scheduled_tasks/tasks.go @@ -36,14 +36,12 @@ func update_db_meta() { } func update_db_message_meta() { - // since we're doing this concurrently, it needs a separate db connection - tmp_db := pg_connect(get_env_var("DATA_DB_URI")) - defer tmp_db.Close(context.Background()) + count := get_message_count() - key := "message" - q := fmt.Sprintf("update info set %s_count = (select count(*) from %s)", key, plural(key)) - log.Println("data db query:", q) - run_simple_pg_query(tmp_db, q) + _, err := data_db.Exec(context.Background(), "update info set message_count = $1", count) + if err != nil { + panic(err) + } } func update_stats() {