Merge branch 'new-shard-handler'

This commit is contained in:
Ske 2021-06-08 10:20:59 +02:00
commit 41f1c58a9f
24 changed files with 723 additions and 538 deletions

View File

@ -1,6 +1,4 @@
using System; using Myriad.Gateway;
using Myriad.Gateway;
using Myriad.Types; using Myriad.Types;
namespace Myriad.Extensions namespace Myriad.Extensions

View File

@ -1,7 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading.Tasks;
using Myriad.Cache; using Myriad.Cache;
using Myriad.Gateway; using Myriad.Gateway;

View File

@ -15,6 +15,7 @@ namespace Myriad.Gateway
private readonly GatewaySettings _gatewaySettings; private readonly GatewaySettings _gatewaySettings;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly ConcurrentDictionary<int, Shard> _shards = new(); private readonly ConcurrentDictionary<int, Shard> _shards = new();
private ShardIdentifyRatelimiter? _ratelimiter;
public Cluster(GatewaySettings gatewaySettings, ILogger logger) public Cluster(GatewaySettings gatewaySettings, ILogger logger)
{ {
@ -26,81 +27,35 @@ namespace Myriad.Gateway
public event Action<Shard>? ShardCreated; public event Action<Shard>? ShardCreated;
public IReadOnlyDictionary<int, Shard> Shards => _shards; public IReadOnlyDictionary<int, Shard> Shards => _shards;
public ClusterSessionState SessionState => GetClusterState();
public User? User => _shards.Values.Select(s => s.User).FirstOrDefault(s => s != null); public User? User => _shards.Values.Select(s => s.User).FirstOrDefault(s => s != null);
public ApplicationPartial? Application => _shards.Values.Select(s => s.Application).FirstOrDefault(s => s != null); public ApplicationPartial? Application => _shards.Values.Select(s => s.Application).FirstOrDefault(s => s != null);
private ClusterSessionState GetClusterState() public async Task Start(GatewayInfo.Bot info)
{ {
var shards = new List<ClusterSessionState.ShardState>(); var concurrency = GetActualShardConcurrency(info.SessionStartLimit.MaxConcurrency);
foreach (var (id, shard) in _shards) _ratelimiter = new(_logger, concurrency);
shards.Add(new ClusterSessionState.ShardState
{ await Start(info.Url, info.Shards);
Shard = shard.ShardInfo,
Session = shard.SessionInfo
});
return new ClusterSessionState {Shards = shards};
} }
public async Task Start(GatewayInfo.Bot info, ClusterSessionState? lastState = null) public async Task Start(string url, int shardCount)
{
if (lastState != null && lastState.Shards.Count == info.Shards)
await Resume(info.Url, lastState, info.SessionStartLimit.MaxConcurrency);
else
await Start(info.Url, info.Shards, info.SessionStartLimit.MaxConcurrency);
}
public async Task Resume(string url, ClusterSessionState sessionState, int concurrency)
{
_logger.Information("Resuming session with {ShardCount} shards at {Url}", sessionState.Shards.Count, url);
foreach (var shardState in sessionState.Shards)
CreateAndAddShard(url, shardState.Shard, shardState.Session);
await StartShards(concurrency);
}
public async Task Start(string url, int shardCount, int concurrency)
{ {
_logger.Information("Starting {ShardCount} shards at {Url}", shardCount, url); _logger.Information("Starting {ShardCount} shards at {Url}", shardCount, url);
for (var i = 0; i < shardCount; i++) for (var i = 0; i < shardCount; i++)
CreateAndAddShard(url, new ShardInfo(i, shardCount), null); CreateAndAddShard(url, new ShardInfo(i, shardCount));
await StartShards(concurrency); await StartShards();
} }
private async Task StartShards(int concurrency) private async Task StartShards()
{ {
concurrency = GetActualShardConcurrency(concurrency);
var lastTime = DateTimeOffset.UtcNow;
var identifyCalls = 0;
_logger.Information("Connecting shards..."); _logger.Information("Connecting shards...");
foreach (var shard in _shards.Values) foreach (var shard in _shards.Values)
{
if (identifyCalls >= concurrency)
{
var timeout = lastTime + TimeSpan.FromSeconds(5.5);
var delay = timeout - DateTimeOffset.UtcNow;
if (delay > TimeSpan.Zero)
{
_logger.Information("Hit shard concurrency limit, waiting {Delay}", delay);
await Task.Delay(delay);
}
identifyCalls = 0;
lastTime = DateTimeOffset.UtcNow;
}
await shard.Start(); await shard.Start();
identifyCalls++;
}
} }
private void CreateAndAddShard(string url, ShardInfo shardInfo, ShardSessionInfo? session) private void CreateAndAddShard(string url, ShardInfo shardInfo)
{ {
var shard = new Shard(_logger, new Uri(url), _gatewaySettings, shardInfo, session); var shard = new Shard(_gatewaySettings, shardInfo, _ratelimiter!, url, _logger);
shard.OnEventReceived += evt => OnShardEventReceived(shard, evt); shard.OnEventReceived += evt => OnShardEventReceived(shard, evt);
_shards[shardInfo.ShardId] = shard; _shards[shardInfo.ShardId] = shard;

View File

@ -1,15 +0,0 @@
using System.Collections.Generic;
namespace Myriad.Gateway
{
public record ClusterSessionState
{
public List<ShardState> Shards { get; init; }
public record ShardState
{
public ShardInfo Shard { get; init; }
public ShardSessionInfo Session { get; init; }
}
}
}

View File

@ -1,6 +1,4 @@
using System.Collections.Generic; using Myriad.Types;
using Myriad.Types;
namespace Myriad.Gateway namespace Myriad.Gateway
{ {

View File

@ -3,6 +3,7 @@ using System.Net.WebSockets;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks; using System.Threading.Tasks;
using Myriad.Gateway.State;
using Myriad.Serialization; using Myriad.Serialization;
using Myriad.Types; using Myriad.Types;
@ -10,340 +11,192 @@ using Serilog;
namespace Myriad.Gateway namespace Myriad.Gateway
{ {
public class Shard: IAsyncDisposable public class Shard
{ {
private const string LibraryName = "Myriad (for PluralKit)"; private const string LibraryName = "Myriad (for PluralKit)";
private readonly JsonSerializerOptions _jsonSerializerOptions = private readonly GatewaySettings _settings;
new JsonSerializerOptions().ConfigureForMyriad(); private readonly ShardInfo _info;
private readonly ShardIdentifyRatelimiter _ratelimiter;
private readonly string _url;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly Uri _uri; private readonly ShardStateManager _stateManager;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private ShardConnection? _conn; private readonly ShardConnection _conn;
private TimeSpan? _currentHeartbeatInterval;
private bool _hasReceivedAck;
private DateTimeOffset? _lastHeartbeatSent;
private Task _worker;
public ShardInfo ShardInfo { get; private set; }
public int ShardId => ShardInfo.ShardId;
public GatewaySettings Settings { get; }
public ShardSessionInfo SessionInfo { get; private set; }
public ShardState State { get; private set; }
public TimeSpan? Latency { get; private set; }
public User? User { get; private set; }
public ApplicationPartial? Application { get; private set; }
public Func<IGatewayEvent, Task>? OnEventReceived { get; set; } public int ShardId => _info.ShardId;
public ShardState State => _stateManager.State;
public TimeSpan? Latency => _stateManager.Latency;
public User? User => _stateManager.User;
public ApplicationPartial? Application => _stateManager.Application;
// TODO: I wanna get rid of these or move them at some point
public event Func<IGatewayEvent, Task>? OnEventReceived;
public event Action<TimeSpan>? HeartbeatReceived; public event Action<TimeSpan>? HeartbeatReceived;
public event Action? SocketOpened; public event Action? SocketOpened;
public event Action? Resumed; public event Action? Resumed;
public event Action? Ready; public event Action? Ready;
public event Action<WebSocketCloseStatus, string?>? SocketClosed; public event Action<WebSocketCloseStatus?, string?>? SocketClosed;
private TimeSpan _reconnectDelay = TimeSpan.Zero;
private Task? _worker;
public Shard(GatewaySettings settings, ShardInfo info, ShardIdentifyRatelimiter ratelimiter, string url, ILogger logger)
{
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
_settings = settings;
_info = info;
_ratelimiter = ratelimiter;
_url = url;
_logger = logger;
_stateManager = new ShardStateManager(info, _jsonSerializerOptions, logger)
{
HandleEvent = HandleEvent,
SendHeartbeat = SendHeartbeat,
SendIdentify = SendIdentify,
SendResume = SendResume,
Connect = ConnectInner,
Reconnect = Reconnect,
};
_stateManager.OnHeartbeatReceived += latency =>
{
HeartbeatReceived?.Invoke(latency);
};
_conn = new ShardConnection(_jsonSerializerOptions, _logger);
}
public Shard(ILogger logger, Uri uri, GatewaySettings settings, ShardInfo info, private async Task ShardLoop()
ShardSessionInfo? sessionInfo = null)
{ {
_logger = logger.ForContext<Shard>(); while (true)
_uri = uri; {
try
{
await ConnectInner();
await HandleConnectionOpened();
Settings = settings; while (_conn.State == WebSocketState.Open)
ShardInfo = info; {
SessionInfo = sessionInfo ?? new ShardSessionInfo(); var packet = await _conn.Read();
if (packet == null)
break;
await _stateManager.HandlePacketReceived(packet);
}
await HandleConnectionClosed(_conn.CloseStatus, _conn.CloseStatusDescription);
_logger.Information("Shard {ShardId}: Reconnecting after delay {ReconnectDelay}",
_info.ShardId, _reconnectDelay);
if (_reconnectDelay > TimeSpan.Zero)
await Task.Delay(_reconnectDelay);
}
catch (Exception e)
{
_logger.Error(e, "Shard {ShardId}: Error in main shard loop, reconnecting in 5 seconds...", _info.ShardId);
// todo: exponential backoff here? this should never happen, ideally...
await Task.Delay(TimeSpan.FromSeconds(5));
}
}
} }
public async ValueTask DisposeAsync()
{
if (_conn != null)
await _conn.DisposeAsync();
}
public Task Start() public Task Start()
{ {
_worker = MainLoop(); if (_worker == null)
_worker = ShardLoop();
return Task.CompletedTask; return Task.CompletedTask;
} }
public async Task UpdateStatus(GatewayStatusUpdate payload) public async Task UpdateStatus(GatewayStatusUpdate payload)
{ {
if (_conn != null && _conn.State == WebSocketState.Open) await _conn.Send(new GatewayPacket
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.PresenceUpdate, Payload = payload});
}
private async Task MainLoop()
{
while (true)
try
{
_logger.Information("Shard {ShardId}: Connecting...", ShardId);
State = ShardState.Connecting;
await Connect();
_logger.Information("Shard {ShardId}: Connected. Entering main loop...", ShardId);
// Tick returns false if we need to stop and reconnect
while (await Tick(_conn!))
await Task.Delay(TimeSpan.FromMilliseconds(1000));
_logger.Information("Shard {ShardId}: Connection closed, reconnecting...", ShardId);
State = ShardState.Closed;
}
catch (Exception e)
{
_logger.Error(e, "Shard {ShardId}: Error in shard state handler", ShardId);
}
}
private async Task<bool> Tick(ShardConnection conn)
{
if (conn.State != WebSocketState.Connecting && conn.State != WebSocketState.Open)
return false;
if (!await TickHeartbeat(conn))
// TickHeartbeat returns false if we're disconnecting
return false;
return true;
}
private async Task<bool> TickHeartbeat(ShardConnection conn)
{
// If we don't need to heartbeat, do nothing
if (_lastHeartbeatSent == null || _currentHeartbeatInterval == null)
return true;
if (DateTimeOffset.UtcNow - _lastHeartbeatSent < _currentHeartbeatInterval)
return true;
// If we haven't received the ack in time, close w/ error
if (!_hasReceivedAck)
{ {
_logger.Warning( Opcode = GatewayOpcode.PresenceUpdate,
"Shard {ShardId}: Did not receive heartbeat Ack from gateway within interval ({HeartbeatInterval})", Payload = payload
ShardId, _currentHeartbeatInterval); });
State = ShardState.Closing;
await conn.Disconnect(WebSocketCloseStatus.ProtocolError, "Did not receive ACK in time");
return false;
}
// Otherwise just send it :)
await SendHeartbeat(conn);
_hasReceivedAck = false;
return true;
} }
private async Task SendHeartbeat(ShardConnection conn) private async Task ConnectInner()
{ {
_logger.Debug("Shard {ShardId}: Sending heartbeat with seq.no. {LastSequence}", await _ratelimiter.Acquire(_info.ShardId);
ShardId, SessionInfo.LastSequence);
await conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = SessionInfo.LastSequence}); _logger.Information("Shard {ShardId}: Connecting to WebSocket", _info.ShardId);
_lastHeartbeatSent = DateTimeOffset.UtcNow; await _conn.Connect(_url, default);
} }
private async Task Connect() private async Task DisconnectInner(WebSocketCloseStatus closeStatus)
{ {
if (_conn != null) await _conn.Disconnect(closeStatus, null);
await _conn.DisposeAsync();
_currentHeartbeatInterval = null;
_conn = new ShardConnection(_uri, _logger, _jsonSerializerOptions)
{
OnReceive = OnReceive,
OnOpen = () => SocketOpened?.Invoke(),
OnClose = (closeStatus, message) => SocketClosed?.Invoke(closeStatus, message)
};
} }
private async Task OnReceive(GatewayPacket packet)
{
switch (packet.Opcode)
{
case GatewayOpcode.Hello:
{
await HandleHello((JsonElement) packet.Payload!);
break;
}
case GatewayOpcode.Heartbeat:
{
_logger.Debug("Shard {ShardId}: Received heartbeat request from shard, sending Ack", ShardId);
await _conn!.Send(new GatewayPacket {Opcode = GatewayOpcode.HeartbeatAck});
break;
}
case GatewayOpcode.HeartbeatAck:
{
Latency = DateTimeOffset.UtcNow - _lastHeartbeatSent;
_logger.Debug("Shard {ShardId}: Received heartbeat Ack with latency {Latency}", ShardId, Latency);
if (Latency != null)
HeartbeatReceived?.Invoke(Latency!.Value);
_hasReceivedAck = true;
break;
}
case GatewayOpcode.Reconnect:
{
_logger.Information("Shard {ShardId}: Received Reconnect, closing and reconnecting", ShardId);
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
break;
}
case GatewayOpcode.InvalidSession:
{
var canResume = ((JsonElement) packet.Payload!).GetBoolean();
// Clear session info before DCing
if (!canResume)
SessionInfo = SessionInfo with { Session = null };
var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000));
_logger.Information(
"Shard {ShardId}: Received Invalid Session (can resume? {CanResume}), reconnecting after {ReconnectDelay}",
ShardId, canResume, delay);
await _conn!.Disconnect(WebSocketCloseStatus.Empty, null);
// Will reconnect after exiting this "loop"
await Task.Delay(delay);
break;
}
case GatewayOpcode.Dispatch:
{
SessionInfo = SessionInfo with { LastSequence = packet.Sequence };
var evt = DeserializeEvent(packet.EventType!, (JsonElement) packet.Payload!)!;
if (evt is ReadyEvent rdy)
{
if (State == ShardState.Connecting)
await HandleReady(rdy);
else
_logger.Warning("Shard {ShardId}: Received Ready event in unexpected state {ShardState}, ignoring?",
ShardId, State);
}
else if (evt is ResumedEvent)
{
if (State == ShardState.Connecting)
await HandleResumed();
else
_logger.Warning("Shard {ShardId}: Received Resumed event in unexpected state {ShardState}, ignoring?",
ShardId, State);
}
await HandleEvent(evt);
break;
}
default:
{
_logger.Debug("Shard {ShardId}: Received unknown gateway opcode {Opcode}", ShardId, packet.Opcode);
break;
}
}
}
private async Task HandleEvent(IGatewayEvent evt)
{
if (OnEventReceived != null)
await OnEventReceived.Invoke(evt);
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement data)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{
_logger.Information("Shard {ShardId}: Received unknown event type {EventType}", ShardId, eventType);
return null;
}
try
{
_logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", ShardId, eventType, clrType);
return JsonSerializer.Deserialize(data.GetRawText(), clrType, _jsonSerializerOptions)
as IGatewayEvent;
}
catch (JsonException e)
{
_logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", ShardId, eventType, clrType);
return null;
}
}
private Task HandleReady(ReadyEvent ready)
{
// TODO: when is ready.Shard ever null?
ShardInfo = ready.Shard ?? new ShardInfo(0, 0);
SessionInfo = SessionInfo with { Session = ready.SessionId };
User = ready.User;
Application = ready.Application;
State = ShardState.Open;
Ready?.Invoke();
return Task.CompletedTask;
}
private Task HandleResumed()
{
State = ShardState.Open;
Resumed?.Invoke();
return Task.CompletedTask;
}
private async Task HandleHello(JsonElement json)
{
var hello = JsonSerializer.Deserialize<GatewayHello>(json.GetRawText(), _jsonSerializerOptions)!;
_logger.Debug("Shard {ShardId}: Received Hello with interval {Interval} ms", ShardId, hello.HeartbeatInterval);
_currentHeartbeatInterval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval);
await SendHeartbeat(_conn!);
await SendIdentifyOrResume();
}
private async Task SendIdentifyOrResume()
{
if (SessionInfo.Session != null && SessionInfo.LastSequence != null)
await SendResume(SessionInfo.Session, SessionInfo.LastSequence!.Value);
else
await SendIdentify();
}
private async Task SendIdentify() private async Task SendIdentify()
{ {
_logger.Information("Shard {ShardId}: Sending gateway Identify for shard {@ShardInfo}", ShardId, ShardInfo); await _conn.Send(new GatewayPacket
await _conn!.Send(new GatewayPacket
{ {
Opcode = GatewayOpcode.Identify, Opcode = GatewayOpcode.Identify,
Payload = new GatewayIdentify Payload = new GatewayIdentify
{ {
Token = Settings.Token, Compress = false,
Intents = _settings.Intents,
Properties = new GatewayIdentify.ConnectionProperties Properties = new GatewayIdentify.ConnectionProperties
{ {
Browser = LibraryName, Device = LibraryName, Os = Environment.OSVersion.ToString() Browser = LibraryName,
Device = LibraryName,
Os = Environment.OSVersion.ToString()
}, },
Intents = Settings.Intents, Shard = _info,
Shard = ShardInfo Token = _settings.Token,
LargeThreshold = 50
} }
}); });
} }
private async Task SendResume(string session, int lastSequence) private async Task SendResume((string SessionId, int? LastSeq) arg)
{ {
_logger.Information("Shard {ShardId}: Sending gateway Resume for session {@SessionInfo}", await _conn.Send(new GatewayPacket
ShardId, SessionInfo);
await _conn!.Send(new GatewayPacket
{ {
Opcode = GatewayOpcode.Resume, Opcode = GatewayOpcode.Resume,
Payload = new GatewayResume(Settings.Token, session, lastSequence) Payload = new GatewayResume(_settings.Token, arg.SessionId, arg.LastSeq ?? 0)
}); });
} }
public enum ShardState private async Task SendHeartbeat(int? lastSeq)
{ {
Closed, await _conn.Send(new GatewayPacket {Opcode = GatewayOpcode.Heartbeat, Payload = lastSeq});
Connecting, }
Open,
Closing private async Task Reconnect(WebSocketCloseStatus closeStatus, TimeSpan delay)
{
_reconnectDelay = delay;
await DisconnectInner(closeStatus);
}
private async Task HandleEvent(IGatewayEvent arg)
{
if (arg is ReadyEvent)
Ready?.Invoke();
if (arg is ResumedEvent)
Resumed?.Invoke();
await (OnEventReceived?.Invoke(arg) ?? Task.CompletedTask);
}
private async Task HandleConnectionOpened()
{
_logger.Information("Shard {ShardId}: Connection opened", _info.ShardId);
await _stateManager.HandleConnectionOpened();
SocketOpened?.Invoke();
}
private async Task HandleConnectionClosed(WebSocketCloseStatus? closeStatus, string? description)
{
_logger.Information("Shard {ShardId}: Connection closed ({CloseStatus}/{Description})",
_info.ShardId, closeStatus, description ?? "<null>");
await _stateManager.HandleConnectionClosed();
SocketClosed?.Invoke(closeStatus, description);
} }
} }
} }

View File

@ -1,6 +1,4 @@
using System; using System;
using System.Buffers;
using System.IO;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text.Json; using System.Text.Json;
using System.Threading; using System.Threading;
@ -12,120 +10,95 @@ namespace Myriad.Gateway
{ {
public class ShardConnection: IAsyncDisposable public class ShardConnection: IAsyncDisposable
{ {
private readonly MemoryStream _bufStream = new(); private ClientWebSocket? _client;
private readonly ClientWebSocket _client = new();
private readonly CancellationTokenSource _cts = new();
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly Task _worker; private readonly ShardPacketSerializer _serializer;
public ShardConnection(Uri uri, ILogger logger, JsonSerializerOptions jsonSerializerOptions)
{
_logger = logger;
_jsonSerializerOptions = jsonSerializerOptions;
_worker = Worker(uri);
}
public Func<GatewayPacket, Task>? OnReceive { get; set; }
public Action? OnOpen { get; set; }
public Action<WebSocketCloseStatus, string?>? OnClose { get; set; }
public WebSocketState State => _client.State;
public async ValueTask DisposeAsync() public WebSocketState State => _client?.State ?? WebSocketState.Closed;
public WebSocketCloseStatus? CloseStatus => _client?.CloseStatus;
public string? CloseStatusDescription => _client?.CloseStatusDescription;
public ShardConnection(JsonSerializerOptions jsonSerializerOptions, ILogger logger)
{ {
_cts.Cancel(); _logger = logger.ForContext<ShardConnection>();
await _worker; _serializer = new(jsonSerializerOptions);
_client.Dispose();
await _bufStream.DisposeAsync();
_cts.Dispose();
} }
private async Task Worker(Uri uri) public async Task Connect(string url, CancellationToken ct)
{ {
var realUrl = new UriBuilder(uri) _client?.Dispose();
{ _client = new ClientWebSocket();
Query = "v=8&encoding=json"
}.Uri;
_logger.Debug("Connecting to gateway WebSocket at {GatewayUrl}", realUrl);
await _client.ConnectAsync(realUrl, default);
_logger.Debug("Gateway connection opened");
OnOpen?.Invoke(); await _client.ConnectAsync(GetConnectionUri(url), ct);
// Main worker loop, spins until we manually disconnect (which hits the cancellation token)
// or the server disconnects us (which sets state to closed)
while (!_cts.IsCancellationRequested && _client.State == WebSocketState.Open)
{
try
{
await HandleReceive();
}
catch (Exception e)
{
_logger.Error(e, "Error in WebSocket receive worker");
}
}
OnClose?.Invoke(_client.CloseStatus ?? default, _client.CloseStatusDescription);
} }
private async Task HandleReceive() public async Task Disconnect(WebSocketCloseStatus closeStatus, string? reason)
{ {
_bufStream.SetLength(0); await CloseInner(closeStatus, reason);
var result = await ReadData(_bufStream);
var data = _bufStream.GetBuffer().AsMemory(0, (int) _bufStream.Position);
if (result.MessageType == WebSocketMessageType.Text)
await HandleReceiveData(data);
else if (result.MessageType == WebSocketMessageType.Close)
_logger.Information("WebSocket closed by server: {StatusCode} {Reason}", _client.CloseStatus,
_client.CloseStatusDescription);
}
private async Task HandleReceiveData(Memory<byte> data)
{
var packet = JsonSerializer.Deserialize<GatewayPacket>(data.Span, _jsonSerializerOptions)!;
try
{
if (OnReceive != null)
await OnReceive.Invoke(packet);
}
catch (Exception e)
{
_logger.Error(e, "Error in gateway handler for {OpcodeType}", packet.Opcode);
}
}
private async Task<ValueWebSocketReceiveResult> ReadData(MemoryStream stream)
{
// TODO: does this throw if we disconnect mid-read?
using var buf = MemoryPool<byte>.Shared.Rent();
ValueWebSocketReceiveResult result;
do
{
result = await _client.ReceiveAsync(buf.Memory, _cts.Token);
stream.Write(buf.Memory.Span.Slice(0, result.Count));
} while (!result.EndOfMessage);
return result;
} }
public async Task Send(GatewayPacket packet) public async Task Send(GatewayPacket packet)
{ {
var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions); if (_client == null || _client.State != WebSocketState.Open)
await _client.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default); return;
try
{
await _serializer.WritePacket(_client, packet);
}
catch (Exception e)
{
_logger.Error(e, "Error sending WebSocket message");
}
} }
public async Task Disconnect(WebSocketCloseStatus status, string? description) public async ValueTask DisposeAsync()
{ {
await _client.CloseAsync(status, description, default); await CloseInner(WebSocketCloseStatus.NormalClosure, null);
_cts.Cancel(); _client?.Dispose();
}
public async Task<GatewayPacket?> Read()
{
if (_client == null || _client.State != WebSocketState.Open)
return null;
try
{
var (_, packet) = await _serializer.ReadPacket(_client);
return packet;
}
catch (Exception e)
{
_logger.Error(e, "Error reading from WebSocket");
}
return null;
}
private Uri GetConnectionUri(string baseUri) => new UriBuilder(baseUri)
{
Query = "v=8&encoding=json"
}.Uri;
private async Task CloseInner(WebSocketCloseStatus closeStatus, string? description)
{
if (_client == null)
return;
if (_client.State != WebSocketState.Connecting && _client.State != WebSocketState.Open)
return;
// Close with timeout, mostly to work around https://github.com/dotnet/runtime/issues/51590
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
try
{
await _client.CloseAsync(closeStatus, description, cts.Token);
}
catch (Exception e)
{
_logger.Error(e, "Error closing WebSocket connection");
}
} }
} }
} }

View File

@ -0,0 +1,72 @@
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;
using Serilog;
namespace Myriad.Gateway
{
public class ShardIdentifyRatelimiter
{
private static readonly TimeSpan BucketLength = TimeSpan.FromSeconds(5);
private readonly ConcurrentDictionary<int, ConcurrentQueue<TaskCompletionSource>> _buckets = new();
private readonly int _maxConcurrency;
private Task? _refillTask;
private readonly ILogger _logger;
public ShardIdentifyRatelimiter(ILogger logger, int maxConcurrency)
{
_logger = logger.ForContext<ShardIdentifyRatelimiter>();
_maxConcurrency = maxConcurrency;
}
public Task Acquire(int shard)
{
var bucket = shard % _maxConcurrency;
var queue = _buckets.GetOrAdd(bucket, _ => new ConcurrentQueue<TaskCompletionSource>());
var tcs = new TaskCompletionSource();
queue.Enqueue(tcs);
ScheduleRefill();
return tcs.Task;
}
private void ScheduleRefill()
{
if (_refillTask != null && !_refillTask.IsCompleted)
return;
_refillTask?.Dispose();
_refillTask = RefillTask();
}
private async Task RefillTask()
{
await Task.Delay(TimeSpan.FromMilliseconds(250));
while (true)
{
var isClear = true;
foreach (var (bucket, queue) in _buckets)
{
if (!queue.TryDequeue(out var tcs))
continue;
_logger.Information(
"Allowing identify for bucket {BucketId} through ({QueueLength} left in bucket queue)",
bucket, queue.Count);
tcs.SetResult();
isClear = false;
}
if (isClear)
return;
await Task.Delay(BucketLength);
}
}
}
}

View File

@ -0,0 +1,70 @@
using System;
using System.Buffers;
using System.IO;
using System.Net.WebSockets;
using System.Text.Json;
using System.Threading.Tasks;
namespace Myriad.Gateway
{
public class ShardPacketSerializer
{
private const int BufferSize = 64 * 1024;
private readonly JsonSerializerOptions _jsonSerializerOptions;
public ShardPacketSerializer(JsonSerializerOptions jsonSerializerOptions)
{
_jsonSerializerOptions = jsonSerializerOptions;
}
public async ValueTask<(WebSocketMessageType type, GatewayPacket? packet)> ReadPacket(ClientWebSocket socket)
{
using var buf = MemoryPool<byte>.Shared.Rent(BufferSize);
var res = await socket.ReceiveAsync(buf.Memory, default);
if (res.MessageType == WebSocketMessageType.Close)
return (res.MessageType, null);
if (res.EndOfMessage)
// Entire packet fits within one buffer, deserialize directly
return DeserializeSingleBuffer(buf, res);
// Otherwise copy to stream buffer and deserialize from there
return await DeserializeMultipleBuffer(socket, buf, res);
}
public async Task WritePacket(ClientWebSocket socket, GatewayPacket packet)
{
var bytes = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions);
await socket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default);
}
private async Task<(WebSocketMessageType type, GatewayPacket packet)> DeserializeMultipleBuffer(ClientWebSocket socket, IMemoryOwner<byte> buf, ValueWebSocketReceiveResult res)
{
await using var stream = new MemoryStream(BufferSize * 4);
stream.Write(buf.Memory.Span.Slice(0, res.Count));
while (!res.EndOfMessage)
{
res = await socket.ReceiveAsync(buf.Memory, default);
stream.Write(buf.Memory.Span.Slice(0, res.Count));
}
return DeserializeObject(res, stream.GetBuffer().AsSpan(0, (int) stream.Length));
}
private (WebSocketMessageType type, GatewayPacket packet) DeserializeSingleBuffer(
IMemoryOwner<byte> buf, ValueWebSocketReceiveResult res)
{
var span = buf.Memory.Span.Slice(0, res.Count);
return DeserializeObject(res, span);
}
private (WebSocketMessageType type, GatewayPacket packet) DeserializeObject(ValueWebSocketReceiveResult res, Span<byte> span)
{
var packet = JsonSerializer.Deserialize<GatewayPacket>(span, _jsonSerializerOptions)!;
return (res.MessageType, packet);
}
}
}

View File

@ -1,8 +0,0 @@
namespace Myriad.Gateway
{
public record ShardSessionInfo
{
public string? Session { get; init; }
public int? LastSequence { get; init; }
}
}

View File

@ -0,0 +1,63 @@
using System;
using System.Threading;
using System.Threading.Tasks;
namespace Myriad.Gateway.State
{
public class HeartbeatWorker: IAsyncDisposable
{
private Task? _worker;
private CancellationTokenSource? _workerCts;
public TimeSpan? CurrentHeartbeatInterval { get; private set; }
public async ValueTask Start(TimeSpan heartbeatInterval, Func<Task> callback)
{
if (_worker != null)
await Stop();
CurrentHeartbeatInterval = heartbeatInterval;
_workerCts = new CancellationTokenSource();
_worker = Worker(heartbeatInterval, callback, _workerCts.Token);
}
public async ValueTask Stop()
{
if (_worker == null)
return;
_workerCts?.Cancel();
try
{
await _worker;
}
catch (TaskCanceledException) { }
_worker?.Dispose();
_workerCts?.Dispose();
_worker = null;
CurrentHeartbeatInterval = null;
}
private async Task Worker(TimeSpan heartbeatInterval, Func<Task> callback, CancellationToken ct)
{
var initialDelay = GetInitialHeartbeatDelay(heartbeatInterval);
await Task.Delay(initialDelay, ct);
while (!ct.IsCancellationRequested)
{
await callback();
await Task.Delay(heartbeatInterval, ct);
}
}
private static TimeSpan GetInitialHeartbeatDelay(TimeSpan heartbeatInterval) =>
// Docs specify `heartbeat_interval * random.random()` but we'll add a lil buffer :)
heartbeatInterval * (new Random().NextDouble() * 0.9 + 0.05);
public async ValueTask DisposeAsync()
{
await Stop();
}
}
}

View File

@ -0,0 +1,11 @@
namespace Myriad.Gateway.State
{
public enum ShardState
{
Disconnected,
Handshaking,
Identifying,
Connected,
Reconnecting
}
}

View File

@ -0,0 +1,244 @@
using System;
using System.Net.WebSockets;
using System.Text.Json;
using System.Threading.Tasks;
using Myriad.Gateway.State;
using Myriad.Types;
using Serilog;
namespace Myriad.Gateway
{
public class ShardStateManager
{
private readonly HeartbeatWorker _heartbeatWorker = new();
private readonly ILogger _logger;
private readonly ShardInfo _info;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private ShardState _state = ShardState.Disconnected;
private DateTimeOffset? _lastHeartbeatSent;
private TimeSpan? _latency;
private bool _hasReceivedHeartbeatAck;
private string? _sessionId;
private int? _lastSeq;
public ShardState State => _state;
public TimeSpan? Latency => _latency;
public User? User { get; private set; }
public ApplicationPartial? Application { get; private set; }
public Func<Task> SendIdentify { get; init; }
public Func<(string SessionId, int? LastSeq), Task> SendResume { get; init; }
public Func<int?, Task> SendHeartbeat { get; init; }
public Func<WebSocketCloseStatus, TimeSpan, Task> Reconnect { get; init; }
public Func<Task> Connect { get; init; }
public Func<IGatewayEvent, Task> HandleEvent { get; init; }
public event Action<TimeSpan> OnHeartbeatReceived;
public ShardStateManager(ShardInfo info, JsonSerializerOptions jsonSerializerOptions, ILogger logger)
{
_info = info;
_jsonSerializerOptions = jsonSerializerOptions;
_logger = logger.ForContext<ShardStateManager>();
}
public Task HandleConnectionOpened()
{
_state = ShardState.Handshaking;
return Task.CompletedTask;
}
public async Task HandleConnectionClosed()
{
_latency = null;
await _heartbeatWorker.Stop();
}
public async Task HandlePacketReceived(GatewayPacket packet)
{
switch (packet.Opcode)
{
case GatewayOpcode.Hello:
var hello = DeserializePayload<GatewayHello>(packet);
await HandleHello(hello);
break;
case GatewayOpcode.Heartbeat:
await HandleHeartbeatRequest();
break;
case GatewayOpcode.HeartbeatAck:
await HandleHeartbeatAck();
break;
case GatewayOpcode.Reconnect:
{
await HandleReconnect();
break;
}
case GatewayOpcode.InvalidSession:
{
var canResume = DeserializePayload<bool>(packet);
await HandleInvalidSession(canResume);
break;
}
case GatewayOpcode.Dispatch:
_lastSeq = packet.Sequence;
var evt = DeserializeEvent(packet.EventType!, (JsonElement) packet.Payload!);
if (evt != null)
{
if (evt is ReadyEvent ready)
await HandleReady(ready);
if (evt is ResumedEvent)
await HandleResumed();
await HandleEvent(evt);
}
break;
}
}
private async Task HandleHello(GatewayHello hello)
{
var interval = TimeSpan.FromMilliseconds(hello.HeartbeatInterval);
_hasReceivedHeartbeatAck = true;
await _heartbeatWorker.Start(interval, HandleHeartbeatTimer);
await IdentifyOrResume();
}
private async Task IdentifyOrResume()
{
_state = ShardState.Identifying;
if (_sessionId != null)
{
_logger.Information("Shard {ShardId}: Received Hello, attempting to resume (seq {LastSeq})",
_info.ShardId, _lastSeq);
await SendResume((_sessionId!, _lastSeq));
}
else
{
_logger.Information("Shard {ShardId}: Received Hello, identifying",
_info.ShardId);
await SendIdentify();
}
}
private Task HandleHeartbeatAck()
{
_hasReceivedHeartbeatAck = true;
_latency = DateTimeOffset.UtcNow - _lastHeartbeatSent;
OnHeartbeatReceived?.Invoke(_latency!.Value);
_logger.Information("Shard {ShardId}: Received Heartbeat (latency {Latency} ms)",
_info.ShardId, _latency);
return Task.CompletedTask;
}
private async Task HandleInvalidSession(bool canResume)
{
if (!canResume)
{
_sessionId = null;
_lastSeq = null;
}
_logger.Information("Shard {ShardId}: Received Invalid Session (can resume? {CanResume})",
_info.ShardId, canResume);
var delay = TimeSpan.FromMilliseconds(new Random().Next(1000, 5000));
await DoReconnect(WebSocketCloseStatus.NormalClosure, delay);
}
private async Task HandleReconnect()
{
_logger.Information("Shard {ShardId}: Received Reconnect", _info.ShardId);
await DoReconnect(WebSocketCloseStatus.NormalClosure, TimeSpan.FromSeconds(1));
}
private Task HandleReady(ReadyEvent ready)
{
_logger.Information("Shard {ShardId}: Received Ready", _info.ShardId);
_sessionId = ready.SessionId;
_state = ShardState.Connected;
User = ready.User;
Application = ready.Application;
return Task.CompletedTask;
}
private Task HandleResumed()
{
_logger.Information("Shard {ShardId}: Received Resume", _info.ShardId);
_state = ShardState.Connected;
return Task.CompletedTask;
}
private async Task HandleHeartbeatRequest()
{
await SendHeartbeatInternal();
}
private async Task SendHeartbeatInternal()
{
await SendHeartbeat(_lastSeq);
_lastHeartbeatSent = DateTimeOffset.UtcNow;
}
private async Task HandleHeartbeatTimer()
{
if (!_hasReceivedHeartbeatAck)
{
_logger.Warning("Shard {ShardId}: Heartbeat worker timed out", _info.ShardId);
await DoReconnect(WebSocketCloseStatus.ProtocolError, TimeSpan.Zero);
return;
}
await SendHeartbeatInternal();
}
private async Task DoReconnect(WebSocketCloseStatus closeStatus, TimeSpan delay)
{
_state = ShardState.Reconnecting;
await Reconnect(closeStatus, delay);
}
private T DeserializePayload<T>(GatewayPacket packet)
{
var packetPayload = (JsonElement) packet.Payload!;
return JsonSerializer.Deserialize<T>(packetPayload.GetRawText(), _jsonSerializerOptions)!;
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{
_logger.Debug("Shard {ShardId}: Received unknown event type {EventType}", _info.ShardId, eventType);
return null;
}
try
{
_logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", _info.ShardId, eventType, clrType);
return JsonSerializer.Deserialize(payload.GetRawText(), clrType, _jsonSerializerOptions)
as IGatewayEvent;
}
catch (JsonException e)
{
_logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", _info.ShardId, eventType, clrType);
return null;
}
}
}
}

View File

@ -69,12 +69,14 @@ namespace Myriad.Rest.Ratelimit
private void PruneStaleBuckets(DateTimeOffset now) private void PruneStaleBuckets(DateTimeOffset now)
{ {
foreach (var (key, bucket) in _buckets) foreach (var (key, bucket) in _buckets)
if (now - bucket.LastUsed > StaleBucketTimeout) {
{ if (now - bucket.LastUsed <= StaleBucketTimeout)
_logger.Debug("Pruning unused bucket {Bucket} (last used at {BucketLastUsed})", bucket, continue;
bucket.LastUsed);
_buckets.TryRemove(key, out _); _logger.Debug("Pruning unused bucket {BucketKey}/{BucketMajor} (last used at {BucketLastUsed})",
} bucket.Key, bucket.Major, bucket.LastUsed);
_buckets.TryRemove(key, out _);
}
} }
} }
} }

View File

@ -1,6 +1,4 @@
using System.Collections.Generic; namespace Myriad.Types
namespace Myriad.Types
{ {
public record Application: ApplicationPartial public record Application: ApplicationPartial
{ {

View File

@ -1,6 +1,4 @@
using System.Collections.Generic; namespace Myriad.Types
namespace Myriad.Types
{ {
public record Guild public record Guild
{ {

View File

@ -287,7 +287,7 @@ namespace PluralKit.Bot
{ {
new ActivityPartial new ActivityPartial
{ {
Name = $"pk;help | in {totalGuilds:N0} servers | shard #{shard.ShardInfo?.ShardId}", Name = $"pk;help | in {totalGuilds:N0} servers | shard #{shard.ShardId}",
Type = ActivityType.Game, Type = ActivityType.Game,
Url = "https://pluralkit.me/" Url = "https://pluralkit.me/"
} }

View File

@ -29,8 +29,6 @@ namespace PluralKit.Bot
private readonly MessageCreateEvent _message; private readonly MessageCreateEvent _message;
private readonly Parameters _parameters; private readonly Parameters _parameters;
private readonly MessageContext _messageContext; private readonly MessageContext _messageContext;
private readonly PermissionSet _botPermissions;
private readonly PermissionSet _userPermissions;
private readonly IDatabase _db; private readonly IDatabase _db;
private readonly ModelRepository _repo; private readonly ModelRepository _repo;
@ -42,7 +40,7 @@ namespace PluralKit.Bot
private Command _currentCommand; private Command _currentCommand;
public Context(ILifetimeScope provider, Shard shard, Guild? guild, Channel channel, MessageCreateEvent message, int commandParseOffset, public Context(ILifetimeScope provider, Shard shard, Guild? guild, Channel channel, MessageCreateEvent message, int commandParseOffset,
PKSystem senderSystem, MessageContext messageContext, PermissionSet botPermissions) PKSystem senderSystem, MessageContext messageContext)
{ {
_message = message; _message = message;
_shard = shard; _shard = shard;
@ -59,9 +57,6 @@ namespace PluralKit.Bot
_parameters = new Parameters(message.Content?.Substring(commandParseOffset)); _parameters = new Parameters(message.Content?.Substring(commandParseOffset));
_rest = provider.Resolve<DiscordApiClient>(); _rest = provider.Resolve<DiscordApiClient>();
_cluster = provider.Resolve<Cluster>(); _cluster = provider.Resolve<Cluster>();
_botPermissions = botPermissions;
_userPermissions = _cache.PermissionsFor(message);
} }
public IDiscordCache Cache => _cache; public IDiscordCache Cache => _cache;
@ -76,8 +71,8 @@ namespace PluralKit.Bot
public Cluster Cluster => _cluster; public Cluster Cluster => _cluster;
public MessageContext MessageContext => _messageContext; public MessageContext MessageContext => _messageContext;
public PermissionSet BotPermissions => _botPermissions; public PermissionSet BotPermissions => _provider.Resolve<Bot>().PermissionsIn(_channel.Id);
public PermissionSet UserPermissions => _userPermissions; public PermissionSet UserPermissions => _cache.PermissionsFor(_message);
public DiscordApiClient Rest => _rest; public DiscordApiClient Rest => _rest;

View File

@ -84,7 +84,7 @@ namespace PluralKit.Bot {
var totalSwitches = _metrics.Snapshot.GetForContext("Application").Gauges.FirstOrDefault(m => m.MultidimensionalName == CoreMetrics.SwitchCount.Name)?.Value ?? 0; var totalSwitches = _metrics.Snapshot.GetForContext("Application").Gauges.FirstOrDefault(m => m.MultidimensionalName == CoreMetrics.SwitchCount.Name)?.Value ?? 0;
var totalMessages = _metrics.Snapshot.GetForContext("Application").Gauges.FirstOrDefault(m => m.MultidimensionalName == CoreMetrics.MessageCount.Name)?.Value ?? 0; var totalMessages = _metrics.Snapshot.GetForContext("Application").Gauges.FirstOrDefault(m => m.MultidimensionalName == CoreMetrics.MessageCount.Name)?.Value ?? 0;
var shardId = ctx.Shard.ShardInfo.ShardId; var shardId = ctx.Shard.ShardId;
var shardTotal = ctx.Cluster.Shards.Count; var shardTotal = ctx.Cluster.Shards.Count;
var shardUpTotal = _shards.Shards.Where(x => x.Connected).Count(); var shardUpTotal = _shards.Shards.Where(x => x.Connected).Count();
var shardInfo = _shards.GetShardInfo(ctx.Shard); var shardInfo = _shards.GetShardInfo(ctx.Shard);

View File

@ -114,7 +114,7 @@ namespace PluralKit.Bot
try try
{ {
var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null; var system = ctx.SystemId != null ? await _db.Execute(c => _repo.GetSystem(c, ctx.SystemId.Value)) : null;
await _tree.ExecuteCommand(new Context(_services, shard, guild, channel, evt, cmdStart, system, ctx, _bot.PermissionsIn(channel.Id))); await _tree.ExecuteCommand(new Context(_services, shard, guild, channel, evt, cmdStart, system, ctx));
} }
catch (PKError) catch (PKError)
{ {

View File

@ -49,7 +49,7 @@ namespace PluralKit.Bot
// Start the Discord shards themselves (handlers already set up) // Start the Discord shards themselves (handlers already set up)
logger.Information("Connecting to Discord"); logger.Information("Connecting to Discord");
var info = await services.Resolve<DiscordApiClient>().GetGatewayBot(); var info = await services.Resolve<DiscordApiClient>().GetGatewayBot();
await services.Resolve<Cluster>().Start(info); await services.Resolve<Cluster>().Start(info with { Shards = 10 });
logger.Information("Connected! All is good (probably)."); logger.Information("Connected! All is good (probably).");
// Lastly, we just... wait. Everything else is handled in the DiscordClient event loop // Lastly, we just... wait. Everything else is handled in the DiscordClient event loop

View File

@ -2,7 +2,6 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Threading.Tasks;
using App.Metrics; using App.Metrics;
@ -66,11 +65,8 @@ namespace PluralKit.Bot
} else _shardInfo[shard.ShardId] = info = new ShardInfo(); } else _shardInfo[shard.ShardId] = info = new ShardInfo();
// Call our own SocketOpened listener manually (and then attach the listener properly) // Call our own SocketOpened listener manually (and then attach the listener properly)
SocketOpened(shard);
shard.SocketOpened += () => SocketOpened(shard);
// Register listeners for new shards // Register listeners for new shards
_logger.Information("Attaching listeners to new shard #{Shard}", shard.ShardId);
shard.Resumed += () => Resumed(shard); shard.Resumed += () => Resumed(shard);
shard.Ready += () => Ready(shard); shard.Ready += () => Ready(shard);
shard.SocketClosed += (closeStatus, message) => SocketClosed(shard, closeStatus, message); shard.SocketClosed += (closeStatus, message) => SocketClosed(shard, closeStatus, message);
@ -78,14 +74,6 @@ namespace PluralKit.Bot
// Register that we've seen it // Register that we've seen it
info.HasAttachedListeners = true; info.HasAttachedListeners = true;
}
private void SocketOpened(Shard shard)
{
// We do nothing else here, since this kinda doesn't mean *much*? It's only really started once we get Ready/Resumed
// And it doesn't get fired first time around since we don't have time to add the event listener before it's fired'
_logger.Information("Shard #{Shard} opened socket", shard.ShardId);
} }
private ShardInfo TryGetShard(Shard shard) private ShardInfo TryGetShard(Shard shard)
@ -100,29 +88,22 @@ namespace PluralKit.Bot
private void Resumed(Shard shard) private void Resumed(Shard shard)
{ {
_logger.Information("Shard #{Shard} resumed connection", shard.ShardId);
var info = TryGetShard(shard);
// info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant();
info.Connected = true;
ReportShardStatus();
}
private void Ready(Shard shard)
{
_logger.Information("Shard #{Shard} sent Ready event", shard.ShardId);
var info = TryGetShard(shard); var info = TryGetShard(shard);
info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant(); info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant();
info.Connected = true; info.Connected = true;
ReportShardStatus(); ReportShardStatus();
} }
private void SocketClosed(Shard shard, WebSocketCloseStatus closeStatus, string message) private void Ready(Shard shard)
{
var info = TryGetShard(shard);
info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant();
info.Connected = true;
ReportShardStatus();
}
private void SocketClosed(Shard shard, WebSocketCloseStatus? closeStatus, string message)
{ {
_logger.Warning("Shard #{Shard} disconnected ({CloseCode}: {CloseMessage})",
shard.ShardId, closeStatus, message);
var info = TryGetShard(shard); var info = TryGetShard(shard);
info.DisconnectionCount++; info.DisconnectionCount++;
info.Connected = false; info.Connected = false;
@ -131,9 +112,6 @@ namespace PluralKit.Bot
private void Heartbeated(Shard shard, TimeSpan latency) private void Heartbeated(Shard shard, TimeSpan latency)
{ {
_logger.Information("Shard #{Shard} received heartbeat (latency: {Latency} ms)",
shard.ShardId, latency.Milliseconds);
var info = TryGetShard(shard); var info = TryGetShard(shard);
info.LastHeartbeatTime = SystemClock.Instance.GetCurrentInstant(); info.LastHeartbeatTime = SystemClock.Instance.GetCurrentInstant();
info.Connected = true; info.Connected = true;

View File

@ -155,6 +155,7 @@ namespace PluralKit.Bot {
// "escape hatch", clean up as if we hit X // "escape hatch", clean up as if we hit X
} }
// todo: re-check
if (ctx.BotPermissions.HasFlag(PermissionSet.ManageMessages)) if (ctx.BotPermissions.HasFlag(PermissionSet.ManageMessages))
await ctx.Rest.DeleteAllReactions(msg.ChannelId, msg.Id); await ctx.Rest.DeleteAllReactions(msg.ChannelId, msg.Id);
} }

View File

@ -38,7 +38,7 @@ namespace PluralKit.Bot
{"guild", evt.GuildId.ToString()}, {"guild", evt.GuildId.ToString()},
{"message", evt.Id.ToString()}, {"message", evt.Id.ToString()},
}); });
scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); scope.SetTag("shard", shard.ShardId.ToString());
// Also report information about the bot's permissions in the channel // Also report information about the bot's permissions in the channel
// We get a lot of permission errors so this'll be useful for determining problems // We get a lot of permission errors so this'll be useful for determining problems
@ -55,7 +55,7 @@ namespace PluralKit.Bot
{"guild", evt.GuildId.ToString()}, {"guild", evt.GuildId.ToString()},
{"message", evt.Id.ToString()}, {"message", evt.Id.ToString()},
}); });
scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); scope.SetTag("shard", shard.ShardId.ToString());
} }
public void Enrich(Scope scope, Shard shard, MessageUpdateEvent evt) public void Enrich(Scope scope, Shard shard, MessageUpdateEvent evt)
@ -67,7 +67,7 @@ namespace PluralKit.Bot
{"guild", evt.GuildId.Value.ToString()}, {"guild", evt.GuildId.Value.ToString()},
{"message", evt.Id.ToString()} {"message", evt.Id.ToString()}
}); });
scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); scope.SetTag("shard", shard.ShardId.ToString());
} }
public void Enrich(Scope scope, Shard shard, MessageDeleteBulkEvent evt) public void Enrich(Scope scope, Shard shard, MessageDeleteBulkEvent evt)
@ -79,7 +79,7 @@ namespace PluralKit.Bot
{"guild", evt.GuildId.ToString()}, {"guild", evt.GuildId.ToString()},
{"messages", string.Join(",", evt.Ids)}, {"messages", string.Join(",", evt.Ids)},
}); });
scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); scope.SetTag("shard", shard.ShardId.ToString());
} }
public void Enrich(Scope scope, Shard shard, MessageReactionAddEvent evt) public void Enrich(Scope scope, Shard shard, MessageReactionAddEvent evt)
@ -93,7 +93,7 @@ namespace PluralKit.Bot
{"message", evt.MessageId.ToString()}, {"message", evt.MessageId.ToString()},
{"reaction", evt.Emoji.Name} {"reaction", evt.Emoji.Name}
}); });
scope.SetTag("shard", shard.ShardInfo.ShardId.ToString()); scope.SetTag("shard", shard.ShardId.ToString());
} }
} }
} }