diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index d734f403..289d1941 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -273,8 +273,9 @@ namespace PluralKit.Bot private DbConnectionFactory _connectionFactory; private IServiceProvider _services; private CommandTree _tree; + private IDataStore _data; - public PKEventHandler(ProxyService proxy, ILogger logger, IMetrics metrics, IDiscordClient client, DbConnectionFactory connectionFactory, IServiceProvider services, CommandTree tree) + public PKEventHandler(ProxyService proxy, ILogger logger, IMetrics metrics, IDiscordClient client, DbConnectionFactory connectionFactory, IServiceProvider services, CommandTree tree, IDataStore data) { _proxy = proxy; _logger = logger; @@ -283,6 +284,7 @@ namespace PluralKit.Bot _connectionFactory = connectionFactory; _services = services; _tree = tree; + _data = data; } public async Task HandleMessage(SocketMessage arg) @@ -298,7 +300,7 @@ namespace PluralKit.Bot // Ignore bot messages if (msg.Author.IsBot || msg.Author.IsWebhook) return; - + int argPos = -1; // Check if message starts with the command prefix if (msg.Content.StartsWith("pk;", StringComparison.InvariantCultureIgnoreCase)) argPos = 3; @@ -307,6 +309,7 @@ namespace PluralKit.Bot if (id != _client.CurrentUser.Id) // But undo it if it's someone else's ping argPos = -1; + // If it does, try executing a command if (argPos > -1) { _logger.Verbose("Parsing command {Command} from message {Channel}-{Message}", msg.Content, msg.Channel.Id, msg.Id); diff --git a/PluralKit.Bot/Commands/CommandTree.cs b/PluralKit.Bot/Commands/CommandTree.cs index 564f0c40..263e99f7 100644 --- a/PluralKit.Bot/Commands/CommandTree.cs +++ b/PluralKit.Bot/Commands/CommandTree.cs @@ -45,7 +45,11 @@ namespace PluralKit.Bot.Commands public static Command Export = new Command("export", "export", "Exports system information to a data file"); public static Command Help = new Command("help", "help", "Shows help information about PluralKit"); public static Command Message = new Command("message", "message ", "Looks up a proxied message"); - public static Command Log = new Command("log", "log ", "Designates a channel to post proxied messages to"); + public static Command LogChannel = new Command("log channel", "log channel ", "Designates a channel to post proxied messages to"); + public static Command LogEnable = new Command("log enable", "log enable all| [channel 2] [channel 3...]", "Enables message logging in certain channels"); + public static Command LogDisable = new Command("log disable", "log disable all| [channel 2] [channel 3...]", "Disables message logging in certain channels"); + public static Command BlacklistAdd = new Command("blacklist add", "blacklist add all| [channel 2] [channel 3...]", "Adds certain channels to the proxy blacklist"); + public static Command BlacklistRemove = new Command("blacklist remove", "blacklist remove all| [channel 2] [channel 3...]", "Removes certain channels from the proxy blacklist"); public static Command Invite = new Command("invite", "invite", "Gets a link to invite PluralKit to other servers"); public static Command PermCheck = new Command("permcheck", "permcheck ", "Checks whether a server's permission setup is correct"); @@ -60,6 +64,8 @@ namespace PluralKit.Bot.Commands }; public static Command[] SwitchCommands = {Switch, SwitchOut, SwitchMove, SwitchDelete}; + + public static Command[] LogCommands = {LogChannel, LogEnable, LogDisable}; private IDiscordClient _client; @@ -100,7 +106,19 @@ namespace PluralKit.Bot.Commands if (ctx.Match("message", "msg")) return ctx.Execute(Message, m => m.GetMessage(ctx)); if (ctx.Match("log")) - return ctx.Execute(Log, m => m.SetLogChannel(ctx)); + if (ctx.Match("channel")) + return ctx.Execute(LogChannel, m => m.SetLogChannel(ctx)); + else if (ctx.Match("enable", "on")) + return ctx.Execute(LogEnable, m => m.SetLogEnabled(ctx, true)); + else if (ctx.Match("disable", "off")) + return ctx.Execute(LogDisable, m => m.SetLogEnabled(ctx, false)); + else return PrintCommandExpectedError(ctx, LogCommands); + if (ctx.Match("blacklist", "bl")) + if (ctx.Match("enable", "on", "add", "deny")) + return ctx.Execute(BlacklistAdd, m => m.SetBlacklisted(ctx, true)); + else if (ctx.Match("disable", "off", "remove", "allow")) + return ctx.Execute(BlacklistRemove, m => m.SetBlacklisted(ctx, false)); + else return PrintCommandExpectedError(ctx, BlacklistAdd, BlacklistRemove); if (ctx.Match("invite")) return ctx.Execute(Invite, m => m.Invite(ctx)); if (ctx.Match("mn")) return ctx.Execute(null, m => m.Mn(ctx)); if (ctx.Match("fire")) return ctx.Execute(null, m => m.Fire(ctx)); diff --git a/PluralKit.Bot/Commands/ModCommands.cs b/PluralKit.Bot/Commands/ModCommands.cs index 996886da..5426144c 100644 --- a/PluralKit.Bot/Commands/ModCommands.cs +++ b/PluralKit.Bot/Commands/ModCommands.cs @@ -1,3 +1,5 @@ +using System.Collections.Generic; +using System.Linq; using System.Text.RegularExpressions; using System.Threading.Tasks; using Discord; @@ -27,8 +29,9 @@ namespace PluralKit.Bot.Commands ITextChannel channel = null; if (ctx.HasNext()) channel = ctx.MatchChannel() ?? throw new PKSyntaxError("You must pass a #channel to set."); + if (channel != null && channel.GuildId != ctx.Guild.Id) throw new PKError("That channel is not in this server!"); - var cfg = await _data.GetGuildConfig(ctx.Guild.Id); + var cfg = await _data.GetOrCreateGuildConfig(ctx.Guild.Id); cfg.LogChannel = channel?.Id; await _data.SaveGuildConfig(cfg); @@ -37,6 +40,57 @@ namespace PluralKit.Bot.Commands else await ctx.Reply($"{Emojis.Success} Proxy logging channel cleared."); } + + public async Task SetLogEnabled(Context ctx, bool enable) + { + ctx.CheckGuildContext().CheckAuthorPermission(GuildPermission.ManageGuild, "Manage Server"); + + var affectedChannels = new List(); + if (ctx.Match("all")) + affectedChannels = (await ctx.Guild.GetChannelsAsync()).OfType().ToList(); + else if (!ctx.HasNext()) throw new PKSyntaxError("You must pass one or more #channels."); + else while (ctx.HasNext()) + { + if (!(ctx.MatchChannel() is ITextChannel channel)) + throw new PKSyntaxError($"Channel \"{ctx.PopArgument().SanitizeMentions()}\" not found."); + if (channel.GuildId != ctx.Guild.Id) throw new PKError($"Channel {ctx.Guild.Id} is not in this server."); + affectedChannels.Add(channel); + } + + var guildCfg = await _data.GetOrCreateGuildConfig(ctx.Guild.Id); + if (enable) guildCfg.LogBlacklist.ExceptWith(affectedChannels.Select(c => c.Id)); + else guildCfg.LogBlacklist.UnionWith(affectedChannels.Select(c => c.Id)); + + await _data.SaveGuildConfig(guildCfg); + await ctx.Reply( + $"{Emojis.Success} Message logging for the given channels {(enable ? "enabled" : "disabled")}." + + (guildCfg.LogChannel == null ? $"\n{Emojis.Warn} Please note that no logging channel is set, so there is nowhere to log messages to. You can set a logging channel using `pk;log channel #your-log-channel`." : "")); + } + + public async Task SetBlacklisted(Context ctx, bool onBlacklist) + { + ctx.CheckGuildContext().CheckAuthorPermission(GuildPermission.ManageGuild, "Manage Server"); + + var affectedChannels = new List(); + if (ctx.Match("all")) + affectedChannels = (await ctx.Guild.GetChannelsAsync()).OfType().ToList(); + else if (!ctx.HasNext()) throw new PKSyntaxError("You must pass one or more #channels."); + else while (ctx.HasNext()) + { + if (!(ctx.MatchChannel() is ITextChannel channel)) + throw new PKSyntaxError($"Channel \"{ctx.PopArgument().SanitizeMentions()}\" not found."); + if (channel.GuildId != ctx.Guild.Id) throw new PKError($"Channel {ctx.Guild.Id} is not in this server."); + affectedChannels.Add(channel); + } + + var guildCfg = await _data.GetOrCreateGuildConfig(ctx.Guild.Id); + if (onBlacklist) guildCfg.Blacklist.UnionWith(affectedChannels.Select(c => c.Id)); + else guildCfg.Blacklist.ExceptWith(affectedChannels.Select(c => c.Id)); + + await _data.SaveGuildConfig(guildCfg); + await ctx.Reply($"{Emojis.Success} Channels {(onBlacklist ? "added to" : "removed from")} the proxy blacklist."); + + } public async Task GetMessage(Context ctx) { diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index 9faf81f6..b3f51533 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -18,21 +18,21 @@ namespace PluralKit.Bot { _logger = logger.ForContext(); } - public async Task LogMessage(PKSystem system, PKMember member, ulong messageId, ulong originalMsgId, IGuildChannel originalChannel, IUser sender, string content) { - var logChannel = await GetLogChannel(originalChannel.Guild); - if (logChannel == null) return; + public async Task LogMessage(PKSystem system, PKMember member, ulong messageId, ulong originalMsgId, IGuildChannel originalChannel, IUser sender, string content) + { + var guildCfg = await _data.GetOrCreateGuildConfig(originalChannel.GuildId); + + // Bail if logging is disabled either globally or for this channel + if (guildCfg.LogChannel == null) return; + if (guildCfg.LogBlacklist.Contains(originalChannel.Id)) return; + + // Bail if we can't find the channel + if (!(await _client.GetChannelAsync(guildCfg.LogChannel.Value) is ITextChannel logChannel)) return; var embed = _embed.CreateLoggedMessageEmbed(system, member, messageId, originalMsgId, sender, content, originalChannel); var url = $"https://discordapp.com/channels/{originalChannel.GuildId}/{originalChannel.Id}/{messageId}"; await logChannel.SendMessageAsync(text: url, embed: embed); } - - private async Task GetLogChannel(IGuild guild) - { - var guildCfg = await _data.GetGuildConfig(guild.Id); - if (guildCfg.LogChannel == null) return null; - return await _client.GetChannelAsync(guildCfg.LogChannel.Value) as ITextChannel; - } } } \ No newline at end of file diff --git a/PluralKit.Bot/Services/ProxyService.cs b/PluralKit.Bot/Services/ProxyService.cs index b5e085ad..c53b0f1f 100644 --- a/PluralKit.Bot/Services/ProxyService.cs +++ b/PluralKit.Bot/Services/ProxyService.cs @@ -80,17 +80,20 @@ namespace PluralKit.Bot public async Task HandleMessageAsync(IMessage message) { // Bail early if this isn't in a guild channel - if (!(message.Channel is ITextChannel)) return; - - var results = await _cache.GetResultsFor(message.Author.Id); - + if (!(message.Channel is ITextChannel channel)) return; + // Find a member with proxy tags matching the message + var results = await _cache.GetResultsFor(message.Author.Id); var match = GetProxyTagMatch(message.Content, results); if (match == null) return; + + // And make sure the channel's not blacklisted from proxying. + var guildCfg = await _data.GetOrCreateGuildConfig(channel.GuildId); + if (guildCfg.Blacklist.Contains(channel.Id)) return; // We know message.Channel can only be ITextChannel as PK doesn't work in DMs/groups // Afterwards we ensure the bot has the right permissions, otherwise bail early - if (!await EnsureBotPermissions(message.Channel as ITextChannel)) return; + if (!await EnsureBotPermissions(channel)) return; // Can't proxy a message with no content and no attachment if (match.InnerText.Trim().Length == 0 && message.Attachments.Count == 0) @@ -114,7 +117,7 @@ namespace PluralKit.Bot // Execute the webhook itself var hookMessageId = await _webhookExecutor.ExecuteWebhook( - (ITextChannel) message.Channel, + channel, proxyName, avatarUrl, messageContents, message.Attachments.FirstOrDefault() diff --git a/PluralKit.Core/Stores.cs b/PluralKit.Core/Stores.cs index ebfa79c8..3a9c93d3 100644 --- a/PluralKit.Core/Stores.cs +++ b/PluralKit.Core/Stores.cs @@ -62,6 +62,15 @@ namespace PluralKit { { public ulong Id { get; set; } public ulong? LogChannel { get; set; } + public ISet LogBlacklist { get; set; } + public ISet Blacklist { get; set; } + } + + public struct ChannelConfig + { + public ulong Id { get; set; } + public bool OnList { get; set; } + public bool LogMessages { get; set; } } public interface IDataStore @@ -329,10 +338,10 @@ namespace PluralKit { Task GetTotalMessages(); /// - /// Gets the guild configuration struct for a given guild. + /// Gets the guild configuration struct for a given guild, creating and saving one if none was found. /// - /// The guild's configuration struct, or a default struct if no guild was found in the data store. - Task GetGuildConfig(ulong guild); + /// The guild's configuration struct. + Task GetOrCreateGuildConfig(ulong guild); /// /// Saves the given guild configuration struct to the data store. @@ -596,28 +605,45 @@ namespace PluralKit { return await conn.ExecuteScalarAsync("select count(mid) from messages"); } - public async Task GetGuildConfig(ulong guild) + // Same as GuildConfig, but with ISet as long[] instead. + private struct DatabaseCompatibleGuildConfig + { + public ulong Id { get; set; } + public ulong? LogChannel { get; set; } + public long[] LogBlacklist { get; set; } + public long[] Blacklist { get; set; } + } + + public async Task GetOrCreateGuildConfig(ulong guild) { using (var conn = await _conn.Obtain()) { - var cfg = await conn.QuerySingleOrDefaultAsync("select * from servers where id = @Id", + var compat = await conn.QuerySingleOrDefaultAsync( + "insert into servers (id) values (@Id) on conflict do nothing; select * from servers where id = @Id", new {Id = guild}); - - if (cfg.Id == 0) - // No entry was found in the db, this is the default entry returned - cfg.Id = guild; - - return cfg; + return new GuildConfig + { + Id = compat.Id, + LogChannel = compat.LogChannel, + LogBlacklist = new HashSet(compat.LogBlacklist.Select(c => (ulong) c)), + Blacklist = new HashSet(compat.Blacklist.Select(c => (ulong) c)), + }; } } public async Task SaveGuildConfig(GuildConfig cfg) { using (var conn = await _conn.Obtain()) - await conn.ExecuteAsync("insert into servers (id, log_channel) values (@Id, @LogChannel) on conflict (id) do update set log_channel = @LogChannel", cfg); + await conn.ExecuteAsync("insert into servers (id, log_channel, log_blacklist, blacklist) values (@Id, @LogChannel, @LogBlacklist, @Blacklist) on conflict (id) do update set log_channel = @LogChannel, log_blacklist = @LogBlacklist, blacklist = @Blacklist", new + { + cfg.Id, + cfg.LogChannel, + LogBlacklist = cfg.LogBlacklist.Select(c => (long) c).ToList(), + Blacklist = cfg.Blacklist.Select(c => (long) c).ToList() + }); _logger.Information("Updated guild configuration {@GuildCfg}", cfg); } - + public async Task AddSwitch(PKSystem system, IEnumerable members) { // Use a transaction here since we're doing multiple executed commands in one diff --git a/PluralKit.Core/db_schema.sql b/PluralKit.Core/db_schema.sql index fd9706b8..d401a6d4 100644 --- a/PluralKit.Core/db_schema.sql +++ b/PluralKit.Core/db_schema.sql @@ -83,6 +83,8 @@ create table if not exists webhooks create table if not exists servers ( - id bigint primary key, - log_channel bigint + id bigint primary key, + log_channel bigint, + log_blacklist bigint[] not null default array[]::bigint[], + blacklist bigint[] not null default array[]::bigint[] ); \ No newline at end of file