feat(api): init rust rewrite

This commit is contained in:
spiral 2023-02-15 19:27:36 -05:00
parent 5da3c84bce
commit 5440386969
24 changed files with 2443 additions and 586 deletions

View File

@ -3,16 +3,17 @@
# Include project code and build files
!PluralKit.*/
!gateway/
!myriad_rs/
!services/
!lib/
!Myriad/
!PluralKit.sln
!nuget.config
!.git
!proto
!scripts/run-clustered.sh
!dashboard
!scheduled_tasks
!Cargo.toml
!Cargo.lock
!PluralKit.sln
!nuget.config
# Re-exclude host build artifact directories
**/bin

36
.github/workflows/api.yml vendored Normal file
View File

@ -0,0 +1,36 @@
name: Build and push API Docker image
on:
push:
branches:
- main
- 'rust-api'
paths:
- 'lib/pklib/**'
- 'services/api/**'
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
packages: write
if: github.repository == 'PluralKit/PluralKit'
steps:
- uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.CR_PAT }}
- uses: actions/checkout@v2
- run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" >> $GITHUB_ENV
- uses: docker/build-push-action@v2
with:
# https://github.com/docker/build-push-action/issues/378
context: .
file: services/api/Dockerfile
push: true
tags: |
ghcr.io/pluralkit/api:${{ env.BRANCH_NAME }}
ghcr.io/pluralkit/api:${{ github.sha }}
ghcr.io/pluralkit/api:latest
cache-from: type=registry,ref=ghcr.io/pluralkit/pluralkit:${{ env.BRANCH_NAME }}
cache-to: type=inline

1741
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

7
Cargo.toml Normal file
View File

@ -0,0 +1,7 @@
[workspace]
members = [
"./lib/libpk",
"./services/api"
]
# todo: add workspace dependencies here

View File

@ -3,9 +3,11 @@ package main
import (
"embed"
"encoding/json"
"errors"
"fmt"
"html"
"io"
_fs "io/fs"
"net/http"
"strings"
@ -78,12 +80,21 @@ func notFoundHandler(rw http.ResponseWriter, r *http.Request) {
data = []byte(strings.Replace(string(data), `<!-- extra data -->`, defaultEmbed+versionJS, 1))
}
if err != nil {
if errors.Is(err, _fs.ErrNotExist) {
rw.WriteHeader(http.StatusNotFound)
} else if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
} else {
// cache built+hashed dashboard js/css files forever
is_dash_static_asset := strings.HasPrefix(r.URL.Path, "/assets/") &&
(strings.HasSuffix(r.URL.Path, ".js") || strings.HasSuffix(r.URL.Path, ".css") || strings.HasSuffix(r.URL.Path, ".map"))
if is_dash_static_asset {
rw.Header().Add("Cache-Control", "max-age=31536000, s-maxage=31536000, immutable")
}
rw.Write(data)
}
}
// explanation for createEmbed:

17
lib/libpk/Cargo.toml Normal file
View File

@ -0,0 +1,17 @@
[package]
name = "libpk"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.69"
config = "0.13.3"
gethostname = "0.4.1"
lazy_static = "1.4.0"
serde = "1.0.152"
tokio = { version = "1.25.0", features = ["full"] }
tracing = "0.1.37"
tracing-gelf = "0.7.1"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }

50
lib/libpk/src/_config.rs Normal file
View File

@ -0,0 +1,50 @@
use config::Config;
use lazy_static::lazy_static;
use serde::Deserialize;
use std::sync::Arc;
#[derive(Deserialize, Debug)]
pub struct DiscordConfig {
pub client_id: u32,
pub bot_token: String,
pub client_secret: String,
}
#[derive(Deserialize, Debug)]
pub struct DatabaseConfig {
pub(crate) _data_db_uri: String,
pub(crate) _messages_db_uri: String,
pub(crate) _db_password: Option<String>,
pub data_redis_addr: String,
}
fn _default_api_addr() -> String {
"0.0.0.0:5000".to_string()
}
#[derive(Deserialize, Debug)]
pub struct ApiConfig {
#[serde(default = "_default_api_addr")]
pub addr: String,
#[serde(default)]
pub ratelimit_redis_addr: Option<String>,
pub remote_url: String,
}
#[derive(Deserialize, Debug)]
pub struct PKConfig {
pub discord: DiscordConfig,
pub api: ApiConfig,
pub(crate) gelf_log_url: Option<String>,
}
lazy_static! {
#[derive(Debug)]
pub static ref CONFIG: Arc<PKConfig> = Arc::new(Config::builder()
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
.build().unwrap()
.try_deserialize::<PKConfig>().unwrap());
}

27
lib/libpk/src/lib.rs Normal file
View File

@ -0,0 +1,27 @@
use gethostname::gethostname;
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Registry};
mod _config;
pub use crate::_config::CONFIG as config;
pub fn init_logging(component: &str) -> anyhow::Result<()> {
let subscriber = Registry::default()
.with(EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer());
if let Some(gelf_url) = &config.gelf_log_url {
let gelf_logger = tracing_gelf::Logger::builder()
.additional_field("component", component)
.additional_field("hostname", gethostname().to_str());
let mut conn_handle = gelf_logger
.init_udp_with_subscriber(gelf_url, subscriber)
.unwrap();
tokio::spawn(async move { conn_handle.connect().await });
} else {
// gelf_logger internally sets the global subscriber
tracing::subscriber::set_global_default(subscriber)
.expect("unable to set global subscriber");
}
Ok(())
}

18
services/api/Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "api"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.69"
axum = "0.6.4"
fred = { version = "5.2.0", default-features = false, features = ["tracing", "pool-prefer-active"] }
http = "0.2.8"
hyper-reverse-proxy = "0.5.1"
lazy_static = "1.4.0"
libpk = { path = "../../lib/libpk" }
tokio = { version = "1.25.0", features = ["full"] }
tower = "0.4.13"
tracing = "0.1.37"

27
services/api/Dockerfile Normal file
View File

@ -0,0 +1,27 @@
FROM alpine:latest AS builder
WORKDIR /build
RUN apk add rustup build-base
# todo: arm64 target
RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain stable --profile default -y
COPY Cargo.toml /build/
COPY Cargo.lock /build/
# todo: fetch dependencies first to cache
# RUN cargo fetch
COPY lib/libpk /build/lib/libpk
COPY services/api/ /build/services/api
RUN source "$HOME/.cargo/env" && RUSTFLAGS='-C link-arg=-s' cargo build --bin api --release --target x86_64-unknown-linux-musl
RUN ls /build/target
RUN ls /build/target/release
FROM alpine:latest
COPY --from=builder /build/target/x86_64-unknown-linux-musl/release/api /bin/api
ENTRYPOINT [ "/bin/api" ]

85
services/api/src/main.rs Normal file
View File

@ -0,0 +1,85 @@
use axum::{
routing::{delete, get, patch, post},
Router,
};
use tracing::info;
mod middleware;
mod util;
// this function is manually formatted for easier legibility of routes
#[rustfmt::skip]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
libpk::init_logging("api")?;
info!("hello world");
// processed upside down (???) so we have to put middleware at the end
let app = Router::new()
.route("/v2/systems/:system_id", get(util::rproxy))
.route("/v2/systems/:system_id", patch(util::rproxy))
.route("/v2/systems/:system_id/settings", get(util::rproxy))
.route("/v2/systems/:system_id/settings", patch(util::rproxy))
.route("/v2/systems/:system_id/members", get(util::rproxy))
.route("/v2/members", post(util::rproxy))
.route("/v2/members/:member_id", get(util::rproxy))
.route("/v2/members/:member_id", patch(util::rproxy))
.route("/v2/members/:member_id", delete(util::rproxy))
.route("/v2/systems/:system_id/groups", get(util::rproxy))
.route("/v2/groups", post(util::rproxy))
.route("/v2/groups/:group_id", get(util::rproxy))
.route("/v2/groups/:group_id", patch(util::rproxy))
.route("/v2/groups/:group_id", delete(util::rproxy))
.route("/v2/groups/:group_id/members", get(util::rproxy))
.route("/v2/groups/:group_id/members/add", post(util::rproxy))
.route("/v2/groups/:group_id/members/remove", post(util::rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(util::rproxy))
.route("/v2/members/:member_id/groups", get(util::rproxy))
.route("/v2/members/:member_id/groups/add", post(util::rproxy))
.route("/v2/members/:member_id/groups/remove", post(util::rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(util::rproxy))
.route("/v2/systems/:system_id/switches", get(util::rproxy))
.route("/v2/systems/:system_id/switches", post(util::rproxy))
.route("/v2/systems/:system_id/fronters", get(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(util::rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(util::rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(util::rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(util::rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(util::rproxy))
.route("/v2/messages/:message_id", get(util::rproxy))
.route("/private/meta", get(util::rproxy))
.route("/private/bulk_privacy/member", post(util::rproxy))
.route("/private/bulk_privacy/group", post(util::rproxy))
.route("/private/discord/callback", post(util::rproxy))
.route("/v2/systems/:system_id/oembed.json", get(util::rproxy))
.route("/v2/members/:member_id/oembed.json", get(util::rproxy))
.route("/v2/groups/:group_id/oembed.json", get(util::rproxy))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn(middleware::logger))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }));
let addr: &str = libpk::config.api.addr.as_ref();
axum::Server::bind(&addr.parse()?)
.serve(app.into_make_service())
.await?;
Ok(())
}

View File

@ -0,0 +1,26 @@
use axum::{
http::{HeaderMap, HeaderValue, Method, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[rustfmt::skip]
fn add_cors_headers(headers: &mut HeaderMap) {
headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
}
pub async fn cors<B>(request: Request<B>, next: Next<B>) -> Response {
let mut response = if request.method() == Method::OPTIONS {
StatusCode::OK.into_response()
} else {
next.run(request).await
};
add_cors_headers(response.headers_mut());
response
}

View File

@ -0,0 +1,61 @@
use axum::{
extract::MatchedPath,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use crate::util::header_or_unknown;
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
path.starts_with("/v2/s/")
|| path.starts_with("/v2/m/")
|| path.starts_with("/v2/a/")
|| path.starts_with("/v2/msg/")
|| path == "/v2/s"
|| path == "/v2/m"
}
pub async fn ignore_invalid_routes<B>(request: Request<B>, next: Next<B>) -> Response {
let path = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
if request.uri().path().starts_with("/v1") {
(
StatusCode::GONE,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response()
} else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
(
StatusCode::BAD_REQUEST,
r#"{"message":"Invalid path for API version","code":0}"#,
)
.into_response()
}
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
else if !request.uri().clone().path().starts_with("/v2") {
return (
StatusCode::BAD_REQUEST,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response();
} else if path == "unknown" {
// current prod api responds with 404 with empty body to invalid endpoints
// just doing that here as well but i'm not sure if it's the correct behaviour
return StatusCode::NOT_FOUND.into_response();
}
// yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
// but "unknown" isn't really a valid user-agent
else if user_agent == "unknown" {
// please set a valid user-agent
return StatusCode::FORBIDDEN.into_response();
} else {
next.run(request).await
}
}

View File

@ -0,0 +1,41 @@
use std::time::Instant;
use axum::{extract::MatchedPath, http::Request, middleware::Next, response::Response};
use tracing::{info, span, Instrument, Level};
use crate::util::header_or_unknown;
pub async fn logger<B>(request: Request<B>, next: Next<B>) -> Response {
let method = request.method().clone();
let request_id = header_or_unknown(request.headers().get("Fly-Request-Id"));
let remote_ip = header_or_unknown(request.headers().get("Fly-Client-IP"));
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
let path = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
// todo: prometheus metrics
let request_id_span = span!(
Level::INFO,
"request",
request_id,
remote_ip,
method = method.as_str(),
path,
user_agent
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
info!("handled request for {} {} in {}ms", method, path, elapsed);
response
}

View File

@ -0,0 +1,11 @@
mod cors;
pub use cors::cors;
mod logger;
pub use logger::logger;
mod ignore_invalid_routes;
pub use ignore_invalid_routes::ignore_invalid_routes;
pub mod ratelimit;

View File

@ -0,0 +1,57 @@
-- Copyright (c) 2017 Pavel Pravosud
-- https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
local rate_limit_key = KEYS[1]
local burst = ARGV[1]
local rate = ARGV[2]
local period = ARGV[3]
-- local cost = tonumber(ARGV[4])
local cost = 1
local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
local reset_after = tat - now
local retry_after = diff * -1
return {
0, -- remaining
retry_after,
reset_after,
}
end
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
local retry_after = -1
return {remaining, retry_after, reset_after}

View File

@ -0,0 +1,154 @@
use std::time::{Duration, SystemTime};
use axum::{
extract::State,
http::Request,
middleware::{FromFnLayer, Next},
response::Response,
};
use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash};
use http::{HeaderValue, StatusCode};
use tracing::{error, info, warn};
use crate::util::{header_or_unknown, json_err};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
lazy_static::lazy_static! {
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
}
// todo lol
const TOKEN2: &'static str = "h";
// this is awful but it works
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
let redis = libpk::config.api.ratelimit_redis_addr.as_ref().map(|val| {
let r = fred::pool::RedisPool::new(
fred::types::RedisConfig::from_url_centralized(val.as_ref())
.expect("redis url is invalid"),
10,
)
.expect("failed to connect to redis");
let handle = r.connect(Some(ReconnectPolicy::default()));
tokio::spawn(async move { handle });
let rscript = r.clone();
tokio::spawn(async move {
if let Ok(()) = rscript.wait_for_connect().await {
match rscript.script_load(LUA_SCRIPT).await {
Ok(_) => info!("connected to redis for request rate limiting"),
Err(err) => error!("could not load redis script: {}", err),
}
} else {
error!("could not wait for connection to load redis script!");
}
});
r
});
if redis.is_none() {
warn!("running without request rate limiting!");
}
axum::middleware::from_fn_with_state(redis, f)
}
pub async fn do_request_ratelimited<B>(
State(redis): State<Option<RedisPool>>,
request: Request<B>,
next: Next<B>,
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("Fly-Client-IP"));
let (rl_key, rate) = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if header == TOKEN2 {
("token2", 20)
} else {
(source_ip, 2)
}
} else {
(source_ip, 2)
};
let burst = 5;
let period = 1; // seconds
// todo: make this static
// though even if it's not static, it's probably cheaper than sending the entire script to redis every time
let scriptsha = sha1_hash(&LUA_SCRIPT);
// local rate_limit_key = KEYS[1]
// local burst = ARGV[1]
// local rate = ARGV[2]
// local period = ARGV[3]
// return {remaining, retry_after, reset_after}
let resp = redis
.evalsha::<(i32, String, u64), String, Vec<&str>, Vec<i32>>(
scriptsha,
vec![rl_key],
vec![burst, rate, period],
)
.await;
match resp {
Ok((mut remaining, retry_after, reset_after)) => {
let mut response = if remaining > 0 {
next.run(request).await
} else {
json_err(
StatusCode::TOO_MANY_REQUESTS,
format!(
// todo: the retry_after is horribly wrong
r#"{{"message":"429: too many requests","retry_after":{retry_after}}}"#
),
)
};
// the redis script puts burst in remaining for ??? some reason
remaining -= burst - rate;
let reset_time = SystemTime::now()
.checked_add(Duration::from_secs(reset_after))
.expect("invalid timestamp")
.duration_since(std::time::UNIX_EPOCH)
.expect("invalid duration")
.as_secs();
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(format!("{}", rate).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from_str(format!("{}", remaining).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Reset",
HeaderValue::from_str(format!("{}", reset_time).as_str())
.expect("invalid header value"),
);
return response;
}
Err(err) => {
tracing::error!("error getting ratelimit info: {}", err);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
);
}
}
}
next.run(request).await
}

42
services/api/src/util.rs Normal file
View File

@ -0,0 +1,42 @@
use axum::{
body::Body,
http::{HeaderValue, Request, Response, StatusCode, Uri},
response::IntoResponse,
};
use tracing::error;
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
if let Some(value) = header {
match value.to_str() {
Ok(v) => v,
Err(err) => {
error!("failed to parse header value {:#?}: {:#?}", value, err);
"failed to parse"
}
}
} else {
"unknown"
}
}
pub async fn rproxy(req: Request<Body>) -> Response<Body> {
let uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
match hyper_reverse_proxy::call("0.0.0.0".parse().unwrap(), &uri[..uri.len() - 1], req).await {
Ok(response) => response,
Err(error) => {
error!("error proxying request: {:?}", error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
}
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();
headers.insert("content-type", HeaderValue::from_static("application/json"));
response
}

View File

@ -2,25 +2,26 @@ package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"web-proxy/redis_rate"
)
var limiter *redis_rate.Limiter
// todo: be able to raise ratelimits for >1 consumers
var token2 string
func proxyTo(host string) *httputil.ReverseProxy {
rp := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: host,
RawQuery: "",
})
rp.ModifyResponse = logTimeElapsed
return rp
}
// todo: this shouldn't be in this repo
var remotes = map[string]*httputil.ReverseProxy{
@ -29,22 +30,6 @@ var remotes = map[string]*httputil.ReverseProxy{
"sentry.pluralkit.me": proxyTo("[fdaa:0:ae33:a7b:8dd7:0:a:202]:9000"),
}
func init() {
redisHost := requireEnv("REDIS_HOST")
redisPassword := requireEnv("REDIS_PASSWORD")
rdb := redis.NewClient(&redis.Options{
Addr: redisHost,
Username: "default",
Password: redisPassword,
})
limiter = redis_rate.NewLimiter(rdb)
token2 = requireEnv("TOKEN2")
remotes["dash.pluralkit.me"].ModifyResponse = modifyDashResponse
}
type ProxyHandler struct{}
func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
@ -61,48 +46,6 @@ func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
return
}
if r.Host == "api.pluralkit.me" {
// root
if r.URL.Path == "" {
// api root path redirects to docs
http.Redirect(rw, r, "https://pluralkit.me/api/", http.StatusFound)
return
}
// CORS headers
rw.Header().Add("Access-Control-Allow-Origin", "*")
rw.Header().Add("Access-Control-Allow-Methods", "*")
rw.Header().Add("Access-Control-Allow-Credentials", "true")
rw.Header().Add("Access-Control-Allow-Headers", "Content-Type, Authorization, sentry-trace, User-Agent")
rw.Header().Add("Access-Control-Max-Age", "86400")
if r.Method == http.MethodOptions {
rw.WriteHeader(200)
return
}
if r.URL.Path == "/" {
http.Redirect(rw, r, "https://pluralkit.me/api", http.StatusFound)
return
}
if strings.HasPrefix(r.URL.Path, "/v1") {
rw.Header().Set("content-type", "application/json")
rw.WriteHeader(410)
rw.Write([]byte(`{"message":"Unsupported API version","code":0}`))
}
if is_trying_to_use_v1_path_on_v2(r.URL.Path) {
rw.WriteHeader(400)
rw.Write([]byte(`{"message":"Invalid path for API version","code":0}`))
return
}
if is_api_ratelimited(rw, r) {
return
}
}
startTime := time.Now()
r = r.WithContext(context.WithValue(r.Context(), "req-time", startTime))
@ -119,40 +62,14 @@ func logTimeElapsed(resp *http.Response) error {
"domain": r.Host,
"method": r.Method,
"status": strconv.Itoa(resp.StatusCode),
"route": cleanPath(r.Host, r.URL.Path),
"route": r.URL.Path,
}).Observe(elapsed.Seconds())
log, _ := json.Marshal(map[string]interface{}{
"remote_ip": r.Header.Get("Fly-Client-IP"),
"method": r.Method,
"host": r.Host,
"route": r.URL.Path,
"route_clean": cleanPath(r.Host, r.URL.Path),
"status": resp.StatusCode,
"elapsed": elapsed.Milliseconds(),
"user_agent": r.Header.Get("User-Agent"),
})
fmt.Println(string(log))
// log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
return nil
}
func modifyDashResponse(resp *http.Response) error {
r := resp.Request
// cache built+hashed dashboard js/css files forever
is_dash_static_asset := strings.HasPrefix(r.URL.Path, "/assets/") &&
(strings.HasSuffix(r.URL.Path, ".js") || strings.HasSuffix(r.URL.Path, ".css") || strings.HasSuffix(r.URL.Path, ".map"))
if is_dash_static_asset && resp.StatusCode == 200 {
resp.Header.Add("Cache-Control", "max-age=31536000, s-maxage=31536000, immutable")
}
return logTimeElapsed(resp)
}
func main() {
prometheus.MustRegister(metric)
@ -161,3 +78,11 @@ func main() {
http.ListenAndServe(":8080", ProxyHandler{})
}
var metric = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "pk_http_requests",
Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
},
[]string{"domain", "method", "status", "route"},
)

View File

@ -1,43 +0,0 @@
package main
import (
"fmt"
"net/http"
"time"
"web-proxy/redis_rate"
)
func is_api_ratelimited(rw http.ResponseWriter, r *http.Request) bool {
var limit int
var key string
if r.Header.Get("X-PluralKit-App") == token2 {
limit = 20
key = "token2"
} else {
limit = 2
key = r.Header.Get("Fly-Client-IP")
}
res, err := limiter.Allow(r.Context(), "ratelimit:"+key, redis_rate.Limit{
Period: time.Second,
Rate: limit,
Burst: 5,
})
if err != nil {
panic(err)
}
rw.Header().Set("X-RateLimit-Limit", fmt.Sprint(limit))
rw.Header().Set("X-RateLimit-Remaining", fmt.Sprint(res.Remaining))
rw.Header().Set("X-RateLimit-Reset", fmt.Sprint(time.Now().Add(res.ResetAfter).UnixNano()/1_000_000))
if res.Allowed < 1 {
rw.WriteHeader(429)
rw.Write([]byte(`{"message":"429: too many requests","retry_after":` + fmt.Sprint(res.RetryAfter.Milliseconds()) + `,"code":0}`))
return true
}
return false
}

View File

@ -1,25 +0,0 @@
Copyright (c) 2013 The github.com/go-redis/redis_rate Authors.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,140 +0,0 @@
package redis_rate
import "github.com/go-redis/redis/v8"
// pluralkit changes:
// fly's hosted redis doesn't support replicate commands
// we can remove it since it's a single host
// Copyright (c) 2017 Pavel Pravosud
// https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
var allowN = redis.NewScript(`
-- this script has side-effects, so it requires replicate commands mode
-- redis.replicate_commands()
local rate_limit_key = KEYS[1]
local burst = ARGV[1]
local rate = ARGV[2]
local period = ARGV[3]
local cost = tonumber(ARGV[4])
local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
local reset_after = tat - now
local retry_after = diff * -1
return {
0, -- allowed
0, -- remaining
tostring(retry_after),
tostring(reset_after),
}
end
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
local retry_after = -1
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
`)
var allowAtMost = redis.NewScript(`
-- this script has side-effects, so it requires replicate commands mode
-- redis.replicate_commands()
local rate_limit_key = KEYS[1]
local burst = ARGV[1]
local rate = ARGV[2]
local period = ARGV[3]
local cost = tonumber(ARGV[4])
local emission_interval = period / rate
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local diff = now - (tat - burst_offset)
local remaining = diff / emission_interval
if remaining < 1 then
local reset_after = tat - now
local retry_after = emission_interval - diff
return {
0, -- allowed
0, -- remaining
tostring(retry_after),
tostring(reset_after),
}
end
if remaining < cost then
cost = remaining
remaining = 0
else
remaining = remaining - cost
end
local increment = emission_interval * cost
local new_tat = tat + increment
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
return {
cost,
remaining,
tostring(-1),
tostring(reset_after),
}
`)

View File

@ -1,198 +0,0 @@
package redis_rate
import (
"context"
"fmt"
"strconv"
"time"
"github.com/go-redis/redis/v8"
)
const redisPrefix = "rate:"
type rediser interface {
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
Del(ctx context.Context, keys ...string) *redis.IntCmd
}
type Limit struct {
Rate int
Burst int
Period time.Duration
}
func (l Limit) String() string {
return fmt.Sprintf("%d req/%s (burst %d)", l.Rate, fmtDur(l.Period), l.Burst)
}
func (l Limit) IsZero() bool {
return l == Limit{}
}
func fmtDur(d time.Duration) string {
switch d {
case time.Second:
return "s"
case time.Minute:
return "m"
case time.Hour:
return "h"
}
return d.String()
}
func PerSecond(rate int) Limit {
return Limit{
Rate: rate,
Period: time.Second,
Burst: rate,
}
}
func PerMinute(rate int) Limit {
return Limit{
Rate: rate,
Period: time.Minute,
Burst: rate,
}
}
func PerHour(rate int) Limit {
return Limit{
Rate: rate,
Period: time.Hour,
Burst: rate,
}
}
//------------------------------------------------------------------------------
// Limiter controls how frequently events are allowed to happen.
type Limiter struct {
rdb rediser
}
// NewLimiter returns a new Limiter.
func NewLimiter(rdb rediser) *Limiter {
return &Limiter{
rdb: rdb,
}
}
// Allow is a shortcut for AllowN(ctx, key, limit, 1).
func (l Limiter) Allow(ctx context.Context, key string, limit Limit) (*Result, error) {
return l.AllowN(ctx, key, limit, 1)
}
// AllowN reports whether n events may happen at time now.
func (l Limiter) AllowN(
ctx context.Context,
key string,
limit Limit,
n int,
) (*Result, error) {
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
if err != nil {
return nil, err
}
values = v.([]interface{})
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
if err != nil {
return nil, err
}
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
if err != nil {
return nil, err
}
res := &Result{
Limit: limit,
Allowed: int(values[0].(int64)),
Remaining: int(values[1].(int64)),
RetryAfter: dur(retryAfter),
ResetAfter: dur(resetAfter),
}
return res, nil
}
// AllowAtMost reports whether at most n events may happen at time now.
// It returns number of allowed events that is less than or equal to n.
func (l Limiter) AllowAtMost(
ctx context.Context,
key string,
limit Limit,
n int,
) (*Result, error) {
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
if err != nil {
return nil, err
}
values = v.([]interface{})
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
if err != nil {
return nil, err
}
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
if err != nil {
return nil, err
}
res := &Result{
Limit: limit,
Allowed: int(values[0].(int64)),
Remaining: int(values[1].(int64)),
RetryAfter: dur(retryAfter),
ResetAfter: dur(resetAfter),
}
return res, nil
}
// Reset gets a key and reset all limitations and previous usages
func (l *Limiter) Reset(ctx context.Context, key string) error {
return l.rdb.Del(ctx, redisPrefix+key).Err()
}
func dur(f float64) time.Duration {
if f == -1 {
return -1
}
return time.Duration(f * float64(time.Second))
}
type Result struct {
// Limit is the limit that was used to obtain this result.
Limit Limit
// Allowed is the number of events that may happen at time now.
Allowed int
// Remaining is the maximum number of requests that could be
// permitted instantaneously for this key given the current
// state. For example, if a rate limiter allows 10 requests per
// second and has already received 6 requests for this key this
// second, Remaining would be 4.
Remaining int
// RetryAfter is the time until the next request will be permitted.
// It should be -1 unless the rate limit has been exceeded.
RetryAfter time.Duration
// ResetAfter is the time until the RateLimiter returns to its
// initial state for a given key. For example, if a rate limiter
// manages requests per second and received one request 200ms ago,
// Reset would return 800ms. You can also think of this as the time
// until Limit and Remaining will be equal.
ResetAfter time.Duration
}

View File

@ -1,74 +0,0 @@
package main
import (
"net/http/httputil"
"net/url"
"os"
"regexp"
"strings"
"github.com/prometheus/client_golang/prometheus"
)
var metric = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "pk_http_requests",
Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
},
[]string{"domain", "method", "status", "route"},
)
func proxyTo(host string) *httputil.ReverseProxy {
rp := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: host,
RawQuery: "",
})
rp.ModifyResponse = logTimeElapsed
return rp
}
var systemsRegex = regexp.MustCompile("systems/[^/{}]+")
var membersRegex = regexp.MustCompile("members/[^/{}]+")
var groupsRegex = regexp.MustCompile("groups/[^/{}]+")
var switchesRegex = regexp.MustCompile("switches/[^/{}]+")
var guildsRegex = regexp.MustCompile("guilds/[^/{}]+")
var messagesRegex = regexp.MustCompile("messages/[^/{}]+")
func cleanPath(host, path string) string {
if host != "api.pluralkit.me" {
return ""
}
path = strings.ToLower(path)
if !(strings.HasPrefix(path, "/v2") || strings.HasPrefix(path, "/private")) {
return ""
}
path = systemsRegex.ReplaceAllString(path, "systems/{systemRef}")
path = membersRegex.ReplaceAllString(path, "members/{memberRef}")
path = groupsRegex.ReplaceAllString(path, "groups/{groupRef}")
path = switchesRegex.ReplaceAllString(path, "switches/{switchRef}")
path = guildsRegex.ReplaceAllString(path, "guilds/{guild_id}")
path = messagesRegex.ReplaceAllString(path, "messages/{message_id}")
return path
}
func requireEnv(key string) string {
if val, ok := os.LookupEnv(key); !ok {
panic("missing `" + key + "` in environment")
} else {
return val
}
}
func is_trying_to_use_v1_path_on_v2(path string) bool {
return strings.HasPrefix(path, "/v2/s/") ||
strings.HasPrefix(path, "/v2/m/") ||
strings.HasPrefix(path, "/v2/a/") ||
strings.HasPrefix(path, "/v2/msg/") ||
path == "/v2/s" ||
path == "/v2/m"
}