Merge pull request #1 from Spectralitree/newdiscord

Merge Newdiscord
This commit is contained in:
Spectralitree 2021-03-27 23:40:13 +01:00 committed by GitHub
commit cf4909586d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
147 changed files with 5008 additions and 1313 deletions

View File

@ -0,0 +1,86 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Builders
{
public class EmbedBuilder
{
private Embed _embed = new();
private readonly List<Embed.Field> _fields = new();
public EmbedBuilder Title(string? title)
{
_embed = _embed with {Title = title};
return this;
}
public EmbedBuilder Description(string? description)
{
_embed = _embed with { Description = description};
return this;
}
public EmbedBuilder Url(string? url)
{
_embed = _embed with {Url = url};
return this;
}
public EmbedBuilder Color(uint? color)
{
_embed = _embed with {Color = color};
return this;
}
public EmbedBuilder Footer(Embed.EmbedFooter? footer)
{
_embed = _embed with {
Footer = footer
};
return this;
}
public EmbedBuilder Image(Embed.EmbedImage? image)
{
_embed = _embed with {
Image = image
};
return this;
}
public EmbedBuilder Thumbnail(Embed.EmbedThumbnail? thumbnail)
{
_embed = _embed with {
Thumbnail = thumbnail
};
return this;
}
public EmbedBuilder Author(Embed.EmbedAuthor? author)
{
_embed = _embed with {
Author = author
};
return this;
}
public EmbedBuilder Timestamp(string? timestamp)
{
_embed = _embed with {
Timestamp = timestamp
};
return this;
}
public EmbedBuilder Field(Embed.Field field)
{
_fields.Add(field);
return this;
}
public Embed Build() =>
_embed with { Fields = _fields.ToArray() };
}
}

View File

@ -0,0 +1,59 @@
using System.Threading.Tasks;
using Myriad.Gateway;
using Myriad.Rest;
using Myriad.Types;
namespace Myriad.Cache
{
public static class DiscordCacheExtensions
{
public static ValueTask HandleGatewayEvent(this IDiscordCache cache, IGatewayEvent evt)
{
switch (evt)
{
case GuildCreateEvent gc:
return cache.SaveGuildCreate(gc);
case GuildUpdateEvent gu:
return cache.SaveGuild(gu);
case GuildDeleteEvent gd:
return cache.RemoveGuild(gd.Id);
case ChannelCreateEvent cc:
return cache.SaveChannel(cc);
case ChannelUpdateEvent cu:
return cache.SaveChannel(cu);
case ChannelDeleteEvent cd:
return cache.RemoveChannel(cd.Id);
case GuildRoleCreateEvent grc:
return cache.SaveRole(grc.GuildId, grc.Role);
case GuildRoleUpdateEvent gru:
return cache.SaveRole(gru.GuildId, gru.Role);
case GuildRoleDeleteEvent grd:
return cache.RemoveRole(grd.GuildId, grd.RoleId);
case MessageCreateEvent mc:
return cache.SaveMessageCreate(mc);
}
return default;
}
private static async ValueTask SaveGuildCreate(this IDiscordCache cache, GuildCreateEvent guildCreate)
{
await cache.SaveGuild(guildCreate);
foreach (var channel in guildCreate.Channels)
// The channel object does not include GuildId for some reason...
await cache.SaveChannel(channel with { GuildId = guildCreate.Id });
foreach (var member in guildCreate.Members)
await cache.SaveUser(member.User);
}
private static async ValueTask SaveMessageCreate(this IDiscordCache cache, MessageCreateEvent evt)
{
await cache.SaveUser(evt.Author);
foreach (var mention in evt.Mentions)
await cache.SaveUser(mention);
}
}
}

View File

@ -0,0 +1,29 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using Myriad.Types;
namespace Myriad.Cache
{
public interface IDiscordCache
{
public ValueTask SaveGuild(Guild guild);
public ValueTask SaveChannel(Channel channel);
public ValueTask SaveUser(User user);
public ValueTask SaveRole(ulong guildId, Role role);
public ValueTask RemoveGuild(ulong guildId);
public ValueTask RemoveChannel(ulong channelId);
public ValueTask RemoveUser(ulong userId);
public ValueTask RemoveRole(ulong guildId, ulong roleId);
public bool TryGetGuild(ulong guildId, out Guild guild);
public bool TryGetChannel(ulong channelId, out Channel channel);
public bool TryGetDmChannel(ulong userId, out Channel channel);
public bool TryGetUser(ulong userId, out User user);
public bool TryGetRole(ulong roleId, out Role role);
public IAsyncEnumerable<Guild> GetAllGuilds();
public IEnumerable<Channel> GetGuildChannels(ulong guildId);
}
}

View File

@ -0,0 +1,164 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Myriad.Types;
namespace Myriad.Cache
{
public class MemoryDiscordCache: IDiscordCache
{
private readonly ConcurrentDictionary<ulong, Channel> _channels = new();
private readonly ConcurrentDictionary<ulong, ulong> _dmChannels = new();
private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds = new();
private readonly ConcurrentDictionary<ulong, Role> _roles = new();
private readonly ConcurrentDictionary<ulong, User> _users = new();
public ValueTask SaveGuild(Guild guild)
{
SaveGuildRaw(guild);
foreach (var role in guild.Roles)
// Don't call SaveRole because that updates guild state
// and we just got a brand new one :)
_roles[role.Id] = role;
return default;
}
public async ValueTask SaveChannel(Channel channel)
{
_channels[channel.Id] = channel;
if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild))
guild.Channels.TryAdd(channel.Id, true);
if (channel.Recipients != null)
{
foreach (var recipient in channel.Recipients)
{
_dmChannels[recipient.Id] = channel.Id;
await SaveUser(recipient);
}
}
}
public ValueTask SaveUser(User user)
{
_users[user.Id] = user;
return default;
}
public ValueTask SaveRole(ulong guildId, Role role)
{
_roles[role.Id] = role;
if (_guilds.TryGetValue(guildId, out var guild))
{
// TODO: this code is stinky
var found = false;
for (var i = 0; i < guild.Guild.Roles.Length; i++)
{
if (guild.Guild.Roles[i].Id != role.Id)
continue;
guild.Guild.Roles[i] = role;
found = true;
}
if (!found)
{
_guilds[guildId] = guild with {
Guild = guild.Guild with {
Roles = guild.Guild.Roles.Concat(new[] { role}).ToArray()
}
};
}
}
return default;
}
public ValueTask RemoveGuild(ulong guildId)
{
_guilds.TryRemove(guildId, out _);
return default;
}
public ValueTask RemoveChannel(ulong channelId)
{
if (!_channels.TryRemove(channelId, out var channel))
return default;
if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild))
guild.Channels.TryRemove(channel.Id, out _);
return default;
}
public ValueTask RemoveUser(ulong userId)
{
_users.TryRemove(userId, out _);
return default;
}
public ValueTask RemoveRole(ulong guildId, ulong roleId)
{
_roles.TryRemove(roleId, out _);
return default;
}
public bool TryGetGuild(ulong guildId, out Guild guild)
{
if (_guilds.TryGetValue(guildId, out var cg))
{
guild = cg.Guild;
return true;
}
guild = null!;
return false;
}
public bool TryGetChannel(ulong channelId, out Channel channel) =>
_channels.TryGetValue(channelId, out channel!);
public bool TryGetDmChannel(ulong userId, out Channel channel)
{
channel = default!;
if (!_dmChannels.TryGetValue(userId, out var channelId))
return false;
return TryGetChannel(channelId, out channel);
}
public bool TryGetUser(ulong userId, out User user) =>
_users.TryGetValue(userId, out user!);
public bool TryGetRole(ulong roleId, out Role role) =>
_roles.TryGetValue(roleId, out role!);
public async IAsyncEnumerable<Guild> GetAllGuilds()
{
foreach (var guild in _guilds.Values)
yield return guild.Guild;
}
public IEnumerable<Channel> GetGuildChannels(ulong guildId)
{
if (!_guilds.TryGetValue(guildId, out var guild))
throw new ArgumentException("Guild not found", nameof(guildId));
return guild.Channels.Keys.Select(c => _channels[c]);
}
private CachedGuild SaveGuildRaw(Guild guild) =>
_guilds.GetOrAdd(guild.Id, (_, g) => new CachedGuild(g), guild);
private record CachedGuild(Guild Guild)
{
public readonly ConcurrentDictionary<ulong, bool> Channels = new();
}
}
}

View File

@ -0,0 +1,79 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using Myriad.Cache;
using Myriad.Rest;
using Myriad.Types;
namespace Myriad.Extensions
{
public static class CacheExtensions
{
public static Guild GetGuild(this IDiscordCache cache, ulong guildId)
{
if (!cache.TryGetGuild(guildId, out var guild))
throw new KeyNotFoundException($"Guild {guildId} not found in cache");
return guild;
}
public static Channel GetChannel(this IDiscordCache cache, ulong channelId)
{
if (!cache.TryGetChannel(channelId, out var channel))
throw new KeyNotFoundException($"Channel {channelId} not found in cache");
return channel;
}
public static Channel? GetChannelOrNull(this IDiscordCache cache, ulong channelId)
{
if (cache.TryGetChannel(channelId, out var channel))
return channel;
return null;
}
public static User GetUser(this IDiscordCache cache, ulong userId)
{
if (!cache.TryGetUser(userId, out var user))
throw new KeyNotFoundException($"User {userId} not found in cache");
return user;
}
public static Role GetRole(this IDiscordCache cache, ulong roleId)
{
if (!cache.TryGetRole(roleId, out var role))
throw new KeyNotFoundException($"User {roleId} not found in cache");
return role;
}
public static async ValueTask<User?> GetOrFetchUser(this IDiscordCache cache, DiscordApiClient rest, ulong userId)
{
if (cache.TryGetUser(userId, out var cacheUser))
return cacheUser;
var restUser = await rest.GetUser(userId);
if (restUser != null)
await cache.SaveUser(restUser);
return restUser;
}
public static async ValueTask<Channel?> GetOrFetchChannel(this IDiscordCache cache, DiscordApiClient rest, ulong channelId)
{
if (cache.TryGetChannel(channelId, out var cacheChannel))
return cacheChannel;
var restChannel = await rest.GetChannel(channelId);
if (restChannel != null)
await cache.SaveChannel(restChannel);
return restChannel;
}
public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest, ulong recipientId)
{
if (cache.TryGetDmChannel(recipientId, out var cacheChannel))
return cacheChannel;
var restChannel = await rest.CreateDm(recipientId);
await cache.SaveChannel(restChannel);
return restChannel;
}
}
}

View File

@ -0,0 +1,9 @@
using Myriad.Types;
namespace Myriad.Extensions
{
public static class ChannelExtensions
{
public static string Mention(this Channel channel) => $"<#{channel.Id}>";
}
}

View File

@ -0,0 +1,7 @@
namespace Myriad.Extensions
{
public static class GuildExtensions
{
}
}

View File

@ -0,0 +1,16 @@
using System;
using Myriad.Gateway;
using Myriad.Types;
namespace Myriad.Extensions
{
public static class MessageExtensions
{
public static string JumpLink(this Message msg) =>
$"https://discord.com/channels/{msg.GuildId}/{msg.ChannelId}/{msg.Id}";
public static string JumpLink(this MessageReactionAddEvent msg) =>
$"https://discord.com/channels/{msg.GuildId}/{msg.ChannelId}/{msg.MessageId}";
}
}

View File

@ -0,0 +1,153 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Myriad.Cache;
using Myriad.Gateway;
using Myriad.Types;
namespace Myriad.Extensions
{
public static class PermissionExtensions
{
public static PermissionSet PermissionsFor(this IDiscordCache cache, MessageCreateEvent message) =>
PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member?.Roles);
public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, GuildMember member) =>
PermissionsFor(cache, channelId, member.User.Id, member.Roles);
public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, GuildMemberPartial member) =>
PermissionsFor(cache, channelId, userId, member.Roles);
public static PermissionSet PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, ICollection<ulong>? userRoles)
{
var channel = cache.GetChannel(channelId);
if (channel.GuildId == null)
return PermissionSet.Dm;
var guild = cache.GetGuild(channel.GuildId.Value);
return PermissionsFor(guild, channel, userId, userRoles);
}
public static PermissionSet EveryonePermissions(this Guild guild) =>
guild.Roles.FirstOrDefault(r => r.Id == guild.Id)?.Permissions ?? PermissionSet.Dm;
public static PermissionSet PermissionsFor(Guild guild, Channel channel, MessageCreateEvent msg) =>
PermissionsFor(guild, channel, msg.Author.Id, msg.Member?.Roles);
public static PermissionSet PermissionsFor(Guild guild, Channel channel, ulong userId,
ICollection<ulong>? roleIds)
{
if (channel.Type == Channel.ChannelType.Dm)
return PermissionSet.Dm;
if (roleIds == null)
throw new ArgumentException($"User roles must be specified for guild channels");
var perms = GuildPermissions(guild, userId, roleIds);
perms = ApplyChannelOverwrites(perms, channel, userId, roleIds);
if ((perms & PermissionSet.Administrator) == PermissionSet.Administrator)
return PermissionSet.All;
if ((perms & PermissionSet.ViewChannel) == 0)
perms &= ~NeedsViewChannel;
if ((perms & PermissionSet.SendMessages) == 0)
perms &= ~NeedsSendMessages;
return perms;
}
public static PermissionSet GuildPermissions(this Guild guild, ulong userId, ICollection<ulong> roleIds)
{
if (guild.OwnerId == userId)
return PermissionSet.All;
var perms = PermissionSet.None;
foreach (var role in guild.Roles)
{
if (role.Id == guild.Id || roleIds.Contains(role.Id))
perms |= role.Permissions;
}
if (perms.HasFlag(PermissionSet.Administrator))
return PermissionSet.All;
return perms;
}
public static PermissionSet ApplyChannelOverwrites(PermissionSet perms, Channel channel, ulong userId,
ICollection<ulong> roleIds)
{
if (channel.PermissionOverwrites == null)
return perms;
var everyoneDeny = PermissionSet.None;
var everyoneAllow = PermissionSet.None;
var roleDeny = PermissionSet.None;
var roleAllow = PermissionSet.None;
var userDeny = PermissionSet.None;
var userAllow = PermissionSet.None;
foreach (var overwrite in channel.PermissionOverwrites)
{
switch (overwrite.Type)
{
case Channel.OverwriteType.Role when overwrite.Id == channel.GuildId:
everyoneDeny |= overwrite.Deny;
everyoneAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Role when roleIds.Contains(overwrite.Id):
roleDeny |= overwrite.Deny;
roleAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Member when overwrite.Id == userId:
userDeny |= overwrite.Deny;
userAllow |= overwrite.Allow;
break;
}
}
perms &= ~everyoneDeny;
perms |= everyoneAllow;
perms &= ~roleDeny;
perms |= roleAllow;
perms &= ~userDeny;
perms |= userAllow;
return perms;
}
private const PermissionSet NeedsViewChannel =
PermissionSet.SendMessages |
PermissionSet.SendTtsMessages |
PermissionSet.ManageMessages |
PermissionSet.EmbedLinks |
PermissionSet.AttachFiles |
PermissionSet.ReadMessageHistory |
PermissionSet.MentionEveryone |
PermissionSet.UseExternalEmojis |
PermissionSet.AddReactions |
PermissionSet.Connect |
PermissionSet.Speak |
PermissionSet.MuteMembers |
PermissionSet.DeafenMembers |
PermissionSet.MoveMembers |
PermissionSet.UseVad |
PermissionSet.Stream |
PermissionSet.PrioritySpeaker;
private const PermissionSet NeedsSendMessages =
PermissionSet.MentionEveryone |
PermissionSet.SendTtsMessages |
PermissionSet.AttachFiles |
PermissionSet.EmbedLinks;
public static string ToPermissionString(this PermissionSet perms)
{
// TODO: clean string
return perms.ToString();
}
}
}

View File

@ -0,0 +1,20 @@
using System;
using Myriad.Types;
namespace Myriad.Extensions
{
public static class SnowflakeExtensions
{
public static readonly DateTimeOffset DiscordEpoch = new(2015, 1, 1, 0, 0, 0, TimeSpan.Zero);
public static DateTimeOffset SnowflakeToTimestamp(ulong snowflake) =>
DiscordEpoch + TimeSpan.FromMilliseconds(snowflake >> 22);
public static DateTimeOffset Timestamp(this Message msg) => SnowflakeToTimestamp(msg.Id);
public static DateTimeOffset Timestamp(this Channel channel) => SnowflakeToTimestamp(channel.Id);
public static DateTimeOffset Timestamp(this Guild guild) => SnowflakeToTimestamp(guild.Id);
public static DateTimeOffset Timestamp(this Webhook webhook) => SnowflakeToTimestamp(webhook.Id);
public static DateTimeOffset Timestamp(this User user) => SnowflakeToTimestamp(user.Id);
}
}

View File

@ -0,0 +1,12 @@
using Myriad.Types;
namespace Myriad.Extensions
{
public static class UserExtensions
{
public static string Mention(this User user) => $"<@{user.Id}>";
public static string AvatarUrl(this User user, string? format = "png", int? size = 128) =>
$"https://cdn.discordapp.com/avatars/{user.Id}/{user.Avatar}.{format}?size={size}";
}
}

116
Myriad/Gateway/Cluster.cs Normal file
View File

@ -0,0 +1,116 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Myriad.Types;
using Serilog;
namespace Myriad.Gateway
{
public class Cluster
{
private readonly GatewaySettings _gatewaySettings;
private readonly ILogger _logger;
private readonly ConcurrentDictionary<int, Shard> _shards = new();
public Cluster(GatewaySettings gatewaySettings, ILogger logger)
{
_gatewaySettings = gatewaySettings;
_logger = logger;
}
public Func<Shard, IGatewayEvent, Task>? EventReceived { get; set; }
public event Action<Shard>? ShardCreated;
public IReadOnlyDictionary<int, Shard> Shards => _shards;
public ClusterSessionState SessionState => GetClusterState();
public User? User => _shards.Values.Select(s => s.User).FirstOrDefault(s => s != null);
public ApplicationPartial? Application => _shards.Values.Select(s => s.Application).FirstOrDefault(s => s != null);
private ClusterSessionState GetClusterState()
{
var shards = new List<ClusterSessionState.ShardState>();
foreach (var (id, shard) in _shards)
shards.Add(new ClusterSessionState.ShardState
{
Shard = shard.ShardInfo,
Session = shard.SessionInfo
});
return new ClusterSessionState {Shards = shards};
}
public async Task Start(GatewayInfo.Bot info, ClusterSessionState? lastState = null)
{
if (lastState != null && lastState.Shards.Count == info.Shards)
await Resume(info.Url, lastState, info.SessionStartLimit.MaxConcurrency);
else
await Start(info.Url, info.Shards, info.SessionStartLimit.MaxConcurrency);
}
public async Task Resume(string url, ClusterSessionState sessionState, int concurrency)
{
_logger.Information("Resuming session with {ShardCount} shards at {Url}", sessionState.Shards.Count, url);
foreach (var shardState in sessionState.Shards)
CreateAndAddShard(url, shardState.Shard, shardState.Session);
await StartShards(concurrency);
}
public async Task Start(string url, int shardCount, int concurrency)
{
_logger.Information("Starting {ShardCount} shards at {Url}", shardCount, url);
for (var i = 0; i < shardCount; i++)
CreateAndAddShard(url, new ShardInfo(i, shardCount), null);
await StartShards(concurrency);
}
private async Task StartShards(int concurrency)
{
var lastTime = DateTimeOffset.UtcNow;
var identifyCalls = 0;
_logger.Information("Connecting shards...");
foreach (var shard in _shards.Values)
{
if (identifyCalls >= concurrency)
{
var timeout = lastTime + TimeSpan.FromSeconds(5.5);
var delay = timeout - DateTimeOffset.UtcNow;
if (delay > TimeSpan.Zero)
{
_logger.Information("Hit shard concurrency limit, waiting {Delay}", delay);
await Task.Delay(delay);
}
identifyCalls = 0;
lastTime = DateTimeOffset.UtcNow;
}
await shard.Start();
identifyCalls++;
}
}
private void CreateAndAddShard(string url, ShardInfo shardInfo, ShardSessionInfo? session)
{
var shard = new Shard(_logger, new Uri(url), _gatewaySettings, shardInfo, session);
shard.OnEventReceived += evt => OnShardEventReceived(shard, evt);
_shards[shardInfo.ShardId] = shard;
ShardCreated?.Invoke(shard);
}
private async Task OnShardEventReceived(Shard shard, IGatewayEvent evt)
{
if (EventReceived != null)
await EventReceived(shard, evt);
}
}
}

View File

@ -0,0 +1,15 @@
using System.Collections.Generic;
namespace Myriad.Gateway
{
public record ClusterSessionState
{
public List<ShardState> Shards { get; init; }
public record ShardState
{
public ShardInfo Shard { get; init; }
public ShardSessionInfo Session { get; init; }
}
}
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ChannelCreateEvent: Channel, IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ChannelDeleteEvent: Channel, IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record ChannelUpdateEvent: Channel, IGatewayEvent;
}

View File

@ -0,0 +1,12 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildCreateEvent: Guild, IGatewayEvent
{
public Channel[] Channels { get; init; }
public GuildMember[] Members { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GuildDeleteEvent(ulong Id, bool Unavailable): IGatewayEvent;
}

View File

@ -0,0 +1,9 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildMemberAddEvent: GuildMember, IGatewayEvent
{
public ulong GuildId { get; init; }
}
}

View File

@ -0,0 +1,10 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public class GuildMemberRemoveEvent: IGatewayEvent
{
public ulong GuildId { get; init; }
public User User { get; init; }
}
}

View File

@ -0,0 +1,9 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildMemberUpdateEvent: GuildMember, IGatewayEvent
{
public ulong GuildId { get; init; }
}
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildRoleCreateEvent(ulong GuildId, Role Role): IGatewayEvent;
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GuildRoleDeleteEvent(ulong GuildId, ulong RoleId): IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildRoleUpdateEvent(ulong GuildId, Role Role): IGatewayEvent;
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record GuildUpdateEvent: Guild, IGatewayEvent;
}

View File

@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
namespace Myriad.Gateway
{
public interface IGatewayEvent
{
public static readonly Dictionary<string, Type> EventTypes = new()
{
{"READY", typeof(ReadyEvent)},
{"RESUMED", typeof(ResumedEvent)},
{"GUILD_CREATE", typeof(GuildCreateEvent)},
{"GUILD_UPDATE", typeof(GuildUpdateEvent)},
{"GUILD_DELETE", typeof(GuildDeleteEvent)},
{"GUILD_MEMBER_ADD", typeof(GuildMemberAddEvent)},
{"GUILD_MEMBER_REMOVE", typeof(GuildMemberRemoveEvent)},
{"GUILD_MEMBER_UPDATE", typeof(GuildMemberUpdateEvent)},
{"GUILD_ROLE_CREATE", typeof(GuildRoleCreateEvent)},
{"GUILD_ROLE_UPDATE", typeof(GuildRoleUpdateEvent)},
{"GUILD_ROLE_DELETE", typeof(GuildRoleDeleteEvent)},
{"CHANNEL_CREATE", typeof(ChannelCreateEvent)},
{"CHANNEL_UPDATE", typeof(ChannelUpdateEvent)},
{"CHANNEL_DELETE", typeof(ChannelDeleteEvent)},
{"MESSAGE_CREATE", typeof(MessageCreateEvent)},
{"MESSAGE_UPDATE", typeof(MessageUpdateEvent)},
{"MESSAGE_DELETE", typeof(MessageDeleteEvent)},
{"MESSAGE_DELETE_BULK", typeof(MessageDeleteBulkEvent)},
{"MESSAGE_REACTION_ADD", typeof(MessageReactionAddEvent)},
{"MESSAGE_REACTION_REMOVE", typeof(MessageReactionRemoveEvent)},
{"MESSAGE_REACTION_REMOVE_ALL", typeof(MessageReactionRemoveAllEvent)},
{"MESSAGE_REACTION_REMOVE_EMOJI", typeof(MessageReactionRemoveEmojiEvent)},
{"INTERACTION_CREATE", typeof(InteractionCreateEvent)}
};
}
}

View File

@ -0,0 +1,6 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record InteractionCreateEvent: Interaction, IGatewayEvent;
}

View File

@ -0,0 +1,9 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageCreateEvent: Message, IGatewayEvent
{
public GuildMemberPartial? Member { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record MessageDeleteBulkEvent(ulong[] Ids, ulong ChannelId, ulong? GuildId): IGatewayEvent;
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record MessageDeleteEvent(ulong Id, ulong ChannelId, ulong? GuildId): IGatewayEvent;
}

View File

@ -0,0 +1,8 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageReactionAddEvent(ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId,
GuildMember? Member,
Emoji Emoji): IGatewayEvent;
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record MessageReactionRemoveAllEvent(ulong ChannelId, ulong MessageId, ulong? GuildId): IGatewayEvent;
}

View File

@ -0,0 +1,7 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageReactionRemoveEmojiEvent
(ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent;
}

View File

@ -0,0 +1,7 @@
using Myriad.Types;
namespace Myriad.Gateway
{
public record MessageReactionRemoveEvent
(ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent;
}

View File

@ -0,0 +1,15 @@
using Myriad.Types;
using Myriad.Utils;
namespace Myriad.Gateway
{
public record MessageUpdateEvent(ulong Id, ulong ChannelId): IGatewayEvent
{
public Optional<string?> Content { get; init; }
public Optional<User> Author { get; init; }
public Optional<GuildMemberPartial> Member { get; init; }
public Optional<Message.Attachment[]> Attachments { get; init; }
public Optional<ulong?> GuildId { get; init; }
// TODO: lots of partials
}
}

View File

@ -0,0 +1,15 @@
using System.Text.Json.Serialization;
using Myriad.Types;
namespace Myriad.Gateway
{
public record ReadyEvent: IGatewayEvent
{
[JsonPropertyName("v")] public int Version { get; init; }
public User User { get; init; }
public string SessionId { get; init; }
public ShardInfo? Shard { get; init; }
public ApplicationPartial Application { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record ResumedEvent: IGatewayEvent;
}

View File

@ -0,0 +1,35 @@
using System;
namespace Myriad.Gateway
{
// TODO: unused?
public class GatewayCloseException: Exception
{
public GatewayCloseException(int closeCode, string closeReason): base($"{closeCode}: {closeReason}")
{
CloseCode = closeCode;
CloseReason = closeReason;
}
public int CloseCode { get; }
public string CloseReason { get; }
}
public class GatewayCloseCode
{
public const int UnknownError = 4000;
public const int UnknownOpcode = 4001;
public const int DecodeError = 4002;
public const int NotAuthenticated = 4003;
public const int AuthenticationFailed = 4004;
public const int AlreadyAuthenticated = 4005;
public const int InvalidSeq = 4007;
public const int RateLimited = 4008;
public const int SessionTimedOut = 4009;
public const int InvalidShard = 4010;
public const int ShardingRequired = 4011;
public const int InvalidApiVersion = 4012;
public const int InvalidIntent = 4013;
public const int DisallowedIntent = 4014;
}
}

View File

@ -0,0 +1,24 @@
using System;
namespace Myriad.Gateway
{
[Flags]
public enum GatewayIntent
{
Guilds = 1 << 0,
GuildMembers = 1 << 1,
GuildBans = 1 << 2,
GuildEmojis = 1 << 3,
GuildIntegrations = 1 << 4,
GuildWebhooks = 1 << 5,
GuildInvites = 1 << 6,
GuildVoiceStates = 1 << 7,
GuildPresences = 1 << 8,
GuildMessages = 1 << 9,
GuildMessageReactions = 1 << 10,
GuildMessageTyping = 1 << 11,
DirectMessages = 1 << 12,
DirectMessageReactions = 1 << 13,
DirectMessageTyping = 1 << 14
}
}

View File

@ -0,0 +1,31 @@
using System.Text.Json.Serialization;
namespace Myriad.Gateway
{
public record GatewayPacket
{
[JsonPropertyName("op")] public GatewayOpcode Opcode { get; init; }
[JsonPropertyName("d")] public object? Payload { get; init; }
[JsonPropertyName("s")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? Sequence { get; init; }
[JsonPropertyName("t")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? EventType { get; init; }
}
public enum GatewayOpcode
{
Dispatch = 0,
Heartbeat = 1,
Identify = 2,
PresenceUpdate = 3,
VoiceStateUpdate = 4,
Resume = 6,
Reconnect = 7,
RequestGuildMembers = 8,
InvalidSession = 9,
Hello = 10,
HeartbeatAck = 11
}
}

View File

@ -0,0 +1,8 @@
namespace Myriad.Gateway
{
public record GatewaySettings
{
public string Token { get; init; }
public GatewayIntent Intents { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GatewayHello(int HeartbeatInterval);
}

View File

@ -0,0 +1,28 @@
using System.Text.Json.Serialization;
namespace Myriad.Gateway
{
public record GatewayIdentify
{
public string Token { get; init; }
public ConnectionProperties Properties { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? Compress { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? LargeThreshold { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ShardInfo? Shard { get; init; }
public GatewayIntent Intents { get; init; }
public record ConnectionProperties
{
[JsonPropertyName("$os")] public string Os { get; init; }
[JsonPropertyName("$browser")] public string Browser { get; init; }
[JsonPropertyName("$device")] public string Device { get; init; }
}
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record GatewayResume(string Token, string SessionId, int Seq);
}

View File

@ -0,0 +1,23 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Gateway
{
public record GatewayStatusUpdate
{
public enum UserStatus
{
Online,
Dnd,
Idle,
Invisible,
Offline
}
public ulong? Since { get; init; }
public ActivityPartial[]? Activities { get; init; }
public UserStatus Status { get; init; }
public bool Afk { get; init; }
}
}

349
Myriad/Gateway/Shard.cs Normal file
View File

@ -0,0 +1,349 @@
using System;
using System.Net.WebSockets;
using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Serialization;
using Myriad.Types;
using Serilog;
namespace Myriad.Gateway
{
public class Shard: IAsyncDisposable
{
private const string LibraryName = "Myriad (for PluralKit)";
private readonly JsonSerializerOptions _jsonSerializerOptions =
new JsonSerializerOptions().ConfigureForMyriad();
private readonly ILogger _logger;
private readonly Uri _uri;
private ShardConnection? _conn;
private TimeSpan? _currentHeartbeatInterval;
private bool _hasReceivedAck;
private DateTimeOffset? _lastHeartbeatSent;
private Task _worker;
public ShardInfo ShardInfo { get; private set; }
public int ShardId => ShardInfo.ShardId;
public GatewaySettings Settings { get; }
public ShardSessionInfo SessionInfo { get; private set; }
public ShardState State { get; private set; }
public TimeSpan? Latency { get; private set; }
public User? User { get; private set; }
public ApplicationPartial? Application { get; private set; }
public Func<IGatewayEvent, Task>? OnEventReceived { get; set; }
public event Action<TimeSpan>? HeartbeatReceived;
public event Action? SocketOpened;
public event Action? Resumed;
public event Action? Ready;
public event Action<WebSocketCloseStatus, string?>? SocketClosed;
public Shard(ILogger logger, Uri uri, GatewaySettings settings, ShardInfo info,
ShardSessionInfo? sessionInfo = null)
{
_logger = logger.ForContext<Shard>();
_uri = uri;
Settings = settings;
ShardInfo = info;
SessionInfo = sessionInfo ?? new ShardSessionInfo();
}
public async ValueTask DisposeAsync()
{
if (_conn != null)
await _conn.DisposeAsync();
}
public Task Start()
{
_worker = MainLoop();
return Task.CompletedTask;
}
public async Task UpdateStatus(GatewayStatusUpdate payload)
{
if (_conn != null && _conn.State == WebSocketState.Open)
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.PresenceUpdate, Payload = payload});
}
private async Task MainLoop()
{
while (true)
try
{
_logger.Information("Shard {ShardId}: Connecting...", ShardId);
State = ShardState.Connecting;
await Connect();
_logger.Information("Shard {ShardId}: Connected. Entering main loop...", ShardId);
// Tick returns false if we need to stop and reconnect
while (await Tick(_conn!))
await Task.Delay(TimeSpan.FromMilliseconds(1000));
_logger.Information("Shard {ShardId}: Connection closed, reconnecting...", ShardId);
State = ShardState.Closed;
}
catch (Exception e)
{
_logger.Error(e, "Shard {ShardId}: Error in shard state handler", ShardId);
}
}
private async Task<bool> Tick(ShardConnection conn)
{
if (conn.State != WebSocketState.Connecting && conn.State != WebSocketState.Open)
return false;
if (!await TickHeartbeat(conn))
// TickHeartbeat returns false if we're disconnecting
return false;
return true;
}
private async Task<bool> TickHeartbeat(ShardConnection conn)
{
// If we don't need to heartbeat, do nothing
if (_lastHeartbeatSent == null || _currentHeartbeatInterval == null)
return true;
if (DateTimeOffset.UtcNow - _lastHeartbeatSent < _currentHeartbeatInterval)
return true;
// If we haven't received the ack in time, close w/ error
if (!_hasReceivedAck)
{
_logger.Warning(
"Shard {ShardId}: Did not receive heartbeat Ack from gateway within interval ({HeartbeatInterval})",
ShardId, _currentHeartbeatInterval);
State = ShardState.Closing;
await conn.Disconnect(WebSocketCloseStatus.ProtocolError, "Did not receive ACK in time");
return false;
}
// Otherwise just send it :)
await SendHeartbeat(conn);
_hasReceivedAck = false;
return true;
}
private async Task SendHeartbeat(ShardConnection conn)
{
_logger.Debug("Shard {ShardId}: Sending heartbeat with seq.no. {LastSequence}",
ShardId, SessionInfo.LastSequence);
await conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = SessionInfo.LastSequence});
_lastHeartbeatSent = DateTimeOffset.UtcNow;
}
private async Task Connect()
{
if (_conn != null)
await _conn.DisposeAsync();
_currentHeartbeatInterval = null;
_conn = new ShardConnection(_uri, _logger, _jsonSerializerOptions)
{
OnReceive = OnReceive,
OnOpen = () => SocketOpened?.Invoke(),
OnClose = (closeStatus, message) => SocketClosed?.Invoke(closeStatus, message)
};
}
private async Task OnReceive(GatewayPacket packet)
{
switch (packet.Opcode)
{
case GatewayOpcode.Hello:
{
await HandleHello((JsonElement) packet.Payload!);
break;
}
case GatewayOpcode.Heartbeat:
{
_logger.Debug("Shard {ShardId}: Received heartbeat request from shard, sending Ack", ShardId);
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.HeartbeatAck});
break;
}
case GatewayOpcode.HeartbeatAck:
{
Latency = DateTimeOffset.UtcNow - _lastHeartbeatSent;
_logger.Debug("Shard {ShardId}: Received heartbeat Ack with latency {Latency}", ShardId, Latency);
if (Latency != null)
HeartbeatReceived?.Invoke(Latency!.Value);
_hasReceivedAck = true;
break;
}
case GatewayOpcode.Reconnect:
{
_logger.Information("Shard {ShardId}: Received Reconnect, closing and reconnecting", ShardId);
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
break;
}
case GatewayOpcode.InvalidSession:
{
var canResume = ((JsonElement) packet.Payload!).GetBoolean();
// Clear session info before DCing
if (!canResume)
SessionInfo = SessionInfo with { Session = null };
var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000));
_logger.Information(
"Shard {ShardId}: Received Invalid Session (can resume? {CanResume}), reconnecting after {ReconnectDelay}",
ShardId, canResume, delay);
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
// Will reconnect after exiting this "loop"
await Task.Delay(delay);
break;
}
case GatewayOpcode.Dispatch:
{
SessionInfo = SessionInfo with { LastSequence = packet.Sequence };
var evt = DeserializeEvent(packet.EventType!, (JsonElement) packet.Payload!)!;
if (evt is ReadyEvent rdy)
{
if (State == ShardState.Connecting)
await HandleReady(rdy);
else
_logger.Warning("Shard {ShardId}: Received Ready event in unexpected state {ShardState}, ignoring?",
ShardId, State);
}
else if (evt is ResumedEvent)
{
if (State == ShardState.Connecting)
await HandleResumed();
else
_logger.Warning("Shard {ShardId}: Received Resumed event in unexpected state {ShardState}, ignoring?",
ShardId, State);
}
await HandleEvent(evt);
break;
}
default:
{
_logger.Debug("Shard {ShardId}: Received unknown gateway opcode {Opcode}", ShardId, packet.Opcode);
break;
}
}
}
private async Task HandleEvent(IGatewayEvent evt)
{
if (OnEventReceived != null)
await OnEventReceived.Invoke(evt);
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement data)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{
_logger.Information("Shard {ShardId}: Received unknown event type {EventType}", ShardId, eventType);
return null;
}
try
{
_logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", ShardId, eventType, clrType);
return JsonSerializer.Deserialize(data.GetRawText(), clrType, _jsonSerializerOptions)
as IGatewayEvent;
}
catch (JsonException e)
{
_logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", ShardId, eventType, clrType);
return null;
}
}
private Task HandleReady(ReadyEvent ready)
{
// TODO: when is ready.Shard ever null?
ShardInfo = ready.Shard ?? new ShardInfo(0, 0);
SessionInfo = SessionInfo with { Session = ready.SessionId };
User = ready.User;
Application = ready.Application;
State = ShardState.Open;
Ready?.Invoke();
return Task.CompletedTask;
}
private Task HandleResumed()
{
State = ShardState.Open;
Resumed?.Invoke();
return Task.CompletedTask;
}
private async Task HandleHello(JsonElement json)
{
var hello = JsonSerializer.Deserialize<GatewayHello>(json.GetRawText(), _jsonSerializerOptions)!;
_logger.Debug("Shard {ShardId}: Received Hello with interval {Interval} ms", ShardId, hello.HeartbeatInterval);
_currentHeartbeatInterval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval);
await SendHeartbeat(_conn!);
await SendIdentifyOrResume();
}
private async Task SendIdentifyOrResume()
{
if (SessionInfo.Session != null && SessionInfo.LastSequence != null)
await SendResume(SessionInfo.Session, SessionInfo.LastSequence!.Value);
else
await SendIdentify();
}
private async Task SendIdentify()
{
_logger.Information("Shard {ShardId}: Sending gateway Identify for shard {@ShardInfo}", ShardId, ShardInfo);
await _conn!.Send(new GatewayPacket
{
Opcode = GatewayOpcode.Identify,
Payload = new GatewayIdentify
{
Token = Settings.Token,
Properties = new GatewayIdentify.ConnectionProperties
{
Browser = LibraryName, Device = LibraryName, Os = Environment.OSVersion.ToString()
},
Intents = Settings.Intents,
Shard = ShardInfo
}
});
}
private async Task SendResume(string session, int lastSequence)
{
_logger.Information("Shard {ShardId}: Sending gateway Resume for session {@SessionInfo}",
ShardId, SessionInfo);
await _conn!.Send(new GatewayPacket
{
Opcode = GatewayOpcode.Resume,
Payload = new GatewayResume(Settings.Token, session, lastSequence)
});
}
public enum ShardState
{
Closed,
Connecting,
Open,
Closing
}
}
}

View File

@ -0,0 +1,131 @@
using System;
using System.Buffers;
using System.IO;
using System.Net.WebSockets;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Serilog;
namespace Myriad.Gateway
{
public class ShardConnection: IAsyncDisposable
{
private readonly MemoryStream _bufStream = new();
private readonly ClientWebSocket _client = new();
private readonly CancellationTokenSource _cts = new();
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger;
private readonly Task _worker;
public ShardConnection(Uri uri, ILogger logger, JsonSerializerOptions jsonSerializerOptions)
{
_logger = logger;
_jsonSerializerOptions = jsonSerializerOptions;
_worker = Worker(uri);
}
public Func<GatewayPacket, Task>? OnReceive { get; set; }
public Action? OnOpen { get; set; }
public Action<WebSocketCloseStatus, string?>? OnClose { get; set; }
public WebSocketState State => _client.State;
public async ValueTask DisposeAsync()
{
_cts.Cancel();
await _worker;
_client.Dispose();
await _bufStream.DisposeAsync();
_cts.Dispose();
}
private async Task Worker(Uri uri)
{
var realUrl = new UriBuilder(uri)
{
Query = "v=8&encoding=json"
}.Uri;
_logger.Debug("Connecting to gateway WebSocket at {GatewayUrl}", realUrl);
await _client.ConnectAsync(realUrl, default);
_logger.Debug("Gateway connection opened");
OnOpen?.Invoke();
// Main worker loop, spins until we manually disconnect (which hits the cancellation token)
// or the server disconnects us (which sets state to closed)
while (!_cts.IsCancellationRequested && _client.State == WebSocketState.Open)
{
try
{
await HandleReceive();
}
catch (Exception e)
{
_logger.Error(e, "Error in WebSocket receive worker");
}
}
OnClose?.Invoke(_client.CloseStatus ?? default, _client.CloseStatusDescription);
}
private async Task HandleReceive()
{
_bufStream.SetLength(0);
var result = await ReadData(_bufStream);
var data = _bufStream.GetBuffer().AsMemory(0, (int) _bufStream.Position);
if (result.MessageType == WebSocketMessageType.Text)
await HandleReceiveData(data);
else if (result.MessageType == WebSocketMessageType.Close)
_logger.Information("WebSocket closed by server: {StatusCode} {Reason}", _client.CloseStatus,
_client.CloseStatusDescription);
}
private async Task HandleReceiveData(Memory<byte> data)
{
var packet = JsonSerializer.Deserialize<GatewayPacket>(data.Span, _jsonSerializerOptions)!;
try
{
if (OnReceive != null)
await OnReceive.Invoke(packet);
}
catch (Exception e)
{
_logger.Error(e, "Error in gateway handler for {OpcodeType}", packet.Opcode);
}
}
private async Task<ValueWebSocketReceiveResult> ReadData(MemoryStream stream)
{
// TODO: does this throw if we disconnect mid-read?
using var buf = MemoryPool<byte>.Shared.Rent();
ValueWebSocketReceiveResult result;
do
{
result = await _client.ReceiveAsync(buf.Memory, _cts.Token);
stream.Write(buf.Memory.Span.Slice(0, result.Count));
} while (!result.EndOfMessage);
return result;
}
public async Task Send(GatewayPacket packet)
{
var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions);
await _client.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default);
}
public async Task Disconnect(WebSocketCloseStatus status, string? description)
{
await _client.CloseAsync(status, description, default);
_cts.Cancel();
}
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Gateway
{
public record ShardInfo(int ShardId, int NumShards);
}

View File

@ -0,0 +1,8 @@
namespace Myriad.Gateway
{
public record ShardSessionInfo
{
public string? Session { get; init; }
public int? LastSequence { get; init; }
}
}

19
Myriad/Myriad.csproj Normal file
View File

@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
<Nullable>enable</Nullable>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)' == 'Release' ">
<DebugSymbols>true</DebugSymbols>
<DebugType>full</DebugType>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Polly" Version="7.2.1" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="Serilog" Version="2.10.0" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,240 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Rest.Exceptions;
using Myriad.Rest.Ratelimit;
using Myriad.Rest.Types;
using Myriad.Serialization;
using Polly;
using Serilog;
namespace Myriad.Rest
{
public class BaseRestClient: IAsyncDisposable
{
private const string ApiBaseUrl = "https://discord.com/api/v8";
private readonly Version _httpVersion = new(2, 0);
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger;
private readonly Ratelimiter _ratelimiter;
private readonly AsyncPolicy<HttpResponseMessage> _retryPolicy;
public BaseRestClient(string userAgent, string token, ILogger logger)
{
_logger = logger.ForContext<BaseRestClient>();
if (!token.StartsWith("Bot "))
token = "Bot " + token;
Client = new HttpClient();
Client.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgent);
Client.DefaultRequestHeaders.TryAddWithoutValidation("Authorization", token);
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
_ratelimiter = new Ratelimiter(logger);
var discordPolicy = new DiscordRateLimitPolicy(_ratelimiter);
// todo: why doesn't the timeout work? o.o
var timeoutPolicy = Policy.TimeoutAsync<HttpResponseMessage>(TimeSpan.FromSeconds(10));
var waitPolicy = Policy
.Handle<RatelimitBucketExhaustedException>()
.WaitAndRetryAsync(3,
(_, e, _) => ((RatelimitBucketExhaustedException) e).RetryAfter,
(_, _, _, _) => Task.CompletedTask)
.AsAsyncPolicy<HttpResponseMessage>();
_retryPolicy = Policy.WrapAsync(timeoutPolicy, waitPolicy, discordPolicy);
}
public HttpClient Client { get; }
public ValueTask DisposeAsync()
{
_ratelimiter.Dispose();
Client.Dispose();
return default;
}
public async Task<T?> Get<T>(string path, (string endpointName, ulong major) ratelimitParams) where T: class
{
var request = new HttpRequestMessage(HttpMethod.Get, ApiBaseUrl + path);
var response = await Send(request, ratelimitParams, true);
// GET-only special case: 404s are nulls and not exceptions
if (response.StatusCode == HttpStatusCode.NotFound)
return null;
return await ReadResponse<T>(response);
}
public async Task<T?> Post<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> PostMultipart<T>(string path, (string endpointName, ulong major) ratelimitParams, object? payload, MultipartFile[]? files)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path);
SetRequestFormDataBody(request, payload, files);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Patch<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Patch, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Put<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Put, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task Delete(string path, (string endpointName, ulong major) ratelimitParams)
{
var request = new HttpRequestMessage(HttpMethod.Delete, ApiBaseUrl + path);
await Send(request, ratelimitParams);
}
private void SetRequestJsonBody(HttpRequestMessage request, object? body)
{
if (body == null) return;
request.Content =
new ReadOnlyMemoryContent(JsonSerializer.SerializeToUtf8Bytes(body, _jsonSerializerOptions));
request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json");
}
private void SetRequestFormDataBody(HttpRequestMessage request, object? payload, MultipartFile[]? files)
{
var bodyJson = JsonSerializer.SerializeToUtf8Bytes(payload, _jsonSerializerOptions);
var mfd = new MultipartFormDataContent();
mfd.Add(new ByteArrayContent(bodyJson), "payload_json");
if (files != null)
{
for (var i = 0; i < files.Length; i++)
{
var (filename, stream) = files[i];
mfd.Add(new StreamContent(stream), $"file{i}", filename);
}
}
request.Content = mfd;
}
private async Task<T?> ReadResponse<T>(HttpResponseMessage response) where T: class
{
if (response.StatusCode == HttpStatusCode.NoContent)
return null;
return await response.Content.ReadFromJsonAsync<T>(_jsonSerializerOptions);
}
private async Task<HttpResponseMessage> Send(HttpRequestMessage request,
(string endpointName, ulong major) ratelimitParams,
bool ignoreNotFound = false)
{
return await _retryPolicy.ExecuteAsync(async _ =>
{
_logger.Debug("Sending request: {RequestMethod} {RequestPath}",
request.Method, request.RequestUri);
request.Version = _httpVersion;
request.VersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
var stopwatch = new Stopwatch();
stopwatch.Start();
var response = await Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
stopwatch.Stop();
_logger.Debug(
"Received response in {ResponseDurationMs} ms: {RequestMethod} {RequestPath} -> {StatusCode} {ReasonPhrase}",
stopwatch.ElapsedMilliseconds, request.Method, request.RequestUri, (int) response.StatusCode,
response.ReasonPhrase);
await HandleApiError(response, ignoreNotFound);
return response;
},
new Dictionary<string, object>
{
{DiscordRateLimitPolicy.EndpointContextKey, ratelimitParams.endpointName},
{DiscordRateLimitPolicy.MajorContextKey, ratelimitParams.major}
});
}
private async ValueTask HandleApiError(HttpResponseMessage response, bool ignoreNotFound)
{
if (response.IsSuccessStatusCode)
return;
if (response.StatusCode == HttpStatusCode.NotFound && ignoreNotFound)
return;
throw await CreateDiscordException(response);
}
private async ValueTask<DiscordRequestException> CreateDiscordException(HttpResponseMessage response)
{
var body = await response.Content.ReadAsStringAsync();
var apiError = TryParseApiError(body);
return response.StatusCode switch
{
HttpStatusCode.BadRequest => new BadRequestException(response, body, apiError),
HttpStatusCode.Forbidden => new ForbiddenException(response, body, apiError),
HttpStatusCode.Unauthorized => new UnauthorizedException(response, body, apiError),
HttpStatusCode.NotFound => new NotFoundException(response, body, apiError),
HttpStatusCode.Conflict => new ConflictException(response, body, apiError),
HttpStatusCode.TooManyRequests => new TooManyRequestsException(response, body, apiError),
_ => new UnknownDiscordRequestException(response, body, apiError)
};
}
private DiscordApiError? TryParseApiError(string responseBody)
{
if (string.IsNullOrWhiteSpace(responseBody))
return null;
try
{
return JsonSerializer.Deserialize<DiscordApiError>(responseBody, _jsonSerializerOptions);
}
catch (JsonException e)
{
_logger.Verbose(e, "Error deserializing API error");
}
return null;
}
}
}

View File

@ -0,0 +1,130 @@
using System;
using System.IO;
using System.Net;
using System.Threading.Tasks;
using Myriad.Rest.Types;
using Myriad.Rest.Types.Requests;
using Myriad.Types;
using Serilog;
namespace Myriad.Rest
{
public class DiscordApiClient
{
private const string UserAgent = "Test Discord Library by @Ske#6201";
private readonly BaseRestClient _client;
public DiscordApiClient(string token, ILogger logger)
{
_client = new BaseRestClient(UserAgent, token, logger);
}
public Task<GatewayInfo> GetGateway() =>
_client.Get<GatewayInfo>("/gateway", ("GetGateway", default))!;
public Task<GatewayInfo.Bot> GetGatewayBot() =>
_client.Get<GatewayInfo.Bot>("/gateway/bot", ("GetGatewayBot", default))!;
public Task<Channel?> GetChannel(ulong channelId) =>
_client.Get<Channel>($"/channels/{channelId}", ("GetChannel", channelId));
public Task<Message?> GetMessage(ulong channelId, ulong messageId) =>
_client.Get<Message>($"/channels/{channelId}/messages/{messageId}", ("GetMessage", channelId));
public Task<Guild?> GetGuild(ulong id) =>
_client.Get<Guild>($"/guilds/{id}", ("GetGuild", id));
public Task<Channel[]> GetGuildChannels(ulong id) =>
_client.Get<Channel[]>($"/guilds/{id}/channels", ("GetGuildChannels", id))!;
public Task<User?> GetUser(ulong id) =>
_client.Get<User>($"/users/{id}", ("GetUser", default));
public Task<GuildMember?> GetGuildMember(ulong guildId, ulong userId) =>
_client.Get<GuildMember>($"/guilds/{guildId}/members/{userId}",
("GetGuildMember", guildId));
public Task<Message> CreateMessage(ulong channelId, MessageRequest request, MultipartFile[]? files = null) =>
_client.PostMultipart<Message>($"/channels/{channelId}/messages", ("CreateMessage", channelId), request, files)!;
public Task<Message> EditMessage(ulong channelId, ulong messageId, MessageEditRequest request) =>
_client.Patch<Message>($"/channels/{channelId}/messages/{messageId}", ("EditMessage", channelId), request)!;
public Task DeleteMessage(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}", ("DeleteMessage", channelId));
public Task CreateReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Put<object>($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("CreateReaction", channelId), null);
public Task DeleteOwnReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("DeleteOwnReaction", channelId));
public Task DeleteUserReaction(ulong channelId, ulong messageId, Emoji emoji, ulong userId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/{userId}",
("DeleteUserReaction", channelId));
public Task DeleteAllReactions(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions",
("DeleteAllReactions", channelId));
public Task DeleteAllReactionsForEmoji(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}",
("DeleteAllReactionsForEmoji", channelId));
public Task<ApplicationCommand> CreateGlobalApplicationCommand(ulong applicationId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/commands",
("CreateGlobalApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand[]> GetGuildApplicationCommands(ulong applicationId, ulong guildId) =>
_client.Get<ApplicationCommand[]>($"/applications/{applicationId}/guilds/{guildId}/commands",
("GetGuildApplicationCommands", applicationId))!;
public Task<ApplicationCommand> CreateGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("CreateGuildApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand> EditGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Patch<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("EditGuildApplicationCommand", applicationId), request)!;
public Task DeleteGuildApplicationCommand(ulong applicationId, ulong commandId) =>
_client.Delete($"/applications/{applicationId}/commands/{commandId}",
("DeleteGuildApplicationCommand", applicationId));
public Task CreateInteractionResponse(ulong interactionId, string token, InteractionResponse response) =>
_client.Post<object>($"/interactions/{interactionId}/{token}/callback",
("CreateInteractionResponse", interactionId), response);
public Task ModifyGuildMember(ulong guildId, ulong userId, ModifyGuildMemberRequest request) =>
_client.Patch<object>($"/guilds/{guildId}/members/{userId}",
("ModifyGuildMember", guildId), request);
public Task<Webhook> CreateWebhook(ulong channelId, CreateWebhookRequest request) =>
_client.Post<Webhook>($"/channels/{channelId}/webhooks", ("CreateWebhook", channelId), request)!;
public Task<Webhook> GetWebhook(ulong webhookId) =>
_client.Get<Webhook>($"/webhooks/{webhookId}/webhooks", ("GetWebhook", webhookId))!;
public Task<Webhook[]> GetChannelWebhooks(ulong channelId) =>
_client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!;
public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request,
MultipartFile[]? files = null) =>
_client.PostMultipart<Message>($"/webhooks/{webhookId}/{webhookToken}?wait=true",
("ExecuteWebhook", webhookId), request, files)!;
public Task<Channel> CreateDm(ulong recipientId) =>
_client.Post<Channel>($"/users/@me/channels", ("CreateDM", default), new CreateDmRequest(recipientId))!;
private static string EncodeEmoji(Emoji emoji) =>
WebUtility.UrlEncode(emoji.Name) ?? emoji.Id?.ToString() ??
throw new ArgumentException("Could not encode emoji");
}
}

View File

@ -0,0 +1,9 @@
using System.Text.Json;
namespace Myriad.Rest
{
public record DiscordApiError(string Message, int Code)
{
public JsonElement? Errors { get; init; }
}
}

View File

@ -0,0 +1,71 @@
using System;
using System.Net;
using System.Net.Http;
namespace Myriad.Rest.Exceptions
{
public class DiscordRequestException: Exception
{
public DiscordRequestException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError)
{
ResponseBody = responseBody;
Response = response;
ApiError = apiError;
}
public string ResponseBody { get; init; } = null!;
public HttpResponseMessage Response { get; init; } = null!;
public HttpStatusCode StatusCode => Response.StatusCode;
public int? ErrorCode => ApiError?.Code;
internal DiscordApiError? ApiError { get; init; }
public override string Message =>
(ApiError?.Message ?? Response.ReasonPhrase ?? "") + (FormError != null ? $": {FormError}" : "");
public string? FormError => ApiError?.Errors?.ToString();
}
public class NotFoundException: DiscordRequestException
{
public NotFoundException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError): base(
response, responseBody, apiError) { }
}
public class UnauthorizedException: DiscordRequestException
{
public UnauthorizedException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError): base(
response, responseBody, apiError) { }
}
public class ForbiddenException: DiscordRequestException
{
public ForbiddenException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError): base(
response, responseBody, apiError) { }
}
public class ConflictException: DiscordRequestException
{
public ConflictException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError): base(
response, responseBody, apiError) { }
}
public class BadRequestException: DiscordRequestException
{
public BadRequestException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError): base(
response, responseBody, apiError) { }
}
public class TooManyRequestsException: DiscordRequestException
{
public TooManyRequestsException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError):
base(response, responseBody, apiError) { }
}
public class UnknownDiscordRequestException: DiscordRequestException
{
public UnknownDiscordRequestException(HttpResponseMessage response, string responseBody,
DiscordApiError? apiError): base(response, responseBody, apiError) { }
}
}

View File

@ -0,0 +1,29 @@
using System;
using Myriad.Rest.Ratelimit;
namespace Myriad.Rest.Exceptions
{
public class RatelimitException: Exception
{
public RatelimitException(string? message): base(message) { }
}
public class RatelimitBucketExhaustedException: RatelimitException
{
public RatelimitBucketExhaustedException(Bucket bucket, TimeSpan retryAfter): base(
"Rate limit bucket exhausted, request blocked")
{
Bucket = bucket;
RetryAfter = retryAfter;
}
public Bucket Bucket { get; }
public TimeSpan RetryAfter { get; }
}
public class GloballyRatelimitedException: RatelimitException
{
public GloballyRatelimitedException(): base("Global rate limit hit") { }
}
}

View File

@ -0,0 +1,173 @@
using System;
using System.Threading;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class Bucket
{
private static readonly TimeSpan Epsilon = TimeSpan.FromMilliseconds(10);
private static readonly TimeSpan FallbackDelay = TimeSpan.FromMilliseconds(200);
private static readonly TimeSpan StaleTimeout = TimeSpan.FromSeconds(5);
private readonly ILogger _logger;
private readonly SemaphoreSlim _semaphore = new(1, 1);
private DateTimeOffset? _nextReset;
private bool _resetTimeValid;
private bool _hasReceivedHeaders;
public Bucket(ILogger logger, string key, ulong major, int limit)
{
_logger = logger.ForContext<Bucket>();
Key = key;
Major = major;
Limit = limit;
Remaining = limit;
_resetTimeValid = false;
}
public string Key { get; }
public ulong Major { get; }
public int Remaining { get; private set; }
public int Limit { get; private set; }
public DateTimeOffset LastUsed { get; private set; } = DateTimeOffset.UtcNow;
public bool TryAcquire()
{
LastUsed = DateTimeOffset.Now;
try
{
_semaphore.Wait();
if (Remaining > 0)
{
_logger.Debug(
"{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit} left], allowing through",
Key, Major, Remaining, Limit);
Remaining--;
return true;
}
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit}] left, denying",
Key, Major, Remaining, Limit);
return false;
}
finally
{
_semaphore.Release();
}
}
public void HandleResponse(RatelimitHeaders headers)
{
try
{
_semaphore.Wait();
_logger.Verbose("{BucketKey}/{BucketMajor}: Received rate limit headers: {@RateLimitHeaders}",
Key, Major, headers);
if (headers.ResetAfter != null)
{
var headerNextReset = DateTimeOffset.UtcNow + headers.ResetAfter.Value; // todo: server time
if (_nextReset == null || headerNextReset > _nextReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Received reset time {NextReset} from server (after: {NextResetAfter}, remaining: {Remaining}, local remaining: {LocalRemaining})",
Key, Major, headerNextReset, headers.ResetAfter.Value, headers.Remaining, Remaining);
_nextReset = headerNextReset;
_resetTimeValid = true;
}
}
if (headers.Limit != null)
Limit = headers.Limit.Value;
if (headers.Remaining != null && !_hasReceivedHeaders)
{
var oldRemaining = Remaining;
Remaining = Math.Min(headers.Remaining.Value, Remaining);
_logger.Debug("{BucketKey}/{BucketMajor}: Received first remaining of {HeaderRemaining}, previous local remaining is {LocalRemaining}, new local remaining is {Remaining}",
Key, Major, headers.Remaining.Value, oldRemaining, Remaining);
_hasReceivedHeaders = true;
}
}
finally
{
_semaphore.Release();
}
}
public void Tick(DateTimeOffset now)
{
try
{
_semaphore.Wait();
// If we don't have any reset data, "snap" it to now
// This happens before first request and at this point the reset is invalid anyway, so it's fine
// but it ensures the stale timeout doesn't trigger early by using `default` value
if (_nextReset == null)
_nextReset = now;
// If we're past the reset time *and* we haven't reset already, do that
var timeSinceReset = now - _nextReset;
var shouldReset = _resetTimeValid && timeSinceReset > TimeSpan.Zero;
if (shouldReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket timed out, refreshing with {BucketLimit} requests",
Key, Major, Limit);
Remaining = Limit;
_resetTimeValid = false;
return;
}
// We've run out of requests without having any new reset time,
// *and* it's been longer than a set amount - add one request back to the pool and hope that one returns
var isBucketStale = !_resetTimeValid && Remaining <= 0 && timeSinceReset > StaleTimeout;
if (isBucketStale)
{
_logger.Warning(
"{BucketKey}/{BucketMajor}: Bucket is stale ({StaleTimeout} passed with no rate limit info), allowing one request through",
Key, Major, StaleTimeout);
Remaining = 1;
// Reset the (still-invalid) reset time to now, so we don't keep hitting this conditional over and over...
_nextReset = now;
}
}
finally
{
_semaphore.Release();
}
}
public TimeSpan GetResetDelay(DateTimeOffset now)
{
// If we don't have a valid reset time, return the fallback delay always
// (so it'll keep spinning until we hopefully have one...)
if (!_resetTimeValid)
return FallbackDelay;
var delay = (_nextReset ?? now) - now;
// If we have a really small (or negative) value, return a fallback delay too
if (delay < Epsilon)
return FallbackDelay;
return delay;
}
}
}

View File

@ -0,0 +1,80 @@
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class BucketManager: IDisposable
{
private static readonly TimeSpan StaleBucketTimeout = TimeSpan.FromMinutes(5);
private static readonly TimeSpan PruneWorkerInterval = TimeSpan.FromMinutes(1);
private readonly ConcurrentDictionary<(string key, ulong major), Bucket> _buckets = new();
private readonly ConcurrentDictionary<string, string> _endpointKeyMap = new();
private readonly ConcurrentDictionary<string, int> _knownKeyLimits = new();
private readonly ILogger _logger;
private readonly Task _worker;
private readonly CancellationTokenSource _workerCts = new();
public BucketManager(ILogger logger)
{
_logger = logger.ForContext<BucketManager>();
_worker = PruneWorker(_workerCts.Token);
}
public void Dispose()
{
_workerCts.Dispose();
_worker.Dispose();
}
public Bucket? GetBucket(string endpoint, ulong major)
{
if (!_endpointKeyMap.TryGetValue(endpoint, out var key))
return null;
if (_buckets.TryGetValue((key, major), out var bucket))
return bucket;
if (!_knownKeyLimits.TryGetValue(key, out var knownLimit))
return null;
_logger.Debug("Creating new bucket {BucketKey}/{BucketMajor} with limit {KnownLimit}", key, major, knownLimit);
return _buckets.GetOrAdd((key, major),
k => new Bucket(_logger, k.Item1, k.Item2, knownLimit));
}
public void UpdateEndpointInfo(string endpoint, string key, int? limit)
{
_endpointKeyMap[endpoint] = key;
if (limit != null)
_knownKeyLimits[key] = limit.Value;
}
private async Task PruneWorker(CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{
await Task.Delay(PruneWorkerInterval, ct);
PruneStaleBuckets(DateTimeOffset.UtcNow);
}
}
private void PruneStaleBuckets(DateTimeOffset now)
{
foreach (var (key, bucket) in _buckets)
if (now - bucket.LastUsed > StaleBucketTimeout)
{
_logger.Debug("Pruning unused bucket {Bucket} (last used at {BucketLastUsed})", bucket,
bucket.LastUsed);
_buckets.TryRemove(key, out _);
}
}
}
}

View File

@ -0,0 +1,46 @@
using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Polly;
namespace Myriad.Rest.Ratelimit
{
public class DiscordRateLimitPolicy: AsyncPolicy<HttpResponseMessage>
{
public const string EndpointContextKey = "Endpoint";
public const string MajorContextKey = "Major";
private readonly Ratelimiter _ratelimiter;
public DiscordRateLimitPolicy(Ratelimiter ratelimiter, PolicyBuilder<HttpResponseMessage>? policyBuilder = null)
: base(policyBuilder)
{
_ratelimiter = ratelimiter;
}
protected override async Task<HttpResponseMessage> ImplementationAsync(
Func<Context, CancellationToken, Task<HttpResponseMessage>> action, Context context, CancellationToken ct,
bool continueOnCapturedContext)
{
if (!context.TryGetValue(EndpointContextKey, out var endpointObj) || !(endpointObj is string endpoint))
throw new ArgumentException("Must provide endpoint in Polly context");
if (!context.TryGetValue(MajorContextKey, out var majorObj) || !(majorObj is ulong major))
throw new ArgumentException("Must provide major in Polly context");
// Check rate limit, throw if we're not allowed...
_ratelimiter.AllowRequestOrThrow(endpoint, major, DateTimeOffset.Now);
// We're OK, push it through
var response = await action(context, ct).ConfigureAwait(continueOnCapturedContext);
// Update rate limit state with headers
var headers = RatelimitHeaders.Parse(response);
_ratelimiter.HandleResponse(headers, endpoint, major);
return response;
}
}
}

View File

@ -0,0 +1,85 @@
using System;
using System.Globalization;
using System.Linq;
using System.Net.Http;
namespace Myriad.Rest.Ratelimit
{
public record RatelimitHeaders
{
private const string LimitHeader = "X-RateLimit-Limit";
private const string RemainingHeader = "X-RateLimit-Remaining";
private const string ResetHeader = "X-RateLimit-Reset";
private const string ResetAfterHeader = "X-RateLimit-Reset-After";
private const string BucketHeader = "X-RateLimit-Bucket";
private const string GlobalHeader = "X-RateLimit-Global";
public bool Global { get; private set; }
public int? Limit { get; private set; }
public int? Remaining { get; private set; }
public DateTimeOffset? Reset { get; private set; }
public TimeSpan? ResetAfter { get; private set; }
public string? Bucket { get; private set; }
public DateTimeOffset? ServerDate { get; private set; }
public bool HasRatelimitInfo =>
Limit != null && Remaining != null && Reset != null && ResetAfter != null && Bucket != null;
public RatelimitHeaders() { }
public static RatelimitHeaders Parse(HttpResponseMessage response)
{
var headers = new RatelimitHeaders
{
ServerDate = response.Headers.Date,
Limit = TryGetInt(response, LimitHeader),
Remaining = TryGetInt(response, RemainingHeader),
Bucket = TryGetHeader(response, BucketHeader)
};
var resetTimestamp = TryGetDouble(response, ResetHeader);
if (resetTimestamp != null)
headers.Reset = DateTimeOffset.FromUnixTimeMilliseconds((long) (resetTimestamp.Value * 1000));
var resetAfterSeconds = TryGetDouble(response, ResetAfterHeader);
if (resetAfterSeconds != null)
headers.ResetAfter = TimeSpan.FromSeconds(resetAfterSeconds.Value);
var global = TryGetHeader(response, GlobalHeader);
if (global != null && bool.TryParse(global, out var globalBool))
headers.Global = globalBool;
return headers;
}
private static string? TryGetHeader(HttpResponseMessage response, string headerName)
{
if (!response.Headers.TryGetValues(headerName, out var values))
return null;
return values.FirstOrDefault();
}
private static int? TryGetInt(HttpResponseMessage response, string headerName)
{
var valueString = TryGetHeader(response, headerName);
if (!int.TryParse(valueString, NumberStyles.Integer, CultureInfo.InvariantCulture, out var value))
return null;
return value;
}
private static double? TryGetDouble(HttpResponseMessage response, string headerName)
{
var valueString = TryGetHeader(response, headerName);
if (!double.TryParse(valueString, NumberStyles.Float, CultureInfo.InvariantCulture, out var value))
return null;
return value;
}
}
}

View File

@ -0,0 +1,86 @@
using System;
using Myriad.Rest.Exceptions;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class Ratelimiter: IDisposable
{
private readonly BucketManager _buckets;
private readonly ILogger _logger;
private DateTimeOffset? _globalRateLimitExpiry;
public Ratelimiter(ILogger logger)
{
_logger = logger.ForContext<Ratelimiter>();
_buckets = new BucketManager(logger);
}
public void Dispose()
{
_buckets.Dispose();
}
public void AllowRequestOrThrow(string endpoint, ulong major, DateTimeOffset now)
{
if (IsGloballyRateLimited(now))
{
_logger.Warning("Globally rate limited until {GlobalRateLimitExpiry}, cancelling request",
_globalRateLimitExpiry);
throw new GloballyRatelimitedException();
}
var bucket = _buckets.GetBucket(endpoint, major);
if (bucket == null)
{
// No rate limit for this endpoint (yet), allow through
_logger.Debug("No rate limit data for endpoint {Endpoint}, allowing through", endpoint);
return;
}
bucket.Tick(now);
if (bucket.TryAcquire())
// We're allowed to send it! :)
return;
// We can't send this request right now; retrying...
var waitTime = bucket.GetResetDelay(now);
// add a small buffer for Timing:tm:
waitTime += TimeSpan.FromMilliseconds(50);
// (this is caught by a WaitAndRetry Polly handler, if configured)
throw new RatelimitBucketExhaustedException(bucket, waitTime);
}
public void HandleResponse(RatelimitHeaders headers, string endpoint, ulong major)
{
if (!headers.HasRatelimitInfo)
return;
// TODO: properly calculate server time?
if (headers.Global)
{
_logger.Warning(
"Global rate limit hit, resetting at {GlobalRateLimitExpiry} (in {GlobalRateLimitResetAfter}!",
_globalRateLimitExpiry, headers.ResetAfter);
_globalRateLimitExpiry = headers.Reset;
}
else
{
// Update buckets first, then get it again, to properly "transfer" this info over to the new value
_buckets.UpdateEndpointInfo(endpoint, headers.Bucket!, headers.Limit);
var bucket = _buckets.GetBucket(endpoint, major);
bucket?.HandleResponse(headers);
}
}
private bool IsGloballyRateLimited(DateTimeOffset now) =>
_globalRateLimitExpiry > now;
}
}

View File

@ -0,0 +1,19 @@
using System.Collections.Generic;
namespace Myriad.Rest.Types
{
public record AllowedMentions
{
public enum ParseType
{
Roles,
Users,
Everyone
}
public ParseType[]? Parse { get; set; }
public ulong[]? Users { get; set; }
public ulong[]? Roles { get; set; }
public bool RepliedUser { get; set; }
}
}

View File

@ -0,0 +1,6 @@
using System.IO;
namespace Myriad.Rest.Types
{
public record MultipartFile(string Filename, Stream Data);
}

View File

@ -0,0 +1,13 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Rest.Types
{
public record ApplicationCommandRequest
{
public string Name { get; init; }
public string Description { get; init; }
public List<ApplicationCommandOption>? Options { get; init; }
}
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Rest.Types.Requests
{
public record CreateDmRequest(ulong RecipientId);
}

View File

@ -0,0 +1,4 @@
namespace Myriad.Rest.Types.Requests
{
public record CreateWebhookRequest(string Name);
}

View File

@ -0,0 +1,13 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record ExecuteWebhookRequest
{
public string? Content { get; init; }
public string? Username { get; init; }
public string? AvatarUrl { get; init; }
public Embed[] Embeds { get; init; }
public AllowedMentions? AllowedMentions { get; init; }
}
}

View File

@ -0,0 +1,22 @@
using System.Text.Json.Serialization;
using Myriad.Types;
using Myriad.Utils;
namespace Myriad.Rest.Types.Requests
{
public record MessageEditRequest
{
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<string?> Content { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<Embed?> Embed { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<Message.MessageFlags> Flags { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<AllowedMentions> AllowedMentions { get; init; }
}
}

View File

@ -0,0 +1,13 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record MessageRequest
{
public string? Content { get; set; }
public object? Nonce { get; set; }
public bool Tts { get; set; }
public AllowedMentions? AllowedMentions { get; set; }
public Embed? Embed { get; set; }
}
}

View File

@ -0,0 +1,7 @@
namespace Myriad.Rest.Types
{
public record ModifyGuildMemberRequest
{
public string? Nick { get; init; }
}
}

View File

@ -0,0 +1,21 @@
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Myriad.Serialization
{
public static class JsonSerializerOptionsExtensions
{
public static JsonSerializerOptions ConfigureForMyriad(this JsonSerializerOptions opts)
{
opts.PropertyNamingPolicy = new JsonSnakeCaseNamingPolicy();
opts.NumberHandling = JsonNumberHandling.AllowReadingFromString;
opts.IncludeFields = true;
opts.Converters.Add(new PermissionSetJsonConverter());
opts.Converters.Add(new ShardInfoJsonConverter());
opts.Converters.Add(new OptionalConverterFactory());
return opts;
}
}
}

View File

@ -0,0 +1,88 @@
using System;
using System.Text;
using System.Text.Json;
namespace Myriad.Serialization
{
// From https://github.com/J0rgeSerran0/JsonNamingPolicy/blob/master/JsonSnakeCaseNamingPolicy.cs, no NuGet :/
public class JsonSnakeCaseNamingPolicy: JsonNamingPolicy
{
private readonly string _separator = "_";
public override string ConvertName(string name)
{
if (string.IsNullOrEmpty(name) || string.IsNullOrWhiteSpace(name)) return string.Empty;
ReadOnlySpan<char> spanName = name.Trim();
var stringBuilder = new StringBuilder();
var addCharacter = true;
var isPreviousSpace = false;
var isPreviousSeparator = false;
var isCurrentSpace = false;
var isNextLower = false;
var isNextUpper = false;
var isNextSpace = false;
for (var position = 0; position < spanName.Length; position++)
{
if (position != 0)
{
isCurrentSpace = spanName[position] == 32;
isPreviousSpace = spanName[position - 1] == 32;
isPreviousSeparator = spanName[position - 1] == 95;
if (position + 1 != spanName.Length)
{
isNextLower = spanName[position + 1] > 96 && spanName[position + 1] < 123;
isNextUpper = spanName[position + 1] > 64 && spanName[position + 1] < 91;
isNextSpace = spanName[position + 1] == 32;
}
if (isCurrentSpace &&
(isPreviousSpace ||
isPreviousSeparator ||
isNextUpper ||
isNextSpace))
{
addCharacter = false;
}
else
{
var isCurrentUpper = spanName[position] > 64 && spanName[position] < 91;
var isPreviousLower = spanName[position - 1] > 96 && spanName[position - 1] < 123;
var isPreviousNumber = spanName[position - 1] > 47 && spanName[position - 1] < 58;
if (isCurrentUpper &&
(isPreviousLower ||
isPreviousNumber ||
isNextLower ||
isNextSpace ||
isNextLower && !isPreviousSpace))
{
stringBuilder.Append(_separator);
}
else
{
if (isCurrentSpace &&
!isPreviousSpace &&
!isNextSpace)
{
stringBuilder.Append(_separator);
addCharacter = false;
}
}
}
}
if (addCharacter)
stringBuilder.Append(spanName[position]);
else
addCharacter = true;
}
return stringBuilder.ToString().ToLower();
}
}
}

View File

@ -0,0 +1,22 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Myriad.Serialization
{
public class JsonStringConverter: JsonConverter<object>
{
public override object? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var str = JsonSerializer.Deserialize<string>(ref reader);
var inner = JsonSerializer.Deserialize(str!, typeToConvert, options);
return inner;
}
public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options)
{
var inner = JsonSerializer.Serialize(value, options);
writer.WriteStringValue(inner);
}
}
}

View File

@ -0,0 +1,48 @@
using System;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
using Myriad.Utils;
namespace Myriad.Serialization
{
public class OptionalConverterFactory: JsonConverterFactory
{
public class Inner<T>: JsonConverter<Optional<T>>
{
public override Optional<T> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var inner = JsonSerializer.Deserialize<T>(ref reader, options);
return new(inner!);
}
public override void Write(Utf8JsonWriter writer, Optional<T> value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.HasValue ? value.GetValue() : default, typeof(T), options);
}
}
public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options)
{
var innerType = typeToConvert.GetGenericArguments()[0];
return (JsonConverter?) Activator.CreateInstance(
typeof(Inner<>).MakeGenericType(innerType),
BindingFlags.Instance | BindingFlags.Public,
null,
null,
null);
}
public override bool CanConvert(Type typeToConvert)
{
if (!typeToConvert.IsGenericType)
return false;
if (typeToConvert.GetGenericTypeDefinition() != typeof(Optional<>))
return false;
return true;
}
}
}

View File

@ -0,0 +1,24 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
using Myriad.Types;
namespace Myriad.Serialization
{
public class PermissionSetJsonConverter: JsonConverter<PermissionSet>
{
public override PermissionSet Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var str = reader.GetString();
if (str == null) return default;
return (PermissionSet) ulong.Parse(str);
}
public override void Write(Utf8JsonWriter writer, PermissionSet value, JsonSerializerOptions options)
{
writer.WriteStringValue(((ulong) value).ToString());
}
}
}

View File

@ -0,0 +1,28 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
using Myriad.Gateway;
namespace Myriad.Serialization
{
public class ShardInfoJsonConverter: JsonConverter<ShardInfo>
{
public override ShardInfo? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<int[]>(ref reader);
if (arr?.Length != 2)
throw new JsonException("Expected shard info as array of length 2");
return new ShardInfo(arr[0], arr[1]);
}
public override void Write(Utf8JsonWriter writer, ShardInfo value, JsonSerializerOptions options)
{
writer.WriteStartArray();
writer.WriteNumberValue(value.ShardId);
writer.WriteNumberValue(value.NumShards);
writer.WriteEndArray();
}
}
}

22
Myriad/Types/Activity.cs Normal file
View File

@ -0,0 +1,22 @@
namespace Myriad.Types
{
public record Activity: ActivityPartial
{
}
public record ActivityPartial
{
public string Name { get; init; }
public ActivityType Type { get; init; }
public string? Url { get; init; }
}
public enum ActivityType
{
Game = 0,
Streaming = 1,
Listening = 2,
Custom = 4,
Competing = 5
}
}

View File

@ -0,0 +1,27 @@
using System.Collections.Generic;
namespace Myriad.Types
{
public record Application: ApplicationPartial
{
public string Name { get; init; }
public string? Icon { get; init; }
public string Description { get; init; }
public string[]? RpcOrigins { get; init; }
public bool BotPublic { get; init; }
public bool BotRequireCodeGrant { get; init; }
public User Owner { get; init; } // TODO: docs specify this is "partial", what does that mean
public string Summary { get; init; }
public string VerifyKey { get; init; }
public ulong? GuildId { get; init; }
public ulong? PrimarySkuId { get; init; }
public string? Slug { get; init; }
public string? CoverImage { get; init; }
}
public record ApplicationPartial
{
public ulong Id { get; init; }
public int Flags { get; init; }
}
}

View File

@ -0,0 +1,13 @@
using System.Collections.Generic;
namespace Myriad.Types
{
public record ApplicationCommand
{
public ulong Id { get; init; }
public ulong ApplicationId { get; init; }
public string Name { get; init; }
public string Description { get; init; }
public ApplicationCommandOption[]? Options { get; init; }
}
}

View File

@ -0,0 +1,9 @@
namespace Myriad.Types
{
public record ApplicationCommandInteractionData
{
public ulong Id { get; init; }
public string Name { get; init; }
public ApplicationCommandInteractionDataOption[] Options { get; init; }
}
}

View File

@ -0,0 +1,9 @@
namespace Myriad.Types
{
public record ApplicationCommandInteractionDataOption
{
public string Name { get; init; }
public object? Value { get; init; }
public ApplicationCommandInteractionDataOption[]? Options { get; init; }
}
}

View File

@ -0,0 +1,24 @@
namespace Myriad.Types
{
public record ApplicationCommandOption(ApplicationCommandOption.OptionType Type, string Name, string Description)
{
public enum OptionType
{
Subcommand = 1,
SubcommandGroup = 2,
String = 3,
Integer = 4,
Boolean = 5,
User = 6,
Channel = 7,
Role = 8
}
public bool Default { get; init; }
public bool Required { get; init; }
public Choice[]? Choices { get; init; }
public ApplicationCommandOption[]? Options { get; init; }
public record Choice(string Name, object Value);
}
}

View File

@ -0,0 +1,19 @@
namespace Myriad.Types
{
public record Interaction
{
public enum InteractionType
{
Ping = 1,
ApplicationCommand = 2
}
public ulong Id { get; init; }
public InteractionType Type { get; init; }
public ApplicationCommandInteractionData? Data { get; init; }
public ulong GuildId { get; init; }
public ulong ChannelId { get; init; }
public GuildMember Member { get; init; }
public string Token { get; init; }
}
}

View File

@ -0,0 +1,15 @@
using System.Collections.Generic;
using Myriad.Rest.Types;
namespace Myriad.Types
{
public record InteractionApplicationCommandCallbackData
{
public bool? Tts { get; init; }
public string Content { get; init; }
public Embed[]? Embeds { get; init; }
public AllowedMentions? AllowedMentions { get; init; }
public Message.MessageFlags Flags { get; init; }
}
}

View File

@ -0,0 +1,17 @@
namespace Myriad.Types
{
public record InteractionResponse
{
public enum ResponseType
{
Pong = 1,
Acknowledge = 2,
ChannelMessage = 3,
ChannelMessageWithSource = 4,
AckWithSource = 5
}
public ResponseType Type { get; init; }
public InteractionApplicationCommandCallbackData? Data { get; init; }
}
}

41
Myriad/Types/Channel.cs Normal file
View File

@ -0,0 +1,41 @@
namespace Myriad.Types
{
public record Channel
{
public enum ChannelType
{
GuildText = 0,
Dm = 1,
GuildVoice = 2,
GroupDm = 3,
GuildCategory = 4,
GuildNews = 5,
GuildStore = 6
}
public ulong Id { get; init; }
public ChannelType Type { get; init; }
public ulong? GuildId { get; init; }
public int? Position { get; init; }
public string? Name { get; init; }
public string? Topic { get; init; }
public bool? Nsfw { get; init; }
public ulong? ParentId { get; init; }
public Overwrite[]? PermissionOverwrites { get; init; }
public User[]? Recipients { get; init; }
public record Overwrite
{
public ulong Id { get; init; }
public OverwriteType Type { get; init; }
public PermissionSet Allow { get; init; }
public PermissionSet Deny { get; init; }
}
public enum OverwriteType
{
Role = 0,
Member = 1
}
}
}

62
Myriad/Types/Embed.cs Normal file
View File

@ -0,0 +1,62 @@
namespace Myriad.Types
{
public record Embed
{
public string? Title { get; init; }
public string? Type { get; init; }
public string? Description { get; init; }
public string? Url { get; init; }
public string? Timestamp { get; init; }
public uint? Color { get; init; }
public EmbedFooter? Footer { get; init; }
public EmbedImage? Image { get; init; }
public EmbedThumbnail? Thumbnail { get; init; }
public EmbedVideo? Video { get; init; }
public EmbedProvider? Provider { get; init; }
public EmbedAuthor? Author { get; init; }
public Field[]? Fields { get; init; }
public record EmbedFooter (
string Text,
string? IconUrl = null,
string? ProxyIconUrl = null
);
public record EmbedImage (
string? Url,
uint? Width = null,
uint? Height = null
);
public record EmbedThumbnail (
string? Url,
string? ProxyUrl = null,
uint? Width = null,
uint? Height = null
);
public record EmbedVideo (
string? Url,
uint? Width = null,
uint? Height = null
);
public record EmbedProvider (
string? Name,
string? Url
);
public record EmbedAuthor (
string? Name = null,
string? Url = null,
string? IconUrl = null,
string? ProxyIconUrl = null
);
public record Field (
string Name,
string Value,
bool Inline = false
);
}
}

9
Myriad/Types/Emoji.cs Normal file
View File

@ -0,0 +1,9 @@
namespace Myriad.Types
{
public record Emoji
{
public ulong? Id { get; init; }
public string? Name { get; init; }
public bool? Animated { get; init; }
}
}

View File

@ -0,0 +1,13 @@
namespace Myriad.Types
{
public record GatewayInfo
{
public string Url { get; init; }
public record Bot: GatewayInfo
{
public int Shards { get; init; }
public SessionStartLimit SessionStartLimit { get; init; }
}
}
}

View File

@ -0,0 +1,10 @@
namespace Myriad.Types
{
public record SessionStartLimit
{
public int Total { get; init; }
public int Remaining { get; init; }
public int ResetAfter { get; init; }
public int MaxConcurrency { get; init; }
}
}

24
Myriad/Types/Guild.cs Normal file
View File

@ -0,0 +1,24 @@
using System.Collections.Generic;
namespace Myriad.Types
{
public record Guild
{
public ulong Id { get; init; }
public string Name { get; init; }
public string? Icon { get; init; }
public string? Splash { get; init; }
public string? DiscoverySplash { get; init; }
public bool? Owner { get; init; }
public ulong OwnerId { get; init; }
public string Region { get; init; }
public ulong? AfkChannelId { get; init; }
public int AfkTimeout { get; init; }
public bool? WidgetEnabled { get; init; }
public bool? WidgetChannelId { get; init; }
public int VerificationLevel { get; init; }
public Role[] Roles { get; init; }
public string[] Features { get; init; }
}
}

View File

@ -0,0 +1,14 @@
namespace Myriad.Types
{
public record GuildMember: GuildMemberPartial
{
public User User { get; init; }
}
public record GuildMemberPartial
{
public string? Nick { get; init; }
public ulong[] Roles { get; init; }
public string JoinedAt { get; init; }
}
}

88
Myriad/Types/Message.cs Normal file
View File

@ -0,0 +1,88 @@
using System;
using System.Collections.Generic;
using System.Net.Mail;
using System.Text.Json.Serialization;
using Myriad.Utils;
namespace Myriad.Types
{
public record Message
{
[Flags]
public enum MessageFlags
{
Crossposted = 1 << 0,
IsCrosspost = 1 << 1,
SuppressEmbeds = 1 << 2,
SourceMessageDeleted = 1 << 3,
Urgent = 1 << 4,
Ephemeral = 1 << 6
}
public enum MessageType
{
Default = 0,
RecipientAdd = 1,
RecipientRemove = 2,
Call = 3,
ChannelNameChange = 4,
ChannelIconChange = 5,
ChannelPinnedMessage = 6,
GuildMemberJoin = 7,
UserPremiumGuildSubscription = 8,
UserPremiumGuildSubscriptionTier1 = 9,
UserPremiumGuildSubscriptionTier2 = 10,
UserPremiumGuildSubscriptionTier3 = 11,
ChannelFollowAdd = 12,
GuildDiscoveryDisqualified = 14,
GuildDiscoveryRequalified = 15,
Reply = 19,
ApplicationCommand = 20
}
public ulong Id { get; init; }
public ulong ChannelId { get; init; }
public ulong? GuildId { get; init; }
public User Author { get; init; }
public string? Content { get; init; }
public string? Timestamp { get; init; }
public string? EditedTimestamp { get; init; }
public bool Tts { get; init; }
public bool MentionEveryone { get; init; }
public User.Extra[] Mentions { get; init; }
public ulong[] MentionRoles { get; init; }
public Attachment[] Attachments { get; init; }
public Embed[] Embeds { get; init; }
public Reaction[] Reactions { get; init; }
public bool Pinned { get; init; }
public ulong? WebhookId { get; init; }
public MessageType Type { get; init; }
public Reference? MessageReference { get; set; }
public MessageFlags Flags { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<Message?> ReferencedMessage { get; init; }
public record Reference(ulong? GuildId, ulong? ChannelId, ulong? MessageId);
public record Attachment
{
public ulong Id { get; init; }
public string Filename { get; init; }
public int Size { get; init; }
public string Url { get; init; }
public string ProxyUrl { get; init; }
public int? Width { get; init; }
public int? Height { get; init; }
}
public record Reaction
{
public int Count { get; init; }
public bool Me { get; init; }
public Emoji Emoji { get; init; }
}
}
}

View File

@ -0,0 +1,47 @@
using System;
namespace Myriad.Types
{
[Flags]
public enum PermissionSet: ulong
{
CreateInvite = 0x1,
KickMembers = 0x2,
BanMembers = 0x4,
Administrator = 0x8,
ManageChannels = 0x10,
ManageGuild = 0x20,
AddReactions = 0x40,
ViewAuditLog = 0x80,
PrioritySpeaker = 0x100,
Stream = 0x200,
ViewChannel = 0x400,
SendMessages = 0x800,
SendTtsMessages = 0x1000,
ManageMessages = 0x2000,
EmbedLinks = 0x4000,
AttachFiles = 0x8000,
ReadMessageHistory = 0x10000,
MentionEveryone = 0x20000,
UseExternalEmojis = 0x40000,
ViewGuildInsights = 0x80000,
Connect = 0x100000,
Speak = 0x200000,
MuteMembers = 0x400000,
DeafenMembers = 0x800000,
MoveMembers = 0x1000000,
UseVad = 0x2000000,
ChangeNickname = 0x4000000,
ManageNicknames = 0x8000000,
ManageRoles = 0x10000000,
ManageWebhooks = 0x20000000,
ManageEmojis = 0x40000000,
// Special:
None = 0,
All = 0x7FFFFFFF,
Dm = ViewChannel | SendMessages | ReadMessageHistory | AddReactions | AttachFiles | EmbedLinks |
UseExternalEmojis | Connect | Speak | UseVad
}
}

View File

@ -0,0 +1,6 @@
namespace Myriad.Types
{
public static class Permissions
{
}
}

14
Myriad/Types/Role.cs Normal file
View File

@ -0,0 +1,14 @@
namespace Myriad.Types
{
public record Role
{
public ulong Id { get; init; }
public string Name { get; init; }
public uint Color { get; init; }
public bool Hoist { get; init; }
public int Position { get; init; }
public PermissionSet Permissions { get; init; }
public bool Managed { get; init; }
public bool Mentionable { get; init; }
}
}

38
Myriad/Types/User.cs Normal file
View File

@ -0,0 +1,38 @@
using System;
namespace Myriad.Types
{
public record User
{
[Flags]
public enum Flags
{
DiscordEmployee = 1 << 0,
PartneredServerOwner = 1 << 1,
HypeSquadEvents = 1 << 2,
BugHunterLevel1 = 1 << 3,
HouseBravery = 1 << 6,
HouseBrilliance = 1 << 7,
HouseBalance = 1 << 8,
EarlySupporter = 1 << 9,
TeamUser = 1 << 10,
System = 1 << 12,
BugHunterLevel2 = 1 << 14,
VerifiedBot = 1 << 16,
EarlyVerifiedBotDeveloper = 1 << 17
}
public ulong Id { get; init; }
public string Username { get; init; }
public string Discriminator { get; init; }
public string? Avatar { get; init; }
public bool Bot { get; init; }
public bool? System { get; init; }
public Flags PublicFlags { get; init; }
public record Extra: User
{
public GuildMemberPartial? Member { get; init; }
}
}
}

21
Myriad/Types/Webhook.cs Normal file
View File

@ -0,0 +1,21 @@
namespace Myriad.Types
{
public record Webhook
{
public ulong Id { get; init; }
public WebhookType Type { get; init; }
public ulong? GuildId { get; init; }
public ulong ChannelId { get; init; }
public User? User { get; init; }
public string? Name { get; init; }
public string? Avatar { get; init; }
public string? Token { get; init; }
public ulong? ApplicationId { get; init; }
}
public enum WebhookType
{
Incoming = 1,
ChannelFollower = 2
}
}

26
Myriad/Utils/Optional.cs Normal file
View File

@ -0,0 +1,26 @@
namespace Myriad.Utils
{
public interface IOptional
{
object? GetValue();
}
public readonly struct Optional<T>: IOptional
{
public Optional(T value)
{
HasValue = true;
Value = value;
}
public bool HasValue { get; }
public object? GetValue() => Value;
public T Value { get; }
public static implicit operator Optional<T>(T value) => new(value);
public static Optional<T> Some(T value) => new(value);
public static Optional<T> None() => default;
}
}

View File

@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
@ -9,10 +10,12 @@ using App.Metrics;
using Autofac;
using DSharpPlus;
using DSharpPlus.Entities;
using DSharpPlus.EventArgs;
using DSharpPlus.Exceptions;
using Myriad.Cache;
using Myriad.Extensions;
using Myriad.Gateway;
using Myriad.Rest;
using Myriad.Rest.Exceptions;
using Myriad.Types;
using NodaTime;
@ -27,47 +30,38 @@ namespace PluralKit.Bot
{
public class Bot
{
private readonly DiscordShardedClient _client;
private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new();
private readonly Cluster _cluster;
private readonly DiscordApiClient _rest;
private readonly ILogger _logger;
private readonly ILifetimeScope _services;
private readonly PeriodicStatCollector _collector;
private readonly IMetrics _metrics;
private readonly ErrorMessageService _errorMessageService;
private readonly CommandMessageService _commandMessageService;
private readonly IDiscordCache _cache;
private bool _hasReceivedReady = false;
private Timer _periodicTask; // Never read, just kept here for GC reasons
public Bot(DiscordShardedClient client, ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics,
ErrorMessageService errorMessageService, CommandMessageService commandMessageService)
public Bot(ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics,
ErrorMessageService errorMessageService, CommandMessageService commandMessageService, Cluster cluster, DiscordApiClient rest, IDiscordCache cache)
{
_client = client;
_logger = logger.ForContext<Bot>();
_services = services;
_collector = collector;
_metrics = metrics;
_errorMessageService = errorMessageService;
_commandMessageService = commandMessageService;
_cluster = cluster;
_rest = rest;
_cache = cache;
}
public void Init()
{
// HandleEvent takes a type parameter, automatically inferred by the event type
// It will then look up an IEventHandler<TypeOfEvent> in the DI container and call that object's handler method
// For registering new ones, see Modules.cs
_client.MessageCreated += HandleEvent;
_client.MessageDeleted += HandleEvent;
_client.MessageUpdated += HandleEvent;
_client.MessagesBulkDeleted += HandleEvent;
_client.MessageReactionAdded += HandleEvent;
// Update shard status for shards immediately on connect
_client.Ready += (client, _) =>
{
_hasReceivedReady = true;
return UpdateBotStatus(client);
};
_client.Resumed += (client, _) => UpdateBotStatus(client);
_cluster.EventReceived += OnEventReceived;
// Init the shard stuff
_services.Resolve<ShardInfoService>().Init();
@ -83,6 +77,69 @@ namespace PluralKit.Bot
}, null, timeTillNextWholeMinute, TimeSpan.FromMinutes(1));
}
public PermissionSet PermissionsIn(ulong channelId)
{
var channel = _cache.GetChannel(channelId);
if (channel.GuildId != null)
{
var member = _guildMembers.GetValueOrDefault(channel.GuildId.Value);
return _cache.PermissionsFor(channelId, _cluster.User?.Id ?? default, member?.Roles);
}
return PermissionSet.Dm;
}
private async Task OnEventReceived(Shard shard, IGatewayEvent evt)
{
await _cache.HandleGatewayEvent(evt);
TryUpdateSelfMember(shard, evt);
// HandleEvent takes a type parameter, automatically inferred by the event type
// It will then look up an IEventHandler<TypeOfEvent> in the DI container and call that object's handler method
// For registering new ones, see Modules.cs
if (evt is MessageCreateEvent mc)
await HandleEvent(shard, mc);
if (evt is MessageUpdateEvent mu)
await HandleEvent(shard, mu);
if (evt is MessageDeleteEvent md)
await HandleEvent(shard, md);
if (evt is MessageDeleteBulkEvent mdb)
await HandleEvent(shard, mdb);
if (evt is MessageReactionAddEvent mra)
await HandleEvent(shard, mra);
// Update shard status for shards immediately on connect
if (evt is ReadyEvent re)
await HandleReady(shard, re);
if (evt is ResumedEvent)
await HandleResumed(shard);
}
private void TryUpdateSelfMember(Shard shard, IGatewayEvent evt)
{
if (evt is GuildCreateEvent gc)
_guildMembers[gc.Id] = gc.Members.FirstOrDefault(m => m.User.Id == shard.User?.Id);
if (evt is MessageCreateEvent mc && mc.Member != null && mc.Author.Id == shard.User?.Id)
_guildMembers[mc.GuildId!.Value] = mc.Member;
if (evt is GuildMemberAddEvent gma && gma.User.Id == shard.User?.Id)
_guildMembers[gma.GuildId] = gma;
if (evt is GuildMemberUpdateEvent gmu && gmu.User.Id == shard.User?.Id)
_guildMembers[gmu.GuildId] = gmu;
}
private Task HandleResumed(Shard shard)
{
return UpdateBotStatus(shard);
}
private Task HandleReady(Shard shard, ReadyEvent _)
{
_hasReceivedReady = true;
return UpdateBotStatus(shard);
}
public async Task Shutdown()
{
// This will stop the timer and prevent any subsequent invocations
@ -92,10 +149,24 @@ namespace PluralKit.Bot
// We're not actually properly disconnecting from the gateway (lol) so it'll linger for a few minutes
// Should be plenty of time for the bot to connect again next startup and set the real status
if (_hasReceivedReady)
await _client.UpdateStatusAsync(new DiscordActivity("Restarting... (please wait)"), UserStatus.Idle);
{
await Task.WhenAll(_cluster.Shards.Values.Select(shard =>
shard.UpdateStatus(new GatewayStatusUpdate
{
Activities = new[]
{
new ActivityPartial
{
Name = "Restarting... (please wait)",
Type = ActivityType.Game
}
},
Status = GatewayStatusUpdate.UserStatus.Idle
})));
}
}
private Task HandleEvent<T>(DiscordClient shard, T evt) where T: DiscordEventArgs
private Task HandleEvent<T>(Shard shard, T evt) where T: IGatewayEvent
{
// We don't want to stall the event pipeline, so we'll "fork" inside here
var _ = HandleEventInner();
@ -103,6 +174,8 @@ namespace PluralKit.Bot
async Task HandleEventInner()
{
await Task.Yield();
using var _ = LogContext.PushProperty("EventId", Guid.NewGuid());
_logger
.ForContext("Elastic", "yes?")
@ -121,7 +194,7 @@ namespace PluralKit.Bot
try
{
using var timer = _metrics.Measure.Timer.Time(BotMetrics.EventsHandled,
new MetricTags("event", typeof(T).Name.Replace("EventArgs", "")));
new MetricTags("event", typeof(T).Name.Replace("Event", "")));
// Delegate to the queue to see if it wants to handle this event
// the TryHandle call returns true if it's handled the event
@ -131,13 +204,13 @@ namespace PluralKit.Bot
}
catch (Exception exc)
{
await HandleError(handler, evt, serviceScope, exc);
await HandleError(shard, handler, evt, serviceScope, exc);
}
}
}
private async Task HandleError<T>(IEventHandler<T> handler, T evt, ILifetimeScope serviceScope, Exception exc)
where T: DiscordEventArgs
private async Task HandleError<T>(Shard shard, IEventHandler<T> handler, T evt, ILifetimeScope serviceScope, Exception exc)
where T: IGatewayEvent
{
_metrics.Measure.Meter.Mark(BotMetrics.BotErrors, exc.GetType().FullName);
@ -149,7 +222,7 @@ namespace PluralKit.Bot
.Error(exc, "Exception in event handler: {SentryEventId}", sentryEvent.EventId);
// If the event is us responding to our own error messages, don't bother logging
if (evt is MessageCreateEventArgs mc && mc.Author.Id == _client.CurrentUser.Id)
if (evt is MessageCreateEvent mc && mc.Author.Id == shard.User?.Id)
return;
var shouldReport = exc.IsOurProblem();
@ -160,19 +233,24 @@ namespace PluralKit.Bot
var sentryScope = serviceScope.Resolve<Scope>();
// Add some specific info about Discord error responses, as a breadcrumb
if (exc is BadRequestException bre)
sentryScope.AddBreadcrumb(bre.WebResponse.Response, "response.error", data: new Dictionary<string, string>(bre.WebResponse.Headers));
if (exc is NotFoundException nfe)
sentryScope.AddBreadcrumb(nfe.WebResponse.Response, "response.error", data: new Dictionary<string, string>(nfe.WebResponse.Headers));
if (exc is UnauthorizedException ue)
sentryScope.AddBreadcrumb(ue.WebResponse.Response, "response.error", data: new Dictionary<string, string>(ue.WebResponse.Headers));
// TODO: headers to dict
// if (exc is BadRequestException bre)
// sentryScope.AddBreadcrumb(bre.Response, "response.error", data: new Dictionary<string, string>(bre.Response.Headers));
// if (exc is NotFoundException nfe)
// sentryScope.AddBreadcrumb(nfe.Response, "response.error", data: new Dictionary<string, string>(nfe.Response.Headers));
// if (exc is UnauthorizedException ue)
// sentryScope.AddBreadcrumb(ue.Response, "response.error", data: new Dictionary<string, string>(ue.Response.Headers));
SentrySdk.CaptureEvent(sentryEvent, sentryScope);
// Once we've sent it to Sentry, report it to the user (if we have permission to)
var reportChannel = handler.ErrorChannelFor(evt);
if (reportChannel != null && reportChannel.BotHasAllPermissions(Permissions.SendMessages | Permissions.EmbedLinks))
await _errorMessageService.SendErrorMessage(reportChannel, sentryEvent.EventId.ToString());
if (reportChannel != null)
{
var botPerms = PermissionsIn(reportChannel.Value);
if (botPerms.HasFlag(PermissionSet.SendMessages | PermissionSet.EmbedLinks))
await _errorMessageService.SendErrorMessage(reportChannel.Value, sentryEvent.EventId.ToString());
}
}
}
@ -191,23 +269,38 @@ namespace PluralKit.Bot
_logger.Debug("Submitted metrics to backend");
}
private async Task UpdateBotStatus(DiscordClient specificShard = null)
private async Task UpdateBotStatus(Shard specificShard = null)
{
// If we're not on any shards, don't bother (this happens if the periodic timer fires before the first Ready)
if (!_hasReceivedReady) return;
var totalGuilds = _client.ShardClients.Values.Sum(c => c.Guilds.Count);
var totalGuilds = await _cache.GetAllGuilds().CountAsync();
try // DiscordClient may throw an exception if the socket is closed (e.g just after OP 7 received)
{
Task UpdateStatus(DiscordClient shard) =>
shard.UpdateStatusAsync(new DiscordActivity($"pk;help | in {totalGuilds} servers | shard #{shard.ShardId}"));
Task UpdateStatus(Shard shard) =>
shard.UpdateStatus(new GatewayStatusUpdate
{
Activities = new[]
{
new ActivityPartial
{
Name = $"pk;help | in {totalGuilds} servers | shard #{shard.ShardInfo?.ShardId}",
Type = ActivityType.Game,
Url = "https://pluralkit.me/"
}
}
});
if (specificShard != null)
await UpdateStatus(specificShard);
else // Run shard updates concurrently
await Task.WhenAll(_client.ShardClients.Values.Select(UpdateStatus));
await Task.WhenAll(_cluster.Shards.Values.Select(UpdateStatus));
}
catch (WebSocketException)
{
// TODO: this still thrown?
}
catch (WebSocketException) { }
}
}
}

View File

@ -1,13 +1,17 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using App.Metrics;
using Autofac;
using DSharpPlus;
using DSharpPlus.Entities;
using Myriad.Cache;
using Myriad.Extensions;
using Myriad.Gateway;
using Myriad.Rest;
using Myriad.Rest.Types;
using Myriad.Rest.Types.Requests;
using Myriad.Types;
using PluralKit.Core;
@ -17,47 +21,65 @@ namespace PluralKit.Bot
{
private readonly ILifetimeScope _provider;
private readonly DiscordRestClient _rest;
private readonly DiscordShardedClient _client;
private readonly DiscordClient _shard;
private readonly DiscordMessage _message;
private readonly DiscordApiClient _rest;
private readonly Cluster _cluster;
private readonly Shard _shard;
private readonly Guild? _guild;
private readonly Channel _channel;
private readonly MessageCreateEvent _message;
private readonly Parameters _parameters;
private readonly MessageContext _messageContext;
private readonly PermissionSet _botPermissions;
private readonly PermissionSet _userPermissions;
private readonly IDatabase _db;
private readonly ModelRepository _repo;
private readonly PKSystem _senderSystem;
private readonly IMetrics _metrics;
private readonly CommandMessageService _commandMessageService;
private readonly IDiscordCache _cache;
private Command _currentCommand;
public Context(ILifetimeScope provider, DiscordClient shard, DiscordMessage message, int commandParseOffset,
PKSystem senderSystem, MessageContext messageContext)
public Context(ILifetimeScope provider, Shard shard, Guild? guild, Channel channel, MessageCreateEvent message, int commandParseOffset,
PKSystem senderSystem, MessageContext messageContext, PermissionSet botPermissions)
{
_rest = provider.Resolve<DiscordRestClient>();
_client = provider.Resolve<DiscordShardedClient>();
_message = message;
_shard = shard;
_guild = guild;
_channel = channel;
_senderSystem = senderSystem;
_messageContext = messageContext;
_cache = provider.Resolve<IDiscordCache>();
_db = provider.Resolve<IDatabase>();
_repo = provider.Resolve<ModelRepository>();
_metrics = provider.Resolve<IMetrics>();
_provider = provider;
_commandMessageService = provider.Resolve<CommandMessageService>();
_parameters = new Parameters(message.Content.Substring(commandParseOffset));
_parameters = new Parameters(message.Content?.Substring(commandParseOffset));
_rest = provider.Resolve<DiscordApiClient>();
_cluster = provider.Resolve<Cluster>();
_botPermissions = botPermissions;
_userPermissions = _cache.PermissionsFor(message);
}
public DiscordUser Author => _message.Author;
public DiscordChannel Channel => _message.Channel;
public DiscordMessage Message => _message;
public DiscordGuild Guild => _message.Channel.Guild;
public DiscordClient Shard => _shard;
public DiscordShardedClient Client => _client;
public IDiscordCache Cache => _cache;
public Channel Channel => _channel;
public User Author => _message.Author;
public GuildMemberPartial Member => _message.Member;
public Message Message => _message;
public Guild Guild => _guild;
public Shard Shard => _shard;
public Cluster Cluster => _cluster;
public MessageContext MessageContext => _messageContext;
public DiscordRestClient Rest => _rest;
public PermissionSet BotPermissions => _botPermissions;
public PermissionSet UserPermissions => _userPermissions;
public DiscordApiClient Rest => _rest;
public PKSystem System => _senderSystem;
@ -66,15 +88,22 @@ namespace PluralKit.Bot
internal IDatabase Database => _db;
internal ModelRepository Repository => _repo;
public async Task<DiscordMessage> Reply(string text = null, DiscordEmbed embed = null, IEnumerable<IMention> mentions = null)
public async Task<Message> Reply(string text = null, Embed embed = null, AllowedMentions? mentions = null)
{
if (!this.BotHasAllPermissions(Permissions.SendMessages))
if (!BotPermissions.HasFlag(PermissionSet.SendMessages))
// Will be "swallowed" during the error handler anyway, this message is never shown.
throw new PKError("PluralKit does not have permission to send messages in this channel.");
if (embed != null && !this.BotHasAllPermissions(Permissions.EmbedLinks))
if (embed != null && !BotPermissions.HasFlag(PermissionSet.EmbedLinks))
throw new PKError("PluralKit does not have permission to send embeds in this channel. Please ensure I have the **Embed Links** permission enabled.");
var msg = await Channel.SendMessageFixedAsync(text, embed: embed, mentions: mentions);
var msg = await _rest.CreateMessage(_channel.Id, new MessageRequest
{
Content = text,
Embed = embed,
// Default to an empty allowed mentions object instead of null (which means no mentions allowed)
AllowedMentions = mentions ?? new AllowedMentions()
});
if (embed != null)
{

Some files were not shown because too many files have changed in this diff Show More