feat(gateway): initial commit
This commit is contained in:
		
							
								
								
									
										3
									
								
								gateway/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								gateway/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
/target
 | 
			
		||||
 | 
			
		||||
config.json
 | 
			
		||||
							
								
								
									
										1943
									
								
								gateway/Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										1943
									
								
								gateway/Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										33
									
								
								gateway/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								gateway/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
			
		||||
[package]
 | 
			
		||||
name = "pluralkit"
 | 
			
		||||
version = "0.1.0"
 | 
			
		||||
edition = "2021"
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
# Infrastructure
 | 
			
		||||
anyhow = "1"
 | 
			
		||||
config = { version = "0.11", default-features = false, features = ["json"] }
 | 
			
		||||
futures = "0.3"
 | 
			
		||||
serde = { version = "1", features = ["derive"] }
 | 
			
		||||
tracing = "0.1"
 | 
			
		||||
tracing-subscriber = "0.3"
 | 
			
		||||
tokio = { version = "1", features = ["full"] }
 | 
			
		||||
tokio-stream = { version = "0.1", features = ["sync"] }
 | 
			
		||||
 | 
			
		||||
procfs = "0.12.0"
 | 
			
		||||
libc = "0.2.122"
 | 
			
		||||
 | 
			
		||||
# Twilight
 | 
			
		||||
twilight-cache-inmemory = "0.10.0"
 | 
			
		||||
twilight-gateway = "0.10.0"
 | 
			
		||||
twilight-gateway-queue = "0.10.0"
 | 
			
		||||
twilight-http = "0.10.0"
 | 
			
		||||
twilight-model = "0.10.0"
 | 
			
		||||
 | 
			
		||||
# Database
 | 
			
		||||
deadpool = "0.9"
 | 
			
		||||
deadpool-postgres = "0.10"
 | 
			
		||||
postgres-types = { version = "0.2", features = ["derive"] }
 | 
			
		||||
tokio-postgres = { version = "0.7", features = ["with-serde_json-1", "with-uuid-0_8"] }
 | 
			
		||||
 | 
			
		||||
redis = { version = "0.21.5", features = ["aio", "tokio-comp"] }
 | 
			
		||||
							
								
								
									
										22
									
								
								gateway/src/config.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								gateway/src/config.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
use config::{self, Config};
 | 
			
		||||
use serde::Deserialize;
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Debug)]
 | 
			
		||||
pub struct BotConfig {
 | 
			
		||||
    pub token: String,
 | 
			
		||||
    pub max_concurrency: u64,
 | 
			
		||||
    pub database: String,
 | 
			
		||||
    pub redis_addr: String,
 | 
			
		||||
    pub redis_gateway_queue_addr: String,
 | 
			
		||||
    pub shard_count: u64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn load_config() -> BotConfig {
 | 
			
		||||
    let mut settings = Config::default();
 | 
			
		||||
    settings.merge(config::File::with_name("config")).unwrap();
 | 
			
		||||
    settings
 | 
			
		||||
        .merge(config::Environment::with_prefix("PluralKit"))
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
    settings.try_into::<BotConfig>().unwrap()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										144
									
								
								gateway/src/db.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								gateway/src/db.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,144 @@
 | 
			
		||||
use std::{str::FromStr, time::SystemTime};
 | 
			
		||||
 | 
			
		||||
use crate::config;
 | 
			
		||||
use anyhow::Context;
 | 
			
		||||
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
 | 
			
		||||
use tokio_postgres::{self, types::FromSql, Row};
 | 
			
		||||
use twilight_model::id::Id;
 | 
			
		||||
use twilight_model::id::marker::ChannelMarker;
 | 
			
		||||
 | 
			
		||||
pub async fn init_db(cfg: &config::BotConfig) -> anyhow::Result<Pool> {
 | 
			
		||||
    let pg_config = tokio_postgres::config::Config::from_str(&cfg.database)
 | 
			
		||||
        .context("could not parse connection string")?;
 | 
			
		||||
 | 
			
		||||
    let mgr_config = ManagerConfig {
 | 
			
		||||
        recycling_method: RecyclingMethod::Fast,
 | 
			
		||||
    };
 | 
			
		||||
    let mgr = Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
 | 
			
		||||
    let pool = Pool::builder(mgr)
 | 
			
		||||
        .max_size(16)
 | 
			
		||||
        .build()
 | 
			
		||||
        .context("could not initialize pool")?;
 | 
			
		||||
    Ok(pool)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn get_message_context(
 | 
			
		||||
    pool: &Pool,
 | 
			
		||||
    account_id: u64,
 | 
			
		||||
    guild_id: u64,
 | 
			
		||||
    channel_id: u64,
 | 
			
		||||
) -> anyhow::Result<Option<MessageContext>> {
 | 
			
		||||
    let client = pool.get().await?;
 | 
			
		||||
    let stmt = client
 | 
			
		||||
        .prepare_cached("select * from message_context($1, $2, $3)")
 | 
			
		||||
        .await?;
 | 
			
		||||
    let result = client
 | 
			
		||||
        .query_opt(
 | 
			
		||||
            &stmt,
 | 
			
		||||
            &[
 | 
			
		||||
                &(account_id as i64),
 | 
			
		||||
                &(guild_id as i64),
 | 
			
		||||
                &(channel_id as i64),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        .await
 | 
			
		||||
        .context("could not fetch message context")?;
 | 
			
		||||
 | 
			
		||||
    Ok(result.map(parse_message_context))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn get_proxy_members(
 | 
			
		||||
    pool: &Pool,
 | 
			
		||||
    account_id: u64,
 | 
			
		||||
    guild_id: u64,
 | 
			
		||||
) -> anyhow::Result<Vec<ProxyMember>> {
 | 
			
		||||
    let client = pool.get().await?;
 | 
			
		||||
    let stmt = client
 | 
			
		||||
        .prepare_cached("select * from proxy_members($1, $2)")
 | 
			
		||||
        .await?;
 | 
			
		||||
    let result = client
 | 
			
		||||
        .query(&stmt, &[&(account_id as i64), &(guild_id as i64)])
 | 
			
		||||
        .await
 | 
			
		||||
        .context("could not fetch proxy members")?;
 | 
			
		||||
 | 
			
		||||
    Ok(result.into_iter().map(parse_proxy_member).collect())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct MessageContext {
 | 
			
		||||
    pub system_id: Option<i32>,
 | 
			
		||||
    pub log_channel: Option<Id<ChannelMarker>>,
 | 
			
		||||
    pub in_blacklist: bool,
 | 
			
		||||
    pub in_log_blacklist: bool,
 | 
			
		||||
    pub log_cleanup_enabled: bool,
 | 
			
		||||
    pub proxy_enabled: bool,
 | 
			
		||||
    pub last_switch: Option<i32>,
 | 
			
		||||
    pub last_switch_members: Option<Vec<i32>>,
 | 
			
		||||
    pub last_switch_timestamp: Option<SystemTime>,
 | 
			
		||||
    pub system_tag: Option<String>,
 | 
			
		||||
    pub system_guild_tag: Option<String>,
 | 
			
		||||
    pub tag_enabled: bool,
 | 
			
		||||
    pub system_avatar: Option<String>,
 | 
			
		||||
    pub allow_autoproxy: bool,
 | 
			
		||||
    pub latch_timeout: Option<i32>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, FromSql)]
 | 
			
		||||
#[postgres(name = "proxy_tag")]
 | 
			
		||||
pub struct ProxyTag {
 | 
			
		||||
    pub prefix: Option<String>,
 | 
			
		||||
    pub suffix: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct ProxyMember {
 | 
			
		||||
    pub id: i32,
 | 
			
		||||
    pub proxy_tags: Vec<ProxyTag>,
 | 
			
		||||
    pub keep_proxy: bool,
 | 
			
		||||
    pub server_name: Option<String>,
 | 
			
		||||
    pub display_name: Option<String>,
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub server_avatar: Option<String>,
 | 
			
		||||
    pub avatar: Option<String>,
 | 
			
		||||
    pub allow_autoproxy: bool,
 | 
			
		||||
    pub color: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn parse_message_context(row: Row) -> MessageContext {
 | 
			
		||||
    MessageContext {
 | 
			
		||||
        system_id: row.get("system_id"),
 | 
			
		||||
        log_channel: to_channel_id_opt(row.get("log_channel")),
 | 
			
		||||
        in_blacklist: row.get::<_, Option<_>>("in_blacklist").unwrap_or(false),
 | 
			
		||||
        in_log_blacklist: row.get::<_, Option<_>>("in_log_blacklist").unwrap_or(false),
 | 
			
		||||
        log_cleanup_enabled: row.get("log_cleanup_enabled"),
 | 
			
		||||
        proxy_enabled: row.get("proxy_enabled"),
 | 
			
		||||
        last_switch: row.get("last_switch"),
 | 
			
		||||
        last_switch_members: row.get("last_switch_members"),
 | 
			
		||||
        last_switch_timestamp: row.get("last_switch_timestamp"),
 | 
			
		||||
        system_tag: row.get("system_tag"),
 | 
			
		||||
        system_guild_tag: row.get("system_guild_tag"),
 | 
			
		||||
        tag_enabled: row.get("tag_enabled"),
 | 
			
		||||
        system_avatar: row.get("system_avatar"),
 | 
			
		||||
        allow_autoproxy: row.get("allow_autoproxy"),
 | 
			
		||||
        latch_timeout: row.get("latch_timeout"),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn parse_proxy_member(row: Row) -> ProxyMember {
 | 
			
		||||
    ProxyMember {
 | 
			
		||||
        id: row.get("id"),
 | 
			
		||||
        proxy_tags: row.get("proxy_tags"),
 | 
			
		||||
        keep_proxy: row.get("keep_proxy"),
 | 
			
		||||
        server_name: row.get("server_name"),
 | 
			
		||||
        display_name: row.get("display_name"),
 | 
			
		||||
        name: row.get("name"),
 | 
			
		||||
        server_avatar: row.get("server_avatar"),
 | 
			
		||||
        avatar: row.get("avatar"),
 | 
			
		||||
        allow_autoproxy: row.get("allow_autoproxy"),
 | 
			
		||||
        color: row.get("color"),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn to_channel_id_opt(id: Option<i64>) -> Option<Id<ChannelMarker>> {
 | 
			
		||||
    id.and_then(|x| Some(Id::<ChannelMarker>::new(x as u64)))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										235
									
								
								gateway/src/main.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								gateway/src/main.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,235 @@
 | 
			
		||||
use deadpool_postgres::Pool;
 | 
			
		||||
use futures::StreamExt;
 | 
			
		||||
use redis::AsyncCommands;
 | 
			
		||||
use std::{sync::Arc, env};
 | 
			
		||||
use tracing::{error, info, Level};
 | 
			
		||||
 | 
			
		||||
use twilight_cache_inmemory::{InMemoryCache, ResourceType};
 | 
			
		||||
use twilight_gateway::{
 | 
			
		||||
    cluster::{Events, ShardScheme},
 | 
			
		||||
    Cluster, Event, EventTypeFlags, Intents,
 | 
			
		||||
};
 | 
			
		||||
use twilight_http::Client as HttpClient;
 | 
			
		||||
 | 
			
		||||
mod config;
 | 
			
		||||
mod db;
 | 
			
		||||
mod util;
 | 
			
		||||
 | 
			
		||||
#[tokio::main]
 | 
			
		||||
async fn main() -> anyhow::Result<()> {
 | 
			
		||||
    init_tracing();
 | 
			
		||||
    info!("starting...");
 | 
			
		||||
 | 
			
		||||
    let cfg = config::load_config();
 | 
			
		||||
 | 
			
		||||
    let http = Arc::new(HttpClient::new(cfg.token.clone()));
 | 
			
		||||
    let rconn = redis::Client::open(cfg.redis_addr.clone()).unwrap();
 | 
			
		||||
    let (_cluster, events) = init_gateway(&cfg, rconn.clone()).await?;
 | 
			
		||||
    let cache = init_cache();
 | 
			
		||||
    let db = db::init_db(&cfg).await?;
 | 
			
		||||
 | 
			
		||||
    run(http, events, cache, db, rconn).await?;
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn run(
 | 
			
		||||
    http: Arc<HttpClient>,
 | 
			
		||||
    mut events: Events,
 | 
			
		||||
    cache: Arc<InMemoryCache>,
 | 
			
		||||
    db: Pool,
 | 
			
		||||
    rconn: redis::Client,
 | 
			
		||||
) -> anyhow::Result<()> {
 | 
			
		||||
    while let Some((shard_id, event)) = events.next().await {
 | 
			
		||||
 | 
			
		||||
        cache.update(&event);
 | 
			
		||||
 | 
			
		||||
        let http_cloned = http.clone();
 | 
			
		||||
        let cache_cloned = cache.clone();
 | 
			
		||||
        let db_cloned = db.clone();
 | 
			
		||||
        let rconn_cloned = rconn.clone();
 | 
			
		||||
        
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            let result = handle_event(
 | 
			
		||||
                shard_id,
 | 
			
		||||
                event,
 | 
			
		||||
                http_cloned,
 | 
			
		||||
                cache_cloned,
 | 
			
		||||
                db_cloned,
 | 
			
		||||
                rconn_cloned
 | 
			
		||||
            )
 | 
			
		||||
            .await;
 | 
			
		||||
            if let Err(e) = result {
 | 
			
		||||
                error!("error in event handler: {:?}", e);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn handle_event<'a>(
 | 
			
		||||
    shard_id: u64,
 | 
			
		||||
    event: Event,
 | 
			
		||||
    http: Arc<HttpClient>,
 | 
			
		||||
    cache: Arc<InMemoryCache>,
 | 
			
		||||
    _db: Pool,
 | 
			
		||||
    rconn: redis::Client
 | 
			
		||||
) -> anyhow::Result<()> {
 | 
			
		||||
    match event {
 | 
			
		||||
        Event::GatewayInvalidateSession(resumable) => {
 | 
			
		||||
            info!("shard {} session invalidated, resumable? {}", shard_id, resumable);
 | 
			
		||||
        }
 | 
			
		||||
        Event::ShardConnected(_) => {
 | 
			
		||||
            info!("shard {} connected", shard_id);
 | 
			
		||||
        }
 | 
			
		||||
        Event::ShardDisconnected(info) => {
 | 
			
		||||
            info!("shard {} disconnected, code: {:?}, reason: {:?}", shard_id, info.code, info.reason);
 | 
			
		||||
        }
 | 
			
		||||
        Event::ShardPayload(payload) => {
 | 
			
		||||
            let mut conn = rconn.get_async_connection().await?;
 | 
			
		||||
            conn.publish::<&str, Vec<u8>, i32>("evt", payload.bytes).await?;
 | 
			
		||||
        }
 | 
			
		||||
        Event::MessageCreate(msg) => {
 | 
			
		||||
            if msg.content == "pkt;test" {
 | 
			
		||||
                // let message_context = db::get_message_context(
 | 
			
		||||
                //     &db,
 | 
			
		||||
                //     msg.author.id.get(),
 | 
			
		||||
                //     msg.guild_id.map(|x| x.get()).unwrap_or(0),
 | 
			
		||||
                //     msg.channel_id.get(),
 | 
			
		||||
                // )
 | 
			
		||||
                // .await?;
 | 
			
		||||
 | 
			
		||||
                // let content = format!("message context:\n```\n{:#?}\n```", message_context);
 | 
			
		||||
                // http.create_message(msg.channel_id)
 | 
			
		||||
                //     .reply(msg.id)
 | 
			
		||||
                //     .content(&content)?
 | 
			
		||||
                //     .exec()
 | 
			
		||||
                //     .await?;
 | 
			
		||||
 | 
			
		||||
                // let proxy_members = db::get_proxy_members(
 | 
			
		||||
                //     &db,
 | 
			
		||||
                //     msg.author.id.get(),
 | 
			
		||||
                //     msg.guild_id.map(|x| x.get()).unwrap_or(0),
 | 
			
		||||
                // )
 | 
			
		||||
                // .await?;
 | 
			
		||||
 | 
			
		||||
                // let content = format!("proxy members:\n```\n{:#?}\n```", proxy_members);
 | 
			
		||||
                // info!("{}", content);
 | 
			
		||||
                // http.create_message(msg.channel_id)
 | 
			
		||||
                //     .reply(msg.id)
 | 
			
		||||
                //     .content(&content)?
 | 
			
		||||
                //     .exec()
 | 
			
		||||
                //     .await?;
 | 
			
		||||
 | 
			
		||||
                let cache_stats = cache.stats();
 | 
			
		||||
 | 
			
		||||
                let pid = unsafe { libc::getpid() };
 | 
			
		||||
                let pagesize = {
 | 
			
		||||
                    unsafe {
 | 
			
		||||
                        libc::sysconf(libc::_SC_PAGESIZE)
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
                
 | 
			
		||||
                let p = procfs::process::Process::new(pid)?;
 | 
			
		||||
                let content = format!(
 | 
			
		||||
                    "[rust]\nguilds:{}\nchannels:{}\nroles:{}\nusers:{}\nmembers:{}\n\nmemory usage: {}",
 | 
			
		||||
                    cache_stats.guilds(),
 | 
			
		||||
                    cache_stats.channels(),
 | 
			
		||||
                    cache_stats.roles(),
 | 
			
		||||
                    cache_stats.users(),
 | 
			
		||||
                    cache_stats.members(),
 | 
			
		||||
                    p.stat.rss * pagesize
 | 
			
		||||
                );
 | 
			
		||||
 | 
			
		||||
                http.create_message(msg.channel_id)
 | 
			
		||||
                .reply(msg.id)
 | 
			
		||||
                .content(&content)?
 | 
			
		||||
                .exec()
 | 
			
		||||
                .await?;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        _ => {}
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn init_tracing() {
 | 
			
		||||
    tracing_subscriber::fmt()
 | 
			
		||||
        .with_max_level(Level::INFO)
 | 
			
		||||
        .init();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn init_gateway(
 | 
			
		||||
    cfg: &config::BotConfig,
 | 
			
		||||
    rconn: redis::Client,
 | 
			
		||||
) -> anyhow::Result<(Arc<Cluster>, Events)> {
 | 
			
		||||
    let shard_count = cfg.shard_count.clone();
 | 
			
		||||
 | 
			
		||||
    let scheme: ShardScheme;
 | 
			
		||||
 | 
			
		||||
    if shard_count < 16 {
 | 
			
		||||
        scheme = ShardScheme::Auto;
 | 
			
		||||
    } else {
 | 
			
		||||
        let cluster_id = env::var("NOMAD_ALLOC_INDEX")?.parse::<u64>().unwrap();
 | 
			
		||||
        let first_shard_id = 16 * cluster_id;
 | 
			
		||||
 | 
			
		||||
        scheme = ShardScheme::try_from((first_shard_id..first_shard_id+16, shard_count)).unwrap();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let queue = util::RedisQueue {
 | 
			
		||||
        client: rconn.clone(),
 | 
			
		||||
        concurrency: cfg.max_concurrency.clone()
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let (cluster, events) = Cluster::builder(
 | 
			
		||||
        cfg.token.clone(),
 | 
			
		||||
        Intents::GUILDS
 | 
			
		||||
        | Intents::DIRECT_MESSAGES
 | 
			
		||||
        | Intents::DIRECT_MESSAGE_REACTIONS
 | 
			
		||||
        | Intents::GUILD_EMOJIS_AND_STICKERS
 | 
			
		||||
        | Intents::GUILD_MESSAGES
 | 
			
		||||
        | Intents::GUILD_MESSAGE_REACTIONS
 | 
			
		||||
        | Intents::GUILD_WEBHOOKS
 | 
			
		||||
        | Intents::MESSAGE_CONTENT
 | 
			
		||||
    )
 | 
			
		||||
        .shard_scheme(scheme)
 | 
			
		||||
        .event_types(
 | 
			
		||||
            // EventTypeFlags::all()
 | 
			
		||||
                EventTypeFlags::READY
 | 
			
		||||
              | EventTypeFlags::GATEWAY_INVALIDATE_SESSION
 | 
			
		||||
              | EventTypeFlags::GATEWAY_RECONNECT
 | 
			
		||||
              | EventTypeFlags::SHARD_PAYLOAD
 | 
			
		||||
              | EventTypeFlags::SHARD_CONNECTED
 | 
			
		||||
              | EventTypeFlags::SHARD_DISCONNECTED
 | 
			
		||||
              | EventTypeFlags::GUILD_CREATE
 | 
			
		||||
              | EventTypeFlags::CHANNEL_CREATE
 | 
			
		||||
              | EventTypeFlags::MESSAGE_CREATE
 | 
			
		||||
            // | EventTypeFlags::MESSAGE_UPDATE
 | 
			
		||||
        )
 | 
			
		||||
        .queue(Arc::new(queue))
 | 
			
		||||
        .build()
 | 
			
		||||
        .await?;
 | 
			
		||||
    let cluster = Arc::new(cluster);
 | 
			
		||||
 | 
			
		||||
    let cluster_spawn = Arc::clone(&cluster);
 | 
			
		||||
    tokio::spawn(async move {
 | 
			
		||||
        cluster_spawn.up().await;
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    Ok((cluster, events))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn init_cache() -> Arc<InMemoryCache> {
 | 
			
		||||
    let cache = InMemoryCache::builder()
 | 
			
		||||
        .resource_types(
 | 
			
		||||
              ResourceType::GUILD
 | 
			
		||||
            | ResourceType::CHANNEL
 | 
			
		||||
            | ResourceType::ROLE
 | 
			
		||||
            | ResourceType::USER
 | 
			
		||||
            // | ResourceType::MEMBER
 | 
			
		||||
        )
 | 
			
		||||
        .build();
 | 
			
		||||
    Arc::new(cache)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										32
									
								
								gateway/src/util.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								gateway/src/util.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
			
		||||
use std::time::Duration;
 | 
			
		||||
 | 
			
		||||
use twilight_gateway_queue::Queue;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct RedisQueue {
 | 
			
		||||
    pub client: redis::Client,
 | 
			
		||||
    pub concurrency: u64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Queue for RedisQueue {
 | 
			
		||||
    fn request<'a>(&'a self, shard_id: [u64; 2]) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + 'a>> {
 | 
			
		||||
        Box::pin(request_inner(self.client.clone(), self.concurrency, *shard_id.first().unwrap()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn request_inner(client: redis::Client, concurrency: u64, shard_id: u64) {
 | 
			
		||||
    let mut conn = client.get_async_connection().await.unwrap();
 | 
			
		||||
    let key = format!("pluralkit:identify:{}", (shard_id % concurrency));
 | 
			
		||||
 | 
			
		||||
    let mut cmd = redis::cmd("SET");
 | 
			
		||||
    cmd.arg(key).arg("1").arg("EX").arg(6i8).arg("NX");
 | 
			
		||||
 | 
			
		||||
    loop {
 | 
			
		||||
        let done = cmd.clone().query_async::<redis::aio::Connection, Option<String>>(&mut conn).await;
 | 
			
		||||
        if done.unwrap().is_some() {
 | 
			
		||||
            return
 | 
			
		||||
        }
 | 
			
		||||
        tokio::time::sleep(Duration::from_millis(500)).await;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user