use std::time::{Duration, SystemTime}; use axum::{ extract::State, http::Request, middleware::{FromFnLayer, Next}, response::Response, }; use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash}; use http::{HeaderValue, StatusCode}; use metrics::increment_counter; use tracing::{debug, error, info, warn}; use crate::util::{header_or_unknown, json_err}; const LUA_SCRIPT: &str = include_str!("ratelimit.lua"); lazy_static::lazy_static! { static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT); } // todo lol const TOKEN2: &'static str = "h"; // this is awful but it works pub fn ratelimiter(f: F) -> FromFnLayer, T> { let redis = libpk::config.api.ratelimit_redis_addr.as_ref().map(|val| { let r = fred::pool::RedisPool::new( fred::types::RedisConfig::from_url_centralized(val.as_ref()) .expect("redis url is invalid"), 10, ) .expect("failed to connect to redis"); let handle = r.connect(Some(ReconnectPolicy::default())); tokio::spawn(async move { handle }); let rscript = r.clone(); tokio::spawn(async move { if let Ok(()) = rscript.wait_for_connect().await { match rscript.script_load(LUA_SCRIPT).await { Ok(_) => info!("connected to redis for request rate limiting"), Err(err) => error!("could not load redis script: {}", err), } } else { error!("could not wait for connection to load redis script!"); } }); r }); if redis.is_none() { warn!("running without request rate limiting!"); } axum::middleware::from_fn_with_state(redis, f) } pub async fn do_request_ratelimited( State(redis): State>, request: Request, next: Next, ) -> Response { if let Some(redis) = redis { let headers = request.headers().clone(); let source_ip = header_or_unknown(headers.get("Fly-Client-IP")); let (rl_key, rate) = if let Some(header) = request.headers().clone().get("X-PluralKit-App") { if header == TOKEN2 { ("token2", 20) } else { (source_ip, 2) } } else { (source_ip, 2) }; let burst = 5; let period = 1; // seconds // todo: make this static // though even if it's not static, it's probably cheaper than sending the entire script to redis every time let scriptsha = sha1_hash(&LUA_SCRIPT); // local rate_limit_key = KEYS[1] // local burst = ARGV[1] // local rate = ARGV[2] // local period = ARGV[3] // return {remaining, tostring(retry_after), reset_after} let resp = redis .evalsha::<(i32, String, u64), String, Vec<&str>, Vec>( scriptsha, vec![rl_key], vec![burst, rate, period], ) .await; match resp { Ok((mut remaining, retry_after, reset_after)) => { // redis's lua doesn't support returning floats let retry_after: f64 = retry_after .parse() .expect("got something that isn't a f64 from redis"); let mut response = if remaining > 0 { next.run(request).await } else { let retry_after = (retry_after * 1_000_f64).ceil() as u64; debug!("ratelimited request from {rl_key}, retry_after={retry_after}"); increment_counter!("pk_http_requests_ratelimited"); json_err( StatusCode::TOO_MANY_REQUESTS, format!( r#"{{"message":"429: too many requests","retry_after":{retry_after},"code":0}}"#, ), ) }; // the redis script puts burst in remaining for ??? some reason remaining -= burst - rate; let reset_time = SystemTime::now() .checked_add(Duration::from_secs(reset_after)) .expect("invalid timestamp") .duration_since(std::time::UNIX_EPOCH) .expect("invalid duration") .as_secs(); let headers = response.headers_mut(); headers.insert( "X-RateLimit-Limit", HeaderValue::from_str(format!("{}", rate).as_str()) .expect("invalid header value"), ); headers.insert( "X-RateLimit-Remaining", HeaderValue::from_str(format!("{}", remaining).as_str()) .expect("invalid header value"), ); headers.insert( "X-RateLimit-Reset", HeaderValue::from_str(format!("{}", reset_time).as_str()) .expect("invalid header value"), ); return response; } Err(err) => { tracing::error!("error getting ratelimit info: {}", err); return json_err( StatusCode::INTERNAL_SERVER_ERROR, r#"{"message": "500: internal server error", "code": 0}"#.to_string(), ); } } } next.run(request).await }