WIP new shard implementation
This commit is contained in:
@@ -3,6 +3,7 @@ using System.Net.WebSockets;
|
||||
using System.Text.Json;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Myriad.Gateway.State;
|
||||
using Myriad.Serialization;
|
||||
using Myriad.Types;
|
||||
|
||||
@@ -10,340 +11,192 @@ using Serilog;
|
||||
|
||||
namespace Myriad.Gateway
|
||||
{
|
||||
public class Shard: IAsyncDisposable
|
||||
public class Shard
|
||||
{
|
||||
private const string LibraryName = "Myriad (for PluralKit)";
|
||||
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions =
|
||||
new JsonSerializerOptions().ConfigureForMyriad();
|
||||
|
||||
private readonly GatewaySettings _settings;
|
||||
private readonly ShardInfo _info;
|
||||
private readonly ShardIdentifyRatelimiter _ratelimiter;
|
||||
private readonly string _url;
|
||||
private readonly ILogger _logger;
|
||||
private readonly Uri _uri;
|
||||
|
||||
private ShardConnection? _conn;
|
||||
private TimeSpan? _currentHeartbeatInterval;
|
||||
private bool _hasReceivedAck;
|
||||
private DateTimeOffset? _lastHeartbeatSent;
|
||||
private Task _worker;
|
||||
|
||||
public ShardInfo ShardInfo { get; private set; }
|
||||
public int ShardId => ShardInfo.ShardId;
|
||||
public GatewaySettings Settings { get; }
|
||||
public ShardSessionInfo SessionInfo { get; private set; }
|
||||
public ShardState State { get; private set; }
|
||||
public TimeSpan? Latency { get; private set; }
|
||||
public User? User { get; private set; }
|
||||
public ApplicationPartial? Application { get; private set; }
|
||||
private readonly ShardStateManager _stateManager;
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions;
|
||||
private readonly ShardConnection _conn;
|
||||
|
||||
public Func<IGatewayEvent, Task>? OnEventReceived { get; set; }
|
||||
public int ShardId => _info.ShardId;
|
||||
public ShardState State => _stateManager.State;
|
||||
public TimeSpan? Latency => _stateManager.Latency;
|
||||
public User? User => _stateManager.User;
|
||||
public ApplicationPartial? Application => _stateManager.Application;
|
||||
|
||||
// TODO: I wanna get rid of these or move them at some point
|
||||
public event Func<IGatewayEvent, Task>? OnEventReceived;
|
||||
public event Action<TimeSpan>? HeartbeatReceived;
|
||||
public event Action? SocketOpened;
|
||||
public event Action? Resumed;
|
||||
public event Action? Ready;
|
||||
public event Action<WebSocketCloseStatus, string?>? SocketClosed;
|
||||
public event Action<WebSocketCloseStatus?, string?>? SocketClosed;
|
||||
|
||||
private TimeSpan _reconnectDelay = TimeSpan.Zero;
|
||||
private Task? _worker;
|
||||
|
||||
public Shard(GatewaySettings settings, ShardInfo info, ShardIdentifyRatelimiter ratelimiter, string url, ILogger logger)
|
||||
{
|
||||
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
|
||||
|
||||
_settings = settings;
|
||||
_info = info;
|
||||
_ratelimiter = ratelimiter;
|
||||
_url = url;
|
||||
_logger = logger;
|
||||
_stateManager = new ShardStateManager(info, _jsonSerializerOptions, logger)
|
||||
{
|
||||
HandleEvent = HandleEvent,
|
||||
SendHeartbeat = SendHeartbeat,
|
||||
SendIdentify = SendIdentify,
|
||||
SendResume = SendResume,
|
||||
Connect = ConnectInner,
|
||||
Reconnect = Reconnect,
|
||||
};
|
||||
_stateManager.OnHeartbeatReceived += latency =>
|
||||
{
|
||||
HeartbeatReceived?.Invoke(latency);
|
||||
};
|
||||
|
||||
_conn = new ShardConnection(_jsonSerializerOptions, _logger);
|
||||
}
|
||||
|
||||
public Shard(ILogger logger, Uri uri, GatewaySettings settings, ShardInfo info,
|
||||
ShardSessionInfo? sessionInfo = null)
|
||||
private async Task ShardLoop()
|
||||
{
|
||||
_logger = logger.ForContext<Shard>();
|
||||
_uri = uri;
|
||||
while (true)
|
||||
{
|
||||
try
|
||||
{
|
||||
await ConnectInner();
|
||||
await HandleConnectionOpened();
|
||||
|
||||
Settings = settings;
|
||||
ShardInfo = info;
|
||||
SessionInfo = sessionInfo ?? new ShardSessionInfo();
|
||||
while (_conn.State == WebSocketState.Open)
|
||||
{
|
||||
var packet = await _conn.Read();
|
||||
if (packet == null)
|
||||
break;
|
||||
|
||||
await _stateManager.HandlePacketReceived(packet);
|
||||
}
|
||||
|
||||
await HandleConnectionClosed(_conn.CloseStatus, _conn.CloseStatusDescription);
|
||||
|
||||
_logger.Information("Shard {ShardId}: Reconnecting after delay {ReconnectDelay}",
|
||||
_info.ShardId, _reconnectDelay);
|
||||
|
||||
if (_reconnectDelay > TimeSpan.Zero)
|
||||
await Task.Delay(_reconnectDelay);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
_logger.Error(e, "Shard {ShardId}: Error in main shard loop, reconnecting in 5 seconds...", _info.ShardId);
|
||||
|
||||
// todo: exponential backoff here? this should never happen, ideally...
|
||||
await Task.Delay(TimeSpan.FromSeconds(5));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
if (_conn != null)
|
||||
await _conn.DisposeAsync();
|
||||
}
|
||||
|
||||
|
||||
public Task Start()
|
||||
{
|
||||
_worker = MainLoop();
|
||||
if (_worker == null)
|
||||
_worker = ShardLoop();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public async Task UpdateStatus(GatewayStatusUpdate payload)
|
||||
{
|
||||
if (_conn != null && _conn.State == WebSocketState.Open)
|
||||
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.PresenceUpdate, Payload = payload});
|
||||
}
|
||||
|
||||
private async Task MainLoop()
|
||||
{
|
||||
while (true)
|
||||
try
|
||||
{
|
||||
_logger.Information("Shard {ShardId}: Connecting...", ShardId);
|
||||
|
||||
State = ShardState.Connecting;
|
||||
await Connect();
|
||||
|
||||
_logger.Information("Shard {ShardId}: Connected. Entering main loop...", ShardId);
|
||||
|
||||
// Tick returns false if we need to stop and reconnect
|
||||
while (await Tick(_conn!))
|
||||
await Task.Delay(TimeSpan.FromMilliseconds(1000));
|
||||
|
||||
_logger.Information("Shard {ShardId}: Connection closed, reconnecting...", ShardId);
|
||||
State = ShardState.Closed;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
_logger.Error(e, "Shard {ShardId}: Error in shard state handler", ShardId);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<bool> Tick(ShardConnection conn)
|
||||
{
|
||||
if (conn.State != WebSocketState.Connecting && conn.State != WebSocketState.Open)
|
||||
return false;
|
||||
|
||||
if (!await TickHeartbeat(conn))
|
||||
// TickHeartbeat returns false if we're disconnecting
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private async Task<bool> TickHeartbeat(ShardConnection conn)
|
||||
{
|
||||
// If we don't need to heartbeat, do nothing
|
||||
if (_lastHeartbeatSent == null || _currentHeartbeatInterval == null)
|
||||
return true;
|
||||
|
||||
if (DateTimeOffset.UtcNow - _lastHeartbeatSent < _currentHeartbeatInterval)
|
||||
return true;
|
||||
|
||||
// If we haven't received the ack in time, close w/ error
|
||||
if (!_hasReceivedAck)
|
||||
await _conn.Send(new GatewayPacket
|
||||
{
|
||||
_logger.Warning(
|
||||
"Shard {ShardId}: Did not receive heartbeat Ack from gateway within interval ({HeartbeatInterval})",
|
||||
ShardId, _currentHeartbeatInterval);
|
||||
State = ShardState.Closing;
|
||||
await conn.Disconnect(WebSocketCloseStatus.ProtocolError, "Did not receive ACK in time");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Otherwise just send it :)
|
||||
await SendHeartbeat(conn);
|
||||
_hasReceivedAck = false;
|
||||
return true;
|
||||
Opcode = GatewayOpcode.PresenceUpdate,
|
||||
Payload = payload
|
||||
});
|
||||
}
|
||||
|
||||
private async Task SendHeartbeat(ShardConnection conn)
|
||||
|
||||
private async Task ConnectInner()
|
||||
{
|
||||
_logger.Debug("Shard {ShardId}: Sending heartbeat with seq.no. {LastSequence}",
|
||||
ShardId, SessionInfo.LastSequence);
|
||||
await _ratelimiter.Acquire(_info.ShardId);
|
||||
|
||||
await conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = SessionInfo.LastSequence});
|
||||
_lastHeartbeatSent = DateTimeOffset.UtcNow;
|
||||
_logger.Information("Shard {ShardId}: Connecting to WebSocket", _info.ShardId);
|
||||
await _conn.Connect(_url, default);
|
||||
}
|
||||
|
||||
private async Task Connect()
|
||||
|
||||
private async Task DisconnectInner(WebSocketCloseStatus closeStatus)
|
||||
{
|
||||
if (_conn != null)
|
||||
await _conn.DisposeAsync();
|
||||
|
||||
_currentHeartbeatInterval = null;
|
||||
|
||||
_conn = new ShardConnection(_uri, _logger, _jsonSerializerOptions)
|
||||
{
|
||||
OnReceive = OnReceive,
|
||||
OnOpen = () => SocketOpened?.Invoke(),
|
||||
OnClose = (closeStatus, message) => SocketClosed?.Invoke(closeStatus, message)
|
||||
};
|
||||
await _conn.Disconnect(closeStatus, null);
|
||||
}
|
||||
|
||||
private async Task OnReceive(GatewayPacket packet)
|
||||
{
|
||||
switch (packet.Opcode)
|
||||
{
|
||||
case GatewayOpcode.Hello:
|
||||
{
|
||||
await HandleHello((JsonElement) packet.Payload!);
|
||||
break;
|
||||
}
|
||||
case GatewayOpcode.Heartbeat:
|
||||
{
|
||||
_logger.Debug("Shard {ShardId}: Received heartbeat request from shard, sending Ack", ShardId);
|
||||
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.HeartbeatAck});
|
||||
break;
|
||||
}
|
||||
case GatewayOpcode.HeartbeatAck:
|
||||
{
|
||||
Latency = DateTimeOffset.UtcNow - _lastHeartbeatSent;
|
||||
_logger.Debug("Shard {ShardId}: Received heartbeat Ack with latency {Latency}", ShardId, Latency);
|
||||
if (Latency != null)
|
||||
HeartbeatReceived?.Invoke(Latency!.Value);
|
||||
|
||||
_hasReceivedAck = true;
|
||||
break;
|
||||
}
|
||||
case GatewayOpcode.Reconnect:
|
||||
{
|
||||
_logger.Information("Shard {ShardId}: Received Reconnect, closing and reconnecting", ShardId);
|
||||
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
|
||||
break;
|
||||
}
|
||||
case GatewayOpcode.InvalidSession:
|
||||
{
|
||||
var canResume = ((JsonElement) packet.Payload!).GetBoolean();
|
||||
|
||||
// Clear session info before DCing
|
||||
if (!canResume)
|
||||
SessionInfo = SessionInfo with { Session = null };
|
||||
|
||||
var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000));
|
||||
|
||||
_logger.Information(
|
||||
"Shard {ShardId}: Received Invalid Session (can resume? {CanResume}), reconnecting after {ReconnectDelay}",
|
||||
ShardId, canResume, delay);
|
||||
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
|
||||
|
||||
// Will reconnect after exiting this "loop"
|
||||
await Task.Delay(delay);
|
||||
break;
|
||||
}
|
||||
case GatewayOpcode.Dispatch:
|
||||
{
|
||||
SessionInfo = SessionInfo with { LastSequence = packet.Sequence };
|
||||
var evt = DeserializeEvent(packet.EventType!, (JsonElement) packet.Payload!)!;
|
||||
|
||||
if (evt is ReadyEvent rdy)
|
||||
{
|
||||
if (State == ShardState.Connecting)
|
||||
await HandleReady(rdy);
|
||||
else
|
||||
_logger.Warning("Shard {ShardId}: Received Ready event in unexpected state {ShardState}, ignoring?",
|
||||
ShardId, State);
|
||||
}
|
||||
else if (evt is ResumedEvent)
|
||||
{
|
||||
if (State == ShardState.Connecting)
|
||||
await HandleResumed();
|
||||
else
|
||||
_logger.Warning("Shard {ShardId}: Received Resumed event in unexpected state {ShardState}, ignoring?",
|
||||
ShardId, State);
|
||||
}
|
||||
|
||||
await HandleEvent(evt);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
_logger.Debug("Shard {ShardId}: Received unknown gateway opcode {Opcode}", ShardId, packet.Opcode);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task HandleEvent(IGatewayEvent evt)
|
||||
{
|
||||
if (OnEventReceived != null)
|
||||
await OnEventReceived.Invoke(evt);
|
||||
}
|
||||
|
||||
|
||||
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement data)
|
||||
{
|
||||
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
|
||||
{
|
||||
_logger.Information("Shard {ShardId}: Received unknown event type {EventType}", ShardId, eventType);
|
||||
return null;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
_logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", ShardId, eventType, clrType);
|
||||
return JsonSerializer.Deserialize(data.GetRawText(), clrType, _jsonSerializerOptions)
|
||||
as IGatewayEvent;
|
||||
}
|
||||
catch (JsonException e)
|
||||
{
|
||||
_logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", ShardId, eventType, clrType);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Task HandleReady(ReadyEvent ready)
|
||||
{
|
||||
// TODO: when is ready.Shard ever null?
|
||||
ShardInfo = ready.Shard ?? new ShardInfo(0, 0);
|
||||
SessionInfo = SessionInfo with { Session = ready.SessionId };
|
||||
User = ready.User;
|
||||
Application = ready.Application;
|
||||
State = ShardState.Open;
|
||||
|
||||
Ready?.Invoke();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private Task HandleResumed()
|
||||
{
|
||||
State = ShardState.Open;
|
||||
Resumed?.Invoke();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task HandleHello(JsonElement json)
|
||||
{
|
||||
var hello = JsonSerializer.Deserialize<GatewayHello>(json.GetRawText(), _jsonSerializerOptions)!;
|
||||
_logger.Debug("Shard {ShardId}: Received Hello with interval {Interval} ms", ShardId, hello.HeartbeatInterval);
|
||||
_currentHeartbeatInterval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval);
|
||||
|
||||
await SendHeartbeat(_conn!);
|
||||
|
||||
await SendIdentifyOrResume();
|
||||
}
|
||||
|
||||
private async Task SendIdentifyOrResume()
|
||||
{
|
||||
if (SessionInfo.Session != null && SessionInfo.LastSequence != null)
|
||||
await SendResume(SessionInfo.Session, SessionInfo.LastSequence!.Value);
|
||||
else
|
||||
await SendIdentify();
|
||||
}
|
||||
|
||||
|
||||
private async Task SendIdentify()
|
||||
{
|
||||
_logger.Information("Shard {ShardId}: Sending gateway Identify for shard {@ShardInfo}", ShardId, ShardInfo);
|
||||
await _conn!.Send(new GatewayPacket
|
||||
await _conn.Send(new GatewayPacket
|
||||
{
|
||||
Opcode = GatewayOpcode.Identify,
|
||||
Payload = new GatewayIdentify
|
||||
{
|
||||
Token = Settings.Token,
|
||||
Compress = false,
|
||||
Intents = _settings.Intents,
|
||||
Properties = new GatewayIdentify.ConnectionProperties
|
||||
{
|
||||
Browser = LibraryName, Device = LibraryName, Os = Environment.OSVersion.ToString()
|
||||
Browser = LibraryName,
|
||||
Device = LibraryName,
|
||||
Os = Environment.OSVersion.ToString()
|
||||
},
|
||||
Intents = Settings.Intents,
|
||||
Shard = ShardInfo
|
||||
Shard = _info,
|
||||
Token = _settings.Token,
|
||||
LargeThreshold = 50
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async Task SendResume(string session, int lastSequence)
|
||||
|
||||
private async Task SendResume((string SessionId, int? LastSeq) arg)
|
||||
{
|
||||
_logger.Information("Shard {ShardId}: Sending gateway Resume for session {@SessionInfo}",
|
||||
ShardId, SessionInfo);
|
||||
await _conn!.Send(new GatewayPacket
|
||||
await _conn.Send(new GatewayPacket
|
||||
{
|
||||
Opcode = GatewayOpcode.Resume,
|
||||
Payload = new GatewayResume(Settings.Token, session, lastSequence)
|
||||
Opcode = GatewayOpcode.Resume,
|
||||
Payload = new GatewayResume(_settings.Token, arg.SessionId, arg.LastSeq ?? 0)
|
||||
});
|
||||
}
|
||||
|
||||
public enum ShardState
|
||||
private async Task SendHeartbeat(int? lastSeq)
|
||||
{
|
||||
Closed,
|
||||
Connecting,
|
||||
Open,
|
||||
Closing
|
||||
await _conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = lastSeq});
|
||||
}
|
||||
|
||||
private async Task Reconnect(WebSocketCloseStatus closeStatus, TimeSpan delay)
|
||||
{
|
||||
_reconnectDelay = delay;
|
||||
await DisconnectInner(closeStatus);
|
||||
}
|
||||
|
||||
private async Task HandleEvent(IGatewayEvent arg)
|
||||
{
|
||||
if (arg is ReadyEvent)
|
||||
Ready?.Invoke();
|
||||
if (arg is ResumedEvent)
|
||||
Resumed?.Invoke();
|
||||
|
||||
await (OnEventReceived?.Invoke(arg) ?? Task.CompletedTask);
|
||||
}
|
||||
|
||||
private async Task HandleConnectionOpened()
|
||||
{
|
||||
_logger.Information("Shard {ShardId}: Connection opened", _info.ShardId);
|
||||
await _stateManager.HandleConnectionOpened();
|
||||
SocketOpened?.Invoke();
|
||||
}
|
||||
|
||||
private async Task HandleConnectionClosed(WebSocketCloseStatus? closeStatus, string? description)
|
||||
{
|
||||
_logger.Information("Shard {ShardId}: Connection closed ({CloseStatus}/{Description})",
|
||||
_info.ShardId, closeStatus, description ?? "<null>");
|
||||
await _stateManager.HandleConnectionClosed();
|
||||
SocketClosed?.Invoke(closeStatus, description);
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user