Port some things, still does not compile

This commit is contained in:
Ske 2020-04-17 23:10:01 +02:00
parent f56c3e819f
commit 23cf06df4c
18 changed files with 543 additions and 538 deletions

View File

@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Data;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -7,8 +8,10 @@ using App.Metrics;
using Autofac; using Autofac;
using Discord; using DSharpPlus;
using Discord.WebSocket; using DSharpPlus.Entities;
using DSharpPlus.EventArgs;
using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Configuration;
using PluralKit.Core; using PluralKit.Core;
@ -61,7 +64,6 @@ namespace PluralKit.Bot
SchemaService.Initialize(); SchemaService.Initialize();
var coreConfig = services.Resolve<CoreConfig>(); var coreConfig = services.Resolve<CoreConfig>();
var botConfig = services.Resolve<BotConfig>();
var schema = services.Resolve<SchemaService>(); var schema = services.Resolve<SchemaService>();
using var _ = Sentry.SentrySdk.Init(coreConfig.SentryUrl); using var _ = Sentry.SentrySdk.Init(coreConfig.SentryUrl);
@ -71,10 +73,9 @@ namespace PluralKit.Bot
logger.Information("Connecting to Discord"); logger.Information("Connecting to Discord");
var client = services.Resolve<DiscordShardedClient>(); var client = services.Resolve<DiscordShardedClient>();
await client.LoginAsync(TokenType.Bot, botConfig.Token);
logger.Information("Initializing bot");
await client.StartAsync(); await client.StartAsync();
logger.Information("Initializing bot");
await services.Resolve<Bot>().Init(); await services.Resolve<Bot>().Init();
try try
@ -105,10 +106,10 @@ namespace PluralKit.Bot
private WebhookRateLimitService _webhookRateLimit; private WebhookRateLimitService _webhookRateLimit;
private int _periodicUpdateCount; private int _periodicUpdateCount;
public Bot(ILifetimeScope services, IDiscordClient client, IMetrics metrics, PeriodicStatCollector collector, ILogger logger, WebhookRateLimitService webhookRateLimit) public Bot(ILifetimeScope services, DiscordShardedClient client, IMetrics metrics, PeriodicStatCollector collector, ILogger logger, WebhookRateLimitService webhookRateLimit)
{ {
_services = services; _services = services;
_client = client as DiscordShardedClient; _client = client;
_metrics = metrics; _metrics = metrics;
_collector = collector; _collector = collector;
_webhookRateLimit = webhookRateLimit; _webhookRateLimit = webhookRateLimit;
@ -117,53 +118,51 @@ namespace PluralKit.Bot
public Task Init() public Task Init()
{ {
_client.ShardDisconnected += ShardDisconnected; // _client.ShardDisconnected += ShardDisconnected;
_client.ShardReady += ShardReady; // _client.ShardReady += ShardReady;
_client.Log += FrameworkLog; _client.DebugLogger.LogMessageReceived += FrameworkLog;
_client.MessageReceived += (msg) => HandleEvent(eh => eh.HandleMessage(msg)); _client.MessageCreated += args => HandleEvent(eh => eh.HandleMessage(args));
_client.ReactionAdded += (msg, channel, reaction) => HandleEvent(eh => eh.HandleReactionAdded(msg, channel, reaction)); _client.MessageReactionAdded += args => HandleEvent(eh => eh.HandleReactionAdded(args));
_client.MessageDeleted += (msg, channel) => HandleEvent(eh => eh.HandleMessageDeleted(msg, channel)); _client.MessageDeleted += args => HandleEvent(eh => eh.HandleMessageDeleted(args));
_client.MessagesBulkDeleted += (msgs, channel) => HandleEvent(eh => eh.HandleMessagesBulkDelete(msgs, channel)); _client.MessagesBulkDeleted += args => HandleEvent(eh => eh.HandleMessagesBulkDelete(args));
_client.MessageUpdated += (oldMessage, newMessage, channel) => HandleEvent(eh => eh.HandleMessageEdited(oldMessage, newMessage, channel)); _client.MessageUpdated += args => HandleEvent(eh => eh.HandleMessageEdited(args));
_services.Resolve<ShardInfoService>().Init(_client); _services.Resolve<ShardInfoService>().Init(_client);
return Task.CompletedTask; return Task.CompletedTask;
} }
private Task ShardDisconnected(Exception ex, DiscordSocketClient shard) /*private Task ShardDisconnected(Exception ex, DiscordSocketClient shard)
{ {
_logger.Warning(ex, $"Shard #{shard.ShardId} disconnected"); _logger.Warning(ex, $"Shard #{shard.ShardId} disconnected");
return Task.CompletedTask; return Task.CompletedTask;
} }*/
private Task FrameworkLog(LogMessage msg) private void FrameworkLog(object sender, DebugLogMessageEventArgs args)
{ {
// Bridge D.NET logging to Serilog // Bridge D#+ logging to Serilog
LogEventLevel level = LogEventLevel.Verbose; LogEventLevel level = LogEventLevel.Verbose;
if (msg.Severity == LogSeverity.Critical) if (args.Level == LogLevel.Critical)
level = LogEventLevel.Fatal; level = LogEventLevel.Fatal;
else if (msg.Severity == LogSeverity.Debug) else if (args.Level == LogLevel.Debug)
level = LogEventLevel.Debug; level = LogEventLevel.Debug;
else if (msg.Severity == LogSeverity.Error) else if (args.Level == LogLevel.Error)
level = LogEventLevel.Error; level = LogEventLevel.Error;
else if (msg.Severity == LogSeverity.Info) else if (args.Level == LogLevel.Info)
level = LogEventLevel.Information; level = LogEventLevel.Information;
else if (msg.Severity == LogSeverity.Debug) // D.NET's lowest level is Debug and Verbose is greater, Serilog's is the other way around else if (args.Level == LogLevel.Warning)
level = LogEventLevel.Verbose; level = LogEventLevel.Warning;
else if (msg.Severity == LogSeverity.Verbose)
level = LogEventLevel.Debug;
_logger.Write(level, msg.Exception, "Discord.Net {Source}: {Message}", msg.Source, msg.Message); _logger.Write(level, args.Exception, "D#+ {Source}: {Message}", args.Application, args.Message);
return Task.CompletedTask;
} }
// Method called every 60 seconds // Method called every 60 seconds
private async Task UpdatePeriodic() private async Task UpdatePeriodic()
{ {
// Change bot status // Change bot status
await _client.SetGameAsync($"pk;help | in {_client.Guilds.Count} servers"); var totalGuilds = _client.ShardClients.Values.Sum(c => c.Guilds.Count);
await _client.UpdateStatusAsync(new DiscordActivity($"pk;help | in {totalGuilds} servers"));
// Run webhook rate limit GC every 10 minutes // Run webhook rate limit GC every 10 minutes
if (_periodicUpdateCount++ % 10 == 0) if (_periodicUpdateCount++ % 10 == 0)
@ -177,7 +176,7 @@ namespace PluralKit.Bot
await Task.WhenAll(((IMetricsRoot) _metrics).ReportRunner.RunAllAsync()); await Task.WhenAll(((IMetricsRoot) _metrics).ReportRunner.RunAllAsync());
} }
private Task ShardReady(DiscordSocketClient shardClient) /*private Task ShardReady(DiscordSocketClient shardClient)
{ {
_logger.Information("Shard {Shard} connected to {ChannelCount} channels in {GuildCount} guilds", shardClient.ShardId, shardClient.Guilds.Sum(g => g.Channels.Count), shardClient.Guilds.Count); _logger.Information("Shard {Shard} connected to {ChannelCount} channels in {GuildCount} guilds", shardClient.ShardId, shardClient.Guilds.Sum(g => g.Channels.Count), shardClient.Guilds.Count);
@ -191,7 +190,7 @@ namespace PluralKit.Bot
} }
return Task.CompletedTask; return Task.CompletedTask;
} }*/
private Task HandleEvent(Func<PKEventHandler, Task> handler) private Task HandleEvent(Func<PKEventHandler, Task> handler)
{ {
@ -252,7 +251,7 @@ namespace PluralKit.Bot
// This means that the HandleMessage function will either be called once, or not at all // This means that the HandleMessage function will either be called once, or not at all
// The ReportError function will be called on an error, and needs to refer back to the "trigger message" // The ReportError function will be called on an error, and needs to refer back to the "trigger message"
// hence, we just store it in a local variable, ignoring it entirely if it's null. // hence, we just store it in a local variable, ignoring it entirely if it's null.
private IUserMessage _msg = null; private DiscordMessage _currentlyHandlingMessage = null;
public PKEventHandler(ProxyService proxy, ILogger logger, IMetrics metrics, DiscordShardedClient client, DbConnectionFactory connectionFactory, ILifetimeScope services, CommandTree tree, Scope sentryScope, ProxyCache cache, LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean) public PKEventHandler(ProxyService proxy, ILogger logger, IMetrics metrics, DiscordShardedClient client, DbConnectionFactory connectionFactory, ILifetimeScope services, CommandTree tree, Scope sentryScope, ProxyCache cache, LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean)
{ {
@ -269,42 +268,44 @@ namespace PluralKit.Bot
_loggerClean = loggerClean; _loggerClean = loggerClean;
} }
public async Task HandleMessage(SocketMessage arg) public async Task HandleMessage(MessageCreateEventArgs args)
{ {
var shard = _client.GetShardFor((arg.Channel as IGuildChannel)?.Guild); // TODO
/*var shard = _client.GetShardFor((arg.Channel as IGuildChannel)?.Guild);
if (shard.ConnectionState != ConnectionState.Connected || _client.CurrentUser == null) if (shard.ConnectionState != ConnectionState.Connected || _client.CurrentUser == null)
return; // Discard messages while the bot "catches up" to avoid unnecessary CPU pressure causing timeouts return; // Discard messages while the bot "catches up" to avoid unnecessary CPU pressure causing timeouts*/
RegisterMessageMetrics(arg); RegisterMessageMetrics(args);
// Ignore system messages (member joined, message pinned, etc) // Ignore system messages (member joined, message pinned, etc)
var msg = arg as SocketUserMessage; var msg = args.Message;
if (msg == null) return; if (msg.MessageType != MessageType.Default) return;
// Fetch information about the guild early, as we need it for the logger cleanup // Fetch information about the guild early, as we need it for the logger cleanup
GuildConfig cachedGuild = default; // todo: is this default correct? GuildConfig cachedGuild = default;
if (msg.Channel is ITextChannel textChannel) cachedGuild = await _cache.GetGuildDataCached(textChannel.GuildId); if (msg.Channel.Type == ChannelType.Text) await _cache.GetGuildDataCached(msg.Channel.GuildId);
// Pass guild bot/WH messages onto the logger cleanup service, but otherwise ignore // Pass guild bot/WH messages onto the logger cleanup service, but otherwise ignore
if ((msg.Author.IsBot || msg.Author.IsWebhook) && msg.Channel is ITextChannel) if (msg.Author.IsBot && msg.Channel.Type == ChannelType.Text)
{ {
await _loggerClean.HandleLoggerBotCleanup(arg, cachedGuild); await _loggerClean.HandleLoggerBotCleanup(msg, cachedGuild);
return; return;
} }
_currentlyHandlingMessage = msg;
// Add message info as Sentry breadcrumb // Add message info as Sentry breadcrumb
_msg = msg;
_sentryScope.AddBreadcrumb(msg.Content, "event.message", data: new Dictionary<string, string> _sentryScope.AddBreadcrumb(msg.Content, "event.message", data: new Dictionary<string, string>
{ {
{"user", msg.Author.Id.ToString()}, {"user", msg.Author.Id.ToString()},
{"channel", msg.Channel.Id.ToString()}, {"channel", msg.Channel.Id.ToString()},
{"guild", ((msg.Channel as IGuildChannel)?.GuildId ?? 0).ToString()}, {"guild", msg.Channel.GuildId.ToString()},
{"message", msg.Id.ToString()}, {"message", msg.Id.ToString()},
}); });
_sentryScope.SetTag("shard", shard.ShardId.ToString()); _sentryScope.SetTag("shard", args.Client.ShardId.ToString());
// Add to last message cache // Add to last message cache
_lastMessageCache.AddMessage(arg.Channel.Id, arg.Id); _lastMessageCache.AddMessage(msg.Channel.Id, msg.Id);
// We fetch information about the sending account from the cache // We fetch information about the sending account from the cache
var cachedAccount = await _cache.GetAccountDataCached(msg.Author.Id); var cachedAccount = await _cache.GetAccountDataCached(msg.Author.Id);
@ -330,7 +331,7 @@ namespace PluralKit.Bot
try try
{ {
await _tree.ExecuteCommand(new Context(_services, msg, argPos, cachedAccount?.System)); await _tree.ExecuteCommand(new Context(_services, args.Client, msg, argPos, cachedAccount?.System));
} }
catch (PKError) catch (PKError)
{ {
@ -345,12 +346,12 @@ namespace PluralKit.Bot
// no data = no account = no system = no proxy! // no data = no account = no system = no proxy!
try try
{ {
await _proxy.HandleMessageAsync(cachedGuild, cachedAccount, msg, doAutoProxy: true); await _proxy.HandleMessageAsync(args.Client, cachedGuild, cachedAccount, msg, doAutoProxy: true);
} }
catch (PKError e) catch (PKError e)
{ {
if (arg.Channel.HasPermission(ChannelPermission.SendMessages)) if (msg.Channel.Guild == null || msg.Channel.BotHasPermission(Permissions.SendMessages))
await arg.Channel.SendMessageAsync($"{Emojis.Error} {e.Message}"); await msg.Channel.SendMessageAsync($"{Emojis.Error} {e.Message}");
} }
} }
} }
@ -358,98 +359,95 @@ namespace PluralKit.Bot
public async Task ReportError(SentryEvent evt, Exception exc) public async Task ReportError(SentryEvent evt, Exception exc)
{ {
// If we don't have a "trigger message", bail // If we don't have a "trigger message", bail
if (_msg == null) return; if (_currentlyHandlingMessage == null) return;
// This function *specifically* handles reporting a command execution error to the user. // This function *specifically* handles reporting a command execution error to the user.
// We'll fetch the event ID and send a user-facing error message. // We'll fetch the event ID and send a user-facing error message.
// ONLY IF this error's actually our problem. As for what defines an error as "our problem", // ONLY IF this error's actually our problem. As for what defines an error as "our problem",
// check the extension method :) // check the extension method :)
if (exc.IsOurProblem() && _msg.Channel.HasPermission(ChannelPermission.SendMessages)) if (exc.IsOurProblem() && _currentlyHandlingMessage.Channel.BotHasPermission(Permissions.SendMessages))
{ {
var eid = evt.EventId; var eid = evt.EventId;
await _msg.Channel.SendMessageAsync( await _currentlyHandlingMessage.Channel.SendMessageAsync(
$"{Emojis.Error} Internal error occurred. Please join the support server (<https://discord.gg/PczBt78>), and send the developer this ID: `{eid}`\nBe sure to include a description of what you were doing to make the error occur."); $"{Emojis.Error} Internal error occurred. Please join the support server (<https://discord.gg/PczBt78>), and send the developer this ID: `{eid}`\nBe sure to include a description of what you were doing to make the error occur.");
} }
// If not, don't care. lol. // If not, don't care. lol.
} }
private void RegisterMessageMetrics(SocketMessage msg) private void RegisterMessageMetrics(MessageCreateEventArgs msg)
{ {
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived); _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
var gatewayLatency = DateTimeOffset.Now - msg.CreatedAt; var gatewayLatency = DateTimeOffset.Now - msg.Message.Timestamp;
_logger.Verbose("Message received with latency {Latency}", gatewayLatency); _logger.Verbose("Message received with latency {Latency}", gatewayLatency);
} }
public Task HandleReactionAdded(Cacheable<IUserMessage, ulong> message, ISocketMessageChannel channel, public Task HandleReactionAdded(MessageReactionAddEventArgs args)
SocketReaction reaction)
{ {
_sentryScope.AddBreadcrumb("", "event.reaction", data: new Dictionary<string, string>() _sentryScope.AddBreadcrumb("", "event.reaction", data: new Dictionary<string, string>()
{ {
{"user", reaction.UserId.ToString()}, {"user", args.User.Id.ToString()},
{"channel", channel.Id.ToString()}, {"channel", (args.Channel?.Id ?? 0).ToString()},
{"guild", ((channel as IGuildChannel)?.GuildId ?? 0).ToString()}, {"guild", (args.Channel?.GuildId ?? 0).ToString()},
{"message", message.Id.ToString()}, {"message", args.Message.Id.ToString()},
{"reaction", reaction.Emote.Name} {"reaction", args.Emoji.Name}
}); });
_sentryScope.SetTag("shard", _client.GetShardIdFor((channel as IGuildChannel)?.Guild).ToString()); _sentryScope.SetTag("shard", args.Client.ShardId.ToString());
return _proxy.HandleReactionAddedAsync(args);
return _proxy.HandleReactionAddedAsync(message, channel, reaction);
} }
public Task HandleMessageDeleted(Cacheable<IMessage, ulong> message, ISocketMessageChannel channel) public Task HandleMessageDeleted(MessageDeleteEventArgs args)
{ {
_sentryScope.AddBreadcrumb("", "event.messageDelete", data: new Dictionary<string, string>() _sentryScope.AddBreadcrumb("", "event.messageDelete", data: new Dictionary<string, string>()
{ {
{"channel", channel.Id.ToString()}, {"channel", args.Channel.Id.ToString()},
{"guild", ((channel as IGuildChannel)?.GuildId ?? 0).ToString()}, {"guild", args.Channel.GuildId.ToString()},
{"message", message.Id.ToString()}, {"message", args.Message.Id.ToString()},
}); });
_sentryScope.SetTag("shard", _client.GetShardIdFor((channel as IGuildChannel)?.Guild).ToString()); _sentryScope.SetTag("shard", args.Client.ShardId.ToString());
return _proxy.HandleMessageDeletedAsync(message, channel); return _proxy.HandleMessageDeletedAsync(args);
} }
public Task HandleMessagesBulkDelete(IReadOnlyCollection<Cacheable<IMessage, ulong>> messages, public Task HandleMessagesBulkDelete(MessageBulkDeleteEventArgs args)
IMessageChannel channel)
{ {
_sentryScope.AddBreadcrumb("", "event.messageDelete", data: new Dictionary<string, string>() _sentryScope.AddBreadcrumb("", "event.messageDelete", data: new Dictionary<string, string>()
{ {
{"channel", channel.Id.ToString()}, {"channel", args.Channel.Id.ToString()},
{"guild", ((channel as IGuildChannel)?.GuildId ?? 0).ToString()}, {"guild", args.Channel.Id.ToString()},
{"messages", string.Join(",", messages.Select(m => m.Id))}, {"messages", string.Join(",", args.Messages.Select(m => m.Id))},
}); });
_sentryScope.SetTag("shard", _client.GetShardIdFor((channel as IGuildChannel)?.Guild).ToString()); _sentryScope.SetTag("shard", args.Client.ShardId.ToString());
return _proxy.HandleMessageBulkDeleteAsync(messages, channel); return _proxy.HandleMessageBulkDeleteAsync(args);
} }
public async Task HandleMessageEdited(Cacheable<IMessage, ulong> oldMessage, SocketMessage newMessage, ISocketMessageChannel channel) public async Task HandleMessageEdited(MessageUpdateEventArgs args)
{ {
_sentryScope.AddBreadcrumb(newMessage.Content, "event.messageEdit", data: new Dictionary<string, string>() _sentryScope.AddBreadcrumb(args.Message.Content ?? "<unknown>", "event.messageEdit", data: new Dictionary<string, string>()
{ {
{"channel", channel.Id.ToString()}, {"channel", args.Channel.Id.ToString()},
{"guild", ((channel as IGuildChannel)?.GuildId ?? 0).ToString()}, {"guild", args.Channel.GuildId.ToString()},
{"message", newMessage.Id.ToString()} {"message", args.Message.Id.ToString()}
}); });
_sentryScope.SetTag("shard", _client.GetShardIdFor((channel as IGuildChannel)?.Guild).ToString()); _sentryScope.SetTag("shard", args.Client.ShardId.ToString());
// If this isn't a guild, bail // If this isn't a guild, bail
if (!(channel is IGuildChannel gc)) return; if (args.Channel.Guild == null) return;
// If this isn't the last message in the channel, don't do anything // If this isn't the last message in the channel, don't do anything
if (_lastMessageCache.GetLastMessage(channel.Id) != newMessage.Id) return; if (_lastMessageCache.GetLastMessage(args.Channel.Id) != args.Message.Id) return;
// Fetch account from cache if there is any // Fetch account from cache if there is any
var account = await _cache.GetAccountDataCached(newMessage.Author.Id); var account = await _cache.GetAccountDataCached(args.Author.Id);
if (account == null) return; // Again: no cache = no account = no system = no proxy if (account == null) return; // Again: no cache = no account = no system = no proxy
// Also fetch guild cache // Also fetch guild cache
var guild = await _cache.GetGuildDataCached(gc.GuildId); var guild = await _cache.GetGuildDataCached(args.Channel.GuildId);
// Just run the normal message handling stuff // Just run the normal message handling stuff
await _proxy.HandleMessageAsync(guild, account, newMessage, doAutoProxy: false); await _proxy.HandleMessageAsync(args.Client, guild, account, args.Message, doAutoProxy: false);
} }
} }
} }

View File

@ -9,6 +9,9 @@ using Autofac;
using Discord; using Discord;
using Discord.WebSocket; using Discord.WebSocket;
using DSharpPlus;
using DSharpPlus.Entities;
using PluralKit.Core; using PluralKit.Core;
namespace PluralKit.Bot namespace PluralKit.Bot
@ -18,7 +21,8 @@ namespace PluralKit.Bot
private ILifetimeScope _provider; private ILifetimeScope _provider;
private readonly DiscordShardedClient _client; private readonly DiscordShardedClient _client;
private readonly SocketUserMessage _message; private readonly DiscordClient _shard;
private readonly DiscordMessage _message;
private readonly Parameters _parameters; private readonly Parameters _parameters;
private readonly IDataStore _data; private readonly IDataStore _data;
@ -27,11 +31,12 @@ namespace PluralKit.Bot
private Command _currentCommand; private Command _currentCommand;
public Context(ILifetimeScope provider, SocketUserMessage message, int commandParseOffset, public Context(ILifetimeScope provider, DiscordClient shard, DiscordMessage message, int commandParseOffset,
PKSystem senderSystem) PKSystem senderSystem)
{ {
_client = provider.Resolve<DiscordShardedClient>(); _client = provider.Resolve<DiscordShardedClient>();
_message = message; _message = message;
_shard = shard;
_data = provider.Resolve<IDataStore>(); _data = provider.Resolve<IDataStore>();
_senderSystem = senderSystem; _senderSystem = senderSystem;
_metrics = provider.Resolve<IMetrics>(); _metrics = provider.Resolve<IMetrics>();
@ -39,11 +44,11 @@ namespace PluralKit.Bot
_parameters = new Parameters(message.Content.Substring(commandParseOffset)); _parameters = new Parameters(message.Content.Substring(commandParseOffset));
} }
public IUser Author => _message.Author; public DiscordUser Author => _message.Author;
public IMessageChannel Channel => _message.Channel; public DiscordChannel Channel => _message.Channel;
public IUserMessage Message => _message; public DiscordMessage Message => _message;
public IGuild Guild => (_message.Channel as ITextChannel)?.Guild; public DiscordGuild Guild => _message.Channel.Guild;
public DiscordSocketClient Shard => _client.GetShardFor(Guild); public DiscordClient Shard => _shard;
public DiscordShardedClient Client => _client; public DiscordShardedClient Client => _client;
public PKSystem System => _senderSystem; public PKSystem System => _senderSystem;
@ -53,13 +58,13 @@ namespace PluralKit.Bot
public bool HasNext(bool skipFlags = true) => RemainderOrNull(skipFlags) != null; public bool HasNext(bool skipFlags = true) => RemainderOrNull(skipFlags) != null;
public string FullCommand => _parameters.FullCommand; public string FullCommand => _parameters.FullCommand;
public Task<IUserMessage> Reply(string text = null, Embed embed = null) public Task<DiscordMessage> Reply(string text = null, DiscordEmbed embed = null)
{ {
if (!this.BotHasPermission(ChannelPermission.SendMessages)) if (!this.BotHasPermission(Permissions.SendMessages))
// Will be "swallowed" during the error handler anyway, this message is never shown. // Will be "swallowed" during the error handler anyway, this message is never shown.
throw new PKError("PluralKit does not have permission to send messages in this channel."); throw new PKError("PluralKit does not have permission to send messages in this channel.");
if (embed != null && !this.BotHasPermission(ChannelPermission.EmbedLinks)) if (embed != null && !this.BotHasPermission(Permissions.EmbedLinks))
throw new PKError("PluralKit does not have permission to send embeds in this channel. Please ensure I have the **Embed Links** permission enabled."); throw new PKError("PluralKit does not have permission to send embeds in this channel. Please ensure I have the **Embed Links** permission enabled.");
return Channel.SendMessageAsync(text, embed: embed); return Channel.SendMessageAsync(text, embed: embed);
@ -125,11 +130,11 @@ namespace PluralKit.Bot
} }
} }
public async Task<IUser> MatchUser() public async Task<DiscordUser> MatchUser()
{ {
var text = PeekArgument(); var text = PeekArgument();
if (MentionUtils.TryParseUser(text, out var id)) if (text.TryParseMention(out var id))
return await Shard.Rest.GetUserAsync(id); // TODO: this should properly fetch return await Shard.GetUserAsync(id);
return null; return null;
} }
@ -138,11 +143,9 @@ namespace PluralKit.Bot
id = 0; id = 0;
var text = PeekArgument(); var text = PeekArgument();
if (MentionUtils.TryParseUser(text, out var mentionId)) if (text.TryParseMention(out var mentionId))
id = mentionId; id = mentionId;
else if (ulong.TryParse(text, out var rawId))
id = rawId;
return id != 0; return id != 0;
} }
@ -246,41 +249,19 @@ namespace PluralKit.Bot
return this; return this;
} }
public GuildPermissions GetGuildPermissions(IUser user) public Context CheckAuthorPermission(Permissions neededPerms, string permissionName)
{ {
if (user is IGuildUser gu) // TODO: can we always assume Author is a DiscordMember? I would think so, given they always come from a
return gu.GuildPermissions; // message received event...
if (Channel is SocketGuildChannel gc) var hasPerms = Channel.PermissionsInSync(Author);
return gc.GetUser(user.Id).GuildPermissions; if ((hasPerms & neededPerms) != neededPerms)
return GuildPermissions.None;
}
public ChannelPermissions GetChannelPermissions(IUser user)
{
if (user is IGuildUser gu && Channel is IGuildChannel igc)
return gu.GetPermissions(igc);
if (Channel is SocketGuildChannel gc)
return gc.GetUser(user.Id).GetPermissions(gc);
return ChannelPermissions.DM;
}
public Context CheckAuthorPermission(GuildPermission permission, string permissionName)
{
if (!GetGuildPermissions(Author).Has(permission))
throw new PKError($"You must have the \"{permissionName}\" permission in this server to use this command.");
return this;
}
public Context CheckAuthorPermission(ChannelPermission permission, string permissionName)
{
if (!GetChannelPermissions(Author).Has(permission))
throw new PKError($"You must have the \"{permissionName}\" permission in this server to use this command."); throw new PKError($"You must have the \"{permissionName}\" permission in this server to use this command.");
return this; return this;
} }
public Context CheckGuildContext() public Context CheckGuildContext()
{ {
if (Channel is IGuildChannel) return this; if (Channel.Guild != null) return this;
throw new PKError("This command can not be run in a DM."); throw new PKError("This command can not be run in a DM.");
} }
@ -296,10 +277,10 @@ namespace PluralKit.Bot
throw new PKError("You do not have permission to access this information."); throw new PKError("You do not have permission to access this information.");
} }
public ITextChannel MatchChannel() public DiscordChannel MatchChannel()
{ {
if (!MentionUtils.TryParseChannel(PeekArgument(), out var channel)) return null; if (!MentionUtils.TryParseChannel(PeekArgument(), out var channel)) return null;
if (!(_client.GetChannel(channel) is ITextChannel textChannel)) return null; if (!(_client.GetChannelAsync(channel) is ITextChannel textChannel)) return null;
PopArgument(); PopArgument();
return textChannel; return textChannel;

View File

@ -3,9 +3,7 @@ using System.Net.Http;
using Autofac; using Autofac;
using Discord; using DSharpPlus;
using Discord.Rest;
using Discord.WebSocket;
using PluralKit.Core; using PluralKit.Core;
@ -18,18 +16,12 @@ namespace PluralKit.Bot
protected override void Load(ContainerBuilder builder) protected override void Load(ContainerBuilder builder)
{ {
// Client // Client
builder.Register(c => new DiscordShardedClient(new DiscordSocketConfig() builder.Register(c => new DiscordShardedClient(new DiscordConfiguration
{ {
MessageCacheSize = 0, Token = c.Resolve<BotConfig>().Token,
ConnectionTimeout = 2 * 60 * 1000, TokenType = TokenType.Bot,
ExclusiveBulkDelete = true, MessageCacheSize = 0,
LargeThreshold = 50, })).AsSelf().SingleInstance();
GuildSubscriptions = false,
DefaultRetryMode = RetryMode.RetryTimeouts | RetryMode.RetryRatelimit
// Commented this out since Debug actually sends, uh, quite a lot that's not necessary in production
// but leaving it here in case I (or someone else) get[s] confused about why logging isn't working again :p
// LogLevel = LogSeverity.Debug // We filter log levels in Serilog, so just pass everything through (Debug is lower than Verbose)
})).AsSelf().As<BaseDiscordClient>().As<BaseSocketClient>().As<IDiscordClient>().SingleInstance();
// Commands // Commands
builder.RegisterType<CommandTree>().AsSelf(); builder.RegisterType<CommandTree>().AsSelf();

View File

@ -6,7 +6,6 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\Discord.Net\src\Discord.Net.WebSocket\Discord.Net.WebSocket.csproj" />
<ProjectReference Include="..\PluralKit.Core\PluralKit.Core.csproj" /> <ProjectReference Include="..\PluralKit.Core\PluralKit.Core.csproj" />
</ItemGroup> </ItemGroup>

View File

@ -2,8 +2,9 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Discord;
using Discord.WebSocket; using DSharpPlus;
using DSharpPlus.Entities;
using Humanizer; using Humanizer;
using NodaTime; using NodaTime;
@ -22,15 +23,15 @@ namespace PluralKit.Bot {
_data = data; _data = data;
} }
public async Task<Embed> CreateSystemEmbed(PKSystem system, LookupContext ctx) { public async Task<DiscordEmbed> CreateSystemEmbed(DiscordClient client, PKSystem system, LookupContext ctx) {
var accounts = await _data.GetSystemAccounts(system); var accounts = await _data.GetSystemAccounts(system);
// Fetch/render info for all accounts simultaneously // Fetch/render info for all accounts simultaneously
var users = await Task.WhenAll(accounts.Select(async uid => (await _client.Rest.GetUserAsync(uid))?.NameAndMention() ?? $"(deleted account {uid})")); var users = await Task.WhenAll(accounts.Select(async uid => (await client.GetUserAsync(uid))?.NameAndMention() ?? $"(deleted account {uid})"));
var memberCount = await _data.GetSystemMemberCount(system, false); var memberCount = await _data.GetSystemMemberCount(system, false);
var eb = new EmbedBuilder() var eb = new DiscordEmbedBuilder()
.WithColor(Color.Blue) .WithColor(DiscordColor.Blue)
.WithTitle(system.Name ?? null) .WithTitle(system.Name ?? null)
.WithThumbnailUrl(system.AvatarUrl ?? null) .WithThumbnailUrl(system.AvatarUrl ?? null)
.WithFooter($"System ID: {system.Hid} | Created on {DateTimeFormats.ZonedDateTimeFormat.Format(system.Created.InZone(system.Zone))}"); .WithFooter($"System ID: {system.Hid} | Created on {DateTimeFormats.ZonedDateTimeFormat.Format(system.Created.InZone(system.Zone))}");
@ -61,33 +62,33 @@ namespace PluralKit.Bot {
return eb.Build(); return eb.Build();
} }
public Embed CreateLoggedMessageEmbed(PKSystem system, PKMember member, ulong messageId, ulong originalMsgId, IUser sender, string content, IGuildChannel channel) { public DiscordEmbed CreateLoggedMessageEmbed(PKSystem system, PKMember member, ulong messageId, ulong originalMsgId, DiscordUser sender, string content, DiscordChannel channel) {
// TODO: pronouns in ?-reacted response using this card // TODO: pronouns in ?-reacted response using this card
var timestamp = SnowflakeUtils.FromSnowflake(messageId); var timestamp = DiscordUtils.SnowflakeToInstant(messageId);
return new EmbedBuilder() return new DiscordEmbedBuilder()
.WithAuthor($"#{channel.Name}: {member.Name}", member.AvatarUrl) .WithAuthor($"#{channel.Name}: {member.Name}", member.AvatarUrl)
.WithDescription(content?.NormalizeLineEndSpacing()) .WithDescription(content?.NormalizeLineEndSpacing())
.WithFooter($"System ID: {system.Hid} | Member ID: {member.Hid} | Sender: {sender.Username}#{sender.Discriminator} ({sender.Id}) | Message ID: {messageId} | Original Message ID: {originalMsgId}") .WithFooter($"System ID: {system.Hid} | Member ID: {member.Hid} | Sender: {sender.Username}#{sender.Discriminator} ({sender.Id}) | Message ID: {messageId} | Original Message ID: {originalMsgId}")
.WithTimestamp(timestamp) .WithTimestamp(timestamp.ToDateTimeOffset())
.Build(); .Build();
} }
public async Task<Embed> CreateMemberEmbed(PKSystem system, PKMember member, IGuild guild, LookupContext ctx) public async Task<DiscordEmbed> CreateMemberEmbed(PKSystem system, PKMember member, DiscordGuild guild, LookupContext ctx)
{ {
var name = member.Name; var name = member.Name;
if (system.Name != null) name = $"{member.Name} ({system.Name})"; if (system.Name != null) name = $"{member.Name} ({system.Name})";
Color color; DiscordColor color;
try try
{ {
color = member.Color?.ToDiscordColor() ?? Color.Default; color = member.Color?.ToDiscordColor() ?? DiscordColor.Gray;
} }
catch (ArgumentException) catch (ArgumentException)
{ {
// Bad API use can cause an invalid color string // Bad API use can cause an invalid color string
// TODO: fix that in the API // TODO: fix that in the API
// for now we just default to a blank color, yolo // for now we just default to a blank color, yolo
color = Color.Default; color = DiscordColor.Gray;
} }
var messageCount = await _data.GetMemberMessageCount(member); var messageCount = await _data.GetMemberMessageCount(member);
@ -98,10 +99,10 @@ namespace PluralKit.Bot {
var proxyTagsStr = string.Join('\n', member.ProxyTags.Select(t => $"`{t.ProxyString}`")); var proxyTagsStr = string.Join('\n', member.ProxyTags.Select(t => $"`{t.ProxyString}`"));
var eb = new EmbedBuilder() var eb = new DiscordEmbedBuilder()
// TODO: add URL of website when that's up // TODO: add URL of website when that's up
.WithAuthor(name, avatar) .WithAuthor(name, avatar)
.WithColor(member.MemberPrivacy.CanAccess(ctx) ? color : Color.Default) .WithColor(member.MemberPrivacy.CanAccess(ctx) ? color : DiscordColor.Gray)
.WithFooter($"System ID: {system.Hid} | Member ID: {member.Hid} | Created on {DateTimeFormats.ZonedDateTimeFormat.Format(member.Created.InZone(system.Zone))}"); .WithFooter($"System ID: {system.Hid} | Member ID: {member.Hid} | Created on {DateTimeFormats.ZonedDateTimeFormat.Format(member.Created.InZone(system.Zone))}");
var description = ""; var description = "";
@ -119,7 +120,7 @@ namespace PluralKit.Bot {
if (guild != null && guildDisplayName != null) eb.AddField($"Server Nickname (for {guild.Name})", guildDisplayName.Truncate(1024), true); if (guild != null && guildDisplayName != null) eb.AddField($"Server Nickname (for {guild.Name})", guildDisplayName.Truncate(1024), true);
if (member.Birthday != null && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Birthdate", member.BirthdayString, true); if (member.Birthday != null && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Birthdate", member.BirthdayString, true);
if (!member.Pronouns.EmptyOrNull() && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Pronouns", member.Pronouns.Truncate(1024), true); if (!member.Pronouns.EmptyOrNull() && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Pronouns", member.Pronouns.Truncate(1024), true);
if (messageCount > 0 && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Message Count", messageCount, true); if (messageCount > 0 && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Message Count", messageCount.ToString(), true);
if (member.HasProxyTags) eb.AddField("Proxy Tags", string.Join('\n', proxyTagsStr).Truncate(1024), true); if (member.HasProxyTags) eb.AddField("Proxy Tags", string.Join('\n', proxyTagsStr).Truncate(1024), true);
if (!member.Color.EmptyOrNull() && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Color", $"#{member.Color}", true); if (!member.Color.EmptyOrNull() && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Color", $"#{member.Color}", true);
if (!member.Description.EmptyOrNull() && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Description", member.Description.NormalizeLineEndSpacing(), false); if (!member.Description.EmptyOrNull() && member.MemberPrivacy.CanAccess(ctx)) eb.AddField("Description", member.Description.NormalizeLineEndSpacing(), false);
@ -127,48 +128,45 @@ namespace PluralKit.Bot {
return eb.Build(); return eb.Build();
} }
public async Task<Embed> CreateFronterEmbed(PKSwitch sw, DateTimeZone zone) public async Task<DiscordEmbed> CreateFronterEmbed(PKSwitch sw, DateTimeZone zone)
{ {
var members = await _data.GetSwitchMembers(sw).ToListAsync(); var members = await _data.GetSwitchMembers(sw).ToListAsync();
var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp; var timeSinceSwitch = SystemClock.Instance.GetCurrentInstant() - sw.Timestamp;
return new EmbedBuilder() return new DiscordEmbedBuilder()
.WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? Color.Blue) .WithColor(members.FirstOrDefault()?.Color?.ToDiscordColor() ?? DiscordColor.Blue)
.AddField($"Current {"fronter".ToQuantity(members.Count, ShowQuantityAs.None)}", members.Count > 0 ? string.Join(", ", members.Select(m => m.Name)) : "*(no fronter)*") .AddField($"Current {"fronter".ToQuantity(members.Count, ShowQuantityAs.None)}", members.Count > 0 ? string.Join(", ", members.Select(m => m.Name)) : "*(no fronter)*")
.AddField("Since", $"{DateTimeFormats.ZonedDateTimeFormat.Format(sw.Timestamp.InZone(zone))} ({DateTimeFormats.DurationFormat.Format(timeSinceSwitch)} ago)") .AddField("Since", $"{DateTimeFormats.ZonedDateTimeFormat.Format(sw.Timestamp.InZone(zone))} ({DateTimeFormats.DurationFormat.Format(timeSinceSwitch)} ago)")
.Build(); .Build();
} }
public async Task<Embed> CreateMessageInfoEmbed(FullMessage msg) public async Task<DiscordEmbed> CreateMessageInfoEmbed(DiscordClient client, FullMessage msg)
{ {
var channel = _client.GetChannel(msg.Message.Channel) as ITextChannel; var channel = await client.GetChannelAsync(msg.Message.Channel);
var serverMsg = channel != null ? await channel.GetMessageAsync(msg.Message.Mid) : null; var serverMsg = channel != null ? await channel.GetMessageAsync(msg.Message.Mid) : null;
var memberStr = $"{msg.Member.Name} (`{msg.Member.Hid}`)"; var memberStr = $"{msg.Member.Name} (`{msg.Member.Hid}`)";
var userStr = $"*(deleted user {msg.Message.Sender})*"; var userStr = $"*(deleted user {msg.Message.Sender})*";
ICollection<IRole> roles = null; ICollection<DiscordRole> roles = null;
if (channel != null) if (channel != null)
{ {
// Look up the user with the REST client // Look up the user with the REST client
// this ensures we'll still get the information even if the user's not cached, // this ensures we'll still get the information even if the user's not cached,
// even if this means an extra API request (meh, it'll be fine) // even if this means an extra API request (meh, it'll be fine)
var shard = _client.GetShardFor(channel.Guild); var guildUser = await channel.Guild.GetMemberAsync(msg.Message.Sender);
var guildUser = await shard.Rest.GetGuildUserAsync(channel.Guild.Id, msg.Message.Sender);
if (guildUser != null) if (guildUser != null)
{ {
if (guildUser.RoleIds.Count > 0) roles = guildUser.Roles
roles = guildUser.RoleIds .Where(role => role.Name != "@everyone")
.Select(roleId => channel.Guild.GetRole(roleId)) .OrderByDescending(role => role.Position)
.Where(role => role.Name != "@everyone") .ToList();
.OrderByDescending(role => role.Position)
.ToList();
userStr = guildUser.Nickname != null ? $"**Username:** {guildUser?.NameAndMention()}\n**Nickname:** {guildUser.Nickname}" : guildUser?.NameAndMention(); userStr = guildUser.Nickname != null ? $"**Username:** {guildUser?.NameAndMention()}\n**Nickname:** {guildUser.Nickname}" : guildUser.NameAndMention();
} }
} }
var eb = new EmbedBuilder() var eb = new DiscordEmbedBuilder()
.WithAuthor(msg.Member.Name, msg.Member.AvatarUrl) .WithAuthor(msg.Member.Name, msg.Member.AvatarUrl)
.WithDescription(serverMsg?.Content?.NormalizeLineEndSpacing() ?? "*(message contents deleted or inaccessible)*") .WithDescription(serverMsg?.Content?.NormalizeLineEndSpacing() ?? "*(message contents deleted or inaccessible)*")
.WithImageUrl(serverMsg?.Attachments?.FirstOrDefault()?.Url) .WithImageUrl(serverMsg?.Attachments?.FirstOrDefault()?.Url)
@ -176,18 +174,18 @@ namespace PluralKit.Bot {
msg.System.Name != null ? $"{msg.System.Name} (`{msg.System.Hid}`)" : $"`{msg.System.Hid}`", true) msg.System.Name != null ? $"{msg.System.Name} (`{msg.System.Hid}`)" : $"`{msg.System.Hid}`", true)
.AddField("Member", memberStr, true) .AddField("Member", memberStr, true)
.AddField("Sent by", userStr, inline: true) .AddField("Sent by", userStr, inline: true)
.WithTimestamp(SnowflakeUtils.FromSnowflake(msg.Message.Mid)); .WithTimestamp(DiscordUtils.SnowflakeToInstant(msg.Message.Mid).ToDateTimeOffset());
if (roles != null && roles.Count > 0) if (roles != null && roles.Count > 0)
eb.AddField($"Account roles ({roles.Count})", string.Join(", ", roles.Select(role => role.Name))); eb.AddField($"Account roles ({roles.Count})", string.Join(", ", roles.Select(role => role.Name)));
return eb.Build(); return eb.Build();
} }
public Task<Embed> CreateFrontPercentEmbed(FrontBreakdown breakdown, DateTimeZone tz) public Task<DiscordEmbed> CreateFrontPercentEmbed(FrontBreakdown breakdown, DateTimeZone tz)
{ {
var actualPeriod = breakdown.RangeEnd - breakdown.RangeStart; var actualPeriod = breakdown.RangeEnd - breakdown.RangeStart;
var eb = new EmbedBuilder() var eb = new DiscordEmbedBuilder()
.WithColor(Color.Blue) .WithColor(DiscordColor.Blue)
.WithFooter($"Since {DateTimeFormats.ZonedDateTimeFormat.Format(breakdown.RangeStart.InZone(tz))} ({DateTimeFormats.DurationFormat.Format(actualPeriod)} ago)"); .WithFooter($"Since {DateTimeFormats.ZonedDateTimeFormat.Format(breakdown.RangeStart.InZone(tz))} ({DateTimeFormats.DurationFormat.Format(actualPeriod)} ago)");
var maxEntriesToDisplay = 24; // max 25 fields allowed in embed - reserve 1 for "others" var maxEntriesToDisplay = 24; // max 25 fields allowed in embed - reserve 1 for "others"

View File

@ -7,6 +7,7 @@ namespace PluralKit.Bot
// not particularly efficient? It allocates a dictionary *and* a queue for every single channel (500k in prod!) // not particularly efficient? It allocates a dictionary *and* a queue for every single channel (500k in prod!)
// whereas this is, worst case, one dictionary *entry* of a single ulong per channel, and one dictionary instance // whereas this is, worst case, one dictionary *entry* of a single ulong per channel, and one dictionary instance
// on the whole instance, total. Yeah, much more efficient. // on the whole instance, total. Yeah, much more efficient.
// TODO: is this still needed after the D#+ migration?
public class LastMessageCacheService public class LastMessageCacheService
{ {
private IDictionary<ulong, ulong> _cache = new ConcurrentDictionary<ulong, ulong>(); private IDictionary<ulong, ulong> _cache = new ConcurrentDictionary<ulong, ulong>();

View File

@ -1,6 +1,7 @@
using System.Threading.Tasks; using System.Threading.Tasks;
using Discord; using DSharpPlus;
using DSharpPlus.Entities;
using PluralKit.Core; using PluralKit.Core;
@ -8,20 +9,18 @@ using Serilog;
namespace PluralKit.Bot { namespace PluralKit.Bot {
public class LogChannelService { public class LogChannelService {
private IDiscordClient _client;
private EmbedService _embed; private EmbedService _embed;
private IDataStore _data; private IDataStore _data;
private ILogger _logger; private ILogger _logger;
public LogChannelService(IDiscordClient client, EmbedService embed, ILogger logger, IDataStore data) public LogChannelService(EmbedService embed, ILogger logger, IDataStore data)
{ {
_client = client;
_embed = embed; _embed = embed;
_data = data; _data = data;
_logger = logger.ForContext<LogChannelService>(); _logger = logger.ForContext<LogChannelService>();
} }
public async Task LogMessage(PKSystem system, PKMember member, ulong messageId, ulong originalMsgId, IGuildChannel originalChannel, IUser sender, string content, GuildConfig? guildCfg = null) public async Task LogMessage(DiscordClient client, PKSystem system, PKMember member, ulong messageId, ulong originalMsgId, DiscordChannel originalChannel, DiscordUser sender, string content, GuildConfig? guildCfg = null)
{ {
if (guildCfg == null) if (guildCfg == null)
guildCfg = await _data.GetOrCreateGuildConfig(originalChannel.GuildId); guildCfg = await _data.GetOrCreateGuildConfig(originalChannel.GuildId);
@ -31,17 +30,19 @@ namespace PluralKit.Bot {
if (guildCfg.Value.LogBlacklist.Contains(originalChannel.Id)) return; if (guildCfg.Value.LogBlacklist.Contains(originalChannel.Id)) return;
// Bail if we can't find the channel // Bail if we can't find the channel
if (!(await _client.GetChannelAsync(guildCfg.Value.LogChannel.Value) is ITextChannel logChannel)) return; var channel = await client.GetChannelAsync(guildCfg.Value.LogChannel.Value);
if (channel == null || channel.Type != ChannelType.Text) return;
// Bail if we don't have permission to send stuff here // Bail if we don't have permission to send stuff here
if (!logChannel.HasPermission(ChannelPermission.SendMessages) || !logChannel.HasPermission(ChannelPermission.EmbedLinks)) var neededPermissions = Permissions.SendMessages | Permissions.EmbedLinks;
if ((channel.BotPermissions() & neededPermissions) != neededPermissions)
return; return;
var embed = _embed.CreateLoggedMessageEmbed(system, member, messageId, originalMsgId, sender, content, originalChannel); var embed = _embed.CreateLoggedMessageEmbed(system, member, messageId, originalMsgId, sender, content, originalChannel);
var url = $"https://discordapp.com/channels/{originalChannel.GuildId}/{originalChannel.Id}/{messageId}"; var url = $"https://discordapp.com/channels/{originalChannel.GuildId}/{originalChannel.Id}/{messageId}";
await logChannel.SendMessageAsync(text: url, embed: embed); await channel.SendMessageAsync(content: url, embed: embed);
} }
} }
} }

View File

@ -6,8 +6,8 @@ using System.Threading.Tasks;
using Dapper; using Dapper;
using Discord; using DSharpPlus;
using Discord.WebSocket; using DSharpPlus.Entities;
using PluralKit.Core; using PluralKit.Core;
@ -61,18 +61,18 @@ namespace PluralKit.Bot
public ICollection<LoggerBot> Bots => _bots.Values; public ICollection<LoggerBot> Bots => _bots.Values;
public async ValueTask HandleLoggerBotCleanup(SocketMessage msg, GuildConfig cachedGuild) public async ValueTask HandleLoggerBotCleanup(DiscordMessage msg, GuildConfig cachedGuild)
{ {
// Bail if not enabled, or if we don't have permission here // Bail if not enabled, or if we don't have permission here
if (!cachedGuild.LogCleanupEnabled) return; if (!cachedGuild.LogCleanupEnabled) return;
if (!(msg.Channel is SocketTextChannel channel)) return; if (msg.Channel.Type != ChannelType.Text) return;
if (!channel.Guild.GetUser(_client.CurrentUser.Id).GetPermissions(channel).ManageMessages) return; if (!msg.Channel.BotHasPermission(Permissions.ManageMessages)) return;
// If this message is from a *webhook*, check if the name matches one of the bots we know // If this message is from a *webhook*, check if the name matches one of the bots we know
// TODO: do we need to do a deeper webhook origin check, or would that be too hard on the rate limit? // TODO: do we need to do a deeper webhook origin check, or would that be too hard on the rate limit?
// If it's from a *bot*, check the bot ID to see if we know it. // If it's from a *bot*, check the bot ID to see if we know it.
LoggerBot bot = null; LoggerBot bot = null;
if (msg.Author.IsWebhook) _botsByWebhookName.TryGetValue(msg.Author.Username, out bot); if (msg.WebhookMessage) _botsByWebhookName.TryGetValue(msg.Author.Username, out bot);
else if (msg.Author.IsBot) _bots.TryGetValue(msg.Author.Id, out bot); else if (msg.Author.IsBot) _bots.TryGetValue(msg.Author.Id, out bot);
// If we didn't find anything before, or what we found is an unsupported bot, bail // If we didn't find anything before, or what we found is an unsupported bot, bail
@ -95,8 +95,8 @@ namespace PluralKit.Bot
new new
{ {
fuzzy.Value.User, fuzzy.Value.User,
Guild = (msg.Channel as ITextChannel)?.GuildId ?? 0, Guild = msg.Channel.GuildId,
ApproxId = SnowflakeUtils.ToSnowflake(fuzzy.Value.ApproxTimestamp - TimeSpan.FromSeconds(3)) ApproxId = DiscordUtils.InstantToSnowflake(fuzzy.Value.ApproxTimestamp - TimeSpan.FromSeconds(3))
}); });
if (mid == null) return; // If we didn't find a corresponding message, bail if (mid == null) return; // If we didn't find a corresponding message, bail
// Otherwise, we can *reasonably assume* that this is a logged deletion, so delete the log message. // Otherwise, we can *reasonably assume* that this is a logged deletion, so delete the log message.
@ -118,7 +118,7 @@ namespace PluralKit.Bot
} // else should not happen, but idk, it might } // else should not happen, but idk, it might
} }
private static ulong? ExtractAuttaja(SocketMessage msg) private static ulong? ExtractAuttaja(DiscordMessage msg)
{ {
// Auttaja has an optional "compact mode" that logs without embeds // Auttaja has an optional "compact mode" that logs without embeds
// That one puts the ID in the message content, non-compact puts it in the embed description. // That one puts the ID in the message content, non-compact puts it in the embed description.
@ -130,16 +130,16 @@ namespace PluralKit.Bot
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static ulong? ExtractDyno(SocketMessage msg) private static ulong? ExtractDyno(DiscordMessage msg)
{ {
// Embed *description* contains "Message sent by [mention] deleted in [channel]", contains message ID in footer per regex // Embed *description* contains "Message sent by [mention] deleted in [channel]", contains message ID in footer per regex
var embed = msg.Embeds.FirstOrDefault(); var embed = msg.Embeds.FirstOrDefault();
if (embed?.Footer == null || !(embed.Description?.Contains("deleted in") ?? false)) return null; if (embed?.Footer == null || !(embed.Description?.Contains("deleted in") ?? false)) return null;
var match = _dynoRegex.Match(embed.Footer.Value.Text ?? ""); var match = _dynoRegex.Match(embed.Footer.Text ?? "");
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static ulong? ExtractLoggerA(SocketMessage msg) private static ulong? ExtractLoggerA(DiscordMessage msg)
{ {
// This is for Logger#6088 (298822483060981760), distinct from Logger#6278 (327424261180620801). // This is for Logger#6088 (298822483060981760), distinct from Logger#6278 (327424261180620801).
// Embed contains title "Message deleted in [channel]", and an ID field containing both message and user ID (see regex). // Embed contains title "Message deleted in [channel]", and an ID field containing both message and user ID (see regex).
@ -153,26 +153,26 @@ namespace PluralKit.Bot
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static ulong? ExtractLoggerB(SocketMessage msg) private static ulong? ExtractLoggerB(DiscordMessage msg)
{ {
// This is for Logger#6278 (327424261180620801), distinct from Logger#6088 (298822483060981760). // This is for Logger#6278 (327424261180620801), distinct from Logger#6088 (298822483060981760).
// Embed title ends with "A Message Was Deleted!", footer contains message ID as per regex. // Embed title ends with "A Message Was Deleted!", footer contains message ID as per regex.
var embed = msg.Embeds.FirstOrDefault(); var embed = msg.Embeds.FirstOrDefault();
if (embed?.Footer == null || !(embed.Title?.EndsWith("A Message Was Deleted!") ?? false)) return null; if (embed?.Footer == null || !(embed.Title?.EndsWith("A Message Was Deleted!") ?? false)) return null;
var match = _loggerBRegex.Match(embed.Footer.Value.Text ?? ""); var match = _loggerBRegex.Match(embed.Footer.Text ?? "");
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static ulong? ExtractGenericBot(SocketMessage msg) private static ulong? ExtractGenericBot(DiscordMessage msg)
{ {
// Embed, title is "Message Deleted", ID plain in footer. // Embed, title is "Message Deleted", ID plain in footer.
var embed = msg.Embeds.FirstOrDefault(); var embed = msg.Embeds.FirstOrDefault();
if (embed?.Footer == null || !(embed.Title?.Contains("Message Deleted") ?? false)) return null; if (embed?.Footer == null || !(embed.Title?.Contains("Message Deleted") ?? false)) return null;
var match = _basicRegex.Match(embed.Footer.Value.Text ?? ""); var match = _basicRegex.Match(embed.Footer.Text ?? "");
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static ulong? ExtractBlargBot(SocketMessage msg) private static ulong? ExtractBlargBot(DiscordMessage msg)
{ {
// Embed, title ends with "Message Deleted", contains ID plain in a field. // Embed, title ends with "Message Deleted", contains ID plain in a field.
var embed = msg.Embeds.FirstOrDefault(); var embed = msg.Embeds.FirstOrDefault();
@ -182,7 +182,7 @@ namespace PluralKit.Bot
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static ulong? ExtractMantaro(SocketMessage msg) private static ulong? ExtractMantaro(DiscordMessage msg)
{ {
// Plain message, "Message (ID: [id]) created by [user] (ID: [id]) in channel [channel] was deleted. // Plain message, "Message (ID: [id]) created by [user] (ID: [id]) in channel [channel] was deleted.
if (!(msg.Content?.Contains("was deleted.") ?? false)) return null; if (!(msg.Content?.Contains("was deleted.") ?? false)) return null;
@ -190,19 +190,19 @@ namespace PluralKit.Bot
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static FuzzyExtractResult? ExtractCarlBot(SocketMessage msg) private static FuzzyExtractResult? ExtractCarlBot(DiscordMessage msg)
{ {
// Embed, title is "Message deleted in [channel], **user** ID in the footer, timestamp as, well, timestamp in embed. // Embed, title is "Message deleted in [channel], **user** ID in the footer, timestamp as, well, timestamp in embed.
// This is the *deletion* timestamp, which we can assume is a couple seconds at most after the message was originally sent // This is the *deletion* timestamp, which we can assume is a couple seconds at most after the message was originally sent
var embed = msg.Embeds.FirstOrDefault(); var embed = msg.Embeds.FirstOrDefault();
if (embed?.Footer == null || embed.Timestamp == null || !(embed.Title?.StartsWith("Message deleted in") ?? false)) return null; if (embed?.Footer == null || embed.Timestamp == null || !(embed.Title?.StartsWith("Message deleted in") ?? false)) return null;
var match = _carlRegex.Match(embed.Footer.Value.Text ?? ""); var match = _carlRegex.Match(embed.Footer.Text ?? "");
return match.Success return match.Success
? new FuzzyExtractResult { User = ulong.Parse(match.Groups[1].Value), ApproxTimestamp = embed.Timestamp.Value } ? new FuzzyExtractResult { User = ulong.Parse(match.Groups[1].Value), ApproxTimestamp = embed.Timestamp.Value }
: (FuzzyExtractResult?) null; : (FuzzyExtractResult?) null;
} }
private static FuzzyExtractResult? ExtractCircle(SocketMessage msg) private static FuzzyExtractResult? ExtractCircle(DiscordMessage msg)
{ {
// Like Auttaja, Circle has both embed and compact modes, but the regex works for both. // Like Auttaja, Circle has both embed and compact modes, but the regex works for both.
// Compact: "Message from [user] ([id]) deleted in [channel]", no timestamp (use message time) // Compact: "Message from [user] ([id]) deleted in [channel]", no timestamp (use message time)
@ -211,7 +211,7 @@ namespace PluralKit.Bot
if (msg.Embeds.Count > 0) if (msg.Embeds.Count > 0)
{ {
var embed = msg.Embeds.First(); var embed = msg.Embeds.First();
if (embed.Author?.Name == null || !embed.Author.Value.Name.StartsWith("Message Deleted in")) return null; if (embed.Author?.Name == null || !embed.Author.Name.StartsWith("Message Deleted in")) return null;
var field = embed.Fields.FirstOrDefault(f => f.Name == "Message Author"); var field = embed.Fields.FirstOrDefault(f => f.Name == "Message Author");
if (field.Value == null) return null; if (field.Value == null) return null;
stringWithId = field.Value; stringWithId = field.Value;
@ -224,7 +224,7 @@ namespace PluralKit.Bot
: (FuzzyExtractResult?) null; : (FuzzyExtractResult?) null;
} }
private static FuzzyExtractResult? ExtractPancake(SocketMessage msg) private static FuzzyExtractResult? ExtractPancake(DiscordMessage msg)
{ {
// Embed, author is "Message Deleted", description includes a mention, timestamp is *message send time* (but no ID) // Embed, author is "Message Deleted", description includes a mention, timestamp is *message send time* (but no ID)
// so we use the message timestamp to get somewhere *after* the message was proxied // so we use the message timestamp to get somewhere *after* the message was proxied
@ -236,16 +236,16 @@ namespace PluralKit.Bot
: (FuzzyExtractResult?) null; : (FuzzyExtractResult?) null;
} }
private static ulong? ExtractUnbelievaBoat(SocketMessage msg) private static ulong? ExtractUnbelievaBoat(DiscordMessage msg)
{ {
// Embed author is "Message Deleted", footer contains message ID per regex // Embed author is "Message Deleted", footer contains message ID per regex
var embed = msg.Embeds.FirstOrDefault(); var embed = msg.Embeds.FirstOrDefault();
if (embed?.Footer == null || embed.Author?.Name != "Message Deleted") return null; if (embed?.Footer == null || embed.Author?.Name != "Message Deleted") return null;
var match = _unbelievaboatRegex.Match(embed.Footer.Value.Text ?? ""); var match = _unbelievaboatRegex.Match(embed.Footer.Text ?? "");
return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null; return match.Success ? ulong.Parse(match.Groups[1].Value) : (ulong?) null;
} }
private static FuzzyExtractResult? ExtractVanessa(SocketMessage msg) private static FuzzyExtractResult? ExtractVanessa(DiscordMessage msg)
{ {
// Title is "Message Deleted", embed description contains mention // Title is "Message Deleted", embed description contains mention
var embed = msg.Embeds.FirstOrDefault(); var embed = msg.Embeds.FirstOrDefault();
@ -261,11 +261,11 @@ namespace PluralKit.Bot
{ {
public string Name; public string Name;
public ulong Id; public ulong Id;
public Func<SocketMessage, ulong?> ExtractFunc; public Func<DiscordMessage, ulong?> ExtractFunc;
public Func<SocketMessage, FuzzyExtractResult?> FuzzyExtractFunc; public Func<DiscordMessage, FuzzyExtractResult?> FuzzyExtractFunc;
public string WebhookName; public string WebhookName;
public LoggerBot(string name, ulong id, Func<SocketMessage, ulong?> extractFunc = null, Func<SocketMessage, FuzzyExtractResult?> fuzzyExtractFunc = null, string webhookName = null) public LoggerBot(string name, ulong id, Func<DiscordMessage, ulong?> extractFunc = null, Func<DiscordMessage, FuzzyExtractResult?> fuzzyExtractFunc = null, string webhookName = null)
{ {
Name = name; Name = name;
Id = id; Id = id;

View File

@ -1,10 +1,13 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Data;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using App.Metrics; using App.Metrics;
using Discord;
using Discord.WebSocket; using DSharpPlus;
using DSharpPlus.Entities;
using NodaTime.Extensions; using NodaTime.Extensions;
using PluralKit.Core; using PluralKit.Core;
@ -27,9 +30,9 @@ namespace PluralKit.Bot
private ILogger _logger; private ILogger _logger;
public PeriodicStatCollector(IDiscordClient client, IMetrics metrics, ILogger logger, WebhookCacheService webhookCache, DbConnectionCountHolder countHolder, IDataStore data, CpuStatService cpu, WebhookRateLimitService webhookRateLimitCache) public PeriodicStatCollector(DiscordShardedClient client, IMetrics metrics, ILogger logger, WebhookCacheService webhookCache, DbConnectionCountHolder countHolder, IDataStore data, CpuStatService cpu, WebhookRateLimitService webhookRateLimitCache)
{ {
_client = (DiscordShardedClient) client; _client = client;
_metrics = metrics; _metrics = metrics;
_webhookCache = webhookCache; _webhookCache = webhookCache;
_countHolder = countHolder; _countHolder = countHolder;
@ -45,18 +48,31 @@ namespace PluralKit.Bot
stopwatch.Start(); stopwatch.Start();
// Aggregate guild/channel stats // Aggregate guild/channel stats
_metrics.Measure.Gauge.SetValue(BotMetrics.Guilds, _client.Guilds.Count);
_metrics.Measure.Gauge.SetValue(BotMetrics.Channels, _client.Guilds.Sum(g => g.TextChannels.Count)); var guildCount = 0;
_metrics.Measure.Gauge.SetValue(BotMetrics.ShardsConnected, _client.Shards.Count(shard => shard.ConnectionState == ConnectionState.Connected)); var channelCount = 0;
// No LINQ today, sorry
foreach (var shard in _client.ShardClients.Values)
{
guildCount += shard.Guilds.Count;
foreach (var guild in shard.Guilds.Values)
foreach (var channel in guild.Channels.Values)
if (channel.Type == ChannelType.Text)
channelCount++;
}
_metrics.Measure.Gauge.SetValue(BotMetrics.Guilds, guildCount);
_metrics.Measure.Gauge.SetValue(BotMetrics.Channels, channelCount);
// Aggregate member stats // Aggregate member stats
var usersKnown = new HashSet<ulong>(); var usersKnown = new HashSet<ulong>();
var usersOnline = new HashSet<ulong>(); var usersOnline = new HashSet<ulong>();
foreach (var guild in _client.Guilds) foreach (var shard in _client.ShardClients.Values)
foreach (var user in guild.Users) foreach (var guild in shard.Guilds.Values)
foreach (var user in guild.Members.Values)
{ {
usersKnown.Add(user.Id); usersKnown.Add(user.Id);
if (user.Status == UserStatus.Online) usersOnline.Add(user.Id); if (user.Presence.Status == UserStatus.Online) usersOnline.Add(user.Id);
} }
_metrics.Measure.Gauge.SetValue(BotMetrics.MembersTotal, usersKnown.Count); _metrics.Measure.Gauge.SetValue(BotMetrics.MembersTotal, usersKnown.Count);

View File

@ -3,12 +3,12 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Discord; using DSharpPlus;
using Discord.Net; using DSharpPlus.Entities;
using Discord.WebSocket; using DSharpPlus.EventArgs;
using DSharpPlus.Exceptions;
using NodaTime; using NodaTime;
using NodaTime.Extensions;
using PluralKit.Core; using PluralKit.Core;
@ -83,16 +83,16 @@ namespace PluralKit.Bot
return null; return null;
} }
public async Task HandleMessageAsync(GuildConfig guild, CachedAccount account, IMessage message, bool doAutoProxy) public async Task HandleMessageAsync(DiscordClient client, GuildConfig guild, CachedAccount account, DiscordMessage message, bool doAutoProxy)
{ {
// Bail early if this isn't in a guild channel // Bail early if this isn't in a guild channel
if (!(message.Channel is ITextChannel channel)) return; if (message.Channel.Guild != null) return;
// Find a member with proxy tags matching the message // Find a member with proxy tags matching the message
var match = GetProxyTagMatch(message.Content, account.System, account.Members); var match = GetProxyTagMatch(message.Content, account.System, account.Members);
// O(n) lookup since n is small (max ~100 in prod) and we're more constrained by memory (for a dictionary) here // O(n) lookup since n is small (max ~100 in prod) and we're more constrained by memory (for a dictionary) here
var systemSettingsForGuild = account.SettingsForGuild(channel.GuildId); var systemSettingsForGuild = account.SettingsForGuild(message.Channel.GuildId);
// If we didn't get a match by proxy tags, try to get one by autoproxy // If we didn't get a match by proxy tags, try to get one by autoproxy
// Also try if we *did* get a match, but there's no inner text. This happens if someone sends a message that // Also try if we *did* get a match, but there's no inner text. This happens if someone sends a message that
@ -102,26 +102,26 @@ namespace PluralKit.Bot
// When a normal message is sent, autoproxy is enabled, but if this method is called from a message *edit* // When a normal message is sent, autoproxy is enabled, but if this method is called from a message *edit*
// event, then autoproxy is disabled. This is so AP doesn't "retrigger" when the original message was escaped. // event, then autoproxy is disabled. This is so AP doesn't "retrigger" when the original message was escaped.
if (doAutoProxy && (match == null || (match.InnerText.Trim().Length == 0 && message.Attachments.Count == 0))) if (doAutoProxy && (match == null || (match.InnerText.Trim().Length == 0 && message.Attachments.Count == 0)))
match = await GetAutoproxyMatch(account, systemSettingsForGuild, message, channel); match = await GetAutoproxyMatch(account, systemSettingsForGuild, message, message.Channel);
// If we still haven't found any, just yeet // If we still haven't found any, just yeet
if (match == null) return; if (match == null) return;
// And make sure the channel's not blacklisted from proxying. // And make sure the channel's not blacklisted from proxying.
if (guild.Blacklist.Contains(channel.Id)) return; if (guild.Blacklist.Contains(message.ChannelId)) return;
// Make sure the system hasn't blacklisted the guild either // Make sure the system hasn't blacklisted the guild either
if (!systemSettingsForGuild.ProxyEnabled) return; if (!systemSettingsForGuild.ProxyEnabled) return;
// We know message.Channel can only be ITextChannel as PK doesn't work in DMs/groups // We know message.Channel can only be ITextChannel as PK doesn't work in DMs/groups
// Afterwards we ensure the bot has the right permissions, otherwise bail early // Afterwards we ensure the bot has the right permissions, otherwise bail early
if (!await EnsureBotPermissions(channel)) return; if (!await EnsureBotPermissions(message.Channel)) return;
// Can't proxy a message with no content and no attachment // Can't proxy a message with no content and no attachment
if (match.InnerText.Trim().Length == 0 && message.Attachments.Count == 0) if (match.InnerText.Trim().Length == 0 && message.Attachments.Count == 0)
return; return;
var memberSettingsForGuild = account.SettingsForMemberGuild(match.Member.Id, channel.GuildId); var memberSettingsForGuild = account.SettingsForMemberGuild(match.Member.Id, message.Channel.GuildId);
// Get variables in order and all // Get variables in order and all
var proxyName = match.Member.ProxyName(match.System.Tag, memberSettingsForGuild.DisplayName); var proxyName = match.Member.ProxyName(match.System.Tag, memberSettingsForGuild.DisplayName);
@ -138,19 +138,17 @@ namespace PluralKit.Bot
: match.InnerText; : match.InnerText;
// Sanitize @everyone, but only if the original user wouldn't have permission to // Sanitize @everyone, but only if the original user wouldn't have permission to
messageContents = SanitizeEveryoneMaybe(message, messageContents); messageContents = await SanitizeEveryoneMaybe(message, messageContents);
// Execute the webhook itself // Execute the webhook itself
var hookMessageId = await _webhookExecutor.ExecuteWebhook( var hookMessageId = await _webhookExecutor.ExecuteWebhook(message.Channel, proxyName, avatarUrl,
channel,
proxyName, avatarUrl,
messageContents, messageContents,
message.Attachments message.Attachments
); );
// Store the message in the database, and log it in the log channel (if applicable) // Store the message in the database, and log it in the log channel (if applicable)
await _data.AddMessage(message.Author.Id, hookMessageId, channel.GuildId, message.Channel.Id, message.Id, match.Member); await _data.AddMessage(message.Author.Id, hookMessageId, message.Channel.GuildId, message.Channel.Id, message.Id, match.Member);
await _logChannel.LogMessage(match.System, match.Member, hookMessageId, message.Id, message.Channel as IGuildChannel, message.Author, match.InnerText, guild); await _logChannel.LogMessage(client, match.System, match.Member, hookMessageId, message.Id, message.Channel, message.Author, match.InnerText, guild);
// Wait a second or so before deleting the original message // Wait a second or so before deleting the original message
await Task.Delay(1000); await Task.Delay(1000);
@ -159,14 +157,14 @@ namespace PluralKit.Bot
{ {
await message.DeleteAsync(); await message.DeleteAsync();
} }
catch (HttpException) catch (NotFoundException)
{ {
// If it's already deleted, we just log and swallow the exception // If it's already deleted, we just log and swallow the exception
_logger.Warning("Attempted to delete already deleted proxy trigger message {Message}", message.Id); _logger.Warning("Attempted to delete already deleted proxy trigger message {Message}", message.Id);
} }
} }
private async Task<ProxyMatch> GetAutoproxyMatch(CachedAccount account, SystemGuildSettings guildSettings, IMessage message, IGuildChannel channel) private async Task<ProxyMatch> GetAutoproxyMatch(CachedAccount account, SystemGuildSettings guildSettings, DiscordMessage message, DiscordChannel channel)
{ {
// For now we use a backslash as an "escape character", subject to change later // For now we use a backslash as an "escape character", subject to change later
if ((message.Content ?? "").TrimStart().StartsWith("\\")) return null; if ((message.Content ?? "").TrimStart().StartsWith("\\")) return null;
@ -189,7 +187,7 @@ namespace PluralKit.Bot
// If the message is older than 6 hours, ignore it and force the sender to "refresh" a proxy // If the message is older than 6 hours, ignore it and force the sender to "refresh" a proxy
// This can be revised in the future, it's a preliminary value. // This can be revised in the future, it's a preliminary value.
var timestamp = SnowflakeUtils.FromSnowflake(msg.Message.Mid).ToInstant(); var timestamp = DiscordUtils.SnowflakeToInstant(msg.Message.Mid);
var timeSince = SystemClock.Instance.GetCurrentInstant() - timestamp; var timeSince = SystemClock.Instance.GetCurrentInstant() - timestamp;
if (timeSince > Duration.FromHours(6)) return null; if (timeSince > Duration.FromHours(6)) return null;
@ -214,23 +212,23 @@ namespace PluralKit.Bot
}; };
} }
private static string SanitizeEveryoneMaybe(IMessage message, string messageContents) private static async Task<string> SanitizeEveryoneMaybe(DiscordMessage message,
string messageContents)
{ {
var senderPermissions = ((IGuildUser) message.Author).GetPermissions(message.Channel as IGuildChannel); var member = await message.Channel.Guild.GetMemberAsync(message.Author.Id);
if (!senderPermissions.MentionEveryone) return messageContents.SanitizeEveryone(); if ((member.PermissionsIn(message.Channel) & Permissions.MentionEveryone) == 0) return messageContents.SanitizeEveryone();
return messageContents; return messageContents;
} }
private async Task<bool> EnsureBotPermissions(ITextChannel channel) private async Task<bool> EnsureBotPermissions(DiscordChannel channel)
{ {
var guildUser = await channel.Guild.GetCurrentUserAsync(); var permissions = channel.BotPermissions();
var permissions = guildUser.GetPermissions(channel);
// If we can't send messages at all, just bail immediately. // If we can't send messages at all, just bail immediately.
// TODO: can you have ManageMessages and *not* SendMessages? What happens then? // TODO: can you have ManageMessages and *not* SendMessages? What happens then?
if (!permissions.SendMessages && !permissions.ManageMessages) return false; if ((permissions & (Permissions.SendMessages | Permissions.ManageMessages)) == 0) return false;
if (!permissions.ManageWebhooks) if ((permissions & Permissions.ManageWebhooks) == 0)
{ {
// todo: PKError-ify these // todo: PKError-ify these
await channel.SendMessageAsync( await channel.SendMessageAsync(
@ -238,7 +236,7 @@ namespace PluralKit.Bot
return false; return false;
} }
if (!permissions.ManageMessages) if ((permissions & Permissions.ManageMessages) == 0)
{ {
await channel.SendMessageAsync( await channel.SendMessageAsync(
$"{Emojis.Error} PluralKit does not have the *Manage Messages* permission in this channel, and thus cannot delete the original trigger message. Please contact a server administrator to remedy this."); $"{Emojis.Error} PluralKit does not have the *Manage Messages* permission in this channel, and thus cannot delete the original trigger message. Please contact a server administrator to remedy this.");
@ -248,121 +246,117 @@ namespace PluralKit.Bot
return true; return true;
} }
public Task HandleReactionAddedAsync(Cacheable<IUserMessage, ulong> message, ISocketMessageChannel channel, SocketReaction reaction) public Task HandleReactionAddedAsync(MessageReactionAddEventArgs args)
{ {
// Dispatch on emoji // Dispatch on emoji
switch (reaction.Emote.Name) switch (args.Emoji.Name)
{ {
case "\u274C": // Red X case "\u274C": // Red X
return HandleMessageDeletionByReaction(message, reaction.UserId); return HandleMessageDeletionByReaction(args);
case "\u2753": // Red question mark case "\u2753": // Red question mark
case "\u2754": // White question mark case "\u2754": // White question mark
return HandleMessageQueryByReaction(message, channel, reaction.UserId, reaction.Emote); return HandleMessageQueryByReaction(args);
case "\U0001F514": // Bell case "\U0001F514": // Bell
case "\U0001F6CE": // Bellhop bell case "\U0001F6CE": // Bellhop bell
case "\U0001F3D3": // Ping pong paddle (lol) case "\U0001F3D3": // Ping pong paddle (lol)
case "\u23F0": // Alarm clock case "\u23F0": // Alarm clock
case "\u2757": // Exclamation mark case "\u2757": // Exclamation mark
return HandleMessagePingByReaction(message, channel, reaction.UserId, reaction.Emote); return HandleMessagePingByReaction(args);
default: default:
return Task.CompletedTask; return Task.CompletedTask;
} }
} }
private async Task HandleMessagePingByReaction(Cacheable<IUserMessage, ulong> message, private async Task HandleMessagePingByReaction(MessageReactionAddEventArgs args)
ISocketMessageChannel channel, ulong userWhoReacted,
IEmote reactedEmote)
{ {
// Bail in DMs // Bail in DMs
if (!(channel is SocketGuildChannel gc)) return; if (args.Channel.Type != ChannelType.Text) return;
// Find the message in the DB // Find the message in the DB
var msg = await _data.GetMessage(message.Id); var msg = await _data.GetMessage(args.Message.Id);
if (msg == null) return; if (msg == null) return;
// Check if the pinger has permission to ping in this channel // Check if the pinger has permission to ping in this channel
var guildUser = await _client.Rest.GetGuildUserAsync(gc.Guild.Id, userWhoReacted); var guildUser = await args.Guild.GetMemberAsync(args.User.Id);
var permissions = guildUser.GetPermissions(gc); var permissions = guildUser.PermissionsIn(args.Channel);
var realMessage = await message.GetOrDownloadAsync();
// If they don't have Send Messages permission, bail (since PK shouldn't send anything on their behalf) // If they don't have Send Messages permission, bail (since PK shouldn't send anything on their behalf)
if (!permissions.SendMessages || !permissions.ViewChannel) return; var requiredPerms = Permissions.AccessChannels | Permissions.SendMessages;
if ((permissions & requiredPerms) != requiredPerms) return;
var embed = new EmbedBuilder().WithDescription($"[Jump to pinged message]({realMessage.GetJumpUrl()})");
await channel.SendMessageAsync($"Psst, **{msg.Member.DisplayName ?? msg.Member.Name}** (<@{msg.Message.Sender}>), you have been pinged by <@{userWhoReacted}>.", embed: embed.Build()); var embed = new DiscordEmbedBuilder().WithDescription($"[Jump to pinged message]({args.Message.JumpLink})");
await args.Channel.SendMessageAsync($"Psst, **{msg.Member.DisplayName ?? msg.Member.Name}** (<@{msg.Message.Sender}>), you have been pinged by <@{args.User.Id}>.", embed: embed.Build());
// Finally remove the original reaction (if we can) // Finally remove the original reaction (if we can)
var user = await _client.Rest.GetUserAsync(userWhoReacted); if (args.Channel.BotHasPermission(Permissions.ManageMessages))
if (user != null && realMessage.Channel.HasPermission(ChannelPermission.ManageMessages)) await args.Message.DeleteReactionAsync(args.Emoji, args.User);
await realMessage.RemoveReactionAsync(reactedEmote, user);
} }
private async Task HandleMessageQueryByReaction(Cacheable<IUserMessage, ulong> message, private async Task HandleMessageQueryByReaction(MessageReactionAddEventArgs args)
ISocketMessageChannel channel, ulong userWhoReacted,
IEmote reactedEmote)
{ {
// Find the user who sent the reaction, so we can DM them // Bail if not in guild
var user = await _client.Rest.GetUserAsync(userWhoReacted); if (args.Guild == null) return;
if (user == null) return;
// Find the message in the DB // Find the message in the DB
var msg = await _data.GetMessage(message.Id); var msg = await _data.GetMessage(args.Message.Id);
if (msg == null) return; if (msg == null) return;
// Get guild member so we can DM
var member = await args.Guild.GetMemberAsync(args.User.Id);
// DM them the message card // DM them the message card
try try
{ {
await user.SendMessageAsync(embed: await _embeds.CreateMemberEmbed(msg.System, msg.Member, (channel as IGuildChannel)?.Guild, LookupContext.ByNonOwner)); await member.SendMessageAsync(embed: await _embeds.CreateMemberEmbed(msg.System, msg.Member, args.Guild, LookupContext.ByNonOwner));
await user.SendMessageAsync(embed: await _embeds.CreateMessageInfoEmbed(msg)); await member.SendMessageAsync(embed: await _embeds.CreateMessageInfoEmbed(args.Client, msg));
} }
catch (HttpException e) when (e.DiscordCode == 50007) catch (BadRequestException)
{ {
// TODO: is this the correct exception
// Ignore exception if it means we don't have DM permission to this user // Ignore exception if it means we don't have DM permission to this user
// not much else we can do here :/ // not much else we can do here :/
} }
// And finally remove the original reaction (if we can) // And finally remove the original reaction (if we can)
var msgObj = await message.GetOrDownloadAsync(); await args.Message.DeleteReactionAsync(args.Emoji, args.User);
if (msgObj.Channel.HasPermission(ChannelPermission.ManageMessages))
await msgObj.RemoveReactionAsync(reactedEmote, user);
} }
public async Task HandleMessageDeletionByReaction(Cacheable<IUserMessage, ulong> message, ulong userWhoReacted) public async Task HandleMessageDeletionByReaction(MessageReactionAddEventArgs args)
{ {
// Bail if we don't have permission to delete
if (!args.Channel.BotHasPermission(Permissions.ManageMessages)) return;
// Find the message in the database // Find the message in the database
var storedMessage = await _data.GetMessage(message.Id); var storedMessage = await _data.GetMessage(args.Message.Id);
if (storedMessage == null) return; // (if we can't, that's ok, no worries) if (storedMessage == null) return; // (if we can't, that's ok, no worries)
// Make sure it's the actual sender of that message deleting the message // Make sure it's the actual sender of that message deleting the message
if (storedMessage.Message.Sender != userWhoReacted) return; if (storedMessage.Message.Sender != args.User.Id) return;
try { try
// Then, fetch the Discord message and delete that {
// TODO: this could be faster if we didn't bother fetching it and just deleted it directly await args.Message.DeleteAsync();
// somehow through REST?
await (await message.GetOrDownloadAsync()).DeleteAsync();
} catch (NullReferenceException) { } catch (NullReferenceException) {
// Message was deleted before we got to it... cool, no problem, lmao // Message was deleted before we got to it... cool, no problem, lmao
} }
// Finally, delete it from our database. // Finally, delete it from our database.
await _data.DeleteMessage(message.Id); await _data.DeleteMessage(args.Message.Id);
} }
public async Task HandleMessageDeletedAsync(Cacheable<IMessage, ulong> message, ISocketMessageChannel channel) public async Task HandleMessageDeletedAsync(MessageDeleteEventArgs args)
{ {
// Don't delete messages from the store if they aren't webhooks // Don't delete messages from the store if they aren't webhooks
// Non-webhook messages will never be stored anyway. // Non-webhook messages will never be stored anyway.
// If we're not sure (eg. message outside of cache), delete just to be sure. // If we're not sure (eg. message outside of cache), delete just to be sure.
if (message.HasValue && !message.Value.Author.IsWebhook) return; if (!args.Message.WebhookMessage) return;
await _data.DeleteMessage(message.Id); await _data.DeleteMessage(args.Message.Id);
} }
public async Task HandleMessageBulkDeleteAsync(IReadOnlyCollection<Cacheable<IMessage, ulong>> messages, IMessageChannel channel) public async Task HandleMessageBulkDeleteAsync(MessageBulkDeleteEventArgs args)
{ {
_logger.Information("Bulk deleting {Count} messages in channel {Channel}", messages.Count, channel.Id); _logger.Information("Bulk deleting {Count} messages in channel {Channel}", args.Messages.Count, args.Channel.Id);
await _data.DeleteMessagesBulk(messages.Select(m => m.Id).ToList()); await _data.DeleteMessagesBulk(args.Messages.Select(m => m.Id).ToList());
} }
} }
} }

View File

@ -2,7 +2,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading.Tasks; using System.Threading.Tasks;
using Discord.WebSocket; using DSharpPlus;
using NodaTime; using NodaTime;
@ -21,35 +21,36 @@ namespace PluralKit.Bot
public void Init(DiscordShardedClient client) public void Init(DiscordShardedClient client)
{ {
for (var i = 0; i < client.Shards.Count; i++) foreach (var i in client.ShardClients.Keys)
_shardInfo[i] = new ShardInfo(); _shardInfo[i] = new ShardInfo();
client.ShardConnected += ShardConnected; // TODO
client.ShardDisconnected += ShardDisconnected; // client.ShardConnected += ShardConnected;
client.ShardReady += ShardReady; // client.ShardDisconnected += ShardDisconnected;
client.ShardLatencyUpdated += ShardLatencyUpdated; // client.ShardReady += ShardReady;
// client.ShardLatencyUpdated += ShardLatencyUpdated;
} }
public ShardInfo GetShardInfo(DiscordSocketClient shard) => _shardInfo[shard.ShardId]; public ShardInfo GetShardInfo(DiscordClient shard) => _shardInfo[shard.ShardId];
private Task ShardLatencyUpdated(int oldLatency, int newLatency, DiscordSocketClient shard) private Task ShardLatencyUpdated(int oldLatency, int newLatency, DiscordClient shard)
{ {
_shardInfo[shard.ShardId].ShardLatency = newLatency; _shardInfo[shard.ShardId].ShardLatency = newLatency;
return Task.CompletedTask; return Task.CompletedTask;
} }
private Task ShardReady(DiscordSocketClient shard) private Task ShardReady(DiscordClient shard)
{ {
return Task.CompletedTask; return Task.CompletedTask;
} }
private Task ShardDisconnected(Exception e, DiscordSocketClient shard) private Task ShardDisconnected(Exception e, DiscordClient shard)
{ {
_shardInfo[shard.ShardId].DisconnectionCount++; _shardInfo[shard.ShardId].DisconnectionCount++;
return Task.CompletedTask; return Task.CompletedTask;
} }
private Task ShardConnected(DiscordSocketClient shard) private Task ShardConnected(DiscordClient shard)
{ {
_shardInfo[shard.ShardId].LastConnectionTime = SystemClock.Instance.GetCurrentInstant(); _shardInfo[shard.ShardId].LastConnectionTime = SystemClock.Instance.GetCurrentInstant();
return Task.CompletedTask; return Task.CompletedTask;

View File

@ -3,8 +3,9 @@ using System.Collections.Concurrent;
using System.Linq; using System.Linq;
using System.Net.Http; using System.Net.Http;
using System.Threading.Tasks; using System.Threading.Tasks;
using Discord;
using Discord.WebSocket; using DSharpPlus;
using DSharpPlus.Entities;
using Serilog; using Serilog;
@ -15,54 +16,55 @@ namespace PluralKit.Bot
public static readonly string WebhookName = "PluralKit Proxy Webhook"; public static readonly string WebhookName = "PluralKit Proxy Webhook";
private DiscordShardedClient _client; private DiscordShardedClient _client;
private ConcurrentDictionary<ulong, Lazy<Task<IWebhook>>> _webhooks; private ConcurrentDictionary<ulong, Lazy<Task<DiscordWebhook>>> _webhooks;
private ILogger _logger; private ILogger _logger;
public WebhookCacheService(IDiscordClient client, ILogger logger) public WebhookCacheService(DiscordShardedClient client, ILogger logger)
{ {
_client = client as DiscordShardedClient; _client = client;
_logger = logger.ForContext<WebhookCacheService>(); _logger = logger.ForContext<WebhookCacheService>();
_webhooks = new ConcurrentDictionary<ulong, Lazy<Task<IWebhook>>>(); _webhooks = new ConcurrentDictionary<ulong, Lazy<Task<DiscordWebhook>>>();
} }
public async Task<IWebhook> GetWebhook(ulong channelId) public async Task<DiscordWebhook> GetWebhook(DiscordClient client, ulong channelId)
{ {
var channel = _client.GetChannel(channelId) as ITextChannel; var channel = await client.GetChannelAsync(channelId);
if (channel == null) return null; if (channel == null) return null;
if (channel.Type == ChannelType.Text) return null;
return await GetWebhook(channel); return await GetWebhook(channel);
} }
public async Task<IWebhook> GetWebhook(ITextChannel channel) public async Task<DiscordWebhook> GetWebhook(DiscordChannel channel)
{ {
// We cache the webhook through a Lazy<Task<T>>, this way we make sure to only create one webhook per channel // We cache the webhook through a Lazy<Task<T>>, this way we make sure to only create one webhook per channel
// If the webhook is requested twice before it's actually been found, the Lazy<T> wrapper will stop the // If the webhook is requested twice before it's actually been found, the Lazy<T> wrapper will stop the
// webhook from being created twice. // webhook from being created twice.
var lazyWebhookValue = var lazyWebhookValue =
_webhooks.GetOrAdd(channel.Id, new Lazy<Task<IWebhook>>(() => GetOrCreateWebhook(channel))); _webhooks.GetOrAdd(channel.Id, new Lazy<Task<DiscordWebhook>>(() => GetOrCreateWebhook(channel)));
// It's possible to "move" a webhook to a different channel after creation // It's possible to "move" a webhook to a different channel after creation
// Here, we ensure it's actually still pointing towards the proper channel, and if not, wipe and refetch one. // Here, we ensure it's actually still pointing towards the proper channel, and if not, wipe and refetch one.
var webhook = await lazyWebhookValue.Value; var webhook = await lazyWebhookValue.Value;
if (webhook.ChannelId != channel.Id) return await InvalidateAndRefreshWebhook(webhook); if (webhook.ChannelId != channel.Id) return await InvalidateAndRefreshWebhook(channel, webhook);
return webhook; return webhook;
} }
public async Task<IWebhook> InvalidateAndRefreshWebhook(IWebhook webhook) public async Task<DiscordWebhook> InvalidateAndRefreshWebhook(DiscordChannel channel, DiscordWebhook webhook)
{ {
_logger.Information("Refreshing webhook for channel {Channel}", webhook.ChannelId); _logger.Information("Refreshing webhook for channel {Channel}", webhook.ChannelId);
_webhooks.TryRemove(webhook.ChannelId, out _); _webhooks.TryRemove(webhook.ChannelId, out _);
return await GetWebhook(webhook.Channel); return await GetWebhook(channel);
} }
private async Task<IWebhook> GetOrCreateWebhook(ITextChannel channel) private async Task<DiscordWebhook> GetOrCreateWebhook(DiscordChannel channel)
{ {
_logger.Debug("Webhook for channel {Channel} not found in cache, trying to fetch", channel.Id); _logger.Debug("Webhook for channel {Channel} not found in cache, trying to fetch", channel.Id);
return await FindExistingWebhook(channel) ?? await DoCreateWebhook(channel); return await FindExistingWebhook(channel) ?? await DoCreateWebhook(channel);
} }
private async Task<IWebhook> FindExistingWebhook(ITextChannel channel) private async Task<DiscordWebhook> FindExistingWebhook(DiscordChannel channel)
{ {
_logger.Debug("Finding webhook for channel {Channel}", channel.Id); _logger.Debug("Finding webhook for channel {Channel}", channel.Id);
try try
@ -78,13 +80,13 @@ namespace PluralKit.Bot
} }
} }
private Task<IWebhook> DoCreateWebhook(ITextChannel channel) private Task<DiscordWebhook> DoCreateWebhook(DiscordChannel channel)
{ {
_logger.Information("Creating new webhook for channel {Channel}", channel.Id); _logger.Information("Creating new webhook for channel {Channel}", channel.Id);
return channel.CreateWebhookAsync(WebhookName); return channel.CreateWebhookAsync(WebhookName);
} }
private bool IsWebhookMine(IWebhook arg) => arg.Creator.Id == _client.GetShardFor(arg.Guild).CurrentUser.Id && arg.Name == WebhookName; private bool IsWebhookMine(DiscordWebhook arg) => arg.User.Id == _client.CurrentUser.Id && arg.Name == WebhookName;
public int CacheSize => _webhooks.Count; public int CacheSize => _webhooks.Count;
} }

View File

@ -8,7 +8,8 @@ using System.Text.RegularExpressions;
using System.Threading.Tasks; using System.Threading.Tasks;
using App.Metrics; using App.Metrics;
using Discord; using DSharpPlus.Entities;
using DSharpPlus.Exceptions;
using Humanizer; using Humanizer;
@ -44,13 +45,13 @@ namespace PluralKit.Bot
_logger = logger.ForContext<WebhookExecutorService>(); _logger = logger.ForContext<WebhookExecutorService>();
} }
public async Task<ulong> ExecuteWebhook(ITextChannel channel, string name, string avatarUrl, string content, IReadOnlyCollection<IAttachment> attachments) public async Task<ulong> ExecuteWebhook(DiscordChannel channel, string name, string avatarUrl, string content, IReadOnlyList<DiscordAttachment> attachments)
{ {
_logger.Verbose("Invoking webhook in channel {Channel}", channel.Id); _logger.Verbose("Invoking webhook in channel {Channel}", channel.Id);
// Get a webhook, execute it // Get a webhook, execute it
var webhook = await _webhookCache.GetWebhook(channel); var webhook = await _webhookCache.GetWebhook(channel);
var id = await ExecuteWebhookInner(webhook, name, avatarUrl, content, attachments); var id = await ExecuteWebhookInner(channel, webhook, name, avatarUrl, content, attachments);
// Log the relevant metrics // Log the relevant metrics
_metrics.Measure.Meter.Mark(BotMetrics.MessagesProxied); _metrics.Measure.Meter.Mark(BotMetrics.MessagesProxied);
@ -60,112 +61,93 @@ namespace PluralKit.Bot
return id; return id;
} }
private async Task<ulong> ExecuteWebhookInner(IWebhook webhook, string name, string avatarUrl, string content, private async Task<ulong> ExecuteWebhookInner(DiscordChannel channel, DiscordWebhook webhook, string name, string avatarUrl, string content,
IReadOnlyCollection<IAttachment> attachments, bool hasRetried = false) IReadOnlyList<DiscordAttachment> attachments, bool hasRetried = false)
{ {
using var mfd = new MultipartFormDataContent var dwb = new DiscordWebhookBuilder();
{ dwb.WithUsername(FixClyde(name).Truncate(80));
{new StringContent(content.Truncate(2000)), "content"}, dwb.WithContent(content.Truncate(2000));
{new StringContent(FixClyde(name).Truncate(80)), "username"} if (avatarUrl != null) dwb.WithAvatarUrl(avatarUrl);
};
if (avatarUrl != null) mfd.Add(new StringContent(avatarUrl), "avatar_url");
var attachmentChunks = ChunkAttachmentsOrThrow(attachments, 8 * 1024 * 1024); var attachmentChunks = ChunkAttachmentsOrThrow(attachments, 8 * 1024 * 1024);
if (attachmentChunks.Count > 0) if (attachmentChunks.Count > 0)
{ {
_logger.Information("Invoking webhook with {AttachmentCount} attachments totalling {AttachmentSize} MiB in {AttachmentChunks} chunks", attachments.Count, attachments.Select(a => a.Size).Sum() / 1024 / 1024, attachmentChunks.Count); _logger.Information("Invoking webhook with {AttachmentCount} attachments totalling {AttachmentSize} MiB in {AttachmentChunks} chunks", attachments.Count, attachments.Select(a => a.FileSize).Sum() / 1024 / 1024, attachmentChunks.Count);
await AddAttachmentsToMultipart(mfd, attachmentChunks.First()); await AddAttachmentsToBuilder(dwb, attachmentChunks[0]);
} }
mfd.Headers.Add("X-RateLimit-Precision", "millisecond"); // Need this for better rate limit support
// Adding this check as close to the actual send call as possible to prevent potential race conditions (unlikely, but y'know)
if (!_rateLimit.TryExecuteWebhook(webhook))
throw new WebhookRateLimited();
var timerCtx = _metrics.Measure.Timer.Time(BotMetrics.WebhookResponseTime); var timerCtx = _metrics.Measure.Timer.Time(BotMetrics.WebhookResponseTime);
using var response = await _client.PostAsync($"{DiscordConfig.APIUrl}webhooks/{webhook.Id}/{webhook.Token}?wait=true", mfd);
timerCtx.Dispose();
_rateLimit.UpdateRateLimitInfo(webhook, response);
if (response.StatusCode == HttpStatusCode.TooManyRequests) DiscordMessage response;
// Rate limits should be respected, we bail early (already updated the limit info so we hopefully won't hit this again)
throw new WebhookRateLimited();
var responseString = await response.Content.ReadAsStringAsync();
JObject responseJson;
try try
{ {
responseJson = JsonConvert.DeserializeObject<JObject>(responseString); response = await webhook.ExecuteAsync(dwb);
} }
catch (JsonReaderException) catch (NotFoundException e)
{ {
// Sometimes we get invalid JSON from the server, just ignore all of it if (e.JsonMessage.Contains("10015") && !hasRetried)
throw new WebhookExecutionErrorOnDiscordsEnd();
}
if (responseJson.ContainsKey("code"))
{
var errorCode = responseJson["code"].Value<int>();
if (errorCode == 10015 && !hasRetried)
{ {
// Error 10015 = "Unknown Webhook" - this likely means the webhook was deleted // Error 10015 = "Unknown Webhook" - this likely means the webhook was deleted
// but is still in our cache. Invalidate, refresh, try again // but is still in our cache. Invalidate, refresh, try again
_logger.Warning("Error invoking webhook {Webhook} in channel {Channel}", webhook.Id, webhook.ChannelId); _logger.Warning("Error invoking webhook {Webhook} in channel {Channel}", webhook.Id, webhook.ChannelId);
return await ExecuteWebhookInner(await _webhookCache.InvalidateAndRefreshWebhook(webhook), name, avatarUrl, content, attachments, hasRetried: true);
}
if (errorCode == 40005)
throw Errors.AttachmentTooLarge; // should be caught by the check above but just makin' sure
// TODO: look into what this actually throws, and if this is the correct handling
if ((int) response.StatusCode >= 500)
// If it's a 5xx error code, this is on Discord's end, so we throw an execution exception
throw new WebhookExecutionErrorOnDiscordsEnd();
// Otherwise, this is going to throw on 4xx, and bubble up to our Sentry handler
response.EnsureSuccessStatusCode();
}
// If we have any leftover attachment chunks, send those
if (attachmentChunks.Count > 1)
{
// Deliberately not adding a content, just the remaining files
foreach (var chunk in attachmentChunks.Skip(1))
{
using var mfd2 = new MultipartFormDataContent();
mfd2.Add(new StringContent(FixClyde(name).Truncate(80)), "username");
if (avatarUrl != null) mfd2.Add(new StringContent(avatarUrl), "avatar_url");
await AddAttachmentsToMultipart(mfd2, chunk);
// Don't bother with ?wait, we're just kinda firehosing this stuff var newWebhook = await _webhookCache.InvalidateAndRefreshWebhook(channel, webhook);
// also don't error check, the real message itself is already sent return await ExecuteWebhookInner(channel, newWebhook, name, avatarUrl, content, attachments, hasRetried: true);
await _client.PostAsync($"{DiscordConfig.APIUrl}webhooks/{webhook.Id}/{webhook.Token}", mfd2);
} }
throw;
}
timerCtx.Dispose();
// We don't care about whether the sending succeeds, and we don't want to *wait* for it, so we just fork it off
var _ = TrySendRemainingAttachments(webhook, name, avatarUrl, attachmentChunks);
return response.Id;
}
private async Task TrySendRemainingAttachments(DiscordWebhook webhook, string name, string avatarUrl, IReadOnlyList<IReadOnlyCollection<DiscordAttachment>> attachmentChunks)
{
if (attachmentChunks.Count <= 1) return;
for (var i = 1; i < attachmentChunks.Count; i++)
{
var dwb = new DiscordWebhookBuilder();
if (avatarUrl != null) dwb.WithAvatarUrl(avatarUrl);
dwb.WithUsername(name);
await AddAttachmentsToBuilder(dwb, attachmentChunks[i]);
await webhook.ExecuteAsync(dwb);
}
}
private async Task AddAttachmentsToBuilder(DiscordWebhookBuilder dwb, IReadOnlyCollection<DiscordAttachment> attachments)
{
async Task<(DiscordAttachment, Stream)> GetStream(DiscordAttachment attachment)
{
var attachmentResponse = await _client.GetAsync(attachment.Url, HttpCompletionOption.ResponseHeadersRead);
return (attachment, await attachmentResponse.Content.ReadAsStreamAsync());
} }
// At this point we're sure we have a 2xx status code, so just assume success foreach (var (attachment, attachmentStream) in await Task.WhenAll(attachments.Select(GetStream)))
// TODO: can we do this without a round-trip to a string? dwb.AddFile(attachment.FileName, attachmentStream);
return responseJson["id"].Value<ulong>();
} }
private IReadOnlyCollection<IReadOnlyCollection<IAttachment>> ChunkAttachmentsOrThrow(
IReadOnlyCollection<IAttachment> attachments, int sizeThreshold) private IReadOnlyList<IReadOnlyCollection<DiscordAttachment>> ChunkAttachmentsOrThrow(
IReadOnlyList<DiscordAttachment> attachments, int sizeThreshold)
{ {
// Splits a list of attachments into "chunks" of at most 8MB each // Splits a list of attachments into "chunks" of at most 8MB each
// If any individual attachment is larger than 8MB, will throw an error // If any individual attachment is larger than 8MB, will throw an error
var chunks = new List<List<IAttachment>>(); var chunks = new List<List<DiscordAttachment>>();
var list = new List<IAttachment>(); var list = new List<DiscordAttachment>();
foreach (var attachment in attachments) foreach (var attachment in attachments)
{ {
if (attachment.Size >= sizeThreshold) throw Errors.AttachmentTooLarge; if (attachment.FileSize >= sizeThreshold) throw Errors.AttachmentTooLarge;
if (list.Sum(a => a.Size) + attachment.Size >= sizeThreshold) if (list.Sum(a => a.FileSize) + attachment.FileSize >= sizeThreshold)
{ {
chunks.Add(list); chunks.Add(list);
list = new List<IAttachment>(); list = new List<DiscordAttachment>();
} }
list.Add(attachment); list.Add(attachment);
@ -175,20 +157,6 @@ namespace PluralKit.Bot
return chunks; return chunks;
} }
private async Task AddAttachmentsToMultipart(MultipartFormDataContent content,
IReadOnlyCollection<IAttachment> attachments)
{
async Task<(IAttachment, Stream)> GetStream(IAttachment attachment)
{
var attachmentResponse = await _client.GetAsync(attachment.Url, HttpCompletionOption.ResponseHeadersRead);
return (attachment, await attachmentResponse.Content.ReadAsStreamAsync());
}
var attachmentId = 0;
foreach (var (attachment, attachmentStream) in await Task.WhenAll(attachments.Select(GetStream)))
content.Add(new StreamContent(attachmentStream), $"file{attachmentId++}", attachment.Filename);
}
private string FixClyde(string name) private string FixClyde(string name)
{ {
// Check if the name contains "Clyde" - if not, do nothing // Check if the name contains "Clyde" - if not, do nothing

View File

@ -5,7 +5,7 @@ using System.Linq;
using System.Net; using System.Net;
using System.Net.Http; using System.Net.Http;
using Discord; using DSharpPlus.Entities;
using NodaTime; using NodaTime;
@ -26,7 +26,7 @@ namespace PluralKit.Bot
public int CacheSize => _info.Count; public int CacheSize => _info.Count;
public bool TryExecuteWebhook(IWebhook webhook) public bool TryExecuteWebhook(DiscordWebhook webhook)
{ {
// If we have nothing saved, just allow it (we'll save something once the response returns) // If we have nothing saved, just allow it (we'll save something once the response returns)
if (!_info.TryGetValue(webhook.Id, out var info)) return true; if (!_info.TryGetValue(webhook.Id, out var info)) return true;
@ -57,7 +57,7 @@ namespace PluralKit.Bot
return true; return true;
} }
public void UpdateRateLimitInfo(IWebhook webhook, HttpResponseMessage response) public void UpdateRateLimitInfo(DiscordWebhook webhook, HttpResponseMessage response)
{ {
var info = _info.GetOrAdd(webhook.Id, _ => new WebhookRateLimitInfo()); var info = _info.GetOrAdd(webhook.Id, _ => new WebhookRateLimitInfo());

View File

@ -1,58 +1,63 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net;
using System.Threading.Tasks; using System.Threading.Tasks;
using Discord;
using Discord.Net; using DSharpPlus;
using Discord.WebSocket; using DSharpPlus.Entities;
using DSharpPlus.EventArgs;
using DSharpPlus.Exceptions;
using PluralKit.Core; using PluralKit.Core;
namespace PluralKit.Bot { namespace PluralKit.Bot {
public static class ContextUtils { public static class ContextUtils {
public static async Task<bool> PromptYesNo(this Context ctx, IUserMessage message, IUser user = null, TimeSpan? timeout = null) { public static async Task<bool> PromptYesNo(this Context ctx, DiscordMessage message, DiscordUser user = null, TimeSpan? timeout = null) {
// "Fork" the task adding the reactions off so we don't have to wait for them to be finished to start listening for presses // "Fork" the task adding the reactions off so we don't have to wait for them to be finished to start listening for presses
#pragma warning disable 4014 var _ = message.CreateReactionsBulk(new[] {Emojis.Success, Emojis.Error});
message.AddReactionsAsync(new IEmote[] {new Emoji(Emojis.Success), new Emoji(Emojis.Error)}); var reaction = await ctx.AwaitReaction(message, user ?? ctx.Author, r => r.Emoji.Name == Emojis.Success || r.Emoji.Name == Emojis.Error, timeout ?? TimeSpan.FromMinutes(1));
#pragma warning restore 4014 return reaction.Emoji.Name == Emojis.Success;
var reaction = await ctx.AwaitReaction(message, user ?? ctx.Author, (r) => r.Emote.Name == Emojis.Success || r.Emote.Name == Emojis.Error, timeout ?? TimeSpan.FromMinutes(1));
return reaction.Emote.Name == Emojis.Success;
} }
public static async Task<SocketReaction> AwaitReaction(this Context ctx, IUserMessage message, IUser user = null, Func<SocketReaction, bool> predicate = null, TimeSpan? timeout = null) { public static async Task<MessageReactionAddEventArgs> AwaitReaction(this Context ctx, DiscordMessage message, DiscordUser user = null, Func<MessageReactionAddEventArgs, bool> predicate = null, TimeSpan? timeout = null) {
var tcs = new TaskCompletionSource<SocketReaction>(); var tcs = new TaskCompletionSource<MessageReactionAddEventArgs>();
Task Inner(Cacheable<IUserMessage, ulong> _message, ISocketMessageChannel _channel, SocketReaction reaction) { Task Inner(MessageReactionAddEventArgs args) {
if (message.Id != _message.Id) return Task.CompletedTask; // Ignore reactions for different messages if (message.Id != args.Message.Id) return Task.CompletedTask; // Ignore reactions for different messages
if (user != null && user.Id != reaction.UserId) return Task.CompletedTask; // Ignore messages from other users if a user was defined if (user != null && user.Id != args.User.Id) return Task.CompletedTask; // Ignore messages from other users if a user was defined
if (predicate != null && !predicate.Invoke(reaction)) return Task.CompletedTask; // Check predicate if (predicate != null && !predicate.Invoke(args)) return Task.CompletedTask; // Check predicate
tcs.SetResult(reaction); tcs.SetResult(args);
return Task.CompletedTask; return Task.CompletedTask;
} }
((BaseSocketClient) ctx.Shard).ReactionAdded += Inner; ctx.Shard.MessageReactionAdded += Inner;
try { try {
return await (tcs.Task.TimeoutAfter(timeout)); return await tcs.Task.TimeoutAfter(timeout);
} finally { } finally {
((BaseSocketClient) ctx.Shard).ReactionAdded -= Inner; ctx.Shard.MessageReactionAdded -= Inner;
} }
} }
public static async Task<IUserMessage> AwaitMessage(this Context ctx, IMessageChannel channel, IUser user = null, Func<SocketMessage, bool> predicate = null, TimeSpan? timeout = null) { public static async Task<DiscordMessage> AwaitMessage(this Context ctx, DiscordChannel channel, DiscordUser user = null, Func<DiscordMessage, bool> predicate = null, TimeSpan? timeout = null) {
var tcs = new TaskCompletionSource<IUserMessage>(); var tcs = new TaskCompletionSource<DiscordMessage>();
Task Inner(SocketMessage msg) { Task Inner(MessageCreateEventArgs args)
{
var msg = args.Message;
if (channel != msg.Channel) return Task.CompletedTask; // Ignore messages in a different channel if (channel != msg.Channel) return Task.CompletedTask; // Ignore messages in a different channel
if (user != null && user != msg.Author) return Task.CompletedTask; // Ignore messages from other users if (user != null && user != msg.Author) return Task.CompletedTask; // Ignore messages from other users
if (predicate != null && !predicate.Invoke(msg)) return Task.CompletedTask; // Check predicate if (predicate != null && !predicate.Invoke(msg)) return Task.CompletedTask; // Check predicate
tcs.SetResult(msg);
((BaseSocketClient) ctx.Shard).MessageReceived -= Inner;
tcs.SetResult(msg as IUserMessage);
return Task.CompletedTask; return Task.CompletedTask;
} }
((BaseSocketClient) ctx.Shard).MessageReceived += Inner; ctx.Shard.MessageCreated += Inner;
return await (tcs.Task.TimeoutAfter(timeout)); try
{
return await (tcs.Task.TimeoutAfter(timeout));
}
finally
{
ctx.Shard.MessageCreated -= Inner;
}
} }
public static async Task<bool> ConfirmWithReply(this Context ctx, string expectedReply) public static async Task<bool> ConfirmWithReply(this Context ctx, string expectedReply)
@ -61,20 +66,20 @@ namespace PluralKit.Bot {
return string.Equals(msg.Content, expectedReply, StringComparison.InvariantCultureIgnoreCase); return string.Equals(msg.Content, expectedReply, StringComparison.InvariantCultureIgnoreCase);
} }
public static async Task Paginate<T>(this Context ctx, IAsyncEnumerable<T> items, int totalCount, int itemsPerPage, string title, Func<EmbedBuilder, IEnumerable<T>, Task> renderer) { public static async Task Paginate<T>(this Context ctx, IAsyncEnumerable<T> items, int totalCount, int itemsPerPage, string title, Func<DiscordEmbedBuilder, IEnumerable<T>, Task> renderer) {
// TODO: make this generic enough we can use it in Choose<T> below // TODO: make this generic enough we can use it in Choose<T> below
var buffer = new List<T>(); var buffer = new List<T>();
await using var enumerator = items.GetAsyncEnumerator(); await using var enumerator = items.GetAsyncEnumerator();
var pageCount = (totalCount / itemsPerPage) + 1; var pageCount = (totalCount / itemsPerPage) + 1;
async Task<Embed> MakeEmbedForPage(int page) async Task<DiscordEmbed> MakeEmbedForPage(int page)
{ {
var bufferedItemsNeeded = (page + 1) * itemsPerPage; var bufferedItemsNeeded = (page + 1) * itemsPerPage;
while (buffer.Count < bufferedItemsNeeded && await enumerator.MoveNextAsync()) while (buffer.Count < bufferedItemsNeeded && await enumerator.MoveNextAsync())
buffer.Add(enumerator.Current); buffer.Add(enumerator.Current);
var eb = new EmbedBuilder(); var eb = new DiscordEmbedBuilder();
eb.Title = pageCount > 1 ? $"[{page+1}/{pageCount}] {title}" : title; eb.Title = pageCount > 1 ? $"[{page+1}/{pageCount}] {title}" : title;
await renderer(eb, buffer.Skip(page*itemsPerPage).Take(itemsPerPage)); await renderer(eb, buffer.Skip(page*itemsPerPage).Take(itemsPerPage));
return eb.Build(); return eb.Build();
@ -84,8 +89,9 @@ namespace PluralKit.Bot {
{ {
var msg = await ctx.Reply(embed: await MakeEmbedForPage(0)); var msg = await ctx.Reply(embed: await MakeEmbedForPage(0));
if (pageCount == 1) return; // If we only have one page, don't bother with the reaction/pagination logic, lol if (pageCount == 1) return; // If we only have one page, don't bother with the reaction/pagination logic, lol
IEmote[] botEmojis = { new Emoji("\u23EA"), new Emoji("\u2B05"), new Emoji("\u27A1"), new Emoji("\u23E9"), new Emoji(Emojis.Error) }; string[] botEmojis = { "\u23EA", "\u2B05", "\u27A1", "\u23E9", Emojis.Error };
await msg.AddReactionsAsync(botEmojis);
var _ = msg.CreateReactionsBulk(botEmojis); // Again, "fork"
try { try {
var currentPage = 0; var currentPage = 0;
@ -93,31 +99,30 @@ namespace PluralKit.Bot {
var reaction = await ctx.AwaitReaction(msg, ctx.Author, timeout: TimeSpan.FromMinutes(5)); var reaction = await ctx.AwaitReaction(msg, ctx.Author, timeout: TimeSpan.FromMinutes(5));
// Increment/decrement page counter based on which reaction was clicked // Increment/decrement page counter based on which reaction was clicked
if (reaction.Emote.Name == "\u23EA") currentPage = 0; // << if (reaction.Emoji.Name == "\u23EA") currentPage = 0; // <<
if (reaction.Emote.Name == "\u2B05") currentPage = (currentPage - 1) % pageCount; // < if (reaction.Emoji.Name == "\u2B05") currentPage = (currentPage - 1) % pageCount; // <
if (reaction.Emote.Name == "\u27A1") currentPage = (currentPage + 1) % pageCount; // > if (reaction.Emoji.Name == "\u27A1") currentPage = (currentPage + 1) % pageCount; // >
if (reaction.Emote.Name == "\u23E9") currentPage = pageCount - 1; // >> if (reaction.Emoji.Name == "\u23E9") currentPage = pageCount - 1; // >>
if (reaction.Emote.Name == Emojis.Error) break; // X if (reaction.Emoji.Name == Emojis.Error) break; // X
// C#'s % operator is dumb and wrong, so we fix negative numbers // C#'s % operator is dumb and wrong, so we fix negative numbers
if (currentPage < 0) currentPage += pageCount; if (currentPage < 0) currentPage += pageCount;
// If we can, remove the user's reaction (so they can press again quickly) // If we can, remove the user's reaction (so they can press again quickly)
if (ctx.BotHasPermission(ChannelPermission.ManageMessages) && reaction.User.IsSpecified) await msg.RemoveReactionAsync(reaction.Emote, reaction.User.Value); if (ctx.BotHasPermission(Permissions.ManageMessages)) await msg.DeleteReactionAsync(reaction.Emoji, reaction.User);
// Edit the embed with the new page // Edit the embed with the new page
var embed = await MakeEmbedForPage(currentPage); var embed = await MakeEmbedForPage(currentPage);
await msg.ModifyAsync((mp) => mp.Embed = embed); await msg.ModifyAsync(embed: embed);
} }
} catch (TimeoutException) { } catch (TimeoutException) {
// "escape hatch", clean up as if we hit X // "escape hatch", clean up as if we hit X
} }
if (ctx.BotHasPermission(ChannelPermission.ManageMessages)) await msg.RemoveAllReactionsAsync(); if (ctx.BotHasPermission(Permissions.ManageMessages)) await msg.DeleteAllReactionsAsync();
else await msg.RemoveReactionsAsync(ctx.Shard.CurrentUser, botEmojis);
} }
// If we get a "NotFound" error, the message has been deleted and thus not our problem // If we get a "NotFound" error, the message has been deleted and thus not our problem
catch (HttpException e) when (e.HttpCode == HttpStatusCode.NotFound) { } catch (NotFoundException) { }
} }
public static async Task<T> Choose<T>(this Context ctx, string description, IList<T> items, Func<T, string> display = null) public static async Task<T> Choose<T>(this Context ctx, string description, IList<T> items, Func<T, string> display = null)
@ -152,36 +157,35 @@ namespace PluralKit.Bot {
// Add back/forward reactions and the actual indicator emojis // Add back/forward reactions and the actual indicator emojis
async Task AddEmojis() async Task AddEmojis()
{ {
await msg.AddReactionAsync(new Emoji("\u2B05")); await msg.CreateReactionAsync(DiscordEmoji.FromUnicode("\u2B05"));
await msg.AddReactionAsync(new Emoji("\u27A1")); await msg.CreateReactionAsync(DiscordEmoji.FromUnicode("\u27A1"));
for (int i = 0; i < items.Count; i++) await msg.AddReactionAsync(new Emoji(indicators[i])); for (int i = 0; i < items.Count; i++) await msg.CreateReactionAsync(DiscordEmoji.FromUnicode(indicators[i]));
} }
var _ = AddEmojis(); // Not concerned about awaiting var _ = AddEmojis(); // Not concerned about awaiting
while (true) while (true)
{ {
// Wait for a reaction // Wait for a reaction
var reaction = await ctx.AwaitReaction(msg, ctx.Author); var reaction = await ctx.AwaitReaction(msg, ctx.Author);
// If it's a movement reaction, inc/dec the page index // If it's a movement reaction, inc/dec the page index
if (reaction.Emote.Name == "\u2B05") currPage -= 1; // < if (reaction.Emoji.Name == "\u2B05") currPage -= 1; // <
if (reaction.Emote.Name == "\u27A1") currPage += 1; // > if (reaction.Emoji.Name == "\u27A1") currPage += 1; // >
if (currPage < 0) currPage += pageCount; if (currPage < 0) currPage += pageCount;
if (currPage >= pageCount) currPage -= pageCount; if (currPage >= pageCount) currPage -= pageCount;
// If it's an indicator emoji, return the relevant item // If it's an indicator emoji, return the relevant item
if (indicators.Contains(reaction.Emote.Name)) if (indicators.Contains(reaction.Emoji.Name))
{ {
var idx = Array.IndexOf(indicators, reaction.Emote.Name) + pageSize * currPage; var idx = Array.IndexOf(indicators, reaction.Emoji.Name) + pageSize * currPage;
// only if it's in bounds, though // only if it's in bounds, though
// eg. 8 items, we're on page 2, and I hit D (3 + 1*7 = index 10 on an 8-long list) = boom // eg. 8 items, we're on page 2, and I hit D (3 + 1*7 = index 10 on an 8-long list) = boom
if (idx < items.Count) return items[idx]; if (idx < items.Count) return items[idx];
} }
var __ = msg.RemoveReactionAsync(reaction.Emote, ctx.Author); // don't care about awaiting var __ = msg.DeleteReactionAsync(reaction.Emoji, ctx.Author); // don't care about awaiting
await msg.ModifyAsync(mp => mp.Content = $"**[Page {currPage + 1}/{pageCount}]**\n{description}\n{MakeOptionList(currPage)}"); await msg.ModifyAsync($"**[Page {currPage + 1}/{pageCount}]**\n{description}\n{MakeOptionList(currPage)}");
} }
} }
else else
@ -191,26 +195,21 @@ namespace PluralKit.Bot {
// Add the relevant reactions (we don't care too much about awaiting) // Add the relevant reactions (we don't care too much about awaiting)
async Task AddEmojis() async Task AddEmojis()
{ {
for (int i = 0; i < items.Count; i++) await msg.AddReactionAsync(new Emoji(indicators[i])); for (int i = 0; i < items.Count; i++) await msg.CreateReactionAsync(DiscordEmoji.FromUnicode(indicators[i]));
} }
var _ = AddEmojis(); var _ = AddEmojis();
// Then wait for a reaction and return whichever one we found // Then wait for a reaction and return whichever one we found
var reaction = await ctx.AwaitReaction(msg, ctx.Author,rx => indicators.Contains(rx.Emote.Name)); var reaction = await ctx.AwaitReaction(msg, ctx.Author,rx => indicators.Contains(rx.Emoji.Name));
return items[Array.IndexOf(indicators, reaction.Emote.Name)]; return items[Array.IndexOf(indicators, reaction.Emoji.Name)];
} }
} }
public static ChannelPermissions BotPermissions(this Context ctx) { public static Permissions BotPermissions(this Context ctx) => ctx.Channel.BotPermissions();
if (ctx.Channel is SocketGuildChannel gc) {
var gu = gc.Guild.CurrentUser;
return gu.GetPermissions(gc);
}
return ChannelPermissions.DM;
}
public static bool BotHasPermission(this Context ctx, ChannelPermission permission) => BotPermissions(ctx).Has(permission); public static bool BotHasPermission(this Context ctx, Permissions permission) =>
ctx.Channel.BotHasPermission(permission);
public static async Task BusyIndicator(this Context ctx, Func<Task> f, string emoji = "\u23f3" /* hourglass */) public static async Task BusyIndicator(this Context ctx, Func<Task> f, string emoji = "\u23f3" /* hourglass */)
{ {
@ -226,17 +225,17 @@ namespace PluralKit.Bot {
var task = f(); var task = f();
// If we don't have permission to add reactions, don't bother, and just await the task normally. // If we don't have permission to add reactions, don't bother, and just await the task normally.
if (!ctx.BotHasPermission(ChannelPermission.AddReactions)) return await task; var neededPermissions = Permissions.AddReactions | Permissions.ReadMessageHistory;
if (!ctx.BotHasPermission(ChannelPermission.ReadMessageHistory)) return await task; if ((ctx.BotPermissions() & neededPermissions) != neededPermissions) return await task;
try try
{ {
await Task.WhenAll(ctx.Message.AddReactionAsync(new Emoji(emoji)), task); await Task.WhenAll(ctx.Message.CreateReactionAsync(DiscordEmoji.FromUnicode(emoji)), task);
return await task; return await task;
} }
finally finally
{ {
var _ = ctx.Message.RemoveReactionAsync(new Emoji(emoji), ctx.Shard.CurrentUser); var _ = ctx.Message.DeleteReactionAsync(DiscordEmoji.FromUnicode(emoji), ctx.Shard.CurrentUser);
} }
} }
} }

View File

@ -1,31 +1,76 @@
using Discord; using System;
using Discord.WebSocket; using System.Threading.Tasks;
using DSharpPlus;
using DSharpPlus.Entities;
using NodaTime;
namespace PluralKit.Bot namespace PluralKit.Bot
{ {
public static class DiscordUtils public static class DiscordUtils
{ {
public static string NameAndMention(this IUser user) { public static string NameAndMention(this DiscordUser user) {
return $"{user.Username}#{user.Discriminator} ({user.Mention})"; return $"{user.Username}#{user.Discriminator} ({user.Mention})";
} }
public static ChannelPermissions PermissionsIn(this IChannel channel) public static async Task<Permissions> PermissionsIn(this DiscordChannel channel, DiscordUser user)
{ {
switch (channel) if (channel.Guild != null)
{ {
case IDMChannel _: var member = await channel.Guild.GetMemberAsync(user.Id);
return ChannelPermissions.DM; return member.PermissionsIn(channel);
case IGroupChannel _:
return ChannelPermissions.Group;
case SocketGuildChannel gc:
var currentUser = gc.Guild.CurrentUser;
return currentUser.GetPermissions(gc);
default:
return ChannelPermissions.None;
} }
if (channel.Type == ChannelType.Private)
return (Permissions) 0b00000_1000110_1011100110000_000000;
return Permissions.None;
} }
public static bool HasPermission(this IChannel channel, ChannelPermission permission) => public static Permissions PermissionsInSync(this DiscordChannel channel, DiscordUser user)
PermissionsIn(channel).Has(permission); {
if (user is DiscordMember dm && channel.Guild != null)
return dm.PermissionsIn(channel);
if (channel.Type == ChannelType.Private)
return (Permissions) 0b00000_1000110_1011100110000_000000;
return Permissions.None;
}
public static Permissions BotPermissions(this DiscordChannel channel)
{
if (channel.Guild != null)
{
var member = channel.Guild.CurrentMember;
return channel.PermissionsFor(member);
}
if (channel.Type == ChannelType.Private)
return (Permissions) 0b00000_1000110_1011100110000_000000;
return Permissions.None;
}
public static bool BotHasPermission(this DiscordChannel channel, Permissions permissionSet) =>
(BotPermissions(channel) & permissionSet) == permissionSet;
public static Instant SnowflakeToInstant(ulong snowflake) =>
Instant.FromUtc(2015, 1, 1, 0, 0, 0) + Duration.FromMilliseconds(snowflake << 22);
public static ulong InstantToSnowflake(Instant time) =>
(ulong) (time - Instant.FromUtc(2015, 1, 1, 0, 0, 0)).TotalMilliseconds >> 22;
public static ulong InstantToSnowflake(DateTimeOffset time) =>
(ulong) (time - new DateTimeOffset(2015, 1, 1, 0, 0, 0, TimeSpan.Zero)).TotalMilliseconds >> 22;
public static async Task CreateReactionsBulk(this DiscordMessage msg, string[] reactions)
{
foreach (var reaction in reactions)
{
await msg.CreateReactionAsync(DiscordEmoji.FromUnicode(reaction));
}
}
} }
} }

View File

@ -3,8 +3,6 @@ using System.Linq;
using System.Net.Sockets; using System.Net.Sockets;
using System.Threading.Tasks; using System.Threading.Tasks;
using Discord.Net;
using Npgsql; using Npgsql;
using PluralKit.Core; using PluralKit.Core;
@ -20,7 +18,8 @@ namespace PluralKit.Bot
// otherwise we'd blow out our error reporting budget as soon as Discord takes a dump, or something. // otherwise we'd blow out our error reporting budget as soon as Discord takes a dump, or something.
// Discord server errors are *not our problem* // Discord server errors are *not our problem*
if (e is HttpException he && ((int) he.HttpCode) >= 500) return false; // TODO
// if (e is DSharpPlus.Exceptions he && ((int) he.HttpCode) >= 500) return false;
// Webhook server errors are also *not our problem* // Webhook server errors are also *not our problem*
// (this includes rate limit errors, WebhookRateLimited is a subclass) // (this includes rate limit errors, WebhookRateLimited is a subclass)

View File

@ -2,16 +2,16 @@
using System.Globalization; using System.Globalization;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using Discord; using DSharpPlus.Entities;
namespace PluralKit.Bot namespace PluralKit.Bot
{ {
public static class StringUtils public static class StringUtils
{ {
public static Color? ToDiscordColor(this string color) public static DiscordColor? ToDiscordColor(this string color)
{ {
if (uint.TryParse(color, NumberStyles.HexNumber, null, out var colorInt)) if (int.TryParse(color, NumberStyles.HexNumber, null, out var colorInt))
return new Color(colorInt); return new DiscordColor(colorInt);
throw new ArgumentException($"Invalid color string '{color}'."); throw new ArgumentException($"Invalid color string '{color}'.");
} }
@ -23,7 +23,7 @@ namespace PluralKit.Bot
if (string.IsNullOrEmpty(content) || content.Length <= 3 || (content[0] != '<' || content[1] != '@')) if (string.IsNullOrEmpty(content) || content.Length <= 3 || (content[0] != '<' || content[1] != '@'))
return false; return false;
int num = content.IndexOf('>'); int num = content.IndexOf('>');
if (num == -1 || content.Length < num + 2 || content[num + 1] != ' ' || !MentionUtils.TryParseUser(content.Substring(0, num + 1), out mentionId)) if (num == -1 || content.Length < num + 2 || content[num + 1] != ' ' || !TryParseMention(content.Substring(0, num + 1), out mentionId))
return false; return false;
argPos = num + 2; argPos = num + 2;
return true; return true;
@ -32,7 +32,18 @@ namespace PluralKit.Bot
public static bool TryParseMention(this string potentialMention, out ulong id) public static bool TryParseMention(this string potentialMention, out ulong id)
{ {
if (ulong.TryParse(potentialMention, out id)) return true; if (ulong.TryParse(potentialMention, out id)) return true;
if (MentionUtils.TryParseUser(potentialMention, out id)) return true;
// Roughly ported from Discord.MentionUtils.TryParseUser
if (potentialMention.Length >= 3 && potentialMention[0] == '<' && potentialMention[1] == '@' && potentialMention[potentialMention.Length - 1] == '>')
{
if (potentialMention.Length >= 4 && potentialMention[2] == '!')
potentialMention = potentialMention.Substring(3, potentialMention.Length - 4); //<@!123>
else
potentialMention = potentialMention.Substring(2, potentialMention.Length - 3); //<@123>
if (ulong.TryParse(potentialMention, NumberStyles.None, CultureInfo.InvariantCulture, out id))
return true;
}
return false; return false;
} }