using System; using System.Collections.Generic; using System.Diagnostics; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Net.Http.Json; using System.Text.Json; using System.Threading.Tasks; using Myriad.Rest.Exceptions; using Myriad.Rest.Ratelimit; using Myriad.Rest.Types; using Myriad.Serialization; using Polly; using Serilog; namespace Myriad.Rest { public class BaseRestClient: IAsyncDisposable { private const string ApiBaseUrl = "https://discord.com/api/v8"; private readonly Version _httpVersion = new(2, 0); private readonly JsonSerializerOptions _jsonSerializerOptions; private readonly ILogger _logger; private readonly Ratelimiter _ratelimiter; private readonly AsyncPolicy _retryPolicy; public BaseRestClient(string userAgent, string token, ILogger logger) { _logger = logger.ForContext(); if (!token.StartsWith("Bot ")) token = "Bot " + token; Client = new HttpClient(); Client.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", userAgent); Client.DefaultRequestHeaders.TryAddWithoutValidation("Authorization", token); _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForNewcord(); _ratelimiter = new Ratelimiter(logger); var discordPolicy = new DiscordRateLimitPolicy(_ratelimiter); // todo: why doesn't the timeout work? o.o var timeoutPolicy = Policy.TimeoutAsync(TimeSpan.FromSeconds(10)); var waitPolicy = Policy .Handle() .WaitAndRetryAsync(3, (_, e, _) => ((RatelimitBucketExhaustedException) e).RetryAfter, (_, _, _, _) => Task.CompletedTask) .AsAsyncPolicy(); _retryPolicy = Policy.WrapAsync(timeoutPolicy, waitPolicy, discordPolicy); } public HttpClient Client { get; } public ValueTask DisposeAsync() { _ratelimiter.Dispose(); Client.Dispose(); return default; } public async Task Get(string path, (string endpointName, ulong major) ratelimitParams) where T: class { var request = new HttpRequestMessage(HttpMethod.Get, ApiBaseUrl + path); var response = await Send(request, ratelimitParams, true); // GET-only special case: 404s are nulls and not exceptions if (response.StatusCode == HttpStatusCode.NotFound) return null; return await ReadResponse(response); } public async Task Post(string path, (string endpointName, ulong major) ratelimitParams, object? body) where T: class { var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path); SetRequestJsonBody(request, body); var response = await Send(request, ratelimitParams); return await ReadResponse(response); } public async Task PostMultipart(string path, (string endpointName, ulong major) ratelimitParams, object? payload, MultipartFile[]? files) where T: class { var request = new HttpRequestMessage(HttpMethod.Post, ApiBaseUrl + path); SetRequestFormDataBody(request, payload, files); var response = await Send(request, ratelimitParams); return await ReadResponse(response); } public async Task Patch(string path, (string endpointName, ulong major) ratelimitParams, object? body) where T: class { var request = new HttpRequestMessage(HttpMethod.Patch, ApiBaseUrl + path); SetRequestJsonBody(request, body); var response = await Send(request, ratelimitParams); return await ReadResponse(response); } public async Task Put(string path, (string endpointName, ulong major) ratelimitParams, object? body) where T: class { var request = new HttpRequestMessage(HttpMethod.Put, ApiBaseUrl + path); SetRequestJsonBody(request, body); var response = await Send(request, ratelimitParams); return await ReadResponse(response); } public async Task Delete(string path, (string endpointName, ulong major) ratelimitParams) { var request = new HttpRequestMessage(HttpMethod.Delete, ApiBaseUrl + path); await Send(request, ratelimitParams); } private void SetRequestJsonBody(HttpRequestMessage request, object? body) { if (body == null) return; request.Content = new ReadOnlyMemoryContent(JsonSerializer.SerializeToUtf8Bytes(body, _jsonSerializerOptions)); request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json"); } private void SetRequestFormDataBody(HttpRequestMessage request, object? payload, MultipartFile[]? files) { var bodyJson = JsonSerializer.SerializeToUtf8Bytes(payload, _jsonSerializerOptions); var mfd = new MultipartFormDataContent(); mfd.Add(new ByteArrayContent(bodyJson), "payload_json"); if (files != null) { for (var i = 0; i < files.Length; i++) { var (filename, stream) = files[i]; mfd.Add(new StreamContent(stream), $"file{i}", filename); } } request.Content = mfd; } private async Task ReadResponse(HttpResponseMessage response) where T: class { if (response.StatusCode == HttpStatusCode.NoContent) return null; return await response.Content.ReadFromJsonAsync(_jsonSerializerOptions); } private async Task Send(HttpRequestMessage request, (string endpointName, ulong major) ratelimitParams, bool ignoreNotFound = false) { return await _retryPolicy.ExecuteAsync(async _ => { _logger.Debug("Sending request: {RequestMethod} {RequestPath}", request.Method, request.RequestUri); request.Version = _httpVersion; request.VersionPolicy = HttpVersionPolicy.RequestVersionOrHigher; var stopwatch = new Stopwatch(); stopwatch.Start(); var response = await Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); stopwatch.Stop(); _logger.Debug( "Received response in {ResponseDurationMs} ms: {RequestMethod} {RequestPath} -> {StatusCode} {ReasonPhrase}", stopwatch.ElapsedMilliseconds, request.Method, request.RequestUri, (int) response.StatusCode, response.ReasonPhrase); await HandleApiError(response, ignoreNotFound); return response; }, new Dictionary { {DiscordRateLimitPolicy.EndpointContextKey, ratelimitParams.endpointName}, {DiscordRateLimitPolicy.MajorContextKey, ratelimitParams.major} }); } private async ValueTask HandleApiError(HttpResponseMessage response, bool ignoreNotFound) { if (response.IsSuccessStatusCode) return; if (response.StatusCode == HttpStatusCode.NotFound && ignoreNotFound) return; throw await CreateDiscordException(response); } private async ValueTask CreateDiscordException(HttpResponseMessage response) { var body = await response.Content.ReadAsStringAsync(); var apiError = TryParseApiError(body); return response.StatusCode switch { HttpStatusCode.BadRequest => new BadRequestException(response, body, apiError), HttpStatusCode.Forbidden => new ForbiddenException(response, body, apiError), HttpStatusCode.Unauthorized => new UnauthorizedException(response, body, apiError), HttpStatusCode.NotFound => new NotFoundException(response, body, apiError), HttpStatusCode.Conflict => new ConflictException(response, body, apiError), HttpStatusCode.TooManyRequests => new TooManyRequestsException(response, body, apiError), _ => new UnknownDiscordRequestException(response, body, apiError) }; } private DiscordApiError? TryParseApiError(string responseBody) { if (string.IsNullOrWhiteSpace(responseBody)) return null; try { return JsonSerializer.Deserialize(responseBody, _jsonSerializerOptions); } catch (JsonException e) { _logger.Verbose(e, "Error deserializing API error"); } return null; } } }