Add support for Twilight gateway queue

This commit is contained in:
Ske 2021-06-09 16:22:10 +02:00
parent 333530d24d
commit 26dc69e5a4
8 changed files with 70 additions and 15 deletions

View File

@ -4,6 +4,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Myriad.Gateway.Limit;
using Myriad.Types; using Myriad.Types;
using Serilog; using Serilog;
@ -15,7 +16,7 @@ namespace Myriad.Gateway
private readonly GatewaySettings _gatewaySettings; private readonly GatewaySettings _gatewaySettings;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly ConcurrentDictionary<int, Shard> _shards = new(); private readonly ConcurrentDictionary<int, Shard> _shards = new();
private ShardIdentifyRatelimiter? _ratelimiter; private IGatewayRatelimiter? _ratelimiter;
public Cluster(GatewaySettings gatewaySettings, ILogger logger) public Cluster(GatewaySettings gatewaySettings, ILogger logger)
{ {
@ -35,10 +36,9 @@ namespace Myriad.Gateway
await Start(info.Url, 0, info.Shards - 1, info.Shards, info.SessionStartLimit.MaxConcurrency); await Start(info.Url, 0, info.Shards - 1, info.Shards, info.SessionStartLimit.MaxConcurrency);
} }
public async Task Start(string url, int shardMin, int shardMax, int shardTotal, int concurrency) public async Task Start(string url, int shardMin, int shardMax, int shardTotal, int recommendedConcurrency)
{ {
concurrency = GetActualShardConcurrency(concurrency); _ratelimiter = GetRateLimiter(recommendedConcurrency);
_ratelimiter = new(_logger, concurrency);
var shardCount = shardMax - shardMin + 1; var shardCount = shardMax - shardMin + 1;
_logger.Information("Starting {ShardCount} of {ShardTotal} shards (#{ShardMin}-#{ShardMax}) at {Url}", _logger.Information("Starting {ShardCount} of {ShardTotal} shards (#{ShardMin}-#{ShardMax}) at {Url}",
@ -77,5 +77,16 @@ namespace Myriad.Gateway
return Math.Min(_gatewaySettings.MaxShardConcurrency.Value, recommendedConcurrency); return Math.Min(_gatewaySettings.MaxShardConcurrency.Value, recommendedConcurrency);
} }
private IGatewayRatelimiter GetRateLimiter(int recommendedConcurrency)
{
if (_gatewaySettings.GatewayQueueUrl != null)
{
return new TwilightGatewayRatelimiter(_logger, _gatewaySettings.GatewayQueueUrl);
}
var concurrency = GetActualShardConcurrency(recommendedConcurrency);
return new LocalGatewayRatelimiter(_logger, concurrency);
}
} }
} }

View File

@ -5,5 +5,6 @@
public string Token { get; init; } public string Token { get; init; }
public GatewayIntent Intents { get; init; } public GatewayIntent Intents { get; init; }
public int? MaxShardConcurrency { get; init; } public int? MaxShardConcurrency { get; init; }
public string? GatewayQueueUrl { get; init; }
} }
} }

View File

@ -0,0 +1,9 @@
using System.Threading.Tasks;
namespace Myriad.Gateway.Limit
{
public interface IGatewayRatelimiter
{
public Task Identify(int shard);
}
}

View File

@ -4,9 +4,9 @@ using System.Threading.Tasks;
using Serilog; using Serilog;
namespace Myriad.Gateway namespace Myriad.Gateway.Limit
{ {
public class ShardIdentifyRatelimiter public class LocalGatewayRatelimiter: IGatewayRatelimiter
{ {
// docs specify 5 seconds, but we're actually throttling connections, not identify, so we need a bit of leeway // docs specify 5 seconds, but we're actually throttling connections, not identify, so we need a bit of leeway
private static readonly TimeSpan BucketLength = TimeSpan.FromSeconds(6); private static readonly TimeSpan BucketLength = TimeSpan.FromSeconds(6);
@ -17,13 +17,13 @@ namespace Myriad.Gateway
private Task? _refillTask; private Task? _refillTask;
private readonly ILogger _logger; private readonly ILogger _logger;
public ShardIdentifyRatelimiter(ILogger logger, int maxConcurrency) public LocalGatewayRatelimiter(ILogger logger, int maxConcurrency)
{ {
_logger = logger.ForContext<ShardIdentifyRatelimiter>(); _logger = logger.ForContext<LocalGatewayRatelimiter>();
_maxConcurrency = maxConcurrency; _maxConcurrency = maxConcurrency;
} }
public Task Acquire(int shard) public Task Identify(int shard)
{ {
var bucket = shard % _maxConcurrency; var bucket = shard % _maxConcurrency;
var queue = _buckets.GetOrAdd(bucket, _ => new ConcurrentQueue<TaskCompletionSource>()); var queue = _buckets.GetOrAdd(bucket, _ => new ConcurrentQueue<TaskCompletionSource>());

View File

@ -0,0 +1,27 @@
using System.Net.Http;
using System.Threading.Tasks;
using Serilog;
namespace Myriad.Gateway.Limit
{
public class TwilightGatewayRatelimiter: IGatewayRatelimiter
{
private readonly string _url;
private readonly ILogger _logger;
private readonly HttpClient _httpClient = new();
public TwilightGatewayRatelimiter(ILogger logger, string url)
{
_url = url;
_logger = logger;
}
public async Task Identify(int shard)
{
// Literally just request and wait :p
_logger.Information("Shard {ShardId}: Requesting identify at gateway queue {GatewayQueueUrl}", shard, _url);
await _httpClient.GetAsync(_url);
}
}
}

View File

@ -3,6 +3,7 @@ using System.Net.WebSockets;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks; using System.Threading.Tasks;
using Myriad.Gateway.Limit;
using Myriad.Gateway.State; using Myriad.Gateway.State;
using Myriad.Serialization; using Myriad.Serialization;
using Myriad.Types; using Myriad.Types;
@ -17,7 +18,7 @@ namespace Myriad.Gateway
private readonly GatewaySettings _settings; private readonly GatewaySettings _settings;
private readonly ShardInfo _info; private readonly ShardInfo _info;
private readonly ShardIdentifyRatelimiter _ratelimiter; private readonly IGatewayRatelimiter _ratelimiter;
private readonly string _url; private readonly string _url;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly ShardStateManager _stateManager; private readonly ShardStateManager _stateManager;
@ -41,7 +42,7 @@ namespace Myriad.Gateway
private TimeSpan _reconnectDelay = TimeSpan.Zero; private TimeSpan _reconnectDelay = TimeSpan.Zero;
private Task? _worker; private Task? _worker;
public Shard(GatewaySettings settings, ShardInfo info, ShardIdentifyRatelimiter ratelimiter, string url, ILogger logger) public Shard(GatewaySettings settings, ShardInfo info, IGatewayRatelimiter ratelimiter, string url, ILogger logger)
{ {
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad(); _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
@ -105,11 +106,14 @@ namespace Myriad.Gateway
} }
} }
public Task Start() public async Task Start()
{ {
if (_worker == null) if (_worker == null)
_worker = ShardLoop(); _worker = ShardLoop();
return Task.CompletedTask;
// we can probably TCS this instead of spin loop but w/e
while (State != ShardState.Connected)
await Task.Delay(100);
} }
public async Task UpdateStatus(GatewayStatusUpdate payload) public async Task UpdateStatus(GatewayStatusUpdate payload)
@ -125,7 +129,7 @@ namespace Myriad.Gateway
{ {
while (true) while (true)
{ {
await _ratelimiter.Acquire(_info.ShardId); await _ratelimiter.Identify(_info.ShardId);
_logger.Information("Shard {ShardId}: Connecting to WebSocket", _info.ShardId); _logger.Information("Shard {ShardId}: Connecting to WebSocket", _info.ShardId);
try try

View File

@ -18,6 +18,8 @@ namespace PluralKit.Bot
public ClusterSettings? Cluster { get; set; } public ClusterSettings? Cluster { get; set; }
public string? GatewayQueueUrl { get; set; }
public record ClusterSettings public record ClusterSettings
{ {
public string NodeName { get; set; } public string NodeName { get; set; }

View File

@ -28,6 +28,7 @@ namespace PluralKit.Bot
{ {
Token = botConfig.Token, Token = botConfig.Token,
MaxShardConcurrency = botConfig.MaxShardConcurrency, MaxShardConcurrency = botConfig.MaxShardConcurrency,
GatewayQueueUrl = botConfig.GatewayQueueUrl,
Intents = GatewayIntent.Guilds | Intents = GatewayIntent.Guilds |
GatewayIntent.DirectMessages | GatewayIntent.DirectMessages |
GatewayIntent.DirectMessageReactions | GatewayIntent.DirectMessageReactions |