Initial commit, basic proxying working

This commit is contained in:
Ske
2020-12-22 13:15:26 +01:00
parent c3f6becea4
commit a6fbd869be
109 changed files with 3539 additions and 359 deletions

View File

@@ -0,0 +1,240 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Rest.Exceptions;
using Myriad.Rest.Ratelimit;
using Myriad.Rest.Types;
using Myriad.Serialization;
using Polly;
using Serilog;
namespace Myriad.Rest
{
public class BaseRestClient: IAsyncDisposable
{
private const string ApiBaseUrl = "https://discord.com/api/v8";
private readonly Version _httpVersion = new(2, 0);
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger;
private readonly Ratelimiter _ratelimiter;
private readonly AsyncPolicy<HttpResponseMessage> _retryPolicy;
public BaseRestClient(string userAgent, string token, ILogger logger)
{
_logger = logger.ForContext<BaseRestClient>();
if (!token.StartsWith("Bot "))
token = "Bot " + token;
Client = new HttpClient();
Client.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgent);
Client.DefaultRequestHeaders.TryAddWithoutValidation("Authorization", token);
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForNewcord();
_ratelimiter = new Ratelimiter(logger);
var discordPolicy = new DiscordRateLimitPolicy(_ratelimiter);
// todo: why doesn't the timeout work? o.o
var timeoutPolicy = Policy.TimeoutAsync<HttpResponseMessage>(TimeSpan.FromSeconds(10));
var waitPolicy = Policy
.Handle<RatelimitBucketExhaustedException>()
.WaitAndRetryAsync(3,
(_, e, _) => ((RatelimitBucketExhaustedException) e).RetryAfter,
(_, _, _, _) => Task.CompletedTask)
.AsAsyncPolicy<HttpResponseMessage>();
_retryPolicy = Policy.WrapAsync(timeoutPolicy, waitPolicy, discordPolicy);
}
public HttpClient Client { get; }
public ValueTask DisposeAsync()
{
_ratelimiter.Dispose();
Client.Dispose();
return default;
}
public async Task<T?> Get<T>(string path, (string endpointName, ulong major) ratelimitParams) where T: class
{
var request = new HttpRequestMessage(HttpMethod.Get, ApiBaseUrl + path);
var response = await Send(request, ratelimitParams, true);
// GET-only special case: 404s are nulls and not exceptions
if (response.StatusCode == HttpStatusCode.NotFound)
return null;
return await ReadResponse<T>(response);
}
public async Task<T?> Post<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> PostMultipart<T>(string path, (string endpointName, ulong major) ratelimitParams, object? payload, MultipartFile[]? files)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path);
SetRequestFormDataBody(request, payload, files);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Patch<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Patch, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task<T?> Put<T>(string path, (string endpointName, ulong major) ratelimitParams, object? body)
where T: class
{
var request = new HttpRequestMessage(HttpMethod.Put, ApiBaseUrl + path);
SetRequestJsonBody(request, body);
var response = await Send(request, ratelimitParams);
return await ReadResponse<T>(response);
}
public async Task Delete(string path, (string endpointName, ulong major) ratelimitParams)
{
var request = new HttpRequestMessage(HttpMethod.Delete, ApiBaseUrl + path);
await Send(request, ratelimitParams);
}
private void SetRequestJsonBody(HttpRequestMessage request, object? body)
{
if (body == null) return;
request.Content =
new ReadOnlyMemoryContent(JsonSerializer.SerializeToUtf8Bytes(body, _jsonSerializerOptions));
request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json");
}
private void SetRequestFormDataBody(HttpRequestMessage request, object? payload, MultipartFile[]? files)
{
var bodyJson = JsonSerializer.SerializeToUtf8Bytes(payload, _jsonSerializerOptions);
var mfd = new MultipartFormDataContent();
mfd.Add(new ByteArrayContent(bodyJson), "payload_json");
if (files != null)
{
for (var i = 0; i < files.Length; i++)
{
var (filename, stream) = files[i];
mfd.Add(new StreamContent(stream), $"file{i}", filename);
}
}
request.Content = mfd;
}
private async Task<T?> ReadResponse<T>(HttpResponseMessage response) where T: class
{
if (response.StatusCode == HttpStatusCode.NoContent)
return null;
return await response.Content.ReadFromJsonAsync<T>(_jsonSerializerOptions);
}
private async Task<HttpResponseMessage> Send(HttpRequestMessage request,
(string endpointName, ulong major) ratelimitParams,
bool ignoreNotFound = false)
{
return await _retryPolicy.ExecuteAsync(async _ =>
{
_logger.Debug("Sending request: {RequestMethod} {RequestPath}",
request.Method, request.RequestUri);
request.Version = _httpVersion;
request.VersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
var stopwatch = new Stopwatch();
stopwatch.Start();
var response = await Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
stopwatch.Stop();
_logger.Debug(
"Received response in {ResponseDurationMs} ms: {RequestMethod} {RequestPath} -> {StatusCode} {ReasonPhrase}",
stopwatch.ElapsedMilliseconds, request.Method, request.RequestUri, (int) response.StatusCode,
response.ReasonPhrase);
await HandleApiError(response, ignoreNotFound);
return response;
},
new Dictionary<string, object>
{
{DiscordRateLimitPolicy.EndpointContextKey, ratelimitParams.endpointName},
{DiscordRateLimitPolicy.MajorContextKey, ratelimitParams.major}
});
}
private async ValueTask HandleApiError(HttpResponseMessage response, bool ignoreNotFound)
{
if (response.IsSuccessStatusCode)
return;
if (response.StatusCode == HttpStatusCode.NotFound && ignoreNotFound)
return;
throw await CreateDiscordException(response);
}
private async ValueTask<DiscordRequestException> CreateDiscordException(HttpResponseMessage response)
{
var body = await response.Content.ReadAsStringAsync();
var apiError = TryParseApiError(body);
return response.StatusCode switch
{
HttpStatusCode.BadRequest => new BadRequestException(response, body, apiError),
HttpStatusCode.Forbidden => new ForbiddenException(response, body, apiError),
HttpStatusCode.Unauthorized => new UnauthorizedException(response, body, apiError),
HttpStatusCode.NotFound => new NotFoundException(response, body, apiError),
HttpStatusCode.Conflict => new ConflictException(response, body, apiError),
HttpStatusCode.TooManyRequests => new TooManyRequestsException(response, body, apiError),
_ => new UnknownDiscordRequestException(response, body, apiError)
};
}
private DiscordApiError? TryParseApiError(string responseBody)
{
if (string.IsNullOrWhiteSpace(responseBody))
return null;
try
{
return JsonSerializer.Deserialize<DiscordApiError>(responseBody, _jsonSerializerOptions);
}
catch (JsonException e)
{
_logger.Verbose(e, "Error deserializing API error");
}
return null;
}
}
}

View File

@@ -0,0 +1,120 @@
using System;
using System.IO;
using System.Net;
using System.Threading.Tasks;
using Myriad.Rest.Types;
using Myriad.Rest.Types.Requests;
using Myriad.Types;
using Serilog;
namespace Myriad.Rest
{
public class DiscordApiClient
{
private const string UserAgent = "Test Discord Library by @Ske#6201";
private readonly BaseRestClient _client;
public DiscordApiClient(string token, ILogger logger)
{
_client = new BaseRestClient(UserAgent, token, logger);
}
public Task<GatewayInfo> GetGateway() =>
_client.Get<GatewayInfo>("/gateway", ("GetGateway", default))!;
public Task<GatewayInfo.Bot> GetGatewayBot() =>
_client.Get<GatewayInfo.Bot>("/gateway/bot", ("GetGatewayBot", default))!;
public Task<Channel?> GetChannel(ulong channelId) =>
_client.Get<Channel>($"/channels/{channelId}", ("GetChannel", channelId));
public Task<Message?> GetMessage(ulong channelId, ulong messageId) =>
_client.Get<Message>($"/channels/{channelId}/messages/{messageId}", ("GetMessage", channelId));
public Task<Channel?> GetGuild(ulong id) =>
_client.Get<Channel>($"/guilds/{id}", ("GetGuild", id));
public Task<User?> GetUser(ulong id) =>
_client.Get<User>($"/users/{id}", ("GetUser", default));
public Task<Message> CreateMessage(ulong channelId, MessageRequest request) =>
_client.Post<Message>($"/channels/{channelId}/messages", ("CreateMessage", channelId), request)!;
public Task<Message> EditMessage(ulong channelId, ulong messageId, MessageEditRequest request) =>
_client.Patch<Message>($"/channels/{channelId}/messages/{messageId}", ("EditMessage", channelId), request)!;
public Task DeleteMessage(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}", ("DeleteMessage", channelId));
public Task CreateReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Put<object>($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("CreateReaction", channelId), null);
public Task DeleteOwnReaction(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/@me",
("DeleteOwnReaction", channelId));
public Task DeleteUserReaction(ulong channelId, ulong messageId, Emoji emoji, ulong userId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}/{userId}",
("DeleteUserReaction", channelId));
public Task DeleteAllReactions(ulong channelId, ulong messageId) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions",
("DeleteAllReactions", channelId));
public Task DeleteAllReactionsForEmoji(ulong channelId, ulong messageId, Emoji emoji) =>
_client.Delete($"/channels/{channelId}/messages/{messageId}/reactions/{EncodeEmoji(emoji)}",
("DeleteAllReactionsForEmoji", channelId));
public Task<ApplicationCommand> CreateGlobalApplicationCommand(ulong applicationId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/commands",
("CreateGlobalApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand[]> GetGuildApplicationCommands(ulong applicationId, ulong guildId) =>
_client.Get<ApplicationCommand[]>($"/applications/{applicationId}/guilds/{guildId}/commands",
("GetGuildApplicationCommands", applicationId))!;
public Task<ApplicationCommand> CreateGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Post<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("CreateGuildApplicationCommand", applicationId), request)!;
public Task<ApplicationCommand> EditGuildApplicationCommand(ulong applicationId, ulong guildId,
ApplicationCommandRequest request) =>
_client.Patch<ApplicationCommand>($"/applications/{applicationId}/guilds/{guildId}/commands",
("EditGuildApplicationCommand", applicationId), request)!;
public Task DeleteGuildApplicationCommand(ulong applicationId, ulong commandId) =>
_client.Delete($"/applications/{applicationId}/commands/{commandId}",
("DeleteGuildApplicationCommand", applicationId));
public Task CreateInteractionResponse(ulong interactionId, string token, InteractionResponse response) =>
_client.Post<object>($"/interactions/{interactionId}/{token}/callback",
("CreateInteractionResponse", interactionId), response);
public Task ModifyGuildMember(ulong guildId, ulong userId, ModifyGuildMemberRequest request) =>
_client.Patch<object>($"/guilds/{guildId}/members/{userId}",
("ModifyGuildMember", guildId), request);
public Task<Webhook> CreateWebhook(ulong channelId, CreateWebhookRequest request) =>
_client.Post<Webhook>($"/channels/{channelId}/webhooks", ("CreateWebhook", channelId), request)!;
public Task<Webhook> GetWebhook(ulong webhookId) =>
_client.Get<Webhook>($"/webhooks/{webhookId}/webhooks", ("GetWebhook", webhookId))!;
public Task<Webhook[]> GetChannelWebhooks(ulong channelId) =>
_client.Get<Webhook[]>($"/channels/{channelId}/webhooks", ("GetChannelWebhooks", channelId))!;
public Task<Message> ExecuteWebhook(ulong webhookId, string webhookToken, ExecuteWebhookRequest request,
MultipartFile[]? files = null) =>
_client.PostMultipart<Message>($"/webhooks/{webhookId}/{webhookToken}",
("ExecuteWebhook", webhookId), request, files)!;
private static string EncodeEmoji(Emoji emoji) =>
WebUtility.UrlEncode(emoji.Name) ?? emoji.Id?.ToString() ??
throw new ArgumentException("Could not encode emoji");
}
}

View File

@@ -0,0 +1,9 @@
using System.Text.Json;
namespace Myriad.Rest
{
public record DiscordApiError(string Message, int Code)
{
public JsonElement? Errors { get; init; }
}
}

View File

@@ -0,0 +1,71 @@
using System;
using System.Net;
using System.Net.Http;
namespace Myriad.Rest.Exceptions
{
public class DiscordRequestException: Exception
{
public DiscordRequestException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError)
{
RequestBody = requestBody;
Response = response;
ApiError = apiError;
}
public string RequestBody { get; init; } = null!;
public HttpResponseMessage Response { get; init; } = null!;
public HttpStatusCode StatusCode => Response.StatusCode;
public int? ErrorCode => ApiError?.Code;
internal DiscordApiError? ApiError { get; init; }
public override string Message =>
(ApiError?.Message ?? Response.ReasonPhrase ?? "") + (FormError != null ? $": {FormError}" : "");
public string? FormError => ApiError?.Errors?.ToString();
}
public class NotFoundException: DiscordRequestException
{
public NotFoundException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class UnauthorizedException: DiscordRequestException
{
public UnauthorizedException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class ForbiddenException: DiscordRequestException
{
public ForbiddenException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class ConflictException: DiscordRequestException
{
public ConflictException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class BadRequestException: DiscordRequestException
{
public BadRequestException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError): base(
response, requestBody, apiError) { }
}
public class TooManyRequestsException: DiscordRequestException
{
public TooManyRequestsException(HttpResponseMessage response, string requestBody, DiscordApiError? apiError):
base(response, requestBody, apiError) { }
}
public class UnknownDiscordRequestException: DiscordRequestException
{
public UnknownDiscordRequestException(HttpResponseMessage response, string requestBody,
DiscordApiError? apiError): base(response, requestBody, apiError) { }
}
}

View File

@@ -0,0 +1,29 @@
using System;
using Myriad.Rest.Ratelimit;
namespace Myriad.Rest.Exceptions
{
public class RatelimitException: Exception
{
public RatelimitException(string? message): base(message) { }
}
public class RatelimitBucketExhaustedException: RatelimitException
{
public RatelimitBucketExhaustedException(Bucket bucket, TimeSpan retryAfter): base(
"Rate limit bucket exhausted, request blocked")
{
Bucket = bucket;
RetryAfter = retryAfter;
}
public Bucket Bucket { get; }
public TimeSpan RetryAfter { get; }
}
public class GloballyRatelimitedException: RatelimitException
{
public GloballyRatelimitedException(): base("Global rate limit hit") { }
}
}

View File

@@ -0,0 +1,152 @@
using System;
using System.Threading;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class Bucket
{
private static readonly TimeSpan Epsilon = TimeSpan.FromMilliseconds(10);
private static readonly TimeSpan FallbackDelay = TimeSpan.FromMilliseconds(200);
private static readonly TimeSpan StaleTimeout = TimeSpan.FromSeconds(5);
private readonly ILogger _logger;
private readonly SemaphoreSlim _semaphore = new(1, 1);
private DateTimeOffset _nextReset;
private bool _resetTimeValid;
public Bucket(ILogger logger, string key, ulong major, int limit)
{
_logger = logger.ForContext<Bucket>();
Key = key;
Major = major;
Limit = limit;
Remaining = limit;
_resetTimeValid = false;
}
public string Key { get; }
public ulong Major { get; }
public int Remaining { get; private set; }
public int Limit { get; private set; }
public DateTimeOffset LastUsed { get; private set; } = DateTimeOffset.UtcNow;
public bool TryAcquire()
{
LastUsed = DateTimeOffset.Now;
try
{
_semaphore.Wait();
if (Remaining > 0)
{
_logger.Debug(
"{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit} left], allowing through",
Key, Major, Remaining, Limit);
Remaining--;
return true;
}
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket has [{BucketRemaining}/{BucketLimit}] left, denying",
Key, Major, Remaining, Limit);
return false;
}
finally
{
_semaphore.Release();
}
}
public void HandleResponse(RatelimitHeaders headers)
{
try
{
_semaphore.Wait();
if (headers.ResetAfter != null)
{
var headerNextReset = DateTimeOffset.UtcNow + headers.ResetAfter.Value; // todo: server time
if (headerNextReset > _nextReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Received reset time {NextReset} from server",
Key, Major, _nextReset);
_nextReset = headerNextReset;
_resetTimeValid = true;
}
}
if (headers.Limit != null)
Limit = headers.Limit.Value;
}
finally
{
_semaphore.Release();
}
}
public void Tick(DateTimeOffset now)
{
try
{
_semaphore.Wait();
// If we're past the reset time *and* we haven't reset already, do that
var timeSinceReset = _nextReset - now;
var shouldReset = _resetTimeValid && timeSinceReset > TimeSpan.Zero;
if (shouldReset)
{
_logger.Debug("{BucketKey}/{BucketMajor}: Bucket timed out, refreshing with {BucketLimit} requests",
Key, Major, Limit);
Remaining = Limit;
_resetTimeValid = false;
return;
}
// We've run out of requests without having any new reset time,
// *and* it's been longer than a set amount - add one request back to the pool and hope that one returns
var isBucketStale = !_resetTimeValid && Remaining <= 0 && timeSinceReset > StaleTimeout;
if (isBucketStale)
{
_logger.Warning(
"{BucketKey}/{BucketMajor}: Bucket is stale ({StaleTimeout} passed with no rate limit info), allowing one request through",
Key, Major, StaleTimeout);
Remaining = 1;
// Reset the (still-invalid) reset time to now, so we don't keep hitting this conditional over and over...
_nextReset = now;
}
}
finally
{
_semaphore.Release();
}
}
public TimeSpan GetResetDelay(DateTimeOffset now)
{
// If we don't have a valid reset time, return the fallback delay always
// (so it'll keep spinning until we hopefully have one...)
if (!_resetTimeValid)
return FallbackDelay;
var delay = _nextReset - now;
// If we have a really small (or negative) value, return a fallback delay too
if (delay < Epsilon)
return FallbackDelay;
return delay;
}
}
}

View File

@@ -0,0 +1,79 @@
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class BucketManager: IDisposable
{
private static readonly TimeSpan StaleBucketTimeout = TimeSpan.FromMinutes(5);
private static readonly TimeSpan PruneWorkerInterval = TimeSpan.FromMinutes(1);
private readonly ConcurrentDictionary<(string key, ulong major), Bucket> _buckets = new();
private readonly ConcurrentDictionary<string, string> _endpointKeyMap = new();
private readonly ConcurrentDictionary<string, int> _knownKeyLimits = new();
private readonly ILogger _logger;
private readonly Task _worker;
private readonly CancellationTokenSource _workerCts = new();
public BucketManager(ILogger logger)
{
_logger = logger.ForContext<BucketManager>();
_worker = PruneWorker(_workerCts.Token);
}
public void Dispose()
{
_workerCts.Dispose();
_worker.Dispose();
}
public Bucket? GetBucket(string endpoint, ulong major)
{
if (!_endpointKeyMap.TryGetValue(endpoint, out var key))
return null;
if (_buckets.TryGetValue((key, major), out var bucket))
return bucket;
if (!_knownKeyLimits.TryGetValue(key, out var knownLimit))
return null;
return _buckets.GetOrAdd((key, major),
k => new Bucket(_logger, k.Item1, k.Item2, knownLimit));
}
public void UpdateEndpointInfo(string endpoint, string key, int? limit)
{
_endpointKeyMap[endpoint] = key;
if (limit != null)
_knownKeyLimits[key] = limit.Value;
}
private async Task PruneWorker(CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{
await Task.Delay(PruneWorkerInterval, ct);
PruneStaleBuckets(DateTimeOffset.UtcNow);
}
}
private void PruneStaleBuckets(DateTimeOffset now)
{
foreach (var (key, bucket) in _buckets)
if (now - bucket.LastUsed > StaleBucketTimeout)
{
_logger.Debug("Pruning unused bucket {Bucket} (last used at {BucketLastUsed})", bucket,
bucket.LastUsed);
_buckets.TryRemove(key, out _);
}
}
}
}

View File

@@ -0,0 +1,46 @@
using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Polly;
namespace Myriad.Rest.Ratelimit
{
public class DiscordRateLimitPolicy: AsyncPolicy<HttpResponseMessage>
{
public const string EndpointContextKey = "Endpoint";
public const string MajorContextKey = "Major";
private readonly Ratelimiter _ratelimiter;
public DiscordRateLimitPolicy(Ratelimiter ratelimiter, PolicyBuilder<HttpResponseMessage>? policyBuilder = null)
: base(policyBuilder)
{
_ratelimiter = ratelimiter;
}
protected override async Task<HttpResponseMessage> ImplementationAsync(
Func<Context, CancellationToken, Task<HttpResponseMessage>> action, Context context, CancellationToken ct,
bool continueOnCapturedContext)
{
if (!context.TryGetValue(EndpointContextKey, out var endpointObj) || !(endpointObj is string endpoint))
throw new ArgumentException("Must provide endpoint in Polly context");
if (!context.TryGetValue(MajorContextKey, out var majorObj) || !(majorObj is ulong major))
throw new ArgumentException("Must provide major in Polly context");
// Check rate limit, throw if we're not allowed...
_ratelimiter.AllowRequestOrThrow(endpoint, major, DateTimeOffset.Now);
// We're OK, push it through
var response = await action(context, ct).ConfigureAwait(continueOnCapturedContext);
// Update rate limit state with headers
var headers = new RatelimitHeaders(response);
_ratelimiter.HandleResponse(headers, endpoint, major);
return response;
}
}
}

View File

@@ -0,0 +1,46 @@
using System;
using System.Linq;
using System.Net.Http;
namespace Myriad.Rest.Ratelimit
{
public record RatelimitHeaders
{
public RatelimitHeaders() { }
public RatelimitHeaders(HttpResponseMessage response)
{
ServerDate = response.Headers.Date;
if (response.Headers.TryGetValues("X-RateLimit-Limit", out var limit))
Limit = int.Parse(limit!.First());
if (response.Headers.TryGetValues("X-RateLimit-Remaining", out var remaining))
Remaining = int.Parse(remaining!.First());
if (response.Headers.TryGetValues("X-RateLimit-Reset", out var reset))
Reset = DateTimeOffset.FromUnixTimeMilliseconds((long) (double.Parse(reset!.First()) * 1000));
if (response.Headers.TryGetValues("X-RateLimit-Reset-After", out var resetAfter))
ResetAfter = TimeSpan.FromSeconds(double.Parse(resetAfter!.First()));
if (response.Headers.TryGetValues("X-RateLimit-Bucket", out var bucket))
Bucket = bucket.First();
if (response.Headers.TryGetValues("X-RateLimit-Global", out var global))
Global = bool.Parse(global!.First());
}
public bool Global { get; init; }
public int? Limit { get; init; }
public int? Remaining { get; init; }
public DateTimeOffset? Reset { get; init; }
public TimeSpan? ResetAfter { get; init; }
public string? Bucket { get; init; }
public DateTimeOffset? ServerDate { get; init; }
public bool HasRatelimitInfo =>
Limit != null && Remaining != null && Reset != null && ResetAfter != null && Bucket != null;
}
}

View File

@@ -0,0 +1,86 @@
using System;
using Myriad.Rest.Exceptions;
using Serilog;
namespace Myriad.Rest.Ratelimit
{
public class Ratelimiter: IDisposable
{
private readonly BucketManager _buckets;
private readonly ILogger _logger;
private DateTimeOffset? _globalRateLimitExpiry;
public Ratelimiter(ILogger logger)
{
_logger = logger.ForContext<Ratelimiter>();
_buckets = new BucketManager(logger);
}
public void Dispose()
{
_buckets.Dispose();
}
public void AllowRequestOrThrow(string endpoint, ulong major, DateTimeOffset now)
{
if (IsGloballyRateLimited(now))
{
_logger.Warning("Globally rate limited until {GlobalRateLimitExpiry}, cancelling request",
_globalRateLimitExpiry);
throw new GloballyRatelimitedException();
}
var bucket = _buckets.GetBucket(endpoint, major);
if (bucket == null)
{
// No rate limit for this endpoint (yet), allow through
_logger.Debug("No rate limit data for endpoint {Endpoint}, allowing through", endpoint);
return;
}
bucket.Tick(now);
if (bucket.TryAcquire())
// We're allowed to send it! :)
return;
// We can't send this request right now; retrying...
var waitTime = bucket.GetResetDelay(now);
// add a small buffer for Timing:tm:
waitTime += TimeSpan.FromMilliseconds(50);
// (this is caught by a WaitAndRetry Polly handler, if configured)
throw new RatelimitBucketExhaustedException(bucket, waitTime);
}
public void HandleResponse(RatelimitHeaders headers, string endpoint, ulong major)
{
if (!headers.HasRatelimitInfo)
return;
// TODO: properly calculate server time?
if (headers.Global)
{
_logger.Warning(
"Global rate limit hit, resetting at {GlobalRateLimitExpiry} (in {GlobalRateLimitResetAfter}!",
_globalRateLimitExpiry, headers.ResetAfter);
_globalRateLimitExpiry = headers.Reset;
}
else
{
// Update buckets first, then get it again, to properly "transfer" this info over to the new value
_buckets.UpdateEndpointInfo(endpoint, headers.Bucket!, headers.Limit);
var bucket = _buckets.GetBucket(endpoint, major);
bucket?.HandleResponse(headers);
}
}
private bool IsGloballyRateLimited(DateTimeOffset now) =>
_globalRateLimitExpiry > now;
}
}

View File

@@ -0,0 +1,19 @@
using System.Collections.Generic;
namespace Myriad.Rest.Types
{
public record AllowedMentions
{
public enum ParseType
{
Roles,
Users,
Everyone
}
public List<ParseType>? Parse { get; set; }
public List<ulong>? Users { get; set; }
public List<ulong>? Roles { get; set; }
public bool RepliedUser { get; set; }
}
}

View File

@@ -0,0 +1,6 @@
using System.IO;
namespace Myriad.Rest.Types
{
public record MultipartFile(string Filename, Stream Data);
}

View File

@@ -0,0 +1,13 @@
using System.Collections.Generic;
using Myriad.Types;
namespace Myriad.Rest.Types
{
public record ApplicationCommandRequest
{
public string Name { get; init; }
public string Description { get; init; }
public List<ApplicationCommandOption>? Options { get; init; }
}
}

View File

@@ -0,0 +1,4 @@
namespace Myriad.Rest.Types.Requests
{
public record CreateWebhookRequest(string Name);
}

View File

@@ -0,0 +1,13 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record ExecuteWebhookRequest
{
public string? Content { get; init; }
public string? Username { get; init; }
public string? AvatarUrl { get; init; }
public Embed[] Embeds { get; init; }
public AllowedMentions? AllowedMentions { get; init; }
}
}

View File

@@ -0,0 +1,10 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record MessageEditRequest
{
public string? Content { get; set; }
public Embed? Embed { get; set; }
}
}

View File

@@ -0,0 +1,13 @@
using Myriad.Types;
namespace Myriad.Rest.Types.Requests
{
public record MessageRequest
{
public string? Content { get; set; }
public object? Nonce { get; set; }
public bool Tts { get; set; }
public AllowedMentions AllowedMentions { get; set; }
public Embed? Embeds { get; set; }
}
}

View File

@@ -0,0 +1,7 @@
namespace Myriad.Rest.Types
{
public record ModifyGuildMemberRequest
{
public string? Nick { get; init; }
}
}