feat: store shard status in Redis
This commit is contained in:
		| @@ -1,11 +1,12 @@ | ||||
| using System.Net.WebSockets; | ||||
|  | ||||
| using App.Metrics; | ||||
| using Google.Protobuf; | ||||
|  | ||||
| using Myriad.Gateway; | ||||
|  | ||||
| using NodaTime; | ||||
| using NodaTime.Extensions; | ||||
|  | ||||
| using StackExchange.Redis; | ||||
|  | ||||
| using PluralKit.Core; | ||||
|  | ||||
| @@ -13,30 +14,20 @@ using Serilog; | ||||
|  | ||||
| namespace PluralKit.Bot; | ||||
|  | ||||
| // TODO: how much of this do we need now that we have logging in the shard library? | ||||
| // A lot could probably be cleaned up... | ||||
| public class ShardInfoService | ||||
| { | ||||
|     private readonly Cluster _client; | ||||
|  | ||||
|     private readonly IDatabase _db; | ||||
|     private readonly ILogger _logger; | ||||
|  | ||||
|     private readonly IMetrics _metrics; | ||||
|     private readonly ModelRepository _repo; | ||||
|     private readonly Cluster _client; | ||||
|     private readonly RedisService _redis; | ||||
|     private readonly Dictionary<int, ShardInfo> _shardInfo = new(); | ||||
|  | ||||
|     public ShardInfoService(ILogger logger, Cluster client, IMetrics metrics, IDatabase db, ModelRepository repo) | ||||
|     public ShardInfoService(ILogger logger, Cluster client, RedisService redis) | ||||
|     { | ||||
|         _client = client; | ||||
|         _metrics = metrics; | ||||
|         _db = db; | ||||
|         _repo = repo; | ||||
|         _logger = logger.ForContext<ShardInfoService>(); | ||||
|         _client = client; | ||||
|         _redis = redis; | ||||
|     } | ||||
|  | ||||
|     public ICollection<ShardInfo> Shards => _shardInfo.Values; | ||||
|  | ||||
|     public void Init() | ||||
|     { | ||||
|         // We initialize this before any shards are actually created and connected | ||||
| @@ -44,109 +35,109 @@ public class ShardInfoService | ||||
|         _client.ShardCreated += InitializeShard; | ||||
|     } | ||||
|  | ||||
|     private void ReportShardStatus() | ||||
|     public async Task<IEnumerable<ShardState>> GetShards() | ||||
|     { | ||||
|         foreach (var (id, shard) in _shardInfo) | ||||
|             _metrics.Measure.Gauge.SetValue(BotMetrics.ShardLatency, new MetricTags("shard", id.ToString()), | ||||
|                 shard.ShardLatency.TotalMilliseconds); | ||||
|         _metrics.Measure.Gauge.SetValue(BotMetrics.ShardsConnected, _shardInfo.Count(s => s.Value.Connected)); | ||||
|         var db = _redis.Connection.GetDatabase(); | ||||
|         var redisInfo = await db.HashGetAllAsync("pluralkit:shardstatus"); | ||||
|         return redisInfo.Select(x => Proto.Unmarshal<ShardState>(x.Value)); | ||||
|     } | ||||
|  | ||||
|     private void InitializeShard(Shard shard) | ||||
|     { | ||||
|         // Get or insert info in the client dict | ||||
|         if (_shardInfo.TryGetValue(shard.ShardId, out var info)) | ||||
|         _ = Inner(); | ||||
|  | ||||
|         async Task Inner() | ||||
|         { | ||||
|             var db = _redis.Connection.GetDatabase(); | ||||
|             var redisInfo = await db.HashGetAsync("pluralkit::shardstatus", shard.ShardId); | ||||
|  | ||||
|             // Skip adding listeners if we've seen this shard & already added listeners to it | ||||
|             if (info.HasAttachedListeners) | ||||
|             if (redisInfo.HasValue) | ||||
|                 return; | ||||
|  | ||||
|             // latency = 0 because otherwise shard 0 would serialize to an empty array, thanks protobuf | ||||
|             var state = new ShardState() { ShardId = shard.ShardId, Up = false, Latency = 1 }; | ||||
|  | ||||
|             // Register listeners for new shard | ||||
|             shard.Resumed += () => ReadyOrResumed(shard); | ||||
|             shard.Ready += () => ReadyOrResumed(shard); | ||||
|             shard.SocketClosed += (closeStatus, message) => SocketClosed(shard, closeStatus, message); | ||||
|             shard.HeartbeatReceived += latency => Heartbeated(shard, latency); | ||||
|  | ||||
|             // Register that we've seen it | ||||
|             await db.HashSetAsync("pluralkit:shardstatus", state.HashWrapper()); | ||||
|         } | ||||
|         else | ||||
|         { | ||||
|             _shardInfo[shard.ShardId] = info = new ShardInfo(); | ||||
|         } | ||||
|  | ||||
|         // Call our own SocketOpened listener manually (and then attach the listener properly) | ||||
|  | ||||
|         // Register listeners for new shards | ||||
|         shard.Resumed += () => ReadyOrResumed(shard); | ||||
|         shard.Ready += () => ReadyOrResumed(shard); | ||||
|         shard.SocketClosed += (closeStatus, message) => SocketClosed(shard, closeStatus, message); | ||||
|         shard.HeartbeatReceived += latency => Heartbeated(shard, latency); | ||||
|  | ||||
|         // Register that we've seen it | ||||
|         info.HasAttachedListeners = true; | ||||
|     } | ||||
|  | ||||
|     private ShardInfo TryGetShard(Shard shard) | ||||
|     private async Task<ShardState?> TryGetShard(Shard shard) | ||||
|     { | ||||
|         // If we haven't seen this shard before, add it to the dict! | ||||
|         // I don't think this will ever occur since the shard number is constant up-front and we handle those | ||||
|         // in the RefreshShardList handler above but you never know, I guess~ | ||||
|         if (!_shardInfo.TryGetValue(shard.ShardId, out var info)) | ||||
|             _shardInfo[shard.ShardId] = info = new ShardInfo(); | ||||
|         return info; | ||||
|         var db = _redis.Connection.GetDatabase(); | ||||
|         var redisInfo = await db.HashGetAsync("pluralkit:shardstatus", shard.ShardId); | ||||
|         if (redisInfo.HasValue) | ||||
|             return Proto.Unmarshal<ShardState>(redisInfo); | ||||
|         return null; | ||||
|     } | ||||
|  | ||||
|     private void ReadyOrResumed(Shard shard) | ||||
|     { | ||||
|         var info = TryGetShard(shard); | ||||
|         info.LastConnectionTime = SystemClock.Instance.GetCurrentInstant(); | ||||
|         info.Connected = true; | ||||
|         ReportShardStatus(); | ||||
|  | ||||
|         _ = ExecuteWithDatabase(async c => | ||||
|         _ = DoAsync(async () => | ||||
|         { | ||||
|             await _repo.SetShardStatus(c, shard.ShardId, PKShardInfo.ShardStatus.Up); | ||||
|             await _repo.RegisterShardConnection(c, shard.ShardId); | ||||
|             var info = await TryGetShard(shard); | ||||
|  | ||||
|             info.LastConnection = (int)SystemClock.Instance.GetCurrentInstant().ToUnixTimeSeconds(); | ||||
|             info.Up = true; | ||||
|  | ||||
|             var db = _redis.Connection.GetDatabase(); | ||||
|             await db.HashSetAsync("pluralkit:shardstatus", info.HashWrapper()); | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     private void SocketClosed(Shard shard, WebSocketCloseStatus? closeStatus, string message) | ||||
|     { | ||||
|         var info = TryGetShard(shard); | ||||
|         info.DisconnectionCount++; | ||||
|         info.Connected = false; | ||||
|         ReportShardStatus(); | ||||
|         _ = DoAsync(async () => | ||||
|         { | ||||
|             var info = await TryGetShard(shard); | ||||
|  | ||||
|         _ = ExecuteWithDatabase(c => | ||||
|             _repo.SetShardStatus(c, shard.ShardId, PKShardInfo.ShardStatus.Down)); | ||||
|             info.DisconnectionCount++; | ||||
|             info.Up = false; | ||||
|  | ||||
|             var db = _redis.Connection.GetDatabase(); | ||||
|             await db.HashSetAsync("pluralkit:shardstatus", info.HashWrapper()); | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     private void Heartbeated(Shard shard, TimeSpan latency) | ||||
|     { | ||||
|         var info = TryGetShard(shard); | ||||
|         info.LastHeartbeatTime = SystemClock.Instance.GetCurrentInstant(); | ||||
|         info.Connected = true; | ||||
|         info.ShardLatency = latency.ToDuration(); | ||||
|         _ = DoAsync(async () => | ||||
|         { | ||||
|             var info = await TryGetShard(shard); | ||||
|  | ||||
|         _ = ExecuteWithDatabase(c => | ||||
|             _repo.RegisterShardHeartbeat(c, shard.ShardId, latency.ToDuration())); | ||||
|             info.LastHeartbeat = (int)SystemClock.Instance.GetCurrentInstant().ToUnixTimeSeconds(); | ||||
|             info.Up = true; | ||||
|             info.Latency = (int)latency.TotalMilliseconds; | ||||
|  | ||||
|             var db = _redis.Connection.GetDatabase(); | ||||
|             await db.HashSetAsync("pluralkit:shardstatus", info.HashWrapper()); | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     private async Task ExecuteWithDatabase(Func<IPKConnection, Task> fn) | ||||
|     private async Task DoAsync(Func<Task> fn) | ||||
|     { | ||||
|         // wrapper function to log errors because we "async void" it at call site :( | ||||
|         try | ||||
|         { | ||||
|             await using var conn = await _db.Obtain(); | ||||
|             await fn(conn); | ||||
|             await fn(); | ||||
|         } | ||||
|         catch (Exception e) | ||||
|         { | ||||
|             _logger.Error(e, "Error persisting shard status"); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
|     public ShardInfo GetShardInfo(int shardId) => _shardInfo[shardId]; | ||||
|  | ||||
|     public class ShardInfo | ||||
|     { | ||||
|         public bool Connected; | ||||
|         public int DisconnectionCount; | ||||
|         public bool HasAttachedListeners; | ||||
|         public Instant LastConnectionTime; | ||||
|         public Instant LastHeartbeatTime; | ||||
|         public Duration ShardLatency; | ||||
|     } | ||||
| public static class RedisExt | ||||
| { | ||||
|     // convenience method | ||||
|     public static HashEntry[] HashWrapper(this ShardState state) | ||||
|         => new[] { new HashEntry(state.ShardId, state.ToByteArray()) }; | ||||
| } | ||||
		Reference in New Issue
	
	Block a user