feat(api): init rust rewrite

This commit is contained in:
spiral
2023-02-15 19:27:36 -05:00
parent 5da3c84bce
commit 5440386969
24 changed files with 2443 additions and 586 deletions

18
services/api/Cargo.toml Normal file
View File

@@ -0,0 +1,18 @@
[package]
name = "api"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.69"
axum = "0.6.4"
fred = { version = "5.2.0", default-features = false, features = ["tracing", "pool-prefer-active"] }
http = "0.2.8"
hyper-reverse-proxy = "0.5.1"
lazy_static = "1.4.0"
libpk = { path = "../../lib/libpk" }
tokio = { version = "1.25.0", features = ["full"] }
tower = "0.4.13"
tracing = "0.1.37"

27
services/api/Dockerfile Normal file
View File

@@ -0,0 +1,27 @@
FROM alpine:latest AS builder
WORKDIR /build
RUN apk add rustup build-base
# todo: arm64 target
RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain stable --profile default -y
COPY Cargo.toml /build/
COPY Cargo.lock /build/
# todo: fetch dependencies first to cache
# RUN cargo fetch
COPY lib/libpk /build/lib/libpk
COPY services/api/ /build/services/api
RUN source "$HOME/.cargo/env" && RUSTFLAGS='-C link-arg=-s' cargo build --bin api --release --target x86_64-unknown-linux-musl
RUN ls /build/target
RUN ls /build/target/release
FROM alpine:latest
COPY --from=builder /build/target/x86_64-unknown-linux-musl/release/api /bin/api
ENTRYPOINT [ "/bin/api" ]

85
services/api/src/main.rs Normal file
View File

@@ -0,0 +1,85 @@
use axum::{
routing::{delete, get, patch, post},
Router,
};
use tracing::info;
mod middleware;
mod util;
// this function is manually formatted for easier legibility of routes
#[rustfmt::skip]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
libpk::init_logging("api")?;
info!("hello world");
// processed upside down (???) so we have to put middleware at the end
let app = Router::new()
.route("/v2/systems/:system_id", get(util::rproxy))
.route("/v2/systems/:system_id", patch(util::rproxy))
.route("/v2/systems/:system_id/settings", get(util::rproxy))
.route("/v2/systems/:system_id/settings", patch(util::rproxy))
.route("/v2/systems/:system_id/members", get(util::rproxy))
.route("/v2/members", post(util::rproxy))
.route("/v2/members/:member_id", get(util::rproxy))
.route("/v2/members/:member_id", patch(util::rproxy))
.route("/v2/members/:member_id", delete(util::rproxy))
.route("/v2/systems/:system_id/groups", get(util::rproxy))
.route("/v2/groups", post(util::rproxy))
.route("/v2/groups/:group_id", get(util::rproxy))
.route("/v2/groups/:group_id", patch(util::rproxy))
.route("/v2/groups/:group_id", delete(util::rproxy))
.route("/v2/groups/:group_id/members", get(util::rproxy))
.route("/v2/groups/:group_id/members/add", post(util::rproxy))
.route("/v2/groups/:group_id/members/remove", post(util::rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(util::rproxy))
.route("/v2/members/:member_id/groups", get(util::rproxy))
.route("/v2/members/:member_id/groups/add", post(util::rproxy))
.route("/v2/members/:member_id/groups/remove", post(util::rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(util::rproxy))
.route("/v2/systems/:system_id/switches", get(util::rproxy))
.route("/v2/systems/:system_id/switches", post(util::rproxy))
.route("/v2/systems/:system_id/fronters", get(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(util::rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(util::rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(util::rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(util::rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(util::rproxy))
.route("/v2/messages/:message_id", get(util::rproxy))
.route("/private/meta", get(util::rproxy))
.route("/private/bulk_privacy/member", post(util::rproxy))
.route("/private/bulk_privacy/group", post(util::rproxy))
.route("/private/discord/callback", post(util::rproxy))
.route("/v2/systems/:system_id/oembed.json", get(util::rproxy))
.route("/v2/members/:member_id/oembed.json", get(util::rproxy))
.route("/v2/groups/:group_id/oembed.json", get(util::rproxy))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn(middleware::logger))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }));
let addr: &str = libpk::config.api.addr.as_ref();
axum::Server::bind(&addr.parse()?)
.serve(app.into_make_service())
.await?;
Ok(())
}

View File

@@ -0,0 +1,26 @@
use axum::{
http::{HeaderMap, HeaderValue, Method, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[rustfmt::skip]
fn add_cors_headers(headers: &mut HeaderMap) {
headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
}
pub async fn cors<B>(request: Request<B>, next: Next<B>) -> Response {
let mut response = if request.method() == Method::OPTIONS {
StatusCode::OK.into_response()
} else {
next.run(request).await
};
add_cors_headers(response.headers_mut());
response
}

View File

@@ -0,0 +1,61 @@
use axum::{
extract::MatchedPath,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use crate::util::header_or_unknown;
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
path.starts_with("/v2/s/")
|| path.starts_with("/v2/m/")
|| path.starts_with("/v2/a/")
|| path.starts_with("/v2/msg/")
|| path == "/v2/s"
|| path == "/v2/m"
}
pub async fn ignore_invalid_routes<B>(request: Request<B>, next: Next<B>) -> Response {
let path = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
if request.uri().path().starts_with("/v1") {
(
StatusCode::GONE,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response()
} else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
(
StatusCode::BAD_REQUEST,
r#"{"message":"Invalid path for API version","code":0}"#,
)
.into_response()
}
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
else if !request.uri().clone().path().starts_with("/v2") {
return (
StatusCode::BAD_REQUEST,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response();
} else if path == "unknown" {
// current prod api responds with 404 with empty body to invalid endpoints
// just doing that here as well but i'm not sure if it's the correct behaviour
return StatusCode::NOT_FOUND.into_response();
}
// yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
// but "unknown" isn't really a valid user-agent
else if user_agent == "unknown" {
// please set a valid user-agent
return StatusCode::FORBIDDEN.into_response();
} else {
next.run(request).await
}
}

View File

@@ -0,0 +1,41 @@
use std::time::Instant;
use axum::{extract::MatchedPath, http::Request, middleware::Next, response::Response};
use tracing::{info, span, Instrument, Level};
use crate::util::header_or_unknown;
pub async fn logger<B>(request: Request<B>, next: Next<B>) -> Response {
let method = request.method().clone();
let request_id = header_or_unknown(request.headers().get("Fly-Request-Id"));
let remote_ip = header_or_unknown(request.headers().get("Fly-Client-IP"));
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
let path = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
// todo: prometheus metrics
let request_id_span = span!(
Level::INFO,
"request",
request_id,
remote_ip,
method = method.as_str(),
path,
user_agent
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
info!("handled request for {} {} in {}ms", method, path, elapsed);
response
}

View File

@@ -0,0 +1,11 @@
mod cors;
pub use cors::cors;
mod logger;
pub use logger::logger;
mod ignore_invalid_routes;
pub use ignore_invalid_routes::ignore_invalid_routes;
pub mod ratelimit;

View File

@@ -0,0 +1,57 @@
-- Copyright (c) 2017 Pavel Pravosud
-- https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
local rate_limit_key = KEYS[1]
local burst = ARGV[1]
local rate = ARGV[2]
local period = ARGV[3]
-- local cost = tonumber(ARGV[4])
local cost = 1
local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
local reset_after = tat - now
local retry_after = diff * -1
return {
0, -- remaining
retry_after,
reset_after,
}
end
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
local retry_after = -1
return {remaining, retry_after, reset_after}

View File

@@ -0,0 +1,154 @@
use std::time::{Duration, SystemTime};
use axum::{
extract::State,
http::Request,
middleware::{FromFnLayer, Next},
response::Response,
};
use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash};
use http::{HeaderValue, StatusCode};
use tracing::{error, info, warn};
use crate::util::{header_or_unknown, json_err};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
lazy_static::lazy_static! {
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
}
// todo lol
const TOKEN2: &'static str = "h";
// this is awful but it works
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
let redis = libpk::config.api.ratelimit_redis_addr.as_ref().map(|val| {
let r = fred::pool::RedisPool::new(
fred::types::RedisConfig::from_url_centralized(val.as_ref())
.expect("redis url is invalid"),
10,
)
.expect("failed to connect to redis");
let handle = r.connect(Some(ReconnectPolicy::default()));
tokio::spawn(async move { handle });
let rscript = r.clone();
tokio::spawn(async move {
if let Ok(()) = rscript.wait_for_connect().await {
match rscript.script_load(LUA_SCRIPT).await {
Ok(_) => info!("connected to redis for request rate limiting"),
Err(err) => error!("could not load redis script: {}", err),
}
} else {
error!("could not wait for connection to load redis script!");
}
});
r
});
if redis.is_none() {
warn!("running without request rate limiting!");
}
axum::middleware::from_fn_with_state(redis, f)
}
pub async fn do_request_ratelimited<B>(
State(redis): State<Option<RedisPool>>,
request: Request<B>,
next: Next<B>,
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("Fly-Client-IP"));
let (rl_key, rate) = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if header == TOKEN2 {
("token2", 20)
} else {
(source_ip, 2)
}
} else {
(source_ip, 2)
};
let burst = 5;
let period = 1; // seconds
// todo: make this static
// though even if it's not static, it's probably cheaper than sending the entire script to redis every time
let scriptsha = sha1_hash(&LUA_SCRIPT);
// local rate_limit_key = KEYS[1]
// local burst = ARGV[1]
// local rate = ARGV[2]
// local period = ARGV[3]
// return {remaining, retry_after, reset_after}
let resp = redis
.evalsha::<(i32, String, u64), String, Vec<&str>, Vec<i32>>(
scriptsha,
vec![rl_key],
vec![burst, rate, period],
)
.await;
match resp {
Ok((mut remaining, retry_after, reset_after)) => {
let mut response = if remaining > 0 {
next.run(request).await
} else {
json_err(
StatusCode::TOO_MANY_REQUESTS,
format!(
// todo: the retry_after is horribly wrong
r#"{{"message":"429: too many requests","retry_after":{retry_after}}}"#
),
)
};
// the redis script puts burst in remaining for ??? some reason
remaining -= burst - rate;
let reset_time = SystemTime::now()
.checked_add(Duration::from_secs(reset_after))
.expect("invalid timestamp")
.duration_since(std::time::UNIX_EPOCH)
.expect("invalid duration")
.as_secs();
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(format!("{}", rate).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from_str(format!("{}", remaining).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Reset",
HeaderValue::from_str(format!("{}", reset_time).as_str())
.expect("invalid header value"),
);
return response;
}
Err(err) => {
tracing::error!("error getting ratelimit info: {}", err);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
);
}
}
}
next.run(request).await
}

42
services/api/src/util.rs Normal file
View File

@@ -0,0 +1,42 @@
use axum::{
body::Body,
http::{HeaderValue, Request, Response, StatusCode, Uri},
response::IntoResponse,
};
use tracing::error;
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
if let Some(value) = header {
match value.to_str() {
Ok(v) => v,
Err(err) => {
error!("failed to parse header value {:#?}: {:#?}", value, err);
"failed to parse"
}
}
} else {
"unknown"
}
}
pub async fn rproxy(req: Request<Body>) -> Response<Body> {
let uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
match hyper_reverse_proxy::call("0.0.0.0".parse().unwrap(), &uri[..uri.len() - 1], req).await {
Ok(response) => response,
Err(error) => {
error!("error proxying request: {:?}", error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
}
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();
headers.insert("content-type", HeaderValue::from_static("application/json"));
response
}