diff --git a/Myriad/Extensions/MessageExtensions.cs b/Myriad/Extensions/MessageExtensions.cs index 56664154..60adb532 100644 --- a/Myriad/Extensions/MessageExtensions.cs +++ b/Myriad/Extensions/MessageExtensions.cs @@ -1,6 +1,4 @@ -using System; - -using Myriad.Gateway; +using Myriad.Gateway; using Myriad.Types; namespace Myriad.Extensions diff --git a/Myriad/Extensions/PermissionExtensions.cs b/Myriad/Extensions/PermissionExtensions.cs index d78288a3..6eafbbfa 100644 --- a/Myriad/Extensions/PermissionExtensions.cs +++ b/Myriad/Extensions/PermissionExtensions.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Threading.Tasks; using Myriad.Cache; using Myriad.Gateway; diff --git a/Myriad/Gateway/Cluster.cs b/Myriad/Gateway/Cluster.cs index bc5805fa..3f6a1134 100644 --- a/Myriad/Gateway/Cluster.cs +++ b/Myriad/Gateway/Cluster.cs @@ -15,6 +15,7 @@ namespace Myriad.Gateway private readonly GatewaySettings _gatewaySettings; private readonly ILogger _logger; private readonly ConcurrentDictionary _shards = new(); + private ShardIdentifyRatelimiter? _ratelimiter; public Cluster(GatewaySettings gatewaySettings, ILogger logger) { @@ -26,81 +27,35 @@ namespace Myriad.Gateway public event Action? ShardCreated; public IReadOnlyDictionary Shards => _shards; - public ClusterSessionState SessionState => GetClusterState(); public User? User => _shards.Values.Select(s => s.User).FirstOrDefault(s => s != null); public ApplicationPartial? Application => _shards.Values.Select(s => s.Application).FirstOrDefault(s => s != null); - - private ClusterSessionState GetClusterState() + + public async Task Start(GatewayInfo.Bot info) { - var shards = new List(); - foreach (var (id, shard) in _shards) - shards.Add(new ClusterSessionState.ShardState - { - Shard = shard.ShardInfo, - Session = shard.SessionInfo - }); - - return new ClusterSessionState {Shards = shards}; + var concurrency = GetActualShardConcurrency(info.SessionStartLimit.MaxConcurrency); + _ratelimiter = new(_logger, concurrency); + + await Start(info.Url, info.Shards); } - public async Task Start(GatewayInfo.Bot info, ClusterSessionState? lastState = null) - { - if (lastState != null && lastState.Shards.Count == info.Shards) - await Resume(info.Url, lastState, info.SessionStartLimit.MaxConcurrency); - else - await Start(info.Url, info.Shards, info.SessionStartLimit.MaxConcurrency); - } - - public async Task Resume(string url, ClusterSessionState sessionState, int concurrency) - { - _logger.Information("Resuming session with {ShardCount} shards at {Url}", sessionState.Shards.Count, url); - foreach (var shardState in sessionState.Shards) - CreateAndAddShard(url, shardState.Shard, shardState.Session); - - await StartShards(concurrency); - } - - public async Task Start(string url, int shardCount, int concurrency) + public async Task Start(string url, int shardCount) { _logger.Information("Starting {ShardCount} shards at {Url}", shardCount, url); for (var i = 0; i < shardCount; i++) - CreateAndAddShard(url, new ShardInfo(i, shardCount), null); + CreateAndAddShard(url, new ShardInfo(i, shardCount)); - await StartShards(concurrency); + await StartShards(); } - private async Task StartShards(int concurrency) + private async Task StartShards() { - concurrency = GetActualShardConcurrency(concurrency); - - var lastTime = DateTimeOffset.UtcNow; - var identifyCalls = 0; - _logger.Information("Connecting shards..."); - foreach (var shard in _shards.Values) - { - if (identifyCalls >= concurrency) - { - var timeout = lastTime + TimeSpan.FromSeconds(5.5); - var delay = timeout - DateTimeOffset.UtcNow; - - if (delay > TimeSpan.Zero) - { - _logger.Information("Hit shard concurrency limit, waiting {Delay}", delay); - await Task.Delay(delay); - } - - identifyCalls = 0; - lastTime = DateTimeOffset.UtcNow; - } - + foreach (var shard in _shards.Values) await shard.Start(); - identifyCalls++; - } } - private void CreateAndAddShard(string url, ShardInfo shardInfo, ShardSessionInfo? session) + private void CreateAndAddShard(string url, ShardInfo shardInfo) { - var shard = new Shard(_logger, new Uri(url), _gatewaySettings, shardInfo, session); + var shard = new Shard(_gatewaySettings, shardInfo, _ratelimiter!, url, _logger); shard.OnEventReceived += evt => OnShardEventReceived(shard, evt); _shards[shardInfo.ShardId] = shard; diff --git a/Myriad/Gateway/ClusterSessionState.cs b/Myriad/Gateway/ClusterSessionState.cs deleted file mode 100644 index aafb14be..00000000 --- a/Myriad/Gateway/ClusterSessionState.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System.Collections.Generic; - -namespace Myriad.Gateway -{ - public record ClusterSessionState - { - public List Shards { get; init; } - - public record ShardState - { - public ShardInfo Shard { get; init; } - public ShardSessionInfo Session { get; init; } - } - } -} \ No newline at end of file diff --git a/Myriad/Gateway/Events/GuildCreateEvent.cs b/Myriad/Gateway/Events/GuildCreateEvent.cs index acfc9132..41f6220c 100644 --- a/Myriad/Gateway/Events/GuildCreateEvent.cs +++ b/Myriad/Gateway/Events/GuildCreateEvent.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; - -using Myriad.Types; +using Myriad.Types; namespace Myriad.Gateway { diff --git a/Myriad/Gateway/Shard.cs b/Myriad/Gateway/Shard.cs index 25cbba81..b0f91158 100644 --- a/Myriad/Gateway/Shard.cs +++ b/Myriad/Gateway/Shard.cs @@ -3,6 +3,7 @@ using System.Net.WebSockets; using System.Text.Json; using System.Threading.Tasks; +using Myriad.Gateway.State; using Myriad.Serialization; using Myriad.Types; @@ -10,340 +11,192 @@ using Serilog; namespace Myriad.Gateway { - public class Shard: IAsyncDisposable + public class Shard { private const string LibraryName = "Myriad (for PluralKit)"; - private readonly JsonSerializerOptions _jsonSerializerOptions = - new JsonSerializerOptions().ConfigureForMyriad(); - + private readonly GatewaySettings _settings; + private readonly ShardInfo _info; + private readonly ShardIdentifyRatelimiter _ratelimiter; + private readonly string _url; private readonly ILogger _logger; - private readonly Uri _uri; - - private ShardConnection? _conn; - private TimeSpan? _currentHeartbeatInterval; - private bool _hasReceivedAck; - private DateTimeOffset? _lastHeartbeatSent; - private Task _worker; - - public ShardInfo ShardInfo { get; private set; } - public int ShardId => ShardInfo.ShardId; - public GatewaySettings Settings { get; } - public ShardSessionInfo SessionInfo { get; private set; } - public ShardState State { get; private set; } - public TimeSpan? Latency { get; private set; } - public User? User { get; private set; } - public ApplicationPartial? Application { get; private set; } + private readonly ShardStateManager _stateManager; + private readonly JsonSerializerOptions _jsonSerializerOptions; + private readonly ShardConnection _conn; - public Func? OnEventReceived { get; set; } + public int ShardId => _info.ShardId; + public ShardState State => _stateManager.State; + public TimeSpan? Latency => _stateManager.Latency; + public User? User => _stateManager.User; + public ApplicationPartial? Application => _stateManager.Application; + + // TODO: I wanna get rid of these or move them at some point + public event Func? OnEventReceived; public event Action? HeartbeatReceived; public event Action? SocketOpened; public event Action? Resumed; public event Action? Ready; - public event Action? SocketClosed; + public event Action? SocketClosed; + + private TimeSpan _reconnectDelay = TimeSpan.Zero; + private Task? _worker; + + public Shard(GatewaySettings settings, ShardInfo info, ShardIdentifyRatelimiter ratelimiter, string url, ILogger logger) + { + _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad(); + + _settings = settings; + _info = info; + _ratelimiter = ratelimiter; + _url = url; + _logger = logger; + _stateManager = new ShardStateManager(info, _jsonSerializerOptions, logger) + { + HandleEvent = HandleEvent, + SendHeartbeat = SendHeartbeat, + SendIdentify = SendIdentify, + SendResume = SendResume, + Connect = ConnectInner, + Reconnect = Reconnect, + }; + _stateManager.OnHeartbeatReceived += latency => + { + HeartbeatReceived?.Invoke(latency); + }; + + _conn = new ShardConnection(_jsonSerializerOptions, _logger); + } - public Shard(ILogger logger, Uri uri, GatewaySettings settings, ShardInfo info, - ShardSessionInfo? sessionInfo = null) + private async Task ShardLoop() { - _logger = logger.ForContext(); - _uri = uri; + while (true) + { + try + { + await ConnectInner(); + await HandleConnectionOpened(); - Settings = settings; - ShardInfo = info; - SessionInfo = sessionInfo ?? new ShardSessionInfo(); + while (_conn.State == WebSocketState.Open) + { + var packet = await _conn.Read(); + if (packet == null) + break; + + await _stateManager.HandlePacketReceived(packet); + } + + await HandleConnectionClosed(_conn.CloseStatus, _conn.CloseStatusDescription); + + _logger.Information("Shard {ShardId}: Reconnecting after delay {ReconnectDelay}", + _info.ShardId, _reconnectDelay); + + if (_reconnectDelay > TimeSpan.Zero) + await Task.Delay(_reconnectDelay); + } + catch (Exception e) + { + _logger.Error(e, "Shard {ShardId}: Error in main shard loop, reconnecting in 5 seconds...", _info.ShardId); + + // todo: exponential backoff here? this should never happen, ideally... + await Task.Delay(TimeSpan.FromSeconds(5)); + } + } } - - public async ValueTask DisposeAsync() - { - if (_conn != null) - await _conn.DisposeAsync(); - } - + public Task Start() { - _worker = MainLoop(); + if (_worker == null) + _worker = ShardLoop(); return Task.CompletedTask; } public async Task UpdateStatus(GatewayStatusUpdate payload) { - if (_conn != null && _conn.State == WebSocketState.Open) - await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.PresenceUpdate, Payload = payload}); - } - - private async Task MainLoop() - { - while (true) - try - { - _logger.Information("Shard {ShardId}: Connecting...", ShardId); - - State = ShardState.Connecting; - await Connect(); - - _logger.Information("Shard {ShardId}: Connected. Entering main loop...", ShardId); - - // Tick returns false if we need to stop and reconnect - while (await Tick(_conn!)) - await Task.Delay(TimeSpan.FromMilliseconds(1000)); - - _logger.Information("Shard {ShardId}: Connection closed, reconnecting...", ShardId); - State = ShardState.Closed; - } - catch (Exception e) - { - _logger.Error(e, "Shard {ShardId}: Error in shard state handler", ShardId); - } - } - - private async Task Tick(ShardConnection conn) - { - if (conn.State != WebSocketState.Connecting && conn.State != WebSocketState.Open) - return false; - - if (!await TickHeartbeat(conn)) - // TickHeartbeat returns false if we're disconnecting - return false; - - return true; - } - - private async Task TickHeartbeat(ShardConnection conn) - { - // If we don't need to heartbeat, do nothing - if (_lastHeartbeatSent == null || _currentHeartbeatInterval == null) - return true; - - if (DateTimeOffset.UtcNow - _lastHeartbeatSent < _currentHeartbeatInterval) - return true; - - // If we haven't received the ack in time, close w/ error - if (!_hasReceivedAck) + await _conn.Send(new GatewayPacket { - _logger.Warning( - "Shard {ShardId}: Did not receive heartbeat Ack from gateway within interval ({HeartbeatInterval})", - ShardId, _currentHeartbeatInterval); - State = ShardState.Closing; - await conn.Disconnect(WebSocketCloseStatus.ProtocolError, "Did not receive ACK in time"); - return false; - } - - // Otherwise just send it :) - await SendHeartbeat(conn); - _hasReceivedAck = false; - return true; + Opcode = GatewayOpcode.PresenceUpdate, + Payload = payload + }); } - - private async Task SendHeartbeat(ShardConnection conn) + + private async Task ConnectInner() { - _logger.Debug("Shard {ShardId}: Sending heartbeat with seq.no. {LastSequence}", - ShardId, SessionInfo.LastSequence); + await _ratelimiter.Acquire(_info.ShardId); - await conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = SessionInfo.LastSequence}); - _lastHeartbeatSent = DateTimeOffset.UtcNow; + _logger.Information("Shard {ShardId}: Connecting to WebSocket", _info.ShardId); + await _conn.Connect(_url, default); } - - private async Task Connect() + + private async Task DisconnectInner(WebSocketCloseStatus closeStatus) { - if (_conn != null) - await _conn.DisposeAsync(); - - _currentHeartbeatInterval = null; - - _conn = new ShardConnection(_uri, _logger, _jsonSerializerOptions) - { - OnReceive = OnReceive, - OnOpen = () => SocketOpened?.Invoke(), - OnClose = (closeStatus, message) => SocketClosed?.Invoke(closeStatus, message) - }; + await _conn.Disconnect(closeStatus, null); } - - private async Task OnReceive(GatewayPacket packet) - { - switch (packet.Opcode) - { - case GatewayOpcode.Hello: - { - await HandleHello((JsonElement) packet.Payload!); - break; - } - case GatewayOpcode.Heartbeat: - { - _logger.Debug("Shard {ShardId}: Received heartbeat request from shard, sending Ack", ShardId); - await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.HeartbeatAck}); - break; - } - case GatewayOpcode.HeartbeatAck: - { - Latency = DateTimeOffset.UtcNow - _lastHeartbeatSent; - _logger.Debug("Shard {ShardId}: Received heartbeat Ack with latency {Latency}", ShardId, Latency); - if (Latency != null) - HeartbeatReceived?.Invoke(Latency!.Value); - - _hasReceivedAck = true; - break; - } - case GatewayOpcode.Reconnect: - { - _logger.Information("Shard {ShardId}: Received Reconnect, closing and reconnecting", ShardId); - await _conn!.Disconnect(WebSocketCloseStatus.Empty, null); - break; - } - case GatewayOpcode.InvalidSession: - { - var canResume = ((JsonElement) packet.Payload!).GetBoolean(); - - // Clear session info before DCing - if (!canResume) - SessionInfo = SessionInfo with { Session = null }; - - var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000)); - - _logger.Information( - "Shard {ShardId}: Received Invalid Session (can resume? {CanResume}), reconnecting after {ReconnectDelay}", - ShardId, canResume, delay); - await _conn!.Disconnect(WebSocketCloseStatus.Empty, null); - - // Will reconnect after exiting this "loop" - await Task.Delay(delay); - break; - } - case GatewayOpcode.Dispatch: - { - SessionInfo = SessionInfo with { LastSequence = packet.Sequence }; - var evt = DeserializeEvent(packet.EventType!, (JsonElement) packet.Payload!)!; - - if (evt is ReadyEvent rdy) - { - if (State == ShardState.Connecting) - await HandleReady(rdy); - else - _logger.Warning("Shard {ShardId}: Received Ready event in unexpected state {ShardState}, ignoring?", - ShardId, State); - } - else if (evt is ResumedEvent) - { - if (State == ShardState.Connecting) - await HandleResumed(); - else - _logger.Warning("Shard {ShardId}: Received Resumed event in unexpected state {ShardState}, ignoring?", - ShardId, State); - } - - await HandleEvent(evt); - break; - } - default: - { - _logger.Debug("Shard {ShardId}: Received unknown gateway opcode {Opcode}", ShardId, packet.Opcode); - break; - } - } - } - - private async Task HandleEvent(IGatewayEvent evt) - { - if (OnEventReceived != null) - await OnEventReceived.Invoke(evt); - } - - - private IGatewayEvent? DeserializeEvent(string eventType, JsonElement data) - { - if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType)) - { - _logger.Information("Shard {ShardId}: Received unknown event type {EventType}", ShardId, eventType); - return null; - } - - try - { - _logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", ShardId, eventType, clrType); - return JsonSerializer.Deserialize(data.GetRawText(), clrType, _jsonSerializerOptions) - as IGatewayEvent; - } - catch (JsonException e) - { - _logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", ShardId, eventType, clrType); - return null; - } - } - - private Task HandleReady(ReadyEvent ready) - { - // TODO: when is ready.Shard ever null? - ShardInfo = ready.Shard ?? new ShardInfo(0, 0); - SessionInfo = SessionInfo with { Session = ready.SessionId }; - User = ready.User; - Application = ready.Application; - State = ShardState.Open; - - Ready?.Invoke(); - return Task.CompletedTask; - } - - private Task HandleResumed() - { - State = ShardState.Open; - Resumed?.Invoke(); - return Task.CompletedTask; - } - - private async Task HandleHello(JsonElement json) - { - var hello = JsonSerializer.Deserialize(json.GetRawText(), _jsonSerializerOptions)!; - _logger.Debug("Shard {ShardId}: Received Hello with interval {Interval} ms", ShardId, hello.HeartbeatInterval); - _currentHeartbeatInterval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval); - - await SendHeartbeat(_conn!); - - await SendIdentifyOrResume(); - } - - private async Task SendIdentifyOrResume() - { - if (SessionInfo.Session != null && SessionInfo.LastSequence != null) - await SendResume(SessionInfo.Session, SessionInfo.LastSequence!.Value); - else - await SendIdentify(); - } - + private async Task SendIdentify() { - _logger.Information("Shard {ShardId}: Sending gateway Identify for shard {@ShardInfo}", ShardId, ShardInfo); - await _conn!.Send(new GatewayPacket + await _conn.Send(new GatewayPacket { Opcode = GatewayOpcode.Identify, Payload = new GatewayIdentify { - Token = Settings.Token, + Compress = false, + Intents = _settings.Intents, Properties = new GatewayIdentify.ConnectionProperties { - Browser = LibraryName, Device = LibraryName, Os = Environment.OSVersion.ToString() + Browser = LibraryName, + Device = LibraryName, + Os = Environment.OSVersion.ToString() }, - Intents = Settings.Intents, - Shard = ShardInfo + Shard = _info, + Token = _settings.Token, + LargeThreshold = 50 } }); } - - private async Task SendResume(string session, int lastSequence) + + private async Task SendResume((string SessionId, int? LastSeq) arg) { - _logger.Information("Shard {ShardId}: Sending gateway Resume for session {@SessionInfo}", - ShardId, SessionInfo); - await _conn!.Send(new GatewayPacket + await _conn.Send(new GatewayPacket { - Opcode = GatewayOpcode.Resume, - Payload = new GatewayResume(Settings.Token, session, lastSequence) + Opcode = GatewayOpcode.Resume, + Payload = new GatewayResume(_settings.Token, arg.SessionId, arg.LastSeq ?? 0) }); } - public enum ShardState + private async Task SendHeartbeat(int? lastSeq) { - Closed, - Connecting, - Open, - Closing + await _conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = lastSeq}); + } + + private async Task Reconnect(WebSocketCloseStatus closeStatus, TimeSpan delay) + { + _reconnectDelay = delay; + await DisconnectInner(closeStatus); + } + + private async Task HandleEvent(IGatewayEvent arg) + { + if (arg is ReadyEvent) + Ready?.Invoke(); + if (arg is ResumedEvent) + Resumed?.Invoke(); + + await (OnEventReceived?.Invoke(arg) ?? Task.CompletedTask); + } + + private async Task HandleConnectionOpened() + { + _logger.Information("Shard {ShardId}: Connection opened", _info.ShardId); + await _stateManager.HandleConnectionOpened(); + SocketOpened?.Invoke(); + } + + private async Task HandleConnectionClosed(WebSocketCloseStatus? closeStatus, string? description) + { + _logger.Information("Shard {ShardId}: Connection closed ({CloseStatus}/{Description})", + _info.ShardId, closeStatus, description ?? ""); + await _stateManager.HandleConnectionClosed(); + SocketClosed?.Invoke(closeStatus, description); } } } \ No newline at end of file diff --git a/Myriad/Gateway/ShardConnection.cs b/Myriad/Gateway/ShardConnection.cs index 886e0664..250ef84b 100644 --- a/Myriad/Gateway/ShardConnection.cs +++ b/Myriad/Gateway/ShardConnection.cs @@ -1,6 +1,4 @@ using System; -using System.Buffers; -using System.IO; using System.Net.WebSockets; using System.Text.Json; using System.Threading; @@ -12,120 +10,95 @@ namespace Myriad.Gateway { public class ShardConnection: IAsyncDisposable { - private readonly MemoryStream _bufStream = new(); - - private readonly ClientWebSocket _client = new(); - private readonly CancellationTokenSource _cts = new(); - private readonly JsonSerializerOptions _jsonSerializerOptions; + private ClientWebSocket? _client; private readonly ILogger _logger; - private readonly Task _worker; - - public ShardConnection(Uri uri, ILogger logger, JsonSerializerOptions jsonSerializerOptions) - { - _logger = logger; - _jsonSerializerOptions = jsonSerializerOptions; - - _worker = Worker(uri); - } - - public Func? OnReceive { get; set; } - public Action? OnOpen { get; set; } - - public Action? OnClose { get; set; } - - public WebSocketState State => _client.State; + private readonly ShardPacketSerializer _serializer; - public async ValueTask DisposeAsync() + public WebSocketState State => _client?.State ?? WebSocketState.Closed; + public WebSocketCloseStatus? CloseStatus => _client?.CloseStatus; + public string? CloseStatusDescription => _client?.CloseStatusDescription; + + public ShardConnection(JsonSerializerOptions jsonSerializerOptions, ILogger logger) { - _cts.Cancel(); - await _worker; - - _client.Dispose(); - await _bufStream.DisposeAsync(); - _cts.Dispose(); + _logger = logger.ForContext(); + _serializer = new(jsonSerializerOptions); } - private async Task Worker(Uri uri) + public async Task Connect(string url, CancellationToken ct) { - var realUrl = new UriBuilder(uri) - { - Query = "v=8&encoding=json" - }.Uri; - _logger.Debug("Connecting to gateway WebSocket at {GatewayUrl}", realUrl); - await _client.ConnectAsync(realUrl, default); - _logger.Debug("Gateway connection opened"); + _client?.Dispose(); + _client = new ClientWebSocket(); - OnOpen?.Invoke(); - - // Main worker loop, spins until we manually disconnect (which hits the cancellation token) - // or the server disconnects us (which sets state to closed) - while (!_cts.IsCancellationRequested && _client.State == WebSocketState.Open) - { - try - { - await HandleReceive(); - } - catch (Exception e) - { - _logger.Error(e, "Error in WebSocket receive worker"); - } - } - - OnClose?.Invoke(_client.CloseStatus ?? default, _client.CloseStatusDescription); + await _client.ConnectAsync(GetConnectionUri(url), ct); } - private async Task HandleReceive() + public async Task Disconnect(WebSocketCloseStatus closeStatus, string? reason) { - _bufStream.SetLength(0); - var result = await ReadData(_bufStream); - var data = _bufStream.GetBuffer().AsMemory(0, (int) _bufStream.Position); - - if (result.MessageType == WebSocketMessageType.Text) - await HandleReceiveData(data); - else if (result.MessageType == WebSocketMessageType.Close) - _logger.Information("WebSocket closed by server: {StatusCode} {Reason}", _client.CloseStatus, - _client.CloseStatusDescription); - } - - private async Task HandleReceiveData(Memory data) - { - var packet = JsonSerializer.Deserialize(data.Span, _jsonSerializerOptions)!; - - try - { - if (OnReceive != null) - await OnReceive.Invoke(packet); - } - catch (Exception e) - { - _logger.Error(e, "Error in gateway handler for {OpcodeType}", packet.Opcode); - } - } - - private async Task ReadData(MemoryStream stream) - { - // TODO: does this throw if we disconnect mid-read? - using var buf = MemoryPool.Shared.Rent(); - ValueWebSocketReceiveResult result; - do - { - result = await _client.ReceiveAsync(buf.Memory, _cts.Token); - stream.Write(buf.Memory.Span.Slice(0, result.Count)); - } while (!result.EndOfMessage); - - return result; + await CloseInner(closeStatus, reason); } public async Task Send(GatewayPacket packet) { - var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions); - await _client.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default); + if (_client == null || _client.State != WebSocketState.Open) + return; + + try + { + await _serializer.WritePacket(_client, packet); + } + catch (Exception e) + { + _logger.Error(e, "Error sending WebSocket message"); + } } - public async Task Disconnect(WebSocketCloseStatus status, string? description) + public async ValueTask DisposeAsync() { - await _client.CloseAsync(status, description, default); - _cts.Cancel(); + await CloseInner(WebSocketCloseStatus.NormalClosure, null); + _client?.Dispose(); + } + + public async Task Read() + { + if (_client == null || _client.State != WebSocketState.Open) + return null; + + try + { + var (_, packet) = await _serializer.ReadPacket(_client); + return packet; + } + catch (Exception e) + { + _logger.Error(e, "Error reading from WebSocket"); + } + + return null; + } + + private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri) + { + Query = "v=8&encoding=json" + }.Uri; + + private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description) + { + if (_client == null) + return; + + if (_client.State != WebSocketState.Connecting && _client.State != WebSocketState.Open) + return; + + // Close with timeout, mostly to work around https://github.com/dotnet/runtime/issues/51590 + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + try + { + await _client.CloseAsync(closeStatus, description, cts.Token); + } + catch (Exception e) + { + _logger.Error(e, "Error closing WebSocket connection"); + } } } } \ No newline at end of file diff --git a/Myriad/Gateway/ShardIdentifyRatelimiter.cs b/Myriad/Gateway/ShardIdentifyRatelimiter.cs new file mode 100644 index 00000000..2d364950 --- /dev/null +++ b/Myriad/Gateway/ShardIdentifyRatelimiter.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Concurrent; +using System.Threading.Tasks; + +using Serilog; + +namespace Myriad.Gateway +{ + public class ShardIdentifyRatelimiter + { + private static readonly TimeSpan BucketLength = TimeSpan.FromSeconds(5); + + private readonly ConcurrentDictionary> _buckets = new(); + private readonly int _maxConcurrency; + + private Task? _refillTask; + private readonly ILogger _logger; + + public ShardIdentifyRatelimiter(ILogger logger, int maxConcurrency) + { + _logger = logger.ForContext(); + _maxConcurrency = maxConcurrency; + } + + public Task Acquire(int shard) + { + var bucket = shard % _maxConcurrency; + var queue = _buckets.GetOrAdd(bucket, _ => new ConcurrentQueue()); + var tcs = new TaskCompletionSource(); + queue.Enqueue(tcs); + + ScheduleRefill(); + + return tcs.Task; + } + + private void ScheduleRefill() + { + if (_refillTask != null && !_refillTask.IsCompleted) + return; + + _refillTask?.Dispose(); + _refillTask = RefillTask(); + } + + private async Task RefillTask() + { + await Task.Delay(TimeSpan.FromMilliseconds(250)); + + while (true) + { + var isClear = true; + foreach (var (bucket, queue) in _buckets) + { + if (!queue.TryDequeue(out var tcs)) + continue; + + _logger.Information( + "Allowing identify for bucket {BucketId} through ({QueueLength} left in bucket queue)", + bucket, queue.Count); + tcs.SetResult(); + isClear = false; + } + + if (isClear) + return; + + await Task.Delay(BucketLength); + } + } + } +} \ No newline at end of file diff --git a/Myriad/Gateway/ShardPacketSerializer.cs b/Myriad/Gateway/ShardPacketSerializer.cs new file mode 100644 index 00000000..133ec960 --- /dev/null +++ b/Myriad/Gateway/ShardPacketSerializer.cs @@ -0,0 +1,70 @@ +using System; +using System.Buffers; +using System.IO; +using System.Net.WebSockets; +using System.Text.Json; +using System.Threading.Tasks; + +namespace Myriad.Gateway +{ + public class ShardPacketSerializer + { + private const int BufferSize = 64 * 1024; + + private readonly JsonSerializerOptions _jsonSerializerOptions; + + public ShardPacketSerializer(JsonSerializerOptions jsonSerializerOptions) + { + _jsonSerializerOptions = jsonSerializerOptions; + } + + public async ValueTask<(WebSocketMessageType type, GatewayPacket? packet)> ReadPacket(ClientWebSocket socket) + { + using var buf = MemoryPool.Shared.Rent(BufferSize); + + var res = await socket.ReceiveAsync(buf.Memory, default); + if (res.MessageType == WebSocketMessageType.Close) + return (res.MessageType, null); + + if (res.EndOfMessage) + // Entire packet fits within one buffer, deserialize directly + return DeserializeSingleBuffer(buf, res); + + // Otherwise copy to stream buffer and deserialize from there + return await DeserializeMultipleBuffer(socket, buf, res); + } + + public async Task WritePacket(ClientWebSocket socket, GatewayPacket packet) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions); + await socket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default); + } + + private async Task<(WebSocketMessageType type, GatewayPacket packet)> DeserializeMultipleBuffer(ClientWebSocket socket, IMemoryOwner buf, ValueWebSocketReceiveResult res) + { + await using var stream = new MemoryStream(BufferSize * 4); + stream.Write(buf.Memory.Span.Slice(0, res.Count)); + + while (!res.EndOfMessage) + { + res = await socket.ReceiveAsync(buf.Memory, default); + stream.Write(buf.Memory.Span.Slice(0, res.Count)); + } + + return DeserializeObject(res, stream.GetBuffer().AsSpan(0, (int) stream.Length)); + } + + private (WebSocketMessageType type, GatewayPacket packet) DeserializeSingleBuffer( + IMemoryOwner buf, ValueWebSocketReceiveResult res) + { + var span = buf.Memory.Span.Slice(0, res.Count); + return DeserializeObject(res, span); + } + + private (WebSocketMessageType type, GatewayPacket packet) DeserializeObject(ValueWebSocketReceiveResult res, Span span) + { + var packet = JsonSerializer.Deserialize(span, _jsonSerializerOptions)!; + return (res.MessageType, packet); + } + } +} \ No newline at end of file diff --git a/Myriad/Gateway/ShardSessionInfo.cs b/Myriad/Gateway/ShardSessionInfo.cs deleted file mode 100644 index 81d6ee5f..00000000 --- a/Myriad/Gateway/ShardSessionInfo.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Myriad.Gateway -{ - public record ShardSessionInfo - { - public string? Session { get; init; } - public int? LastSequence { get; init; } - } -} \ No newline at end of file diff --git a/Myriad/Gateway/State/HeartbeatWorker.cs b/Myriad/Gateway/State/HeartbeatWorker.cs new file mode 100644 index 00000000..794cfc2c --- /dev/null +++ b/Myriad/Gateway/State/HeartbeatWorker.cs @@ -0,0 +1,63 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Myriad.Gateway.State +{ + public class HeartbeatWorker: IAsyncDisposable + { + private Task? _worker; + private CancellationTokenSource? _workerCts; + + public TimeSpan? CurrentHeartbeatInterval { get; private set; } + + public async ValueTask Start(TimeSpan heartbeatInterval, Func callback) + { + if (_worker != null) + await Stop(); + + CurrentHeartbeatInterval = heartbeatInterval; + _workerCts = new CancellationTokenSource(); + _worker = Worker(heartbeatInterval, callback, _workerCts.Token); + } + + public async ValueTask Stop() + { + if (_worker == null) + return; + + _workerCts?.Cancel(); + try + { + await _worker; + } + catch (TaskCanceledException) { } + + _worker?.Dispose(); + _workerCts?.Dispose(); + _worker = null; + CurrentHeartbeatInterval = null; + } + + private async Task Worker(TimeSpan heartbeatInterval, Func callback, CancellationToken ct) + { + var initialDelay = GetInitialHeartbeatDelay(heartbeatInterval); + await Task.Delay(initialDelay, ct); + + while (!ct.IsCancellationRequested) + { + await callback(); + await Task.Delay(heartbeatInterval, ct); + } + } + + private static TimeSpan GetInitialHeartbeatDelay(TimeSpan heartbeatInterval) => + // Docs specify `heartbeat_interval * random.random()` but we'll add a lil buffer :) + heartbeatInterval * (new Random().NextDouble() * 0.9 + 0.05); + + public async ValueTask DisposeAsync() + { + await Stop(); + } + } +} \ No newline at end of file diff --git a/Myriad/Gateway/State/ShardState.cs b/Myriad/Gateway/State/ShardState.cs new file mode 100644 index 00000000..a3bacaee --- /dev/null +++ b/Myriad/Gateway/State/ShardState.cs @@ -0,0 +1,11 @@ +namespace Myriad.Gateway.State +{ + public enum ShardState + { + Disconnected, + Handshaking, + Identifying, + Connected, + Reconnecting + } +} \ No newline at end of file diff --git a/Myriad/Gateway/State/ShardStateManager.cs b/Myriad/Gateway/State/ShardStateManager.cs new file mode 100644 index 00000000..f57b1c91 --- /dev/null +++ b/Myriad/Gateway/State/ShardStateManager.cs @@ -0,0 +1,244 @@ +using System; +using System.Net.WebSockets; +using System.Text.Json; +using System.Threading.Tasks; + +using Myriad.Gateway.State; +using Myriad.Types; + +using Serilog; + +namespace Myriad.Gateway +{ + public class ShardStateManager + { + private readonly HeartbeatWorker _heartbeatWorker = new(); + private readonly ILogger _logger; + + private readonly ShardInfo _info; + private readonly JsonSerializerOptions _jsonSerializerOptions; + private ShardState _state = ShardState.Disconnected; + + private DateTimeOffset? _lastHeartbeatSent; + private TimeSpan? _latency; + private bool _hasReceivedHeartbeatAck; + + private string? _sessionId; + private int? _lastSeq; + + public ShardState State => _state; + public TimeSpan? Latency => _latency; + public User? User { get; private set; } + public ApplicationPartial? Application { get; private set; } + + public Func SendIdentify { get; init; } + public Func<(string SessionId, int? LastSeq), Task> SendResume { get; init; } + public Func SendHeartbeat { get; init; } + public Func Reconnect { get; init; } + public Func Connect { get; init; } + public Func HandleEvent { get; init; } + + public event Action OnHeartbeatReceived; + + public ShardStateManager(ShardInfo info, JsonSerializerOptions jsonSerializerOptions, ILogger logger) + { + _info = info; + _jsonSerializerOptions = jsonSerializerOptions; + _logger = logger.ForContext(); + } + + public Task HandleConnectionOpened() + { + _state = ShardState.Handshaking; + return Task.CompletedTask; + } + + public async Task HandleConnectionClosed() + { + _latency = null; + await _heartbeatWorker.Stop(); + } + + public async Task HandlePacketReceived(GatewayPacket packet) + { + switch (packet.Opcode) + { + case GatewayOpcode.Hello: + var hello = DeserializePayload(packet); + await HandleHello(hello); + break; + + case GatewayOpcode.Heartbeat: + await HandleHeartbeatRequest(); + break; + + case GatewayOpcode.HeartbeatAck: + await HandleHeartbeatAck(); + break; + + case GatewayOpcode.Reconnect: + { + await HandleReconnect(); + break; + } + + case GatewayOpcode.InvalidSession: + { + var canResume = DeserializePayload(packet); + await HandleInvalidSession(canResume); + break; + } + + case GatewayOpcode.Dispatch: + _lastSeq = packet.Sequence; + + var evt = DeserializeEvent(packet.EventType!, (JsonElement) packet.Payload!); + if (evt != null) + { + if (evt is ReadyEvent ready) + await HandleReady(ready); + + if (evt is ResumedEvent) + await HandleResumed(); + + await HandleEvent(evt); + } + break; + } + } + + private async Task HandleHello(GatewayHello hello) + { + var interval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval); + + _hasReceivedHeartbeatAck = true; + await _heartbeatWorker.Start(interval, HandleHeartbeatTimer); + await IdentifyOrResume(); + } + + private async Task IdentifyOrResume() + { + _state = ShardState.Identifying; + + if (_sessionId != null) + { + _logger.Information("Shard {ShardId}: Received Hello, attempting to resume (seq {LastSeq})", + _info.ShardId, _lastSeq); + await SendResume((_sessionId!, _lastSeq)); + } + else + { + _logger.Information("Shard {ShardId}: Received Hello, identifying", + _info.ShardId); + + await SendIdentify(); + } + } + + private Task HandleHeartbeatAck() + { + _hasReceivedHeartbeatAck = true; + _latency = DateTimeOffset.UtcNow - _lastHeartbeatSent; + OnHeartbeatReceived?.Invoke(_latency!.Value); + _logger.Information("Shard {ShardId}: Received Heartbeat (latency {Latency} ms)", + _info.ShardId, _latency); + return Task.CompletedTask; + } + + private async Task HandleInvalidSession(bool canResume) + { + if (!canResume) + { + _sessionId = null; + _lastSeq = null; + } + + _logger.Information("Shard {ShardId}: Received Invalid Session (can resume? {CanResume})", + _info.ShardId, canResume); + + var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000)); + await DoReconnect(WebSocketCloseStatus.NormalClosure, delay); + } + + private async Task HandleReconnect() + { + _logger.Information("Shard {ShardId}: Received Reconnect", _info.ShardId); + await DoReconnect(WebSocketCloseStatus.NormalClosure, TimeSpan.FromSeconds(1)); + } + + private Task HandleReady(ReadyEvent ready) + { + _logger.Information("Shard {ShardId}: Received Ready", _info.ShardId); + + _sessionId = ready.SessionId; + _state = ShardState.Connected; + User = ready.User; + Application = ready.Application; + return Task.CompletedTask; + } + + private Task HandleResumed() + { + _logger.Information("Shard {ShardId}: Received Resume", _info.ShardId); + + _state = ShardState.Connected; + return Task.CompletedTask; + } + + private async Task HandleHeartbeatRequest() + { + await SendHeartbeatInternal(); + } + + private async Task SendHeartbeatInternal() + { + await SendHeartbeat(_lastSeq); + _lastHeartbeatSent = DateTimeOffset.UtcNow; + } + + private async Task HandleHeartbeatTimer() + { + if (!_hasReceivedHeartbeatAck) + { + _logger.Warning("Shard {ShardId}: Heartbeat worker timed out", _info.ShardId); + await DoReconnect(WebSocketCloseStatus.ProtocolError, TimeSpan.Zero); + return; + } + + await SendHeartbeatInternal(); + } + + private async Task DoReconnect(WebSocketCloseStatus closeStatus, TimeSpan delay) + { + _state = ShardState.Reconnecting; + await Reconnect(closeStatus, delay); + } + + private T DeserializePayload(GatewayPacket packet) + { + var packetPayload = (JsonElement) packet.Payload!; + return JsonSerializer.Deserialize(packetPayload.GetRawText(), _jsonSerializerOptions)!; + } + + private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload) + { + if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType)) + { + _logger.Debug("Shard {ShardId}: Received unknown event type {EventType}", _info.ShardId, eventType); + return null; + } + + try + { + _logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", _info.ShardId, eventType, clrType); + return JsonSerializer.Deserialize(payload.GetRawText(), clrType, _jsonSerializerOptions) + as IGatewayEvent; + } + catch (JsonException e) + { + _logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", _info.ShardId, eventType, clrType); + return null; + } + } + } +} \ No newline at end of file diff --git a/Myriad/Rest/DiscordApiClient.cs b/Myriad/Rest/DiscordApiClient.cs index 4612fd2c..76f8b4b9 100644 --- a/Myriad/Rest/DiscordApiClient.cs +++ b/Myriad/Rest/DiscordApiClient.cs @@ -1,5 +1,4 @@ using System; -using System.IO; using System.Net; using System.Threading.Tasks; diff --git a/Myriad/Rest/Ratelimit/BucketManager.cs b/Myriad/Rest/Ratelimit/BucketManager.cs index edea0825..0f8ccf0e 100644 --- a/Myriad/Rest/Ratelimit/BucketManager.cs +++ b/Myriad/Rest/Ratelimit/BucketManager.cs @@ -69,12 +69,14 @@ namespace Myriad.Rest.Ratelimit private void PruneStaleBuckets(DateTimeOffset now) { foreach (var (key, bucket) in _buckets) - if (now - bucket.LastUsed > StaleBucketTimeout) - { - _logger.Debug("Pruning unused bucket {Bucket} (last used at {BucketLastUsed})", bucket, - bucket.LastUsed); - _buckets.TryRemove(key, out _); - } + { + if (now - bucket.LastUsed <= StaleBucketTimeout) + continue; + + _logger.Debug("Pruning unused bucket {BucketKey}/{BucketMajor} (last used at {BucketLastUsed})", + bucket.Key, bucket.Major, bucket.LastUsed); + _buckets.TryRemove(key, out _); + } } } } \ No newline at end of file diff --git a/Myriad/Types/Application/Application.cs b/Myriad/Types/Application/Application.cs index 1fe04127..e277e946 100644 --- a/Myriad/Types/Application/Application.cs +++ b/Myriad/Types/Application/Application.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; - -namespace Myriad.Types +namespace Myriad.Types { public record Application: ApplicationPartial { diff --git a/Myriad/Types/Application/ApplicationCommand.cs b/Myriad/Types/Application/ApplicationCommand.cs index 92ecd856..53f88dd6 100644 --- a/Myriad/Types/Application/ApplicationCommand.cs +++ b/Myriad/Types/Application/ApplicationCommand.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; - -namespace Myriad.Types +namespace Myriad.Types { public record ApplicationCommand { diff --git a/Myriad/Types/Application/InteractionApplicationCommandCallbackData.cs b/Myriad/Types/Application/InteractionApplicationCommandCallbackData.cs index 2718aa0e..32c9aaac 100644 --- a/Myriad/Types/Application/InteractionApplicationCommandCallbackData.cs +++ b/Myriad/Types/Application/InteractionApplicationCommandCallbackData.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; - -using Myriad.Rest.Types; +using Myriad.Rest.Types; namespace Myriad.Types { diff --git a/Myriad/Types/Guild.cs b/Myriad/Types/Guild.cs index 9b9cccfe..413425cb 100644 --- a/Myriad/Types/Guild.cs +++ b/Myriad/Types/Guild.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; - -namespace Myriad.Types +namespace Myriad.Types { public record Guild { diff --git a/Myriad/Types/Message.cs b/Myriad/Types/Message.cs index a7cb88c6..71ecbedb 100644 --- a/Myriad/Types/Message.cs +++ b/Myriad/Types/Message.cs @@ -1,6 +1,4 @@ using System; -using System.Collections.Generic; -using System.Net.Mail; using System.Text.Json.Serialization; using Myriad.Utils; diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index 550c8d48..9b600876 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -285,7 +285,7 @@ namespace PluralKit.Bot { new ActivityPartial { - Name = $"pk;help | in {totalGuilds} servers | shard #{shard.ShardInfo?.ShardId}", + Name = $"pk;help | in {totalGuilds} servers | shard #{shard.ShardId}", Type = ActivityType.Game, Url = "https://pluralkit.me/" } diff --git a/PluralKit.Bot/CommandSystem/Context.cs b/PluralKit.Bot/CommandSystem/Context.cs index 0ad82b6e..adcb8642 100644 --- a/PluralKit.Bot/CommandSystem/Context.cs +++ b/PluralKit.Bot/CommandSystem/Context.cs @@ -29,8 +29,6 @@ namespace PluralKit.Bot private readonly MessageCreateEvent _message; private readonly Parameters _parameters; private readonly MessageContext _messageContext; - private readonly PermissionSet _botPermissions; - private readonly PermissionSet _userPermissions; private readonly IDatabase _db; private readonly ModelRepository _repo; @@ -42,7 +40,7 @@ namespace PluralKit.Bot private Command _currentCommand; public Context(ILifetimeScope provider, Shard shard, Guild? guild, Channel channel, MessageCreateEvent message, int commandParseOffset, - PKSystem senderSystem, MessageContext messageContext, PermissionSet botPermissions) + PKSystem senderSystem, MessageContext messageContext) { _message = message; _shard = shard; @@ -59,9 +57,6 @@ namespace PluralKit.Bot _parameters = new Parameters(message.Content?.Substring(commandParseOffset)); _rest = provider.Resolve(); _cluster = provider.Resolve(); - - _botPermissions = botPermissions; - _userPermissions = _cache.PermissionsFor(message); } public IDiscordCache Cache => _cache; @@ -76,8 +71,8 @@ namespace PluralKit.Bot public Cluster Cluster => _cluster; public MessageContext MessageContext => _messageContext; - public PermissionSet BotPermissions => _botPermissions; - public PermissionSet UserPermissions => _userPermissions; + public PermissionSet BotPermissions => _provider.Resolve().PermissionsIn(_channel.Id); + public PermissionSet UserPermissions => _cache.PermissionsFor(_message); public DiscordApiClient Rest => _rest; diff --git a/PluralKit.Bot/Commands/Misc.cs b/PluralKit.Bot/Commands/Misc.cs index 23b654f3..4b510fa7 100644 --- a/PluralKit.Bot/Commands/Misc.cs +++ b/PluralKit.Bot/Commands/Misc.cs @@ -84,7 +84,7 @@ namespace PluralKit.Bot { var totalSwitches = _metrics.Snapshot.GetForContext("Application").Gauges.FirstOrDefault(m => m.MultidimensionalName == CoreMetrics.SwitchCount.Name)?.Value ?? 0; var totalMessages = _metrics.Snapshot.GetForContext("Application").Gauges.FirstOrDefault(m => m.MultidimensionalName == CoreMetrics.MessageCount.Name)?.Value ?? 0; - var shardId = ctx.Shard.ShardInfo.ShardId; + var shardId = ctx.Shard.ShardId; var shardTotal = ctx.Cluster.Shards.Count; var shardUpTotal = _shards.Shards.Where(x => x.Connected).Count(); var shardInfo = _shards.GetShardInfo(ctx.Shard); diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index 30ced3ee..0d8da48f 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -114,7 +114,7 @@ namespace PluralKit.Bot try { var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null; - await _tree.ExecuteCommand(new Context(_services, shard, guild, channel, evt, cmdStart, system, ctx, _bot.PermissionsIn(channel.Id))); + await _tree.ExecuteCommand(new Context(_services, shard, guild, channel, evt, cmdStart, system, ctx)); } catch (PKError) { diff --git a/PluralKit.Bot/Init.cs b/PluralKit.Bot/Init.cs index f4b7c6f5..222a40b1 100644 --- a/PluralKit.Bot/Init.cs +++ b/PluralKit.Bot/Init.cs @@ -49,7 +49,7 @@ namespace PluralKit.Bot // Start the Discord shards themselves (handlers already set up) logger.Information("Connecting to Discord"); var info = await services.Resolve().GetGatewayBot(); - await services.Resolve().Start(info); + await services.Resolve().Start(info with { Shards = 10 }); logger.Information("Connected! All is good (probably)."); // Lastly, we just... wait. Everything else is handled in the DiscordClient event loop diff --git a/PluralKit.Bot/Services/ShardInfoService.cs b/PluralKit.Bot/Services/ShardInfoService.cs index d35cc299..5fdc2aa1 100644 --- a/PluralKit.Bot/Services/ShardInfoService.cs +++ b/PluralKit.Bot/Services/ShardInfoService.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; using System.Linq; using System.Net.WebSockets; -using System.Threading.Tasks; using App.Metrics; @@ -66,11 +65,8 @@ namespace PluralKit.Bot } else _shardInfo[shard.ShardId] = info = new ShardInfo(); // Call our own SocketOpened listener manually (and then attach the listener properly) - SocketOpened(shard); - shard.SocketOpened += () => SocketOpened(shard); - + // Register listeners for new shards - _logger.Information("Attaching listeners to new shard #{Shard}", shard.ShardId); shard.Resumed += () => Resumed(shard); shard.Ready += () => Ready(shard); shard.SocketClosed += (closeStatus, message) => SocketClosed(shard, closeStatus, message); @@ -78,14 +74,6 @@ namespace PluralKit.Bot // Register that we've seen it info.HasAttachedListeners = true; - - } - - private void SocketOpened(Shard shard) - { - // We do nothing else here, since this kinda doesn't mean *much*? It's only really started once we get Ready/Resumed - // And it doesn't get fired first time around since we don't have time to add the event listener before it's fired' - _logger.Information("Shard #{Shard} opened socket", shard.ShardId); } private ShardInfo TryGetShard(Shard shard) @@ -100,29 +88,22 @@ namespace PluralKit.Bot private void Resumed(Shard shard) { - _logger.Information("Shard #{Shard} resumed connection", shard.ShardId); - - var info = TryGetShard(shard); - // info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant(); - info.Connected = true; - ReportShardStatus(); - } - - private void Ready(Shard shard) - { - _logger.Information("Shard #{Shard} sent Ready event", shard.ShardId); - var info = TryGetShard(shard); info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant(); info.Connected = true; ReportShardStatus(); } - private void SocketClosed(Shard shard, WebSocketCloseStatus closeStatus, string message) + private void Ready(Shard shard) + { + var info = TryGetShard(shard); + info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant(); + info.Connected = true; + ReportShardStatus(); + } + + private void SocketClosed(Shard shard, WebSocketCloseStatus? closeStatus, string message) { - _logger.Warning("Shard #{Shard} disconnected ({CloseCode}: {CloseMessage})", - shard.ShardId, closeStatus, message); - var info = TryGetShard(shard); info.DisconnectionCount++; info.Connected = false; @@ -131,9 +112,6 @@ namespace PluralKit.Bot private void Heartbeated(Shard shard, TimeSpan latency) { - _logger.Information("Shard #{Shard} received heartbeat (latency: {Latency} ms)", - shard.ShardId, latency.Milliseconds); - var info = TryGetShard(shard); info.LastHeartbeatTime = SystemClock.Instance.GetCurrentInstant(); info.Connected = true; diff --git a/PluralKit.Bot/Utils/ContextUtils.cs b/PluralKit.Bot/Utils/ContextUtils.cs index 71da564c..f4c4759c 100644 --- a/PluralKit.Bot/Utils/ContextUtils.cs +++ b/PluralKit.Bot/Utils/ContextUtils.cs @@ -155,6 +155,7 @@ namespace PluralKit.Bot { // "escape hatch", clean up as if we hit X } + // todo: re-check if (ctx.BotPermissions.HasFlag(PermissionSet.ManageMessages)) await ctx.Rest.DeleteAllReactions(msg.ChannelId, msg.Id); } diff --git a/PluralKit.Bot/Utils/SentryUtils.cs b/PluralKit.Bot/Utils/SentryUtils.cs index ed1bf2f5..a43bfad9 100644 --- a/PluralKit.Bot/Utils/SentryUtils.cs +++ b/PluralKit.Bot/Utils/SentryUtils.cs @@ -38,7 +38,7 @@ namespace PluralKit.Bot {"guild", evt.GuildId.ToString()}, {"message", evt.Id.ToString()}, }); - scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); + scope.SetTag("shard", shard.ShardId.ToString()); // Also report information about the bot's permissions in the channel // We get a lot of permission errors so this'll be useful for determining problems @@ -55,7 +55,7 @@ namespace PluralKit.Bot {"guild", evt.GuildId.ToString()}, {"message", evt.Id.ToString()}, }); - scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); + scope.SetTag("shard", shard.ShardId.ToString()); } public void Enrich(Scope scope, Shard shard, MessageUpdateEvent evt) @@ -67,7 +67,7 @@ namespace PluralKit.Bot {"guild", evt.GuildId.Value.ToString()}, {"message", evt.Id.ToString()} }); - scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); + scope.SetTag("shard", shard.ShardId.ToString()); } public void Enrich(Scope scope, Shard shard, MessageDeleteBulkEvent evt) @@ -79,7 +79,7 @@ namespace PluralKit.Bot {"guild", evt.GuildId.ToString()}, {"messages", string.Join(",", evt.Ids)}, }); - scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); + scope.SetTag("shard", shard.ShardId.ToString()); } public void Enrich(Scope scope, Shard shard, MessageReactionAddEvent evt) @@ -93,7 +93,7 @@ namespace PluralKit.Bot {"message", evt.MessageId.ToString()}, {"reaction", evt.Emoji.Name} }); - scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); + scope.SetTag("shard", shard.ShardId.ToString()); } } } \ No newline at end of file