using System.Net.WebSockets; using Google.Protobuf; using Myriad.Gateway; using NodaTime; using StackExchange.Redis; using PluralKit.Core; using Serilog; namespace PluralKit.Bot; public class ShardInfoService { private readonly int? _clusterId; private readonly ILogger _logger; private readonly Cluster _client; private readonly RedisService _redis; private readonly Dictionary _shardInfo = new(); public ShardInfoService(ILogger logger, Cluster client, RedisService redis, BotConfig config) { _logger = logger.ForContext(); _client = client; _redis = redis; _clusterId = config.Cluster?.NodeIndex; } public void Init() { // We initialize this before any shards are actually created and connected // This means the client won't know the shard count, so we attach a listener every time a shard gets connected _client.ShardCreated += InitializeShard; } public async Task> GetShards() { if (_redis.Connection == null) return new ShardState[] { }; var db = _redis.Connection.GetDatabase(); var redisInfo = await db.HashGetAllAsync("pluralkit:shardstatus"); return redisInfo.Select(x => Proto.Unmarshal(x.Value)); } private void InitializeShard(Shard shard) { _ = Inner(); async Task Inner() { if (_redis.Connection == null) { _logger.Warning("Redis is disabled, shard connection status will be unavailable."); return; } var db = _redis.Connection.GetDatabase(); var redisInfo = await db.HashGetAsync("pluralkit::shardstatus", shard.ShardId); // Skip adding listeners if we've seen this shard & already added listeners to it if (redisInfo.HasValue) return; // latency = 0 because otherwise shard 0 would serialize to an empty array, thanks protobuf var state = new ShardState() { ShardId = shard.ShardId, Up = false, Latency = 1 }; if (_clusterId != null) state.ClusterId = _clusterId.Value; // Register listeners for new shard shard.Resumed += () => ReadyOrResumed(shard); shard.Ready += () => ReadyOrResumed(shard); shard.SocketClosed += (closeStatus, message) => SocketClosed(shard, closeStatus, message); shard.HeartbeatReceived += latency => Heartbeated(shard, latency); // Register that we've seen it await db.HashSetAsync("pluralkit:shardstatus", state.HashWrapper()); } } private async Task TryGetShard(Shard shard) { var db = _redis.Connection.GetDatabase(); var redisInfo = await db.HashGetAsync("pluralkit:shardstatus", shard.ShardId); if (redisInfo.HasValue) return Proto.Unmarshal(redisInfo); return null; } private void ReadyOrResumed(Shard shard) { _ = DoAsync(async () => { var info = await TryGetShard(shard); info.LastConnection = (int)SystemClock.Instance.GetCurrentInstant().ToUnixTimeSeconds(); info.Up = true; var db = _redis.Connection.GetDatabase(); await db.HashSetAsync("pluralkit:shardstatus", info.HashWrapper()); }); } private void SocketClosed(Shard shard, WebSocketCloseStatus? closeStatus, string message) { _ = DoAsync(async () => { var info = await TryGetShard(shard); info.DisconnectionCount++; info.Up = false; var db = _redis.Connection.GetDatabase(); await db.HashSetAsync("pluralkit:shardstatus", info.HashWrapper()); }); } private void Heartbeated(Shard shard, TimeSpan latency) { _ = DoAsync(async () => { var info = await TryGetShard(shard); info.LastHeartbeat = (int)SystemClock.Instance.GetCurrentInstant().ToUnixTimeSeconds(); info.Up = true; info.Latency = (int)latency.TotalMilliseconds; var db = _redis.Connection.GetDatabase(); await db.HashSetAsync("pluralkit:shardstatus", info.HashWrapper()); }); } private async Task DoAsync(Func fn) { // wrapper function to log errors because we "async void" it at call site :( try { await fn(); } catch (Exception e) { _logger.Error(e, "Error persisting shard status"); } } } public static class RedisExt { // convenience method public static HashEntry[] HashWrapper(this ShardState state) => new[] { new HashEntry(state.ShardId, state.ToByteArray()) }; }