fix: include shard_id in redis channel

This commit is contained in:
spiral 2022-05-10 07:32:14 -04:00
parent abb01aaf2c
commit 799279054d
No known key found for this signature in database
GPG Key ID: 244A11E4B0BCF40E
4 changed files with 23 additions and 15 deletions

View File

@ -73,7 +73,7 @@ public class Bot
}
};
_services.Resolve<RedisGatewayService>().OnEventReceived += (evt) => OnEventReceivedInner(0, evt);
_services.Resolve<RedisGatewayService>().OnEventReceived += (e) => OnEventReceivedInner(e.Item1, e.Item2);
// Init the shard stuff
_services.Resolve<ShardInfoService>().Init();

View File

@ -79,10 +79,6 @@ public class Init
// Start the Discord shards themselves (handlers already set up)
logger.Information("Connecting to Discord");
if (config.RedisGatewayUrl != null)
await services.Resolve<RedisGatewayService>().Start();
else
await StartCluster(services);
logger.Information("Connected! All is good (probably).");
@ -189,6 +185,14 @@ public class Init
var shardMin = (int)Math.Round(totalShards * (float)nodeIndex / totalNodes);
var shardMax = (int)Math.Round(totalShards * (float)(nodeIndex + 1) / totalNodes) - 1;
if (config.RedisGatewayUrl != null)
{
var shardService = services.Resolve<RedisGatewayService>();
for (var i = shardMin; i <= shardMax; i++)
await shardService.Start(i);
}
else
await cluster.Start(info.Url, shardMin, shardMax, totalShards, info.SessionStartLimit.MaxConcurrency, redis.Connection);
}
else

View File

@ -23,22 +23,26 @@ public class RedisGatewayService
_logger = logger.ForContext<RedisGatewayService>();
}
public event Func<IGatewayEvent, Task>? OnEventReceived;
public event Func<(int, IGatewayEvent), Task>? OnEventReceived;
public async Task Start()
public async Task Start(int shardId)
{
if (_redis == null)
_redis = await ConnectionMultiplexer.ConnectAsync(_config.RedisGatewayUrl);
var channel = await _redis.GetSubscriber().SubscribeAsync("evt");
channel.OnMessage(Handle);
_logger.Debug("Subscribing to shard {ShardId} on redis", shardId);
var channel = await _redis.GetSubscriber().SubscribeAsync($"evt-{shardId}");
channel.OnMessage((evt) => Handle(shardId, evt));
}
public async Task Handle(ChannelMessage message)
public async Task Handle(int shardId, ChannelMessage message)
{
var packet = JsonSerializer.Deserialize<GatewayPacket>(message.Message, _jsonSerializerOptions);
if (packet.Opcode != GatewayOpcode.Dispatch) return;
var evt = DeserializeEvent(packet.EventType, (JsonElement)packet.Payload);
if (evt == null) return;
await OnEventReceived(evt);
await OnEventReceived((shardId, evt));
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload)

View File

@ -41,7 +41,7 @@ pub async fn handle_event<'a>(
let deserializer = GatewayEventDeserializer::from_json(std::str::from_utf8(&payload.bytes)?).unwrap();
if deserializer.op() == 0 && ALLOWED_EVENTS.contains(&deserializer.event_type_ref().unwrap()) {
let mut conn = rconn.get_async_connection().await?;
conn.publish::<&str, Vec<u8>, i32>("evt", payload.bytes).await?;
conn.publish::<&str, Vec<u8>, i32>(&format!("evt-{shard_id}"), payload.bytes).await?;
}
}
Event::MessageCreate(msg) => {