diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index 3f79c5c3..cfcfe589 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -73,7 +73,7 @@ public class Bot } }; - _services.Resolve().OnEventReceived += (evt) => OnEventReceivedInner(0, evt); + _services.Resolve().OnEventReceived += (e) => OnEventReceivedInner(e.Item1, e.Item2); // Init the shard stuff _services.Resolve().Init(); diff --git a/PluralKit.Bot/Init.cs b/PluralKit.Bot/Init.cs index 80a73dca..57af88a8 100644 --- a/PluralKit.Bot/Init.cs +++ b/PluralKit.Bot/Init.cs @@ -79,11 +79,7 @@ public class Init // Start the Discord shards themselves (handlers already set up) logger.Information("Connecting to Discord"); - - if (config.RedisGatewayUrl != null) - await services.Resolve().Start(); - else - await StartCluster(services); + await StartCluster(services); logger.Information("Connected! All is good (probably)."); @@ -189,7 +185,15 @@ public class Init var shardMin = (int)Math.Round(totalShards * (float)nodeIndex / totalNodes); var shardMax = (int)Math.Round(totalShards * (float)(nodeIndex + 1) / totalNodes) - 1; - await cluster.Start(info.Url, shardMin, shardMax, totalShards, info.SessionStartLimit.MaxConcurrency, redis.Connection); + if (config.RedisGatewayUrl != null) + { + var shardService = services.Resolve(); + + for (var i = shardMin; i <= shardMax; i++) + await shardService.Start(i); + } + else + await cluster.Start(info.Url, shardMin, shardMax, totalShards, info.SessionStartLimit.MaxConcurrency, redis.Connection); } else { diff --git a/PluralKit.Bot/Services/RedisGatewayService.cs b/PluralKit.Bot/Services/RedisGatewayService.cs index 83cb3f88..8671dc9c 100644 --- a/PluralKit.Bot/Services/RedisGatewayService.cs +++ b/PluralKit.Bot/Services/RedisGatewayService.cs @@ -23,22 +23,26 @@ public class RedisGatewayService _logger = logger.ForContext(); } - public event Func? OnEventReceived; + public event Func<(int, IGatewayEvent), Task>? OnEventReceived; - public async Task Start() + public async Task Start(int shardId) { - _redis = await ConnectionMultiplexer.ConnectAsync(_config.RedisGatewayUrl); - var channel = await _redis.GetSubscriber().SubscribeAsync("evt"); - channel.OnMessage(Handle); + if (_redis == null) + _redis = await ConnectionMultiplexer.ConnectAsync(_config.RedisGatewayUrl); + + _logger.Debug("Subscribing to shard {ShardId} on redis", shardId); + + var channel = await _redis.GetSubscriber().SubscribeAsync($"evt-{shardId}"); + channel.OnMessage((evt) => Handle(shardId, evt)); } - public async Task Handle(ChannelMessage message) + public async Task Handle(int shardId, ChannelMessage message) { var packet = JsonSerializer.Deserialize(message.Message, _jsonSerializerOptions); if (packet.Opcode != GatewayOpcode.Dispatch) return; var evt = DeserializeEvent(packet.EventType, (JsonElement)packet.Payload); if (evt == null) return; - await OnEventReceived(evt); + await OnEventReceived((shardId, evt)); } private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload) diff --git a/gateway/src/evt.rs b/gateway/src/evt.rs index 307a4c7f..958c6cf1 100644 --- a/gateway/src/evt.rs +++ b/gateway/src/evt.rs @@ -41,7 +41,7 @@ pub async fn handle_event<'a>( let deserializer = GatewayEventDeserializer::from_json(std::str::from_utf8(&payload.bytes)?).unwrap(); if deserializer.op() == 0 && ALLOWED_EVENTS.contains(&deserializer.event_type_ref().unwrap()) { let mut conn = rconn.get_async_connection().await?; - conn.publish::<&str, Vec, i32>("evt", payload.bytes).await?; + conn.publish::<&str, Vec, i32>(&format!("evt-{shard_id}"), payload.bytes).await?; } } Event::MessageCreate(msg) => {