feat: split out messages table from main database

This commit is contained in:
spiral 2022-11-23 09:17:19 +00:00
parent 09ac002d26
commit bf7747ab34
No known key found for this signature in database
GPG Key ID: 244A11E4B0BCF40E
14 changed files with 119 additions and 84 deletions

View File

@ -92,7 +92,7 @@ public class DiscordControllerV2: PKControllerBase
[HttpGet("messages/{messageId}")]
public async Task<ActionResult<JObject>> MessageGet(ulong messageId)
{
var msg = await _db.Execute(c => _repo.GetMessage(c, messageId));
var msg = await _repo.GetFullMessage(messageId);
if (msg == null)
throw Errors.MessageNotFound;

View File

@ -197,7 +197,7 @@ public class Checks
if (messageId == null || channelId == null)
throw new PKError(failedToGetMessage);
var proxiedMsg = await ctx.Database.Execute(conn => ctx.Repository.GetMessage(conn, messageId.Value));
var proxiedMsg = await ctx.Repository.GetMessage(messageId.Value);
if (proxiedMsg != null)
{
await ctx.Reply($"{Emojis.Success} This message was proxied successfully.");

View File

@ -55,9 +55,9 @@ public class ProxiedMessage
public async Task ReproxyMessage(Context ctx)
{
var msg = await GetMessageToEdit(ctx, ReproxyTimeout, true);
var (msg, systemId) = await GetMessageToEdit(ctx, ReproxyTimeout, true);
if (ctx.System.Id != msg.System?.Id)
if (ctx.System.Id != systemId)
throw new PKError("Can't reproxy a message sent by a different system.");
// Get target member ID
@ -68,14 +68,14 @@ public class ProxiedMessage
// Fetch members and get the ProxyMember for `target`
List<ProxyMember> members;
using (_metrics.Measure.Timer.Time(BotMetrics.ProxyMembersQueryTime))
members = (await _repo.GetProxyMembers(ctx.Author.Id, msg.Message.Guild!.Value)).ToList();
members = (await _repo.GetProxyMembers(ctx.Author.Id, msg.Guild!.Value)).ToList();
var match = members.Find(x => x.Id == target.Id);
if (match == null)
throw new PKError("Could not find a member to reproxy the message with.");
try
{
await _proxy.ExecuteReproxy(ctx.Message, msg.Message, members, match);
await _proxy.ExecuteReproxy(ctx.Message, msg, members, match);
if (ctx.Guild == null)
await _rest.CreateReaction(ctx.Channel.Id, ctx.Message.Id, new Emoji { Name = Emojis.Success });
@ -90,15 +90,15 @@ public class ProxiedMessage
public async Task EditMessage(Context ctx)
{
var msg = await GetMessageToEdit(ctx, EditTimeout, false);
var (msg, systemId) = await GetMessageToEdit(ctx, EditTimeout, false);
if (ctx.System.Id != msg.System?.Id)
if (ctx.System.Id != systemId)
throw new PKError("Can't edit a message sent by a different system.");
if (!ctx.HasNext())
throw new PKSyntaxError("You need to include the message to edit in.");
var originalMsg = await _rest.GetMessageOrNull(msg.Message.Channel, msg.Message.Mid);
var originalMsg = await _rest.GetMessageOrNull(msg.Channel, msg.Mid);
if (originalMsg == null)
throw new PKError("Could not edit message.");
@ -124,7 +124,7 @@ public class ProxiedMessage
try
{
var editedMsg =
await _webhookExecutor.EditWebhookMessage(msg.Message.Channel, msg.Message.Mid, newContent);
await _webhookExecutor.EditWebhookMessage(msg.Channel, msg.Mid, newContent);
if (ctx.Guild == null)
await _rest.CreateReaction(ctx.Channel.Id, ctx.Message.Id, new Emoji { Name = Emojis.Success });
@ -132,7 +132,7 @@ public class ProxiedMessage
if ((await ctx.BotPermissions).HasFlag(PermissionSet.ManageMessages))
await _rest.DeleteMessage(ctx.Channel.Id, ctx.Message.Id);
await _logChannel.LogMessage(msg.Message, ctx.Message, editedMsg, originalMsg!.Content!);
await _logChannel.LogMessage(msg, ctx.Message, editedMsg, originalMsg!.Content!);
}
catch (NotFoundException)
{
@ -140,18 +140,18 @@ public class ProxiedMessage
}
}
private async Task<FullMessage> GetMessageToEdit(Context ctx, Duration timeout, bool isReproxy)
private async Task<(PKMessage, SystemId)> GetMessageToEdit(Context ctx, Duration timeout, bool isReproxy)
{
var editType = isReproxy ? "reproxy" : "edit";
var editTypeAction = isReproxy ? "reproxied" : "edited";
FullMessage? msg = null;
PKMessage? msg = null;
var (referencedMessage, _) = ctx.MatchMessage(false);
if (referencedMessage != null)
{
await using var conn = await ctx.Database.Obtain();
msg = await ctx.Repository.GetMessage(conn, referencedMessage.Value);
msg = await ctx.Repository.GetMessage(referencedMessage.Value);
if (msg == null)
throw new PKError("This is not a message proxied by PluralKit.");
}
@ -161,7 +161,7 @@ public class ProxiedMessage
if (ctx.Guild == null)
throw new PKSyntaxError($"You must use a message link to {editType} messages in DMs.");
PKMessage? recent;
ulong? recent = null;
if (isReproxy)
recent = await ctx.Repository.GetLastMessage(ctx.Guild.Id, ctx.Channel.Id, ctx.Author.Id);
@ -172,17 +172,21 @@ public class ProxiedMessage
throw new PKSyntaxError($"Could not find a recent message to {editType}.");
await using var conn = await ctx.Database.Obtain();
msg = await ctx.Repository.GetMessage(conn, recent.Mid);
msg = await ctx.Repository.GetMessage(recent.Value);
if (msg == null)
throw new PKSyntaxError($"Could not find a recent message to {editType}.");
}
if (msg.Message.Channel != ctx.Channel.Id)
var member = await ctx.Repository.GetMember(msg.Member!.Value);
if (member == null)
throw new PKSyntaxError($"Could not find a recent message to {editType}.");
if (msg.Channel != ctx.Channel.Id)
{
var error =
"The channel where the message was sent does not exist anymore, or you are missing permissions to access it.";
var channel = await _rest.GetChannelOrNull(msg.Message.Channel);
var channel = await _rest.GetChannelOrNull(msg.Channel);
if (channel == null)
throw new PKError(error);
@ -192,16 +196,18 @@ public class ProxiedMessage
throw new PKError(error);
}
var isLatestMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId)?.Current.Id == ctx.Message.Id
? _lastMessageCache.GetLastMessage(ctx.Message.ChannelId)?.Previous?.Id == msg.Message.Mid
: _lastMessageCache.GetLastMessage(ctx.Message.ChannelId)?.Current.Id == msg.Message.Mid;
var lastMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId);
var msgTimestamp = DiscordUtils.SnowflakeToInstant(msg.Message.Mid);
var isLatestMessage = lastMessage?.Current.Id == ctx.Message.Id
? lastMessage?.Previous?.Id == msg.Mid
: lastMessage?.Current.Id == msg.Mid;
var msgTimestamp = DiscordUtils.SnowflakeToInstant(msg.Mid);
if (isReproxy && !isLatestMessage)
if (SystemClock.Instance.GetCurrentInstant() - msgTimestamp > timeout)
throw new PKError($"The message is too old to be {editTypeAction}.");
return msg;
return (msg, member.System);
}
private async Task<PKMessage?> FindRecentMessage(Context ctx, Duration timeout)
@ -229,7 +235,7 @@ public class ProxiedMessage
var isDelete = ctx.Match("delete") || ctx.MatchFlag("delete");
var message = await ctx.Database.Execute(c => ctx.Repository.GetMessage(c, messageId.Value));
var message = await ctx.Repository.GetFullMessage(messageId.Value);
if (message == null)
{
if (isDelete)

View File

@ -91,7 +91,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
// Message deletion
case "\u274C": // Red X
{
var msg = await _db.Execute(c => _repo.GetMessage(c, evt.MessageId));
var msg = await _repo.GetMessage(evt.MessageId);
if (msg != null)
await HandleProxyDeleteReaction(evt, msg);
@ -100,7 +100,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
case "\u2753": // Red question mark
case "\u2754": // White question mark
{
var msg = await _db.Execute(c => _repo.GetMessage(c, evt.MessageId));
var msg = await _repo.GetFullMessage(evt.MessageId);
if (msg != null)
await HandleQueryReaction(evt, msg);
@ -113,7 +113,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
case "\u23F0": // Alarm clock
case "\u2757": // Exclamation mark
{
var msg = await _db.Execute(c => _repo.GetMessage(c, evt.MessageId));
var msg = await _repo.GetFullMessage(evt.MessageId);
if (msg != null)
await HandlePingReaction(evt, msg);
break;
@ -121,15 +121,15 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
}
}
private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEvent evt, FullMessage msg)
private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEvent evt, PKMessage msg)
{
if (!(await _cache.PermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages))
return;
var system = await _repo.GetSystemByAccount(evt.UserId);
var isSameSystem = msg.Member != null && await _repo.IsMemberOwnedByAccount(msg.Member.Value, evt.UserId);
// Can only delete your own message
if (msg.System?.Id != system?.Id && msg.Message.Sender != evt.UserId) return;
// Can only delete your own message (same system or same Discord account)
if (!isSameSystem && msg.Sender != evt.UserId) return;
try
{

View File

@ -375,7 +375,7 @@ public class ProxyService
// cache is out of date or channel is empty.
return proxyName;
var pkMessage = await _db.Execute(conn => _repo.GetMessage(conn, lastMessage.Id));
var pkMessage = await _repo.GetMessage(lastMessage.Id);
if (lastMessage.AuthorUsername == proxyName)
{
@ -385,12 +385,12 @@ public class ProxyService
return FixSameNameInner(proxyName);
// last message was proxied by a different member
if (pkMessage.Member?.Id != member.Id)
if (pkMessage.Member != member.Id)
return FixSameNameInner(proxyName);
}
// if we fixed the name last message and it's the same member proxying, we want to fix it again
if (lastMessage.AuthorUsername == FixSameNameInner(proxyName) && pkMessage?.Member?.Id == member.Id)
if (lastMessage.AuthorUsername == FixSameNameInner(proxyName) && pkMessage?.Member == member.Id)
return FixSameNameInner(proxyName);
// No issues found, current proxy name is fine.

View File

@ -5,6 +5,7 @@ namespace PluralKit.Core;
public class CoreConfig
{
public string Database { get; set; }
public string? MessagesDatabase { get; set; }
public string? DatabasePassword { get; set; }
public string RedisAddr { get; set; }
public bool UseRedisMetrics { get; set; } = false;

View File

@ -25,6 +25,7 @@ internal partial class Database: IDatabase
private readonly DbConnectionCountHolder _countHolder;
private readonly DatabaseMigrator _migrator;
private readonly string _connectionString;
private readonly string _messagesConnectionString;
public Database(CoreConfig config, DbConnectionCountHolder countHolder, ILogger logger,
IMetrics metrics, DatabaseMigrator migrator)
@ -35,20 +36,26 @@ internal partial class Database: IDatabase
_migrator = migrator;
_logger = logger.ForContext<Database>();
var connectionString = new NpgsqlConnectionStringBuilder(_config.Database)
string connectionString(string src)
{
Pooling = true,
Enlist = false,
NoResetOnClose = true,
var builder = new NpgsqlConnectionStringBuilder(src)
{
Pooling = true,
Enlist = false,
NoResetOnClose = true,
// Lower timeout than default (15s -> 2s), should ideally fail-fast instead of hanging
Timeout = 2
};
// Lower timeout than default (15s -> 2s), should ideally fail-fast instead of hanging
Timeout = 2
};
if (_config.DatabasePassword != null)
connectionString.Password = _config.DatabasePassword;
if (_config.DatabasePassword != null)
builder.Password = _config.DatabasePassword;
_connectionString = connectionString.ConnectionString;
return builder.ConnectionString;
}
_connectionString = connectionString(_config.Database);
_messagesConnectionString = connectionString(_config.MessagesDatabase ?? _config.Database);
}
private static readonly PostgresCompiler _compiler = new();
@ -88,14 +95,14 @@ internal partial class Database: IDatabase
}
// TODO: make sure every SQL query is behind a logged query method
public async Task<IPKConnection> Obtain()
public async Task<IPKConnection> Obtain(bool messages = false)
{
// Mark the request (for a handle, I guess) in the metrics
_metrics.Measure.Meter.Mark(CoreMetrics.DatabaseRequests);
// Create a connection and open it
// We wrap it in PKConnection for tracing purposes
var conn = new PKConnection(new NpgsqlConnection(_connectionString), _countHolder, _logger, _metrics);
var conn = new PKConnection(new NpgsqlConnection(messages ? _messagesConnectionString : _connectionString), _countHolder, _logger, _metrics);
await conn.OpenAsync();
return conn;
}

View File

@ -31,10 +31,17 @@ internal partial class Database: IDatabase
yield return val;
}
public async Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "")
public async Task<T> QueryFirst<T>(string q, object param = null, [CallerMemberName] string queryName = "", bool messages = false)
{
using var conn = await Obtain(messages);
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
return await conn.QueryFirstOrDefaultAsync<T>(q, param);
}
public async Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "", bool messages = false)
{
var query = _compiler.Compile(q);
using var conn = await Obtain();
using var conn = await Obtain(messages);
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings);
}

View File

@ -7,15 +7,16 @@ namespace PluralKit.Core;
public interface IDatabase
{
Task ApplyMigrations();
Task<IPKConnection> Obtain();
Task<IPKConnection> Obtain(bool messages = false);
Task Execute(Func<IPKConnection, Task> func);
Task<T> Execute<T>(Func<IPKConnection, Task<T>> func);
IAsyncEnumerable<T> Execute<T>(Func<IPKConnection, IAsyncEnumerable<T>> func);
Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "");
Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "", bool messages = false);
Task<int> ExecuteQuery(IPKConnection? conn, Query q, string extraSql = "",
[CallerMemberName] string queryName = "");
Task<T> QueryFirst<T>(string q, object param = null, [CallerMemberName] string queryName = "", bool messages = false);
Task<T> QueryFirst<T>(Query q, string extraSql = "", [CallerMemberName] string queryName = "");
Task<T> QueryFirst<T>(IPKConnection? conn, Query q, string extraSql = "",

View File

@ -89,4 +89,12 @@ public partial class ModelRepository
if (oldMember != null)
_ = _dispatch.Dispatch(oldMember.System, oldMember.Uuid, DispatchEvent.DELETE_MEMBER);
}
public async Task<bool> IsMemberOwnedByAccount(MemberId id, ulong userId)
{
return await _db.QueryFirst<bool>(
"select true from accounts, members where members.id = @member and accounts.uid = @account and members.system = accounts.system",
new { member = id, account = userId }
);
}
}

View File

@ -20,29 +20,38 @@ public partial class ModelRepository
_logger.Debug("Stored message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
return _db.ExecuteQuery(query, "on conflict do nothing");
return _db.ExecuteQuery(query, "on conflict do nothing", messages: true);
}
// todo: add a Mapper to QuerySingle and move this to SqlKata
public async Task<FullMessage?> GetMessage(IPKConnection conn, ulong id)
public async Task<PKMessage?> GetMessage(ulong id)
{
FullMessage Mapper(PKMessage msg, PKMember member, PKSystem system) =>
new() { Message = msg, System = system, Member = member };
return await _db.QueryFirst<PKMessage?>(
"select * from messages where mid = @Id",
new { Id = id },
messages: true
);
}
var query = "select * from messages"
+ " left join members on messages.member = members.id"
+ " left join systems on members.system = systems.id"
+ " where (mid = @Id or original_mid = @Id)";
public async Task<FullMessage?> GetFullMessage(ulong id)
{
var rawMessage = await GetMessage(id);
if (rawMessage == null) return null;
var result = await conn.QueryAsync<PKMessage, PKMember, PKSystem, FullMessage>(
query, Mapper, new { Id = id });
return result.FirstOrDefault();
var member = rawMessage.Member == null ? null : await GetMember(rawMessage.Member.Value);
var system = member == null ? null : await GetSystem(member.System);
return new FullMessage
{
Message = rawMessage,
Member = member,
System = system,
};
}
public async Task DeleteMessage(ulong id)
{
var query = new Query("messages").AsDelete().Where("mid", id);
var rowCount = await _db.ExecuteQuery(query);
var rowCount = await _db.ExecuteQuery(query, messages: true);
if (rowCount > 0)
_logger.Information("Deleted message {MessageId} from database", id);
}
@ -52,22 +61,9 @@ public partial class ModelRepository
// Npgsql doesn't support ulongs in general - we hacked around it for plain ulongs but tbh not worth it for collections of ulong
// Hence we map them to single longs, which *are* supported (this is ok since they're Technically (tm) stored as signed longs in the db anyway)
var query = new Query("messages").AsDelete().WhereIn("mid", ids.Select(id => (long)id).ToArray());
var rowCount = await _db.ExecuteQuery(query);
var rowCount = await _db.ExecuteQuery(query, messages: true);
if (rowCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
ids);
}
public Task<PKMessage?> GetLastMessage(ulong guildId, ulong channelId, ulong accountId)
{
// Want to index scan on the (guild, sender, mid) index so need the additional constraint
var query = new Query("messages")
.Where("guild", guildId)
.Where("channel", channelId)
.Where("sender", accountId)
.OrderByDesc("mid")
.Limit(1);
return _db.QueryFirst<PKMessage?>(query);
}
}

View File

@ -8,6 +8,7 @@ import (
)
var data_db *pgx.Conn
var messages_db *pgx.Conn
var stats_db *pgx.Conn
var rdb *redis.Client
@ -20,6 +21,7 @@ func run_simple_pg_query(c *pgx.Conn, sql string) {
func connect_dbs() {
data_db = pg_connect(get_env_var("DATA_DB_URI"))
messages_db = pg_connect(get_env_var("MESSAGES_DB_URI"))
stats_db = pg_connect(get_env_var("STATS_DB_URI"))
rdb = redis_connect(get_env_var("REDIS_ADDR"))
}

View File

@ -37,6 +37,15 @@ func run_redis_query() []rstatval {
return values
}
func get_message_count() int {
var count int
row := messages_db.QueryRow(context.Background(), "select count(*) as count from systems")
if err := row.Scan(&count); err != nil {
panic(err)
}
return count
}
func run_data_stats_query() map[string]interface{} {
s := map[string]interface{}{}

View File

@ -36,14 +36,12 @@ func update_db_meta() {
}
func update_db_message_meta() {
// since we're doing this concurrently, it needs a separate db connection
tmp_db := pg_connect(get_env_var("DATA_DB_URI"))
defer tmp_db.Close(context.Background())
count := get_message_count()
key := "message"
q := fmt.Sprintf("update info set %s_count = (select count(*) from %s)", key, plural(key))
log.Println("data db query:", q)
run_simple_pg_query(tmp_db, q)
_, err := data_db.Exec(context.Background(), "update info set message_count = $1", count)
if err != nil {
panic(err)
}
}
func update_stats() {