feat(gateway): initial commit

This commit is contained in:
spiral 2022-04-11 15:55:10 -04:00
parent 8e5b987b2c
commit fadf007abc
No known key found for this signature in database
GPG Key ID: 244A11E4B0BCF40E
12 changed files with 2487 additions and 1 deletions

View File

@ -73,6 +73,8 @@ public class Bot
}
};
_services.Resolve<RedisGatewayService>().OnEventReceived += (evt) => OnEventReceived(0, evt);
// Init the shard stuff
_services.Resolve<ShardInfoService>().Init();

View File

@ -21,6 +21,8 @@ public class BotConfig
public string? GatewayQueueUrl { get; set; }
public bool UseRedisRatelimiter { get; set; } = false;
public string? RedisGatewayUrl { get; set; }
public string? DiscordBaseUrl { get; set; }
public bool DisableErrorReporting { get; set; } = false;

View File

@ -74,7 +74,12 @@ public class Init
// Start the Discord shards themselves (handlers already set up)
logger.Information("Connecting to Discord");
await StartCluster(services);
if (config.RedisGatewayUrl != null)
await services.Resolve<RedisGatewayService>().Start();
else
await StartCluster(services);
logger.Information("Connected! All is good (probably).");
// Lastly, we just... wait. Everything else is handled in the DiscordClient event loop
@ -98,6 +103,7 @@ public class Init
// - Wraps the given function in an exception handler that properly logs errors
// - Adds a SIGINT (Ctrl-C) listener through Console.CancelKeyPress to gracefully shut down
// - Adds a SIGTERM (kill, systemctl stop, docker stop) listener through AppDomain.ProcessExit (same as above)
// todo: move run-clustered.sh to here
var logger = services.Resolve<ILogger>().ForContext<Init>();
var shutdown = new TaskCompletionSource<object>();

View File

@ -43,6 +43,7 @@ public class BotModule: Module
};
}).AsSelf().SingleInstance();
builder.RegisterType<Cluster>().AsSelf().SingleInstance();
builder.RegisterType<RedisGatewayService>().AsSelf().SingleInstance();
builder.Register(c => { return new MemoryDiscordCache(); }).AsSelf().As<IDiscordCache>().SingleInstance();
builder.RegisterType<PrivateChannelService>().AsSelf().SingleInstance();

View File

@ -0,0 +1,63 @@
using System.Text.Json;
using Serilog;
using StackExchange.Redis;
using Myriad.Gateway;
using Myriad.Serialization;
namespace PluralKit.Bot;
public class RedisGatewayService
{
private readonly BotConfig _config;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private ConnectionMultiplexer _redis;
private ILogger _logger;
public RedisGatewayService(BotConfig config, ILogger logger)
{
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
_config = config;
_logger = logger.ForContext<RedisGatewayService>();
}
public event Func<IGatewayEvent, Task>? OnEventReceived;
public async Task Start()
{
_redis = await ConnectionMultiplexer.ConnectAsync(_config.RedisGatewayUrl);
var channel = await _redis.GetSubscriber().SubscribeAsync("evt");
channel.OnMessage(Handle);
}
public async Task Handle(ChannelMessage message)
{
var packet = JsonSerializer.Deserialize<GatewayPacket>(message.Message, _jsonSerializerOptions);
if (packet.Opcode != GatewayOpcode.Dispatch) return;
var evt = DeserializeEvent(packet.EventType, (JsonElement)packet.Payload);
if (evt == null) return;
await OnEventReceived(evt);
}
private IGatewayEvent? DeserializeEvent(string eventType, JsonElement payload)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{
_logger.Debug("Received unknown event type {EventType}", eventType);
return null;
}
try
{
_logger.Verbose("Deserializing {EventType} to {ClrType}", eventType, clrType);
return JsonSerializer.Deserialize(payload.GetRawText(), clrType, _jsonSerializerOptions) as IGatewayEvent;
}
catch (JsonException e)
{
_logger.Error(e, "Error deserializing event {EventType} to {ClrType}", eventType, clrType);
return null;
}
}
}

3
gateway/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/target
config.json

1943
gateway/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

33
gateway/Cargo.toml Normal file
View File

@ -0,0 +1,33 @@
[package]
name = "pluralkit"
version = "0.1.0"
edition = "2021"
[dependencies]
# Infrastructure
anyhow = "1"
config = { version = "0.11", default-features = false, features = ["json"] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
tracing = "0.1"
tracing-subscriber = "0.3"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] }
procfs = "0.12.0"
libc = "0.2.122"
# Twilight
twilight-cache-inmemory = "0.10.0"
twilight-gateway = "0.10.0"
twilight-gateway-queue = "0.10.0"
twilight-http = "0.10.0"
twilight-model = "0.10.0"
# Database
deadpool = "0.9"
deadpool-postgres = "0.10"
postgres-types = { version = "0.2", features = ["derive"] }
tokio-postgres = { version = "0.7", features = ["with-serde_json-1", "with-uuid-0_8"] }
redis = { version = "0.21.5", features = ["aio", "tokio-comp"] }

22
gateway/src/config.rs Normal file
View File

@ -0,0 +1,22 @@
use config::{self, Config};
use serde::Deserialize;
#[derive(Deserialize, Debug)]
pub struct BotConfig {
pub token: String,
pub max_concurrency: u64,
pub database: String,
pub redis_addr: String,
pub redis_gateway_queue_addr: String,
pub shard_count: u64
}
pub fn load_config() -> BotConfig {
let mut settings = Config::default();
settings.merge(config::File::with_name("config")).unwrap();
settings
.merge(config::Environment::with_prefix("PluralKit"))
.unwrap();
settings.try_into::<BotConfig>().unwrap()
}

144
gateway/src/db.rs Normal file
View File

@ -0,0 +1,144 @@
use std::{str::FromStr, time::SystemTime};
use crate::config;
use anyhow::Context;
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
use tokio_postgres::{self, types::FromSql, Row};
use twilight_model::id::Id;
use twilight_model::id::marker::ChannelMarker;
pub async fn init_db(cfg: &config::BotConfig) -> anyhow::Result<Pool> {
let pg_config = tokio_postgres::config::Config::from_str(&cfg.database)
.context("could not parse connection string")?;
let mgr_config = ManagerConfig {
recycling_method: RecyclingMethod::Fast,
};
let mgr = Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
let pool = Pool::builder(mgr)
.max_size(16)
.build()
.context("could not initialize pool")?;
Ok(pool)
}
pub async fn get_message_context(
pool: &Pool,
account_id: u64,
guild_id: u64,
channel_id: u64,
) -> anyhow::Result<Option<MessageContext>> {
let client = pool.get().await?;
let stmt = client
.prepare_cached("select * from message_context($1, $2, $3)")
.await?;
let result = client
.query_opt(
&stmt,
&[
&(account_id as i64),
&(guild_id as i64),
&(channel_id as i64),
],
)
.await
.context("could not fetch message context")?;
Ok(result.map(parse_message_context))
}
pub async fn get_proxy_members(
pool: &Pool,
account_id: u64,
guild_id: u64,
) -> anyhow::Result<Vec<ProxyMember>> {
let client = pool.get().await?;
let stmt = client
.prepare_cached("select * from proxy_members($1, $2)")
.await?;
let result = client
.query(&stmt, &[&(account_id as i64), &(guild_id as i64)])
.await
.context("could not fetch proxy members")?;
Ok(result.into_iter().map(parse_proxy_member).collect())
}
#[derive(Debug)]
pub struct MessageContext {
pub system_id: Option<i32>,
pub log_channel: Option<Id<ChannelMarker>>,
pub in_blacklist: bool,
pub in_log_blacklist: bool,
pub log_cleanup_enabled: bool,
pub proxy_enabled: bool,
pub last_switch: Option<i32>,
pub last_switch_members: Option<Vec<i32>>,
pub last_switch_timestamp: Option<SystemTime>,
pub system_tag: Option<String>,
pub system_guild_tag: Option<String>,
pub tag_enabled: bool,
pub system_avatar: Option<String>,
pub allow_autoproxy: bool,
pub latch_timeout: Option<i32>,
}
#[derive(Debug, FromSql)]
#[postgres(name = "proxy_tag")]
pub struct ProxyTag {
pub prefix: Option<String>,
pub suffix: Option<String>,
}
#[derive(Debug)]
pub struct ProxyMember {
pub id: i32,
pub proxy_tags: Vec<ProxyTag>,
pub keep_proxy: bool,
pub server_name: Option<String>,
pub display_name: Option<String>,
pub name: String,
pub server_avatar: Option<String>,
pub avatar: Option<String>,
pub allow_autoproxy: bool,
pub color: Option<String>,
}
fn parse_message_context(row: Row) -> MessageContext {
MessageContext {
system_id: row.get("system_id"),
log_channel: to_channel_id_opt(row.get("log_channel")),
in_blacklist: row.get::<_, Option<_>>("in_blacklist").unwrap_or(false),
in_log_blacklist: row.get::<_, Option<_>>("in_log_blacklist").unwrap_or(false),
log_cleanup_enabled: row.get("log_cleanup_enabled"),
proxy_enabled: row.get("proxy_enabled"),
last_switch: row.get("last_switch"),
last_switch_members: row.get("last_switch_members"),
last_switch_timestamp: row.get("last_switch_timestamp"),
system_tag: row.get("system_tag"),
system_guild_tag: row.get("system_guild_tag"),
tag_enabled: row.get("tag_enabled"),
system_avatar: row.get("system_avatar"),
allow_autoproxy: row.get("allow_autoproxy"),
latch_timeout: row.get("latch_timeout"),
}
}
fn parse_proxy_member(row: Row) -> ProxyMember {
ProxyMember {
id: row.get("id"),
proxy_tags: row.get("proxy_tags"),
keep_proxy: row.get("keep_proxy"),
server_name: row.get("server_name"),
display_name: row.get("display_name"),
name: row.get("name"),
server_avatar: row.get("server_avatar"),
avatar: row.get("avatar"),
allow_autoproxy: row.get("allow_autoproxy"),
color: row.get("color"),
}
}
fn to_channel_id_opt(id: Option<i64>) -> Option<Id<ChannelMarker>> {
id.and_then(|x| Some(Id::<ChannelMarker>::new(x as u64)))
}

235
gateway/src/main.rs Normal file
View File

@ -0,0 +1,235 @@
use deadpool_postgres::Pool;
use futures::StreamExt;
use redis::AsyncCommands;
use std::{sync::Arc, env};
use tracing::{error, info, Level};
use twilight_cache_inmemory::{InMemoryCache, ResourceType};
use twilight_gateway::{
cluster::{Events, ShardScheme},
Cluster, Event, EventTypeFlags, Intents,
};
use twilight_http::Client as HttpClient;
mod config;
mod db;
mod util;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing();
info!("starting...");
let cfg = config::load_config();
let http = Arc::new(HttpClient::new(cfg.token.clone()));
let rconn = redis::Client::open(cfg.redis_addr.clone()).unwrap();
let (_cluster, events) = init_gateway(&cfg, rconn.clone()).await?;
let cache = init_cache();
let db = db::init_db(&cfg).await?;
run(http, events, cache, db, rconn).await?;
Ok(())
}
async fn run(
http: Arc<HttpClient>,
mut events: Events,
cache: Arc<InMemoryCache>,
db: Pool,
rconn: redis::Client,
) -> anyhow::Result<()> {
while let Some((shard_id, event)) = events.next().await {
cache.update(&event);
let http_cloned = http.clone();
let cache_cloned = cache.clone();
let db_cloned = db.clone();
let rconn_cloned = rconn.clone();
tokio::spawn(async move {
let result = handle_event(
shard_id,
event,
http_cloned,
cache_cloned,
db_cloned,
rconn_cloned
)
.await;
if let Err(e) = result {
error!("error in event handler: {:?}", e);
}
});
}
Ok(())
}
async fn handle_event<'a>(
shard_id: u64,
event: Event,
http: Arc<HttpClient>,
cache: Arc<InMemoryCache>,
_db: Pool,
rconn: redis::Client
) -> anyhow::Result<()> {
match event {
Event::GatewayInvalidateSession(resumable) => {
info!("shard {} session invalidated, resumable? {}", shard_id, resumable);
}
Event::ShardConnected(_) => {
info!("shard {} connected", shard_id);
}
Event::ShardDisconnected(info) => {
info!("shard {} disconnected, code: {:?}, reason: {:?}", shard_id, info.code, info.reason);
}
Event::ShardPayload(payload) => {
let mut conn = rconn.get_async_connection().await?;
conn.publish::<&str, Vec<u8>, i32>("evt", payload.bytes).await?;
}
Event::MessageCreate(msg) => {
if msg.content == "pkt;test" {
// let message_context = db::get_message_context(
// &db,
// msg.author.id.get(),
// msg.guild_id.map(|x| x.get()).unwrap_or(0),
// msg.channel_id.get(),
// )
// .await?;
// let content = format!("message context:\n```\n{:#?}\n```", message_context);
// http.create_message(msg.channel_id)
// .reply(msg.id)
// .content(&content)?
// .exec()
// .await?;
// let proxy_members = db::get_proxy_members(
// &db,
// msg.author.id.get(),
// msg.guild_id.map(|x| x.get()).unwrap_or(0),
// )
// .await?;
// let content = format!("proxy members:\n```\n{:#?}\n```", proxy_members);
// info!("{}", content);
// http.create_message(msg.channel_id)
// .reply(msg.id)
// .content(&content)?
// .exec()
// .await?;
let cache_stats = cache.stats();
let pid = unsafe { libc::getpid() };
let pagesize = {
unsafe {
libc::sysconf(libc::_SC_PAGESIZE)
}
};
let p = procfs::process::Process::new(pid)?;
let content = format!(
"[rust]\nguilds:{}\nchannels:{}\nroles:{}\nusers:{}\nmembers:{}\n\nmemory usage: {}",
cache_stats.guilds(),
cache_stats.channels(),
cache_stats.roles(),
cache_stats.users(),
cache_stats.members(),
p.stat.rss * pagesize
);
http.create_message(msg.channel_id)
.reply(msg.id)
.content(&content)?
.exec()
.await?;
}
}
_ => {}
}
Ok(())
}
fn init_tracing() {
tracing_subscriber::fmt()
.with_max_level(Level::INFO)
.init();
}
async fn init_gateway(
cfg: &config::BotConfig,
rconn: redis::Client,
) -> anyhow::Result<(Arc<Cluster>, Events)> {
let shard_count = cfg.shard_count.clone();
let scheme: ShardScheme;
if shard_count < 16 {
scheme = ShardScheme::Auto;
} else {
let cluster_id = env::var("NOMAD_ALLOC_INDEX")?.parse::<u64>().unwrap();
let first_shard_id = 16 * cluster_id;
scheme = ShardScheme::try_from((first_shard_id..first_shard_id+16, shard_count)).unwrap();
}
let queue = util::RedisQueue {
client: rconn.clone(),
concurrency: cfg.max_concurrency.clone()
};
let (cluster, events) = Cluster::builder(
cfg.token.clone(),
Intents::GUILDS
| Intents::DIRECT_MESSAGES
| Intents::DIRECT_MESSAGE_REACTIONS
| Intents::GUILD_EMOJIS_AND_STICKERS
| Intents::GUILD_MESSAGES
| Intents::GUILD_MESSAGE_REACTIONS
| Intents::GUILD_WEBHOOKS
| Intents::MESSAGE_CONTENT
)
.shard_scheme(scheme)
.event_types(
// EventTypeFlags::all()
EventTypeFlags::READY
| EventTypeFlags::GATEWAY_INVALIDATE_SESSION
| EventTypeFlags::GATEWAY_RECONNECT
| EventTypeFlags::SHARD_PAYLOAD
| EventTypeFlags::SHARD_CONNECTED
| EventTypeFlags::SHARD_DISCONNECTED
| EventTypeFlags::GUILD_CREATE
| EventTypeFlags::CHANNEL_CREATE
| EventTypeFlags::MESSAGE_CREATE
// | EventTypeFlags::MESSAGE_UPDATE
)
.queue(Arc::new(queue))
.build()
.await?;
let cluster = Arc::new(cluster);
let cluster_spawn = Arc::clone(&cluster);
tokio::spawn(async move {
cluster_spawn.up().await;
});
Ok((cluster, events))
}
fn init_cache() -> Arc<InMemoryCache> {
let cache = InMemoryCache::builder()
.resource_types(
ResourceType::GUILD
| ResourceType::CHANNEL
| ResourceType::ROLE
| ResourceType::USER
// | ResourceType::MEMBER
)
.build();
Arc::new(cache)
}

32
gateway/src/util.rs Normal file
View File

@ -0,0 +1,32 @@
use std::time::Duration;
use twilight_gateway_queue::Queue;
#[derive(Debug, Clone)]
pub struct RedisQueue {
pub client: redis::Client,
pub concurrency: u64
}
impl Queue for RedisQueue {
fn request<'a>(&'a self, shard_id: [u64; 2]) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + 'a>> {
Box::pin(request_inner(self.client.clone(), self.concurrency, *shard_id.first().unwrap()))
}
}
async fn request_inner(client: redis::Client, concurrency: u64, shard_id: u64) {
let mut conn = client.get_async_connection().await.unwrap();
let key = format!("pluralkit:identify:{}", (shard_id % concurrency));
let mut cmd = redis::cmd("SET");
cmd.arg(key).arg("1").arg("EX").arg(6i8).arg("NX");
loop {
let done = cmd.clone().query_async::<redis::aio::Connection, Option<String>>(&mut conn).await;
if done.unwrap().is_some() {
return
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
}