diff --git a/Myriad/Cache/IDiscordCache.cs b/Myriad/Cache/IDiscordCache.cs index 7f522afc..a9ecf4de 100644 --- a/Myriad/Cache/IDiscordCache.cs +++ b/Myriad/Cache/IDiscordCache.cs @@ -20,7 +20,6 @@ public interface IDiscordCache public Task GetOwnUser(); public Task TryGetGuild(ulong guildId); public Task TryGetChannel(ulong channelId); - public Task TryGetDmChannel(ulong userId); public Task TryGetUser(ulong userId); public Task TryGetSelfMember(ulong guildId); public Task TryGetRole(ulong roleId); diff --git a/Myriad/Cache/MemoryDiscordCache.cs b/Myriad/Cache/MemoryDiscordCache.cs index b715266e..6ff9f48a 100644 --- a/Myriad/Cache/MemoryDiscordCache.cs +++ b/Myriad/Cache/MemoryDiscordCache.cs @@ -7,7 +7,6 @@ namespace Myriad.Cache; public class MemoryDiscordCache: IDiscordCache { private readonly ConcurrentDictionary _channels = new(); - private readonly ConcurrentDictionary _dmChannels = new(); private readonly ConcurrentDictionary _guildMembers = new(); private readonly ConcurrentDictionary _guilds = new(); private readonly ConcurrentDictionary _roles = new(); @@ -35,10 +34,7 @@ public class MemoryDiscordCache: IDiscordCache if (channel.Recipients != null) foreach (var recipient in channel.Recipients) - { - _dmChannels[recipient.Id] = channel.Id; await SaveUser(recipient); - } } public ValueTask SaveOwnUser(ulong userId) @@ -140,13 +136,6 @@ public class MemoryDiscordCache: IDiscordCache return Task.FromResult(channel); } - public Task TryGetDmChannel(ulong userId) - { - if (!_dmChannels.TryGetValue(userId, out var channelId)) - return Task.FromResult((Channel?)null); - return TryGetChannel(channelId); - } - public Task TryGetUser(ulong userId) { _users.TryGetValue(userId, out var user); diff --git a/Myriad/Extensions/CacheExtensions.cs b/Myriad/Extensions/CacheExtensions.cs index 8e491031..17660002 100644 --- a/Myriad/Extensions/CacheExtensions.cs +++ b/Myriad/Extensions/CacheExtensions.cs @@ -58,17 +58,6 @@ public static class CacheExtensions return restChannel; } - public static async Task GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest, - ulong recipientId) - { - if (await cache.TryGetDmChannel(recipientId) is { } cacheChannel) - return cacheChannel; - - var restChannel = await rest.CreateDm(recipientId); - await cache.SaveChannel(restChannel); - return restChannel; - } - public static async Task GetRootChannel(this IDiscordCache cache, ulong channelOrThread) { var channel = await cache.GetChannel(channelOrThread); diff --git a/PluralKit.Bot/Commands/Api.cs b/PluralKit.Bot/Commands/Api.cs index 7b0cc3ce..6911c0f4 100644 --- a/PluralKit.Bot/Commands/Api.cs +++ b/PluralKit.Bot/Commands/Api.cs @@ -17,12 +17,14 @@ public class Api private readonly BotConfig _botConfig; private readonly DispatchService _dispatch; private readonly ModelRepository _repo; + private readonly PrivateChannelService _dmCache; - public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch) + public Api(BotConfig botConfig, ModelRepository repo, DispatchService dispatch, PrivateChannelService dmCache) { _botConfig = botConfig; _repo = repo; _dispatch = dispatch; + _dmCache = dmCache; } public async Task GetToken(Context ctx) @@ -35,17 +37,17 @@ public class Api try { // DM the user a security disclaimer, and then the token in a separate message (for easy copying on mobile) - var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); - await ctx.Rest.CreateMessage(dm.Id, + var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id); + await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"{Emojis.Warn} Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure." + $" If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`.\n\nYour token is below:" }); - await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token }); + await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token }); if (_botConfig.IsBetaBot) - await ctx.Rest.CreateMessage(dm.Id, new MessageRequest + await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>." + " You need to use this URL instead of the base URL listed on the documentation website." @@ -84,8 +86,8 @@ public class Api try { // DM the user an invalidation disclaimer, and then the token in a separate message (for easy copying on mobile) - var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); - await ctx.Rest.CreateMessage(dm.Id, + var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id); + await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"{Emojis.Warn} Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\n\nYour token is below:" @@ -94,10 +96,10 @@ public class Api // Make the new token after sending the first DM; this ensures if we can't DM, we also don't end up // breaking their existing token as a side effect :) var token = await MakeAndSetNewToken(ctx.System); - await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = token }); + await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = token }); if (_botConfig.IsBetaBot) - await ctx.Rest.CreateMessage(dm.Id, new MessageRequest + await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"{Emojis.Note} The beta bot's API base URL is currently <{_botConfig.BetaBotAPIUrl}>." + " You need to use this URL instead of the base URL listed on the documentation website." diff --git a/PluralKit.Bot/Commands/ImportExport.cs b/PluralKit.Bot/Commands/ImportExport.cs index 3981899f..046e9ce2 100644 --- a/PluralKit.Bot/Commands/ImportExport.cs +++ b/PluralKit.Bot/Commands/ImportExport.cs @@ -17,6 +17,7 @@ public class ImportExport { private readonly HttpClient _client; private readonly DataFileService _dataFiles; + private readonly PrivateChannelService _dmCache; private readonly JsonSerializerSettings _settings = new() { @@ -24,10 +25,11 @@ public class ImportExport DateParseHandling = DateParseHandling.None }; - public ImportExport(DataFileService dataFiles, HttpClient client) + public ImportExport(DataFileService dataFiles, HttpClient client, PrivateChannelService dmCache) { _dataFiles = dataFiles; _client = client; + _dmCache = dmCache; } public async Task Import(Context ctx) @@ -110,12 +112,12 @@ public class ImportExport try { - var dm = await ctx.Cache.GetOrCreateDmChannel(ctx.Rest, ctx.Author.Id); + var dm = await _dmCache.GetOrCreateDmChannel(ctx.Author.Id); - var msg = await ctx.Rest.CreateMessage(dm.Id, + var msg = await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"{Emojis.Success} Here you go!" }, new[] { new MultipartFile("system.json", stream, null) }); - await ctx.Rest.CreateMessage(dm.Id, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" }); + await ctx.Rest.CreateMessage(dm, new MessageRequest { Content = $"<{msg.Attachments[0].Url}>" }); // If the original message wasn't posted in DMs, send a public reminder if (ctx.Channel.Type != Channel.ChannelType.Dm) diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index d97fbf05..68f32b17 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -28,12 +28,13 @@ public class MessageCreated: IEventHandler private readonly DiscordApiClient _rest; private readonly ILifetimeScope _services; private readonly CommandTree _tree; + private readonly PrivateChannelService _dmCache; public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean, IMetrics metrics, ProxyService proxy, CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, ModelRepository repo, IDiscordCache cache, - Bot bot, Cluster cluster, DiscordApiClient rest) + Bot bot, Cluster cluster, DiscordApiClient rest, PrivateChannelService dmCache) { _lastMessageCache = lastMessageCache; _loggerClean = loggerClean; @@ -48,6 +49,7 @@ public class MessageCreated: IEventHandler _bot = bot; _cluster = cluster; _rest = rest; + _dmCache = dmCache; } // for now, only return error messages for explicit commands @@ -66,6 +68,10 @@ public class MessageCreated: IEventHandler if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return; if (IsDuplicateMessage(evt)) return; + // spawn off saving the private channel into another thread + // it is not a fatal error if this fails, and it shouldn't block message processing + _ = _dmCache.TrySavePrivateChannel(evt); + var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null; var channel = await _cache.GetChannel(evt.ChannelId); var rootChannel = await _cache.GetRootChannel(evt.ChannelId); diff --git a/PluralKit.Bot/Handlers/ReactionAdded.cs b/PluralKit.Bot/Handlers/ReactionAdded.cs index 1b8bd461..0d52e6ef 100644 --- a/PluralKit.Bot/Handlers/ReactionAdded.cs +++ b/PluralKit.Bot/Handlers/ReactionAdded.cs @@ -26,10 +26,11 @@ public class ReactionAdded: IEventHandler private readonly ILogger _logger; private readonly ModelRepository _repo; private readonly DiscordApiClient _rest; + private readonly PrivateChannelService _dmCache; public ReactionAdded(ILogger logger, IDatabase db, ModelRepository repo, CommandMessageService commandMessageService, IDiscordCache cache, Bot bot, Cluster cluster, - DiscordApiClient rest, EmbedService embeds) + DiscordApiClient rest, EmbedService embeds, PrivateChannelService dmCache) { _db = db; _repo = repo; @@ -40,6 +41,7 @@ public class ReactionAdded: IEventHandler _rest = rest; _embeds = embeds; _logger = logger.ForContext(); + _dmCache = dmCache; } public async Task Handle(int shardId, MessageReactionAddEvent evt) @@ -168,9 +170,9 @@ public class ReactionAdded: IEventHandler // Try to DM the user info about the message try { - var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId); + var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId); if (msg.Member != null) - await _rest.CreateMessage(dm.Id, new MessageRequest + await _rest.CreateMessage(dm, new MessageRequest { Embed = await _embeds.CreateMemberEmbed( msg.System, @@ -182,7 +184,7 @@ public class ReactionAdded: IEventHandler }); await _rest.CreateMessage( - dm.Id, + dm, new MessageRequest { Embed = await _embeds.CreateMessageInfoEmbed(msg, true) } ); } @@ -234,15 +236,15 @@ public class ReactionAdded: IEventHandler // If not, tell them in DMs (if we can) try { - var dm = await _cache.GetOrCreateDmChannel(_rest, evt.UserId); - await _rest.CreateMessage(dm.Id, + var dm = await _dmCache.GetOrCreateDmChannel(evt.UserId); + await _rest.CreateMessage(dm, new MessageRequest { Content = $"{Emojis.Error} {msg.Member.DisplayName()}'s system has disabled reaction pings. If you want to mention them anyway, you can copy/paste the following message:" }); await _rest.CreateMessage( - dm.Id, + dm, new MessageRequest { Content = $"<@{msg.Message.Sender}>".AsCode() } ); } diff --git a/PluralKit.Bot/Modules.cs b/PluralKit.Bot/Modules.cs index 60ed490c..bf176dd5 100644 --- a/PluralKit.Bot/Modules.cs +++ b/PluralKit.Bot/Modules.cs @@ -43,6 +43,7 @@ public class BotModule: Module }).AsSelf().SingleInstance(); builder.RegisterType().AsSelf().SingleInstance(); builder.Register(c => { return new MemoryDiscordCache(); }).AsSelf().As().SingleInstance(); + builder.RegisterType().AsSelf().SingleInstance(); builder.Register(c => { diff --git a/PluralKit.Bot/Services/PrivateChannelService.cs b/PluralKit.Bot/Services/PrivateChannelService.cs new file mode 100644 index 00000000..4daa58cb --- /dev/null +++ b/PluralKit.Bot/Services/PrivateChannelService.cs @@ -0,0 +1,63 @@ +using Serilog; + +using Myriad.Cache; +using Myriad.Gateway; +using Myriad.Rest; + +using PluralKit.Core; + +namespace PluralKit.Bot; + +public class PrivateChannelService +{ + private readonly ILogger _logger; + private readonly ModelRepository _repo; + private readonly DiscordApiClient _rest; + + private static Dictionary _channelsCache = new(); + public PrivateChannelService(ILogger logger, ModelRepository repo, DiscordApiClient rest) + { + _logger = logger; + _repo = repo; + _rest = rest; + } + + public async Task TrySavePrivateChannel(MessageCreateEvent evt) + { + if (evt.GuildId != null) return; + if (_channelsCache.TryGetValue(evt.Author.Id, out _)) return; + + await SaveDmChannel(evt.Author.Id, evt.ChannelId); + } + + public async Task GetOrCreateDmChannel(ulong userId) + { + if (_channelsCache.TryGetValue(userId, out var cachedChannelId)) + return cachedChannelId; + + var channelId = await _repo.GetDmChannel(userId); + if (channelId == null) + { + var channel = await _rest.CreateDm(userId); + channelId = channel.Id; + } + + // spawn off saving the channel as to not block the current thread + _ = SaveDmChannel(userId, channelId.Value); + + return channelId.Value; + } + + private async Task SaveDmChannel(ulong userId, ulong channelId) + { + try + { + _channelsCache.Add(userId, channelId); + await _repo.UpdateAccount(userId, new() { DmChannel = channelId }); + } + catch (Exception e) + { + _logger.Error(e, "Failed to save DM channel {ChannelId} for user {UserId}", channelId, userId); + } + } +} \ No newline at end of file diff --git a/PluralKit.Core/Database/Migrations/26.sql b/PluralKit.Core/Database/Migrations/26.sql new file mode 100644 index 00000000..a3865649 --- /dev/null +++ b/PluralKit.Core/Database/Migrations/26.sql @@ -0,0 +1,12 @@ +-- schema version 26 +-- cache Discord DM channels in the database + +alter table accounts alter column system drop not null; +alter table accounts drop constraint accounts_system_fkey; +alter table accounts + add constraint accounts_system_fkey + foreign key (system) references systems(id) on delete set null; + +alter table accounts add column dm_channel bigint; + +update info set schema_version = 26; \ No newline at end of file diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Account.cs b/PluralKit.Core/Database/Repository/ModelRepository.Account.cs index cd7d8847..d8676cf5 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Account.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Account.cs @@ -1,9 +1,14 @@ +using Dapper; + using SqlKata; namespace PluralKit.Core; public partial class ModelRepository { + public async Task GetDmChannel(ulong id) + => await _db.Execute(c => c.QueryFirstOrDefaultAsync("select dm_channel from accounts where uid = @id", new { id = id })); + public async Task UpdateAccount(ulong id, AccountPatch patch) { _logger.Information("Updated account {accountId}: {@AccountPatch}", id, patch); diff --git a/PluralKit.Core/Database/Repository/ModelRepository.System.cs b/PluralKit.Core/Database/Repository/ModelRepository.System.cs index 05bb5942..bfaa31d2 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.System.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.System.cs @@ -24,7 +24,8 @@ public partial class ModelRepository var query = new Query("accounts") .Select("systems.*") .LeftJoin("systems", "systems.id", "accounts.system") - .Where("uid", accountId); + .Where("uid", accountId) + .WhereNotNull("system"); return _db.QueryFirst(query); } @@ -111,10 +112,13 @@ public partial class ModelRepository // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent // This is used in import/export, although the pk;link command checks for this case beforehand + // update 2022-01: the accounts table is now independent of systems + // we MUST check for the presence of a system before inserting, or it will move the new account to the current system + var query = new Query("accounts").AsInsert(new { system, uid = accountId }); + await _db.ExecuteQuery(conn, query, "on conflict (uid) do update set system = @p0"); _logger.Information("Linked account {UserId} to {SystemId}", accountId, system); - await _db.ExecuteQuery(conn, query, "on conflict do nothing"); _ = _dispatch.Dispatch(system, new UpdateDispatchData { @@ -125,7 +129,10 @@ public partial class ModelRepository public async Task RemoveAccount(SystemId system, ulong accountId) { - var query = new Query("accounts").AsDelete().Where("uid", accountId).Where("system", system); + var query = new Query("accounts").AsUpdate(new + { + system = (ulong?)null + }).Where("uid", accountId).Where("system", system); await _db.ExecuteQuery(query); _logger.Information("Unlinked account {UserId} from {SystemId}", accountId, system); _ = _dispatch.Dispatch(system, new UpdateDispatchData diff --git a/PluralKit.Core/Database/Utils/DatabaseMigrator.cs b/PluralKit.Core/Database/Utils/DatabaseMigrator.cs index d735b351..777f3eb6 100644 --- a/PluralKit.Core/Database/Utils/DatabaseMigrator.cs +++ b/PluralKit.Core/Database/Utils/DatabaseMigrator.cs @@ -9,7 +9,7 @@ namespace PluralKit.Core; internal class DatabaseMigrator { private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files - private const int TargetSchemaVersion = 25; + private const int TargetSchemaVersion = 26; private readonly ILogger _logger; public DatabaseMigrator(ILogger logger) diff --git a/PluralKit.Core/Models/Patch/AccountPatch.cs b/PluralKit.Core/Models/Patch/AccountPatch.cs index fd0b8ac6..daba9b9c 100644 --- a/PluralKit.Core/Models/Patch/AccountPatch.cs +++ b/PluralKit.Core/Models/Patch/AccountPatch.cs @@ -6,9 +6,11 @@ namespace PluralKit.Core; public class AccountPatch: PatchObject { + public Partial DmChannel { get; set; } public Partial AllowAutoproxy { get; set; } public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper + .With("dm_channel", DmChannel) .With("allow_autoproxy", AllowAutoproxy) );