feat(api): init rust rewrite
This commit is contained in:
parent
5da3c84bce
commit
5440386969
@ -3,16 +3,17 @@
|
||||
|
||||
# Include project code and build files
|
||||
!PluralKit.*/
|
||||
!gateway/
|
||||
!myriad_rs/
|
||||
!services/
|
||||
!lib/
|
||||
!Myriad/
|
||||
!PluralKit.sln
|
||||
!nuget.config
|
||||
!.git
|
||||
!proto
|
||||
!scripts/run-clustered.sh
|
||||
!dashboard
|
||||
!scheduled_tasks
|
||||
|
||||
!Cargo.toml
|
||||
!Cargo.lock
|
||||
!PluralKit.sln
|
||||
!nuget.config
|
||||
|
||||
# Re-exclude host build artifact directories
|
||||
**/bin
|
||||
|
36
.github/workflows/api.yml
vendored
Normal file
36
.github/workflows/api.yml
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
name: Build and push API Docker image
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- 'rust-api'
|
||||
paths:
|
||||
- 'lib/pklib/**'
|
||||
- 'services/api/**'
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
packages: write
|
||||
if: github.repository == 'PluralKit/PluralKit'
|
||||
steps:
|
||||
- uses: docker/login-action@v1
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CR_PAT }}
|
||||
- uses: actions/checkout@v2
|
||||
- run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" >> $GITHUB_ENV
|
||||
- uses: docker/build-push-action@v2
|
||||
with:
|
||||
# https://github.com/docker/build-push-action/issues/378
|
||||
context: .
|
||||
file: services/api/Dockerfile
|
||||
push: true
|
||||
tags: |
|
||||
ghcr.io/pluralkit/api:${{ env.BRANCH_NAME }}
|
||||
ghcr.io/pluralkit/api:${{ github.sha }}
|
||||
ghcr.io/pluralkit/api:latest
|
||||
cache-from: type=registry,ref=ghcr.io/pluralkit/pluralkit:${{ env.BRANCH_NAME }}
|
||||
cache-to: type=inline
|
1741
Cargo.lock
generated
Normal file
1741
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
7
Cargo.toml
Normal file
7
Cargo.toml
Normal file
@ -0,0 +1,7 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"./lib/libpk",
|
||||
"./services/api"
|
||||
]
|
||||
|
||||
# todo: add workspace dependencies here
|
@ -3,9 +3,11 @@ package main
|
||||
import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
_fs "io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -78,12 +80,21 @@ func notFoundHandler(rw http.ResponseWriter, r *http.Request) {
|
||||
data = []byte(strings.Replace(string(data), `<!-- extra data -->`, defaultEmbed+versionJS, 1))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, _fs.ErrNotExist) {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
} else if err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else {
|
||||
// cache built+hashed dashboard js/css files forever
|
||||
is_dash_static_asset := strings.HasPrefix(r.URL.Path, "/assets/") &&
|
||||
(strings.HasSuffix(r.URL.Path, ".js") || strings.HasSuffix(r.URL.Path, ".css") || strings.HasSuffix(r.URL.Path, ".map"))
|
||||
|
||||
if is_dash_static_asset {
|
||||
rw.Header().Add("Cache-Control", "max-age=31536000, s-maxage=31536000, immutable")
|
||||
}
|
||||
|
||||
rw.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
// explanation for createEmbed:
|
||||
|
17
lib/libpk/Cargo.toml
Normal file
17
lib/libpk/Cargo.toml
Normal file
@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "libpk"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.69"
|
||||
config = "0.13.3"
|
||||
gethostname = "0.4.1"
|
||||
lazy_static = "1.4.0"
|
||||
serde = "1.0.152"
|
||||
tokio = { version = "1.25.0", features = ["full"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-gelf = "0.7.1"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
|
50
lib/libpk/src/_config.rs
Normal file
50
lib/libpk/src/_config.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use config::Config;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DiscordConfig {
|
||||
pub client_id: u32,
|
||||
pub bot_token: String,
|
||||
pub client_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DatabaseConfig {
|
||||
pub(crate) _data_db_uri: String,
|
||||
pub(crate) _messages_db_uri: String,
|
||||
pub(crate) _db_password: Option<String>,
|
||||
pub data_redis_addr: String,
|
||||
}
|
||||
|
||||
fn _default_api_addr() -> String {
|
||||
"0.0.0.0:5000".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ApiConfig {
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub addr: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub ratelimit_redis_addr: Option<String>,
|
||||
|
||||
pub remote_url: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PKConfig {
|
||||
pub discord: DiscordConfig,
|
||||
pub api: ApiConfig,
|
||||
|
||||
pub(crate) gelf_log_url: Option<String>,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
#[derive(Debug)]
|
||||
pub static ref CONFIG: Arc<PKConfig> = Arc::new(Config::builder()
|
||||
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
|
||||
.build().unwrap()
|
||||
.try_deserialize::<PKConfig>().unwrap());
|
||||
}
|
27
lib/libpk/src/lib.rs
Normal file
27
lib/libpk/src/lib.rs
Normal file
@ -0,0 +1,27 @@
|
||||
use gethostname::gethostname;
|
||||
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Registry};
|
||||
|
||||
mod _config;
|
||||
pub use crate::_config::CONFIG as config;
|
||||
|
||||
pub fn init_logging(component: &str) -> anyhow::Result<()> {
|
||||
let subscriber = Registry::default()
|
||||
.with(EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer());
|
||||
|
||||
if let Some(gelf_url) = &config.gelf_log_url {
|
||||
let gelf_logger = tracing_gelf::Logger::builder()
|
||||
.additional_field("component", component)
|
||||
.additional_field("hostname", gethostname().to_str());
|
||||
let mut conn_handle = gelf_logger
|
||||
.init_udp_with_subscriber(gelf_url, subscriber)
|
||||
.unwrap();
|
||||
tokio::spawn(async move { conn_handle.connect().await });
|
||||
} else {
|
||||
// gelf_logger internally sets the global subscriber
|
||||
tracing::subscriber::set_global_default(subscriber)
|
||||
.expect("unable to set global subscriber");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
18
services/api/Cargo.toml
Normal file
18
services/api/Cargo.toml
Normal file
@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "api"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.69"
|
||||
axum = "0.6.4"
|
||||
fred = { version = "5.2.0", default-features = false, features = ["tracing", "pool-prefer-active"] }
|
||||
http = "0.2.8"
|
||||
hyper-reverse-proxy = "0.5.1"
|
||||
lazy_static = "1.4.0"
|
||||
libpk = { path = "../../lib/libpk" }
|
||||
tokio = { version = "1.25.0", features = ["full"] }
|
||||
tower = "0.4.13"
|
||||
tracing = "0.1.37"
|
27
services/api/Dockerfile
Normal file
27
services/api/Dockerfile
Normal file
@ -0,0 +1,27 @@
|
||||
FROM alpine:latest AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN apk add rustup build-base
|
||||
# todo: arm64 target
|
||||
RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain stable --profile default -y
|
||||
|
||||
COPY Cargo.toml /build/
|
||||
COPY Cargo.lock /build/
|
||||
|
||||
# todo: fetch dependencies first to cache
|
||||
# RUN cargo fetch
|
||||
|
||||
COPY lib/libpk /build/lib/libpk
|
||||
COPY services/api/ /build/services/api
|
||||
|
||||
RUN source "$HOME/.cargo/env" && RUSTFLAGS='-C link-arg=-s' cargo build --bin api --release --target x86_64-unknown-linux-musl
|
||||
|
||||
RUN ls /build/target
|
||||
RUN ls /build/target/release
|
||||
|
||||
FROM alpine:latest
|
||||
|
||||
COPY --from=builder /build/target/x86_64-unknown-linux-musl/release/api /bin/api
|
||||
|
||||
ENTRYPOINT [ "/bin/api" ]
|
85
services/api/src/main.rs
Normal file
85
services/api/src/main.rs
Normal file
@ -0,0 +1,85 @@
|
||||
use axum::{
|
||||
routing::{delete, get, patch, post},
|
||||
Router,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
mod middleware;
|
||||
mod util;
|
||||
|
||||
// this function is manually formatted for easier legibility of routes
|
||||
#[rustfmt::skip]
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
libpk::init_logging("api")?;
|
||||
info!("hello world");
|
||||
|
||||
// processed upside down (???) so we have to put middleware at the end
|
||||
let app = Router::new()
|
||||
.route("/v2/systems/:system_id", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/settings", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/settings", patch(util::rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/members", get(util::rproxy))
|
||||
.route("/v2/members", post(util::rproxy))
|
||||
.route("/v2/members/:member_id", get(util::rproxy))
|
||||
.route("/v2/members/:member_id", patch(util::rproxy))
|
||||
.route("/v2/members/:member_id", delete(util::rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/groups", get(util::rproxy))
|
||||
.route("/v2/groups", post(util::rproxy))
|
||||
.route("/v2/groups/:group_id", get(util::rproxy))
|
||||
.route("/v2/groups/:group_id", patch(util::rproxy))
|
||||
.route("/v2/groups/:group_id", delete(util::rproxy))
|
||||
|
||||
.route("/v2/groups/:group_id/members", get(util::rproxy))
|
||||
.route("/v2/groups/:group_id/members/add", post(util::rproxy))
|
||||
.route("/v2/groups/:group_id/members/remove", post(util::rproxy))
|
||||
.route("/v2/groups/:group_id/members/overwrite", post(util::rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/groups", get(util::rproxy))
|
||||
.route("/v2/members/:member_id/groups/add", post(util::rproxy))
|
||||
.route("/v2/members/:member_id/groups/remove", post(util::rproxy))
|
||||
.route("/v2/members/:member_id/groups/overwrite", post(util::rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches", post(util::rproxy))
|
||||
.route("/v2/systems/:system_id/fronters", get(util::rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", delete(util::rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", patch(util::rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", get(util::rproxy))
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", patch(util::rproxy))
|
||||
|
||||
.route("/v2/messages/:message_id", get(util::rproxy))
|
||||
|
||||
.route("/private/meta", get(util::rproxy))
|
||||
.route("/private/bulk_privacy/member", post(util::rproxy))
|
||||
.route("/private/bulk_privacy/group", post(util::rproxy))
|
||||
.route("/private/discord/callback", post(util::rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/oembed.json", get(util::rproxy))
|
||||
.route("/v2/members/:member_id/oembed.json", get(util::rproxy))
|
||||
.route("/v2/groups/:group_id/oembed.json", get(util::rproxy))
|
||||
|
||||
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
|
||||
.layer(axum::middleware::from_fn(middleware::logger))
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::cors))
|
||||
|
||||
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }));
|
||||
|
||||
let addr: &str = libpk::config.api.addr.as_ref();
|
||||
axum::Server::bind(&addr.parse()?)
|
||||
.serve(app.into_make_service())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
26
services/api/src/middleware/cors.rs
Normal file
26
services/api/src/middleware/cors.rs
Normal file
@ -0,0 +1,26 @@
|
||||
use axum::{
|
||||
http::{HeaderMap, HeaderValue, Method, Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
#[rustfmt::skip]
|
||||
fn add_cors_headers(headers: &mut HeaderMap) {
|
||||
headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
|
||||
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
|
||||
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
|
||||
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
|
||||
}
|
||||
|
||||
pub async fn cors<B>(request: Request<B>, next: Next<B>) -> Response {
|
||||
let mut response = if request.method() == Method::OPTIONS {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
next.run(request).await
|
||||
};
|
||||
|
||||
add_cors_headers(response.headers_mut());
|
||||
|
||||
response
|
||||
}
|
61
services/api/src/middleware/ignore_invalid_routes.rs
Normal file
61
services/api/src/middleware/ignore_invalid_routes.rs
Normal file
@ -0,0 +1,61 @@
|
||||
use axum::{
|
||||
extract::MatchedPath,
|
||||
http::{Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
|
||||
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
|
||||
path.starts_with("/v2/s/")
|
||||
|| path.starts_with("/v2/m/")
|
||||
|| path.starts_with("/v2/a/")
|
||||
|| path.starts_with("/v2/msg/")
|
||||
|| path == "/v2/s"
|
||||
|| path == "/v2/m"
|
||||
}
|
||||
|
||||
pub async fn ignore_invalid_routes<B>(request: Request<B>, next: Next<B>) -> Response {
|
||||
let path = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
if request.uri().path().starts_with("/v1") {
|
||||
(
|
||||
StatusCode::GONE,
|
||||
r#"{"message":"Unsupported API version","code":0}"#,
|
||||
)
|
||||
.into_response()
|
||||
} else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"Invalid path for API version","code":0}"#,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
|
||||
else if !request.uri().clone().path().starts_with("/v2") {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"Unsupported API version","code":0}"#,
|
||||
)
|
||||
.into_response();
|
||||
} else if path == "unknown" {
|
||||
// current prod api responds with 404 with empty body to invalid endpoints
|
||||
// just doing that here as well but i'm not sure if it's the correct behaviour
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
// yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
|
||||
// but "unknown" isn't really a valid user-agent
|
||||
else if user_agent == "unknown" {
|
||||
// please set a valid user-agent
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
} else {
|
||||
next.run(request).await
|
||||
}
|
||||
}
|
41
services/api/src/middleware/logger.rs
Normal file
41
services/api/src/middleware/logger.rs
Normal file
@ -0,0 +1,41 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use axum::{extract::MatchedPath, http::Request, middleware::Next, response::Response};
|
||||
use tracing::{info, span, Instrument, Level};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
|
||||
pub async fn logger<B>(request: Request<B>, next: Next<B>) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let request_id = header_or_unknown(request.headers().get("Fly-Request-Id"));
|
||||
let remote_ip = header_or_unknown(request.headers().get("Fly-Client-IP"));
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
let path = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
// todo: prometheus metrics
|
||||
|
||||
let request_id_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
request_id,
|
||||
remote_ip,
|
||||
method = method.as_str(),
|
||||
path,
|
||||
user_agent
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let response = next.run(request).instrument(request_id_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
info!("handled request for {} {} in {}ms", method, path, elapsed);
|
||||
|
||||
response
|
||||
}
|
11
services/api/src/middleware/mod.rs
Normal file
11
services/api/src/middleware/mod.rs
Normal file
@ -0,0 +1,11 @@
|
||||
mod cors;
|
||||
|
||||
pub use cors::cors;
|
||||
|
||||
mod logger;
|
||||
pub use logger::logger;
|
||||
|
||||
mod ignore_invalid_routes;
|
||||
pub use ignore_invalid_routes::ignore_invalid_routes;
|
||||
|
||||
pub mod ratelimit;
|
57
services/api/src/middleware/ratelimit.lua
Normal file
57
services/api/src/middleware/ratelimit.lua
Normal file
@ -0,0 +1,57 @@
|
||||
-- Copyright (c) 2017 Pavel Pravosud
|
||||
-- https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
|
||||
|
||||
local rate_limit_key = KEYS[1]
|
||||
local burst = ARGV[1]
|
||||
local rate = ARGV[2]
|
||||
local period = ARGV[3]
|
||||
-- local cost = tonumber(ARGV[4])
|
||||
local cost = 1
|
||||
|
||||
local emission_interval = period / rate
|
||||
local increment = emission_interval * cost
|
||||
local burst_offset = emission_interval * burst
|
||||
|
||||
-- redis returns time as an array containing two integers: seconds of the epoch
|
||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
|
||||
-- convert them to a floating point number. the resulting number is 16 digits,
|
||||
-- bordering on the limits of a 64-bit double-precision floating point number.
|
||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
|
||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
|
||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
|
||||
local jan_1_2017 = 1483228800
|
||||
local now = redis.call("TIME")
|
||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
|
||||
|
||||
local tat = redis.call("GET", rate_limit_key)
|
||||
|
||||
if not tat then
|
||||
tat = now
|
||||
else
|
||||
tat = tonumber(tat)
|
||||
end
|
||||
|
||||
tat = math.max(tat, now)
|
||||
|
||||
local new_tat = tat + increment
|
||||
local allow_at = new_tat - burst_offset
|
||||
|
||||
local diff = now - allow_at
|
||||
local remaining = diff / emission_interval
|
||||
|
||||
if remaining < 0 then
|
||||
local reset_after = tat - now
|
||||
local retry_after = diff * -1
|
||||
return {
|
||||
0, -- remaining
|
||||
retry_after,
|
||||
reset_after,
|
||||
}
|
||||
end
|
||||
|
||||
local reset_after = new_tat - now
|
||||
if reset_after > 0 then
|
||||
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
|
||||
end
|
||||
local retry_after = -1
|
||||
return {remaining, retry_after, reset_after}
|
154
services/api/src/middleware/ratelimit.rs
Normal file
154
services/api/src/middleware/ratelimit.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::Request,
|
||||
middleware::{FromFnLayer, Next},
|
||||
response::Response,
|
||||
};
|
||||
use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash};
|
||||
use http::{HeaderValue, StatusCode};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::util::{header_or_unknown, json_err};
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
|
||||
}
|
||||
|
||||
// todo lol
|
||||
const TOKEN2: &'static str = "h";
|
||||
|
||||
// this is awful but it works
|
||||
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
||||
let redis = libpk::config.api.ratelimit_redis_addr.as_ref().map(|val| {
|
||||
let r = fred::pool::RedisPool::new(
|
||||
fred::types::RedisConfig::from_url_centralized(val.as_ref())
|
||||
.expect("redis url is invalid"),
|
||||
10,
|
||||
)
|
||||
.expect("failed to connect to redis");
|
||||
|
||||
let handle = r.connect(Some(ReconnectPolicy::default()));
|
||||
|
||||
tokio::spawn(async move { handle });
|
||||
|
||||
let rscript = r.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Ok(()) = rscript.wait_for_connect().await {
|
||||
match rscript.script_load(LUA_SCRIPT).await {
|
||||
Ok(_) => info!("connected to redis for request rate limiting"),
|
||||
Err(err) => error!("could not load redis script: {}", err),
|
||||
}
|
||||
} else {
|
||||
error!("could not wait for connection to load redis script!");
|
||||
}
|
||||
});
|
||||
|
||||
r
|
||||
});
|
||||
|
||||
if redis.is_none() {
|
||||
warn!("running without request rate limiting!");
|
||||
}
|
||||
|
||||
axum::middleware::from_fn_with_state(redis, f)
|
||||
}
|
||||
|
||||
pub async fn do_request_ratelimited<B>(
|
||||
State(redis): State<Option<RedisPool>>,
|
||||
request: Request<B>,
|
||||
next: Next<B>,
|
||||
) -> Response {
|
||||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("Fly-Client-IP"));
|
||||
|
||||
let (rl_key, rate) = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
|
||||
{
|
||||
if header == TOKEN2 {
|
||||
("token2", 20)
|
||||
} else {
|
||||
(source_ip, 2)
|
||||
}
|
||||
} else {
|
||||
(source_ip, 2)
|
||||
};
|
||||
|
||||
let burst = 5;
|
||||
let period = 1; // seconds
|
||||
|
||||
// todo: make this static
|
||||
// though even if it's not static, it's probably cheaper than sending the entire script to redis every time
|
||||
let scriptsha = sha1_hash(&LUA_SCRIPT);
|
||||
|
||||
// local rate_limit_key = KEYS[1]
|
||||
// local burst = ARGV[1]
|
||||
// local rate = ARGV[2]
|
||||
// local period = ARGV[3]
|
||||
// return {remaining, retry_after, reset_after}
|
||||
let resp = redis
|
||||
.evalsha::<(i32, String, u64), String, Vec<&str>, Vec<i32>>(
|
||||
scriptsha,
|
||||
vec![rl_key],
|
||||
vec![burst, rate, period],
|
||||
)
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok((mut remaining, retry_after, reset_after)) => {
|
||||
let mut response = if remaining > 0 {
|
||||
next.run(request).await
|
||||
} else {
|
||||
json_err(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
format!(
|
||||
// todo: the retry_after is horribly wrong
|
||||
r#"{{"message":"429: too many requests","retry_after":{retry_after}}}"#
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
// the redis script puts burst in remaining for ??? some reason
|
||||
remaining -= burst - rate;
|
||||
|
||||
let reset_time = SystemTime::now()
|
||||
.checked_add(Duration::from_secs(reset_after))
|
||||
.expect("invalid timestamp")
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("invalid duration")
|
||||
.as_secs();
|
||||
|
||||
let headers = response.headers_mut();
|
||||
headers.insert(
|
||||
"X-RateLimit-Limit",
|
||||
HeaderValue::from_str(format!("{}", rate).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Remaining",
|
||||
HeaderValue::from_str(format!("{}", remaining).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Reset",
|
||||
HeaderValue::from_str(format!("{}", reset_time).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
|
||||
return response;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("error getting ratelimit info: {}", err);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
42
services/api/src/util.rs
Normal file
42
services/api/src/util.rs
Normal file
@ -0,0 +1,42 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{HeaderValue, Request, Response, StatusCode, Uri},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
|
||||
if let Some(value) = header {
|
||||
match value.to_str() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
error!("failed to parse header value {:#?}: {:#?}", value, err);
|
||||
"failed to parse"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
"unknown"
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn rproxy(req: Request<Body>) -> Response<Body> {
|
||||
let uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
|
||||
|
||||
match hyper_reverse_proxy::call("0.0.0.0".parse().unwrap(), &uri[..uri.len() - 1], req).await {
|
||||
Ok(response) => response,
|
||||
Err(error) => {
|
||||
error!("error proxying request: {:?}", error);
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
||||
let mut response = (code, text).into_response();
|
||||
let headers = response.headers_mut();
|
||||
headers.insert("content-type", HeaderValue::from_static("application/json"));
|
||||
response
|
||||
}
|
@ -2,25 +2,26 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
"web-proxy/redis_rate"
|
||||
)
|
||||
|
||||
var limiter *redis_rate.Limiter
|
||||
|
||||
// todo: be able to raise ratelimits for >1 consumers
|
||||
var token2 string
|
||||
func proxyTo(host string) *httputil.ReverseProxy {
|
||||
rp := httputil.NewSingleHostReverseProxy(&url.URL{
|
||||
Scheme: "http",
|
||||
Host: host,
|
||||
RawQuery: "",
|
||||
})
|
||||
rp.ModifyResponse = logTimeElapsed
|
||||
return rp
|
||||
}
|
||||
|
||||
// todo: this shouldn't be in this repo
|
||||
var remotes = map[string]*httputil.ReverseProxy{
|
||||
@ -29,22 +30,6 @@ var remotes = map[string]*httputil.ReverseProxy{
|
||||
"sentry.pluralkit.me": proxyTo("[fdaa:0:ae33:a7b:8dd7:0:a:202]:9000"),
|
||||
}
|
||||
|
||||
func init() {
|
||||
redisHost := requireEnv("REDIS_HOST")
|
||||
redisPassword := requireEnv("REDIS_PASSWORD")
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: redisHost,
|
||||
Username: "default",
|
||||
Password: redisPassword,
|
||||
})
|
||||
limiter = redis_rate.NewLimiter(rdb)
|
||||
|
||||
token2 = requireEnv("TOKEN2")
|
||||
|
||||
remotes["dash.pluralkit.me"].ModifyResponse = modifyDashResponse
|
||||
}
|
||||
|
||||
type ProxyHandler struct{}
|
||||
|
||||
func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -61,48 +46,6 @@ func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Host == "api.pluralkit.me" {
|
||||
// root
|
||||
if r.URL.Path == "" {
|
||||
// api root path redirects to docs
|
||||
http.Redirect(rw, r, "https://pluralkit.me/api/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// CORS headers
|
||||
rw.Header().Add("Access-Control-Allow-Origin", "*")
|
||||
rw.Header().Add("Access-Control-Allow-Methods", "*")
|
||||
rw.Header().Add("Access-Control-Allow-Credentials", "true")
|
||||
rw.Header().Add("Access-Control-Allow-Headers", "Content-Type, Authorization, sentry-trace, User-Agent")
|
||||
rw.Header().Add("Access-Control-Max-Age", "86400")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
rw.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/" {
|
||||
http.Redirect(rw, r, "https://pluralkit.me/api", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.URL.Path, "/v1") {
|
||||
rw.Header().Set("content-type", "application/json")
|
||||
rw.WriteHeader(410)
|
||||
rw.Write([]byte(`{"message":"Unsupported API version","code":0}`))
|
||||
}
|
||||
|
||||
if is_trying_to_use_v1_path_on_v2(r.URL.Path) {
|
||||
rw.WriteHeader(400)
|
||||
rw.Write([]byte(`{"message":"Invalid path for API version","code":0}`))
|
||||
return
|
||||
}
|
||||
|
||||
if is_api_ratelimited(rw, r) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
r = r.WithContext(context.WithValue(r.Context(), "req-time", startTime))
|
||||
|
||||
@ -119,40 +62,14 @@ func logTimeElapsed(resp *http.Response) error {
|
||||
"domain": r.Host,
|
||||
"method": r.Method,
|
||||
"status": strconv.Itoa(resp.StatusCode),
|
||||
"route": cleanPath(r.Host, r.URL.Path),
|
||||
"route": r.URL.Path,
|
||||
}).Observe(elapsed.Seconds())
|
||||
|
||||
log, _ := json.Marshal(map[string]interface{}{
|
||||
"remote_ip": r.Header.Get("Fly-Client-IP"),
|
||||
"method": r.Method,
|
||||
"host": r.Host,
|
||||
"route": r.URL.Path,
|
||||
"route_clean": cleanPath(r.Host, r.URL.Path),
|
||||
"status": resp.StatusCode,
|
||||
"elapsed": elapsed.Milliseconds(),
|
||||
"user_agent": r.Header.Get("User-Agent"),
|
||||
})
|
||||
fmt.Println(string(log))
|
||||
|
||||
// log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
|
||||
log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func modifyDashResponse(resp *http.Response) error {
|
||||
r := resp.Request
|
||||
|
||||
// cache built+hashed dashboard js/css files forever
|
||||
is_dash_static_asset := strings.HasPrefix(r.URL.Path, "/assets/") &&
|
||||
(strings.HasSuffix(r.URL.Path, ".js") || strings.HasSuffix(r.URL.Path, ".css") || strings.HasSuffix(r.URL.Path, ".map"))
|
||||
|
||||
if is_dash_static_asset && resp.StatusCode == 200 {
|
||||
resp.Header.Add("Cache-Control", "max-age=31536000, s-maxage=31536000, immutable")
|
||||
}
|
||||
|
||||
return logTimeElapsed(resp)
|
||||
}
|
||||
|
||||
func main() {
|
||||
prometheus.MustRegister(metric)
|
||||
|
||||
@ -161,3 +78,11 @@ func main() {
|
||||
|
||||
http.ListenAndServe(":8080", ProxyHandler{})
|
||||
}
|
||||
|
||||
var metric = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "pk_http_requests",
|
||||
Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
|
||||
},
|
||||
[]string{"domain", "method", "status", "route"},
|
||||
)
|
||||
|
@ -1,43 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"web-proxy/redis_rate"
|
||||
)
|
||||
|
||||
func is_api_ratelimited(rw http.ResponseWriter, r *http.Request) bool {
|
||||
var limit int
|
||||
var key string
|
||||
|
||||
if r.Header.Get("X-PluralKit-App") == token2 {
|
||||
limit = 20
|
||||
key = "token2"
|
||||
} else {
|
||||
limit = 2
|
||||
key = r.Header.Get("Fly-Client-IP")
|
||||
}
|
||||
|
||||
res, err := limiter.Allow(r.Context(), "ratelimit:"+key, redis_rate.Limit{
|
||||
Period: time.Second,
|
||||
Rate: limit,
|
||||
Burst: 5,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rw.Header().Set("X-RateLimit-Limit", fmt.Sprint(limit))
|
||||
rw.Header().Set("X-RateLimit-Remaining", fmt.Sprint(res.Remaining))
|
||||
rw.Header().Set("X-RateLimit-Reset", fmt.Sprint(time.Now().Add(res.ResetAfter).UnixNano()/1_000_000))
|
||||
|
||||
if res.Allowed < 1 {
|
||||
rw.WriteHeader(429)
|
||||
rw.Write([]byte(`{"message":"429: too many requests","retry_after":` + fmt.Sprint(res.RetryAfter.Milliseconds()) + `,"code":0}`))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
Copyright (c) 2013 The github.com/go-redis/redis_rate Authors.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
@ -1,140 +0,0 @@
|
||||
package redis_rate
|
||||
|
||||
import "github.com/go-redis/redis/v8"
|
||||
|
||||
// pluralkit changes:
|
||||
// fly's hosted redis doesn't support replicate commands
|
||||
// we can remove it since it's a single host
|
||||
|
||||
// Copyright (c) 2017 Pavel Pravosud
|
||||
// https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
|
||||
var allowN = redis.NewScript(`
|
||||
-- this script has side-effects, so it requires replicate commands mode
|
||||
-- redis.replicate_commands()
|
||||
|
||||
local rate_limit_key = KEYS[1]
|
||||
local burst = ARGV[1]
|
||||
local rate = ARGV[2]
|
||||
local period = ARGV[3]
|
||||
local cost = tonumber(ARGV[4])
|
||||
|
||||
local emission_interval = period / rate
|
||||
local increment = emission_interval * cost
|
||||
local burst_offset = emission_interval * burst
|
||||
|
||||
-- redis returns time as an array containing two integers: seconds of the epoch
|
||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
|
||||
-- convert them to a floating point number. the resulting number is 16 digits,
|
||||
-- bordering on the limits of a 64-bit double-precision floating point number.
|
||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
|
||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
|
||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
|
||||
local jan_1_2017 = 1483228800
|
||||
local now = redis.call("TIME")
|
||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
|
||||
|
||||
local tat = redis.call("GET", rate_limit_key)
|
||||
|
||||
if not tat then
|
||||
tat = now
|
||||
else
|
||||
tat = tonumber(tat)
|
||||
end
|
||||
|
||||
tat = math.max(tat, now)
|
||||
|
||||
local new_tat = tat + increment
|
||||
local allow_at = new_tat - burst_offset
|
||||
|
||||
local diff = now - allow_at
|
||||
local remaining = diff / emission_interval
|
||||
|
||||
if remaining < 0 then
|
||||
local reset_after = tat - now
|
||||
local retry_after = diff * -1
|
||||
return {
|
||||
0, -- allowed
|
||||
0, -- remaining
|
||||
tostring(retry_after),
|
||||
tostring(reset_after),
|
||||
}
|
||||
end
|
||||
|
||||
local reset_after = new_tat - now
|
||||
if reset_after > 0 then
|
||||
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
|
||||
end
|
||||
local retry_after = -1
|
||||
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
|
||||
`)
|
||||
|
||||
var allowAtMost = redis.NewScript(`
|
||||
-- this script has side-effects, so it requires replicate commands mode
|
||||
-- redis.replicate_commands()
|
||||
|
||||
local rate_limit_key = KEYS[1]
|
||||
local burst = ARGV[1]
|
||||
local rate = ARGV[2]
|
||||
local period = ARGV[3]
|
||||
local cost = tonumber(ARGV[4])
|
||||
|
||||
local emission_interval = period / rate
|
||||
local burst_offset = emission_interval * burst
|
||||
|
||||
-- redis returns time as an array containing two integers: seconds of the epoch
|
||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
|
||||
-- convert them to a floating point number. the resulting number is 16 digits,
|
||||
-- bordering on the limits of a 64-bit double-precision floating point number.
|
||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
|
||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
|
||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
|
||||
local jan_1_2017 = 1483228800
|
||||
local now = redis.call("TIME")
|
||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
|
||||
|
||||
local tat = redis.call("GET", rate_limit_key)
|
||||
|
||||
if not tat then
|
||||
tat = now
|
||||
else
|
||||
tat = tonumber(tat)
|
||||
end
|
||||
|
||||
tat = math.max(tat, now)
|
||||
|
||||
local diff = now - (tat - burst_offset)
|
||||
local remaining = diff / emission_interval
|
||||
|
||||
if remaining < 1 then
|
||||
local reset_after = tat - now
|
||||
local retry_after = emission_interval - diff
|
||||
return {
|
||||
0, -- allowed
|
||||
0, -- remaining
|
||||
tostring(retry_after),
|
||||
tostring(reset_after),
|
||||
}
|
||||
end
|
||||
|
||||
if remaining < cost then
|
||||
cost = remaining
|
||||
remaining = 0
|
||||
else
|
||||
remaining = remaining - cost
|
||||
end
|
||||
|
||||
local increment = emission_interval * cost
|
||||
local new_tat = tat + increment
|
||||
|
||||
local reset_after = new_tat - now
|
||||
if reset_after > 0 then
|
||||
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
|
||||
end
|
||||
|
||||
return {
|
||||
cost,
|
||||
remaining,
|
||||
tostring(-1),
|
||||
tostring(reset_after),
|
||||
}
|
||||
`)
|
@ -1,198 +0,0 @@
|
||||
package redis_rate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
const redisPrefix = "rate:"
|
||||
|
||||
type rediser interface {
|
||||
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
|
||||
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
|
||||
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
|
||||
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
|
||||
Del(ctx context.Context, keys ...string) *redis.IntCmd
|
||||
}
|
||||
|
||||
type Limit struct {
|
||||
Rate int
|
||||
Burst int
|
||||
Period time.Duration
|
||||
}
|
||||
|
||||
func (l Limit) String() string {
|
||||
return fmt.Sprintf("%d req/%s (burst %d)", l.Rate, fmtDur(l.Period), l.Burst)
|
||||
}
|
||||
|
||||
func (l Limit) IsZero() bool {
|
||||
return l == Limit{}
|
||||
}
|
||||
|
||||
func fmtDur(d time.Duration) string {
|
||||
switch d {
|
||||
case time.Second:
|
||||
return "s"
|
||||
case time.Minute:
|
||||
return "m"
|
||||
case time.Hour:
|
||||
return "h"
|
||||
}
|
||||
return d.String()
|
||||
}
|
||||
|
||||
func PerSecond(rate int) Limit {
|
||||
return Limit{
|
||||
Rate: rate,
|
||||
Period: time.Second,
|
||||
Burst: rate,
|
||||
}
|
||||
}
|
||||
|
||||
func PerMinute(rate int) Limit {
|
||||
return Limit{
|
||||
Rate: rate,
|
||||
Period: time.Minute,
|
||||
Burst: rate,
|
||||
}
|
||||
}
|
||||
|
||||
func PerHour(rate int) Limit {
|
||||
return Limit{
|
||||
Rate: rate,
|
||||
Period: time.Hour,
|
||||
Burst: rate,
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Limiter controls how frequently events are allowed to happen.
|
||||
type Limiter struct {
|
||||
rdb rediser
|
||||
}
|
||||
|
||||
// NewLimiter returns a new Limiter.
|
||||
func NewLimiter(rdb rediser) *Limiter {
|
||||
return &Limiter{
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow is a shortcut for AllowN(ctx, key, limit, 1).
|
||||
func (l Limiter) Allow(ctx context.Context, key string, limit Limit) (*Result, error) {
|
||||
return l.AllowN(ctx, key, limit, 1)
|
||||
}
|
||||
|
||||
// AllowN reports whether n events may happen at time now.
|
||||
func (l Limiter) AllowN(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
limit Limit,
|
||||
n int,
|
||||
) (*Result, error) {
|
||||
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
|
||||
v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values = v.([]interface{})
|
||||
|
||||
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &Result{
|
||||
Limit: limit,
|
||||
Allowed: int(values[0].(int64)),
|
||||
Remaining: int(values[1].(int64)),
|
||||
RetryAfter: dur(retryAfter),
|
||||
ResetAfter: dur(resetAfter),
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// AllowAtMost reports whether at most n events may happen at time now.
|
||||
// It returns number of allowed events that is less than or equal to n.
|
||||
func (l Limiter) AllowAtMost(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
limit Limit,
|
||||
n int,
|
||||
) (*Result, error) {
|
||||
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
|
||||
v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values = v.([]interface{})
|
||||
|
||||
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &Result{
|
||||
Limit: limit,
|
||||
Allowed: int(values[0].(int64)),
|
||||
Remaining: int(values[1].(int64)),
|
||||
RetryAfter: dur(retryAfter),
|
||||
ResetAfter: dur(resetAfter),
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Reset gets a key and reset all limitations and previous usages
|
||||
func (l *Limiter) Reset(ctx context.Context, key string) error {
|
||||
return l.rdb.Del(ctx, redisPrefix+key).Err()
|
||||
}
|
||||
|
||||
func dur(f float64) time.Duration {
|
||||
if f == -1 {
|
||||
return -1
|
||||
}
|
||||
return time.Duration(f * float64(time.Second))
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
// Limit is the limit that was used to obtain this result.
|
||||
Limit Limit
|
||||
|
||||
// Allowed is the number of events that may happen at time now.
|
||||
Allowed int
|
||||
|
||||
// Remaining is the maximum number of requests that could be
|
||||
// permitted instantaneously for this key given the current
|
||||
// state. For example, if a rate limiter allows 10 requests per
|
||||
// second and has already received 6 requests for this key this
|
||||
// second, Remaining would be 4.
|
||||
Remaining int
|
||||
|
||||
// RetryAfter is the time until the next request will be permitted.
|
||||
// It should be -1 unless the rate limit has been exceeded.
|
||||
RetryAfter time.Duration
|
||||
|
||||
// ResetAfter is the time until the RateLimiter returns to its
|
||||
// initial state for a given key. For example, if a rate limiter
|
||||
// manages requests per second and received one request 200ms ago,
|
||||
// Reset would return 800ms. You can also think of this as the time
|
||||
// until Limit and Remaining will be equal.
|
||||
ResetAfter time.Duration
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
var metric = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "pk_http_requests",
|
||||
Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
|
||||
},
|
||||
[]string{"domain", "method", "status", "route"},
|
||||
)
|
||||
|
||||
func proxyTo(host string) *httputil.ReverseProxy {
|
||||
rp := httputil.NewSingleHostReverseProxy(&url.URL{
|
||||
Scheme: "http",
|
||||
Host: host,
|
||||
RawQuery: "",
|
||||
})
|
||||
rp.ModifyResponse = logTimeElapsed
|
||||
return rp
|
||||
}
|
||||
|
||||
var systemsRegex = regexp.MustCompile("systems/[^/{}]+")
|
||||
var membersRegex = regexp.MustCompile("members/[^/{}]+")
|
||||
var groupsRegex = regexp.MustCompile("groups/[^/{}]+")
|
||||
var switchesRegex = regexp.MustCompile("switches/[^/{}]+")
|
||||
var guildsRegex = regexp.MustCompile("guilds/[^/{}]+")
|
||||
var messagesRegex = regexp.MustCompile("messages/[^/{}]+")
|
||||
|
||||
func cleanPath(host, path string) string {
|
||||
if host != "api.pluralkit.me" {
|
||||
return ""
|
||||
}
|
||||
|
||||
path = strings.ToLower(path)
|
||||
|
||||
if !(strings.HasPrefix(path, "/v2") || strings.HasPrefix(path, "/private")) {
|
||||
return ""
|
||||
}
|
||||
|
||||
path = systemsRegex.ReplaceAllString(path, "systems/{systemRef}")
|
||||
path = membersRegex.ReplaceAllString(path, "members/{memberRef}")
|
||||
path = groupsRegex.ReplaceAllString(path, "groups/{groupRef}")
|
||||
path = switchesRegex.ReplaceAllString(path, "switches/{switchRef}")
|
||||
path = guildsRegex.ReplaceAllString(path, "guilds/{guild_id}")
|
||||
path = messagesRegex.ReplaceAllString(path, "messages/{message_id}")
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
func requireEnv(key string) string {
|
||||
if val, ok := os.LookupEnv(key); !ok {
|
||||
panic("missing `" + key + "` in environment")
|
||||
} else {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
func is_trying_to_use_v1_path_on_v2(path string) bool {
|
||||
return strings.HasPrefix(path, "/v2/s/") ||
|
||||
strings.HasPrefix(path, "/v2/m/") ||
|
||||
strings.HasPrefix(path, "/v2/a/") ||
|
||||
strings.HasPrefix(path, "/v2/msg/") ||
|
||||
path == "/v2/s" ||
|
||||
path == "/v2/m"
|
||||
}
|
Loading…
Reference in New Issue
Block a user