fix: include shard_id in redis channel

This commit is contained in:
spiral 2022-05-10 07:32:14 -04:00
parent abb01aaf2c
commit 799279054d
No known key found for this signature in database
GPG Key ID: 244A11E4B0BCF40E
4 changed files with 23 additions and 15 deletions

View File

@ -73,7 +73,7 @@ public class Bot
} }
}; };
_services.Resolve<RedisGatewayService>().OnEventReceived += (evt) => OnEventReceivedInner(0, evt); _services.Resolve<RedisGatewayService>().OnEventReceived += (e) => OnEventReceivedInner(e.Item1, e.Item2);
// Init the shard stuff // Init the shard stuff
_services.Resolve<ShardInfoService>().Init(); _services.Resolve<ShardInfoService>().Init();

View File

@ -79,11 +79,7 @@ public class Init
// Start the Discord shards themselves (handlers already set up) // Start the Discord shards themselves (handlers already set up)
logger.Information("Connecting to Discord"); logger.Information("Connecting to Discord");
await StartCluster(services);
if (config.RedisGatewayUrl != null)
await services.Resolve<RedisGatewayService>().Start();
else
await StartCluster(services);
logger.Information("Connected! All is good (probably)."); logger.Information("Connected! All is good (probably).");
@ -189,7 +185,15 @@ public class Init
var shardMin = (int)Math.Round(totalShards * (float)nodeIndex / totalNodes); var shardMin = (int)Math.Round(totalShards * (float)nodeIndex / totalNodes);
var shardMax = (int)Math.Round(totalShards * (float)(nodeIndex + 1) / totalNodes) - 1; var shardMax = (int)Math.Round(totalShards * (float)(nodeIndex + 1) / totalNodes) - 1;
await cluster.Start(info.Url, shardMin, shardMax, totalShards, info.SessionStartLimit.MaxConcurrency, redis.Connection); if (config.RedisGatewayUrl != null)
{
var shardService = services.Resolve<RedisGatewayService>();
for (var i = shardMin; i <= shardMax; i++)
await shardService.Start(i);
}
else
await cluster.Start(info.Url, shardMin, shardMax, totalShards, info.SessionStartLimit.MaxConcurrency, redis.Connection);
} }
else else
{ {

View File

@ -23,22 +23,26 @@ public class RedisGatewayService
_logger = logger.ForContext<RedisGatewayService>(); _logger = logger.ForContext<RedisGatewayService>();
} }
public event Func<IGatewayEvent, Task>? OnEventReceived; public event Func<(int, IGatewayEvent), Task>? OnEventReceived;
public async Task Start() public async Task Start(int shardId)
{ {
_redis = await ConnectionMultiplexer.ConnectAsync(_config.RedisGatewayUrl); if (_redis == null)
var channel = await _redis.GetSubscriber().SubscribeAsync("evt"); _redis = await ConnectionMultiplexer.ConnectAsync(_config.RedisGatewayUrl);
channel.OnMessage(Handle);
_logger.Debug("Subscribing to shard {ShardId} on redis", shardId);
var channel = await _redis.GetSubscriber().SubscribeAsync($"evt-{shardId}");
channel.OnMessage((evt) => Handle(shardId, evt));
} }
public async Task Handle(ChannelMessage message) public async Task Handle(int shardId, ChannelMessage message)
{ {
var packet = JsonSerializer.Deserialize<GatewayPacket>(message.Message, _jsonSerializerOptions); var packet = JsonSerializer.Deserialize<GatewayPacket>(message.Message, _jsonSerializerOptions);
if (packet.Opcode != GatewayOpcode.Dispatch) return; if (packet.Opcode != GatewayOpcode.Dispatch) return;
var evt = DeserializeEvent(packet.EventType, (JsonElement)packet.Payload); var evt = DeserializeEvent(packet.EventType, (JsonElement)packet.Payload);
if (evt == null) return; if (evt == null) return;
await OnEventReceived(evt); await OnEventReceived((shardId, evt));
} }
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload) private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload)

View File

@ -41,7 +41,7 @@ pub async fn handle_event<'a>(
let deserializer = GatewayEventDeserializer::from_json(std::str::from_utf8(&payload.bytes)?).unwrap(); let deserializer = GatewayEventDeserializer::from_json(std::str::from_utf8(&payload.bytes)?).unwrap();
if deserializer.op() == 0 && ALLOWED_EVENTS.contains(&deserializer.event_type_ref().unwrap()) { if deserializer.op() == 0 && ALLOWED_EVENTS.contains(&deserializer.event_type_ref().unwrap()) {
let mut conn = rconn.get_async_connection().await?; let mut conn = rconn.get_async_connection().await?;
conn.publish::<&str, Vec<u8>, i32>("evt", payload.bytes).await?; conn.publish::<&str, Vec<u8>, i32>(&format!("evt-{shard_id}"), payload.bytes).await?;
} }
} }
Event::MessageCreate(msg) => { Event::MessageCreate(msg) => {