Initial commit, basic proxying working

This commit is contained in:
Ske 2020-12-22 13:15:26 +01:00
parent c3f6becea4
commit a6fbd869be
109 changed files with 3539 additions and 359 deletions

View File

@ -0,0 +1,50 @@
using System.Threading.Tasks;
using Myriad.Gateway;
namespace Myriad.Cache
{
public static class DiscordCacheExtensions
{
public static ValueTask HandleGatewayEvent(this IDiscordCache cache, IGatewayEvent evt)
{
switch (evt)
{
case GuildCreateEvent gc:
return cache.SaveGuildCreate(gc);
case GuildUpdateEvent gu:
return cache.SaveGuild(gu);
case GuildDeleteEvent gd:
return cache.RemoveGuild(gd.Id);
case ChannelCreateEvent cc:
return cache.SaveChannel(cc);
case ChannelUpdateEvent cu:
return cache.SaveChannel(cu);
case ChannelDeleteEvent cd:
return cache.RemoveChannel(cd.Id);
case GuildRoleCreateEvent grc:
return cache.SaveRole(grc.GuildId, grc.Role);
case GuildRoleUpdateEvent gru:
return cache.SaveRole(gru.GuildId, gru.Role);
case GuildRoleDeleteEvent grd:
return cache.RemoveRole(grd.GuildId, grd.RoleId);
case MessageCreateEvent mc:
return cache.SaveUser(mc.Author);
}
return default;
}
private static async ValueTask SaveGuildCreate(this IDiscordCache cache, GuildCreateEvent guildCreate)
{
await cache.SaveGuild(guildCreate);
foreach (var channel in guildCreate.Channels)
// The channel object does not include GuildId for some reason...
await cache.SaveChannel(channel with { GuildId = guildCreate.Id });
foreach (var member in guildCreate.Members)
await cache.SaveUser(member.User);
}
}
}

View File

@ -0,0 +1,28 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using Myriad.Types;
namespace Myriad.Cache
{
public interface IDiscordCache
{
public ValueTask SaveGuild(Guild guild);
public ValueTask SaveChannel(Channel channel);
public ValueTask SaveUser(User user);
public ValueTask SaveRole(ulong guildId, Role role);
public ValueTask RemoveGuild(ulong guildId);
public ValueTask RemoveChannel(ulong channelId);
public ValueTask RemoveUser(ulong userId);
public ValueTask RemoveRole(ulong guildId, ulong roleId);
public ValueTask<Guild?> GetGuild(ulong guildId);
public ValueTask<Channel?> GetChannel(ulong channelId);
public ValueTask<User?> GetUser(ulong userId);
public ValueTask<Role?> GetRole(ulong roleId);
public IAsyncEnumerable<Guild> GetAllGuilds();
public ValueTask<IEnumerable<Channel>> GetGuildChannels(ulong guildId);
}
}

View File

@ -0,0 +1,143 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Myriad.Types;
namespace Myriad.Cache
{
public class MemoryDiscordCache: IDiscordCache
{
private readonly ConcurrentDictionary<ulong, Channel> _channels;
private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds;
private readonly ConcurrentDictionary<ulong, Role> _roles;
private readonly ConcurrentDictionary<ulong, User> _users;
public MemoryDiscordCache()
{
_guilds = new ConcurrentDictionary<ulong, CachedGuild>();
_channels = new ConcurrentDictionary<ulong, Channel>();
_users = new ConcurrentDictionary<ulong, User>();
_roles = new ConcurrentDictionary<ulong, Role>();
}
public ValueTask SaveGuild(Guild guild)
{
SaveGuildRaw(guild);
foreach (var role in guild.Roles)
// Don't call SaveRole because that updates guild state
// and we just got a brand new one :)
_roles[role.Id] = role;
return default;
}
public ValueTask SaveChannel(Channel channel)
{
_channels[channel.Id] = channel;
if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild))
guild.Channels.TryAdd(channel.Id, true);
return default;
}
public ValueTask SaveUser(User user)
{
_users[user.Id] = user;
return default;
}
public ValueTask SaveRole(ulong guildId, Role role)
{
_roles[role.Id] = role;
if (_guilds.TryGetValue(guildId, out var guild))
{
// TODO: this code is stinky
var found = false;
for (var i = 0; i < guild.Guild.Roles.Length; i++)
{
if (guild.Guild.Roles[i].Id != role.Id)
continue;
guild.Guild.Roles[i] = role;
found = true;
}
if (!found)
{
_guilds[guildId] = guild with {
Guild = guild.Guild with {
Roles = guild.Guild.Roles.Concat(new[] { role}).ToArray()
}
};
}
}
return default;
}
public ValueTask RemoveGuild(ulong guildId)
{
_guilds.TryRemove(guildId, out _);
return default;
}
public ValueTask RemoveChannel(ulong channelId)
{
if (!_channels.TryRemove(channelId, out var channel))
return default;
if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild))
guild.Channels.TryRemove(channel.Id, out _);
return default;
}
public ValueTask RemoveUser(ulong userId)
{
_users.TryRemove(userId, out _);
return default;
}
public ValueTask RemoveRole(ulong guildId, ulong roleId)
{
_roles.TryRemove(roleId, out _);
return default;
}
public ValueTask<Guild?> GetGuild(ulong guildId) => new(_guilds.GetValueOrDefault(guildId)?.Guild);
public ValueTask<Channel?> GetChannel(ulong channelId) => new(_channels.GetValueOrDefault(channelId));
public ValueTask<User?> GetUser(ulong userId) => new(_users.GetValueOrDefault(userId));
public ValueTask<Role?> GetRole(ulong roleId) => new(_roles.GetValueOrDefault(roleId));
public async IAsyncEnumerable<Guild> GetAllGuilds()
{
foreach (var guild in _guilds.Values)
yield return guild.Guild;
}
public ValueTask<IEnumerable<Channel>> GetGuildChannels(ulong guildId)
{
if (!_guilds.TryGetValue(guildId, out var guild))
throw new ArgumentException("Guild not found", nameof(guildId));
return new ValueTask<IEnumerable<Channel>>(guild.Channels.Keys.Select(c => _channels[c]));
}
private CachedGuild SaveGuildRaw(Guild guild) =>
_guilds.GetOrAdd(guild.Id, (_, g) => new CachedGuild(g), guild);
private record CachedGuild(Guild Guild)
{
public readonly ConcurrentDictionary<ulong, bool> Channels = new();
}
}
}

View File

@ -0,0 +1,7 @@
namespace Myriad.Extensions
{
public static class ChannelExtensions
{
}
}

View File

@ -0,0 +1,7 @@
namespace Myriad.Extensions
{
public class MessageExtensions
{
}
}

View File

@ -0,0 +1,126 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Myriad.Gateway;
using Myriad.Types;
namespace Myriad.Extensions
{
public static class PermissionExtensions
{
public static PermissionSet EveryonePermissions(this Guild guild) =>
guild.Roles.FirstOrDefault(r => r.Id == guild.Id)?.Permissions ?? PermissionSet.Dm;
public static PermissionSet PermissionsFor(Guild guild, Channel channel, MessageCreateEvent msg) =>
PermissionsFor(guild, channel, msg.Author.Id, msg.Member!.Roles);
public static PermissionSet PermissionsFor(Guild guild, Channel channel, ulong userId,
ICollection<ulong> roleIds)
{
if (channel.Type == Channel.ChannelType.Dm)
return PermissionSet.Dm;
var perms = GuildPermissions(guild, userId, roleIds);
perms = ApplyChannelOverwrites(perms, channel, userId, roleIds);
if ((perms & PermissionSet.Administrator) == PermissionSet.Administrator)
return PermissionSet.All;
if ((perms & PermissionSet.ViewChannel) == 0)
perms &= ~NeedsViewChannel;
if ((perms & PermissionSet.SendMessages) == 0)
perms &= ~NeedsSendMessages;
return perms;
}
public static bool Has(this PermissionSet value, PermissionSet flag) =>
(value & flag) == flag;
public static PermissionSet GuildPermissions(this Guild guild, ulong userId, ICollection<ulong> roleIds)
{
if (guild.OwnerId == userId)
return PermissionSet.All;
var perms = PermissionSet.None;
foreach (var role in guild.Roles)
{
if (role.Id == guild.Id || roleIds.Contains(role.Id))
perms |= role.Permissions;
}
if (perms.Has(PermissionSet.Administrator))
return PermissionSet.All;
return perms;
}
public static PermissionSet ApplyChannelOverwrites(PermissionSet perms, Channel channel, ulong userId,
ICollection<ulong> roleIds)
{
if (channel.PermissionOverwrites == null)
return perms;
var everyoneDeny = PermissionSet.None;
var everyoneAllow = PermissionSet.None;
var roleDeny = PermissionSet.None;
var roleAllow = PermissionSet.None;
var userDeny = PermissionSet.None;
var userAllow = PermissionSet.None;
foreach (var overwrite in channel.PermissionOverwrites)
{
switch (overwrite.Type)
{
case Channel.OverwriteType.Role when overwrite.Id == channel.GuildId:
everyoneDeny |= overwrite.Deny;
everyoneAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Role when roleIds.Contains(overwrite.Id):
roleDeny |= overwrite.Deny;
roleAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Member when overwrite.Id == userId:
userDeny |= overwrite.Deny;
userAllow |= overwrite.Allow;
break;
}
}
perms &= ~everyoneDeny;
perms |= everyoneAllow;
perms &= ~roleDeny;
perms |= roleAllow;
perms &= ~userDeny;
perms |= userAllow;
return perms;
}
private const PermissionSet NeedsViewChannel =
PermissionSet.SendMessages |
PermissionSet.SendTtsMessages |
PermissionSet.ManageMessages |
PermissionSet.EmbedLinks |
PermissionSet.AttachFiles |
PermissionSet.ReadMessageHistory |
PermissionSet.MentionEveryone |
PermissionSet.UseExternalEmojis |
PermissionSet.AddReactions |
PermissionSet.Connect |
PermissionSet.Speak |
PermissionSet.MuteMembers |
PermissionSet.DeafenMembers |
PermissionSet.MoveMembers |
PermissionSet.UseVad |
PermissionSet.Stream |
PermissionSet.PrioritySpeaker;
private const PermissionSet NeedsSendMessages =
PermissionSet.MentionEveryone |
PermissionSet.SendTtsMessages |
PermissionSet.AttachFiles |
PermissionSet.EmbedLinks;
}
}

View File

@ -0,0 +1,10 @@
using Myriad.Types;
namespace Myriad.Extensions
{
public static class UserExtensions
{
public static string AvatarUrl(this User user) =>
$"https://cdn.discordapp.com/avatars/{user.Id}/{user.Avatar}.png";
}
}

88
Myriad/Gateway/Cluster.cs Normal file
View File

@ -0,0 +1,88 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Myriad.Types;
using Serilog;
namespace Myriad.Gateway
{
public class Cluster
{
private readonly GatewaySettings _gatewaySettings;
private readonly ILogger _logger;
private readonly ConcurrentDictionary<int, Shard> _shards = new();
public Cluster(GatewaySettings gatewaySettings, ILogger logger)
{
_gatewaySettings = gatewaySettings;
_logger = logger;
}
public Func<Shard, IGatewayEvent, Task>? EventReceived { get; set; }
public IReadOnlyDictionary<int, Shard> Shards => _shards;
public ClusterSessionState SessionState => GetClusterState();
public User? User => _shards.Values.Select(s => s.User).FirstOrDefault(s => s != null);
private ClusterSessionState GetClusterState()
{
var shards = new List<ClusterSessionState.ShardState>();
foreach (var (id, shard) in _shards)
shards.Add(new ClusterSessionState.ShardState
{
Shard = shard.ShardInfo ?? new ShardInfo(id, _shards.Count), Session = shard.SessionInfo
});
return new ClusterSessionState {Shards = shards};
}
public async Task Start(GatewayInfo.Bot info, ClusterSessionState? lastState = null)
{
if (lastState != null && lastState.Shards.Count == info.Shards)
await Resume(info.Url, lastState);
else
await Start(info.Url, info.Shards);
}
public async Task Resume(string url, ClusterSessionState sessionState)
{
_logger.Information("Resuming session with {ShardCount} shards at {Url}", sessionState.Shards.Count, url);
foreach (var shardState in sessionState.Shards)
CreateAndAddShard(url, shardState.Shard, shardState.Session);
await StartShards();
}
public async Task Start(string url, int shardCount)
{
_logger.Information("Starting {ShardCount} shards at {Url}", shardCount, url);
for (var i = 0; i < shardCount; i++)
CreateAndAddShard(url, new ShardInfo(i, shardCount), null);
await StartShards();
}
private async Task StartShards()
{
_logger.Information("Connecting shards...");
await Task.WhenAll(_shards.Values.Select(s => s.Start()));
}
private void CreateAndAddShard(string url, ShardInfo shardInfo, ShardSessionInfo? session)
{
var shard = new Shard(_logger, new Uri(url), _gatewaySettings, shardInfo, session);
shard.OnEventReceived += evt => OnShardEventReceived(shard, evt);
_shards[shardInfo.ShardId] = shard;
}
private async Task OnShardEventReceived(Shard shard, IGatewayEvent evt)
{
if (EventReceived != null)
await EventReceived(shard, evt);
}
}
}

View File

@ -0,0 +1,15 @@
using System.Collections.Generic;
namespace Myriad.Gateway
{
public record ClusterSessionState
{
public List<ShardState> Shards { get; init; }
public record ShardState
{
public ShardInfo Shard { get; init; }
public ShardSessionInfo Session { get; init; }
}
}
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ChannelCreateEvent: Channel, IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ChannelDeleteEvent: Channel, IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ChannelUpdateEvent: Channel, IGatewayEvent;
}

View File

@ -0,0 +1,12 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildCreateEvent: Guild, IGatewayEvent
{
public Channel[] Channels { get; init; }
public GuildMember[] Members { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GuildDeleteEvent(ulong Id, bool Unavailable): IGatewayEvent;
}

View File

@ -0,0 +1,9 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildMemberAddEvent: GuildMember, IGatewayEvent
{
public ulong GuildId { get; init; }
}
}

View File

@ -0,0 +1,10 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public class GuildMemberRemoveEvent: IGatewayEvent
{
public ulong GuildId { get; init; }
public User User { get; init; }
}
}

View File

@ -0,0 +1,9 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildMemberUpdateEvent: GuildMember, IGatewayEvent
{
public ulong GuildId { get; init; }
}
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildRoleCreateEvent(ulong GuildId, Role Role): IGatewayEvent;
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GuildRoleDeleteEvent(ulong GuildId, ulong RoleId): IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildRoleUpdateEvent(ulong GuildId, Role Role): IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildUpdateEvent: Guild, IGatewayEvent;
}

View File

@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
namespace Myriad.Gateway
{
public interface IGatewayEvent
{
public static readonly Dictionary<string, Type> EventTypes = new()
{
{"READY", typeof(ReadyEvent)},
{"RESUMED", typeof(ResumedEvent)},
{"GUILD_CREATE", typeof(GuildCreateEvent)},
{"GUILD_UPDATE", typeof(GuildUpdateEvent)},
{"GUILD_DELETE", typeof(GuildDeleteEvent)},
{"GUILD_MEMBER_ADD", typeof(GuildMemberAddEvent)},
{"GUILD_MEMBER_REMOVE", typeof(GuildMemberRemoveEvent)},
{"GUILD_MEMBER_UPDATE", typeof(GuildMemberUpdateEvent)},
{"GUILD_ROLE_CREATE", typeof(GuildRoleCreateEvent)},
{"GUILD_ROLE_UPDATE", typeof(GuildRoleUpdateEvent)},
{"GUILD_ROLE_DELETE", typeof(GuildRoleDeleteEvent)},
{"CHANNEL_CREATE", typeof(ChannelCreateEvent)},
{"CHANNEL_UPDATE", typeof(ChannelUpdateEvent)},
{"CHANNEL_DELETE", typeof(ChannelDeleteEvent)},
{"MESSAGE_CREATE", typeof(MessageCreateEvent)},
{"MESSAGE_UPDATE", typeof(MessageUpdateEvent)},
{"MESSAGE_DELETE", typeof(MessageDeleteEvent)},
{"MESSAGE_DELETE_BULK", typeof(MessageDeleteBulkEvent)},
{"MESSAGE_REACTION_ADD", typeof(MessageReactionAddEvent)},
{"MESSAGE_REACTION_REMOVE", typeof(MessageReactionRemoveEvent)},
{"MESSAGE_REACTION_REMOVE_ALL", typeof(MessageReactionRemoveAllEvent)},
{"MESSAGE_REACTION_REMOVE_EMOJI", typeof(MessageReactionRemoveEmojiEvent)},
{"INTERACTION_CREATE", typeof(InteractionCreateEvent)}
};
}
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record InteractionCreateEvent: Interaction, IGatewayEvent;
}

View File

@ -0,0 +1,9 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageCreateEvent: Message, IGatewayEvent
{
public GuildMemberPartial? Member { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record MessageDeleteBulkEvent(ulong[] Ids, ulong ChannelId, ulong? GuildId): IGatewayEvent;
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record MessageDeleteEvent(ulong Id, ulong ChannelId, ulong? GuildId): IGatewayEvent;
}

View File

@ -0,0 +1,8 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageReactionAddEvent(ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId,
GuildMember? Member,
Emoji Emoji): IGatewayEvent;
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record MessageReactionRemoveAllEvent(ulong ChannelId, ulong MessageId, ulong? GuildId): IGatewayEvent;
}

View File

@ -0,0 +1,7 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageReactionRemoveEmojiEvent
(ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent;
}

View File

@ -0,0 +1,7 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageReactionRemoveEvent
(ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent;
}

View File

@ -0,0 +1,7 @@
namespace Myriad.Gateway
{
public record MessageUpdateEvent(ulong Id, ulong ChannelId): IGatewayEvent
{
// TODO: lots of partials
}
}

View File

@ -0,0 +1,15 @@
using System.Text.Json.Serialization;
using Myriad.Types;
namespace Myriad.Gateway
{
public record ReadyEvent: IGatewayEvent
{
[JsonPropertyName("v")] public int Version { get; init; }
public User User { get; init; }
public string SessionId { get; init; }
public ShardInfo? Shard { get; init; }
public ApplicationPartial Application { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record ResumedEvent: IGatewayEvent;
}

View File

@ -0,0 +1,35 @@
using System;
namespace Myriad.Gateway
{
// TODO: unused?
public class GatewayCloseException: Exception
{
public GatewayCloseException(int closeCode, string closeReason): base($"{closeCode}: {closeReason}")
{
CloseCode = closeCode;
CloseReason = closeReason;
}
public int CloseCode { get; }
public string CloseReason { get; }
}
public class GatewayCloseCode
{
public const int UnknownError = 4000;
public const int UnknownOpcode = 4001;
public const int DecodeError = 4002;
public const int NotAuthenticated = 4003;
public const int AuthenticationFailed = 4004;
public const int AlreadyAuthenticated = 4005;
public const int InvalidSeq = 4007;
public const int RateLimited = 4008;
public const int SessionTimedOut = 4009;
public const int InvalidShard = 4010;
public const int ShardingRequired = 4011;
public const int InvalidApiVersion = 4012;
public const int InvalidIntent = 4013;
public const int DisallowedIntent = 4014;
}
}

View File

@ -0,0 +1,24 @@
using System;
namespace Myriad.Gateway
{
[Flags]
public enum GatewayIntent
{
Guilds = 1 << 0,
GuildMembers = 1 << 1,
GuildBans = 1 << 2,
GuildEmojis = 1 << 3,
GuildIntegrations = 1 << 4,
GuildWebhooks = 1 << 5,
GuildInvites = 1 << 6,
GuildVoiceStates = 1 << 7,
GuildPresences = 1 << 8,
GuildMessages = 1 << 9,
GuildMessageReactions = 1 << 10,
GuildMessageTyping = 1 << 11,
DirectMessages = 1 << 12,
DirectMessageReactions = 1 << 13,
DirectMessageTyping = 1 << 14
}
}

View File

@ -0,0 +1,31 @@
using System.Text.Json.Serialization;
namespace Myriad.Gateway
{
public record GatewayPacket
{
[JsonPropertyName("op")] public GatewayOpcode Opcode { get; init; }
[JsonPropertyName("d")] public object? Payload { get; init; }
[JsonPropertyName("s")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? Sequence { get; init; }
[JsonPropertyName("t")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? EventType { get; init; }
}
public enum GatewayOpcode
{
Dispatch = 0,
Heartbeat = 1,
Identify = 2,
PresenceUpdate = 3,
VoiceStateUpdate = 4,
Resume = 6,
Reconnect = 7,
RequestGuildMembers = 8,
InvalidSession = 9,
Hello = 10,
HeartbeatAck = 11
}
}

View File

@ -0,0 +1,8 @@
namespace Myriad.Gateway
{
public record GatewaySettings
{
public string Token { get; init; }
public GatewayIntent Intents { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GatewayHello(int HeartbeatInterval);
}

View File

@ -0,0 +1,28 @@
using System.Text.Json.Serialization;
namespace Myriad.Gateway
{
public record GatewayIdentify
{
public string Token { get; init; }
public ConnectionProperties Properties { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? Compress { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? LargeThreshold { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ShardInfo? Shard { get; init; }
public GatewayIntent Intents { get; init; }
public record ConnectionProperties
{
[JsonPropertyName("$os")] public string Os { get; init; }
[JsonPropertyName("$browser")] public string Browser { get; init; }
[JsonPropertyName("$device")] public string Device { get; init; }
}
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GatewayResume(string Token, string SessionId, int Seq);
}

View File

@ -0,0 +1,23 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Gateway
{
public record GatewayStatusUpdate
{
public enum UserStatus
{
Online,
Dnd,
Idle,
Invisible,
Offline
}
public ulong? Since { get; init; }
public ActivityPartial[]? Activities { get; init; }
public UserStatus Status { get; init; }
public bool Afk { get; init; }
}
}

328
Myriad/Gateway/Shard.cs Normal file
View File

@ -0,0 +1,328 @@
using System;
using System.Net.WebSockets;
using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Serialization;
using Myriad.Types;
using Serilog;
namespace Myriad.Gateway
{
public class Shard: IAsyncDisposable
{
private const string LibraryName = "Newcord Test";
private readonly JsonSerializerOptions _jsonSerializerOptions =
new JsonSerializerOptions().ConfigureForNewcord();
private readonly ILogger _logger;
private readonly Uri _uri;
private ShardConnection? _conn;
private TimeSpan? _currentHeartbeatInterval;
private bool _hasReceivedAck;
private DateTimeOffset? _lastHeartbeatSent;
private Task _worker;
public ShardInfo? ShardInfo { get; private set; }
public GatewaySettings Settings { get; }
public ShardSessionInfo SessionInfo { get; private set; }
public ShardState State { get; private set; }
public TimeSpan? Latency { get; private set; }
public User? User { get; private set; }
public Func<IGatewayEvent, Task>? OnEventReceived { get; set; }
public Shard(ILogger logger, Uri uri, GatewaySettings settings, ShardInfo? info = null,
ShardSessionInfo? sessionInfo = null)
{
_logger = logger;
_uri = uri;
Settings = settings;
ShardInfo = info;
SessionInfo = sessionInfo ?? new ShardSessionInfo();
}
public async ValueTask DisposeAsync()
{
if (_conn != null)
await _conn.DisposeAsync();
}
public Task Start()
{
_worker = MainLoop();
return Task.CompletedTask;
}
public async Task UpdateStatus(GatewayStatusUpdate payload)
{
if (_conn != null && _conn.State == WebSocketState.Open)
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.PresenceUpdate, Payload = payload});
}
private async Task MainLoop()
{
while (true)
try
{
_logger.Information("Connecting...");
State = ShardState.Connecting;
await Connect();
_logger.Information("Connected. Entering main loop...");
// Tick returns false if we need to stop and reconnect
while (await Tick(_conn!))
await Task.Delay(TimeSpan.FromMilliseconds(1000));
_logger.Information("Connection closed, reconnecting...");
State = ShardState.Closed;
}
catch (Exception e)
{
_logger.Error(e, "Error in shard state handler");
}
}
private async Task<bool> Tick(ShardConnection conn)
{
if (conn.State != WebSocketState.Connecting && conn.State != WebSocketState.Open)
return false;
if (!await TickHeartbeat(conn))
// TickHeartbeat returns false if we're disconnecting
return false;
return true;
}
private async Task<bool> TickHeartbeat(ShardConnection conn)
{
// If we don't need to heartbeat, do nothing
if (_lastHeartbeatSent == null || _currentHeartbeatInterval == null)
return true;
if (DateTimeOffset.UtcNow - _lastHeartbeatSent < _currentHeartbeatInterval)
return true;
// If we haven't received the ack in time, close w/ error
if (!_hasReceivedAck)
{
_logger.Warning(
"Did not receive heartbeat Ack from gateway within interval ({HeartbeatInterval})",
_currentHeartbeatInterval);
State = ShardState.Closing;
await conn.Disconnect(WebSocketCloseStatus.ProtocolError, "Did not receive ACK in time");
return false;
}
// Otherwise just send it :)
await SendHeartbeat(conn);
_hasReceivedAck = false;
return true;
}
private async Task SendHeartbeat(ShardConnection conn)
{
_logger.Debug("Sending heartbeat");
await conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = SessionInfo.LastSequence});
_lastHeartbeatSent = DateTimeOffset.UtcNow;
}
private async Task Connect()
{
if (_conn != null)
await _conn.DisposeAsync();
_currentHeartbeatInterval = null;
_conn = new ShardConnection(_uri, _logger, _jsonSerializerOptions) {OnReceive = OnReceive};
}
private async Task OnReceive(GatewayPacket packet)
{
switch (packet.Opcode)
{
case GatewayOpcode.Hello:
{
await HandleHello((JsonElement) packet.Payload!);
break;
}
case GatewayOpcode.Heartbeat:
{
_logger.Debug("Received heartbeat request from shard, sending Ack");
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.HeartbeatAck});
break;
}
case GatewayOpcode.HeartbeatAck:
{
Latency = DateTimeOffset.UtcNow - _lastHeartbeatSent;
_logger.Debug("Received heartbeat Ack (latency {Latency})", Latency);
_hasReceivedAck = true;
break;
}
case GatewayOpcode.Reconnect:
{
_logger.Information("Received Reconnect, closing and reconnecting");
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
break;
}
case GatewayOpcode.InvalidSession:
{
var canResume = ((JsonElement) packet.Payload!).GetBoolean();
// Clear session info before DCing
if (!canResume)
SessionInfo = SessionInfo with { Session = null };
var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000));
_logger.Information(
"Received Invalid Session (can resume? {CanResume}), reconnecting after {ReconnectDelay}",
canResume, delay);
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
// Will reconnect after exiting this "loop"
await Task.Delay(delay);
break;
}
case GatewayOpcode.Dispatch:
{
SessionInfo = SessionInfo with { LastSequence = packet.Sequence };
var evt = DeserializeEvent(packet.EventType!, (JsonElement) packet.Payload!)!;
if (evt is ReadyEvent rdy)
{
if (State == ShardState.Connecting)
await HandleReady(rdy);
else
_logger.Warning("Received Ready event in unexpected state {ShardState}, ignoring?", State);
}
else if (evt is ResumedEvent)
{
if (State == ShardState.Connecting)
await HandleResumed();
else
_logger.Warning("Received Resumed event in unexpected state {ShardState}, ignoring?",
State);
}
await HandleEvent(evt);
break;
}
default:
{
_logger.Debug("Received unknown gateway opcode {Opcode}", packet.Opcode);
break;
}
}
}
private async Task HandleEvent(IGatewayEvent evt)
{
if (OnEventReceived != null)
await OnEventReceived.Invoke(evt);
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement data)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{
_logger.Information("Received unknown event type {EventType}", eventType);
return null;
}
try
{
_logger.Verbose("Deserializing {EventType} to {ClrType}", eventType, clrType);
return JsonSerializer.Deserialize(data.GetRawText(), clrType, _jsonSerializerOptions)
as IGatewayEvent;
}
catch (JsonException e)
{
_logger.Error(e, "Error deserializing event {EventType} to {ClrType}", eventType, clrType);
return null;
}
}
private Task HandleReady(ReadyEvent ready)
{
ShardInfo = ready.Shard;
SessionInfo = SessionInfo with { Session = ready.SessionId };
User = ready.User;
State = ShardState.Open;
return Task.CompletedTask;
}
private Task HandleResumed()
{
State = ShardState.Open;
return Task.CompletedTask;
}
private async Task HandleHello(JsonElement json)
{
var hello = JsonSerializer.Deserialize<GatewayHello>(json.GetRawText(), _jsonSerializerOptions)!;
_logger.Debug("Received Hello with interval {Interval} ms", hello.HeartbeatInterval);
_currentHeartbeatInterval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval);
await SendHeartbeat(_conn!);
await SendIdentifyOrResume();
}
private async Task SendIdentifyOrResume()
{
if (SessionInfo.Session != null && SessionInfo.LastSequence != null)
await SendResume(SessionInfo.Session, SessionInfo.LastSequence!.Value);
else
await SendIdentify();
}
private async Task SendIdentify()
{
_logger.Information("Sending gateway Identify for shard {@ShardInfo}", SessionInfo);
await _conn!.Send(new GatewayPacket
{
Opcode = GatewayOpcode.Identify,
Payload = new GatewayIdentify
{
Token = Settings.Token,
Properties = new GatewayIdentify.ConnectionProperties
{
Browser = LibraryName, Device = LibraryName, Os = Environment.OSVersion.ToString()
},
Intents = Settings.Intents,
Shard = ShardInfo
}
});
}
private async Task SendResume(string session, int lastSequence)
{
_logger.Information("Sending gateway Resume for session {@SessionInfo}", ShardInfo,
SessionInfo);
await _conn!.Send(new GatewayPacket
{
Opcode = GatewayOpcode.Resume, Payload = new GatewayResume(Settings.Token, session, lastSequence)
});
}
public enum ShardState
{
Closed,
Connecting,
Open,
Closing
}
}
}

View File

@ -0,0 +1,118 @@
using System;
using System.Buffers;
using System.IO;
using System.Net.WebSockets;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Serilog;
namespace Myriad.Gateway
{
public class ShardConnection: IAsyncDisposable
{
private readonly MemoryStream _bufStream = new();
private readonly ClientWebSocket _client = new();
private readonly CancellationTokenSource _cts = new();
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger;
private readonly Task _worker;
public ShardConnection(Uri uri, ILogger logger, JsonSerializerOptions jsonSerializerOptions)
{
_logger = logger;
_jsonSerializerOptions = jsonSerializerOptions;
_worker = Worker(uri);
}
public Func<GatewayPacket, Task>? OnReceive { get; set; }
public WebSocketState State => _client.State;
public async ValueTask DisposeAsync()
{
_cts.Cancel();
await _worker;
_client.Dispose();
await _bufStream.DisposeAsync();
_cts.Dispose();
}
private async Task Worker(Uri uri)
{
var realUrl = new UriBuilder(uri)
{
Query = "v=8&encoding=json"
}.Uri;
_logger.Debug("Connecting to gateway WebSocket at {GatewayUrl}", realUrl);
await _client.ConnectAsync(realUrl, default);
while (!_cts.IsCancellationRequested && _client.State == WebSocketState.Open)
try
{
await HandleReceive();
}
catch (Exception e)
{
_logger.Error(e, "Error in WebSocket receive worker");
}
}
private async Task HandleReceive()
{
_bufStream.SetLength(0);
var result = await ReadData(_bufStream);
var data = _bufStream.GetBuffer().AsMemory(0, (int) _bufStream.Position);
if (result.MessageType == WebSocketMessageType.Text)
await HandleReceiveData(data);
else if (result.MessageType == WebSocketMessageType.Close)
_logger.Information("WebSocket closed by server: {StatusCode} {Reason}", _client.CloseStatus,
_client.CloseStatusDescription);
}
private async Task HandleReceiveData(Memory<byte> data)
{
var packet = JsonSerializer.Deserialize<GatewayPacket>(data.Span, _jsonSerializerOptions)!;
try
{
if (OnReceive != null)
await OnReceive.Invoke(packet);
}
catch (Exception e)
{
_logger.Error(e, "Error in gateway handler for {OpcodeType}", packet.Opcode);
}
}
private async Task<ValueWebSocketReceiveResult> ReadData(MemoryStream stream)
{
using var buf = MemoryPool<byte>.Shared.Rent();
ValueWebSocketReceiveResult result;
do
{
result = await _client.ReceiveAsync(buf.Memory, _cts.Token);
stream.Write(buf.Memory.Span.Slice(0, result.Count));
} while (!result.EndOfMessage);
return result;
}
public async Task Send(GatewayPacket packet)
{
var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions);
await _client.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default);
}
public async Task Disconnect(WebSocketCloseStatus status, string? description)
{
await _client.CloseAsync(status, description, default);
_cts.Cancel();
}
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record ShardInfo(int ShardId, int NumShards);
}

View File

@ -0,0 +1,8 @@
namespace Myriad.Gateway
{
public record ShardSessionInfo
{
public string? Session { get; init; }
public int? LastSequence { get; init; }
}
}

19
Myriad/Myriad.csproj Normal file
View File

@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
<Nullable>enable</Nullable>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)' == 'Release' ">
<DebugSymbols>true</DebugSymbols>
<DebugType>full</DebugType>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Polly" Version="7.2.1" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="Serilog" Version="2.10.0" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,240 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Rest.Exceptions;
using Myriad.Rest.Ratelimit;
using Myriad.Rest.Types;
using Myriad.Serialization;
using Polly;
using Serilog;
namespace Myriad.Rest
{
public class BaseRestClient: IAsyncDisposable
{
private const string ApiBaseUrl = "https://discord.com/api/v8";
private readonly Version _httpVersion = new(2, 0);
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger;
private readonly Ratelimiter _ratelimiter;
private readonly AsyncPolicy<HttpResponseMessage> _retryPolicy;
public BaseRestClient(string userAgent, string token, ILogger logger)
{
_logger = logger.ForContext<BaseRestClient>();
if (!token.StartsWith("Bot "))
token = "Bot " + token;
Client = new HttpClient();
Client.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgent);
Client.DefaultRequestHeaders.TryAddWithoutValidation("Authorization", token);
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForNewcord();
_ratelimiter = new Ratelimiter(logger);
var discordPolicy = new DiscordRateLimitPolicy(_ratelimiter);
// todo: why doesn't the timeout work? o.o
var timeoutPolicy = Policy.TimeoutAsync<HttpResponseMessage>(TimeSpan.FromSeconds(10));
var waitPolicy = Policy
.Handle<RatelimitBucketExhaustedException>()
.WaitAndRetryAsync(3,
(_, e, _) => ((RatelimitBucketExhaustedException) e).RetryAfter,
(_, _, _, _) => Task.CompletedTask)
.AsAsyncPolicy<HttpResponseMessage>();
_retryPolicy = Policy.WrapAsync(timeoutPolicy, waitPolicy, discordPolicy);
}
public HttpClient Client { get; }
public ValueTask DisposeAsync()
{
_ratelimiter.Dispose();
Client.Dispose();
return default;
}
public async Task<T?> Get<T>(string path, (string endpointName, ulong major) ratelimitParams) where T: class
{
var request = new HttpRequestMessage(HttpMethod.Get, ApiBaseUrl + path);
var response = await Send(request, ratelimitParams, true);
// GET-only special case: 404s are nulls and not exceptions
if (response.StatusCode == HttpStatusCode.NotFound)
return null;
return await ReadResponse<T>(response);
}
public async Task<T?> Post<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> PostMultipart<T>(string path, (string endpointName, ulong major) ratelimitParams, object? payload, MultipartFile[]? files)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path);
SetRequestFormDataBody(request, payload, files);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Patch<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Patch, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Put<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Put, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task Delete(string path, (string endpointName, ulong major) ratelimitParams)
{
var request = new HttpRequestMessage(HttpMethod.Delete, ApiBaseUrl + path);
await Send(request, ratelimitParams);
}
private void SetRequestJsonBody(HttpRequestMessage request, object? body)
{
if (body == null) return;
request.Content =
new ReadOnlyMemoryContent(JsonSerializer.SerializeToUtf8Bytes(body, _jsonSerializerOptions));
request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json");
}
private void SetRequestFormDataBody(HttpRequestMessage request, object? payload, MultipartFile[]? files)
{
var bodyJson = JsonSerializer.SerializeToUtf8Bytes(payload, _jsonSerializerOptions);
var mfd = new MultipartFormDataContent();
mfd.Add(new ByteArrayContent(bodyJson), "payload_json");
if (files != null)
{
for (var i = 0; i < files.Length; i++)
{
var (filename, stream) = files[i];
mfd.Add(new StreamContent(stream), $"file{i}", filename);
}
}
request.Content = mfd;
}
private async Task<T?> ReadResponse<T>(HttpResponseMessage response) where T: class
{
if (response.StatusCode == HttpStatusCode.NoContent)
return null;
return await response.Content.ReadFromJsonAsync<T>(_jsonSerializerOptions);
}
private async Task<HttpResponseMessage> Send(HttpRequestMessage request,
(string endpointName, ulong major) ratelimitParams,
bool ignoreNotFound = false)
{
return await _retryPolicy.ExecuteAsync(async _ =>
{
_logger.Debug("Sending request: {RequestMethod} {RequestPath}",
request.Method, request.RequestUri);
request.Version = _httpVersion;
request.VersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
var stopwatch = new Stopwatch();
stopwatch.Start();
var response = await Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
stopwatch.Stop();
_logger.Debug(
"Received response in {ResponseDurationMs} ms: {RequestMethod} {RequestPath} -> {StatusCode} {ReasonPhrase}",
stopwatch.ElapsedMilliseconds, request.Method, request.RequestUri, (int) response.StatusCode,
response.ReasonPhrase);
await HandleApiError(response, ignoreNotFound);
return response;
},
new Dictionary<string, object>
{
{DiscordRateLimitPolicy.EndpointContextKey, ratelimitParams.endpointName},
{DiscordRateLimitPolicy.MajorContextKey, ratelimitParams.major}
});
}
private async ValueTask HandleApiError(HttpResponseMessage response, bool ignoreNotFound)
{
if (response.IsSuccessStatusCode)
return;
if (response.StatusCode == HttpStatusCode.NotFound && ignoreNotFound)
return;
throw await CreateDiscordException(response);
}
private async ValueTask<DiscordRequestException> CreateDiscordException(HttpResponseMessage response)
{
var body = await response.Content.ReadAsStringAsync();
var apiError = TryParseApiError(body);
return response.StatusCode switch
{
HttpStatusCode.BadRequest => new BadRequestException(response, body, apiError),
HttpStatusCode.Forbidden => new ForbiddenException(response, body, apiError),
HttpStatusCode.Unauthorized => new UnauthorizedException(response, body, apiError),
HttpStatusCode.NotFound => new NotFoundException(response, body, apiError),
HttpStatusCode.Conflict => new ConflictException(response, body, apiError),
HttpStatusCode.TooManyRequests => new TooManyRequestsException(response, body, apiError),
_ => new UnknownDiscordRequestException(response, body, apiError)
};
}
private DiscordApiError? TryParseApiError(string responseBody)
{
if (string.IsNullOrWhiteSpace(responseBody))
return null;
try
{
return JsonSerializer.Deserialize<DiscordApiError>(responseBody, _jsonSerializerOptions);
}
catch (JsonException e)
{
_logger.Verbose(e, "Error deserializing API error");
}
return null;
}
}
}

View File

@ -0,0 +1,120 @@
using System;
using System.IO;
using System.Net;
using System.Threading.Tasks;
using Myriad.Rest.Types;
using Myriad.Rest.Types.Requests;
using Myriad.Types;
using Serilog;
namespace Myriad.Rest
{
public class DiscordApiClient
{
private const string UserAgent = "Test Discord Library by @Ske#6201";
private readonly BaseRestClient _client;
public DiscordApiClient(string token, ILogger logger)
{
_client = new BaseRestClient(UserAgent, token, logger);
}
public Task<GatewayInfo> GetGateway() =>
_client.Get<GatewayInfo>("/gateway", ("GetGateway", default))!;
public Task<GatewayInfo.Bot> GetGatewayBot() =>
_client.Get<GatewayInfo.Bot>("/gateway/bot", ("GetGatewayBot", default))!;
public Task<Channel?> GetChannel(ulong channelId) =>
_client.Get<Channel>($"/channels/{channelId}", ("GetChannel", channelId));
public Task<Message?> GetMessage(ulong channelId, ulong messageId) =>
_client.Get<Message>($"/channels/{channelId}/messages/{messageId}", ("GetMessage", channelId));
public Task<Channel?> GetGuild(ulong id) =>
_client.Get<Channel>($"/guilds/{id}", ("GetGuild", id));
public Task<User?> GetUser(ulong id) =>
_client.Get<User>($"/users/{id}", ("GetUser", default));
public Task<Message> CreateMessage(ulong channelId, MessageRequest request) =>
_client.Post<Message>($"/channels/{channelId}/messages", ("CreateMessage", channelId), request)!;
public Task<Message> EditMessage(ulong channelId, ulong messageId, MessageEditRequest request) =>
_client.Patch<Message>($"/channels/{channelId}/messages/{messageId}", ("EditMessage", channelId), request)!;
public Task DeleteMessage(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}", ("DeleteMessage", channelId));
public Task CreateReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Put<object>($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("CreateReaction", channelId), null);
public Task DeleteOwnReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("DeleteOwnReaction", channelId));
public Task DeleteUserReaction(ulong channelId, ulong messageId, Emoji emoji, ulong userId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/{userId}",
("DeleteUserReaction", channelId));
public Task DeleteAllReactions(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions",
("DeleteAllReactions", channelId));
public Task DeleteAllReactionsForEmoji(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}",
("DeleteAllReactionsForEmoji", channelId));
public Task<ApplicationCommand> CreateGlobalApplicationCommand(ulong applicationId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/commands",
("CreateGlobalApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand[]> GetGuildApplicationCommands(ulong applicationId, ulong guildId) =>
_client.Get<ApplicationCommand[]>($"/applications/{applicationId}/guilds/{guildId}/commands",
("GetGuildApplicationCommands", applicationId))!;
public Task<ApplicationCommand> CreateGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("CreateGuildApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand> EditGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Patch<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("EditGuildApplicationCommand", applicationId), request)!;
public Task DeleteGuildApplicationCommand(ulong applicationId, ulong commandId) =>
_client.Delete($"/applications/{applicationId}/commands/{commandId}",
("DeleteGuildApplicationCommand", applicationId));
public Task CreateInteractionResponse(ulong interactionId, string token, InteractionResponse response) =>
_client.Post<object>($"/interactions/{interactionId}/{token}/callback",
("CreateInteractionResponse", interactionId), response);
public Task ModifyGuildMember(ulong guildId, ulong userId, ModifyGuildMemberRequest request) =>
_client.Patch<object>($"/guilds/{guildId}/members/{userId}",
("ModifyGuildMember", guildId), request);
public Task<Webhook> CreateWebhook(ulong channelId, CreateWebhookRequest request) =>
_client.Post<Webhook>($"/channels/{channelId}/webhooks", ("CreateWebhook", channelId), request)!;
public Task<Webhook> GetWebhook(ulong webhookId) =>
_client.Get<Webhook>($"/webhooks/{webhookId}/webhooks", ("GetWebhook", webhookId))!;
public Task<Webhook[]> GetChannelWebhooks(ulong channelId) =>
_client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!;
public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request,
MultipartFile[]? files = null) =>
_client.PostMultipart<Message>($"/webhooks/{webhookId}/{webhookToken}",
("ExecuteWebhook", webhookId), request, files)!;
private static string EncodeEmoji(Emoji emoji) =>
WebUtility.UrlEncode(emoji.Name) ?? emoji.Id?.ToString() ??
throw new ArgumentException("Could not encode emoji");
}
}

View File

@ -0,0 +1,9 @@
using System.Text.Json;
namespace Myriad.Rest
{
public record DiscordApiError(string Message, int Code)
{
public JsonElement? Errors { get; init; }
}
}

View File

@ -0,0 +1,71 @@
using System;
using System.Net;
using System.Net.Http;
namespace Myriad.Rest.Exceptions
{
public class DiscordRequestException: Exception
{
public DiscordRequestException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError)
{
RequestBody = requestBody;
Response = response;
ApiError = apiError;
}
public string RequestBody { get; init; } = null!;
public HttpResponseMessage Response { get; init; } = null!;
public HttpStatusCode StatusCode => Response.StatusCode;
public int? ErrorCode => ApiError?.Code;
internal DiscordApiError? ApiError { get; init; }
public override string Message =>
(ApiError?.Message ?? Response.ReasonPhrase ?? "") + (FormError != null ? $": {FormError}" : "");
public string? FormError => ApiError?.Errors?.ToString();
}
public class NotFoundException: DiscordRequestException
{
public NotFoundException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class UnauthorizedException: DiscordRequestException
{
public UnauthorizedException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class ForbiddenException: DiscordRequestException
{
public ForbiddenException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class ConflictException: DiscordRequestException
{
public ConflictException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class BadRequestException: DiscordRequestException
{
public BadRequestException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class TooManyRequestsException: DiscordRequestException
{
public TooManyRequestsException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError):
base(response, requestBody, apiError) { }
}
public class UnknownDiscordRequestException: DiscordRequestException
{
public UnknownDiscordRequestException(HttpResponseMessage response, string requestBody,
DiscordApiError? apiError): base(response, requestBody, apiError) { }
}
}

View File

@ -0,0 +1,29 @@
using System;
using Myriad.Rest.Ratelimit;
namespace Myriad.Rest.Exceptions
{
public class RatelimitException: Exception
{
public RatelimitException(string? message): base(message) { }
}
public class RatelimitBucketExhaustedException: RatelimitException
{
public RatelimitBucketExhaustedException(Bucket bucket, TimeSpan retryAfter): base(
"Rate limit bucket exhausted, request blocked")
{
Bucket = bucket;
RetryAfter = retryAfter;
}
public Bucket Bucket { get; }
public TimeSpan RetryAfter { get; }
}
public class GloballyRatelimitedException: RatelimitException
{
public GloballyRatelimitedException(): base("Global rate limit hit") { }
}
}

View File

@ -0,0 +1,152 @@
using System;
using System.Threading;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class Bucket
{
private static readonly TimeSpan Epsilon = TimeSpan.FromMilliseconds(10);
private static readonly TimeSpan FallbackDelay = TimeSpan.FromMilliseconds(200);
private static readonly TimeSpan StaleTimeout = TimeSpan.FromSeconds(5);
private readonly ILogger _logger;
private readonly SemaphoreSlim _semaphore = new(1, 1);
private DateTimeOffset _nextReset;
private bool _resetTimeValid;
public Bucket(ILogger logger, string key, ulong major, int limit)
{
_logger = logger.ForContext<Bucket>();
Key = key;
Major = major;
Limit = limit;
Remaining = limit;
_resetTimeValid = false;
}
public string Key { get; }
public ulong Major { get; }
public int Remaining { get; private set; }
public int Limit { get; private set; }
public DateTimeOffset LastUsed { get; private set; } = DateTimeOffset.UtcNow;
public bool TryAcquire()
{
LastUsed = DateTimeOffset.Now;
try
{
_semaphore.Wait();
if (Remaining > 0)
{
_logger.Debug(
"{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit} left], allowing through",
Key, Major, Remaining, Limit);
Remaining--;
return true;
}
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit}] left, denying",
Key, Major, Remaining, Limit);
return false;
}
finally
{
_semaphore.Release();
}
}
public void HandleResponse(RatelimitHeaders headers)
{
try
{
_semaphore.Wait();
if (headers.ResetAfter != null)
{
var headerNextReset = DateTimeOffset.UtcNow + headers.ResetAfter.Value; // todo: server time
if (headerNextReset > _nextReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Received reset time {NextReset} from server",
Key, Major, _nextReset);
_nextReset = headerNextReset;
_resetTimeValid = true;
}
}
if (headers.Limit != null)
Limit = headers.Limit.Value;
}
finally
{
_semaphore.Release();
}
}
public void Tick(DateTimeOffset now)
{
try
{
_semaphore.Wait();
// If we're past the reset time *and* we haven't reset already, do that
var timeSinceReset = _nextReset - now;
var shouldReset = _resetTimeValid && timeSinceReset > TimeSpan.Zero;
if (shouldReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket timed out, refreshing with {BucketLimit} requests",
Key, Major, Limit);
Remaining = Limit;
_resetTimeValid = false;
return;
}
// We've run out of requests without having any new reset time,
// *and* it's been longer than a set amount - add one request back to the pool and hope that one returns
var isBucketStale = !_resetTimeValid && Remaining <= 0 && timeSinceReset > StaleTimeout;
if (isBucketStale)
{
_logger.Warning(
"{BucketKey}/{BucketMajor}: Bucket is stale ({StaleTimeout} passed with no rate limit info), allowing one request through",
Key, Major, StaleTimeout);
Remaining = 1;
// Reset the (still-invalid) reset time to now, so we don't keep hitting this conditional over and over...
_nextReset = now;
}
}
finally
{
_semaphore.Release();
}
}
public TimeSpan GetResetDelay(DateTimeOffset now)
{
// If we don't have a valid reset time, return the fallback delay always
// (so it'll keep spinning until we hopefully have one...)
if (!_resetTimeValid)
return FallbackDelay;
var delay = _nextReset - now;
// If we have a really small (or negative) value, return a fallback delay too
if (delay < Epsilon)
return FallbackDelay;
return delay;
}
}
}

View File

@ -0,0 +1,79 @@
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class BucketManager: IDisposable
{
private static readonly TimeSpan StaleBucketTimeout = TimeSpan.FromMinutes(5);
private static readonly TimeSpan PruneWorkerInterval = TimeSpan.FromMinutes(1);
private readonly ConcurrentDictionary<(string key, ulong major), Bucket> _buckets = new();
private readonly ConcurrentDictionary<string, string> _endpointKeyMap = new();
private readonly ConcurrentDictionary<string, int> _knownKeyLimits = new();
private readonly ILogger _logger;
private readonly Task _worker;
private readonly CancellationTokenSource _workerCts = new();
public BucketManager(ILogger logger)
{
_logger = logger.ForContext<BucketManager>();
_worker = PruneWorker(_workerCts.Token);
}
public void Dispose()
{
_workerCts.Dispose();
_worker.Dispose();
}
public Bucket? GetBucket(string endpoint, ulong major)
{
if (!_endpointKeyMap.TryGetValue(endpoint, out var key))
return null;
if (_buckets.TryGetValue((key, major), out var bucket))
return bucket;
if (!_knownKeyLimits.TryGetValue(key, out var knownLimit))
return null;
return _buckets.GetOrAdd((key, major),
k => new Bucket(_logger, k.Item1, k.Item2, knownLimit));
}
public void UpdateEndpointInfo(string endpoint, string key, int? limit)
{
_endpointKeyMap[endpoint] = key;
if (limit != null)
_knownKeyLimits[key] = limit.Value;
}
private async Task PruneWorker(CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{
await Task.Delay(PruneWorkerInterval, ct);
PruneStaleBuckets(DateTimeOffset.UtcNow);
}
}
private void PruneStaleBuckets(DateTimeOffset now)
{
foreach (var (key, bucket) in _buckets)
if (now - bucket.LastUsed > StaleBucketTimeout)
{
_logger.Debug("Pruning unused bucket {Bucket} (last used at {BucketLastUsed})", bucket,
bucket.LastUsed);
_buckets.TryRemove(key, out _);
}
}
}
}

View File

@ -0,0 +1,46 @@
using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Polly;
namespace Myriad.Rest.Ratelimit
{
public class DiscordRateLimitPolicy: AsyncPolicy<HttpResponseMessage>
{
public const string EndpointContextKey = "Endpoint";
public const string MajorContextKey = "Major";
private readonly Ratelimiter _ratelimiter;
public DiscordRateLimitPolicy(Ratelimiter ratelimiter, PolicyBuilder<HttpResponseMessage>? policyBuilder = null)
: base(policyBuilder)
{
_ratelimiter = ratelimiter;
}
protected override async Task<HttpResponseMessage> ImplementationAsync(
Func<Context, CancellationToken, Task<HttpResponseMessage>> action, Context context, CancellationToken ct,
bool continueOnCapturedContext)
{
if (!context.TryGetValue(EndpointContextKey, out var endpointObj) || !(endpointObj is string endpoint))
throw new ArgumentException("Must provide endpoint in Polly context");
if (!context.TryGetValue(MajorContextKey, out var majorObj) || !(majorObj is ulong major))
throw new ArgumentException("Must provide major in Polly context");
// Check rate limit, throw if we're not allowed...
_ratelimiter.AllowRequestOrThrow(endpoint, major, DateTimeOffset.Now);
// We're OK, push it through
var response = await action(context, ct).ConfigureAwait(continueOnCapturedContext);
// Update rate limit state with headers
var headers = new RatelimitHeaders(response);
_ratelimiter.HandleResponse(headers, endpoint, major);
return response;
}
}
}

View File

@ -0,0 +1,46 @@
using System;
using System.Linq;
using System.Net.Http;
namespace Myriad.Rest.Ratelimit
{
public record RatelimitHeaders
{
public RatelimitHeaders() { }
public RatelimitHeaders(HttpResponseMessage response)
{
ServerDate = response.Headers.Date;
if (response.Headers.TryGetValues("X-RateLimit-Limit", out var limit))
Limit = int.Parse(limit!.First());
if (response.Headers.TryGetValues("X-RateLimit-Remaining", out var remaining))
Remaining = int.Parse(remaining!.First());
if (response.Headers.TryGetValues("X-RateLimit-Reset", out var reset))
Reset = DateTimeOffset.FromUnixTimeMilliseconds((long) (double.Parse(reset!.First()) * 1000));
if (response.Headers.TryGetValues("X-RateLimit-Reset-After", out var resetAfter))
ResetAfter = TimeSpan.FromSeconds(double.Parse(resetAfter!.First()));
if (response.Headers.TryGetValues("X-RateLimit-Bucket", out var bucket))
Bucket = bucket.First();
if (response.Headers.TryGetValues("X-RateLimit-Global", out var global))
Global = bool.Parse(global!.First());
}
public bool Global { get; init; }
public int? Limit { get; init; }
public int? Remaining { get; init; }
public DateTimeOffset? Reset { get; init; }
public TimeSpan? ResetAfter { get; init; }
public string? Bucket { get; init; }
public DateTimeOffset? ServerDate { get; init; }
public bool HasRatelimitInfo =>
Limit != null && Remaining != null && Reset != null && ResetAfter != null && Bucket != null;
}
}

View File

@ -0,0 +1,86 @@
using System;
using Myriad.Rest.Exceptions;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class Ratelimiter: IDisposable
{
private readonly BucketManager _buckets;
private readonly ILogger _logger;
private DateTimeOffset? _globalRateLimitExpiry;
public Ratelimiter(ILogger logger)
{
_logger = logger.ForContext<Ratelimiter>();
_buckets = new BucketManager(logger);
}
public void Dispose()
{
_buckets.Dispose();
}
public void AllowRequestOrThrow(string endpoint, ulong major, DateTimeOffset now)
{
if (IsGloballyRateLimited(now))
{
_logger.Warning("Globally rate limited until {GlobalRateLimitExpiry}, cancelling request",
_globalRateLimitExpiry);
throw new GloballyRatelimitedException();
}
var bucket = _buckets.GetBucket(endpoint, major);
if (bucket == null)
{
// No rate limit for this endpoint (yet), allow through
_logger.Debug("No rate limit data for endpoint {Endpoint}, allowing through", endpoint);
return;
}
bucket.Tick(now);
if (bucket.TryAcquire())
// We're allowed to send it! :)
return;
// We can't send this request right now; retrying...
var waitTime = bucket.GetResetDelay(now);
// add a small buffer for Timing:tm:
waitTime += TimeSpan.FromMilliseconds(50);
// (this is caught by a WaitAndRetry Polly handler, if configured)
throw new RatelimitBucketExhaustedException(bucket, waitTime);
}
public void HandleResponse(RatelimitHeaders headers, string endpoint, ulong major)
{
if (!headers.HasRatelimitInfo)
return;
// TODO: properly calculate server time?
if (headers.Global)
{
_logger.Warning(
"Global rate limit hit, resetting at {GlobalRateLimitExpiry} (in {GlobalRateLimitResetAfter}!",
_globalRateLimitExpiry, headers.ResetAfter);
_globalRateLimitExpiry = headers.Reset;
}
else
{
// Update buckets first, then get it again, to properly "transfer" this info over to the new value
_buckets.UpdateEndpointInfo(endpoint, headers.Bucket!, headers.Limit);
var bucket = _buckets.GetBucket(endpoint, major);
bucket?.HandleResponse(headers);
}
}
private bool IsGloballyRateLimited(DateTimeOffset now) =>
_globalRateLimitExpiry > now;
}
}

View File

@ -0,0 +1,19 @@
using System.Collections.Generic;
namespace Myriad.Rest.Types
{
public record AllowedMentions
{
public enum ParseType
{
Roles,
Users,
Everyone
}
public List<ParseType>? Parse { get; set; }
public List<ulong>? Users { get; set; }
public List<ulong>? Roles { get; set; }
public bool RepliedUser { get; set; }
}
}

View File

@ -0,0 +1,6 @@
using System.IO;
namespace Myriad.Rest.Types
{
public record MultipartFile(string Filename, Stream Data);
}

View File

@ -0,0 +1,13 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Rest.Types
{
public record ApplicationCommandRequest
{
public string Name { get; init; }
public string Description { get; init; }
public List<ApplicationCommandOption>? Options { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Rest.Types.Requests
{
public record CreateWebhookRequest(string Name);
}

View File

@ -0,0 +1,13 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record ExecuteWebhookRequest
{
public string? Content { get; init; }
public string? Username { get; init; }
public string? AvatarUrl { get; init; }
public Embed[] Embeds { get; init; }
public AllowedMentions? AllowedMentions { get; init; }
}
}

View File

@ -0,0 +1,10 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record MessageEditRequest
{
public string? Content { get; set; }
public Embed? Embed { get; set; }
}
}

View File

@ -0,0 +1,13 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record MessageRequest
{
public string? Content { get; set; }
public object? Nonce { get; set; }
public bool Tts { get; set; }
public AllowedMentions AllowedMentions { get; set; }
public Embed? Embeds { get; set; }
}
}

View File

@ -0,0 +1,7 @@
namespace Myriad.Rest.Types
{
public record ModifyGuildMemberRequest
{
public string? Nick { get; init; }
}
}

View File

@ -0,0 +1,20 @@
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Myriad.Serialization
{
public static class JsonSerializerOptionsExtensions
{
public static JsonSerializerOptions ConfigureForNewcord(this JsonSerializerOptions opts)
{
opts.PropertyNamingPolicy = new JsonSnakeCaseNamingPolicy();
opts.NumberHandling = JsonNumberHandling.AllowReadingFromString;
opts.IncludeFields = true;
opts.Converters.Add(new PermissionSetJsonConverter());
opts.Converters.Add(new ShardInfoJsonConverter());
return opts;
}
}
}

View File

@ -0,0 +1,88 @@
using System;
using System.Text;
using System.Text.Json;
namespace Myriad.Serialization
{
// From https://github.com/J0rgeSerran0/JsonNamingPolicy/blob/master/JsonSnakeCaseNamingPolicy.cs, no NuGet :/
public class JsonSnakeCaseNamingPolicy: JsonNamingPolicy
{
private readonly string _separator = "_";
public override string ConvertName(string name)
{
if (string.IsNullOrEmpty(name) || string.IsNullOrWhiteSpace(name)) return string.Empty;
ReadOnlySpan<char> spanName = name.Trim();
var stringBuilder = new StringBuilder();
var addCharacter = true;
var isPreviousSpace = false;
var isPreviousSeparator = false;
var isCurrentSpace = false;
var isNextLower = false;
var isNextUpper = false;
var isNextSpace = false;
for (var position = 0; position < spanName.Length; position++)
{
if (position != 0)
{
isCurrentSpace = spanName[position] == 32;
isPreviousSpace = spanName[position - 1] == 32;
isPreviousSeparator = spanName[position - 1] == 95;
if (position + 1 != spanName.Length)
{
isNextLower = spanName[position + 1] > 96 && spanName[position + 1] < 123;
isNextUpper = spanName[position + 1] > 64 && spanName[position + 1] < 91;
isNextSpace = spanName[position + 1] == 32;
}
if (isCurrentSpace &&
(isPreviousSpace ||
isPreviousSeparator ||
isNextUpper ||
isNextSpace))
{
addCharacter = false;
}
else
{
var isCurrentUpper = spanName[position] > 64 && spanName[position] < 91;
var isPreviousLower = spanName[position - 1] > 96 && spanName[position - 1] < 123;
var isPreviousNumber = spanName[position - 1] > 47 && spanName[position - 1] < 58;
if (isCurrentUpper &&
(isPreviousLower ||
isPreviousNumber ||
isNextLower ||
isNextSpace ||
isNextLower && !isPreviousSpace))
{
stringBuilder.Append(_separator);
}
else
{
if (isCurrentSpace &&
!isPreviousSpace &&
!isNextSpace)
{
stringBuilder.Append(_separator);
addCharacter = false;
}
}
}
}
if (addCharacter)
stringBuilder.Append(spanName[position]);
else
addCharacter = true;
}
return stringBuilder.ToString().ToLower();
}
}
}

View File

@ -0,0 +1,22 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Myriad.Serialization
{
public class JsonStringConverter: JsonConverter<object>
{
public override object? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var str = JsonSerializer.Deserialize<string>(ref reader);
var inner = JsonSerializer.Deserialize(str!, typeToConvert, options);
return inner;
}
public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options)
{
var inner = JsonSerializer.Serialize(value, options);
writer.WriteStringValue(inner);
}
}
}

View File

@ -0,0 +1,24 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
using Myriad.Types;
namespace Myriad.Serialization
{
public class PermissionSetJsonConverter: JsonConverter<PermissionSet>
{
public override PermissionSet Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var str = reader.GetString();
if (str == null) return default;
return (PermissionSet) ulong.Parse(str);
}
public override void Write(Utf8JsonWriter writer, PermissionSet value, JsonSerializerOptions options)
{
writer.WriteStringValue(((ulong) value).ToString());
}
}
}

View File

@ -0,0 +1,28 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
using Myriad.Gateway;
namespace Myriad.Serialization
{
public class ShardInfoJsonConverter: JsonConverter<ShardInfo>
{
public override ShardInfo? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<int[]>(ref reader);
if (arr?.Length != 2)
throw new JsonException("Expected shard info as array of length 2");
return new ShardInfo(arr[0], arr[1]);
}
public override void Write(Utf8JsonWriter writer, ShardInfo value, JsonSerializerOptions options)
{
writer.WriteStartArray();
writer.WriteNumberValue(value.ShardId);
writer.WriteNumberValue(value.NumShards);
writer.WriteEndArray();
}
}
}

22
Myriad/Types/Activity.cs Normal file
View File

@ -0,0 +1,22 @@
namespace Myriad.Types
{
public record Activity: ActivityPartial
{
}
public record ActivityPartial
{
public string Name { get; init; }
public ActivityType Type { get; init; }
public string? Url { get; init; }
}
public enum ActivityType
{
Game = 0,
Streaming = 1,
Listening = 2,
Custom = 4,
Competing = 5
}
}

View File

@ -0,0 +1,27 @@
using System.Collections.Generic;
namespace Myriad.Types
{
public record Application: ApplicationPartial
{
public string Name { get; init; }
public string? Icon { get; init; }
public string Description { get; init; }
public string[]? RpcOrigins { get; init; }
public bool BotPublic { get; init; }
public bool BotRequireCodeGrant { get; init; }
public User Owner { get; init; } // TODO: docs specify this is "partial", what does that mean
public string Summary { get; init; }
public string VerifyKey { get; init; }
public ulong? GuildId { get; init; }
public ulong? PrimarySkuId { get; init; }
public string? Slug { get; init; }
public string? CoverImage { get; init; }
}
public record ApplicationPartial
{
public ulong Id { get; init; }
public int Flags { get; init; }
}
}

View File

@ -0,0 +1,13 @@
using System.Collections.Generic;
namespace Myriad.Types
{
public record ApplicationCommand
{
public ulong Id { get; init; }
public ulong ApplicationId { get; init; }
public string Name { get; init; }
public string Description { get; init; }
public ApplicationCommandOption[]? Options { get; init; }
}
}

View File

@ -0,0 +1,9 @@
namespace Myriad.Types
{
public record ApplicationCommandInteractionData
{
public ulong Id { get; init; }
public string Name { get; init; }
public ApplicationCommandInteractionDataOption[] Options { get; init; }
}
}

View File

@ -0,0 +1,9 @@
namespace Myriad.Types
{
public record ApplicationCommandInteractionDataOption
{
public string Name { get; init; }
public object? Value { get; init; }
public ApplicationCommandInteractionDataOption[]? Options { get; init; }
}
}

View File

@ -0,0 +1,24 @@
namespace Myriad.Types
{
public record ApplicationCommandOption(ApplicationCommandOption.OptionType Type, string Name, string Description)
{
public enum OptionType
{
Subcommand = 1,
SubcommandGroup = 2,
String = 3,
Integer = 4,
Boolean = 5,
User = 6,
Channel = 7,
Role = 8
}
public bool Default { get; init; }
public bool Required { get; init; }
public Choice[]? Choices { get; init; }
public ApplicationCommandOption[]? Options { get; init; }
public record Choice(string Name, object Value);
}
}

View File

@ -0,0 +1,19 @@
namespace Myriad.Types
{
public record Interaction
{
public enum InteractionType
{
Ping = 1,
ApplicationCommand = 2
}
public ulong Id { get; init; }
public InteractionType Type { get; init; }
public ApplicationCommandInteractionData? Data { get; init; }
public ulong GuildId { get; init; }
public ulong ChannelId { get; init; }
public GuildMember Member { get; init; }
public string Token { get; init; }
}
}

View File

@ -0,0 +1,15 @@
using System.Collections.Generic;
using Myriad.Rest.Types;
namespace Myriad.Types
{
public record InteractionApplicationCommandCallbackData
{
public bool? Tts { get; init; }
public string Content { get; init; }
public Embed[]? Embeds { get; init; }
public AllowedMentions? AllowedMentions { get; init; }
public Message.MessageFlags Flags { get; init; }
}
}

View File

@ -0,0 +1,17 @@
namespace Myriad.Types
{
public record InteractionResponse
{
public enum ResponseType
{
Pong = 1,
Acknowledge = 2,
ChannelMessage = 3,
ChannelMessageWithSource = 4,
AckWithSource = 5
}
public ResponseType Type { get; init; }
public InteractionApplicationCommandCallbackData? Data { get; init; }
}
}

40
Myriad/Types/Channel.cs Normal file
View File

@ -0,0 +1,40 @@
namespace Myriad.Types
{
public record Channel
{
public enum ChannelType
{
GuildText = 0,
Dm = 1,
GuildVoice = 2,
GroupDm = 3,
GuildCategory = 4,
GuildNews = 5,
GuildStore = 6
}
public ulong Id { get; init; }
public ChannelType Type { get; init; }
public ulong? GuildId { get; init; }
public int? Position { get; init; }
public string? Name { get; init; }
public string? Topic { get; init; }
public bool? Nsfw { get; init; }
public long? ParentId { get; init; }
public Overwrite[]? PermissionOverwrites { get; init; }
public record Overwrite
{
public ulong Id { get; init; }
public OverwriteType Type { get; init; }
public PermissionSet Allow { get; init; }
public PermissionSet Deny { get; init; }
}
public enum OverwriteType
{
Role = 0,
Member = 1
}
}
}

64
Myriad/Types/Embed.cs Normal file
View File

@ -0,0 +1,64 @@
using System.Collections.Generic;
namespace Myriad.Types
{
public record Embed
{
public string? Title { get; init; }
public string? Type { get; init; }
public string? Description { get; init; }
public string? Url { get; init; }
public string? Timestamp { get; init; }
public uint? Color { get; init; }
public EmbedFooter? Footer { get; init; }
public EmbedImage? Image { get; init; }
public EmbedThumbnail? Thumbnail { get; init; }
public EmbedVideo? Video { get; init; }
public EmbedProvider? Provider { get; init; }
public EmbedAuthor? Author { get; init; }
public Field[]? Fields { get; init; }
public record EmbedFooter (
string Text,
string? IconUrl = null,
string? ProxyIconUrl = null
);
public record EmbedImage (
string? Url,
uint? Width = null,
uint? Height = null
);
public record EmbedThumbnail (
string? Url,
string? ProxyUrl = null,
uint? Width = null,
uint? Height = null
);
public record EmbedVideo (
string? Url,
uint? Width = null,
uint? Height = null
);
public record EmbedProvider (
string? Name,
string? Url
);
public record EmbedAuthor (
string? Name = null,
string? Url = null,
string? IconUrl = null,
string? ProxyIconUrl = null
);
public record Field (
string Name,
string Value,
bool Inline = false
);
}
}

9
Myriad/Types/Emoji.cs Normal file
View File

@ -0,0 +1,9 @@
namespace Myriad.Types
{
public record Emoji
{
public ulong? Id { get; init; }
public string? Name { get; init; }
public bool? Animated { get; init; }
}
}

View File

@ -0,0 +1,13 @@
namespace Myriad.Types
{
public record GatewayInfo
{
public string Url { get; init; }
public record Bot: GatewayInfo
{
public int Shards { get; init; }
public SessionStartLimit SessionStartLimit { get; init; }
}
}
}

View File

@ -0,0 +1,9 @@
namespace Myriad.Types
{
public record SessionStartLimit
{
public int Total { get; init; }
public int Remaining { get; init; }
public int ResetAfter { get; init; }
}
}

24
Myriad/Types/Guild.cs Normal file
View File

@ -0,0 +1,24 @@
using System.Collections.Generic;
namespace Myriad.Types
{
public record Guild
{
public ulong Id { get; init; }
public string Name { get; init; }
public string? Icon { get; init; }
public string? Splash { get; init; }
public string? DiscoverySplash { get; init; }
public bool? Owner { get; init; }
public ulong OwnerId { get; init; }
public string Region { get; init; }
public ulong? AfkChannelId { get; init; }
public int AfkTimeout { get; init; }
public bool? WidgetEnabled { get; init; }
public bool? WidgetChannelId { get; init; }
public int VerificationLevel { get; init; }
public Role[] Roles { get; init; }
public string[] Features { get; init; }
}
}

View File

@ -0,0 +1,14 @@
namespace Myriad.Types
{
public record GuildMember: GuildMemberPartial
{
public User User { get; init; }
}
public record GuildMemberPartial
{
public string Nick { get; init; }
public ulong[] Roles { get; init; }
public string JoinedAt { get; init; }
}
}

85
Myriad/Types/Message.cs Normal file
View File

@ -0,0 +1,85 @@
using System;
using System.Collections.Generic;
using System.Net.Mail;
namespace Myriad.Types
{
public record Message
{
[Flags]
public enum MessageFlags
{
Crossposted = 1 << 0,
IsCrosspost = 1 << 1,
SuppressEmbeds = 1 << 2,
SourceMessageDeleted = 1 << 3,
Urgent = 1 << 4,
Ephemeral = 1 << 6
}
public enum MessageType
{
Default = 0,
RecipientAdd = 1,
RecipientRemove = 2,
Call = 3,
ChannelNameChange = 4,
ChannelIconChange = 5,
ChannelPinnedMessage = 6,
GuildMemberJoin = 7,
UserPremiumGuildSubscription = 8,
UserPremiumGuildSubscriptionTier1 = 9,
UserPremiumGuildSubscriptionTier2 = 10,
UserPremiumGuildSubscriptionTier3 = 11,
ChannelFollowAdd = 12,
GuildDiscoveryDisqualified = 14,
GuildDiscoveryRequalified = 15,
Reply = 19,
ApplicationCommand = 20
}
public ulong Id { get; init; }
public ulong ChannelId { get; init; }
public ulong? GuildId { get; init; }
public User Author { get; init; }
public string? Content { get; init; }
public string? Timestamp { get; init; }
public string? EditedTimestamp { get; init; }
public bool Tts { get; init; }
public bool MentionEveryone { get; init; }
public User.Extra[] Mentions { get; init; }
public ulong[] MentionRoles { get; init; }
public Attachment[] Attachments { get; init; }
public Embed[] Embeds { get; init; }
public Reaction[] Reactions { get; init; }
public bool Pinned { get; init; }
public ulong? WebhookId { get; init; }
public MessageType Type { get; init; }
public Reference? MessageReference { get; set; }
public MessageFlags Flags { get; init; }
// todo: null vs. absence
public Message? ReferencedMessage { get; init; }
public record Reference(ulong? GuildId, ulong? ChannelId, ulong? MessageId);
public record Attachment
{
public ulong Id { get; init; }
public string Filename { get; init; }
public int Size { get; init; }
public string Url { get; init; }
public string ProxyUrl { get; init; }
public int? Width { get; init; }
public int? Height { get; init; }
}
public record Reaction
{
public int Count { get; init; }
public bool Me { get; init; }
public Emoji Emoji { get; init; }
}
}
}

View File

@ -0,0 +1,47 @@
using System;
namespace Myriad.Types
{
[Flags]
public enum PermissionSet: ulong
{
CreateInvite = 0x1,
KickMembers = 0x2,
BanMembers = 0x4,
Administrator = 0x8,
ManageChannels = 0x10,
ManageGuild = 0x20,
AddReactions = 0x40,
ViewAuditLog = 0x80,
PrioritySpeaker = 0x100,
Stream = 0x200,
ViewChannel = 0x400,
SendMessages = 0x800,
SendTtsMessages = 0x1000,
ManageMessages = 0x2000,
EmbedLinks = 0x4000,
AttachFiles = 0x8000,
ReadMessageHistory = 0x10000,
MentionEveryone = 0x20000,
UseExternalEmojis = 0x40000,
ViewGuildInsights = 0x80000,
Connect = 0x100000,
Speak = 0x200000,
MuteMembers = 0x400000,
DeafenMembers = 0x800000,
MoveMembers = 0x1000000,
UseVad = 0x2000000,
ChangeNickname = 0x4000000,
ManageNicknames = 0x8000000,
ManageRoles = 0x10000000,
ManageWebhooks = 0x20000000,
ManageEmojis = 0x40000000,
// Special:
None = 0,
All = 0x7FFFFFFF,
Dm = ViewChannel | SendMessages | ReadMessageHistory | AddReactions | AttachFiles | EmbedLinks |
UseExternalEmojis | Connect | Speak | UseVad
}
}

View File

@ -0,0 +1,6 @@
namespace Myriad.Types
{
public static class Permissions
{
}
}

14
Myriad/Types/Role.cs Normal file
View File

@ -0,0 +1,14 @@
namespace Myriad.Types
{
public record Role
{
public ulong Id { get; init; }
public string Name { get; init; }
public uint Color { get; init; }
public bool Hoist { get; init; }
public int Position { get; init; }
public PermissionSet Permissions { get; init; }
public bool Managed { get; init; }
public bool Mentionable { get; init; }
}
}

38
Myriad/Types/User.cs Normal file
View File

@ -0,0 +1,38 @@
using System;
namespace Myriad.Types
{
public record User
{
[Flags]
public enum Flags
{
DiscordEmployee = 1 << 0,
PartneredServerOwner = 1 << 1,
HypeSquadEvents = 1 << 2,
BugHunterLevel1 = 1 << 3,
HouseBravery = 1 << 6,
HouseBrilliance = 1 << 7,
HouseBalance = 1 << 8,
EarlySupporter = 1 << 9,
TeamUser = 1 << 10,
System = 1 << 12,
BugHunterLevel2 = 1 << 14,
VerifiedBot = 1 << 16,
EarlyVerifiedBotDeveloper = 1 << 17
}
public ulong Id { get; init; }
public string Username { get; init; }
public string Discriminator { get; init; }
public string? Avatar { get; init; }
public bool Bot { get; init; }
public bool? System { get; init; }
public Flags PublicFlags { get; init; }
public record Extra: User
{
public GuildMemberPartial? Member { get; init; }
}
}
}

21
Myriad/Types/Webhook.cs Normal file
View File

@ -0,0 +1,21 @@
namespace Myriad.Types
{
public record Webhook
{
public ulong Id { get; init; }
public WebhookType Type { get; init; }
public ulong? GuildId { get; init; }
public ulong ChannelId { get; init; }
public User? User { get; init; }
public string? Name { get; init; }
public string? Avatar { get; init; }
public string? Token { get; init; }
public ulong? ApplicationId { get; init; }
}
public enum WebhookType
{
Incoming = 1,
ChannelFollower = 2
}
}

View File

@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
@ -9,10 +10,10 @@ using App.Metrics;
using Autofac;
using DSharpPlus;
using DSharpPlus.Entities;
using DSharpPlus.EventArgs;
using DSharpPlus.Exceptions;
using Myriad.Cache;
using Myriad.Gateway;
using Myriad.Rest;
using Myriad.Types;
using NodaTime;
@ -27,47 +28,38 @@ namespace PluralKit.Bot
{
public class Bot
{
private readonly DiscordShardedClient _client;
private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new();
private readonly Cluster _cluster;
private readonly DiscordApiClient _rest;
private readonly ILogger _logger;
private readonly ILifetimeScope _services;
private readonly PeriodicStatCollector _collector;
private readonly IMetrics _metrics;
private readonly ErrorMessageService _errorMessageService;
private readonly CommandMessageService _commandMessageService;
private readonly IDiscordCache _cache;
private bool _hasReceivedReady = false;
private Timer _periodicTask; // Never read, just kept here for GC reasons
public Bot(DiscordShardedClient client, ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics,
ErrorMessageService errorMessageService, CommandMessageService commandMessageService)
public Bot(ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics,
ErrorMessageService errorMessageService, CommandMessageService commandMessageService, Cluster cluster, DiscordApiClient rest, IDiscordCache cache)
{
_client = client;
_logger = logger.ForContext<Bot>();
_services = services;
_collector = collector;
_metrics = metrics;
_errorMessageService = errorMessageService;
_commandMessageService = commandMessageService;
_cluster = cluster;
_rest = rest;
_cache = cache;
}
public void Init()
{
// HandleEvent takes a type parameter, automatically inferred by the event type
// It will then look up an IEventHandler<TypeOfEvent> in the DI container and call that object's handler method
// For registering new ones, see Modules.cs
_client.MessageCreated += HandleEvent;
_client.MessageDeleted += HandleEvent;
_client.MessageUpdated += HandleEvent;
_client.MessagesBulkDeleted += HandleEvent;
_client.MessageReactionAdded += HandleEvent;
// Update shard status for shards immediately on connect
_client.Ready += (client, _) =>
{
_hasReceivedReady = true;
return UpdateBotStatus(client);
};
_client.Resumed += (client, _) => UpdateBotStatus(client);
_cluster.EventReceived += OnEventReceived;
// Init the shard stuff
_services.Resolve<ShardInfoService>().Init();
@ -83,6 +75,58 @@ namespace PluralKit.Bot
}, null, timeTillNextWholeMinute, TimeSpan.FromMinutes(1));
}
public GuildMemberPartial? BotMemberIn(ulong guildId) => _guildMembers.GetValueOrDefault(guildId);
private async Task OnEventReceived(Shard shard, IGatewayEvent evt)
{
await _cache.HandleGatewayEvent(evt);
TryUpdateSelfMember(shard, evt);
// HandleEvent takes a type parameter, automatically inferred by the event type
// It will then look up an IEventHandler<TypeOfEvent> in the DI container and call that object's handler method
// For registering new ones, see Modules.cs
if (evt is MessageCreateEvent mc)
await HandleEvent(shard, mc);
if (evt is MessageUpdateEvent mu)
await HandleEvent(shard, mu);
if (evt is MessageDeleteEvent md)
await HandleEvent(shard, md);
if (evt is MessageDeleteBulkEvent mdb)
await HandleEvent(shard, mdb);
if (evt is MessageReactionAddEvent mra)
await HandleEvent(shard, mra);
// Update shard status for shards immediately on connect
if (evt is ReadyEvent re)
await HandleReady(shard, re);
if (evt is ResumedEvent)
await HandleResumed(shard);
}
private void TryUpdateSelfMember(Shard shard, IGatewayEvent evt)
{
if (evt is GuildCreateEvent gc)
_guildMembers[gc.Id] = gc.Members.FirstOrDefault(m => m.User.Id == shard.User?.Id);
if (evt is MessageCreateEvent mc && mc.Member != null && mc.Author.Id == shard.User?.Id)
_guildMembers[mc.GuildId!.Value] = mc.Member;
if (evt is GuildMemberAddEvent gma && gma.User.Id == shard.User?.Id)
_guildMembers[gma.GuildId] = gma;
if (evt is GuildMemberUpdateEvent gmu && gmu.User.Id == shard.User?.Id)
_guildMembers[gmu.GuildId] = gmu;
}
private Task HandleResumed(Shard shard)
{
return UpdateBotStatus(shard);
}
private Task HandleReady(Shard shard, ReadyEvent _)
{
_hasReceivedReady = true;
return UpdateBotStatus(shard);
}
public async Task Shutdown()
{
// This will stop the timer and prevent any subsequent invocations
@ -92,10 +136,24 @@ namespace PluralKit.Bot
// We're not actually properly disconnecting from the gateway (lol) so it'll linger for a few minutes
// Should be plenty of time for the bot to connect again next startup and set the real status
if (_hasReceivedReady)
await _client.UpdateStatusAsync(new DiscordActivity("Restarting... (please wait)"), UserStatus.Idle);
{
await Task.WhenAll(_cluster.Shards.Values.Select(shard =>
shard.UpdateStatus(new GatewayStatusUpdate
{
Activities = new[]
{
new ActivityPartial
{
Name = "Restarting... (please wait)",
Type = ActivityType.Game
}
},
Status = GatewayStatusUpdate.UserStatus.Idle
})));
}
}
private Task HandleEvent<T>(DiscordClient shard, T evt) where T: DiscordEventArgs
private Task HandleEvent<T>(Shard shard, T evt) where T: IGatewayEvent
{
// We don't want to stall the event pipeline, so we'll "fork" inside here
var _ = HandleEventInner();
@ -121,7 +179,7 @@ namespace PluralKit.Bot
try
{
using var timer = _metrics.Measure.Timer.Time(BotMetrics.EventsHandled,
new MetricTags("event", typeof(T).Name.Replace("EventArgs", "")));
new MetricTags("event", typeof(T).Name.Replace("Event", "")));
// Delegate to the queue to see if it wants to handle this event
// the TryHandle call returns true if it's handled the event
@ -131,13 +189,13 @@ namespace PluralKit.Bot
}
catch (Exception exc)
{
await HandleError(handler, evt, serviceScope, exc);
await HandleError(shard, handler, evt, serviceScope, exc);
}
}
}
private async Task HandleError<T>(IEventHandler<T> handler, T evt, ILifetimeScope serviceScope, Exception exc)
where T: DiscordEventArgs
private async Task HandleError<T>(Shard shard, IEventHandler<T> handler, T evt, ILifetimeScope serviceScope, Exception exc)
where T: IGatewayEvent
{
_metrics.Measure.Meter.Mark(BotMetrics.BotErrors, exc.GetType().FullName);
@ -149,7 +207,7 @@ namespace PluralKit.Bot
.Error(exc, "Exception in event handler: {SentryEventId}", sentryEvent.EventId);
// If the event is us responding to our own error messages, don't bother logging
if (evt is MessageCreateEventArgs mc && mc.Author.Id == _client.CurrentUser.Id)
if (evt is MessageCreateEvent mc && mc.Author.Id == shard.User?.Id)
return;
var shouldReport = exc.IsOurProblem();
@ -160,19 +218,21 @@ namespace PluralKit.Bot
var sentryScope = serviceScope.Resolve<Scope>();
// Add some specific info about Discord error responses, as a breadcrumb
if (exc is BadRequestException bre)
sentryScope.AddBreadcrumb(bre.WebResponse.Response, "response.error", data: new Dictionary<string, string>(bre.WebResponse.Headers));
if (exc is NotFoundException nfe)
sentryScope.AddBreadcrumb(nfe.WebResponse.Response, "response.error", data: new Dictionary<string, string>(nfe.WebResponse.Headers));
if (exc is UnauthorizedException ue)
sentryScope.AddBreadcrumb(ue.WebResponse.Response, "response.error", data: new Dictionary<string, string>(ue.WebResponse.Headers));
// TODO: headers to dict
// if (exc is BadRequestException bre)
// sentryScope.AddBreadcrumb(bre.Response, "response.error", data: new Dictionary<string, string>(bre.Response.Headers));
// if (exc is NotFoundException nfe)
// sentryScope.AddBreadcrumb(nfe.Response, "response.error", data: new Dictionary<string, string>(nfe.Response.Headers));
// if (exc is UnauthorizedException ue)
// sentryScope.AddBreadcrumb(ue.Response, "response.error", data: new Dictionary<string, string>(ue.Response.Headers));
SentrySdk.CaptureEvent(sentryEvent, sentryScope);
// Once we've sent it to Sentry, report it to the user (if we have permission to)
var reportChannel = handler.ErrorChannelFor(evt);
if (reportChannel != null && reportChannel.BotHasAllPermissions(Permissions.SendMessages | Permissions.EmbedLinks))
await _errorMessageService.SendErrorMessage(reportChannel, sentryEvent.EventId.ToString());
// TODO: ID lookup
// if (reportChannel != null && reportChannel.BotHasAllPermissions(Permissions.SendMessages | Permissions.EmbedLinks))
// await _errorMessageService.SendErrorMessage(reportChannel, sentryEvent.EventId.ToString());
}
}
@ -191,23 +251,38 @@ namespace PluralKit.Bot
_logger.Debug("Submitted metrics to backend");
}
private async Task UpdateBotStatus(DiscordClient specificShard = null)
private async Task UpdateBotStatus(Shard specificShard = null)
{
// If we're not on any shards, don't bother (this happens if the periodic timer fires before the first Ready)
if (!_hasReceivedReady) return;
var totalGuilds = _client.ShardClients.Values.Sum(c => c.Guilds.Count);
var totalGuilds = await _cache.GetAllGuilds().CountAsync();
try // DiscordClient may throw an exception if the socket is closed (e.g just after OP 7 received)
{
Task UpdateStatus(DiscordClient shard) =>
shard.UpdateStatusAsync(new DiscordActivity($"pk;help | in {totalGuilds} servers | shard #{shard.ShardId}"));
Task UpdateStatus(Shard shard) =>
shard.UpdateStatus(new GatewayStatusUpdate
{
Activities = new[]
{
new ActivityPartial
{
Name = $"pk;help | in {totalGuilds} servers | shard #{shard.ShardInfo?.ShardId}",
Type = ActivityType.Game,
Url = "https://pluralkit.me/"
}
}
});
if (specificShard != null)
await UpdateStatus(specificShard);
else // Run shard updates concurrently
await Task.WhenAll(_client.ShardClients.Values.Select(UpdateStatus));
await Task.WhenAll(_cluster.Shards.Values.Select(UpdateStatus));
}
catch (WebSocketException)
{
// TODO: this still thrown?
}
catch (WebSocketException) { }
}
}
}

View File

@ -9,8 +9,14 @@ using Autofac;
using DSharpPlus;
using DSharpPlus.Entities;
using Myriad.Extensions;
using Myriad.Gateway;
using Myriad.Types;
using PluralKit.Core;
using Permissions = DSharpPlus.Permissions;
namespace PluralKit.Bot
{
public class Context
@ -19,10 +25,17 @@ namespace PluralKit.Bot
private readonly DiscordRestClient _rest;
private readonly DiscordShardedClient _client;
private readonly DiscordClient _shard;
private readonly DiscordMessage _message;
private readonly DiscordClient _shard = null;
private readonly Shard _shardNew;
private readonly Guild? _guild;
private readonly Channel _channel;
private readonly DiscordMessage _message = null;
private readonly Message _messageNew;
private readonly Parameters _parameters;
private readonly MessageContext _messageContext;
private readonly GuildMemberPartial? _botMember;
private readonly PermissionSet _botPermissions;
private readonly PermissionSet _userPermissions;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
@ -32,31 +45,47 @@ namespace PluralKit.Bot
private Command _currentCommand;
public Context(ILifetimeScope provider, DiscordClient shard, DiscordMessage message, int commandParseOffset,
PKSystem senderSystem, MessageContext messageContext)
public Context(ILifetimeScope provider, Shard shard, Guild? guild, Channel channel, MessageCreateEvent message, int commandParseOffset,
PKSystem senderSystem, MessageContext messageContext, GuildMemberPartial? botMember)
{
_rest = provider.Resolve<DiscordRestClient>();
_client = provider.Resolve<DiscordShardedClient>();
_message = message;
_shard = shard;
_messageNew = message;
_shardNew = shard;
_guild = guild;
_channel = channel;
_senderSystem = senderSystem;
_messageContext = messageContext;
_botMember = botMember;
_db = provider.Resolve<IDatabase>();
_repo = provider.Resolve<ModelRepository>();
_metrics = provider.Resolve<IMetrics>();
_provider = provider;
_commandMessageService = provider.Resolve<CommandMessageService>();
_parameters = new Parameters(message.Content.Substring(commandParseOffset));
_botPermissions = message.GuildId != null
? PermissionExtensions.PermissionsFor(guild!, channel, shard.User?.Id ?? default, botMember!.Roles)
: PermissionSet.Dm;
_userPermissions = message.GuildId != null
? PermissionExtensions.PermissionsFor(guild!, channel, message.Author.Id, message.Member!.Roles)
: PermissionSet.Dm;
}
public DiscordUser Author => _message.Author;
public DiscordChannel Channel => _message.Channel;
public Channel ChannelNew => _channel;
public DiscordMessage Message => _message;
public Message MessageNew => _messageNew;
public DiscordGuild Guild => _message.Channel.Guild;
public Guild GuildNew => _guild;
public DiscordClient Shard => _shard;
public DiscordShardedClient Client => _client;
public MessageContext MessageContext => _messageContext;
public PermissionSet BotPermissions => _botPermissions;
public PermissionSet UserPermissions => _userPermissions;
public DiscordRestClient Rest => _rest;
public PKSystem System => _senderSystem;

View File

@ -1,15 +1,13 @@
using System.Threading.Tasks;
using DSharpPlus;
using DSharpPlus.Entities;
using DSharpPlus.EventArgs;
using Myriad.Gateway;
namespace PluralKit.Bot
{
public interface IEventHandler<in T> where T: DiscordEventArgs
public interface IEventHandler<in T> where T: IGatewayEvent
{
Task Handle(DiscordClient shard, T evt);
Task Handle(Shard shard, T evt);
DiscordChannel ErrorChannelFor(T evt) => null;
ulong? ErrorChannelFor(T evt) => null;
}
}

View File

@ -5,18 +5,22 @@ using App.Metrics;
using Autofac;
using DSharpPlus;
using DSharpPlus.Entities;
using DSharpPlus.EventArgs;
using Myriad.Cache;
using Myriad.Extensions;
using Myriad.Gateway;
using Myriad.Rest;
using Myriad.Rest.Types.Requests;
using Myriad.Types;
using PluralKit.Core;
namespace PluralKit.Bot
{
public class MessageCreated: IEventHandler<MessageCreateEventArgs>
public class MessageCreated: IEventHandler<MessageCreateEvent>
{
private readonly Bot _bot;
private readonly CommandTree _tree;
private readonly DiscordShardedClient _client;
private readonly IDiscordCache _cache;
private readonly LastMessageCacheService _lastMessageCache;
private readonly LoggerCleanService _loggerClean;
private readonly IMetrics _metrics;
@ -25,73 +29,81 @@ namespace PluralKit.Bot
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly BotConfig _config;
private readonly DiscordApiClient _rest;
public MessageCreated(LastMessageCacheService lastMessageCache, LoggerCleanService loggerClean,
IMetrics metrics, ProxyService proxy, DiscordShardedClient client,
CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, ModelRepository repo)
IMetrics metrics, ProxyService proxy,
CommandTree tree, ILifetimeScope services, IDatabase db, BotConfig config, ModelRepository repo, IDiscordCache cache, Bot bot, DiscordApiClient rest)
{
_lastMessageCache = lastMessageCache;
_loggerClean = loggerClean;
_metrics = metrics;
_proxy = proxy;
_client = client;
_tree = tree;
_services = services;
_db = db;
_config = config;
_repo = repo;
_cache = cache;
_bot = bot;
_rest = rest;
}
public DiscordChannel ErrorChannelFor(MessageCreateEventArgs evt) => evt.Channel;
public ulong? ErrorChannelFor(MessageCreateEvent evt) => evt.ChannelId;
private bool IsDuplicateMessage(DiscordMessage evt) =>
private bool IsDuplicateMessage(Message msg) =>
// We consider a message duplicate if it has the same ID as the previous message that hit the gateway
_lastMessageCache.GetLastMessage(evt.ChannelId) == evt.Id;
_lastMessageCache.GetLastMessage(msg.ChannelId) == msg.Id;
public async Task Handle(DiscordClient shard, MessageCreateEventArgs evt)
public async Task Handle(Shard shard, MessageCreateEvent evt)
{
if (evt.Author?.Id == _client.CurrentUser?.Id) return;
if (evt.Message.MessageType != MessageType.Default) return;
if (IsDuplicateMessage(evt.Message)) return;
if (evt.Author.Id == shard.User?.Id) return;
if (evt.Type != Message.MessageType.Default) return;
if (IsDuplicateMessage(evt)) return;
var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null;
var channel = await _cache.GetChannel(evt.ChannelId);
// Log metrics and message info
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
_lastMessageCache.AddMessage(evt.Channel.Id, evt.Message.Id);
_lastMessageCache.AddMessage(evt.ChannelId, evt.Id);
// Get message context from DB (tracking w/ metrics)
MessageContext ctx;
await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.GuildId ?? default, evt.ChannelId);
// Try each handler until we find one that succeeds
if (await TryHandleLogClean(evt, ctx))
return;
// Only do command/proxy handling if it's a user account
if (evt.Message.Author.IsBot || evt.Message.WebhookMessage || evt.Message.Author.IsSystem == true)
if (evt.Author.Bot || evt.WebhookId != null || evt.Author.System == true)
return;
if (await TryHandleCommand(shard, evt, ctx))
if (await TryHandleCommand(shard, evt, guild, channel, ctx))
return;
await TryHandleProxy(shard, evt, ctx);
await TryHandleProxy(shard, evt, guild, channel, ctx);
}
private async ValueTask<bool> TryHandleLogClean(MessageCreateEventArgs evt, MessageContext ctx)
private async ValueTask<bool> TryHandleLogClean(MessageCreateEvent evt, MessageContext ctx)
{
if (!evt.Message.Author.IsBot || evt.Message.Channel.Type != ChannelType.Text ||
var channel = await _cache.GetChannel(evt.ChannelId);
if (!evt.Author.Bot || channel!.Type != Channel.ChannelType.GuildText ||
!ctx.LogCleanupEnabled) return false;
await _loggerClean.HandleLoggerBotCleanup(evt.Message);
await _loggerClean.HandleLoggerBotCleanup(evt);
return true;
}
private async ValueTask<bool> TryHandleCommand(DiscordClient shard, MessageCreateEventArgs evt, MessageContext ctx)
private async ValueTask<bool> TryHandleCommand(Shard shard, MessageCreateEvent evt, Guild? guild, Channel channel, MessageContext ctx)
{
var content = evt.Message.Content;
var content = evt.Content;
if (content == null) return false;
// Check for command prefix
if (!HasCommandPrefix(content, out var cmdStart))
if (!HasCommandPrefix(content, shard.User?.Id ?? default, out var cmdStart))
return false;
// Trim leading whitespace from command without actually modifying the string
@ -102,7 +114,7 @@ namespace PluralKit.Bot
try
{
var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null;
await _tree.ExecuteCommand(new Context(_services, shard, evt.Message, cmdStart, system, ctx));
await _tree.ExecuteCommand(new Context(_services, shard, guild, channel, evt, cmdStart, system, ctx, _bot.BotMemberIn(channel.GuildId!.Value)));
}
catch (PKError)
{
@ -113,7 +125,7 @@ namespace PluralKit.Bot
return true;
}
private bool HasCommandPrefix(string message, out int argPos)
private bool HasCommandPrefix(string message, ulong currentUserId, out int argPos)
{
// First, try prefixes defined in the config
var prefixes = _config.Prefixes ?? BotConfig.DefaultPrefixes;
@ -128,23 +140,28 @@ namespace PluralKit.Bot
// Then, check mention prefix (must be the bot user, ofc)
argPos = -1;
if (DiscordUtils.HasMentionPrefix(message, ref argPos, out var id))
return id == _client.CurrentUser.Id;
return id == currentUserId;
return false;
}
private async ValueTask<bool> TryHandleProxy(DiscordClient shard, MessageCreateEventArgs evt, MessageContext ctx)
private async ValueTask<bool> TryHandleProxy(Shard shard, MessageCreateEvent evt, Guild guild, Channel channel, MessageContext ctx)
{
var botMember = _bot.BotMemberIn(channel.GuildId!.Value);
var botPermissions = PermissionExtensions.PermissionsFor(guild, channel, shard.User!.Id, botMember!.Roles);
try
{
return await _proxy.HandleIncomingMessage(shard, evt.Message, ctx, allowAutoproxy: ctx.AllowAutoproxy);
return await _proxy.HandleIncomingMessage(shard, evt, ctx, guild, channel, allowAutoproxy: ctx.AllowAutoproxy, botPermissions);
}
catch (PKError e)
{
// User-facing errors, print to the channel properly formatted
var msg = evt.Message;
if (msg.Channel.Guild == null || msg.Channel.BotHasAllPermissions(Permissions.SendMessages))
await msg.Channel.SendMessageFixedAsync($"{Emojis.Error} {e.Message}");
if (botPermissions.HasFlag(PermissionSet.SendMessages))
{
await _rest.CreateMessage(evt.ChannelId,
new MessageRequest {Content = $"{Emojis.Error} {e.Message}"});
}
}
return false;

View File

@ -1,9 +1,7 @@
using System;
using System.Linq;
using System.Threading.Tasks;
using DSharpPlus;
using DSharpPlus.EventArgs;
using Myriad.Gateway;
using PluralKit.Core;
@ -12,7 +10,7 @@ using Serilog;
namespace PluralKit.Bot
{
// Double duty :)
public class MessageDeleted: IEventHandler<MessageDeleteEventArgs>, IEventHandler<MessageBulkDeleteEventArgs>
public class MessageDeleted: IEventHandler<MessageDeleteEvent>, IEventHandler<MessageDeleteBulkEvent>
{
private static readonly TimeSpan MessageDeleteDelay = TimeSpan.FromSeconds(15);
@ -27,7 +25,7 @@ namespace PluralKit.Bot
_logger = logger.ForContext<MessageDeleted>();
}
public Task Handle(DiscordClient shard, MessageDeleteEventArgs evt)
public Task Handle(Shard shard, MessageDeleteEvent evt)
{
// Delete deleted webhook messages from the data store
// Most of the data in the given message is wrong/missing, so always delete just to be sure.
@ -35,7 +33,8 @@ namespace PluralKit.Bot
async Task Inner()
{
await Task.Delay(MessageDeleteDelay);
await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id));
// TODO
// await _db.Execute(c => _repo.DeleteMessage(c, evt.Message.Id));
}
// Fork a task to delete the message after a short delay
@ -44,14 +43,15 @@ namespace PluralKit.Bot
return Task.CompletedTask;
}
public Task Handle(DiscordClient shard, MessageBulkDeleteEventArgs evt)
public Task Handle(Shard shard, MessageDeleteBulkEvent evt)
{
// Same as above, but bulk
async Task Inner()
{
await Task.Delay(MessageDeleteDelay);
_logger.Information("Bulk deleting {Count} messages in channel {Channel}", evt.Messages.Count, evt.Channel.Id);
await _db.Execute(c => _repo.DeleteMessagesBulk(c, evt.Messages.Select(m => m.Id).ToList()));
// TODO
// _logger.Information("Bulk deleting {Count} messages in channel {Channel}", evt.Messages.Count, evt.Channel.Id);
// await _db.Execute(c => _repo.DeleteMessagesBulk(c, evt.Messages.Select(m => m.Id).ToList()));
}
_ = Inner();

View File

@ -3,14 +3,15 @@ using System.Threading.Tasks;
using App.Metrics;
using DSharpPlus;
using DSharpPlus.EventArgs;
using Myriad.Gateway;
using PluralKit.Core;
namespace PluralKit.Bot
{
public class MessageEdited: IEventHandler<MessageUpdateEventArgs>
public class MessageEdited: IEventHandler<MessageUpdateEvent>
{
private readonly LastMessageCacheService _lastMessageCache;
private readonly ProxyService _proxy;
@ -29,22 +30,23 @@ namespace PluralKit.Bot
_client = client;
}
public async Task Handle(DiscordClient shard, MessageUpdateEventArgs evt)
public async Task Handle(Shard shard, MessageUpdateEvent evt)
{
if (evt.Author?.Id == _client.CurrentUser?.Id) return;
// Edit message events sometimes arrive with missing data; double-check it's all there
if (evt.Message.Content == null || evt.Author == null || evt.Channel.Guild == null) return;
// Only react to the last message in the channel
if (_lastMessageCache.GetLastMessage(evt.Channel.Id) != evt.Message.Id) return;
// Just run the normal message handling code, with a flag to disable autoproxying
MessageContext ctx;
await using (var conn = await _db.Obtain())
using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
await _proxy.HandleIncomingMessage(shard, evt.Message, ctx, allowAutoproxy: false);
// TODO: fix
// if (evt.Author?.Id == _client.CurrentUser?.Id) return;
//
// // Edit message events sometimes arrive with missing data; double-check it's all there
// if (evt.Message.Content == null || evt.Author == null || evt.Channel.Guild == null) return;
//
// // Only react to the last message in the channel
// if (_lastMessageCache.GetLastMessage(evt.Channel.Id) != evt.Message.Id) return;
//
// // Just run the normal message handling code, with a flag to disable autoproxying
// MessageContext ctx;
// await using (var conn = await _db.Obtain())
// using (_metrics.Measure.Timer.Time(BotMetrics.MessageContextQueryTime))
// ctx = await _repo.GetMessageContext(conn, evt.Author.Id, evt.Channel.GuildId, evt.Channel.Id);
// await _proxy.HandleIncomingMessage(shard, evt.Message, ctx, allowAutoproxy: false);
}
}
}

View File

@ -5,13 +5,15 @@ using DSharpPlus.Entities;
using DSharpPlus.EventArgs;
using DSharpPlus.Exceptions;
using Myriad.Gateway;
using PluralKit.Core;
using Serilog;
namespace PluralKit.Bot
{
public class ReactionAdded: IEventHandler<MessageReactionAddEventArgs>
public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
{
private readonly IDatabase _db;
private readonly ModelRepository _repo;
@ -28,9 +30,9 @@ namespace PluralKit.Bot
_logger = logger.ForContext<ReactionAdded>();
}
public async Task Handle(DiscordClient shard, MessageReactionAddEventArgs evt)
public async Task Handle(Shard shard, MessageReactionAddEvent evt)
{
await TryHandleProxyMessageReactions(shard, evt);
// await TryHandleProxyMessageReactions(shard, evt);
}
private async ValueTask TryHandleProxyMessageReactions(DiscordClient shard, MessageReactionAddEventArgs evt)

View File

@ -4,10 +4,11 @@ using System.Threading.Tasks;
using Autofac;
using DSharpPlus;
using Microsoft.Extensions.Configuration;
using Myriad.Gateway;
using Myriad.Rest;
using PluralKit.Core;
using Serilog;
@ -47,7 +48,8 @@ namespace PluralKit.Bot
// Start the Discord shards themselves (handlers already set up)
logger.Information("Connecting to Discord");
await services.Resolve<DiscordShardedClient>().StartAsync();
var info = await services.Resolve<DiscordApiClient>().GetGatewayBot();
await services.Resolve<Cluster>().Start(info);
logger.Information("Connected! All is good (probably).");
// Lastly, we just... wait. Everything else is handled in the DiscordClient event loop

View File

@ -6,12 +6,17 @@ using Autofac;
using DSharpPlus;
using DSharpPlus.EventArgs;
using Myriad.Cache;
using Myriad.Gateway;
using NodaTime;
using PluralKit.Core;
using Sentry;
using Serilog;
namespace PluralKit.Bot
{
public class BotModule: Module
@ -30,6 +35,22 @@ namespace PluralKit.Bot
builder.Register(c => new DiscordShardedClient(c.Resolve<DiscordConfiguration>())).AsSelf().SingleInstance();
builder.Register(c => new DiscordRestClient(c.Resolve<DiscordConfiguration>())).AsSelf().SingleInstance();
builder.Register(c => new GatewaySettings
{
Token = c.Resolve<BotConfig>().Token,
Intents = GatewayIntent.Guilds |
GatewayIntent.DirectMessages |
GatewayIntent.DirectMessageReactions |
GatewayIntent.GuildEmojis |
GatewayIntent.GuildMessages |
GatewayIntent.GuildWebhooks |
GatewayIntent.GuildMessageReactions
}).AsSelf().SingleInstance();
builder.RegisterType<Cluster>().AsSelf().SingleInstance();
builder.Register(c => new Myriad.Rest.DiscordApiClient(c.Resolve<BotConfig>().Token, c.Resolve<ILogger>()))
.AsSelf().SingleInstance();
builder.RegisterType<MemoryDiscordCache>().AsSelf().As<IDiscordCache>().SingleInstance();
// Commands
builder.RegisterType<CommandTree>().AsSelf();
builder.RegisterType<Autoproxy>().AsSelf();
@ -55,10 +76,10 @@ namespace PluralKit.Bot
// Bot core
builder.RegisterType<Bot>().AsSelf().SingleInstance();
builder.RegisterType<MessageCreated>().As<IEventHandler<MessageCreateEventArgs>>();
builder.RegisterType<MessageDeleted>().As<IEventHandler<MessageDeleteEventArgs>>().As<IEventHandler<MessageBulkDeleteEventArgs>>();
builder.RegisterType<MessageEdited>().As<IEventHandler<MessageUpdateEventArgs>>();
builder.RegisterType<ReactionAdded>().As<IEventHandler<MessageReactionAddEventArgs>>();
builder.RegisterType<MessageCreated>().As<IEventHandler<MessageCreateEvent>>();
builder.RegisterType<MessageDeleted>().As<IEventHandler<MessageDeleteEvent>>().As<IEventHandler<MessageDeleteBulkEvent>>();
builder.RegisterType<MessageEdited>().As<IEventHandler<MessageUpdateEvent>>();
builder.RegisterType<ReactionAdded>().As<IEventHandler<MessageReactionAddEvent>>();
// Event handler queue
builder.RegisterType<HandlerQueue<MessageCreateEventArgs>>().AsSelf().SingleInstance();
@ -81,13 +102,14 @@ namespace PluralKit.Bot
// Sentry stuff
builder.Register(_ => new Scope(null)).AsSelf().InstancePerLifetimeScope();
builder.RegisterType<SentryEnricher>()
.As<ISentryEnricher<MessageCreateEventArgs>>()
.As<ISentryEnricher<MessageDeleteEventArgs>>()
.As<ISentryEnricher<MessageUpdateEventArgs>>()
.As<ISentryEnricher<MessageBulkDeleteEventArgs>>()
.As<ISentryEnricher<MessageReactionAddEventArgs>>()
.SingleInstance();
// TODO:
// builder.RegisterType<SentryEnricher>()
// .As<ISentryEnricher<MessageCreateEvent>>()
// .As<ISentryEnricher<MessageDeleteEvent>>()
// .As<ISentryEnricher<MessageUpdateEvent>>()
// .As<ISentryEnricher<MessageDeleteBulkEvent>>()
// .As<ISentryEnricher<MessageReactionAddEvent>>()
// .SingleInstance();
// Proxy stuff
builder.RegisterType<ProxyMatcher>().AsSelf().SingleInstance();

Some files were not shown because too many files have changed in this diff Show More