feat: upgrade to .NET 6, refactor everything

This commit is contained in:
spiral 2021-11-26 21:10:56 -05:00
parent d28e99ba43
commit 1918c56937
No known key found for this signature in database
GPG Key ID: A6059F0CA0E1BD31
314 changed files with 27954 additions and 27966 deletions

View File

@ -1,53 +1,52 @@
[*] [*]
charset=utf-8 charset = utf-8
end_of_line=lf end_of_line = lf
trim_trailing_whitespace=false trim_trailing_whitespace = false
insert_final_newline=false insert_final_newline = false
indent_style=space indent_style = space
indent_size=4 indent_size = 4
# Microsoft .NET properties # Microsoft .NET properties
csharp_preferred_modifier_order=public, private, protected, internal, new, abstract, virtual, sealed, override, static, readonly, extern, unsafe, volatile, async:suggestion csharp_preferred_modifier_order = public, private, protected, internal, new, abstract, virtual, sealed, override, static, readonly, extern, unsafe, volatile, async:suggestion
csharp_space_before_colon_in_inheritance_clause=false csharp_space_before_colon_in_inheritance_clause = false
csharp_style_var_elsewhere=true:hint csharp_style_var_elsewhere = true:hint
csharp_style_var_for_built_in_types=true:hint csharp_style_var_for_built_in_types = true:hint
csharp_style_var_when_type_is_apparent=true:hint csharp_style_var_when_type_is_apparent = true:hint
dotnet_style_predefined_type_for_locals_parameters_members=true:hint dotnet_style_predefined_type_for_locals_parameters_members = true:hint
dotnet_style_predefined_type_for_member_access=true:hint dotnet_style_predefined_type_for_member_access = true:hint
dotnet_style_qualification_for_event=false:warning dotnet_style_qualification_for_event = false:warning
dotnet_style_qualification_for_field=false:warning dotnet_style_qualification_for_field = false:warning
dotnet_style_qualification_for_method=false:warning dotnet_style_qualification_for_method = false:warning
dotnet_style_qualification_for_property=false:warning dotnet_style_qualification_for_property = false:warning
dotnet_style_require_accessibility_modifiers=for_non_interface_members:hint dotnet_style_require_accessibility_modifiers = for_non_interface_members:hint
# ReSharper properties # ReSharper properties
resharper_align_multiline_parameter=true resharper_align_multiline_parameter = true
resharper_autodetect_indent_settings=true resharper_autodetect_indent_settings = true
resharper_blank_lines_between_using_groups=1 resharper_blank_lines_between_using_groups = 1
resharper_braces_for_using=required_for_multiline resharper_braces_for_using = required_for_multiline
resharper_csharp_stick_comment=false resharper_csharp_stick_comment = false
resharper_empty_block_style=together_same_line resharper_empty_block_style = together_same_line
resharper_keep_existing_attribute_arrangement=true resharper_keep_existing_attribute_arrangement = true
resharper_keep_existing_initializer_arrangement=false resharper_keep_existing_initializer_arrangement = false
resharper_local_function_body=expression_body resharper_local_function_body = expression_body
resharper_method_or_operator_body=expression_body resharper_method_or_operator_body = expression_body
resharper_place_accessor_with_attrs_holder_on_single_line=true resharper_place_accessor_with_attrs_holder_on_single_line = true
resharper_place_simple_case_statement_on_same_line=if_owner_is_single_line resharper_place_simple_case_statement_on_same_line = if_owner_is_single_line
resharper_space_before_type_parameter_constraint_colon=false resharper_space_before_type_parameter_constraint_colon = false
resharper_use_indent_from_vs=false resharper_use_indent_from_vs = false
resharper_wrap_before_first_type_parameter_constraint=true resharper_wrap_before_first_type_parameter_constraint = true
# ReSharper inspection severities: # ReSharper inspection severities:
resharper_web_config_module_not_resolved_highlighting=warning resharper_web_config_module_not_resolved_highlighting = warning
resharper_web_config_type_not_resolved_highlighting=warning resharper_web_config_type_not_resolved_highlighting = warning
resharper_web_config_wrong_module_highlighting=warning resharper_web_config_wrong_module_highlighting = warning
[{*.yml,*.yaml}] [{*.yml,*.yaml}]
indent_style=space indent_style = space
indent_size=2 indent_size = 2
[*.{appxmanifest,asax,ascx,aspx,build,config,cs,cshtml,csproj,dbml,discomap,dtd,fs,fsi,fsscript,fsx,htm,html,jsproj,lsproj,master,ml,mli,njsproj,nuspec,proj,props,razor,resw,resx,skin,StyleCop,targets,tasks,vb,vbproj,xaml,xamlx,xml,xoml,xsd}] [*.{appxmanifest,asax,ascx,aspx,build,config,cs,cshtml,csproj,dbml,discomap,dtd,fs,fsi,fsscript,fsx,htm,html,jsproj,lsproj,master,ml,mli,njsproj,nuspec,proj,props,razor,resw,resx,skin,StyleCop,targets,tasks,vb,vbproj,xaml,xamlx,xml,xoml,xsd}]
indent_style=space indent_style = space
indent_size=4 indent_size = 4
tab_width=4 tab_width = 4

View File

@ -1,91 +1,73 @@
using System.Collections.Generic;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Builders namespace Myriad.Builders;
public class EmbedBuilder
{ {
public class EmbedBuilder private readonly List<Embed.Field> _fields = new();
private Embed _embed = new();
public EmbedBuilder Title(string? title)
{ {
private Embed _embed = new(); _embed = _embed with { Title = title };
private readonly List<Embed.Field> _fields = new(); return this;
public EmbedBuilder Title(string? title)
{
_embed = _embed with { Title = title };
return this;
}
public EmbedBuilder Description(string? description)
{
_embed = _embed with { Description = description };
return this;
}
public EmbedBuilder Url(string? url)
{
_embed = _embed with { Url = url };
return this;
}
public EmbedBuilder Color(uint? color)
{
_embed = _embed with { Color = color };
return this;
}
public EmbedBuilder Footer(Embed.EmbedFooter? footer)
{
_embed = _embed with
{
Footer = footer
};
return this;
}
public EmbedBuilder Image(Embed.EmbedImage? image)
{
_embed = _embed with
{
Image = image
};
return this;
}
public EmbedBuilder Thumbnail(Embed.EmbedThumbnail? thumbnail)
{
_embed = _embed with
{
Thumbnail = thumbnail
};
return this;
}
public EmbedBuilder Author(Embed.EmbedAuthor? author)
{
_embed = _embed with
{
Author = author
};
return this;
}
public EmbedBuilder Timestamp(string? timestamp)
{
_embed = _embed with
{
Timestamp = timestamp
};
return this;
}
public EmbedBuilder Field(Embed.Field field)
{
_fields.Add(field);
return this;
}
public Embed Build() =>
_embed with { Fields = _fields.ToArray() };
} }
public EmbedBuilder Description(string? description)
{
_embed = _embed with { Description = description };
return this;
}
public EmbedBuilder Url(string? url)
{
_embed = _embed with { Url = url };
return this;
}
public EmbedBuilder Color(uint? color)
{
_embed = _embed with { Color = color };
return this;
}
public EmbedBuilder Footer(Embed.EmbedFooter? footer)
{
_embed = _embed with { Footer = footer };
return this;
}
public EmbedBuilder Image(Embed.EmbedImage? image)
{
_embed = _embed with { Image = image };
return this;
}
public EmbedBuilder Thumbnail(Embed.EmbedThumbnail? thumbnail)
{
_embed = _embed with { Thumbnail = thumbnail };
return this;
}
public EmbedBuilder Author(Embed.EmbedAuthor? author)
{
_embed = _embed with { Author = author };
return this;
}
public EmbedBuilder Timestamp(string? timestamp)
{
_embed = _embed with { Timestamp = timestamp };
return this;
}
public EmbedBuilder Field(Embed.Field field)
{
_fields.Add(field);
return this;
}
public Embed Build() =>
_embed with { Fields = _fields.ToArray() };
} }

View File

@ -1,125 +1,118 @@
using System.Linq;
using System.Threading.Tasks;
using Myriad.Extensions; using Myriad.Extensions;
using Myriad.Gateway; using Myriad.Gateway;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Cache namespace Myriad.Cache;
public static class DiscordCacheExtensions
{ {
public static class DiscordCacheExtensions public static ValueTask HandleGatewayEvent(this IDiscordCache cache, IGatewayEvent evt)
{ {
public static ValueTask HandleGatewayEvent(this IDiscordCache cache, IGatewayEvent evt) switch (evt)
{ {
switch (evt) case ReadyEvent ready:
{ return cache.SaveOwnUser(ready.User.Id);
case ReadyEvent ready: case GuildCreateEvent gc:
return cache.SaveOwnUser(ready.User.Id); return cache.SaveGuildCreate(gc);
case GuildCreateEvent gc: case GuildUpdateEvent gu:
return cache.SaveGuildCreate(gc); return cache.SaveGuild(gu);
case GuildUpdateEvent gu: case GuildDeleteEvent gd:
return cache.SaveGuild(gu); return cache.RemoveGuild(gd.Id);
case GuildDeleteEvent gd: case ChannelCreateEvent cc:
return cache.RemoveGuild(gd.Id); return cache.SaveChannel(cc);
case ChannelCreateEvent cc: case ChannelUpdateEvent cu:
return cache.SaveChannel(cc); return cache.SaveChannel(cu);
case ChannelUpdateEvent cu: case ChannelDeleteEvent cd:
return cache.SaveChannel(cu); return cache.RemoveChannel(cd.Id);
case ChannelDeleteEvent cd: case GuildRoleCreateEvent grc:
return cache.RemoveChannel(cd.Id); return cache.SaveRole(grc.GuildId, grc.Role);
case GuildRoleCreateEvent grc: case GuildRoleUpdateEvent gru:
return cache.SaveRole(grc.GuildId, grc.Role); return cache.SaveRole(gru.GuildId, gru.Role);
case GuildRoleUpdateEvent gru: case GuildRoleDeleteEvent grd:
return cache.SaveRole(gru.GuildId, gru.Role); return cache.RemoveRole(grd.GuildId, grd.RoleId);
case GuildRoleDeleteEvent grd: case MessageReactionAddEvent mra:
return cache.RemoveRole(grd.GuildId, grd.RoleId); return cache.TrySaveDmChannelStub(mra.GuildId, mra.ChannelId);
case MessageReactionAddEvent mra: case MessageCreateEvent mc:
return cache.TrySaveDmChannelStub(mra.GuildId, mra.ChannelId); return cache.SaveMessageCreate(mc);
case MessageCreateEvent mc: case MessageUpdateEvent mu:
return cache.SaveMessageCreate(mc); return cache.TrySaveDmChannelStub(mu.GuildId.Value, mu.ChannelId);
case MessageUpdateEvent mu: case MessageDeleteEvent md:
return cache.TrySaveDmChannelStub(mu.GuildId.Value, mu.ChannelId); return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId);
case MessageDeleteEvent md: case MessageDeleteBulkEvent md:
return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId); return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId);
case MessageDeleteBulkEvent md: case ThreadCreateEvent tc:
return cache.TrySaveDmChannelStub(md.GuildId, md.ChannelId); return cache.SaveChannel(tc);
case ThreadCreateEvent tc: case ThreadUpdateEvent tu:
return cache.SaveChannel(tc); return cache.SaveChannel(tu);
case ThreadUpdateEvent tu: case ThreadDeleteEvent td:
return cache.SaveChannel(tu); return cache.RemoveChannel(td.Id);
case ThreadDeleteEvent td: case ThreadListSyncEvent tls:
return cache.RemoveChannel(td.Id); return cache.SaveThreadListSync(tls);
case ThreadListSyncEvent tls:
return cache.SaveThreadListSync(tls);
}
return default;
} }
public static ValueTask TryUpdateSelfMember(this IDiscordCache cache, Shard shard, IGatewayEvent evt) return default;
{ }
if (evt is GuildCreateEvent gc)
return cache.SaveSelfMember(gc.Id, gc.Members.FirstOrDefault(m => m.User.Id == shard.User?.Id)!);
if (evt is MessageCreateEvent mc && mc.Member != null && mc.Author.Id == shard.User?.Id)
return cache.SaveSelfMember(mc.GuildId!.Value, mc.Member);
if (evt is GuildMemberAddEvent gma && gma.User.Id == shard.User?.Id)
return cache.SaveSelfMember(gma.GuildId, gma);
if (evt is GuildMemberUpdateEvent gmu && gmu.User.Id == shard.User?.Id)
return cache.SaveSelfMember(gmu.GuildId, gmu);
return default; public static ValueTask TryUpdateSelfMember(this IDiscordCache cache, Shard shard, IGatewayEvent evt)
} {
if (evt is GuildCreateEvent gc)
private static async ValueTask SaveGuildCreate(this IDiscordCache cache, GuildCreateEvent guildCreate) return cache.SaveSelfMember(gc.Id, gc.Members.FirstOrDefault(m => m.User.Id == shard.User?.Id)!);
{ if (evt is MessageCreateEvent mc && mc.Member != null && mc.Author.Id == shard.User?.Id)
await cache.SaveGuild(guildCreate); return cache.SaveSelfMember(mc.GuildId!.Value, mc.Member);
if (evt is GuildMemberAddEvent gma && gma.User.Id == shard.User?.Id)
foreach (var channel in guildCreate.Channels) return cache.SaveSelfMember(gma.GuildId, gma);
// The channel object does not include GuildId for some reason... if (evt is GuildMemberUpdateEvent gmu && gmu.User.Id == shard.User?.Id)
await cache.SaveChannel(channel with { GuildId = guildCreate.Id }); return cache.SaveSelfMember(gmu.GuildId, gmu);
foreach (var member in guildCreate.Members) return default;
await cache.SaveUser(member.User); }
foreach (var thread in guildCreate.Threads) private static async ValueTask SaveGuildCreate(this IDiscordCache cache, GuildCreateEvent guildCreate)
await cache.SaveChannel(thread); {
} await cache.SaveGuild(guildCreate);
private static async ValueTask SaveMessageCreate(this IDiscordCache cache, MessageCreateEvent evt) foreach (var channel in guildCreate.Channels)
{ // The channel object does not include GuildId for some reason...
await cache.TrySaveDmChannelStub(evt.GuildId, evt.ChannelId); await cache.SaveChannel(channel with { GuildId = guildCreate.Id });
await cache.SaveUser(evt.Author); foreach (var member in guildCreate.Members)
foreach (var mention in evt.Mentions) await cache.SaveUser(member.User);
await cache.SaveUser(mention);
} foreach (var thread in guildCreate.Threads)
await cache.SaveChannel(thread);
private static ValueTask TrySaveDmChannelStub(this IDiscordCache cache, ulong? guildId, ulong channelId) }
{
// DM messages don't get Channel Create events first, so we need to save private static async ValueTask SaveMessageCreate(this IDiscordCache cache, MessageCreateEvent evt)
// some kind of stub channel object until we get the real one {
return guildId != null ? default : cache.SaveDmChannelStub(channelId); await cache.TrySaveDmChannelStub(evt.GuildId, evt.ChannelId);
}
await cache.SaveUser(evt.Author);
private static async ValueTask SaveThreadListSync(this IDiscordCache cache, ThreadListSyncEvent evt) foreach (var mention in evt.Mentions)
{ await cache.SaveUser(mention);
foreach (var thread in evt.Threads) }
await cache.SaveChannel(thread);
} private static ValueTask TrySaveDmChannelStub(this IDiscordCache cache, ulong? guildId, ulong channelId) =>
// DM messages don't get Channel Create events first, so we need to save
public static async Task<PermissionSet> PermissionsIn(this IDiscordCache cache, ulong channelId) // some kind of stub channel object until we get the real one
{ guildId != null ? default : cache.SaveDmChannelStub(channelId);
var channel = await cache.GetRootChannel(channelId);
private static async ValueTask SaveThreadListSync(this IDiscordCache cache, ThreadListSyncEvent evt)
if (channel.GuildId != null) {
{ foreach (var thread in evt.Threads)
var userId = await cache.GetOwnUser(); await cache.SaveChannel(thread);
var member = await cache.TryGetSelfMember(channel.GuildId.Value); }
return await cache.PermissionsFor(channelId, userId, member);
} public static async Task<PermissionSet> PermissionsIn(this IDiscordCache cache, ulong channelId)
{
return PermissionSet.Dm; var channel = await cache.GetRootChannel(channelId);
if (channel.GuildId != null)
{
var userId = await cache.GetOwnUser();
var member = await cache.TryGetSelfMember(channel.GuildId.Value);
return await cache.PermissionsFor(channelId, userId, member);
} }
return PermissionSet.Dm;
} }
} }

View File

@ -1,34 +1,30 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Cache namespace Myriad.Cache;
public interface IDiscordCache
{ {
public interface IDiscordCache public ValueTask SaveOwnUser(ulong userId);
{ public ValueTask SaveGuild(Guild guild);
public ValueTask SaveOwnUser(ulong userId); public ValueTask SaveChannel(Channel channel);
public ValueTask SaveGuild(Guild guild); public ValueTask SaveUser(User user);
public ValueTask SaveChannel(Channel channel); public ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member);
public ValueTask SaveUser(User user); public ValueTask SaveRole(ulong guildId, Role role);
public ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member); public ValueTask SaveDmChannelStub(ulong channelId);
public ValueTask SaveRole(ulong guildId, Role role);
public ValueTask SaveDmChannelStub(ulong channelId);
public ValueTask RemoveGuild(ulong guildId); public ValueTask RemoveGuild(ulong guildId);
public ValueTask RemoveChannel(ulong channelId); public ValueTask RemoveChannel(ulong channelId);
public ValueTask RemoveUser(ulong userId); public ValueTask RemoveUser(ulong userId);
public ValueTask RemoveRole(ulong guildId, ulong roleId); public ValueTask RemoveRole(ulong guildId, ulong roleId);
public Task<ulong> GetOwnUser(); public Task<ulong> GetOwnUser();
public Task<Guild?> TryGetGuild(ulong guildId); public Task<Guild?> TryGetGuild(ulong guildId);
public Task<Channel?> TryGetChannel(ulong channelId); public Task<Channel?> TryGetChannel(ulong channelId);
public Task<Channel?> TryGetDmChannel(ulong userId); public Task<Channel?> TryGetDmChannel(ulong userId);
public Task<User?> TryGetUser(ulong userId); public Task<User?> TryGetUser(ulong userId);
public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId); public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId);
public Task<Role?> TryGetRole(ulong roleId); public Task<Role?> TryGetRole(ulong roleId);
public IAsyncEnumerable<Guild> GetAllGuilds(); public IAsyncEnumerable<Guild> GetAllGuilds();
public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId); public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId);
}
} }

View File

@ -1,206 +1,190 @@
using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Cache namespace Myriad.Cache;
public class MemoryDiscordCache: IDiscordCache
{ {
public class MemoryDiscordCache: IDiscordCache private readonly ConcurrentDictionary<ulong, Channel> _channels = new();
private readonly ConcurrentDictionary<ulong, ulong> _dmChannels = new();
private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new();
private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds = new();
private readonly ConcurrentDictionary<ulong, Role> _roles = new();
private readonly ConcurrentDictionary<ulong, User> _users = new();
private ulong? _ownUserId { get; set; }
public ValueTask SaveGuild(Guild guild)
{ {
private readonly ConcurrentDictionary<ulong, Channel> _channels = new(); SaveGuildRaw(guild);
private readonly ConcurrentDictionary<ulong, ulong> _dmChannels = new();
private readonly ConcurrentDictionary<ulong, CachedGuild> _guilds = new();
private readonly ConcurrentDictionary<ulong, Role> _roles = new();
private readonly ConcurrentDictionary<ulong, User> _users = new();
private readonly ConcurrentDictionary<ulong, GuildMemberPartial> _guildMembers = new();
private ulong? _ownUserId { get; set; }
public ValueTask SaveGuild(Guild guild) foreach (var role in guild.Roles)
{ // Don't call SaveRole because that updates guild state
SaveGuildRaw(guild); // and we just got a brand new one :)
foreach (var role in guild.Roles)
// Don't call SaveRole because that updates guild state
// and we just got a brand new one :)
_roles[role.Id] = role;
return default;
}
public async ValueTask SaveChannel(Channel channel)
{
_channels[channel.Id] = channel;
if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild))
guild.Channels.TryAdd(channel.Id, true);
if (channel.Recipients != null)
{
foreach (var recipient in channel.Recipients)
{
_dmChannels[recipient.Id] = channel.Id;
await SaveUser(recipient);
}
}
}
public ValueTask SaveOwnUser(ulong userId)
{
// this (hopefully) never changes at runtime, so we skip out on re-assigning it
if (_ownUserId == null)
_ownUserId = userId;
return default;
}
public ValueTask SaveUser(User user)
{
_users[user.Id] = user;
return default;
}
public ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member)
{
_guildMembers[guildId] = member;
return default;
}
public ValueTask SaveRole(ulong guildId, Role role)
{
_roles[role.Id] = role; _roles[role.Id] = role;
if (_guilds.TryGetValue(guildId, out var guild)) return default;
}
public async ValueTask SaveChannel(Channel channel)
{
_channels[channel.Id] = channel;
if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild))
guild.Channels.TryAdd(channel.Id, true);
if (channel.Recipients != null)
foreach (var recipient in channel.Recipients)
{ {
// TODO: this code is stinky _dmChannels[recipient.Id] = channel.Id;
var found = false; await SaveUser(recipient);
for (var i = 0; i < guild.Guild.Roles.Length; i++) }
{ }
if (guild.Guild.Roles[i].Id != role.Id)
continue;
guild.Guild.Roles[i] = role; public ValueTask SaveOwnUser(ulong userId)
found = true; {
} // this (hopefully) never changes at runtime, so we skip out on re-assigning it
if (_ownUserId == null)
_ownUserId = userId;
if (!found) return default;
{ }
_guilds[guildId] = guild with
{ public ValueTask SaveUser(User user)
Guild = guild.Guild with {
{ _users[user.Id] = user;
Roles = guild.Guild.Roles.Concat(new[] { role }).ToArray() return default;
} }
};
} public ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member)
{
_guildMembers[guildId] = member;
return default;
}
public ValueTask SaveRole(ulong guildId, Role role)
{
_roles[role.Id] = role;
if (_guilds.TryGetValue(guildId, out var guild))
{
// TODO: this code is stinky
var found = false;
for (var i = 0; i < guild.Guild.Roles.Length; i++)
{
if (guild.Guild.Roles[i].Id != role.Id)
continue;
guild.Guild.Roles[i] = role;
found = true;
} }
if (!found)
_guilds[guildId] = guild with
{
Guild = guild.Guild with { Roles = guild.Guild.Roles.Concat(new[] { role }).ToArray() }
};
}
return default;
}
public ValueTask SaveDmChannelStub(ulong channelId)
{
// Use existing channel object if present, otherwise add a stub
// We may get a message create before channel create and we want to have it saved
_channels.GetOrAdd(channelId, id => new Channel { Id = id, Type = Channel.ChannelType.Dm });
return default;
}
public ValueTask RemoveGuild(ulong guildId)
{
_guilds.TryRemove(guildId, out _);
return default;
}
public ValueTask RemoveChannel(ulong channelId)
{
if (!_channels.TryRemove(channelId, out var channel))
return default; return default;
}
public ValueTask SaveDmChannelStub(ulong channelId) if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild))
{ guild.Channels.TryRemove(channel.Id, out _);
// Use existing channel object if present, otherwise add a stub
// We may get a message create before channel create and we want to have it saved
_channels.GetOrAdd(channelId, id => new Channel
{
Id = id,
Type = Channel.ChannelType.Dm
});
return default;
}
public ValueTask RemoveGuild(ulong guildId) return default;
{ }
_guilds.TryRemove(guildId, out _);
return default;
}
public ValueTask RemoveChannel(ulong channelId) public ValueTask RemoveUser(ulong userId)
{ {
if (!_channels.TryRemove(channelId, out var channel)) _users.TryRemove(userId, out _);
return default; return default;
}
if (channel.GuildId != null && _guilds.TryGetValue(channel.GuildId.Value, out var guild)) public Task<ulong> GetOwnUser() => Task.FromResult(_ownUserId!.Value);
guild.Channels.TryRemove(channel.Id, out _);
return default; public ValueTask RemoveRole(ulong guildId, ulong roleId)
} {
_roles.TryRemove(roleId, out _);
return default;
}
public ValueTask RemoveUser(ulong userId) public Task<Guild?> TryGetGuild(ulong guildId)
{ {
_users.TryRemove(userId, out _); _guilds.TryGetValue(guildId, out var cg);
return default; return Task.FromResult(cg?.Guild);
} }
public Task<ulong> GetOwnUser() => Task.FromResult(_ownUserId!.Value); public Task<Channel?> TryGetChannel(ulong channelId)
{
_channels.TryGetValue(channelId, out var channel);
return Task.FromResult(channel);
}
public ValueTask RemoveRole(ulong guildId, ulong roleId) public Task<Channel?> TryGetDmChannel(ulong userId)
{ {
_roles.TryRemove(roleId, out _); if (!_dmChannels.TryGetValue(userId, out var channelId))
return default; return Task.FromResult((Channel?)null);
} return TryGetChannel(channelId);
}
public Task<Guild?> TryGetGuild(ulong guildId) public Task<User?> TryGetUser(ulong userId)
{ {
_guilds.TryGetValue(guildId, out var cg); _users.TryGetValue(userId, out var user);
return Task.FromResult(cg?.Guild); return Task.FromResult(user);
} }
public Task<Channel?> TryGetChannel(ulong channelId) public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId)
{ {
_channels.TryGetValue(channelId, out var channel); _guildMembers.TryGetValue(guildId, out var guildMember);
return Task.FromResult(channel); return Task.FromResult(guildMember);
} }
public Task<Channel?> TryGetDmChannel(ulong userId) public Task<Role?> TryGetRole(ulong roleId)
{ {
if (!_dmChannels.TryGetValue(userId, out var channelId)) _roles.TryGetValue(roleId, out var role);
return Task.FromResult((Channel?)null); return Task.FromResult(role);
return TryGetChannel(channelId); }
}
public Task<User?> TryGetUser(ulong userId) public IAsyncEnumerable<Guild> GetAllGuilds()
{ {
_users.TryGetValue(userId, out var user); return _guilds.Values
return Task.FromResult(user); .Select(g => g.Guild)
} .ToAsyncEnumerable();
}
public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId) public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId)
{ {
_guildMembers.TryGetValue(guildId, out var guildMember); if (!_guilds.TryGetValue(guildId, out var guild))
return Task.FromResult(guildMember); throw new ArgumentException("Guild not found", nameof(guildId));
}
public Task<Role?> TryGetRole(ulong roleId) return Task.FromResult(guild.Channels.Keys.Select(c => _channels[c]));
{ }
_roles.TryGetValue(roleId, out var role);
return Task.FromResult(role);
}
public IAsyncEnumerable<Guild> GetAllGuilds() private CachedGuild SaveGuildRaw(Guild guild) =>
{ _guilds.GetOrAdd(guild.Id, (_, g) => new CachedGuild(g), guild);
return _guilds.Values
.Select(g => g.Guild)
.ToAsyncEnumerable();
}
public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId) private record CachedGuild(Guild Guild)
{ {
if (!_guilds.TryGetValue(guildId, out var guild)) public readonly ConcurrentDictionary<ulong, bool> Channels = new();
throw new ArgumentException("Guild not found", nameof(guildId));
return Task.FromResult(guild.Channels.Keys.Select(c => _channels[c]));
}
private CachedGuild SaveGuildRaw(Guild guild) =>
_guilds.GetOrAdd(guild.Id, (_, g) => new CachedGuild(g), guild);
private record CachedGuild(Guild Guild)
{
public readonly ConcurrentDictionary<ulong, bool> Channels = new();
}
} }
} }

View File

@ -1,82 +1,81 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using Myriad.Cache; using Myriad.Cache;
using Myriad.Rest; using Myriad.Rest;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions;
public static class CacheExtensions
{ {
public static class CacheExtensions public static async Task<Guild> GetGuild(this IDiscordCache cache, ulong guildId)
{ {
public static async Task<Guild> GetGuild(this IDiscordCache cache, ulong guildId) if (!(await cache.TryGetGuild(guildId) is Guild guild))
{ throw new KeyNotFoundException($"Guild {guildId} not found in cache");
if (!(await cache.TryGetGuild(guildId) is Guild guild)) return guild;
throw new KeyNotFoundException($"Guild {guildId} not found in cache"); }
return guild;
}
public static async Task<Channel> GetChannel(this IDiscordCache cache, ulong channelId) public static async Task<Channel> GetChannel(this IDiscordCache cache, ulong channelId)
{ {
if (!(await cache.TryGetChannel(channelId) is Channel channel)) if (!(await cache.TryGetChannel(channelId) is Channel channel))
throw new KeyNotFoundException($"Channel {channelId} not found in cache"); throw new KeyNotFoundException($"Channel {channelId} not found in cache");
return channel; return channel;
} }
public static async Task<User> GetUser(this IDiscordCache cache, ulong userId) public static async Task<User> GetUser(this IDiscordCache cache, ulong userId)
{ {
if (!(await cache.TryGetUser(userId) is User user)) if (!(await cache.TryGetUser(userId) is User user))
throw new KeyNotFoundException($"User {userId} not found in cache"); throw new KeyNotFoundException($"User {userId} not found in cache");
return user; return user;
} }
public static async Task<Role> GetRole(this IDiscordCache cache, ulong roleId) public static async Task<Role> GetRole(this IDiscordCache cache, ulong roleId)
{ {
if (!(await cache.TryGetRole(roleId) is Role role)) if (!(await cache.TryGetRole(roleId) is Role role))
throw new KeyNotFoundException($"Role {roleId} not found in cache"); throw new KeyNotFoundException($"Role {roleId} not found in cache");
return role; return role;
} }
public static async ValueTask<User?> GetOrFetchUser(this IDiscordCache cache, DiscordApiClient rest, ulong userId) public static async ValueTask<User?> GetOrFetchUser(this IDiscordCache cache, DiscordApiClient rest,
{ ulong userId)
if (await cache.TryGetUser(userId) is User cacheUser) {
return cacheUser; if (await cache.TryGetUser(userId) is User cacheUser)
return cacheUser;
var restUser = await rest.GetUser(userId); var restUser = await rest.GetUser(userId);
if (restUser != null) if (restUser != null)
await cache.SaveUser(restUser); await cache.SaveUser(restUser);
return restUser; return restUser;
} }
public static async ValueTask<Channel?> GetOrFetchChannel(this IDiscordCache cache, DiscordApiClient rest, ulong channelId) public static async ValueTask<Channel?> GetOrFetchChannel(this IDiscordCache cache, DiscordApiClient rest,
{ ulong channelId)
if (await cache.TryGetChannel(channelId) is { } cacheChannel) {
return cacheChannel; if (await cache.TryGetChannel(channelId) is { } cacheChannel)
return cacheChannel;
var restChannel = await rest.GetChannel(channelId); var restChannel = await rest.GetChannel(channelId);
if (restChannel != null) if (restChannel != null)
await cache.SaveChannel(restChannel);
return restChannel;
}
public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest, ulong recipientId)
{
if (await cache.TryGetDmChannel(recipientId) is { } cacheChannel)
return cacheChannel;
var restChannel = await rest.CreateDm(recipientId);
await cache.SaveChannel(restChannel); await cache.SaveChannel(restChannel);
return restChannel; return restChannel;
} }
public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong channelOrThread) public static async Task<Channel> GetOrCreateDmChannel(this IDiscordCache cache, DiscordApiClient rest,
{ ulong recipientId)
var channel = await cache.GetChannel(channelOrThread); {
if (!channel.IsThread()) if (await cache.TryGetDmChannel(recipientId) is { } cacheChannel)
return channel; return cacheChannel;
var parent = await cache.GetChannel(channel.ParentId!.Value); var restChannel = await rest.CreateDm(recipientId);
return parent; await cache.SaveChannel(restChannel);
} return restChannel;
}
public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong channelOrThread)
{
var channel = await cache.GetChannel(channelOrThread);
if (!channel.IsThread())
return channel;
var parent = await cache.GetChannel(channel.ParentId!.Value);
return parent;
} }
} }

View File

@ -1,16 +1,15 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions;
public static class ChannelExtensions
{ {
public static class ChannelExtensions public static string Mention(this Channel channel) => $"<#{channel.Id}>";
{
public static string Mention(this Channel channel) => $"<#{channel.Id}>";
public static bool IsThread(this Channel channel) => channel.Type.IsThread(); public static bool IsThread(this Channel channel) => channel.Type.IsThread();
public static bool IsThread(this Channel.ChannelType type) => public static bool IsThread(this Channel.ChannelType type) =>
type is Channel.ChannelType.GuildPublicThread type is Channel.ChannelType.GuildPublicThread
or Channel.ChannelType.GuildPrivateThread or Channel.ChannelType.GuildPrivateThread
or Channel.ChannelType.GuildNewsThread; or Channel.ChannelType.GuildNewsThread;
}
} }

View File

@ -1,22 +1,21 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions;
public static class GuildExtensions
{ {
public static class GuildExtensions public static int FileSizeLimit(this Guild guild)
{ {
public static int FileSizeLimit(this Guild guild) switch (guild.PremiumTier)
{ {
switch (guild.PremiumTier) default:
{ case PremiumTier.NONE:
default: case PremiumTier.TIER_1:
case PremiumTier.NONE: return 8;
case PremiumTier.TIER_1: case PremiumTier.TIER_2:
return 8; return 50;
case PremiumTier.TIER_2: case PremiumTier.TIER_3:
return 50; return 100;
case PremiumTier.TIER_3:
return 100;
}
} }
} }
} }

View File

@ -1,14 +1,13 @@
using Myriad.Gateway; using Myriad.Gateway;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions;
{
public static class MessageExtensions
{
public static string JumpLink(this Message msg) =>
$"https://discord.com/channels/{msg.GuildId}/{msg.ChannelId}/{msg.Id}";
public static string JumpLink(this MessageReactionAddEvent msg) => public static class MessageExtensions
$"https://discord.com/channels/{msg.GuildId}/{msg.ChannelId}/{msg.MessageId}"; {
} public static string JumpLink(this Message msg) =>
$"https://discord.com/channels/{msg.GuildId}/{msg.ChannelId}/{msg.Id}";
public static string JumpLink(this MessageReactionAddEvent msg) =>
$"https://discord.com/channels/{msg.GuildId}/{msg.ChannelId}/{msg.MessageId}";
} }

View File

@ -1,176 +1,167 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Myriad.Cache; using Myriad.Cache;
using Myriad.Gateway; using Myriad.Gateway;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions;
public static class PermissionExtensions
{ {
public static class PermissionExtensions private const PermissionSet NeedsViewChannel =
PermissionSet.SendMessages |
PermissionSet.SendTtsMessages |
PermissionSet.ManageMessages |
PermissionSet.EmbedLinks |
PermissionSet.AttachFiles |
PermissionSet.ReadMessageHistory |
PermissionSet.MentionEveryone |
PermissionSet.UseExternalEmojis |
PermissionSet.AddReactions |
PermissionSet.Connect |
PermissionSet.Speak |
PermissionSet.MuteMembers |
PermissionSet.DeafenMembers |
PermissionSet.MoveMembers |
PermissionSet.UseVad |
PermissionSet.Stream |
PermissionSet.PrioritySpeaker;
private const PermissionSet NeedsSendMessages =
PermissionSet.MentionEveryone |
PermissionSet.SendTtsMessages |
PermissionSet.AttachFiles |
PermissionSet.EmbedLinks;
public static Task<PermissionSet> PermissionsFor(this IDiscordCache cache, MessageCreateEvent message) =>
PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member, message.WebhookId != null);
public static Task<PermissionSet>
PermissionsFor(this IDiscordCache cache, ulong channelId, GuildMember member) =>
PermissionsFor(cache, channelId, member.User.Id, member);
public static async Task<PermissionSet> PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId,
GuildMemberPartial? member, bool isWebhook = false)
{ {
public static Task<PermissionSet> PermissionsFor(this IDiscordCache cache, MessageCreateEvent message) => if (!(await cache.TryGetChannel(channelId) is Channel channel))
PermissionsFor(cache, message.ChannelId, message.Author.Id, message.Member, isWebhook: message.WebhookId != null); // todo: handle channel not found better
return PermissionSet.Dm;
public static Task<PermissionSet> PermissionsFor(this IDiscordCache cache, ulong channelId, GuildMember member) => if (channel.GuildId == null)
PermissionsFor(cache, channelId, member.User.Id, member); return PermissionSet.Dm;
public static async Task<PermissionSet> PermissionsFor(this IDiscordCache cache, ulong channelId, ulong userId, GuildMemberPartial? member, bool isWebhook = false) var rootChannel = await cache.GetRootChannel(channelId);
{
if (!(await cache.TryGetChannel(channelId) is Channel channel))
// todo: handle channel not found better
return PermissionSet.Dm;
if (channel.GuildId == null) var guild = await cache.GetGuild(channel.GuildId.Value);
return PermissionSet.Dm;
var rootChannel = await cache.GetRootChannel(channelId); if (isWebhook)
return EveryonePermissions(guild);
var guild = await cache.GetGuild(channel.GuildId.Value); return PermissionsFor(guild, rootChannel, userId, member);
if (isWebhook)
return EveryonePermissions(guild);
return PermissionsFor(guild, rootChannel, userId, member);
}
public static PermissionSet EveryonePermissions(this Guild guild) =>
guild.Roles.FirstOrDefault(r => r.Id == guild.Id)?.Permissions ?? PermissionSet.Dm;
public static async Task<PermissionSet> EveryonePermissions(this IDiscordCache cache, Channel channel)
{
if (channel.Type == Channel.ChannelType.Dm)
return PermissionSet.Dm;
var defaultPermissions = (await cache.GetGuild(channel.GuildId!.Value)).EveryonePermissions();
var overwrite = channel.PermissionOverwrites?.FirstOrDefault(r => r.Id == channel.GuildId);
if (overwrite == null)
return defaultPermissions;
var perms = defaultPermissions;
perms &= ~overwrite.Deny;
perms |= overwrite.Allow;
return perms;
}
public static PermissionSet PermissionsFor(Guild guild, Channel channel, MessageCreateEvent msg) =>
PermissionsFor(guild, channel, msg.Author.Id, msg.Member);
public static PermissionSet PermissionsFor(Guild guild, Channel channel, ulong userId, GuildMemberPartial? member)
{
if (channel.Type == Channel.ChannelType.Dm)
return PermissionSet.Dm;
if (member == null)
// this happens with system (Discord platform-owned) users - they're not actually in the guild, so there is no member object.
return EveryonePermissions(guild);
var perms = GuildPermissions(guild, userId, member.Roles);
perms = ApplyChannelOverwrites(perms, channel, userId, member.Roles);
if ((perms & PermissionSet.Administrator) == PermissionSet.Administrator)
return PermissionSet.All;
if ((perms & PermissionSet.ViewChannel) == 0)
perms &= ~NeedsViewChannel;
if ((perms & PermissionSet.SendMessages) == 0)
perms &= ~NeedsSendMessages;
return perms;
}
public static PermissionSet GuildPermissions(this Guild guild, ulong userId, ICollection<ulong> roleIds)
{
if (guild.OwnerId == userId)
return PermissionSet.All;
var perms = PermissionSet.None;
foreach (var role in guild.Roles)
{
if (role.Id == guild.Id || roleIds.Contains(role.Id))
perms |= role.Permissions;
}
if (perms.HasFlag(PermissionSet.Administrator))
return PermissionSet.All;
return perms;
}
public static PermissionSet ApplyChannelOverwrites(PermissionSet perms, Channel channel, ulong userId,
ICollection<ulong> roleIds)
{
if (channel.PermissionOverwrites == null)
return perms;
var everyoneDeny = PermissionSet.None;
var everyoneAllow = PermissionSet.None;
var roleDeny = PermissionSet.None;
var roleAllow = PermissionSet.None;
var userDeny = PermissionSet.None;
var userAllow = PermissionSet.None;
foreach (var overwrite in channel.PermissionOverwrites)
{
switch (overwrite.Type)
{
case Channel.OverwriteType.Role when overwrite.Id == channel.GuildId:
everyoneDeny |= overwrite.Deny;
everyoneAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Role when roleIds.Contains(overwrite.Id):
roleDeny |= overwrite.Deny;
roleAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Member when overwrite.Id == userId:
userDeny |= overwrite.Deny;
userAllow |= overwrite.Allow;
break;
}
}
perms &= ~everyoneDeny;
perms |= everyoneAllow;
perms &= ~roleDeny;
perms |= roleAllow;
perms &= ~userDeny;
perms |= userAllow;
return perms;
}
private const PermissionSet NeedsViewChannel =
PermissionSet.SendMessages |
PermissionSet.SendTtsMessages |
PermissionSet.ManageMessages |
PermissionSet.EmbedLinks |
PermissionSet.AttachFiles |
PermissionSet.ReadMessageHistory |
PermissionSet.MentionEveryone |
PermissionSet.UseExternalEmojis |
PermissionSet.AddReactions |
PermissionSet.Connect |
PermissionSet.Speak |
PermissionSet.MuteMembers |
PermissionSet.DeafenMembers |
PermissionSet.MoveMembers |
PermissionSet.UseVad |
PermissionSet.Stream |
PermissionSet.PrioritySpeaker;
private const PermissionSet NeedsSendMessages =
PermissionSet.MentionEveryone |
PermissionSet.SendTtsMessages |
PermissionSet.AttachFiles |
PermissionSet.EmbedLinks;
public static string ToPermissionString(this PermissionSet perms)
{
// TODO: clean string
return perms.ToString();
}
} }
public static PermissionSet EveryonePermissions(this Guild guild) =>
guild.Roles.FirstOrDefault(r => r.Id == guild.Id)?.Permissions ?? PermissionSet.Dm;
public static async Task<PermissionSet> EveryonePermissions(this IDiscordCache cache, Channel channel)
{
if (channel.Type == Channel.ChannelType.Dm)
return PermissionSet.Dm;
var defaultPermissions = (await cache.GetGuild(channel.GuildId!.Value)).EveryonePermissions();
var overwrite = channel.PermissionOverwrites?.FirstOrDefault(r => r.Id == channel.GuildId);
if (overwrite == null)
return defaultPermissions;
var perms = defaultPermissions;
perms &= ~overwrite.Deny;
perms |= overwrite.Allow;
return perms;
}
public static PermissionSet PermissionsFor(Guild guild, Channel channel, MessageCreateEvent msg) =>
PermissionsFor(guild, channel, msg.Author.Id, msg.Member);
public static PermissionSet PermissionsFor(Guild guild, Channel channel, ulong userId,
GuildMemberPartial? member)
{
if (channel.Type == Channel.ChannelType.Dm)
return PermissionSet.Dm;
if (member == null)
// this happens with system (Discord platform-owned) users - they're not actually in the guild, so there is no member object.
return EveryonePermissions(guild);
var perms = GuildPermissions(guild, userId, member.Roles);
perms = ApplyChannelOverwrites(perms, channel, userId, member.Roles);
if ((perms & PermissionSet.Administrator) == PermissionSet.Administrator)
return PermissionSet.All;
if ((perms & PermissionSet.ViewChannel) == 0)
perms &= ~NeedsViewChannel;
if ((perms & PermissionSet.SendMessages) == 0)
perms &= ~NeedsSendMessages;
return perms;
}
public static PermissionSet GuildPermissions(this Guild guild, ulong userId, ICollection<ulong> roleIds)
{
if (guild.OwnerId == userId)
return PermissionSet.All;
var perms = PermissionSet.None;
foreach (var role in guild.Roles)
if (role.Id == guild.Id || roleIds.Contains(role.Id))
perms |= role.Permissions;
if (perms.HasFlag(PermissionSet.Administrator))
return PermissionSet.All;
return perms;
}
public static PermissionSet ApplyChannelOverwrites(PermissionSet perms, Channel channel, ulong userId,
ICollection<ulong> roleIds)
{
if (channel.PermissionOverwrites == null)
return perms;
var everyoneDeny = PermissionSet.None;
var everyoneAllow = PermissionSet.None;
var roleDeny = PermissionSet.None;
var roleAllow = PermissionSet.None;
var userDeny = PermissionSet.None;
var userAllow = PermissionSet.None;
foreach (var overwrite in channel.PermissionOverwrites)
switch (overwrite.Type)
{
case Channel.OverwriteType.Role when overwrite.Id == channel.GuildId:
everyoneDeny |= overwrite.Deny;
everyoneAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Role when roleIds.Contains(overwrite.Id):
roleDeny |= overwrite.Deny;
roleAllow |= overwrite.Allow;
break;
case Channel.OverwriteType.Member when overwrite.Id == userId:
userDeny |= overwrite.Deny;
userAllow |= overwrite.Allow;
break;
}
perms &= ~everyoneDeny;
perms |= everyoneAllow;
perms &= ~roleDeny;
perms |= roleAllow;
perms &= ~userDeny;
perms |= userAllow;
return perms;
}
public static string ToPermissionString(this PermissionSet perms) =>
// TODO: clean string
perms.ToString();
} }

View File

@ -1,20 +1,17 @@
using System;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions;
public static class SnowflakeExtensions
{ {
public static class SnowflakeExtensions public static readonly DateTimeOffset DiscordEpoch = new(2015, 1, 1, 0, 0, 0, TimeSpan.Zero);
{
public static readonly DateTimeOffset DiscordEpoch = new(2015, 1, 1, 0, 0, 0, TimeSpan.Zero);
public static DateTimeOffset SnowflakeToTimestamp(ulong snowflake) => public static DateTimeOffset SnowflakeToTimestamp(ulong snowflake) =>
DiscordEpoch + TimeSpan.FromMilliseconds(snowflake >> 22); DiscordEpoch + TimeSpan.FromMilliseconds(snowflake >> 22);
public static DateTimeOffset Timestamp(this Message msg) => SnowflakeToTimestamp(msg.Id); public static DateTimeOffset Timestamp(this Message msg) => SnowflakeToTimestamp(msg.Id);
public static DateTimeOffset Timestamp(this Channel channel) => SnowflakeToTimestamp(channel.Id); public static DateTimeOffset Timestamp(this Channel channel) => SnowflakeToTimestamp(channel.Id);
public static DateTimeOffset Timestamp(this Guild guild) => SnowflakeToTimestamp(guild.Id); public static DateTimeOffset Timestamp(this Guild guild) => SnowflakeToTimestamp(guild.Id);
public static DateTimeOffset Timestamp(this Webhook webhook) => SnowflakeToTimestamp(webhook.Id); public static DateTimeOffset Timestamp(this Webhook webhook) => SnowflakeToTimestamp(webhook.Id);
public static DateTimeOffset Timestamp(this User user) => SnowflakeToTimestamp(user.Id); public static DateTimeOffset Timestamp(this User user) => SnowflakeToTimestamp(user.Id);
}
} }

View File

@ -1,12 +1,11 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions;
{
public static class UserExtensions
{
public static string Mention(this User user) => $"<@{user.Id}>";
public static string AvatarUrl(this User user, string? format = "png", int? size = 128) => public static class UserExtensions
$"https://cdn.discordapp.com/avatars/{user.Id}/{user.Avatar}.{format}?size={size}"; {
} public static string Mention(this User user) => $"<@{user.Id}>";
public static string AvatarUrl(this User user, string? format = "png", int? size = 128) =>
$"https://cdn.discordapp.com/avatars/{user.Id}/{user.Avatar}.{format}?size={size}";
} }

View File

@ -1,90 +1,84 @@
using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Myriad.Gateway.Limit; using Myriad.Gateway.Limit;
using Myriad.Types; using Myriad.Types;
using Serilog; using Serilog;
namespace Myriad.Gateway namespace Myriad.Gateway;
public class Cluster
{ {
public class Cluster private readonly GatewaySettings _gatewaySettings;
private readonly ILogger _logger;
private readonly ConcurrentDictionary<int, Shard> _shards = new();
private IGatewayRatelimiter? _ratelimiter;
public Cluster(GatewaySettings gatewaySettings, ILogger logger)
{ {
private readonly GatewaySettings _gatewaySettings; _gatewaySettings = gatewaySettings;
private readonly ILogger _logger; _logger = logger.ForContext<Cluster>();
private readonly ConcurrentDictionary<int, Shard> _shards = new(); }
private IGatewayRatelimiter? _ratelimiter;
public Cluster(GatewaySettings gatewaySettings, ILogger logger) public Func<Shard, IGatewayEvent, Task>? EventReceived { get; set; }
{
_gatewaySettings = gatewaySettings;
_logger = logger.ForContext<Cluster>();
}
public Func<Shard, IGatewayEvent, Task>? EventReceived { get; set; } public IReadOnlyDictionary<int, Shard> Shards => _shards;
public event Action<Shard>? ShardCreated; public event Action<Shard>? ShardCreated;
public IReadOnlyDictionary<int, Shard> Shards => _shards; public async Task Start(GatewayInfo.Bot info)
{
await Start(info.Url, 0, info.Shards - 1, info.Shards, info.SessionStartLimit.MaxConcurrency);
}
public async Task Start(GatewayInfo.Bot info) public async Task Start(string url, int shardMin, int shardMax, int shardTotal, int recommendedConcurrency)
{ {
await Start(info.Url, 0, info.Shards - 1, info.Shards, info.SessionStartLimit.MaxConcurrency); _ratelimiter = GetRateLimiter(recommendedConcurrency);
}
public async Task Start(string url, int shardMin, int shardMax, int shardTotal, int recommendedConcurrency) var shardCount = shardMax - shardMin + 1;
{ _logger.Information("Starting {ShardCount} of {ShardTotal} shards (#{ShardMin}-#{ShardMax}) at {Url}",
_ratelimiter = GetRateLimiter(recommendedConcurrency); shardCount, shardTotal, shardMin, shardMax, url);
for (var i = shardMin; i <= shardMax; i++)
CreateAndAddShard(url, new ShardInfo(i, shardTotal));
var shardCount = shardMax - shardMin + 1; await StartShards();
_logger.Information("Starting {ShardCount} of {ShardTotal} shards (#{ShardMin}-#{ShardMax}) at {Url}", }
shardCount, shardTotal, shardMin, shardMax, url);
for (var i = shardMin; i <= shardMax; i++)
CreateAndAddShard(url, new ShardInfo(i, shardTotal));
await StartShards(); private async Task StartShards()
} {
private async Task StartShards() _logger.Information("Connecting shards...");
{ foreach (var shard in _shards.Values)
_logger.Information("Connecting shards..."); await shard.Start();
foreach (var shard in _shards.Values) }
await shard.Start();
}
private void CreateAndAddShard(string url, ShardInfo shardInfo) private void CreateAndAddShard(string url, ShardInfo shardInfo)
{ {
var shard = new Shard(_gatewaySettings, shardInfo, _ratelimiter!, url, _logger); var shard = new Shard(_gatewaySettings, shardInfo, _ratelimiter!, url, _logger);
shard.OnEventReceived += evt => OnShardEventReceived(shard, evt); shard.OnEventReceived += evt => OnShardEventReceived(shard, evt);
_shards[shardInfo.ShardId] = shard; _shards[shardInfo.ShardId] = shard;
ShardCreated?.Invoke(shard); ShardCreated?.Invoke(shard);
} }
private async Task OnShardEventReceived(Shard shard, IGatewayEvent evt) private async Task OnShardEventReceived(Shard shard, IGatewayEvent evt)
{ {
if (EventReceived != null) if (EventReceived != null)
await EventReceived(shard, evt); await EventReceived(shard, evt);
} }
private int GetActualShardConcurrency(int recommendedConcurrency) private int GetActualShardConcurrency(int recommendedConcurrency)
{ {
if (_gatewaySettings.MaxShardConcurrency == null) if (_gatewaySettings.MaxShardConcurrency == null)
return recommendedConcurrency; return recommendedConcurrency;
return Math.Min(_gatewaySettings.MaxShardConcurrency.Value, recommendedConcurrency); return Math.Min(_gatewaySettings.MaxShardConcurrency.Value, recommendedConcurrency);
} }
private IGatewayRatelimiter GetRateLimiter(int recommendedConcurrency) private IGatewayRatelimiter GetRateLimiter(int recommendedConcurrency)
{ {
if (_gatewaySettings.GatewayQueueUrl != null) if (_gatewaySettings.GatewayQueueUrl != null)
{ return new TwilightGatewayRatelimiter(_logger, _gatewaySettings.GatewayQueueUrl);
return new TwilightGatewayRatelimiter(_logger, _gatewaySettings.GatewayQueueUrl);
}
var concurrency = GetActualShardConcurrency(recommendedConcurrency); var concurrency = GetActualShardConcurrency(recommendedConcurrency);
return new LocalGatewayRatelimiter(_logger, concurrency); return new LocalGatewayRatelimiter(_logger, concurrency);
}
} }
} }

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record ChannelCreateEvent: Channel, IGatewayEvent; public record ChannelCreateEvent: Channel, IGatewayEvent;
}

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record ChannelDeleteEvent: Channel, IGatewayEvent; public record ChannelDeleteEvent: Channel, IGatewayEvent;
}

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record ChannelUpdateEvent: Channel, IGatewayEvent; public record ChannelUpdateEvent: Channel, IGatewayEvent;
}

View File

@ -1,11 +1,10 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record GuildCreateEvent: Guild, IGatewayEvent
{ {
public record GuildCreateEvent: Guild, IGatewayEvent public Channel[] Channels { get; init; }
{ public GuildMember[] Members { get; init; }
public Channel[] Channels { get; init; } public Channel[] Threads { get; init; }
public GuildMember[] Members { get; init; }
public Channel[] Threads { get; init; }
}
} }

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GuildDeleteEvent(ulong Id, bool Unavailable): IGatewayEvent; public record GuildDeleteEvent(ulong Id, bool Unavailable): IGatewayEvent;
}

View File

@ -1,9 +1,8 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record GuildMemberAddEvent: GuildMember, IGatewayEvent
{ {
public record GuildMemberAddEvent: GuildMember, IGatewayEvent public ulong GuildId { get; init; }
{
public ulong GuildId { get; init; }
}
} }

View File

@ -1,10 +1,9 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public class GuildMemberRemoveEvent: IGatewayEvent
{ {
public class GuildMemberRemoveEvent: IGatewayEvent public ulong GuildId { get; init; }
{ public User User { get; init; }
public ulong GuildId { get; init; }
public User User { get; init; }
}
} }

View File

@ -1,9 +1,8 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record GuildMemberUpdateEvent: GuildMember, IGatewayEvent
{ {
public record GuildMemberUpdateEvent: GuildMember, IGatewayEvent public ulong GuildId { get; init; }
{
public ulong GuildId { get; init; }
}
} }

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GuildRoleCreateEvent(ulong GuildId, Role Role): IGatewayEvent; public record GuildRoleCreateEvent(ulong GuildId, Role Role): IGatewayEvent;
}

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GuildRoleDeleteEvent(ulong GuildId, ulong RoleId): IGatewayEvent; public record GuildRoleDeleteEvent(ulong GuildId, ulong RoleId): IGatewayEvent;
}

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GuildRoleUpdateEvent(ulong GuildId, Role Role): IGatewayEvent; public record GuildRoleUpdateEvent(ulong GuildId, Role Role): IGatewayEvent;
}

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GuildUpdateEvent: Guild, IGatewayEvent; public record GuildUpdateEvent: Guild, IGatewayEvent;
}

View File

@ -1,39 +1,35 @@
using System; namespace Myriad.Gateway;
using System.Collections.Generic;
namespace Myriad.Gateway public interface IGatewayEvent
{ {
public interface IGatewayEvent public static readonly Dictionary<string, Type> EventTypes = new()
{ {
public static readonly Dictionary<string, Type> EventTypes = new() { "READY", typeof(ReadyEvent) },
{ { "RESUMED", typeof(ResumedEvent) },
{ "READY", typeof(ReadyEvent) }, { "GUILD_CREATE", typeof(GuildCreateEvent) },
{ "RESUMED", typeof(ResumedEvent) }, { "GUILD_UPDATE", typeof(GuildUpdateEvent) },
{ "GUILD_CREATE", typeof(GuildCreateEvent) }, { "GUILD_DELETE", typeof(GuildDeleteEvent) },
{ "GUILD_UPDATE", typeof(GuildUpdateEvent) }, { "GUILD_MEMBER_ADD", typeof(GuildMemberAddEvent) },
{ "GUILD_DELETE", typeof(GuildDeleteEvent) }, { "GUILD_MEMBER_REMOVE", typeof(GuildMemberRemoveEvent) },
{ "GUILD_MEMBER_ADD", typeof(GuildMemberAddEvent) }, { "GUILD_MEMBER_UPDATE", typeof(GuildMemberUpdateEvent) },
{ "GUILD_MEMBER_REMOVE", typeof(GuildMemberRemoveEvent) }, { "GUILD_ROLE_CREATE", typeof(GuildRoleCreateEvent) },
{ "GUILD_MEMBER_UPDATE", typeof(GuildMemberUpdateEvent) }, { "GUILD_ROLE_UPDATE", typeof(GuildRoleUpdateEvent) },
{ "GUILD_ROLE_CREATE", typeof(GuildRoleCreateEvent) }, { "GUILD_ROLE_DELETE", typeof(GuildRoleDeleteEvent) },
{ "GUILD_ROLE_UPDATE", typeof(GuildRoleUpdateEvent) }, { "CHANNEL_CREATE", typeof(ChannelCreateEvent) },
{ "GUILD_ROLE_DELETE", typeof(GuildRoleDeleteEvent) }, { "CHANNEL_UPDATE", typeof(ChannelUpdateEvent) },
{ "CHANNEL_CREATE", typeof(ChannelCreateEvent) }, { "CHANNEL_DELETE", typeof(ChannelDeleteEvent) },
{ "CHANNEL_UPDATE", typeof(ChannelUpdateEvent) }, { "THREAD_CREATE", typeof(ThreadCreateEvent) },
{ "CHANNEL_DELETE", typeof(ChannelDeleteEvent) }, { "THREAD_UPDATE", typeof(ThreadUpdateEvent) },
{ "THREAD_CREATE", typeof(ThreadCreateEvent) }, { "THREAD_DELETE", typeof(ThreadDeleteEvent) },
{ "THREAD_UPDATE", typeof(ThreadUpdateEvent) }, { "THREAD_LIST_SYNC", typeof(ThreadListSyncEvent) },
{ "THREAD_DELETE", typeof(ThreadDeleteEvent) }, { "MESSAGE_CREATE", typeof(MessageCreateEvent) },
{ "THREAD_LIST_SYNC", typeof(ThreadListSyncEvent) }, { "MESSAGE_UPDATE", typeof(MessageUpdateEvent) },
{ "MESSAGE_CREATE", typeof(MessageCreateEvent) }, { "MESSAGE_DELETE", typeof(MessageDeleteEvent) },
{ "MESSAGE_UPDATE", typeof(MessageUpdateEvent) }, { "MESSAGE_DELETE_BULK", typeof(MessageDeleteBulkEvent) },
{ "MESSAGE_DELETE", typeof(MessageDeleteEvent) }, { "MESSAGE_REACTION_ADD", typeof(MessageReactionAddEvent) },
{ "MESSAGE_DELETE_BULK", typeof(MessageDeleteBulkEvent) }, { "MESSAGE_REACTION_REMOVE", typeof(MessageReactionRemoveEvent) },
{ "MESSAGE_REACTION_ADD", typeof(MessageReactionAddEvent) }, { "MESSAGE_REACTION_REMOVE_ALL", typeof(MessageReactionRemoveAllEvent) },
{ "MESSAGE_REACTION_REMOVE", typeof(MessageReactionRemoveEvent) }, { "MESSAGE_REACTION_REMOVE_EMOJI", typeof(MessageReactionRemoveEmojiEvent) },
{ "MESSAGE_REACTION_REMOVE_ALL", typeof(MessageReactionRemoveAllEvent) }, { "INTERACTION_CREATE", typeof(InteractionCreateEvent) }
{ "MESSAGE_REACTION_REMOVE_EMOJI", typeof(MessageReactionRemoveEmojiEvent) }, };
{ "INTERACTION_CREATE", typeof(InteractionCreateEvent) }
};
}
} }

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record InteractionCreateEvent: Interaction, IGatewayEvent; public record InteractionCreateEvent: Interaction, IGatewayEvent;
}

View File

@ -1,9 +1,8 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record MessageCreateEvent: Message, IGatewayEvent
{ {
public record MessageCreateEvent: Message, IGatewayEvent public GuildMemberPartial? Member { get; init; }
{
public GuildMemberPartial? Member { get; init; }
}
} }

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record MessageDeleteBulkEvent(ulong[] Ids, ulong ChannelId, ulong? GuildId): IGatewayEvent; public record MessageDeleteBulkEvent(ulong[] Ids, ulong ChannelId, ulong? GuildId): IGatewayEvent;
}

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record MessageDeleteEvent(ulong Id, ulong ChannelId, ulong? GuildId): IGatewayEvent; public record MessageDeleteEvent(ulong Id, ulong ChannelId, ulong? GuildId): IGatewayEvent;
}

View File

@ -1,8 +1,7 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record MessageReactionAddEvent(ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId, public record MessageReactionAddEvent(ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId,
GuildMember? Member, GuildMember? Member,
Emoji Emoji): IGatewayEvent; Emoji Emoji): IGatewayEvent;
}

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record MessageReactionRemoveAllEvent(ulong ChannelId, ulong MessageId, ulong? GuildId): IGatewayEvent; public record MessageReactionRemoveAllEvent(ulong ChannelId, ulong MessageId, ulong? GuildId): IGatewayEvent;
}

View File

@ -1,7 +1,6 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record MessageReactionRemoveEmojiEvent public record MessageReactionRemoveEmojiEvent
(ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent; (ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent;
}

View File

@ -1,7 +1,6 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record MessageReactionRemoveEvent public record MessageReactionRemoveEvent
(ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent; (ulong UserId, ulong ChannelId, ulong MessageId, ulong? GuildId, Emoji Emoji): IGatewayEvent;
}

View File

@ -1,15 +1,15 @@
using Myriad.Types; using Myriad.Types;
using Myriad.Utils; using Myriad.Utils;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record MessageUpdateEvent(ulong Id, ulong ChannelId): IGatewayEvent
{ {
public record MessageUpdateEvent(ulong Id, ulong ChannelId): IGatewayEvent public Optional<string?> Content { get; init; }
{ public Optional<User> Author { get; init; }
public Optional<string?> Content { get; init; } public Optional<GuildMemberPartial> Member { get; init; }
public Optional<User> Author { get; init; } public Optional<Message.Attachment[]> Attachments { get; init; }
public Optional<GuildMemberPartial> Member { get; init; }
public Optional<Message.Attachment[]> Attachments { get; init; } public Optional<ulong?> GuildId { get; init; }
public Optional<ulong?> GuildId { get; init; } // TODO: lots of partials
// TODO: lots of partials
}
} }

View File

@ -2,14 +2,13 @@ using System.Text.Json.Serialization;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record ReadyEvent: IGatewayEvent
{ {
public record ReadyEvent: IGatewayEvent [JsonPropertyName("v")] public int Version { get; init; }
{ public User User { get; init; }
[JsonPropertyName("v")] public int Version { get; init; } public string SessionId { get; init; }
public User User { get; init; } public ShardInfo? Shard { get; init; }
public string SessionId { get; init; } public ApplicationPartial Application { get; init; }
public ShardInfo? Shard { get; init; }
public ApplicationPartial Application { get; init; }
}
} }

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record ResumedEvent: IGatewayEvent; public record ResumedEvent: IGatewayEvent;
}

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record ThreadCreateEvent: Channel, IGatewayEvent; public record ThreadCreateEvent: Channel, IGatewayEvent;
}

View File

@ -1,12 +1,11 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record ThreadDeleteEvent: IGatewayEvent
{ {
public record ThreadDeleteEvent: IGatewayEvent public ulong Id { get; init; }
{ public ulong? GuildId { get; init; }
public ulong Id { get; init; } public ulong? ParentId { get; init; }
public ulong? GuildId { get; init; } public Channel.ChannelType Type { get; init; }
public ulong? ParentId { get; init; }
public Channel.ChannelType Type { get; init; }
}
} }

View File

@ -1,11 +1,10 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record ThreadListSyncEvent: IGatewayEvent
{ {
public record ThreadListSyncEvent: IGatewayEvent public ulong GuildId { get; init; }
{ public ulong[]? ChannelIds { get; init; }
public ulong GuildId { get; init; } public Channel[] Threads { get; init; }
public ulong[]? ChannelIds { get; init; }
public Channel[] Threads { get; init; }
}
} }

View File

@ -1,6 +1,5 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record ThreadUpdateEvent: Channel, IGatewayEvent; public record ThreadUpdateEvent: Channel, IGatewayEvent;
}

View File

@ -1,35 +1,32 @@
using System; namespace Myriad.Gateway;
namespace Myriad.Gateway // TODO: unused?
public class GatewayCloseException: Exception
{ {
// TODO: unused? public GatewayCloseException(int closeCode, string closeReason) : base($"{closeCode}: {closeReason}")
public class GatewayCloseException: Exception
{ {
public GatewayCloseException(int closeCode, string closeReason) : base($"{closeCode}: {closeReason}") CloseCode = closeCode;
{ CloseReason = closeReason;
CloseCode = closeCode;
CloseReason = closeReason;
}
public int CloseCode { get; }
public string CloseReason { get; }
} }
public class GatewayCloseCode public int CloseCode { get; }
{ public string CloseReason { get; }
public const int UnknownError = 4000; }
public const int UnknownOpcode = 4001;
public const int DecodeError = 4002; public class GatewayCloseCode
public const int NotAuthenticated = 4003; {
public const int AuthenticationFailed = 4004; public const int UnknownError = 4000;
public const int AlreadyAuthenticated = 4005; public const int UnknownOpcode = 4001;
public const int InvalidSeq = 4007; public const int DecodeError = 4002;
public const int RateLimited = 4008; public const int NotAuthenticated = 4003;
public const int SessionTimedOut = 4009; public const int AuthenticationFailed = 4004;
public const int InvalidShard = 4010; public const int AlreadyAuthenticated = 4005;
public const int ShardingRequired = 4011; public const int InvalidSeq = 4007;
public const int InvalidApiVersion = 4012; public const int RateLimited = 4008;
public const int InvalidIntent = 4013; public const int SessionTimedOut = 4009;
public const int DisallowedIntent = 4014; public const int InvalidShard = 4010;
} public const int ShardingRequired = 4011;
public const int InvalidApiVersion = 4012;
public const int InvalidIntent = 4013;
public const int DisallowedIntent = 4014;
} }

View File

@ -1,24 +1,21 @@
using System; namespace Myriad.Gateway;
namespace Myriad.Gateway [Flags]
public enum GatewayIntent
{ {
[Flags] Guilds = 1 << 0,
public enum GatewayIntent GuildMembers = 1 << 1,
{ GuildBans = 1 << 2,
Guilds = 1 << 0, GuildEmojis = 1 << 3,
GuildMembers = 1 << 1, GuildIntegrations = 1 << 4,
GuildBans = 1 << 2, GuildWebhooks = 1 << 5,
GuildEmojis = 1 << 3, GuildInvites = 1 << 6,
GuildIntegrations = 1 << 4, GuildVoiceStates = 1 << 7,
GuildWebhooks = 1 << 5, GuildPresences = 1 << 8,
GuildInvites = 1 << 6, GuildMessages = 1 << 9,
GuildVoiceStates = 1 << 7, GuildMessageReactions = 1 << 10,
GuildPresences = 1 << 8, GuildMessageTyping = 1 << 11,
GuildMessages = 1 << 9, DirectMessages = 1 << 12,
GuildMessageReactions = 1 << 10, DirectMessageReactions = 1 << 13,
GuildMessageTyping = 1 << 11, DirectMessageTyping = 1 << 14
DirectMessages = 1 << 12,
DirectMessageReactions = 1 << 13,
DirectMessageTyping = 1 << 14
}
} }

View File

@ -1,33 +1,32 @@
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record GatewayPacket
{ {
public record GatewayPacket [JsonPropertyName("op")] public GatewayOpcode Opcode { get; init; }
{ [JsonPropertyName("d")] public object? Payload { get; init; }
[JsonPropertyName("op")] public GatewayOpcode Opcode { get; init; }
[JsonPropertyName("d")] public object? Payload { get; init; }
[JsonPropertyName("s")] [JsonPropertyName("s")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? Sequence { get; init; } public int? Sequence { get; init; }
[JsonPropertyName("t")] [JsonPropertyName("t")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? EventType { get; init; } public string? EventType { get; init; }
} }
public enum GatewayOpcode public enum GatewayOpcode
{ {
Dispatch = 0, Dispatch = 0,
Heartbeat = 1, Heartbeat = 1,
Identify = 2, Identify = 2,
PresenceUpdate = 3, PresenceUpdate = 3,
VoiceStateUpdate = 4, VoiceStateUpdate = 4,
Resume = 6, Resume = 6,
Reconnect = 7, Reconnect = 7,
RequestGuildMembers = 8, RequestGuildMembers = 8,
InvalidSession = 9, InvalidSession = 9,
Hello = 10, Hello = 10,
HeartbeatAck = 11 HeartbeatAck = 11
}
} }

View File

@ -1,10 +1,9 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
public record GatewaySettings
{ {
public record GatewaySettings public string Token { get; init; }
{ public GatewayIntent Intents { get; init; }
public string Token { get; init; } public int? MaxShardConcurrency { get; init; }
public GatewayIntent Intents { get; init; } public string? GatewayQueueUrl { get; init; }
public int? MaxShardConcurrency { get; init; }
public string? GatewayQueueUrl { get; init; }
}
} }

View File

@ -1,9 +1,6 @@
using System.Threading.Tasks; namespace Myriad.Gateway.Limit;
namespace Myriad.Gateway.Limit public interface IGatewayRatelimiter
{ {
public interface IGatewayRatelimiter public Task Identify(int shard);
{
public Task Identify(int shard);
}
} }

View File

@ -1,73 +1,70 @@
using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Threading.Tasks;
using Serilog; using Serilog;
namespace Myriad.Gateway.Limit namespace Myriad.Gateway.Limit;
public class LocalGatewayRatelimiter: IGatewayRatelimiter
{ {
public class LocalGatewayRatelimiter: IGatewayRatelimiter // docs specify 5 seconds, but we're actually throttling connections, not identify, so we need a bit of leeway
private static readonly TimeSpan BucketLength = TimeSpan.FromSeconds(6);
private readonly ConcurrentDictionary<int, ConcurrentQueue<TaskCompletionSource>> _buckets = new();
private readonly ILogger _logger;
private readonly int _maxConcurrency;
private Task? _refillTask;
public LocalGatewayRatelimiter(ILogger logger, int maxConcurrency)
{ {
// docs specify 5 seconds, but we're actually throttling connections, not identify, so we need a bit of leeway _logger = logger.ForContext<LocalGatewayRatelimiter>();
private static readonly TimeSpan BucketLength = TimeSpan.FromSeconds(6); _maxConcurrency = maxConcurrency;
}
private readonly ConcurrentDictionary<int, ConcurrentQueue<TaskCompletionSource>> _buckets = new(); public Task Identify(int shard)
private readonly int _maxConcurrency; {
var bucket = shard % _maxConcurrency;
var queue = _buckets.GetOrAdd(bucket, _ => new ConcurrentQueue<TaskCompletionSource>());
var tcs = new TaskCompletionSource();
queue.Enqueue(tcs);
private Task? _refillTask; ScheduleRefill();
private readonly ILogger _logger;
public LocalGatewayRatelimiter(ILogger logger, int maxConcurrency) return tcs.Task;
}
private void ScheduleRefill()
{
if (_refillTask != null && !_refillTask.IsCompleted)
return;
_refillTask?.Dispose();
_refillTask = RefillTask();
}
private async Task RefillTask()
{
await Task.Delay(TimeSpan.FromMilliseconds(250));
while (true)
{ {
_logger = logger.ForContext<LocalGatewayRatelimiter>(); var isClear = true;
_maxConcurrency = maxConcurrency; foreach (var (bucket, queue) in _buckets)
} {
if (!queue.TryDequeue(out var tcs))
continue;
public Task Identify(int shard) _logger.Debug(
{ "Allowing identify for bucket {BucketId} through ({QueueLength} left in bucket queue)",
var bucket = shard % _maxConcurrency; bucket, queue.Count);
var queue = _buckets.GetOrAdd(bucket, _ => new ConcurrentQueue<TaskCompletionSource>()); tcs.SetResult();
var tcs = new TaskCompletionSource(); isClear = false;
queue.Enqueue(tcs); }
ScheduleRefill(); if (isClear)
return tcs.Task;
}
private void ScheduleRefill()
{
if (_refillTask != null && !_refillTask.IsCompleted)
return; return;
_refillTask?.Dispose(); await Task.Delay(BucketLength);
_refillTask = RefillTask();
}
private async Task RefillTask()
{
await Task.Delay(TimeSpan.FromMilliseconds(250));
while (true)
{
var isClear = true;
foreach (var (bucket, queue) in _buckets)
{
if (!queue.TryDequeue(out var tcs))
continue;
_logger.Debug(
"Allowing identify for bucket {BucketId} through ({QueueLength} left in bucket queue)",
bucket, queue.Count);
tcs.SetResult();
isClear = false;
}
if (isClear)
return;
await Task.Delay(BucketLength);
}
} }
} }
} }

View File

@ -1,41 +1,30 @@
using System;
using System.Net.Http;
using System.Threading.Tasks;
using Serilog; using Serilog;
namespace Myriad.Gateway.Limit namespace Myriad.Gateway.Limit;
public class TwilightGatewayRatelimiter: IGatewayRatelimiter
{ {
public class TwilightGatewayRatelimiter: IGatewayRatelimiter private readonly HttpClient _httpClient = new() { Timeout = TimeSpan.FromSeconds(60) };
private readonly ILogger _logger;
private readonly string _url;
public TwilightGatewayRatelimiter(ILogger logger, string url)
{ {
private readonly string _url; _url = url;
private readonly ILogger _logger; _logger = logger.ForContext<TwilightGatewayRatelimiter>();
private readonly HttpClient _httpClient = new() }
{
Timeout = TimeSpan.FromSeconds(60)
};
public TwilightGatewayRatelimiter(ILogger logger, string url) public async Task Identify(int shard)
{ {
_url = url; while (true)
_logger = logger.ForContext<TwilightGatewayRatelimiter>(); try
}
public async Task Identify(int shard)
{
while (true)
{ {
try _logger.Information("Shard {ShardId}: Requesting identify at gateway queue {GatewayQueueUrl}",
{ shard, _url);
_logger.Information("Shard {ShardId}: Requesting identify at gateway queue {GatewayQueueUrl}", await _httpClient.GetAsync(_url);
shard, _url); return;
await _httpClient.GetAsync(_url);
return;
}
catch (TimeoutException)
{
}
} }
} catch (TimeoutException) { }
} }
} }

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GatewayHello(int HeartbeatInterval); public record GatewayHello(int HeartbeatInterval);
}

View File

@ -1,28 +1,27 @@
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
namespace Myriad.Gateway namespace Myriad.Gateway;
public record GatewayIdentify
{ {
public record GatewayIdentify public string Token { get; init; }
public ConnectionProperties Properties { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? Compress { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? LargeThreshold { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ShardInfo? Shard { get; init; }
public GatewayIntent Intents { get; init; }
public record ConnectionProperties
{ {
public string Token { get; init; } [JsonPropertyName("$os")] public string Os { get; init; }
public ConnectionProperties Properties { get; init; } [JsonPropertyName("$browser")] public string Browser { get; init; }
[JsonPropertyName("$device")] public string Device { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? Compress { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? LargeThreshold { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ShardInfo? Shard { get; init; }
public GatewayIntent Intents { get; init; }
public record ConnectionProperties
{
[JsonPropertyName("$os")] public string Os { get; init; }
[JsonPropertyName("$browser")] public string Browser { get; init; }
[JsonPropertyName("$device")] public string Device { get; init; }
}
} }
} }

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GatewayResume(string Token, string SessionId, int Seq); public record GatewayResume(string Token, string SessionId, int Seq);
}

View File

@ -3,23 +3,22 @@ using System.Text.Json.Serialization;
using Myriad.Serialization; using Myriad.Serialization;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record GatewayStatusUpdate
{
[JsonConverter(typeof(JsonSnakeCaseStringEnumConverter))]
public enum UserStatus
{
Online,
Dnd,
Idle,
Invisible,
Offline
}
public ulong? Since { get; init; } public record GatewayStatusUpdate
public ActivityPartial[]? Activities { get; init; } {
public UserStatus Status { get; init; } [JsonConverter(typeof(JsonSnakeCaseStringEnumConverter))]
public bool Afk { get; init; } public enum UserStatus
{
Online,
Dnd,
Idle,
Invisible,
Offline
} }
public ulong? Since { get; init; }
public ActivityPartial[]? Activities { get; init; }
public UserStatus Status { get; init; }
public bool Afk { get; init; }
} }

View File

@ -1,7 +1,5 @@
using System;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Gateway.Limit; using Myriad.Gateway.Limit;
using Myriad.Gateway.State; using Myriad.Gateway.State;
@ -11,214 +9,203 @@ using Myriad.Types;
using Serilog; using Serilog;
using Serilog.Context; using Serilog.Context;
namespace Myriad.Gateway namespace Myriad.Gateway;
public class Shard
{ {
public class Shard private const string LibraryName = "Myriad (for PluralKit)";
private readonly GatewaySettings _settings;
private readonly ShardInfo _info;
private readonly IGatewayRatelimiter _ratelimiter;
private readonly string _url;
private readonly ILogger _logger;
private readonly ShardStateManager _stateManager;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ShardConnection _conn;
public int ShardId => _info.ShardId;
public ShardState State => _stateManager.State;
public TimeSpan? Latency => _stateManager.Latency;
public User? User => _stateManager.User;
public ApplicationPartial? Application => _stateManager.Application;
// TODO: I wanna get rid of these or move them at some point
public event Func<IGatewayEvent, Task>? OnEventReceived;
public event Action<TimeSpan>? HeartbeatReceived;
public event Action? SocketOpened;
public event Action? Resumed;
public event Action? Ready;
public event Action<WebSocketCloseStatus?, string?>? SocketClosed;
private TimeSpan _reconnectDelay = TimeSpan.Zero;
private Task? _worker;
public Shard(GatewaySettings settings, ShardInfo info, IGatewayRatelimiter ratelimiter, string url, ILogger logger)
{ {
private const string LibraryName = "Myriad (for PluralKit)"; _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
private readonly GatewaySettings _settings; _settings = settings;
private readonly ShardInfo _info; _info = info;
private readonly IGatewayRatelimiter _ratelimiter; _ratelimiter = ratelimiter;
private readonly string _url; _url = url;
private readonly ILogger _logger; _logger = logger.ForContext<Shard>().ForContext("ShardId", info.ShardId);
private readonly ShardStateManager _stateManager; _stateManager = new ShardStateManager(info, _jsonSerializerOptions, logger)
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ShardConnection _conn;
public int ShardId => _info.ShardId;
public ShardState State => _stateManager.State;
public TimeSpan? Latency => _stateManager.Latency;
public User? User => _stateManager.User;
public ApplicationPartial? Application => _stateManager.Application;
// TODO: I wanna get rid of these or move them at some point
public event Func<IGatewayEvent, Task>? OnEventReceived;
public event Action<TimeSpan>? HeartbeatReceived;
public event Action? SocketOpened;
public event Action? Resumed;
public event Action? Ready;
public event Action<WebSocketCloseStatus?, string?>? SocketClosed;
private TimeSpan _reconnectDelay = TimeSpan.Zero;
private Task? _worker;
public Shard(GatewaySettings settings, ShardInfo info, IGatewayRatelimiter ratelimiter, string url, ILogger logger)
{ {
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad(); HandleEvent = HandleEvent,
SendHeartbeat = SendHeartbeat,
_settings = settings; SendIdentify = SendIdentify,
_info = info; SendResume = SendResume,
_ratelimiter = ratelimiter; Connect = ConnectInner,
_url = url; Reconnect = Reconnect,
_logger = logger.ForContext<Shard>().ForContext("ShardId", info.ShardId); };
_stateManager = new ShardStateManager(info, _jsonSerializerOptions, logger) _stateManager.OnHeartbeatReceived += latency =>
{
HandleEvent = HandleEvent,
SendHeartbeat = SendHeartbeat,
SendIdentify = SendIdentify,
SendResume = SendResume,
Connect = ConnectInner,
Reconnect = Reconnect,
};
_stateManager.OnHeartbeatReceived += latency =>
{
HeartbeatReceived?.Invoke(latency);
};
_conn = new ShardConnection(_jsonSerializerOptions, _logger);
}
private async Task ShardLoop()
{ {
// may be superfluous but this adds shard id to ambient context which is nice HeartbeatReceived?.Invoke(latency);
using var _ = LogContext.PushProperty("ShardId", _info.ShardId); };
while (true) _conn = new ShardConnection(_jsonSerializerOptions, _logger);
}
private async Task ShardLoop()
{
// may be superfluous but this adds shard id to ambient context which is nice
using var _ = LogContext.PushProperty("ShardId", _info.ShardId);
while (true)
{
try
{ {
try await ConnectInner();
await HandleConnectionOpened();
while (_conn.State == WebSocketState.Open)
{ {
await ConnectInner(); var packet = await _conn.Read();
if (packet == null)
break;
await HandleConnectionOpened(); await _stateManager.HandlePacketReceived(packet);
while (_conn.State == WebSocketState.Open)
{
var packet = await _conn.Read();
if (packet == null)
break;
await _stateManager.HandlePacketReceived(packet);
}
await HandleConnectionClosed(_conn.CloseStatus, _conn.CloseStatusDescription);
_logger.Information("Shard {ShardId}: Reconnecting after delay {ReconnectDelay}",
_info.ShardId, _reconnectDelay);
if (_reconnectDelay > TimeSpan.Zero)
await Task.Delay(_reconnectDelay);
_reconnectDelay = TimeSpan.Zero;
} }
catch (Exception e)
{
_logger.Error(e, "Shard {ShardId}: Error in main shard loop, reconnecting in 5 seconds...", _info.ShardId);
// todo: exponential backoff here? this should never happen, ideally... await HandleConnectionClosed(_conn.CloseStatus, _conn.CloseStatusDescription);
await Task.Delay(TimeSpan.FromSeconds(5));
} _logger.Information("Shard {ShardId}: Reconnecting after delay {ReconnectDelay}",
_info.ShardId, _reconnectDelay);
if (_reconnectDelay > TimeSpan.Zero)
await Task.Delay(_reconnectDelay);
_reconnectDelay = TimeSpan.Zero;
} }
} catch (Exception e)
public async Task Start()
{
if (_worker == null)
_worker = ShardLoop();
// Ideally we'd stagger the startups so we don't smash the websocket but that's difficult with the
// identify rate limiter so this is the best we can do rn, maybe?
await Task.Delay(200);
}
public async Task UpdateStatus(GatewayStatusUpdate payload)
{
await _conn.Send(new GatewayPacket
{ {
Opcode = GatewayOpcode.PresenceUpdate, _logger.Error(e, "Shard {ShardId}: Error in main shard loop, reconnecting in 5 seconds...", _info.ShardId);
Payload = payload
});
}
private async Task ConnectInner() // todo: exponential backoff here? this should never happen, ideally...
{ await Task.Delay(TimeSpan.FromSeconds(5));
while (true)
{
await _ratelimiter.Identify(_info.ShardId);
_logger.Information("Shard {ShardId}: Connecting to WebSocket", _info.ShardId);
try
{
await _conn.Connect(_url, default);
break;
}
catch (WebSocketException e)
{
_logger.Error(e, "Shard {ShardId}: Error connecting to WebSocket, retrying in 5 seconds...", _info.ShardId);
await Task.Delay(TimeSpan.FromSeconds(5));
}
} }
} }
}
private async Task DisconnectInner(WebSocketCloseStatus closeStatus)
{ public async Task Start()
await _conn.Disconnect(closeStatus, null); {
} if (_worker == null)
_worker = ShardLoop();
private async Task SendIdentify()
{ // Ideally we'd stagger the startups so we don't smash the websocket but that's difficult with the
await _conn.Send(new GatewayPacket // identify rate limiter so this is the best we can do rn, maybe?
{ await Task.Delay(200);
Opcode = GatewayOpcode.Identify, }
Payload = new GatewayIdentify
{ public async Task UpdateStatus(GatewayStatusUpdate payload)
Compress = false, => await _conn.Send(new GatewayPacket
Intents = _settings.Intents, {
Properties = new GatewayIdentify.ConnectionProperties Opcode = GatewayOpcode.PresenceUpdate,
{ Payload = payload
Browser = LibraryName, });
Device = LibraryName,
Os = Environment.OSVersion.ToString() private async Task ConnectInner()
}, {
Shard = _info, while (true)
Token = _settings.Token, {
LargeThreshold = 50 await _ratelimiter.Identify(_info.ShardId);
}
}); _logger.Information("Shard {ShardId}: Connecting to WebSocket", _info.ShardId);
} try
{
private async Task SendResume((string SessionId, int? LastSeq) arg) await _conn.Connect(_url, default);
{ break;
await _conn.Send(new GatewayPacket }
{ catch (WebSocketException e)
Opcode = GatewayOpcode.Resume, {
Payload = new GatewayResume(_settings.Token, arg.SessionId, arg.LastSeq ?? 0) _logger.Error(e, "Shard {ShardId}: Error connecting to WebSocket, retrying in 5 seconds...", _info.ShardId);
}); await Task.Delay(TimeSpan.FromSeconds(5));
} }
}
private async Task SendHeartbeat(int? lastSeq) }
{
await _conn.Send(new GatewayPacket { Opcode = GatewayOpcode.Heartbeat, Payload = lastSeq }); private Task DisconnectInner(WebSocketCloseStatus closeStatus)
} => _conn.Disconnect(closeStatus, null);
private async Task Reconnect(WebSocketCloseStatus closeStatus, TimeSpan delay) private async Task SendIdentify()
{ => await _conn.Send(new GatewayPacket
_reconnectDelay = delay; {
await DisconnectInner(closeStatus); Opcode = GatewayOpcode.Identify,
} Payload = new GatewayIdentify
{
private async Task HandleEvent(IGatewayEvent arg) Compress = false,
{ Intents = _settings.Intents,
if (arg is ReadyEvent) Properties = new GatewayIdentify.ConnectionProperties
Ready?.Invoke(); {
if (arg is ResumedEvent) Browser = LibraryName,
Resumed?.Invoke(); Device = LibraryName,
Os = Environment.OSVersion.ToString()
await (OnEventReceived?.Invoke(arg) ?? Task.CompletedTask); },
} Shard = _info,
Token = _settings.Token,
private async Task HandleConnectionOpened() LargeThreshold = 50
{ }
_logger.Information("Shard {ShardId}: Connection opened", _info.ShardId); });
await _stateManager.HandleConnectionOpened();
SocketOpened?.Invoke(); private async Task SendResume((string SessionId, int? LastSeq) arg)
} => await _conn.Send(new GatewayPacket
{
private async Task HandleConnectionClosed(WebSocketCloseStatus? closeStatus, string? description) Opcode = GatewayOpcode.Resume,
{ Payload = new GatewayResume(_settings.Token, arg.SessionId, arg.LastSeq ?? 0)
_logger.Information("Shard {ShardId}: Connection closed ({CloseStatus}/{Description})", });
_info.ShardId, closeStatus, description ?? "<null>");
await _stateManager.HandleConnectionClosed(); private async Task SendHeartbeat(int? lastSeq)
SocketClosed?.Invoke(closeStatus, description); => await _conn.Send(new GatewayPacket { Opcode = GatewayOpcode.Heartbeat, Payload = lastSeq });
}
private async Task Reconnect(WebSocketCloseStatus closeStatus, TimeSpan delay)
{
_reconnectDelay = delay;
await DisconnectInner(closeStatus);
}
private async Task HandleEvent(IGatewayEvent arg)
{
if (arg is ReadyEvent)
Ready?.Invoke();
if (arg is ResumedEvent)
Resumed?.Invoke();
await (OnEventReceived?.Invoke(arg) ?? Task.CompletedTask);
}
private async Task HandleConnectionOpened()
{
_logger.Information("Shard {ShardId}: Connection opened", _info.ShardId);
await _stateManager.HandleConnectionOpened();
SocketOpened?.Invoke();
}
private async Task HandleConnectionClosed(WebSocketCloseStatus? closeStatus, string? description)
{
_logger.Information("Shard {ShardId}: Connection closed ({CloseStatus}/{Description})",
_info.ShardId, closeStatus, description ?? "<null>");
await _stateManager.HandleConnectionClosed();
SocketClosed?.Invoke(closeStatus, description);
} }
} }

View File

@ -1,122 +1,115 @@
using System;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text.Json; using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Serilog; using Serilog;
namespace Myriad.Gateway namespace Myriad.Gateway;
public class ShardConnection: IAsyncDisposable
{ {
public class ShardConnection: IAsyncDisposable private readonly ILogger _logger;
private readonly ShardPacketSerializer _serializer;
private ClientWebSocket? _client;
public ShardConnection(JsonSerializerOptions jsonSerializerOptions, ILogger logger)
{ {
private ClientWebSocket? _client; _logger = logger.ForContext<ShardConnection>();
private readonly ILogger _logger; _serializer = new ShardPacketSerializer(jsonSerializerOptions);
private readonly ShardPacketSerializer _serializer; }
public WebSocketState State => _client?.State ?? WebSocketState.Closed; public WebSocketState State => _client?.State ?? WebSocketState.Closed;
public WebSocketCloseStatus? CloseStatus => _client?.CloseStatus; public WebSocketCloseStatus? CloseStatus => _client?.CloseStatus;
public string? CloseStatusDescription => _client?.CloseStatusDescription; public string? CloseStatusDescription => _client?.CloseStatusDescription;
public ShardConnection(JsonSerializerOptions jsonSerializerOptions, ILogger logger) public async ValueTask DisposeAsync()
{
await CloseInner(WebSocketCloseStatus.NormalClosure, null);
_client?.Dispose();
}
public async Task Connect(string url, CancellationToken ct)
{
_client?.Dispose();
_client = new ClientWebSocket();
await _client.ConnectAsync(GetConnectionUri(url), ct);
}
public async Task Disconnect(WebSocketCloseStatus closeStatus, string? reason)
{
await CloseInner(closeStatus, reason);
}
public async Task Send(GatewayPacket packet)
{
// from `ManagedWebSocket.s_validSendStates`
if (_client is not { State: WebSocketState.Open or WebSocketState.CloseReceived })
return;
try
{ {
_logger = logger.ForContext<ShardConnection>(); await _serializer.WritePacket(_client, packet);
_serializer = new(jsonSerializerOptions);
} }
catch (Exception e)
public async Task Connect(string url, CancellationToken ct)
{ {
_client?.Dispose(); _logger.Error(e, "Error sending WebSocket message");
_client = new ClientWebSocket();
await _client.ConnectAsync(GetConnectionUri(url), ct);
} }
}
public async Task Disconnect(WebSocketCloseStatus closeStatus, string? reason) public async Task<GatewayPacket?> Read()
{ {
await CloseInner(closeStatus, reason); // from `ManagedWebSocket.s_validReceiveStates`
} if (_client is not { State: WebSocketState.Open or WebSocketState.CloseSent })
public async Task Send(GatewayPacket packet)
{
// from `ManagedWebSocket.s_validSendStates`
if (_client is not { State: WebSocketState.Open or WebSocketState.CloseReceived })
return;
try
{
await _serializer.WritePacket(_client, packet);
}
catch (Exception e)
{
_logger.Error(e, "Error sending WebSocket message");
}
}
public async ValueTask DisposeAsync()
{
await CloseInner(WebSocketCloseStatus.NormalClosure, null);
_client?.Dispose();
}
public async Task<GatewayPacket?> Read()
{
// from `ManagedWebSocket.s_validReceiveStates`
if (_client is not { State: WebSocketState.Open or WebSocketState.CloseSent })
return null;
try
{
var (_, packet) = await _serializer.ReadPacket(_client);
return packet;
}
catch (Exception e)
{
_logger.Error(e, "Error reading from WebSocket");
// force close so we can "reset"
await CloseInner(WebSocketCloseStatus.NormalClosure, null);
}
return null; return null;
try
{
var (_, packet) = await _serializer.ReadPacket(_client);
return packet;
}
catch (Exception e)
{
_logger.Error(e, "Error reading from WebSocket");
// force close so we can "reset"
await CloseInner(WebSocketCloseStatus.NormalClosure, null);
} }
private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri) return null;
}
private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri) { Query = "v=9&encoding=json" }.Uri;
private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description)
{
if (_client == null)
return;
var client = _client;
_client = null;
// from `ManagedWebSocket.s_validCloseStates`
if (client.State is WebSocketState.Open or WebSocketState.CloseReceived or WebSocketState.CloseSent)
{ {
Query = "v=9&encoding=json" // Close with timeout, mostly to work around https://github.com/dotnet/runtime/issues/51590
}.Uri; var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description)
{
if (_client == null)
return;
var client = _client;
_client = null;
// from `ManagedWebSocket.s_validCloseStates`
if (client.State is WebSocketState.Open or WebSocketState.CloseReceived or WebSocketState.CloseSent)
{
// Close with timeout, mostly to work around https://github.com/dotnet/runtime/issues/51590
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
try
{
await client.CloseAsync(closeStatus, description, cts.Token);
}
catch (Exception e)
{
_logger.Error(e, "Error closing WebSocket connection");
}
}
// This shouldn't need to be wrapped in a try/catch but doing it anyway :/
try try
{ {
client.Dispose(); await client.CloseAsync(closeStatus, description, cts.Token);
} }
catch (Exception e) catch (Exception e)
{ {
_logger.Error(e, "Error disposing WebSocket connection"); _logger.Error(e, "Error closing WebSocket connection");
} }
} }
// This shouldn't need to be wrapped in a try/catch but doing it anyway :/
try
{
client.Dispose();
}
catch (Exception e)
{
_logger.Error(e, "Error disposing WebSocket connection");
}
} }
} }

View File

@ -1,4 +1,3 @@
namespace Myriad.Gateway namespace Myriad.Gateway;
{
public record ShardInfo(int ShardId, int NumShards); public record ShardInfo(int ShardId, int NumShards);
}

View File

@ -1,70 +1,68 @@
using System;
using System.Buffers; using System.Buffers;
using System.IO;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks;
namespace Myriad.Gateway namespace Myriad.Gateway;
public class ShardPacketSerializer
{ {
public class ShardPacketSerializer private const int BufferSize = 64 * 1024;
private readonly JsonSerializerOptions _jsonSerializerOptions;
public ShardPacketSerializer(JsonSerializerOptions jsonSerializerOptions)
{ {
private const int BufferSize = 64 * 1024; _jsonSerializerOptions = jsonSerializerOptions;
}
private readonly JsonSerializerOptions _jsonSerializerOptions; public async ValueTask<(WebSocketMessageType type, GatewayPacket? packet)> ReadPacket(ClientWebSocket socket)
{
using var buf = MemoryPool<byte>.Shared.Rent(BufferSize);
public ShardPacketSerializer(JsonSerializerOptions jsonSerializerOptions) var res = await socket.ReceiveAsync(buf.Memory, default);
if (res.MessageType == WebSocketMessageType.Close)
return (res.MessageType, null);
if (res.EndOfMessage)
// Entire packet fits within one buffer, deserialize directly
return DeserializeSingleBuffer(buf, res);
// Otherwise copy to stream buffer and deserialize from there
return await DeserializeMultipleBuffer(socket, buf, res);
}
public async Task WritePacket(ClientWebSocket socket, GatewayPacket packet)
{
var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions);
await socket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default);
}
private async Task<(WebSocketMessageType type, GatewayPacket packet)> DeserializeMultipleBuffer(
ClientWebSocket socket, IMemoryOwner<byte> buf, ValueWebSocketReceiveResult res)
{
await using var stream = new MemoryStream(BufferSize * 4);
stream.Write(buf.Memory.Span.Slice(0, res.Count));
while (!res.EndOfMessage)
{ {
_jsonSerializerOptions = jsonSerializerOptions; res = await socket.ReceiveAsync(buf.Memory, default);
}
public async ValueTask<(WebSocketMessageType type, GatewayPacket? packet)> ReadPacket(ClientWebSocket socket)
{
using var buf = MemoryPool<byte>.Shared.Rent(BufferSize);
var res = await socket.ReceiveAsync(buf.Memory, default);
if (res.MessageType == WebSocketMessageType.Close)
return (res.MessageType, null);
if (res.EndOfMessage)
// Entire packet fits within one buffer, deserialize directly
return DeserializeSingleBuffer(buf, res);
// Otherwise copy to stream buffer and deserialize from there
return await DeserializeMultipleBuffer(socket, buf, res);
}
public async Task WritePacket(ClientWebSocket socket, GatewayPacket packet)
{
var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions);
await socket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default);
}
private async Task<(WebSocketMessageType type, GatewayPacket packet)> DeserializeMultipleBuffer(ClientWebSocket socket, IMemoryOwner<byte> buf, ValueWebSocketReceiveResult res)
{
await using var stream = new MemoryStream(BufferSize * 4);
stream.Write(buf.Memory.Span.Slice(0, res.Count)); stream.Write(buf.Memory.Span.Slice(0, res.Count));
while (!res.EndOfMessage)
{
res = await socket.ReceiveAsync(buf.Memory, default);
stream.Write(buf.Memory.Span.Slice(0, res.Count));
}
return DeserializeObject(res, stream.GetBuffer().AsSpan(0, (int)stream.Length));
} }
private (WebSocketMessageType type, GatewayPacket packet) DeserializeSingleBuffer( return DeserializeObject(res, stream.GetBuffer().AsSpan(0, (int)stream.Length));
IMemoryOwner<byte> buf, ValueWebSocketReceiveResult res) }
{
var span = buf.Memory.Span.Slice(0, res.Count);
return DeserializeObject(res, span);
}
private (WebSocketMessageType type, GatewayPacket packet) DeserializeObject(ValueWebSocketReceiveResult res, Span<byte> span) private (WebSocketMessageType type, GatewayPacket packet) DeserializeSingleBuffer(
{ IMemoryOwner<byte> buf, ValueWebSocketReceiveResult res)
var packet = JsonSerializer.Deserialize<GatewayPacket>(span, _jsonSerializerOptions)!; {
return (res.MessageType, packet); var span = buf.Memory.Span.Slice(0, res.Count);
} return DeserializeObject(res, span);
}
private (WebSocketMessageType type, GatewayPacket packet) DeserializeObject(
ValueWebSocketReceiveResult res, Span<byte> span)
{
var packet = JsonSerializer.Deserialize<GatewayPacket>(span, _jsonSerializerOptions)!;
return (res.MessageType, packet);
} }
} }

View File

@ -1,63 +1,58 @@
using System; namespace Myriad.Gateway.State;
using System.Threading;
using System.Threading.Tasks;
namespace Myriad.Gateway.State public class HeartbeatWorker: IAsyncDisposable
{ {
public class HeartbeatWorker: IAsyncDisposable private Task? _worker;
private CancellationTokenSource? _workerCts;
public TimeSpan? CurrentHeartbeatInterval { get; private set; }
public async ValueTask DisposeAsync()
{ {
private Task? _worker; await Stop();
private CancellationTokenSource? _workerCts; }
public TimeSpan? CurrentHeartbeatInterval { get; private set; } public async ValueTask Start(TimeSpan heartbeatInterval, Func<Task> callback)
{
public async ValueTask Start(TimeSpan heartbeatInterval, Func<Task> callback) if (_worker != null)
{
if (_worker != null)
await Stop();
CurrentHeartbeatInterval = heartbeatInterval;
_workerCts = new CancellationTokenSource();
_worker = Worker(heartbeatInterval, callback, _workerCts.Token);
}
public async ValueTask Stop()
{
if (_worker == null)
return;
_workerCts?.Cancel();
try
{
await _worker;
}
catch (TaskCanceledException) { }
_worker?.Dispose();
_workerCts?.Dispose();
_worker = null;
CurrentHeartbeatInterval = null;
}
private async Task Worker(TimeSpan heartbeatInterval, Func<Task> callback, CancellationToken ct)
{
var initialDelay = GetInitialHeartbeatDelay(heartbeatInterval);
await Task.Delay(initialDelay, ct);
while (!ct.IsCancellationRequested)
{
await callback();
await Task.Delay(heartbeatInterval, ct);
}
}
private static TimeSpan GetInitialHeartbeatDelay(TimeSpan heartbeatInterval) =>
// Docs specify `heartbeat_interval * random.random()` but we'll add a lil buffer :)
heartbeatInterval * (new Random().NextDouble() * 0.9 + 0.05);
public async ValueTask DisposeAsync()
{
await Stop(); await Stop();
CurrentHeartbeatInterval = heartbeatInterval;
_workerCts = new CancellationTokenSource();
_worker = Worker(heartbeatInterval, callback, _workerCts.Token);
}
public async ValueTask Stop()
{
if (_worker == null)
return;
_workerCts?.Cancel();
try
{
await _worker;
}
catch (TaskCanceledException) { }
_worker?.Dispose();
_workerCts?.Dispose();
_worker = null;
CurrentHeartbeatInterval = null;
}
private async Task Worker(TimeSpan heartbeatInterval, Func<Task> callback, CancellationToken ct)
{
var initialDelay = GetInitialHeartbeatDelay(heartbeatInterval);
await Task.Delay(initialDelay, ct);
while (!ct.IsCancellationRequested)
{
await callback();
await Task.Delay(heartbeatInterval, ct);
} }
} }
private static TimeSpan GetInitialHeartbeatDelay(TimeSpan heartbeatInterval) =>
// Docs specify `heartbeat_interval * random.random()` but we'll add a lil buffer :)
heartbeatInterval * (new Random().NextDouble() * 0.9 + 0.05);
} }

View File

@ -1,11 +1,10 @@
namespace Myriad.Gateway.State namespace Myriad.Gateway.State;
public enum ShardState
{ {
public enum ShardState Disconnected,
{ Handshaking,
Disconnected, Identifying,
Handshaking, Connected,
Identifying, Reconnecting
Connected,
Reconnecting
}
} }

View File

@ -1,246 +1,246 @@
using System;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Gateway.State; using Myriad.Gateway.State;
using Myriad.Types; using Myriad.Types;
using Serilog; using Serilog;
namespace Myriad.Gateway namespace Myriad.Gateway;
public class ShardStateManager
{ {
public class ShardStateManager private readonly HeartbeatWorker _heartbeatWorker = new();
private readonly ShardInfo _info;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger;
private bool _hasReceivedHeartbeatAck;
private DateTimeOffset? _lastHeartbeatSent;
private int? _lastSeq;
private TimeSpan? _latency;
private string? _sessionId;
public ShardStateManager(ShardInfo info, JsonSerializerOptions jsonSerializerOptions, ILogger logger)
{ {
private readonly HeartbeatWorker _heartbeatWorker = new(); _info = info;
private readonly ILogger _logger; _jsonSerializerOptions = jsonSerializerOptions;
_logger = logger.ForContext<ShardStateManager>();
}
private readonly ShardInfo _info; public ShardState State { get; private set; } = ShardState.Disconnected;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private ShardState _state = ShardState.Disconnected;
private DateTimeOffset? _lastHeartbeatSent; public TimeSpan? Latency => _latency;
private TimeSpan? _latency; public User? User { get; private set; }
private bool _hasReceivedHeartbeatAck; public ApplicationPartial? Application { get; private set; }
private string? _sessionId; public Func<Task> SendIdentify { get; init; }
private int? _lastSeq; public Func<(string SessionId, int? LastSeq), Task> SendResume { get; init; }
public Func<int?, Task> SendHeartbeat { get; init; }
public Func<WebSocketCloseStatus, TimeSpan, Task> Reconnect { get; init; }
public Func<Task> Connect { get; init; }
public Func<IGatewayEvent, Task> HandleEvent { get; init; }
public ShardState State => _state; public event Action<TimeSpan> OnHeartbeatReceived;
public TimeSpan? Latency => _latency;
public User? User { get; private set; }
public ApplicationPartial? Application { get; private set; }
public Func<Task> SendIdentify { get; init; } public Task HandleConnectionOpened()
public Func<(string SessionId, int? LastSeq), Task> SendResume { get; init; } {
public Func<int?, Task> SendHeartbeat { get; init; } State = ShardState.Handshaking;
public Func<WebSocketCloseStatus, TimeSpan, Task> Reconnect { get; init; } return Task.CompletedTask;
public Func<Task> Connect { get; init; } }
public Func<IGatewayEvent, Task> HandleEvent { get; init; }
public event Action<TimeSpan> OnHeartbeatReceived; public async Task HandleConnectionClosed()
{
_latency = null;
await _heartbeatWorker.Stop();
}
public ShardStateManager(ShardInfo info, JsonSerializerOptions jsonSerializerOptions, ILogger logger) public async Task HandlePacketReceived(GatewayPacket packet)
{
switch (packet.Opcode)
{ {
_info = info; case GatewayOpcode.Hello:
_jsonSerializerOptions = jsonSerializerOptions; var hello = DeserializePayload<GatewayHello>(packet);
_logger = logger.ForContext<ShardStateManager>(); await HandleHello(hello);
} break;
public Task HandleConnectionOpened() case GatewayOpcode.Heartbeat:
{ await HandleHeartbeatRequest();
_state = ShardState.Handshaking; break;
return Task.CompletedTask;
}
public async Task HandleConnectionClosed() case GatewayOpcode.HeartbeatAck:
{ await HandleHeartbeatAck();
_latency = null; break;
await _heartbeatWorker.Stop();
}
public async Task HandlePacketReceived(GatewayPacket packet) case GatewayOpcode.Reconnect:
{ {
switch (packet.Opcode) await HandleReconnect();
{
case GatewayOpcode.Hello:
var hello = DeserializePayload<GatewayHello>(packet);
await HandleHello(hello);
break; break;
}
case GatewayOpcode.Heartbeat: case GatewayOpcode.InvalidSession:
await HandleHeartbeatRequest(); {
var canResume = DeserializePayload<bool>(packet);
await HandleInvalidSession(canResume);
break; break;
}
case GatewayOpcode.HeartbeatAck: case GatewayOpcode.Dispatch:
await HandleHeartbeatAck(); _lastSeq = packet.Sequence;
break;
case GatewayOpcode.Reconnect: var evt = DeserializeEvent(packet.EventType!, (JsonElement)packet.Payload!);
{ if (evt != null)
await HandleReconnect(); {
break; if (evt is ReadyEvent ready)
} await HandleReady(ready);
case GatewayOpcode.InvalidSession: if (evt is ResumedEvent)
{ await HandleResumed();
var canResume = DeserializePayload<bool>(packet);
await HandleInvalidSession(canResume);
break;
}
case GatewayOpcode.Dispatch: await HandleEvent(evt);
_lastSeq = packet.Sequence; }
var evt = DeserializeEvent(packet.EventType!, (JsonElement)packet.Payload!); break;
if (evt != null) }
{ }
if (evt is ReadyEvent ready)
await HandleReady(ready);
if (evt is ResumedEvent) private async Task HandleHello(GatewayHello hello)
await HandleResumed(); {
var interval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval);
await HandleEvent(evt); _hasReceivedHeartbeatAck = true;
} await _heartbeatWorker.Start(interval, HandleHeartbeatTimer);
break; await IdentifyOrResume();
} }
private async Task IdentifyOrResume()
{
State = ShardState.Identifying;
if (_sessionId != null)
{
_logger.Information("Shard {ShardId}: Received Hello, attempting to resume (seq {LastSeq})",
_info.ShardId, _lastSeq);
await SendResume((_sessionId!, _lastSeq));
}
else
{
_logger.Information("Shard {ShardId}: Received Hello, identifying",
_info.ShardId);
await SendIdentify();
}
}
private Task HandleHeartbeatAck()
{
_hasReceivedHeartbeatAck = true;
_latency = DateTimeOffset.UtcNow - _lastHeartbeatSent;
OnHeartbeatReceived?.Invoke(_latency!.Value);
_logger.Debug("Shard {ShardId}: Received Heartbeat (latency {Latency:N2} ms)",
_info.ShardId, _latency?.TotalMilliseconds);
return Task.CompletedTask;
}
private async Task HandleInvalidSession(bool canResume)
{
if (!canResume)
{
_sessionId = null;
_lastSeq = null;
} }
private async Task HandleHello(GatewayHello hello) _logger.Information("Shard {ShardId}: Received Invalid Session (can resume? {CanResume})",
{ _info.ShardId, canResume);
var interval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval);
_hasReceivedHeartbeatAck = true; var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000));
await _heartbeatWorker.Start(interval, HandleHeartbeatTimer); await DoReconnect(WebSocketCloseStatus.NormalClosure, delay);
await IdentifyOrResume(); }
private async Task HandleReconnect()
{
_logger.Information("Shard {ShardId}: Received Reconnect", _info.ShardId);
// close code 1000 kills the session, so can't reconnect
// we use 1005 (no error specified) instead
await DoReconnect(WebSocketCloseStatus.Empty, TimeSpan.FromSeconds(1));
}
private Task HandleReady(ReadyEvent ready)
{
_logger.Information("Shard {ShardId}: Received Ready", _info.ShardId);
_sessionId = ready.SessionId;
State = ShardState.Connected;
User = ready.User;
Application = ready.Application;
return Task.CompletedTask;
}
private Task HandleResumed()
{
_logger.Information("Shard {ShardId}: Received Resume", _info.ShardId);
State = ShardState.Connected;
return Task.CompletedTask;
}
private async Task HandleHeartbeatRequest()
{
await SendHeartbeatInternal();
}
private async Task SendHeartbeatInternal()
{
await SendHeartbeat(_lastSeq);
_lastHeartbeatSent = DateTimeOffset.UtcNow;
}
private async Task HandleHeartbeatTimer()
{
if (!_hasReceivedHeartbeatAck)
{
_logger.Warning("Shard {ShardId}: Heartbeat worker timed out", _info.ShardId);
await DoReconnect(WebSocketCloseStatus.ProtocolError, TimeSpan.Zero);
return;
} }
private async Task IdentifyOrResume() await SendHeartbeatInternal();
}
private async Task DoReconnect(WebSocketCloseStatus closeStatus, TimeSpan delay)
{
State = ShardState.Reconnecting;
await Reconnect(closeStatus, delay);
}
private T DeserializePayload<T>(GatewayPacket packet)
{
var packetPayload = (JsonElement)packet.Payload!;
return JsonSerializer.Deserialize<T>(packetPayload.GetRawText(), _jsonSerializerOptions)!;
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{ {
_state = ShardState.Identifying; _logger.Debug("Shard {ShardId}: Received unknown event type {EventType}", _info.ShardId, eventType);
return null;
if (_sessionId != null)
{
_logger.Information("Shard {ShardId}: Received Hello, attempting to resume (seq {LastSeq})",
_info.ShardId, _lastSeq);
await SendResume((_sessionId!, _lastSeq));
}
else
{
_logger.Information("Shard {ShardId}: Received Hello, identifying",
_info.ShardId);
await SendIdentify();
}
} }
private Task HandleHeartbeatAck() try
{ {
_hasReceivedHeartbeatAck = true; _logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", _info.ShardId, eventType,
_latency = DateTimeOffset.UtcNow - _lastHeartbeatSent; clrType);
OnHeartbeatReceived?.Invoke(_latency!.Value); return JsonSerializer.Deserialize(payload.GetRawText(), clrType, _jsonSerializerOptions)
_logger.Debug("Shard {ShardId}: Received Heartbeat (latency {Latency:N2} ms)", as IGatewayEvent;
_info.ShardId, _latency?.TotalMilliseconds);
return Task.CompletedTask;
} }
catch (JsonException e)
private async Task HandleInvalidSession(bool canResume)
{ {
if (!canResume) _logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", _info.ShardId,
{ eventType, clrType);
_sessionId = null; return null;
_lastSeq = null;
}
_logger.Information("Shard {ShardId}: Received Invalid Session (can resume? {CanResume})",
_info.ShardId, canResume);
var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000));
await DoReconnect(WebSocketCloseStatus.NormalClosure, delay);
}
private async Task HandleReconnect()
{
_logger.Information("Shard {ShardId}: Received Reconnect", _info.ShardId);
// close code 1000 kills the session, so can't reconnect
// we use 1005 (no error specified) instead
await DoReconnect(WebSocketCloseStatus.Empty, TimeSpan.FromSeconds(1));
}
private Task HandleReady(ReadyEvent ready)
{
_logger.Information("Shard {ShardId}: Received Ready", _info.ShardId);
_sessionId = ready.SessionId;
_state = ShardState.Connected;
User = ready.User;
Application = ready.Application;
return Task.CompletedTask;
}
private Task HandleResumed()
{
_logger.Information("Shard {ShardId}: Received Resume", _info.ShardId);
_state = ShardState.Connected;
return Task.CompletedTask;
}
private async Task HandleHeartbeatRequest()
{
await SendHeartbeatInternal();
}
private async Task SendHeartbeatInternal()
{
await SendHeartbeat(_lastSeq);
_lastHeartbeatSent = DateTimeOffset.UtcNow;
}
private async Task HandleHeartbeatTimer()
{
if (!_hasReceivedHeartbeatAck)
{
_logger.Warning("Shard {ShardId}: Heartbeat worker timed out", _info.ShardId);
await DoReconnect(WebSocketCloseStatus.ProtocolError, TimeSpan.Zero);
return;
}
await SendHeartbeatInternal();
}
private async Task DoReconnect(WebSocketCloseStatus closeStatus, TimeSpan delay)
{
_state = ShardState.Reconnecting;
await Reconnect(closeStatus, delay);
}
private T DeserializePayload<T>(GatewayPacket packet)
{
var packetPayload = (JsonElement)packet.Payload!;
return JsonSerializer.Deserialize<T>(packetPayload.GetRawText(), _jsonSerializerOptions)!;
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{
_logger.Debug("Shard {ShardId}: Received unknown event type {EventType}", _info.ShardId, eventType);
return null;
}
try
{
_logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", _info.ShardId, eventType, clrType);
return JsonSerializer.Deserialize(payload.GetRawText(), clrType, _jsonSerializerOptions)
as IGatewayEvent;
}
catch (JsonException e)
{
_logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", _info.ShardId, eventType, clrType);
return null;
}
} }
} }
} }

View File

@ -1,8 +1,9 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup> <PropertyGroup>
<TargetFramework>net5.0</TargetFramework> <TargetFramework>net6.0</TargetFramework>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
@ -20,10 +21,10 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Polly" Version="7.2.1" /> <PackageReference Include="Polly" Version="7.2.1"/>
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" /> <PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1"/>
<PackageReference Include="Serilog" Version="2.10.0" /> <PackageReference Include="Serilog" Version="2.10.0"/>
<PackageReference Include="System.Linq.Async" Version="5.0.0" /> <PackageReference Include="System.Linq.Async" Version="5.0.0"/>
</ItemGroup> </ItemGroup>
</Project> </Project>

View File

@ -1,13 +1,9 @@
using System;
using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Net; using System.Net;
using System.Net.Http;
using System.Net.Http.Headers; using System.Net.Http.Headers;
using System.Net.Http.Json; using System.Net.Http.Json;
using System.Text.Json; using System.Text.Json;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Myriad.Rest.Exceptions; using Myriad.Rest.Exceptions;
using Myriad.Rest.Ratelimit; using Myriad.Rest.Ratelimit;
@ -19,305 +15,306 @@ using Polly;
using Serilog; using Serilog;
using Serilog.Context; using Serilog.Context;
namespace Myriad.Rest namespace Myriad.Rest;
public class BaseRestClient: IAsyncDisposable
{ {
public class BaseRestClient: IAsyncDisposable private readonly string _baseUrl;
private readonly Version _httpVersion = new(2, 0);
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger;
private readonly Ratelimiter _ratelimiter;
private readonly AsyncPolicy<HttpResponseMessage> _retryPolicy;
public EventHandler<(string, int, long)> OnResponseEvent;
public BaseRestClient(string userAgent, string token, ILogger logger, string baseUrl)
{ {
private readonly Version _httpVersion = new(2, 0); _logger = logger.ForContext<BaseRestClient>();
private readonly JsonSerializerOptions _jsonSerializerOptions; _baseUrl = baseUrl;
private readonly ILogger _logger;
private readonly Ratelimiter _ratelimiter;
private readonly AsyncPolicy<HttpResponseMessage> _retryPolicy;
private readonly string _baseUrl;
public BaseRestClient(string userAgent, string token, ILogger logger, string baseUrl) if (!token.StartsWith("Bot "))
{ token = "Bot " + token;
_logger = logger.ForContext<BaseRestClient>();
_baseUrl = baseUrl;
if (!token.StartsWith("Bot ")) Client = new HttpClient();
token = "Bot " + token; Client.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgent);
Client.DefaultRequestHeaders.TryAddWithoutValidation("Authorization", token);
Client = new HttpClient(); _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
Client.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgent);
Client.DefaultRequestHeaders.TryAddWithoutValidation("Authorization", token);
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad(); _ratelimiter = new Ratelimiter(logger);
var discordPolicy = new DiscordRateLimitPolicy(_ratelimiter);
_ratelimiter = new Ratelimiter(logger); // todo: why doesn't the timeout work? o.o
var discordPolicy = new DiscordRateLimitPolicy(_ratelimiter); var timeoutPolicy = Policy.TimeoutAsync<HttpResponseMessage>(TimeSpan.FromSeconds(10));
// todo: why doesn't the timeout work? o.o var waitPolicy = Policy
var timeoutPolicy = Policy.TimeoutAsync<HttpResponseMessage>(TimeSpan.FromSeconds(10)); .Handle<RatelimitBucketExhaustedException>()
.WaitAndRetryAsync(3,
(_, e, _) => ((RatelimitBucketExhaustedException)e).RetryAfter,
(_, _, _, _) => Task.CompletedTask)
.AsAsyncPolicy<HttpResponseMessage>();
var waitPolicy = Policy _retryPolicy = Policy.WrapAsync(timeoutPolicy, waitPolicy, discordPolicy);
.Handle<RatelimitBucketExhaustedException>() }
.WaitAndRetryAsync(3,
(_, e, _) => ((RatelimitBucketExhaustedException)e).RetryAfter,
(_, _, _, _) => Task.CompletedTask)
.AsAsyncPolicy<HttpResponseMessage>();
_retryPolicy = Policy.WrapAsync(timeoutPolicy, waitPolicy, discordPolicy); public HttpClient Client { get; }
}
public HttpClient Client { get; } public ValueTask DisposeAsync()
public EventHandler<(string, int, long)> OnResponseEvent; {
_ratelimiter.Dispose();
Client.Dispose();
return default;
}
public ValueTask DisposeAsync() public async Task<T?> Get<T>(string path, (string endpointName, ulong major) ratelimitParams) where T : class
{ {
_ratelimiter.Dispose(); using var response = await Send(() => new HttpRequestMessage(HttpMethod.Get, _baseUrl + path),
Client.Dispose(); ratelimitParams, true);
return default;
}
public async Task<T?> Get<T>(string path, (string endpointName, ulong major) ratelimitParams) where T : class
{
using var response = await Send(() => new HttpRequestMessage(HttpMethod.Get, _baseUrl + path),
ratelimitParams, true);
// GET-only special case: 404s are nulls and not exceptions
if (response.StatusCode == HttpStatusCode.NotFound)
return null;
return await ReadResponse<T>(response);
}
public async Task<T?> Post<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T : class
{
using var response = await Send(() =>
{
var request = new HttpRequestMessage(HttpMethod.Post, _baseUrl + path);
SetRequestJsonBody(request, body);
return request;
}, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> PostMultipart<T>(string path, (string endpointName, ulong major) ratelimitParams, object? payload, MultipartFile[]? files)
where T : class
{
using var response = await Send(() =>
{
var request = new HttpRequestMessage(HttpMethod.Post, _baseUrl + path);
SetRequestFormDataBody(request, payload, files);
return request;
}, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Patch<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T : class
{
using var response = await Send(() =>
{
var request = new HttpRequestMessage(HttpMethod.Patch, _baseUrl + path);
SetRequestJsonBody(request, body);
return request;
}, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Put<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T : class
{
using var response = await Send(() =>
{
var request = new HttpRequestMessage(HttpMethod.Put, _baseUrl + path);
SetRequestJsonBody(request, body);
return request;
}, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task Delete(string path, (string endpointName, ulong major) ratelimitParams)
{
using var _ = await Send(() => new HttpRequestMessage(HttpMethod.Delete, _baseUrl + path), ratelimitParams);
}
private void SetRequestJsonBody(HttpRequestMessage request, object? body)
{
if (body == null) return;
request.Content =
new ReadOnlyMemoryContent(JsonSerializer.SerializeToUtf8Bytes(body, _jsonSerializerOptions));
request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json");
}
private void SetRequestFormDataBody(HttpRequestMessage request, object? payload, MultipartFile[]? files)
{
var bodyJson = JsonSerializer.SerializeToUtf8Bytes(payload, _jsonSerializerOptions);
var mfd = new MultipartFormDataContent();
mfd.Add(new ByteArrayContent(bodyJson), "payload_json");
if (files != null)
{
for (var i = 0; i < files.Length; i++)
{
var (filename, stream, _) = files[i];
mfd.Add(new StreamContent(stream), $"files[{i}]", filename);
}
}
request.Content = mfd;
}
private async Task<T?> ReadResponse<T>(HttpResponseMessage response) where T : class
{
if (response.StatusCode == HttpStatusCode.NoContent)
return null;
return await response.Content.ReadFromJsonAsync<T>(_jsonSerializerOptions);
}
private async Task<HttpResponseMessage> Send(Func<HttpRequestMessage> createRequest,
(string endpointName, ulong major) ratelimitParams,
bool ignoreNotFound = false)
{
return await _retryPolicy.ExecuteAsync(async _ =>
{
using var __ = LogContext.PushProperty("EndpointName", ratelimitParams.endpointName);
var request = createRequest();
_logger.Debug("Request: {RequestMethod} {RequestPath}",
request.Method, CleanForLogging(request.RequestUri!));
request.Version = _httpVersion;
request.VersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
HttpResponseMessage response;
var stopwatch = new Stopwatch();
stopwatch.Start();
try
{
response = await Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
stopwatch.Stop();
}
catch (Exception exc)
{
_logger.Error(exc, "HTTP error: {RequestMethod} {RequestUrl}", request.Method, request.RequestUri);
// kill the running thread
// in PluralKit.Bot, this error is ignored in "IsOurProblem" (PluralKit.Bot/Utils/MiscUtils.cs)
throw;
}
_logger.Debug(
"Response: {RequestMethod} {RequestPath} -> {StatusCode} {ReasonPhrase} (in {ResponseDurationMs} ms)",
request.Method, CleanForLogging(request.RequestUri!), (int)response.StatusCode,
response.ReasonPhrase, stopwatch.ElapsedMilliseconds);
await HandleApiError(response, ignoreNotFound);
OnResponseEvent?.Invoke(null, (
GetEndpointMetricsName(response.RequestMessage!),
(int)response.StatusCode,
stopwatch.ElapsedTicks
));
return response;
},
new Dictionary<string, object>
{
{DiscordRateLimitPolicy.EndpointContextKey, ratelimitParams.endpointName},
{DiscordRateLimitPolicy.MajorContextKey, ratelimitParams.major}
});
}
private async ValueTask HandleApiError(HttpResponseMessage response, bool ignoreNotFound)
{
if (response.IsSuccessStatusCode)
return;
if (response.StatusCode == HttpStatusCode.NotFound && ignoreNotFound)
return;
var body = await response.Content.ReadAsStringAsync();
var apiError = TryParseApiError(body);
if (apiError != null)
_logger.Warning("Discord API error: {DiscordErrorCode} {DiscordErrorMessage}", apiError.Code, apiError.Message);
throw CreateDiscordException(response, body, apiError);
}
private DiscordRequestException CreateDiscordException(HttpResponseMessage response, string body, DiscordApiError? apiError)
{
return response.StatusCode switch
{
HttpStatusCode.BadRequest => new BadRequestException(response, body, apiError),
HttpStatusCode.Forbidden => new ForbiddenException(response, body, apiError),
HttpStatusCode.Unauthorized => new UnauthorizedException(response, body, apiError),
HttpStatusCode.NotFound => new NotFoundException(response, body, apiError),
HttpStatusCode.Conflict => new ConflictException(response, body, apiError),
HttpStatusCode.TooManyRequests => new TooManyRequestsException(response, body, apiError),
_ => new UnknownDiscordRequestException(response, body, apiError)
};
}
private DiscordApiError? TryParseApiError(string responseBody)
{
if (string.IsNullOrWhiteSpace(responseBody))
return null;
try
{
return JsonSerializer.Deserialize<DiscordApiError>(responseBody, _jsonSerializerOptions);
}
catch (JsonException e)
{
_logger.Verbose(e, "Error deserializing API error");
}
// GET-only special case: 404s are nulls and not exceptions
if (response.StatusCode == HttpStatusCode.NotFound)
return null; return null;
}
private string NormalizeRoutePath(string url) return await ReadResponse<T>(response);
}
public async Task<T?> Post<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T : class
{
using var response = await Send(() =>
{ {
url = Regex.Replace(url, @"/channels/\d+", "/channels/{channel_id}"); var request = new HttpRequestMessage(HttpMethod.Post, _baseUrl + path);
url = Regex.Replace(url, @"/messages/\d+", "/messages/{message_id}"); SetRequestJsonBody(request, body);
url = Regex.Replace(url, @"/members/\d+", "/members/{user_id}"); return request;
url = Regex.Replace(url, @"/webhooks/\d+/[^/]+", "/webhooks/{webhook_id}/{webhook_token}"); }, ratelimitParams);
url = Regex.Replace(url, @"/webhooks/\d+", "/webhooks/{webhook_id}"); return await ReadResponse<T>(response);
url = Regex.Replace(url, @"/users/\d+", "/users/{user_id}"); }
url = Regex.Replace(url, @"/bans/\d+", "/bans/{user_id}");
url = Regex.Replace(url, @"/roles/\d+", "/roles/{role_id}");
url = Regex.Replace(url, @"/pins/\d+", "/pins/{message_id}");
url = Regex.Replace(url, @"/emojis/\d+", "/emojis/{emoji_id}");
url = Regex.Replace(url, @"/guilds/\d+", "/guilds/{guild_id}");
url = Regex.Replace(url, @"/integrations/\d+", "/integrations/{integration_id}");
url = Regex.Replace(url, @"/permissions/\d+", "/permissions/{overwrite_id}");
url = Regex.Replace(url, @"/reactions/[^{/]+/\d+", "/reactions/{emoji}/{user_id}");
url = Regex.Replace(url, @"/reactions/[^{/]+", "/reactions/{emoji}");
url = Regex.Replace(url, @"/invites/[^{/]+", "/invites/{invite_code}");
url = Regex.Replace(url, @"/interactions/\d+/[^{/]+", "/interactions/{interaction_id}/{interaction_token}");
url = Regex.Replace(url, @"/interactions/\d+", "/interactions/{interaction_id}");
// catch-all for missed IDs public async Task<T?> PostMultipart<T>(string path, (string endpointName, ulong major) ratelimitParams,
url = Regex.Replace(url, @"\d{17,19}", "{snowflake}"); object? payload, MultipartFile[]? files)
where T : class
return url; {
} using var response = await Send(() =>
private string GetEndpointMetricsName(HttpRequestMessage req)
{ {
var localPath = Regex.Replace(req.RequestUri!.LocalPath, @"/api/v\d+", ""); var request = new HttpRequestMessage(HttpMethod.Post, _baseUrl + path);
var routePath = NormalizeRoutePath(localPath); SetRequestFormDataBody(request, payload, files);
return $"{req.Method} {routePath}"; return request;
} }, ratelimitParams);
return await ReadResponse<T>(response);
}
private string CleanForLogging(Uri uri) public async Task<T?> Patch<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T : class
{
using var response = await Send(() =>
{ {
var path = uri.ToString(); var request = new HttpRequestMessage(HttpMethod.Patch, _baseUrl + path);
SetRequestJsonBody(request, body);
return request;
}, ratelimitParams);
return await ReadResponse<T>(response);
}
// don't show tokens in logs public async Task<T?> Put<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
// todo: anything missing here? where T : class
path = Regex.Replace(path, @"/webhooks/(\d+)/[^/]+", "/webhooks/$1/:token"); {
path = Regex.Replace(path, @"/interactions/(\d+)/[^{/]+", "/interactions/$1/:token"); using var response = await Send(() =>
{
var request = new HttpRequestMessage(HttpMethod.Put, _baseUrl + path);
SetRequestJsonBody(request, body);
return request;
}, ratelimitParams);
return await ReadResponse<T>(response);
}
// remove base URL public async Task Delete(string path, (string endpointName, ulong major) ratelimitParams)
path = path.Substring(_baseUrl.Length); {
using var _ = await Send(() => new HttpRequestMessage(HttpMethod.Delete, _baseUrl + path), ratelimitParams);
}
return path; private void SetRequestJsonBody(HttpRequestMessage request, object? body)
{
if (body == null) return;
request.Content =
new ReadOnlyMemoryContent(JsonSerializer.SerializeToUtf8Bytes(body, _jsonSerializerOptions));
request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json");
}
private void SetRequestFormDataBody(HttpRequestMessage request, object? payload, MultipartFile[]? files)
{
var bodyJson = JsonSerializer.SerializeToUtf8Bytes(payload, _jsonSerializerOptions);
var mfd = new MultipartFormDataContent();
mfd.Add(new ByteArrayContent(bodyJson), "payload_json");
if (files != null)
for (var i = 0; i < files.Length; i++)
{
var (filename, stream, _) = files[i];
mfd.Add(new StreamContent(stream), $"files[{i}]", filename);
}
request.Content = mfd;
}
private async Task<T?> ReadResponse<T>(HttpResponseMessage response) where T : class
{
if (response.StatusCode == HttpStatusCode.NoContent)
return null;
return await response.Content.ReadFromJsonAsync<T>(_jsonSerializerOptions);
}
private async Task<HttpResponseMessage> Send(Func<HttpRequestMessage> createRequest,
(string endpointName, ulong major) ratelimitParams,
bool ignoreNotFound = false)
{
return await _retryPolicy.ExecuteAsync(async _ =>
{
using var __ = LogContext.PushProperty("EndpointName", ratelimitParams.endpointName);
var request = createRequest();
_logger.Debug("Request: {RequestMethod} {RequestPath}",
request.Method, CleanForLogging(request.RequestUri!));
request.Version = _httpVersion;
request.VersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
HttpResponseMessage response;
var stopwatch = new Stopwatch();
stopwatch.Start();
try
{
response = await Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
stopwatch.Stop();
}
catch (Exception exc)
{
_logger.Error(exc, "HTTP error: {RequestMethod} {RequestUrl}", request.Method,
request.RequestUri);
// kill the running thread
// in PluralKit.Bot, this error is ignored in "IsOurProblem" (PluralKit.Bot/Utils/MiscUtils.cs)
throw;
}
_logger.Debug(
"Response: {RequestMethod} {RequestPath} -> {StatusCode} {ReasonPhrase} (in {ResponseDurationMs} ms)",
request.Method, CleanForLogging(request.RequestUri!), (int)response.StatusCode,
response.ReasonPhrase, stopwatch.ElapsedMilliseconds);
await HandleApiError(response, ignoreNotFound);
OnResponseEvent?.Invoke(null, (
GetEndpointMetricsName(response.RequestMessage!),
(int)response.StatusCode,
stopwatch.ElapsedTicks
));
return response;
},
new Dictionary<string, object>
{
{DiscordRateLimitPolicy.EndpointContextKey, ratelimitParams.endpointName},
{DiscordRateLimitPolicy.MajorContextKey, ratelimitParams.major}
});
}
private async ValueTask HandleApiError(HttpResponseMessage response, bool ignoreNotFound)
{
if (response.IsSuccessStatusCode)
return;
if (response.StatusCode == HttpStatusCode.NotFound && ignoreNotFound)
return;
var body = await response.Content.ReadAsStringAsync();
var apiError = TryParseApiError(body);
if (apiError != null)
_logger.Warning("Discord API error: {DiscordErrorCode} {DiscordErrorMessage}", apiError.Code,
apiError.Message);
throw CreateDiscordException(response, body, apiError);
}
private DiscordRequestException CreateDiscordException(HttpResponseMessage response, string body,
DiscordApiError? apiError)
{
return response.StatusCode switch
{
HttpStatusCode.BadRequest => new BadRequestException(response, body, apiError),
HttpStatusCode.Forbidden => new ForbiddenException(response, body, apiError),
HttpStatusCode.Unauthorized => new UnauthorizedException(response, body, apiError),
HttpStatusCode.NotFound => new NotFoundException(response, body, apiError),
HttpStatusCode.Conflict => new ConflictException(response, body, apiError),
HttpStatusCode.TooManyRequests => new TooManyRequestsException(response, body, apiError),
_ => new UnknownDiscordRequestException(response, body, apiError)
};
}
private DiscordApiError? TryParseApiError(string responseBody)
{
if (string.IsNullOrWhiteSpace(responseBody))
return null;
try
{
return JsonSerializer.Deserialize<DiscordApiError>(responseBody, _jsonSerializerOptions);
} }
catch (JsonException e)
{
_logger.Verbose(e, "Error deserializing API error");
}
return null;
}
private string NormalizeRoutePath(string url)
{
url = Regex.Replace(url, @"/channels/\d+", "/channels/{channel_id}");
url = Regex.Replace(url, @"/messages/\d+", "/messages/{message_id}");
url = Regex.Replace(url, @"/members/\d+", "/members/{user_id}");
url = Regex.Replace(url, @"/webhooks/\d+/[^/]+", "/webhooks/{webhook_id}/{webhook_token}");
url = Regex.Replace(url, @"/webhooks/\d+", "/webhooks/{webhook_id}");
url = Regex.Replace(url, @"/users/\d+", "/users/{user_id}");
url = Regex.Replace(url, @"/bans/\d+", "/bans/{user_id}");
url = Regex.Replace(url, @"/roles/\d+", "/roles/{role_id}");
url = Regex.Replace(url, @"/pins/\d+", "/pins/{message_id}");
url = Regex.Replace(url, @"/emojis/\d+", "/emojis/{emoji_id}");
url = Regex.Replace(url, @"/guilds/\d+", "/guilds/{guild_id}");
url = Regex.Replace(url, @"/integrations/\d+", "/integrations/{integration_id}");
url = Regex.Replace(url, @"/permissions/\d+", "/permissions/{overwrite_id}");
url = Regex.Replace(url, @"/reactions/[^{/]+/\d+", "/reactions/{emoji}/{user_id}");
url = Regex.Replace(url, @"/reactions/[^{/]+", "/reactions/{emoji}");
url = Regex.Replace(url, @"/invites/[^{/]+", "/invites/{invite_code}");
url = Regex.Replace(url, @"/interactions/\d+/[^{/]+", "/interactions/{interaction_id}/{interaction_token}");
url = Regex.Replace(url, @"/interactions/\d+", "/interactions/{interaction_id}");
// catch-all for missed IDs
url = Regex.Replace(url, @"\d{17,19}", "{snowflake}");
return url;
}
private string GetEndpointMetricsName(HttpRequestMessage req)
{
var localPath = Regex.Replace(req.RequestUri!.LocalPath, @"/api/v\d+", "");
var routePath = NormalizeRoutePath(localPath);
return $"{req.Method} {routePath}";
}
private string CleanForLogging(Uri uri)
{
var path = uri.ToString();
// don't show tokens in logs
// todo: anything missing here?
path = Regex.Replace(path, @"/webhooks/(\d+)/[^/]+", "/webhooks/$1/:token");
path = Regex.Replace(path, @"/interactions/(\d+)/[^{/]+", "/interactions/$1/:token");
// remove base URL
path = path.Substring(_baseUrl.Length);
return path;
} }
} }

View File

@ -1,6 +1,4 @@
using System;
using System.Net; using System.Net;
using System.Threading.Tasks;
using Myriad.Rest.Types; using Myriad.Rest.Types;
using Myriad.Rest.Types.Requests; using Myriad.Rest.Types.Requests;
@ -8,143 +6,146 @@ using Myriad.Types;
using Serilog; using Serilog;
namespace Myriad.Rest namespace Myriad.Rest;
public class DiscordApiClient
{ {
public class DiscordApiClient public const string UserAgent = "DiscordBot (https://github.com/xSke/PluralKit/tree/main/Myriad/, v1)";
private const string DefaultApiBaseUrl = "https://discord.com/api/v9";
private readonly BaseRestClient _client;
public EventHandler<(string, int, long)> OnResponseEvent;
public DiscordApiClient(string token, ILogger logger, string? baseUrl = null)
{ {
public const string UserAgent = "DiscordBot (https://github.com/xSke/PluralKit/tree/main/Myriad/, v1)"; _client = new BaseRestClient(UserAgent, token, logger, baseUrl ?? DefaultApiBaseUrl);
private const string DefaultApiBaseUrl = "https://discord.com/api/v9"; _client.OnResponseEvent += (_, ev) => OnResponseEvent?.Invoke(null, ev);
private readonly BaseRestClient _client;
public DiscordApiClient(string token, ILogger logger, string? baseUrl = null)
{
_client = new BaseRestClient(UserAgent, token, logger, baseUrl ?? DefaultApiBaseUrl);
_client.OnResponseEvent += (_, ev) => OnResponseEvent?.Invoke(null, ev);
}
public EventHandler<(string, int, long)> OnResponseEvent;
public Task<GatewayInfo> GetGateway() =>
_client.Get<GatewayInfo>("/gateway", ("GetGateway", default))!;
public Task<GatewayInfo.Bot> GetGatewayBot() =>
_client.Get<GatewayInfo.Bot>("/gateway/bot", ("GetGatewayBot", default))!;
public Task<Channel?> GetChannel(ulong channelId) =>
_client.Get<Channel>($"/channels/{channelId}", ("GetChannel", channelId));
public Task<Message?> GetMessage(ulong channelId, ulong messageId) =>
_client.Get<Message>($"/channels/{channelId}/messages/{messageId}", ("GetMessage", channelId));
public Task<Guild?> GetGuild(ulong id) =>
_client.Get<Guild>($"/guilds/{id}", ("GetGuild", id));
public Task<Channel[]> GetGuildChannels(ulong id) =>
_client.Get<Channel[]>($"/guilds/{id}/channels", ("GetGuildChannels", id))!;
public Task<User?> GetUser(ulong id) =>
_client.Get<User>($"/users/{id}", ("GetUser", default));
public Task<GuildMember?> GetGuildMember(ulong guildId, ulong userId) =>
_client.Get<GuildMember>($"/guilds/{guildId}/members/{userId}",
("GetGuildMember", guildId));
public Task<Message> CreateMessage(ulong channelId, MessageRequest request, MultipartFile[]? files = null) =>
_client.PostMultipart<Message>($"/channels/{channelId}/messages", ("CreateMessage", channelId), request, files)!;
public Task<Message> EditMessage(ulong channelId, ulong messageId, MessageEditRequest request) =>
_client.Patch<Message>($"/channels/{channelId}/messages/{messageId}", ("EditMessage", channelId), request)!;
public Task DeleteMessage(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}", ("DeleteMessage", channelId));
public Task DeleteMessage(Message message) =>
_client.Delete($"/channels/{message.ChannelId}/messages/{message.Id}", ("DeleteMessage", message.ChannelId));
public Task CreateReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Put<object>($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("CreateReaction", channelId), null);
public Task DeleteOwnReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("DeleteOwnReaction", channelId));
public Task DeleteUserReaction(ulong channelId, ulong messageId, Emoji emoji, ulong userId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/{userId}",
("DeleteUserReaction", channelId));
public Task DeleteAllReactions(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions",
("DeleteAllReactions", channelId));
public Task DeleteAllReactionsForEmoji(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}",
("DeleteAllReactionsForEmoji", channelId));
public Task<ApplicationCommand> CreateGlobalApplicationCommand(ulong applicationId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/commands",
("CreateGlobalApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand[]> GetGuildApplicationCommands(ulong applicationId, ulong guildId) =>
_client.Get<ApplicationCommand[]>($"/applications/{applicationId}/guilds/{guildId}/commands",
("GetGuildApplicationCommands", applicationId))!;
public Task<ApplicationCommand> CreateGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("CreateGuildApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand> EditGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Patch<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("EditGuildApplicationCommand", applicationId), request)!;
public Task DeleteGuildApplicationCommand(ulong applicationId, ulong commandId) =>
_client.Delete($"/applications/{applicationId}/commands/{commandId}",
("DeleteGuildApplicationCommand", applicationId));
public Task CreateInteractionResponse(ulong interactionId, string token, InteractionResponse response) =>
_client.Post<object>($"/interactions/{interactionId}/{token}/callback",
("CreateInteractionResponse", interactionId), response);
public Task ModifyGuildMember(ulong guildId, ulong userId, ModifyGuildMemberRequest request) =>
_client.Patch<object>($"/guilds/{guildId}/members/{userId}",
("ModifyGuildMember", guildId), request);
public Task<Webhook> CreateWebhook(ulong channelId, CreateWebhookRequest request) =>
_client.Post<Webhook>($"/channels/{channelId}/webhooks", ("CreateWebhook", channelId), request)!;
public Task<Webhook> GetWebhook(ulong webhookId) =>
_client.Get<Webhook>($"/webhooks/{webhookId}/webhooks", ("GetWebhook", webhookId))!;
public Task<Webhook[]> GetChannelWebhooks(ulong channelId) =>
_client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!;
public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request,
MultipartFile[]? files = null, ulong? threadId = null)
{
var url = $"/webhooks/{webhookId}/{webhookToken}?wait=true";
if (threadId != null)
url += $"&thread_id={threadId}";
return _client.PostMultipart<Message>(url,
("ExecuteWebhook", webhookId), request, files)!;
}
public Task<Message> EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId,
WebhookMessageEditRequest request, ulong? threadId = null)
{
var url = $"/webhooks/{webhookId}/{webhookToken}/messages/{messageId}";
if (threadId != null)
url += $"?thread_id={threadId}";
return _client.Patch<Message>(url, ("EditWebhookMessage", webhookId), request)!;
}
public Task<Channel> CreateDm(ulong recipientId) =>
_client.Post<Channel>($"/users/@me/channels", ("CreateDM", default), new CreateDmRequest(recipientId))!;
private static string EncodeEmoji(Emoji emoji) =>
WebUtility.UrlEncode(emoji.Id != null ? $"{emoji.Name}:{emoji.Id}" : emoji.Name) ??
throw new ArgumentException("Could not encode emoji");
} }
public Task<GatewayInfo> GetGateway() =>
_client.Get<GatewayInfo>("/gateway", ("GetGateway", default))!;
public Task<GatewayInfo.Bot> GetGatewayBot() =>
_client.Get<GatewayInfo.Bot>("/gateway/bot", ("GetGatewayBot", default))!;
public Task<Channel?> GetChannel(ulong channelId) =>
_client.Get<Channel>($"/channels/{channelId}", ("GetChannel", channelId));
public Task<Message?> GetMessage(ulong channelId, ulong messageId) =>
_client.Get<Message>($"/channels/{channelId}/messages/{messageId}", ("GetMessage", channelId));
public Task<Guild?> GetGuild(ulong id) =>
_client.Get<Guild>($"/guilds/{id}", ("GetGuild", id));
public Task<Channel[]> GetGuildChannels(ulong id) =>
_client.Get<Channel[]>($"/guilds/{id}/channels", ("GetGuildChannels", id))!;
public Task<User?> GetUser(ulong id) =>
_client.Get<User>($"/users/{id}", ("GetUser", default));
public Task<GuildMember?> GetGuildMember(ulong guildId, ulong userId) =>
_client.Get<GuildMember>($"/guilds/{guildId}/members/{userId}",
("GetGuildMember", guildId));
public Task<Message> CreateMessage(ulong channelId, MessageRequest request, MultipartFile[]? files = null) =>
_client.PostMultipart<Message>($"/channels/{channelId}/messages", ("CreateMessage", channelId), request,
files)!;
public Task<Message> EditMessage(ulong channelId, ulong messageId, MessageEditRequest request) =>
_client.Patch<Message>($"/channels/{channelId}/messages/{messageId}", ("EditMessage", channelId), request)!;
public Task DeleteMessage(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}", ("DeleteMessage", channelId));
public Task DeleteMessage(Message message) =>
_client.Delete($"/channels/{message.ChannelId}/messages/{message.Id}",
("DeleteMessage", message.ChannelId));
public Task CreateReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Put<object>($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("CreateReaction", channelId), null);
public Task DeleteOwnReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("DeleteOwnReaction", channelId));
public Task DeleteUserReaction(ulong channelId, ulong messageId, Emoji emoji, ulong userId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/{userId}",
("DeleteUserReaction", channelId));
public Task DeleteAllReactions(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions",
("DeleteAllReactions", channelId));
public Task DeleteAllReactionsForEmoji(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}",
("DeleteAllReactionsForEmoji", channelId));
public Task<ApplicationCommand> CreateGlobalApplicationCommand(ulong applicationId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/commands",
("CreateGlobalApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand[]> GetGuildApplicationCommands(ulong applicationId, ulong guildId) =>
_client.Get<ApplicationCommand[]>($"/applications/{applicationId}/guilds/{guildId}/commands",
("GetGuildApplicationCommands", applicationId))!;
public Task<ApplicationCommand> CreateGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("CreateGuildApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand> EditGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Patch<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("EditGuildApplicationCommand", applicationId), request)!;
public Task DeleteGuildApplicationCommand(ulong applicationId, ulong commandId) =>
_client.Delete($"/applications/{applicationId}/commands/{commandId}",
("DeleteGuildApplicationCommand", applicationId));
public Task CreateInteractionResponse(ulong interactionId, string token, InteractionResponse response) =>
_client.Post<object>($"/interactions/{interactionId}/{token}/callback",
("CreateInteractionResponse", interactionId), response);
public Task ModifyGuildMember(ulong guildId, ulong userId, ModifyGuildMemberRequest request) =>
_client.Patch<object>($"/guilds/{guildId}/members/{userId}",
("ModifyGuildMember", guildId), request);
public Task<Webhook> CreateWebhook(ulong channelId, CreateWebhookRequest request) =>
_client.Post<Webhook>($"/channels/{channelId}/webhooks", ("CreateWebhook", channelId), request)!;
public Task<Webhook> GetWebhook(ulong webhookId) =>
_client.Get<Webhook>($"/webhooks/{webhookId}/webhooks", ("GetWebhook", webhookId))!;
public Task<Webhook[]> GetChannelWebhooks(ulong channelId) =>
_client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!;
public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request,
MultipartFile[]? files = null, ulong? threadId = null)
{
var url = $"/webhooks/{webhookId}/{webhookToken}?wait=true";
if (threadId != null)
url += $"&thread_id={threadId}";
return _client.PostMultipart<Message>(url,
("ExecuteWebhook", webhookId), request, files)!;
}
public Task<Message> EditWebhookMessage(ulong webhookId, string webhookToken, ulong messageId,
WebhookMessageEditRequest request, ulong? threadId = null)
{
var url = $"/webhooks/{webhookId}/{webhookToken}/messages/{messageId}";
if (threadId != null)
url += $"?thread_id={threadId}";
return _client.Patch<Message>(url, ("EditWebhookMessage", webhookId), request)!;
}
public Task<Channel> CreateDm(ulong recipientId) =>
_client.Post<Channel>("/users/@me/channels", ("CreateDM", default), new CreateDmRequest(recipientId))!;
private static string EncodeEmoji(Emoji emoji) =>
WebUtility.UrlEncode(emoji.Id != null ? $"{emoji.Name}:{emoji.Id}" : emoji.Name) ??
throw new ArgumentException("Could not encode emoji");
} }

View File

@ -1,9 +1,8 @@
using System.Text.Json; using System.Text.Json;
namespace Myriad.Rest namespace Myriad.Rest;
public record DiscordApiError(string Message, int Code)
{ {
public record DiscordApiError(string Message, int Code) public JsonElement? Errors { get; init; }
{
public JsonElement? Errors { get; init; }
}
} }

View File

@ -1,77 +1,75 @@
using System;
using System.Net; using System.Net;
using System.Net.Http;
namespace Myriad.Rest.Exceptions namespace Myriad.Rest.Exceptions;
public class DiscordRequestException: Exception
{ {
public class DiscordRequestException: Exception public DiscordRequestException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError)
{ {
public DiscordRequestException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) ResponseBody = responseBody;
{ Response = response;
ResponseBody = responseBody; ApiError = apiError;
Response = response;
ApiError = apiError;
}
public string ResponseBody { get; init; } = null!;
public HttpResponseMessage Response { get; init; } = null!;
public HttpStatusCode StatusCode => Response.StatusCode;
public int? ErrorCode => ApiError?.Code;
internal DiscordApiError? ApiError { get; init; }
public override string Message =>
(ApiError?.Message ?? Response.ReasonPhrase ?? "") + (FormError != null ? $": {FormError}" : "");
public string? FormError => ApiError?.Errors?.ToString();
} }
public class NotFoundException: DiscordRequestException public string ResponseBody { get; init; } = null!;
{ public HttpResponseMessage Response { get; init; } = null!;
public NotFoundException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
{ }
}
public class UnauthorizedException: DiscordRequestException public HttpStatusCode StatusCode => Response.StatusCode;
{ public int? ErrorCode => ApiError?.Code;
public UnauthorizedException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
{ }
}
public class ForbiddenException: DiscordRequestException internal DiscordApiError? ApiError { get; init; }
{
public ForbiddenException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
{ }
}
public class ConflictException: DiscordRequestException public override string Message =>
{ (ApiError?.Message ?? Response.ReasonPhrase ?? "") + (FormError != null ? $": {FormError}" : "");
public ConflictException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
{ }
}
public class BadRequestException: DiscordRequestException public string? FormError => ApiError?.Errors?.ToString();
{ }
public BadRequestException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError) public class NotFoundException: DiscordRequestException
{ } {
} public NotFoundException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
public class TooManyRequestsException: DiscordRequestException { }
{ }
public TooManyRequestsException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) :
base(response, responseBody, apiError) public class UnauthorizedException: DiscordRequestException
{ } {
} public UnauthorizedException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) :
base(
public class UnknownDiscordRequestException: DiscordRequestException response, responseBody, apiError)
{ { }
public UnknownDiscordRequestException(HttpResponseMessage response, string responseBody, }
DiscordApiError? apiError) : base(response, responseBody, apiError) { }
} public class ForbiddenException: DiscordRequestException
{
public ForbiddenException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
{ }
}
public class ConflictException: DiscordRequestException
{
public ConflictException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
{ }
}
public class BadRequestException: DiscordRequestException
{
public BadRequestException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) : base(
response, responseBody, apiError)
{ }
}
public class TooManyRequestsException: DiscordRequestException
{
public TooManyRequestsException(HttpResponseMessage response, string responseBody, DiscordApiError? apiError) :
base(response, responseBody, apiError)
{ }
}
public class UnknownDiscordRequestException: DiscordRequestException
{
public UnknownDiscordRequestException(HttpResponseMessage response, string responseBody,
DiscordApiError? apiError) : base(response, responseBody, apiError) { }
} }

View File

@ -1,29 +1,26 @@
using System;
using Myriad.Rest.Ratelimit; using Myriad.Rest.Ratelimit;
namespace Myriad.Rest.Exceptions namespace Myriad.Rest.Exceptions;
public class RatelimitException: Exception
{ {
public class RatelimitException: Exception public RatelimitException(string? message) : base(message) { }
{ }
public RatelimitException(string? message) : base(message) { }
} public class RatelimitBucketExhaustedException: RatelimitException
{
public class RatelimitBucketExhaustedException: RatelimitException public RatelimitBucketExhaustedException(Bucket bucket, TimeSpan retryAfter) : base(
{ "Rate limit bucket exhausted, request blocked")
public RatelimitBucketExhaustedException(Bucket bucket, TimeSpan retryAfter) : base( {
"Rate limit bucket exhausted, request blocked") Bucket = bucket;
{ RetryAfter = retryAfter;
Bucket = bucket; }
RetryAfter = retryAfter;
} public Bucket Bucket { get; }
public TimeSpan RetryAfter { get; }
public Bucket Bucket { get; } }
public TimeSpan RetryAfter { get; }
} public class GloballyRatelimitedException: RatelimitException
{
public class GloballyRatelimitedException: RatelimitException public GloballyRatelimitedException() : base("Global rate limit hit") { }
{
public GloballyRatelimitedException() : base("Global rate limit hit") { }
}
} }

View File

@ -1,173 +1,172 @@
using System;
using System.Threading;
using Serilog; using Serilog;
namespace Myriad.Rest.Ratelimit namespace Myriad.Rest.Ratelimit;
public class Bucket
{ {
public class Bucket private static readonly TimeSpan Epsilon = TimeSpan.FromMilliseconds(10);
private static readonly TimeSpan FallbackDelay = TimeSpan.FromMilliseconds(200);
private static readonly TimeSpan StaleTimeout = TimeSpan.FromSeconds(5);
private readonly ILogger _logger;
private readonly SemaphoreSlim _semaphore = new(1, 1);
private bool _hasReceivedHeaders;
private DateTimeOffset? _nextReset;
private bool _resetTimeValid;
public Bucket(ILogger logger, string key, ulong major, int limit)
{ {
private static readonly TimeSpan Epsilon = TimeSpan.FromMilliseconds(10); _logger = logger.ForContext<Bucket>();
private static readonly TimeSpan FallbackDelay = TimeSpan.FromMilliseconds(200);
private static readonly TimeSpan StaleTimeout = TimeSpan.FromSeconds(5); Key = key;
Major = major;
private readonly ILogger _logger; Limit = limit;
private readonly SemaphoreSlim _semaphore = new(1, 1); Remaining = limit;
_resetTimeValid = false;
}
private DateTimeOffset? _nextReset; public string Key { get; }
private bool _resetTimeValid; public ulong Major { get; }
private bool _hasReceivedHeaders;
public Bucket(ILogger logger, string key, ulong major, int limit) public int Remaining { get; private set; }
public int Limit { get; private set; }
public DateTimeOffset LastUsed { get; private set; } = DateTimeOffset.UtcNow;
public bool TryAcquire()
{
LastUsed = DateTimeOffset.Now;
try
{ {
_logger = logger.ForContext<Bucket>(); _semaphore.Wait();
Key = key; if (Remaining > 0)
Major = major;
Limit = limit;
Remaining = limit;
_resetTimeValid = false;
}
public string Key { get; }
public ulong Major { get; }
public int Remaining { get; private set; }
public int Limit { get; private set; }
public DateTimeOffset LastUsed { get; private set; } = DateTimeOffset.UtcNow;
public bool TryAcquire()
{
LastUsed = DateTimeOffset.Now;
try
{ {
_semaphore.Wait(); _logger.Verbose(
"{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit} left], allowing through",
if (Remaining > 0)
{
_logger.Verbose(
"{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit} left], allowing through",
Key, Major, Remaining, Limit);
Remaining--;
return true;
}
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit}] left, denying",
Key, Major, Remaining, Limit); Key, Major, Remaining, Limit);
return false; Remaining--;
}
finally return true;
{
_semaphore.Release();
} }
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit}] left, denying",
Key, Major, Remaining, Limit);
return false;
} }
finally
public void HandleResponse(RatelimitHeaders headers)
{ {
try _semaphore.Release();
{
_semaphore.Wait();
_logger.Verbose("{BucketKey}/{BucketMajor}: Received rate limit headers: {@RateLimitHeaders}",
Key, Major, headers);
if (headers.ResetAfter != null)
{
var headerNextReset = DateTimeOffset.UtcNow + headers.ResetAfter.Value; // todo: server time
if (_nextReset == null || headerNextReset > _nextReset)
{
_logger.Verbose("{BucketKey}/{BucketMajor}: Received reset time {NextReset} from server (after: {NextResetAfter}, remaining: {Remaining}, local remaining: {LocalRemaining})",
Key, Major, headerNextReset, headers.ResetAfter.Value, headers.Remaining, Remaining);
_nextReset = headerNextReset;
_resetTimeValid = true;
}
}
if (headers.Limit != null)
Limit = headers.Limit.Value;
if (headers.Remaining != null && !_hasReceivedHeaders)
{
var oldRemaining = Remaining;
Remaining = Math.Min(headers.Remaining.Value, Remaining);
_logger.Debug("{BucketKey}/{BucketMajor}: Received first remaining of {HeaderRemaining}, previous local remaining is {LocalRemaining}, new local remaining is {Remaining}",
Key, Major, headers.Remaining.Value, oldRemaining, Remaining);
_hasReceivedHeaders = true;
}
}
finally
{
_semaphore.Release();
}
}
public void Tick(DateTimeOffset now)
{
try
{
_semaphore.Wait();
// If we don't have any reset data, "snap" it to now
// This happens before first request and at this point the reset is invalid anyway, so it's fine
// but it ensures the stale timeout doesn't trigger early by using `default` value
if (_nextReset == null)
_nextReset = now;
// If we're past the reset time *and* we haven't reset already, do that
var timeSinceReset = now - _nextReset;
var shouldReset = _resetTimeValid && timeSinceReset > TimeSpan.Zero;
if (shouldReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket timed out, refreshing with {BucketLimit} requests",
Key, Major, Limit);
Remaining = Limit;
_resetTimeValid = false;
return;
}
// We've run out of requests without having any new reset time,
// *and* it's been longer than a set amount - add one request back to the pool and hope that one returns
var isBucketStale = !_resetTimeValid && Remaining <= 0 && timeSinceReset > StaleTimeout;
if (isBucketStale)
{
_logger.Warning(
"{BucketKey}/{BucketMajor}: Bucket is stale ({StaleTimeout} passed with no rate limit info), allowing one request through",
Key, Major, StaleTimeout);
Remaining = 1;
// Reset the (still-invalid) reset time to now, so we don't keep hitting this conditional over and over...
_nextReset = now;
}
}
finally
{
_semaphore.Release();
}
}
public TimeSpan GetResetDelay(DateTimeOffset now)
{
// If we don't have a valid reset time, return the fallback delay always
// (so it'll keep spinning until we hopefully have one...)
if (!_resetTimeValid)
return FallbackDelay;
var delay = (_nextReset ?? now) - now;
// If we have a really small (or negative) value, return a fallback delay too
if (delay < Epsilon)
return FallbackDelay;
return delay;
} }
} }
public void HandleResponse(RatelimitHeaders headers)
{
try
{
_semaphore.Wait();
_logger.Verbose("{BucketKey}/{BucketMajor}: Received rate limit headers: {@RateLimitHeaders}",
Key, Major, headers);
if (headers.ResetAfter != null)
{
var headerNextReset = DateTimeOffset.UtcNow + headers.ResetAfter.Value; // todo: server time
if (_nextReset == null || headerNextReset > _nextReset)
{
_logger.Verbose(
"{BucketKey}/{BucketMajor}: Received reset time {NextReset} from server (after: {NextResetAfter}, remaining: {Remaining}, local remaining: {LocalRemaining})",
Key, Major, headerNextReset, headers.ResetAfter.Value, headers.Remaining, Remaining
);
_nextReset = headerNextReset;
_resetTimeValid = true;
}
}
if (headers.Limit != null)
Limit = headers.Limit.Value;
if (headers.Remaining != null && !_hasReceivedHeaders)
{
var oldRemaining = Remaining;
Remaining = Math.Min(headers.Remaining.Value, Remaining);
_logger.Debug(
"{BucketKey}/{BucketMajor}: Received first remaining of {HeaderRemaining}, previous local remaining is {LocalRemaining}, new local remaining is {Remaining}",
Key, Major, headers.Remaining.Value, oldRemaining, Remaining);
_hasReceivedHeaders = true;
}
}
finally
{
_semaphore.Release();
}
}
public void Tick(DateTimeOffset now)
{
try
{
_semaphore.Wait();
// If we don't have any reset data, "snap" it to now
// This happens before first request and at this point the reset is invalid anyway, so it's fine
// but it ensures the stale timeout doesn't trigger early by using `default` value
if (_nextReset == null)
_nextReset = now;
// If we're past the reset time *and* we haven't reset already, do that
var timeSinceReset = now - _nextReset;
var shouldReset = _resetTimeValid && timeSinceReset > TimeSpan.Zero;
if (shouldReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket timed out, refreshing with {BucketLimit} requests",
Key, Major, Limit);
Remaining = Limit;
_resetTimeValid = false;
return;
}
// We've run out of requests without having any new reset time,
// *and* it's been longer than a set amount - add one request back to the pool and hope that one returns
var isBucketStale = !_resetTimeValid && Remaining <= 0 && timeSinceReset > StaleTimeout;
if (isBucketStale)
{
_logger.Warning(
"{BucketKey}/{BucketMajor}: Bucket is stale ({StaleTimeout} passed with no rate limit info), allowing one request through",
Key, Major, StaleTimeout);
Remaining = 1;
// Reset the (still-invalid) reset time to now, so we don't keep hitting this conditional over and over...
_nextReset = now;
}
}
finally
{
_semaphore.Release();
}
}
public TimeSpan GetResetDelay(DateTimeOffset now)
{
// If we don't have a valid reset time, return the fallback delay always
// (so it'll keep spinning until we hopefully have one...)
if (!_resetTimeValid)
return FallbackDelay;
var delay = (_nextReset ?? now) - now;
// If we have a really small (or negative) value, return a fallback delay too
if (delay < Epsilon)
return FallbackDelay;
return delay;
}
} }

View File

@ -1,82 +1,79 @@
using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using Serilog; using Serilog;
namespace Myriad.Rest.Ratelimit namespace Myriad.Rest.Ratelimit;
public class BucketManager: IDisposable
{ {
public class BucketManager: IDisposable private static readonly TimeSpan StaleBucketTimeout = TimeSpan.FromMinutes(5);
private static readonly TimeSpan PruneWorkerInterval = TimeSpan.FromMinutes(1);
private readonly ConcurrentDictionary<(string key, ulong major), Bucket> _buckets = new();
private readonly ConcurrentDictionary<string, string> _endpointKeyMap = new();
private readonly ConcurrentDictionary<string, int> _knownKeyLimits = new();
private readonly ILogger _logger;
private readonly Task _worker;
private readonly CancellationTokenSource _workerCts = new();
public BucketManager(ILogger logger)
{ {
private static readonly TimeSpan StaleBucketTimeout = TimeSpan.FromMinutes(5); _logger = logger.ForContext<BucketManager>();
private static readonly TimeSpan PruneWorkerInterval = TimeSpan.FromMinutes(1); _worker = PruneWorker(_workerCts.Token);
private readonly ConcurrentDictionary<(string key, ulong major), Bucket> _buckets = new(); }
private readonly ConcurrentDictionary<string, string> _endpointKeyMap = new(); public void Dispose()
private readonly ConcurrentDictionary<string, int> _knownKeyLimits = new(); {
_workerCts.Dispose();
_worker.Dispose();
}
private readonly ILogger _logger; public Bucket? GetBucket(string endpoint, ulong major)
{
if (!_endpointKeyMap.TryGetValue(endpoint, out var key))
return null;
private readonly Task _worker; if (_buckets.TryGetValue((key, major), out var bucket))
private readonly CancellationTokenSource _workerCts = new(); return bucket;
public BucketManager(ILogger logger) if (!_knownKeyLimits.TryGetValue(key, out var knownLimit))
return null;
_logger.Debug("Creating new bucket {BucketKey}/{BucketMajor} with limit {KnownLimit}", key, major,
knownLimit);
return _buckets.GetOrAdd((key, major),
k => new Bucket(_logger, k.Item1, k.Item2, knownLimit));
}
public void UpdateEndpointInfo(string endpoint, string key, int? limit)
{
_endpointKeyMap[endpoint] = key;
if (limit != null)
_knownKeyLimits[key] = limit.Value;
}
private async Task PruneWorker(CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{ {
_logger = logger.ForContext<BucketManager>(); await Task.Delay(PruneWorkerInterval, ct);
_worker = PruneWorker(_workerCts.Token); PruneStaleBuckets(DateTimeOffset.UtcNow);
} }
}
public void Dispose() private void PruneStaleBuckets(DateTimeOffset now)
{
foreach (var (key, bucket) in _buckets)
{ {
_workerCts.Dispose(); if (now - bucket.LastUsed <= StaleBucketTimeout)
_worker.Dispose(); continue;
}
public Bucket? GetBucket(string endpoint, ulong major) _logger.Debug("Pruning unused bucket {BucketKey}/{BucketMajor} (last used at {BucketLastUsed})",
{ bucket.Key, bucket.Major, bucket.LastUsed);
if (!_endpointKeyMap.TryGetValue(endpoint, out var key)) _buckets.TryRemove(key, out _);
return null;
if (_buckets.TryGetValue((key, major), out var bucket))
return bucket;
if (!_knownKeyLimits.TryGetValue(key, out var knownLimit))
return null;
_logger.Debug("Creating new bucket {BucketKey}/{BucketMajor} with limit {KnownLimit}", key, major, knownLimit);
return _buckets.GetOrAdd((key, major),
k => new Bucket(_logger, k.Item1, k.Item2, knownLimit));
}
public void UpdateEndpointInfo(string endpoint, string key, int? limit)
{
_endpointKeyMap[endpoint] = key;
if (limit != null)
_knownKeyLimits[key] = limit.Value;
}
private async Task PruneWorker(CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{
await Task.Delay(PruneWorkerInterval, ct);
PruneStaleBuckets(DateTimeOffset.UtcNow);
}
}
private void PruneStaleBuckets(DateTimeOffset now)
{
foreach (var (key, bucket) in _buckets)
{
if (now - bucket.LastUsed <= StaleBucketTimeout)
continue;
_logger.Debug("Pruning unused bucket {BucketKey}/{BucketMajor} (last used at {BucketLastUsed})",
bucket.Key, bucket.Major, bucket.LastUsed);
_buckets.TryRemove(key, out _);
}
} }
} }
} }

View File

@ -1,46 +1,40 @@
using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Polly; using Polly;
namespace Myriad.Rest.Ratelimit namespace Myriad.Rest.Ratelimit;
public class DiscordRateLimitPolicy: AsyncPolicy<HttpResponseMessage>
{ {
public class DiscordRateLimitPolicy: AsyncPolicy<HttpResponseMessage> public const string EndpointContextKey = "Endpoint";
public const string MajorContextKey = "Major";
private readonly Ratelimiter _ratelimiter;
public DiscordRateLimitPolicy(Ratelimiter ratelimiter, PolicyBuilder<HttpResponseMessage>? policyBuilder = null)
: base(policyBuilder)
{ {
public const string EndpointContextKey = "Endpoint"; _ratelimiter = ratelimiter;
public const string MajorContextKey = "Major"; }
private readonly Ratelimiter _ratelimiter; protected override async Task<HttpResponseMessage> ImplementationAsync(
Func<Context, CancellationToken, Task<HttpResponseMessage>> action, Context context, CancellationToken ct,
bool continueOnCapturedContext)
{
if (!context.TryGetValue(EndpointContextKey, out var endpointObj) || !(endpointObj is string endpoint))
throw new ArgumentException("Must provide endpoint in Polly context");
public DiscordRateLimitPolicy(Ratelimiter ratelimiter, PolicyBuilder<HttpResponseMessage>? policyBuilder = null) if (!context.TryGetValue(MajorContextKey, out var majorObj) || !(majorObj is ulong major))
: base(policyBuilder) throw new ArgumentException("Must provide major in Polly context");
{
_ratelimiter = ratelimiter;
}
protected override async Task<HttpResponseMessage> ImplementationAsync( // Check rate limit, throw if we're not allowed...
Func<Context, CancellationToken, Task<HttpResponseMessage>> action, Context context, CancellationToken ct, _ratelimiter.AllowRequestOrThrow(endpoint, major, DateTimeOffset.Now);
bool continueOnCapturedContext)
{
if (!context.TryGetValue(EndpointContextKey, out var endpointObj) || !(endpointObj is string endpoint))
throw new ArgumentException("Must provide endpoint in Polly context");
if (!context.TryGetValue(MajorContextKey, out var majorObj) || !(majorObj is ulong major)) // We're OK, push it through
throw new ArgumentException("Must provide major in Polly context"); var response = await action(context, ct).ConfigureAwait(continueOnCapturedContext);
// Check rate limit, throw if we're not allowed... // Update rate limit state with headers
_ratelimiter.AllowRequestOrThrow(endpoint, major, DateTimeOffset.Now); var headers = RatelimitHeaders.Parse(response);
_ratelimiter.HandleResponse(headers, endpoint, major);
// We're OK, push it through return response;
var response = await action(context, ct).ConfigureAwait(continueOnCapturedContext);
// Update rate limit state with headers
var headers = RatelimitHeaders.Parse(response);
_ratelimiter.HandleResponse(headers, endpoint, major);
return response;
}
} }
} }

View File

@ -1,85 +1,79 @@
using System;
using System.Globalization; using System.Globalization;
using System.Linq;
using System.Net.Http;
namespace Myriad.Rest.Ratelimit namespace Myriad.Rest.Ratelimit;
public record RatelimitHeaders
{ {
public record RatelimitHeaders private const string LimitHeader = "X-RateLimit-Limit";
private const string RemainingHeader = "X-RateLimit-Remaining";
private const string ResetHeader = "X-RateLimit-Reset";
private const string ResetAfterHeader = "X-RateLimit-Reset-After";
private const string BucketHeader = "X-RateLimit-Bucket";
private const string GlobalHeader = "X-RateLimit-Global";
public bool Global { get; private set; }
public int? Limit { get; private set; }
public int? Remaining { get; private set; }
public DateTimeOffset? Reset { get; private set; }
public TimeSpan? ResetAfter { get; private set; }
public string? Bucket { get; private set; }
public DateTimeOffset? ServerDate { get; private set; }
public bool HasRatelimitInfo =>
Limit != null && Remaining != null && Reset != null && ResetAfter != null && Bucket != null;
public static RatelimitHeaders Parse(HttpResponseMessage response)
{ {
private const string LimitHeader = "X-RateLimit-Limit"; var headers = new RatelimitHeaders
private const string RemainingHeader = "X-RateLimit-Remaining";
private const string ResetHeader = "X-RateLimit-Reset";
private const string ResetAfterHeader = "X-RateLimit-Reset-After";
private const string BucketHeader = "X-RateLimit-Bucket";
private const string GlobalHeader = "X-RateLimit-Global";
public bool Global { get; private set; }
public int? Limit { get; private set; }
public int? Remaining { get; private set; }
public DateTimeOffset? Reset { get; private set; }
public TimeSpan? ResetAfter { get; private set; }
public string? Bucket { get; private set; }
public DateTimeOffset? ServerDate { get; private set; }
public bool HasRatelimitInfo =>
Limit != null && Remaining != null && Reset != null && ResetAfter != null && Bucket != null;
public RatelimitHeaders() { }
public static RatelimitHeaders Parse(HttpResponseMessage response)
{ {
var headers = new RatelimitHeaders ServerDate = response.Headers.Date,
{ Limit = TryGetInt(response, LimitHeader),
ServerDate = response.Headers.Date, Remaining = TryGetInt(response, RemainingHeader),
Limit = TryGetInt(response, LimitHeader), Bucket = TryGetHeader(response, BucketHeader)
Remaining = TryGetInt(response, RemainingHeader), };
Bucket = TryGetHeader(response, BucketHeader)
};
var resetTimestamp = TryGetDouble(response, ResetHeader); var resetTimestamp = TryGetDouble(response, ResetHeader);
if (resetTimestamp != null) if (resetTimestamp != null)
headers.Reset = DateTimeOffset.FromUnixTimeMilliseconds((long)(resetTimestamp.Value * 1000)); headers.Reset = DateTimeOffset.FromUnixTimeMilliseconds((long)(resetTimestamp.Value * 1000));
var resetAfterSeconds = TryGetDouble(response, ResetAfterHeader); var resetAfterSeconds = TryGetDouble(response, ResetAfterHeader);
if (resetAfterSeconds != null) if (resetAfterSeconds != null)
headers.ResetAfter = TimeSpan.FromSeconds(resetAfterSeconds.Value); headers.ResetAfter = TimeSpan.FromSeconds(resetAfterSeconds.Value);
var global = TryGetHeader(response, GlobalHeader); var global = TryGetHeader(response, GlobalHeader);
if (global != null && bool.TryParse(global, out var globalBool)) if (global != null && bool.TryParse(global, out var globalBool))
headers.Global = globalBool; headers.Global = globalBool;
return headers; return headers;
} }
private static string? TryGetHeader(HttpResponseMessage response, string headerName) private static string? TryGetHeader(HttpResponseMessage response, string headerName)
{ {
if (!response.Headers.TryGetValues(headerName, out var values)) if (!response.Headers.TryGetValues(headerName, out var values))
return null; return null;
return values.FirstOrDefault(); return values.FirstOrDefault();
} }
private static int? TryGetInt(HttpResponseMessage response, string headerName) private static int? TryGetInt(HttpResponseMessage response, string headerName)
{ {
var valueString = TryGetHeader(response, headerName); var valueString = TryGetHeader(response, headerName);
if (!int.TryParse(valueString, NumberStyles.Integer, CultureInfo.InvariantCulture, out var value)) if (!int.TryParse(valueString, NumberStyles.Integer, CultureInfo.InvariantCulture, out var value))
return null; return null;
return value; return value;
} }
private static double? TryGetDouble(HttpResponseMessage response, string headerName) private static double? TryGetDouble(HttpResponseMessage response, string headerName)
{ {
var valueString = TryGetHeader(response, headerName); var valueString = TryGetHeader(response, headerName);
if (!double.TryParse(valueString, NumberStyles.Float, CultureInfo.InvariantCulture, out var value)) if (!double.TryParse(valueString, NumberStyles.Float, CultureInfo.InvariantCulture, out var value))
return null; return null;
return value; return value;
}
} }
} }

View File

@ -1,86 +1,83 @@
using System;
using Myriad.Rest.Exceptions; using Myriad.Rest.Exceptions;
using Serilog; using Serilog;
namespace Myriad.Rest.Ratelimit namespace Myriad.Rest.Ratelimit;
public class Ratelimiter: IDisposable
{ {
public class Ratelimiter: IDisposable private readonly BucketManager _buckets;
private readonly ILogger _logger;
private DateTimeOffset? _globalRateLimitExpiry;
public Ratelimiter(ILogger logger)
{ {
private readonly BucketManager _buckets; _logger = logger.ForContext<Ratelimiter>();
private readonly ILogger _logger; _buckets = new BucketManager(logger);
}
private DateTimeOffset? _globalRateLimitExpiry; public void Dispose()
{
_buckets.Dispose();
}
public Ratelimiter(ILogger logger) public void AllowRequestOrThrow(string endpoint, ulong major, DateTimeOffset now)
{
if (IsGloballyRateLimited(now))
{ {
_logger = logger.ForContext<Ratelimiter>(); _logger.Warning("Globally rate limited until {GlobalRateLimitExpiry}, cancelling request",
_buckets = new BucketManager(logger); _globalRateLimitExpiry);
throw new GloballyRatelimitedException();
} }
public void Dispose() var bucket = _buckets.GetBucket(endpoint, major);
if (bucket == null)
{ {
_buckets.Dispose(); // No rate limit for this endpoint (yet), allow through
_logger.Debug("No rate limit data for endpoint {Endpoint}, allowing through", endpoint);
return;
} }
public void AllowRequestOrThrow(string endpoint, ulong major, DateTimeOffset now) bucket.Tick(now);
if (bucket.TryAcquire())
// We're allowed to send it! :)
return;
// We can't send this request right now; retrying...
var waitTime = bucket.GetResetDelay(now);
// add a small buffer for Timing:tm:
waitTime += TimeSpan.FromMilliseconds(50);
// (this is caught by a WaitAndRetry Polly handler, if configured)
throw new RatelimitBucketExhaustedException(bucket, waitTime);
}
public void HandleResponse(RatelimitHeaders headers, string endpoint, ulong major)
{
if (!headers.HasRatelimitInfo)
return;
// TODO: properly calculate server time?
if (headers.Global)
{ {
if (IsGloballyRateLimited(now)) _logger.Warning(
{ "Global rate limit hit, resetting at {GlobalRateLimitExpiry} (in {GlobalRateLimitResetAfter}!",
_logger.Warning("Globally rate limited until {GlobalRateLimitExpiry}, cancelling request", _globalRateLimitExpiry, headers.ResetAfter);
_globalRateLimitExpiry); _globalRateLimitExpiry = headers.Reset;
throw new GloballyRatelimitedException(); }
} else
{
// Update buckets first, then get it again, to properly "transfer" this info over to the new value
_buckets.UpdateEndpointInfo(endpoint, headers.Bucket!, headers.Limit);
var bucket = _buckets.GetBucket(endpoint, major); var bucket = _buckets.GetBucket(endpoint, major);
if (bucket == null) bucket?.HandleResponse(headers);
{
// No rate limit for this endpoint (yet), allow through
_logger.Debug("No rate limit data for endpoint {Endpoint}, allowing through", endpoint);
return;
}
bucket.Tick(now);
if (bucket.TryAcquire())
// We're allowed to send it! :)
return;
// We can't send this request right now; retrying...
var waitTime = bucket.GetResetDelay(now);
// add a small buffer for Timing:tm:
waitTime += TimeSpan.FromMilliseconds(50);
// (this is caught by a WaitAndRetry Polly handler, if configured)
throw new RatelimitBucketExhaustedException(bucket, waitTime);
} }
public void HandleResponse(RatelimitHeaders headers, string endpoint, ulong major)
{
if (!headers.HasRatelimitInfo)
return;
// TODO: properly calculate server time?
if (headers.Global)
{
_logger.Warning(
"Global rate limit hit, resetting at {GlobalRateLimitExpiry} (in {GlobalRateLimitResetAfter}!",
_globalRateLimitExpiry, headers.ResetAfter);
_globalRateLimitExpiry = headers.Reset;
}
else
{
// Update buckets first, then get it again, to properly "transfer" this info over to the new value
_buckets.UpdateEndpointInfo(endpoint, headers.Bucket!, headers.Limit);
var bucket = _buckets.GetBucket(endpoint, major);
bucket?.HandleResponse(headers);
}
}
private bool IsGloballyRateLimited(DateTimeOffset now) =>
_globalRateLimitExpiry > now;
} }
private bool IsGloballyRateLimited(DateTimeOffset now) =>
_globalRateLimitExpiry > now;
} }

View File

@ -2,21 +2,20 @@ using System.Text.Json.Serialization;
using Myriad.Serialization; using Myriad.Serialization;
namespace Myriad.Rest.Types namespace Myriad.Rest.Types;
{
public record AllowedMentions
{
[JsonConverter(typeof(JsonSnakeCaseStringEnumConverter))]
public enum ParseType
{
Roles,
Users,
Everyone
}
public ParseType[]? Parse { get; set; } public record AllowedMentions
public ulong[]? Users { get; set; } {
public ulong[]? Roles { get; set; } [JsonConverter(typeof(JsonSnakeCaseStringEnumConverter))]
public bool RepliedUser { get; set; } public enum ParseType
{
Roles,
Users,
Everyone
} }
public ParseType[]? Parse { get; set; }
public ulong[]? Users { get; set; }
public ulong[]? Roles { get; set; }
public bool RepliedUser { get; set; }
} }

View File

@ -1,6 +1,3 @@
using System.IO; namespace Myriad.Rest.Types;
namespace Myriad.Rest.Types public record MultipartFile(string Filename, Stream Data, string? Description);
{
public record MultipartFile(string Filename, Stream Data, string? Description);
}

View File

@ -1,13 +1,10 @@
using System.Collections.Generic;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Rest.Types namespace Myriad.Rest.Types;
public record ApplicationCommandRequest
{ {
public record ApplicationCommandRequest public string Name { get; init; }
{ public string Description { get; init; }
public string Name { get; init; } public List<ApplicationCommandOption>? Options { get; init; }
public string Description { get; init; }
public List<ApplicationCommandOption>? Options { get; init; }
}
} }

View File

@ -1,4 +1,3 @@
namespace Myriad.Rest.Types.Requests namespace Myriad.Rest.Types.Requests;
{
public record CreateDmRequest(ulong RecipientId); public record CreateDmRequest(ulong RecipientId);
}

View File

@ -1,4 +1,3 @@
namespace Myriad.Rest.Types.Requests namespace Myriad.Rest.Types.Requests;
{
public record CreateWebhookRequest(string Name); public record CreateWebhookRequest(string Name);
}

View File

@ -1,14 +1,13 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Rest.Types.Requests namespace Myriad.Rest.Types.Requests;
public record ExecuteWebhookRequest
{ {
public record ExecuteWebhookRequest public string? Content { get; init; }
{ public string? Username { get; init; }
public string? Content { get; init; } public string? AvatarUrl { get; init; }
public string? Username { get; init; } public Embed[] Embeds { get; init; }
public string? AvatarUrl { get; init; } public Message.Attachment[] Attachments { get; set; }
public Embed[] Embeds { get; init; } public AllowedMentions? AllowedMentions { get; init; }
public Message.Attachment[] Attachments { get; set; }
public AllowedMentions? AllowedMentions { get; init; }
}
} }

View File

@ -3,23 +3,22 @@ using System.Text.Json.Serialization;
using Myriad.Types; using Myriad.Types;
using Myriad.Utils; using Myriad.Utils;
namespace Myriad.Rest.Types.Requests namespace Myriad.Rest.Types.Requests;
public record MessageEditRequest
{ {
public record MessageEditRequest [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
{ public Optional<string?> Content { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<string?> Content { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<Embed?> Embed { get; init; } public Optional<Embed?> Embed { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<Message.MessageFlags> Flags { get; init; } public Optional<Message.MessageFlags> Flags { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<AllowedMentions> AllowedMentions { get; init; } public Optional<AllowedMentions> AllowedMentions { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<MessageComponent[]?> Components { get; init; } public Optional<MessageComponent[]?> Components { get; init; }
}
} }

View File

@ -1,14 +1,13 @@
using Myriad.Types; using Myriad.Types;
namespace Myriad.Rest.Types.Requests namespace Myriad.Rest.Types.Requests;
public record MessageRequest
{ {
public record MessageRequest public string? Content { get; set; }
{ public object? Nonce { get; set; }
public string? Content { get; set; } public bool Tts { get; set; }
public object? Nonce { get; set; } public AllowedMentions? AllowedMentions { get; set; }
public bool Tts { get; set; } public Embed? Embed { get; set; }
public AllowedMentions? AllowedMentions { get; set; } public MessageComponent[]? Components { get; set; }
public Embed? Embed { get; set; }
public MessageComponent[]? Components { get; set; }
}
} }

View File

@ -1,7 +1,6 @@
namespace Myriad.Rest.Types namespace Myriad.Rest.Types;
public record ModifyGuildMemberRequest
{ {
public record ModifyGuildMemberRequest public string? Nick { get; init; }
{
public string? Nick { get; init; }
}
} }

View File

@ -2,14 +2,13 @@ using System.Text.Json.Serialization;
using Myriad.Utils; using Myriad.Utils;
namespace Myriad.Rest.Types.Requests namespace Myriad.Rest.Types.Requests;
{
public record WebhookMessageEditRequest
{
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<string?> Content { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public record WebhookMessageEditRequest
public Optional<AllowedMentions> AllowedMentions { get; init; } {
} [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<string?> Content { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<AllowedMentions> AllowedMentions { get; init; }
} }

View File

@ -1,21 +1,20 @@
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
namespace Myriad.Serialization namespace Myriad.Serialization;
public static class JsonSerializerOptionsExtensions
{ {
public static class JsonSerializerOptionsExtensions public static JsonSerializerOptions ConfigureForMyriad(this JsonSerializerOptions opts)
{ {
public static JsonSerializerOptions ConfigureForMyriad(this JsonSerializerOptions opts) opts.PropertyNamingPolicy = new JsonSnakeCaseNamingPolicy();
{ opts.NumberHandling = JsonNumberHandling.AllowReadingFromString;
opts.PropertyNamingPolicy = new JsonSnakeCaseNamingPolicy(); opts.IncludeFields = true;
opts.NumberHandling = JsonNumberHandling.AllowReadingFromString;
opts.IncludeFields = true;
opts.Converters.Add(new PermissionSetJsonConverter()); opts.Converters.Add(new PermissionSetJsonConverter());
opts.Converters.Add(new ShardInfoJsonConverter()); opts.Converters.Add(new ShardInfoJsonConverter());
opts.Converters.Add(new OptionalConverterFactory()); opts.Converters.Add(new OptionalConverterFactory());
return opts; return opts;
}
} }
} }

View File

@ -1,88 +1,86 @@
using System;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
namespace Myriad.Serialization namespace Myriad.Serialization;
// From https://github.com/J0rgeSerran0/JsonNamingPolicy/blob/master/JsonSnakeCaseNamingPolicy.cs, no NuGet :/
public class JsonSnakeCaseNamingPolicy: JsonNamingPolicy
{ {
// From https://github.com/J0rgeSerran0/JsonNamingPolicy/blob/master/JsonSnakeCaseNamingPolicy.cs, no NuGet :/ private readonly string _separator = "_";
public class JsonSnakeCaseNamingPolicy: JsonNamingPolicy
public override string ConvertName(string name)
{ {
private readonly string _separator = "_"; if (string.IsNullOrEmpty(name) || string.IsNullOrWhiteSpace(name)) return string.Empty;
public override string ConvertName(string name) ReadOnlySpan<char> spanName = name.Trim();
var stringBuilder = new StringBuilder();
var addCharacter = true;
var isPreviousSpace = false;
var isPreviousSeparator = false;
var isCurrentSpace = false;
var isNextLower = false;
var isNextUpper = false;
var isNextSpace = false;
for (var position = 0; position < spanName.Length; position++)
{ {
if (string.IsNullOrEmpty(name) || string.IsNullOrWhiteSpace(name)) return string.Empty; if (position != 0)
ReadOnlySpan<char> spanName = name.Trim();
var stringBuilder = new StringBuilder();
var addCharacter = true;
var isPreviousSpace = false;
var isPreviousSeparator = false;
var isCurrentSpace = false;
var isNextLower = false;
var isNextUpper = false;
var isNextSpace = false;
for (var position = 0; position < spanName.Length; position++)
{ {
if (position != 0) isCurrentSpace = spanName[position] == 32;
isPreviousSpace = spanName[position - 1] == 32;
isPreviousSeparator = spanName[position - 1] == 95;
if (position + 1 != spanName.Length)
{ {
isCurrentSpace = spanName[position] == 32; isNextLower = spanName[position + 1] > 96 && spanName[position + 1] < 123;
isPreviousSpace = spanName[position - 1] == 32; isNextUpper = spanName[position + 1] > 64 && spanName[position + 1] < 91;
isPreviousSeparator = spanName[position - 1] == 95; isNextSpace = spanName[position + 1] == 32;
}
if (position + 1 != spanName.Length) if (isCurrentSpace &&
{ (isPreviousSpace ||
isNextLower = spanName[position + 1] > 96 && spanName[position + 1] < 123; isPreviousSeparator ||
isNextUpper = spanName[position + 1] > 64 && spanName[position + 1] < 91; isNextUpper ||
isNextSpace = spanName[position + 1] == 32; isNextSpace))
} {
addCharacter = false;
}
else
{
var isCurrentUpper = spanName[position] > 64 && spanName[position] < 91;
var isPreviousLower = spanName[position - 1] > 96 && spanName[position - 1] < 123;
var isPreviousNumber = spanName[position - 1] > 47 && spanName[position - 1] < 58;
if (isCurrentSpace && if (isCurrentUpper &&
(isPreviousSpace || (isPreviousLower ||
isPreviousSeparator || isPreviousNumber ||
isNextUpper || isNextLower ||
isNextSpace)) isNextSpace ||
isNextLower && !isPreviousSpace))
{ {
addCharacter = false; stringBuilder.Append(_separator);
} }
else else
{ {
var isCurrentUpper = spanName[position] > 64 && spanName[position] < 91; if (isCurrentSpace &&
var isPreviousLower = spanName[position - 1] > 96 && spanName[position - 1] < 123; !isPreviousSpace &&
var isPreviousNumber = spanName[position - 1] > 47 && spanName[position - 1] < 58; !isNextSpace)
if (isCurrentUpper &&
(isPreviousLower ||
isPreviousNumber ||
isNextLower ||
isNextSpace ||
isNextLower && !isPreviousSpace))
{ {
stringBuilder.Append(_separator); stringBuilder.Append(_separator);
} addCharacter = false;
else
{
if (isCurrentSpace &&
!isPreviousSpace &&
!isNextSpace)
{
stringBuilder.Append(_separator);
addCharacter = false;
}
} }
} }
} }
if (addCharacter)
stringBuilder.Append(spanName[position]);
else
addCharacter = true;
} }
return stringBuilder.ToString().ToLower(); if (addCharacter)
stringBuilder.Append(spanName[position]);
else
addCharacter = true;
} }
return stringBuilder.ToString().ToLower();
} }
} }

View File

@ -1,17 +1,15 @@
using System;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
namespace Myriad.Serialization namespace Myriad.Serialization;
public class JsonSnakeCaseStringEnumConverter: JsonConverterFactory
{ {
public class JsonSnakeCaseStringEnumConverter: JsonConverterFactory private readonly JsonStringEnumConverter _inner = new(new JsonSnakeCaseNamingPolicy());
{
private readonly JsonStringEnumConverter _inner = new(new JsonSnakeCaseNamingPolicy());
public override bool CanConvert(Type typeToConvert) => public override bool CanConvert(Type typeToConvert) =>
_inner.CanConvert(typeToConvert); _inner.CanConvert(typeToConvert);
public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) => public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) =>
_inner.CreateConverter(typeToConvert, options); _inner.CreateConverter(typeToConvert, options);
}
} }

View File

@ -1,22 +1,20 @@
using System;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
namespace Myriad.Serialization namespace Myriad.Serialization;
{
public class JsonStringConverter: JsonConverter<object>
{
public override object? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var str = JsonSerializer.Deserialize<string>(ref reader);
var inner = JsonSerializer.Deserialize(str!, typeToConvert, options);
return inner;
}
public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options) public class JsonStringConverter: JsonConverter<object>
{ {
var inner = JsonSerializer.Serialize(value, options); public override object? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
writer.WriteStringValue(inner); {
} var str = JsonSerializer.Deserialize<string>(ref reader);
var inner = JsonSerializer.Deserialize(str!, typeToConvert, options);
return inner;
}
public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options)
{
var inner = JsonSerializer.Serialize(value, options);
writer.WriteStringValue(inner);
} }
} }

View File

@ -1,48 +1,47 @@
using System;
using System.Reflection; using System.Reflection;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Myriad.Utils; using Myriad.Utils;
namespace Myriad.Serialization namespace Myriad.Serialization;
public class OptionalConverterFactory: JsonConverterFactory
{ {
public class OptionalConverterFactory: JsonConverterFactory public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options)
{ {
public class Inner<T>: JsonConverter<Optional<T>> var innerType = typeToConvert.GetGenericArguments()[0];
{ return (JsonConverter?)Activator.CreateInstance(
public override Optional<T> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) typeof(Inner<>).MakeGenericType(innerType),
{ BindingFlags.Instance | BindingFlags.Public,
var inner = JsonSerializer.Deserialize<T>(ref reader, options); null,
return new(inner!); null,
} null);
}
public override void Write(Utf8JsonWriter writer, Optional<T> value, JsonSerializerOptions options) public override bool CanConvert(Type typeToConvert)
{ {
JsonSerializer.Serialize(writer, value.HasValue ? value.GetValue() : default, typeof(T), options); if (!typeToConvert.IsGenericType)
} return false;
if (typeToConvert.GetGenericTypeDefinition() != typeof(Optional<>))
return false;
return true;
}
public class Inner<T>: JsonConverter<Optional<T>>
{
public override Optional<T> Read(ref Utf8JsonReader reader, Type typeToConvert,
JsonSerializerOptions options)
{
var inner = JsonSerializer.Deserialize<T>(ref reader, options);
return new Optional<T>(inner!);
} }
public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) public override void Write(Utf8JsonWriter writer, Optional<T> value, JsonSerializerOptions options)
{ {
var innerType = typeToConvert.GetGenericArguments()[0]; JsonSerializer.Serialize(writer, value.HasValue ? value.GetValue() : default, typeof(T), options);
return (JsonConverter?)Activator.CreateInstance(
typeof(Inner<>).MakeGenericType(innerType),
BindingFlags.Instance | BindingFlags.Public,
null,
null,
null);
}
public override bool CanConvert(Type typeToConvert)
{
if (!typeToConvert.IsGenericType)
return false;
if (typeToConvert.GetGenericTypeDefinition() != typeof(Optional<>))
return false;
return true;
} }
} }
} }

View File

@ -1,24 +1,22 @@
using System;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Serialization namespace Myriad.Serialization;
public class PermissionSetJsonConverter: JsonConverter<PermissionSet>
{ {
public class PermissionSetJsonConverter: JsonConverter<PermissionSet> public override PermissionSet Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{ {
public override PermissionSet Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) var str = reader.GetString();
{ if (str == null) return default;
var str = reader.GetString();
if (str == null) return default;
return (PermissionSet)ulong.Parse(str); return (PermissionSet)ulong.Parse(str);
} }
public override void Write(Utf8JsonWriter writer, PermissionSet value, JsonSerializerOptions options) public override void Write(Utf8JsonWriter writer, PermissionSet value, JsonSerializerOptions options)
{ {
writer.WriteStringValue(((ulong)value).ToString()); writer.WriteStringValue(((ulong)value).ToString());
}
} }
} }

View File

@ -1,28 +1,26 @@
using System;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Myriad.Gateway; using Myriad.Gateway;
namespace Myriad.Serialization namespace Myriad.Serialization;
public class ShardInfoJsonConverter: JsonConverter<ShardInfo>
{ {
public class ShardInfoJsonConverter: JsonConverter<ShardInfo> public override ShardInfo? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{ {
public override ShardInfo? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) var arr = JsonSerializer.Deserialize<int[]>(ref reader);
{ if (arr?.Length != 2)
var arr = JsonSerializer.Deserialize<int[]>(ref reader); throw new JsonException("Expected shard info as array of length 2");
if (arr?.Length != 2)
throw new JsonException("Expected shard info as array of length 2");
return new ShardInfo(arr[0], arr[1]); return new ShardInfo(arr[0], arr[1]);
} }
public override void Write(Utf8JsonWriter writer, ShardInfo value, JsonSerializerOptions options) public override void Write(Utf8JsonWriter writer, ShardInfo value, JsonSerializerOptions options)
{ {
writer.WriteStartArray(); writer.WriteStartArray();
writer.WriteNumberValue(value.ShardId); writer.WriteNumberValue(value.ShardId);
writer.WriteNumberValue(value.NumShards); writer.WriteNumberValue(value.NumShards);
writer.WriteEndArray(); writer.WriteEndArray();
}
} }
} }

View File

@ -1,25 +1,24 @@
namespace Myriad.Types namespace Myriad.Types;
{
public record Application: ApplicationPartial
{
public string Name { get; init; }
public string? Icon { get; init; }
public string Description { get; init; }
public string[]? RpcOrigins { get; init; }
public bool BotPublic { get; init; }
public bool BotRequireCodeGrant { get; init; }
public User Owner { get; init; } // TODO: docs specify this is "partial", what does that mean
public string Summary { get; init; }
public string VerifyKey { get; init; }
public ulong? GuildId { get; init; }
public ulong? PrimarySkuId { get; init; }
public string? Slug { get; init; }
public string? CoverImage { get; init; }
}
public record ApplicationPartial public record Application: ApplicationPartial
{ {
public ulong Id { get; init; } public string Name { get; init; }
public int Flags { get; init; } public string? Icon { get; init; }
} public string Description { get; init; }
public string[]? RpcOrigins { get; init; }
public bool BotPublic { get; init; }
public bool BotRequireCodeGrant { get; init; }
public User Owner { get; init; } // TODO: docs specify this is "partial", what does that mean
public string Summary { get; init; }
public string VerifyKey { get; init; }
public ulong? GuildId { get; init; }
public ulong? PrimarySkuId { get; init; }
public string? Slug { get; init; }
public string? CoverImage { get; init; }
}
public record ApplicationPartial
{
public ulong Id { get; init; }
public int Flags { get; init; }
} }

View File

@ -1,11 +1,10 @@
namespace Myriad.Types namespace Myriad.Types;
public record ApplicationCommand
{ {
public record ApplicationCommand public ulong Id { get; init; }
{ public ulong ApplicationId { get; init; }
public ulong Id { get; init; } public string Name { get; init; }
public ulong ApplicationId { get; init; } public string Description { get; init; }
public string Name { get; init; } public ApplicationCommandOption[]? Options { get; init; }
public string Description { get; init; }
public ApplicationCommandOption[]? Options { get; init; }
}
} }

View File

@ -1,11 +1,10 @@
namespace Myriad.Types namespace Myriad.Types;
public record ApplicationCommandInteractionData
{ {
public record ApplicationCommandInteractionData public ulong? Id { get; init; }
{ public string? Name { get; init; }
public ulong? Id { get; init; } public ApplicationCommandInteractionDataOption[]? Options { get; init; }
public string? Name { get; init; } public string? CustomId { get; init; }
public ApplicationCommandInteractionDataOption[]? Options { get; init; } public ComponentType? ComponentType { get; init; }
public string? CustomId { get; init; }
public ComponentType? ComponentType { get; init; }
}
} }

View File

@ -1,9 +1,8 @@
namespace Myriad.Types namespace Myriad.Types;
public record ApplicationCommandInteractionDataOption
{ {
public record ApplicationCommandInteractionDataOption public string Name { get; init; }
{ public object? Value { get; init; }
public string Name { get; init; } public ApplicationCommandInteractionDataOption[]? Options { get; init; }
public object? Value { get; init; }
public ApplicationCommandInteractionDataOption[]? Options { get; init; }
}
} }

View File

@ -1,24 +1,23 @@
namespace Myriad.Types namespace Myriad.Types;
public record ApplicationCommandOption(ApplicationCommandOption.OptionType Type, string Name, string Description)
{ {
public record ApplicationCommandOption(ApplicationCommandOption.OptionType Type, string Name, string Description) public enum OptionType
{ {
public enum OptionType Subcommand = 1,
{ SubcommandGroup = 2,
Subcommand = 1, String = 3,
SubcommandGroup = 2, Integer = 4,
String = 3, Boolean = 5,
Integer = 4, User = 6,
Boolean = 5, Channel = 7,
User = 6, Role = 8
Channel = 7,
Role = 8
}
public bool Default { get; init; }
public bool Required { get; init; }
public Choice[]? Choices { get; init; }
public ApplicationCommandOption[]? Options { get; init; }
public record Choice(string Name, object Value);
} }
public bool Default { get; init; }
public bool Required { get; init; }
public Choice[]? Choices { get; init; }
public ApplicationCommandOption[]? Options { get; init; }
public record Choice(string Name, object Value);
} }

View File

@ -1,22 +1,21 @@
namespace Myriad.Types namespace Myriad.Types;
{
public record Interaction
{
public enum InteractionType
{
Ping = 1,
ApplicationCommand = 2,
MessageComponent = 3
}
public ulong Id { get; init; } public record Interaction
public InteractionType Type { get; init; } {
public ApplicationCommandInteractionData? Data { get; init; } public enum InteractionType
public ulong GuildId { get; init; } {
public ulong ChannelId { get; init; } Ping = 1,
public GuildMember? Member { get; init; } ApplicationCommand = 2,
public User? User { get; init; } MessageComponent = 3
public string Token { get; init; }
public Message? Message { get; init; }
} }
public ulong Id { get; init; }
public InteractionType Type { get; init; }
public ApplicationCommandInteractionData? Data { get; init; }
public ulong GuildId { get; init; }
public ulong ChannelId { get; init; }
public GuildMember? Member { get; init; }
public User? User { get; init; }
public string Token { get; init; }
public Message? Message { get; init; }
} }

View File

@ -3,26 +3,25 @@ using System.Text.Json.Serialization;
using Myriad.Rest.Types; using Myriad.Rest.Types;
using Myriad.Utils; using Myriad.Utils;
namespace Myriad.Types namespace Myriad.Types;
public record InteractionApplicationCommandCallbackData
{ {
public record InteractionApplicationCommandCallbackData [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
{ public Optional<bool?> Tts { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<bool?> Tts { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<string?> Content { get; init; } public Optional<string?> Content { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<Embed[]?> Embeds { get; init; } public Optional<Embed[]?> Embeds { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<AllowedMentions?> AllowedMentions { get; init; } public Optional<AllowedMentions?> AllowedMentions { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<Message.MessageFlags> Flags { get; init; } public Optional<Message.MessageFlags> Flags { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public Optional<MessageComponent[]?> Components { get; init; } public Optional<MessageComponent[]?> Components { get; init; }
}
} }

View File

@ -1,17 +1,16 @@
namespace Myriad.Types namespace Myriad.Types;
{
public record InteractionResponse
{
public enum ResponseType
{
Pong = 1,
ChannelMessageWithSource = 4,
DeferredChannelMessageWithSource = 5,
DeferredUpdateMessage = 6,
UpdateMessage = 7
}
public ResponseType Type { get; init; } public record InteractionResponse
public InteractionApplicationCommandCallbackData? Data { get; init; } {
public enum ResponseType
{
Pong = 1,
ChannelMessageWithSource = 4,
DeferredChannelMessageWithSource = 5,
DeferredUpdateMessage = 6,
UpdateMessage = 7
} }
public ResponseType Type { get; init; }
public InteractionApplicationCommandCallbackData? Data { get; init; }
} }

View File

@ -1,45 +1,45 @@
namespace Myriad.Types namespace Myriad.Types;
public record Channel
{ {
public record Channel public enum ChannelType
{ {
public enum ChannelType GuildText = 0,
{ Dm = 1,
GuildText = 0, GuildVoice = 2,
Dm = 1, GroupDm = 3,
GuildVoice = 2, GuildCategory = 4,
GroupDm = 3, GuildNews = 5,
GuildCategory = 4, GuildStore = 6,
GuildNews = 5, GuildNewsThread = 10,
GuildStore = 6, GuildPublicThread = 11,
GuildNewsThread = 10, GuildPrivateThread = 12,
GuildPublicThread = 11, GuildStageVoice = 13
GuildPrivateThread = 12,
GuildStageVoice = 13
}
public ulong Id { get; init; }
public ChannelType Type { get; init; }
public ulong? GuildId { get; init; }
public int? Position { get; init; }
public string? Name { get; init; }
public string? Topic { get; init; }
public bool? Nsfw { get; init; }
public ulong? ParentId { get; init; }
public Overwrite[]? PermissionOverwrites { get; init; }
public User[]? Recipients { get; init; } // NOTE: this may be null for stub channel objects
public record Overwrite
{
public ulong Id { get; init; }
public OverwriteType Type { get; init; }
public PermissionSet Allow { get; init; }
public PermissionSet Deny { get; init; }
}
public enum OverwriteType
{
Role = 0,
Member = 1
}
} }
public ulong Id { get; init; }
public ChannelType Type { get; init; }
public ulong? GuildId { get; init; }
public int? Position { get; init; }
public string? Name { get; init; }
public string? Topic { get; init; }
public bool? Nsfw { get; init; }
public ulong? ParentId { get; init; }
public Overwrite[]? PermissionOverwrites { get; init; }
public User[]? Recipients { get; init; } // NOTE: this may be null for stub channel objects
public record Overwrite
{
public ulong Id { get; init; }
public OverwriteType Type { get; init; }
public PermissionSet Allow { get; init; }
public PermissionSet Deny { get; init; }
}
public enum OverwriteType
{
Role = 0,
Member = 1
}
} }

View File

@ -1,11 +1,10 @@
namespace Myriad.Types namespace Myriad.Types;
public enum ButtonStyle
{ {
public enum ButtonStyle Primary = 1,
{ Secondary = 2,
Primary = 1, Success = 3,
Secondary = 2, Danger = 4,
Success = 3, Link = 5
Danger = 4,
Link = 5
}
} }

View File

@ -1,8 +1,7 @@
namespace Myriad.Types namespace Myriad.Types;
public enum ComponentType
{ {
public enum ComponentType ActionRow = 1,
{ Button = 2
ActionRow = 1,
Button = 2
}
} }

View File

@ -1,14 +1,13 @@
namespace Myriad.Types namespace Myriad.Types;
public record MessageComponent
{ {
public record MessageComponent public ComponentType Type { get; init; }
{ public ButtonStyle? Style { get; init; }
public ComponentType Type { get; init; } public string? Label { get; init; }
public ButtonStyle? Style { get; init; } public Emoji? Emoji { get; init; }
public string? Label { get; init; } public string? CustomId { get; init; }
public Emoji? Emoji { get; init; } public string? Url { get; init; }
public string? CustomId { get; init; } public bool? Disabled { get; init; }
public string? Url { get; init; } public MessageComponent[]? Components { get; init; }
public bool? Disabled { get; init; }
public MessageComponent[]? Components { get; init; }
}
} }

View File

@ -1,62 +1,61 @@
namespace Myriad.Types namespace Myriad.Types;
public record Embed
{ {
public record Embed public string? Title { get; init; }
{ public string? Type { get; init; }
public string? Title { get; init; } public string? Description { get; init; }
public string? Type { get; init; } public string? Url { get; init; }
public string? Description { get; init; } public string? Timestamp { get; init; }
public string? Url { get; init; } public uint? Color { get; init; }
public string? Timestamp { get; init; } public EmbedFooter? Footer { get; init; }
public uint? Color { get; init; } public EmbedImage? Image { get; init; }
public EmbedFooter? Footer { get; init; } public EmbedThumbnail? Thumbnail { get; init; }
public EmbedImage? Image { get; init; } public EmbedVideo? Video { get; init; }
public EmbedThumbnail? Thumbnail { get; init; } public EmbedProvider? Provider { get; init; }
public EmbedVideo? Video { get; init; } public EmbedAuthor? Author { get; init; }
public EmbedProvider? Provider { get; init; } public Field[]? Fields { get; init; }
public EmbedAuthor? Author { get; init; }
public Field[]? Fields { get; init; }
public record EmbedFooter( public record EmbedFooter(
string Text, string Text,
string? IconUrl = null, string? IconUrl = null,
string? ProxyIconUrl = null string? ProxyIconUrl = null
); );
public record EmbedImage( public record EmbedImage(
string? Url, string? Url,
uint? Width = null, uint? Width = null,
uint? Height = null uint? Height = null
); );
public record EmbedThumbnail( public record EmbedThumbnail(
string? Url, string? Url,
string? ProxyUrl = null, string? ProxyUrl = null,
uint? Width = null, uint? Width = null,
uint? Height = null uint? Height = null
); );
public record EmbedVideo( public record EmbedVideo(
string? Url, string? Url,
uint? Width = null, uint? Width = null,
uint? Height = null uint? Height = null
); );
public record EmbedProvider( public record EmbedProvider(
string? Name, string? Name,
string? Url string? Url
); );
public record EmbedAuthor( public record EmbedAuthor(
string? Name = null, string? Name = null,
string? Url = null, string? Url = null,
string? IconUrl = null, string? IconUrl = null,
string? ProxyIconUrl = null string? ProxyIconUrl = null
); );
public record Field( public record Field(
string Name, string Name,
string Value, string Value,
bool Inline = false bool Inline = false
); );
}
} }

Some files were not shown because too many files have changed in this diff Show More