feat(api): init rust rewrite
This commit is contained in:
		@@ -3,16 +3,17 @@
 | 
			
		||||
 | 
			
		||||
# Include project code and build files
 | 
			
		||||
!PluralKit.*/
 | 
			
		||||
!gateway/
 | 
			
		||||
!myriad_rs/
 | 
			
		||||
!services/
 | 
			
		||||
!lib/
 | 
			
		||||
!Myriad/
 | 
			
		||||
!PluralKit.sln
 | 
			
		||||
!nuget.config
 | 
			
		||||
!.git
 | 
			
		||||
!proto
 | 
			
		||||
!scripts/run-clustered.sh
 | 
			
		||||
!dashboard
 | 
			
		||||
!scheduled_tasks
 | 
			
		||||
 | 
			
		||||
!Cargo.toml
 | 
			
		||||
!Cargo.lock
 | 
			
		||||
!PluralKit.sln
 | 
			
		||||
!nuget.config
 | 
			
		||||
 | 
			
		||||
# Re-exclude host build artifact directories
 | 
			
		||||
**/bin
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										36
									
								
								.github/workflows/api.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								.github/workflows/api.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
name: Build and push API Docker image
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
      - 'rust-api'
 | 
			
		||||
    paths:
 | 
			
		||||
    - 'lib/pklib/**'
 | 
			
		||||
    - 'services/api/**'
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  deploy:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    permissions:
 | 
			
		||||
      packages: write
 | 
			
		||||
    if: github.repository == 'PluralKit/PluralKit'
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: docker/login-action@v1
 | 
			
		||||
        with:
 | 
			
		||||
          registry: ghcr.io
 | 
			
		||||
          username: ${{ github.actor }}
 | 
			
		||||
          password: ${{ secrets.CR_PAT }}
 | 
			
		||||
      - uses: actions/checkout@v2
 | 
			
		||||
      - run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" >> $GITHUB_ENV
 | 
			
		||||
      - uses: docker/build-push-action@v2
 | 
			
		||||
        with:
 | 
			
		||||
          # https://github.com/docker/build-push-action/issues/378
 | 
			
		||||
          context: .
 | 
			
		||||
          file: services/api/Dockerfile
 | 
			
		||||
          push: true
 | 
			
		||||
          tags: |
 | 
			
		||||
            ghcr.io/pluralkit/api:${{ env.BRANCH_NAME }}
 | 
			
		||||
            ghcr.io/pluralkit/api:${{ github.sha }}
 | 
			
		||||
            ghcr.io/pluralkit/api:latest
 | 
			
		||||
          cache-from: type=registry,ref=ghcr.io/pluralkit/pluralkit:${{ env.BRANCH_NAME }}
 | 
			
		||||
          cache-to: type=inline
 | 
			
		||||
							
								
								
									
										1741
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										1741
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										7
									
								
								Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								Cargo.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
			
		||||
[workspace]
 | 
			
		||||
members = [
 | 
			
		||||
    "./lib/libpk",
 | 
			
		||||
    "./services/api"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# todo: add workspace dependencies here
 | 
			
		||||
@@ -3,9 +3,11 @@ package main
 | 
			
		||||
import (
 | 
			
		||||
	"embed"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"html"
 | 
			
		||||
	"io"
 | 
			
		||||
	_fs "io/fs"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
@@ -78,12 +80,21 @@ func notFoundHandler(rw http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		data = []byte(strings.Replace(string(data), `<!-- extra data -->`, defaultEmbed+versionJS, 1))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
	if errors.Is(err, _fs.ErrNotExist) {
 | 
			
		||||
		rw.WriteHeader(http.StatusNotFound)
 | 
			
		||||
	} else if err != nil {
 | 
			
		||||
		rw.WriteHeader(http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	} else {
 | 
			
		||||
		// cache built+hashed dashboard js/css files forever
 | 
			
		||||
		is_dash_static_asset := strings.HasPrefix(r.URL.Path, "/assets/") &&
 | 
			
		||||
			(strings.HasSuffix(r.URL.Path, ".js") || strings.HasSuffix(r.URL.Path, ".css") || strings.HasSuffix(r.URL.Path, ".map"))
 | 
			
		||||
 | 
			
		||||
	rw.Write(data)
 | 
			
		||||
		if is_dash_static_asset {
 | 
			
		||||
			rw.Header().Add("Cache-Control", "max-age=31536000, s-maxage=31536000, immutable")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rw.Write(data)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// explanation for createEmbed:
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										17
									
								
								lib/libpk/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								lib/libpk/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
[package]
 | 
			
		||||
name = "libpk"
 | 
			
		||||
version = "0.1.0"
 | 
			
		||||
edition = "2021"
 | 
			
		||||
 | 
			
		||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
anyhow = "1.0.69"
 | 
			
		||||
config = "0.13.3"
 | 
			
		||||
gethostname = "0.4.1"
 | 
			
		||||
lazy_static = "1.4.0"
 | 
			
		||||
serde = "1.0.152"
 | 
			
		||||
tokio = { version = "1.25.0", features = ["full"] }
 | 
			
		||||
tracing = "0.1.37"
 | 
			
		||||
tracing-gelf = "0.7.1"
 | 
			
		||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
 | 
			
		||||
							
								
								
									
										50
									
								
								lib/libpk/src/_config.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								lib/libpk/src/_config.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
use config::Config;
 | 
			
		||||
use lazy_static::lazy_static;
 | 
			
		||||
use serde::Deserialize;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Debug)]
 | 
			
		||||
pub struct DiscordConfig {
 | 
			
		||||
    pub client_id: u32,
 | 
			
		||||
    pub bot_token: String,
 | 
			
		||||
    pub client_secret: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Debug)]
 | 
			
		||||
pub struct DatabaseConfig {
 | 
			
		||||
    pub(crate) _data_db_uri: String,
 | 
			
		||||
    pub(crate) _messages_db_uri: String,
 | 
			
		||||
    pub(crate) _db_password: Option<String>,
 | 
			
		||||
    pub data_redis_addr: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn _default_api_addr() -> String {
 | 
			
		||||
    "0.0.0.0:5000".to_string()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Debug)]
 | 
			
		||||
pub struct ApiConfig {
 | 
			
		||||
    #[serde(default = "_default_api_addr")]
 | 
			
		||||
    pub addr: String,
 | 
			
		||||
 | 
			
		||||
    #[serde(default)]
 | 
			
		||||
    pub ratelimit_redis_addr: Option<String>,
 | 
			
		||||
 | 
			
		||||
    pub remote_url: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Debug)]
 | 
			
		||||
pub struct PKConfig {
 | 
			
		||||
    pub discord: DiscordConfig,
 | 
			
		||||
    pub api: ApiConfig,
 | 
			
		||||
 | 
			
		||||
    pub(crate) gelf_log_url: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
lazy_static! {
 | 
			
		||||
    #[derive(Debug)]
 | 
			
		||||
    pub static ref CONFIG: Arc<PKConfig> = Arc::new(Config::builder()
 | 
			
		||||
    .add_source(config::Environment::with_prefix("pluralkit").separator("__"))
 | 
			
		||||
    .build().unwrap()
 | 
			
		||||
    .try_deserialize::<PKConfig>().unwrap());
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										27
									
								
								lib/libpk/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								lib/libpk/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
use gethostname::gethostname;
 | 
			
		||||
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Registry};
 | 
			
		||||
 | 
			
		||||
mod _config;
 | 
			
		||||
pub use crate::_config::CONFIG as config;
 | 
			
		||||
 | 
			
		||||
pub fn init_logging(component: &str) -> anyhow::Result<()> {
 | 
			
		||||
    let subscriber = Registry::default()
 | 
			
		||||
        .with(EnvFilter::from_default_env())
 | 
			
		||||
        .with(tracing_subscriber::fmt::layer());
 | 
			
		||||
 | 
			
		||||
    if let Some(gelf_url) = &config.gelf_log_url {
 | 
			
		||||
        let gelf_logger = tracing_gelf::Logger::builder()
 | 
			
		||||
            .additional_field("component", component)
 | 
			
		||||
            .additional_field("hostname", gethostname().to_str());
 | 
			
		||||
        let mut conn_handle = gelf_logger
 | 
			
		||||
            .init_udp_with_subscriber(gelf_url, subscriber)
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        tokio::spawn(async move { conn_handle.connect().await });
 | 
			
		||||
    } else {
 | 
			
		||||
        // gelf_logger internally sets the global subscriber
 | 
			
		||||
        tracing::subscriber::set_global_default(subscriber)
 | 
			
		||||
            .expect("unable to set global subscriber");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										18
									
								
								services/api/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								services/api/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
			
		||||
[package]
 | 
			
		||||
name = "api"
 | 
			
		||||
version = "0.1.0"
 | 
			
		||||
edition = "2021"
 | 
			
		||||
 | 
			
		||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
anyhow = "1.0.69"
 | 
			
		||||
axum = "0.6.4"
 | 
			
		||||
fred = { version = "5.2.0", default-features = false, features = ["tracing", "pool-prefer-active"] }
 | 
			
		||||
http = "0.2.8"
 | 
			
		||||
hyper-reverse-proxy = "0.5.1"
 | 
			
		||||
lazy_static = "1.4.0"
 | 
			
		||||
libpk = { path = "../../lib/libpk" }
 | 
			
		||||
tokio = { version = "1.25.0", features = ["full"] }
 | 
			
		||||
tower = "0.4.13"
 | 
			
		||||
tracing = "0.1.37"
 | 
			
		||||
							
								
								
									
										27
									
								
								services/api/Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								services/api/Dockerfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
FROM alpine:latest AS builder
 | 
			
		||||
 | 
			
		||||
WORKDIR /build
 | 
			
		||||
 | 
			
		||||
RUN apk add rustup build-base
 | 
			
		||||
# todo: arm64 target
 | 
			
		||||
RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain stable --profile default -y
 | 
			
		||||
 | 
			
		||||
COPY Cargo.toml /build/
 | 
			
		||||
COPY Cargo.lock /build/
 | 
			
		||||
 | 
			
		||||
# todo: fetch dependencies first to cache
 | 
			
		||||
# RUN cargo fetch
 | 
			
		||||
 | 
			
		||||
COPY lib/libpk /build/lib/libpk
 | 
			
		||||
COPY services/api/ /build/services/api
 | 
			
		||||
 | 
			
		||||
RUN source "$HOME/.cargo/env" && RUSTFLAGS='-C link-arg=-s' cargo build --bin api --release --target x86_64-unknown-linux-musl
 | 
			
		||||
 | 
			
		||||
RUN ls /build/target
 | 
			
		||||
RUN ls /build/target/release
 | 
			
		||||
 | 
			
		||||
FROM alpine:latest
 | 
			
		||||
 | 
			
		||||
COPY --from=builder /build/target/x86_64-unknown-linux-musl/release/api /bin/api
 | 
			
		||||
 | 
			
		||||
ENTRYPOINT [ "/bin/api" ]
 | 
			
		||||
							
								
								
									
										85
									
								
								services/api/src/main.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								services/api/src/main.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,85 @@
 | 
			
		||||
use axum::{
 | 
			
		||||
    routing::{delete, get, patch, post},
 | 
			
		||||
    Router,
 | 
			
		||||
};
 | 
			
		||||
use tracing::info;
 | 
			
		||||
 | 
			
		||||
mod middleware;
 | 
			
		||||
mod util;
 | 
			
		||||
 | 
			
		||||
// this function is manually formatted for easier legibility of routes
 | 
			
		||||
#[rustfmt::skip]
 | 
			
		||||
#[tokio::main]
 | 
			
		||||
async fn main() -> anyhow::Result<()> {
 | 
			
		||||
    libpk::init_logging("api")?;
 | 
			
		||||
    info!("hello world");
 | 
			
		||||
 | 
			
		||||
    // processed upside down (???) so we have to put middleware at the end
 | 
			
		||||
    let app = Router::new()
 | 
			
		||||
        .route("/v2/systems/:system_id", get(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id", patch(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/settings", get(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/settings", patch(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/systems/:system_id/members", get(util::rproxy))
 | 
			
		||||
        .route("/v2/members", post(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id", get(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id", patch(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id", delete(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/systems/:system_id/groups", get(util::rproxy))
 | 
			
		||||
        .route("/v2/groups", post(util::rproxy))
 | 
			
		||||
        .route("/v2/groups/:group_id", get(util::rproxy))
 | 
			
		||||
        .route("/v2/groups/:group_id", patch(util::rproxy))
 | 
			
		||||
        .route("/v2/groups/:group_id", delete(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/groups/:group_id/members", get(util::rproxy))
 | 
			
		||||
        .route("/v2/groups/:group_id/members/add", post(util::rproxy))
 | 
			
		||||
        .route("/v2/groups/:group_id/members/remove", post(util::rproxy))
 | 
			
		||||
        .route("/v2/groups/:group_id/members/overwrite", post(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/members/:member_id/groups", get(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id/groups/add", post(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id/groups/remove", post(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id/groups/overwrite", post(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/systems/:system_id/switches", get(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/switches", post(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/fronters", get(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/systems/:system_id/switches/:switch_id", get(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/switches/:switch_id", patch(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/switches/:switch_id/members", patch(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/switches/:switch_id", delete(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/systems/:system_id/guilds/:guild_id", get(util::rproxy))
 | 
			
		||||
        .route("/v2/systems/:system_id/guilds/:guild_id", patch(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/members/:member_id/guilds/:guild_id", get(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id/guilds/:guild_id", patch(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/messages/:message_id", get(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/private/meta", get(util::rproxy))
 | 
			
		||||
        .route("/private/bulk_privacy/member", post(util::rproxy))
 | 
			
		||||
        .route("/private/bulk_privacy/group", post(util::rproxy))
 | 
			
		||||
        .route("/private/discord/callback", post(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .route("/v2/systems/:system_id/oembed.json", get(util::rproxy))
 | 
			
		||||
        .route("/v2/members/:member_id/oembed.json", get(util::rproxy))
 | 
			
		||||
        .route("/v2/groups/:group_id/oembed.json", get(util::rproxy))
 | 
			
		||||
 | 
			
		||||
        .layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
 | 
			
		||||
        .layer(axum::middleware::from_fn(middleware::logger))
 | 
			
		||||
        .layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
 | 
			
		||||
        .layer(axum::middleware::from_fn(middleware::cors))
 | 
			
		||||
 | 
			
		||||
        .route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }));
 | 
			
		||||
 | 
			
		||||
    let addr: &str = libpk::config.api.addr.as_ref();
 | 
			
		||||
    axum::Server::bind(&addr.parse()?)
 | 
			
		||||
        .serve(app.into_make_service())
 | 
			
		||||
        .await?;
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										26
									
								
								services/api/src/middleware/cors.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								services/api/src/middleware/cors.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
			
		||||
use axum::{
 | 
			
		||||
    http::{HeaderMap, HeaderValue, Method, Request, StatusCode},
 | 
			
		||||
    middleware::Next,
 | 
			
		||||
    response::{IntoResponse, Response},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#[rustfmt::skip]
 | 
			
		||||
fn add_cors_headers(headers: &mut HeaderMap) {
 | 
			
		||||
    headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
 | 
			
		||||
    headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
 | 
			
		||||
    headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
 | 
			
		||||
    headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
 | 
			
		||||
    headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn cors<B>(request: Request<B>, next: Next<B>) -> Response {
 | 
			
		||||
    let mut response = if request.method() == Method::OPTIONS {
 | 
			
		||||
        StatusCode::OK.into_response()
 | 
			
		||||
    } else {
 | 
			
		||||
        next.run(request).await
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    add_cors_headers(response.headers_mut());
 | 
			
		||||
 | 
			
		||||
    response
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										61
									
								
								services/api/src/middleware/ignore_invalid_routes.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								services/api/src/middleware/ignore_invalid_routes.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
			
		||||
use axum::{
 | 
			
		||||
    extract::MatchedPath,
 | 
			
		||||
    http::{Request, StatusCode},
 | 
			
		||||
    middleware::Next,
 | 
			
		||||
    response::{IntoResponse, Response},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::util::header_or_unknown;
 | 
			
		||||
 | 
			
		||||
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
 | 
			
		||||
    path.starts_with("/v2/s/")
 | 
			
		||||
        || path.starts_with("/v2/m/")
 | 
			
		||||
        || path.starts_with("/v2/a/")
 | 
			
		||||
        || path.starts_with("/v2/msg/")
 | 
			
		||||
        || path == "/v2/s"
 | 
			
		||||
        || path == "/v2/m"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn ignore_invalid_routes<B>(request: Request<B>, next: Next<B>) -> Response {
 | 
			
		||||
    let path = request
 | 
			
		||||
        .extensions()
 | 
			
		||||
        .get::<MatchedPath>()
 | 
			
		||||
        .cloned()
 | 
			
		||||
        .map(|v| v.as_str().to_string())
 | 
			
		||||
        .unwrap_or("unknown".to_string());
 | 
			
		||||
    let user_agent = header_or_unknown(request.headers().get("User-Agent"));
 | 
			
		||||
 | 
			
		||||
    if request.uri().path().starts_with("/v1") {
 | 
			
		||||
        (
 | 
			
		||||
            StatusCode::GONE,
 | 
			
		||||
            r#"{"message":"Unsupported API version","code":0}"#,
 | 
			
		||||
        )
 | 
			
		||||
            .into_response()
 | 
			
		||||
    } else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
 | 
			
		||||
        (
 | 
			
		||||
            StatusCode::BAD_REQUEST,
 | 
			
		||||
            r#"{"message":"Invalid path for API version","code":0}"#,
 | 
			
		||||
        )
 | 
			
		||||
            .into_response()
 | 
			
		||||
    }
 | 
			
		||||
    // we ignored v1 routes earlier, now let's ignore all non-v2 routes
 | 
			
		||||
    else if !request.uri().clone().path().starts_with("/v2") {
 | 
			
		||||
        return (
 | 
			
		||||
            StatusCode::BAD_REQUEST,
 | 
			
		||||
            r#"{"message":"Unsupported API version","code":0}"#,
 | 
			
		||||
        )
 | 
			
		||||
            .into_response();
 | 
			
		||||
    } else if path == "unknown" {
 | 
			
		||||
        // current prod api responds with 404 with empty body to invalid endpoints
 | 
			
		||||
        // just doing that here as well but i'm not sure if it's the correct behaviour
 | 
			
		||||
        return StatusCode::NOT_FOUND.into_response();
 | 
			
		||||
    }
 | 
			
		||||
    // yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
 | 
			
		||||
    // but "unknown" isn't really a valid user-agent
 | 
			
		||||
    else if user_agent == "unknown" {
 | 
			
		||||
        // please set a valid user-agent
 | 
			
		||||
        return StatusCode::FORBIDDEN.into_response();
 | 
			
		||||
    } else {
 | 
			
		||||
        next.run(request).await
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										41
									
								
								services/api/src/middleware/logger.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								services/api/src/middleware/logger.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
use std::time::Instant;
 | 
			
		||||
 | 
			
		||||
use axum::{extract::MatchedPath, http::Request, middleware::Next, response::Response};
 | 
			
		||||
use tracing::{info, span, Instrument, Level};
 | 
			
		||||
 | 
			
		||||
use crate::util::header_or_unknown;
 | 
			
		||||
 | 
			
		||||
pub async fn logger<B>(request: Request<B>, next: Next<B>) -> Response {
 | 
			
		||||
    let method = request.method().clone();
 | 
			
		||||
 | 
			
		||||
    let request_id = header_or_unknown(request.headers().get("Fly-Request-Id"));
 | 
			
		||||
    let remote_ip = header_or_unknown(request.headers().get("Fly-Client-IP"));
 | 
			
		||||
    let user_agent = header_or_unknown(request.headers().get("User-Agent"));
 | 
			
		||||
 | 
			
		||||
    let path = request
 | 
			
		||||
        .extensions()
 | 
			
		||||
        .get::<MatchedPath>()
 | 
			
		||||
        .cloned()
 | 
			
		||||
        .map(|v| v.as_str().to_string())
 | 
			
		||||
        .unwrap_or("unknown".to_string());
 | 
			
		||||
 | 
			
		||||
    // todo: prometheus metrics
 | 
			
		||||
 | 
			
		||||
    let request_id_span = span!(
 | 
			
		||||
        Level::INFO,
 | 
			
		||||
        "request",
 | 
			
		||||
        request_id,
 | 
			
		||||
        remote_ip,
 | 
			
		||||
        method = method.as_str(),
 | 
			
		||||
        path,
 | 
			
		||||
        user_agent
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    let start = Instant::now();
 | 
			
		||||
    let response = next.run(request).instrument(request_id_span).await;
 | 
			
		||||
    let elapsed = start.elapsed().as_millis();
 | 
			
		||||
 | 
			
		||||
    info!("handled request for {} {} in {}ms", method, path, elapsed);
 | 
			
		||||
 | 
			
		||||
    response
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										11
									
								
								services/api/src/middleware/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								services/api/src/middleware/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
			
		||||
mod cors;
 | 
			
		||||
 | 
			
		||||
pub use cors::cors;
 | 
			
		||||
 | 
			
		||||
mod logger;
 | 
			
		||||
pub use logger::logger;
 | 
			
		||||
 | 
			
		||||
mod ignore_invalid_routes;
 | 
			
		||||
pub use ignore_invalid_routes::ignore_invalid_routes;
 | 
			
		||||
 | 
			
		||||
pub mod ratelimit;
 | 
			
		||||
							
								
								
									
										57
									
								
								services/api/src/middleware/ratelimit.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								services/api/src/middleware/ratelimit.lua
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,57 @@
 | 
			
		||||
-- Copyright (c) 2017 Pavel Pravosud
 | 
			
		||||
-- https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
 | 
			
		||||
 | 
			
		||||
local rate_limit_key = KEYS[1]
 | 
			
		||||
local burst = ARGV[1]
 | 
			
		||||
local rate = ARGV[2]
 | 
			
		||||
local period = ARGV[3]
 | 
			
		||||
-- local cost = tonumber(ARGV[4])
 | 
			
		||||
local cost = 1
 | 
			
		||||
 | 
			
		||||
local emission_interval = period / rate
 | 
			
		||||
local increment = emission_interval * cost
 | 
			
		||||
local burst_offset = emission_interval * burst
 | 
			
		||||
 | 
			
		||||
-- redis returns time as an array containing two integers: seconds of the epoch
 | 
			
		||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
 | 
			
		||||
-- convert them to a floating point number. the resulting number is 16 digits,
 | 
			
		||||
-- bordering on the limits of a 64-bit double-precision floating point number.
 | 
			
		||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
 | 
			
		||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
 | 
			
		||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
 | 
			
		||||
local jan_1_2017 = 1483228800
 | 
			
		||||
local now = redis.call("TIME")
 | 
			
		||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
 | 
			
		||||
 | 
			
		||||
local tat = redis.call("GET", rate_limit_key)
 | 
			
		||||
 | 
			
		||||
if not tat then
 | 
			
		||||
    tat = now
 | 
			
		||||
else
 | 
			
		||||
    tat = tonumber(tat)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
tat = math.max(tat, now)
 | 
			
		||||
 | 
			
		||||
local new_tat = tat + increment
 | 
			
		||||
local allow_at = new_tat - burst_offset
 | 
			
		||||
 | 
			
		||||
local diff = now - allow_at
 | 
			
		||||
local remaining = diff / emission_interval
 | 
			
		||||
 | 
			
		||||
if remaining < 0 then
 | 
			
		||||
    local reset_after = tat - now
 | 
			
		||||
    local retry_after = diff * -1
 | 
			
		||||
    return {
 | 
			
		||||
        0, -- remaining
 | 
			
		||||
        retry_after,
 | 
			
		||||
        reset_after,
 | 
			
		||||
    }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
local reset_after = new_tat - now
 | 
			
		||||
if reset_after > 0 then
 | 
			
		||||
    redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
 | 
			
		||||
end
 | 
			
		||||
local retry_after = -1
 | 
			
		||||
return {remaining, retry_after, reset_after}
 | 
			
		||||
							
								
								
									
										154
									
								
								services/api/src/middleware/ratelimit.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								services/api/src/middleware/ratelimit.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,154 @@
 | 
			
		||||
use std::time::{Duration, SystemTime};
 | 
			
		||||
 | 
			
		||||
use axum::{
 | 
			
		||||
    extract::State,
 | 
			
		||||
    http::Request,
 | 
			
		||||
    middleware::{FromFnLayer, Next},
 | 
			
		||||
    response::Response,
 | 
			
		||||
};
 | 
			
		||||
use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash};
 | 
			
		||||
use http::{HeaderValue, StatusCode};
 | 
			
		||||
use tracing::{error, info, warn};
 | 
			
		||||
 | 
			
		||||
use crate::util::{header_or_unknown, json_err};
 | 
			
		||||
 | 
			
		||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
 | 
			
		||||
 | 
			
		||||
lazy_static::lazy_static! {
 | 
			
		||||
    static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// todo lol
 | 
			
		||||
const TOKEN2: &'static str = "h";
 | 
			
		||||
 | 
			
		||||
// this is awful but it works
 | 
			
		||||
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
 | 
			
		||||
    let redis = libpk::config.api.ratelimit_redis_addr.as_ref().map(|val| {
 | 
			
		||||
        let r = fred::pool::RedisPool::new(
 | 
			
		||||
            fred::types::RedisConfig::from_url_centralized(val.as_ref())
 | 
			
		||||
                .expect("redis url is invalid"),
 | 
			
		||||
            10,
 | 
			
		||||
        )
 | 
			
		||||
        .expect("failed to connect to redis");
 | 
			
		||||
 | 
			
		||||
        let handle = r.connect(Some(ReconnectPolicy::default()));
 | 
			
		||||
 | 
			
		||||
        tokio::spawn(async move { handle });
 | 
			
		||||
 | 
			
		||||
        let rscript = r.clone();
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            if let Ok(()) = rscript.wait_for_connect().await {
 | 
			
		||||
                match rscript.script_load(LUA_SCRIPT).await {
 | 
			
		||||
                    Ok(_) => info!("connected to redis for request rate limiting"),
 | 
			
		||||
                    Err(err) => error!("could not load redis script: {}", err),
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                error!("could not wait for connection to load redis script!");
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        r
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    if redis.is_none() {
 | 
			
		||||
        warn!("running without request rate limiting!");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    axum::middleware::from_fn_with_state(redis, f)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn do_request_ratelimited<B>(
 | 
			
		||||
    State(redis): State<Option<RedisPool>>,
 | 
			
		||||
    request: Request<B>,
 | 
			
		||||
    next: Next<B>,
 | 
			
		||||
) -> Response {
 | 
			
		||||
    if let Some(redis) = redis {
 | 
			
		||||
        let headers = request.headers().clone();
 | 
			
		||||
        let source_ip = header_or_unknown(headers.get("Fly-Client-IP"));
 | 
			
		||||
 | 
			
		||||
        let (rl_key, rate) = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
 | 
			
		||||
        {
 | 
			
		||||
            if header == TOKEN2 {
 | 
			
		||||
                ("token2", 20)
 | 
			
		||||
            } else {
 | 
			
		||||
                (source_ip, 2)
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            (source_ip, 2)
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let burst = 5;
 | 
			
		||||
        let period = 1; // seconds
 | 
			
		||||
 | 
			
		||||
        // todo: make this static
 | 
			
		||||
        // though even if it's not static, it's probably cheaper than sending the entire script to redis every time
 | 
			
		||||
        let scriptsha = sha1_hash(&LUA_SCRIPT);
 | 
			
		||||
 | 
			
		||||
        // local rate_limit_key = KEYS[1]
 | 
			
		||||
        // local burst = ARGV[1]
 | 
			
		||||
        // local rate = ARGV[2]
 | 
			
		||||
        // local period = ARGV[3]
 | 
			
		||||
        // return {remaining, retry_after, reset_after}
 | 
			
		||||
        let resp = redis
 | 
			
		||||
            .evalsha::<(i32, String, u64), String, Vec<&str>, Vec<i32>>(
 | 
			
		||||
                scriptsha,
 | 
			
		||||
                vec![rl_key],
 | 
			
		||||
                vec![burst, rate, period],
 | 
			
		||||
            )
 | 
			
		||||
            .await;
 | 
			
		||||
 | 
			
		||||
        match resp {
 | 
			
		||||
            Ok((mut remaining, retry_after, reset_after)) => {
 | 
			
		||||
                let mut response = if remaining > 0 {
 | 
			
		||||
                    next.run(request).await
 | 
			
		||||
                } else {
 | 
			
		||||
                    json_err(
 | 
			
		||||
                        StatusCode::TOO_MANY_REQUESTS,
 | 
			
		||||
                        format!(
 | 
			
		||||
                            // todo: the retry_after is horribly wrong
 | 
			
		||||
                            r#"{{"message":"429: too many requests","retry_after":{retry_after}}}"#
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                // the redis script puts burst in remaining for ??? some reason
 | 
			
		||||
                remaining -= burst - rate;
 | 
			
		||||
 | 
			
		||||
                let reset_time = SystemTime::now()
 | 
			
		||||
                    .checked_add(Duration::from_secs(reset_after))
 | 
			
		||||
                    .expect("invalid timestamp")
 | 
			
		||||
                    .duration_since(std::time::UNIX_EPOCH)
 | 
			
		||||
                    .expect("invalid duration")
 | 
			
		||||
                    .as_secs();
 | 
			
		||||
 | 
			
		||||
                let headers = response.headers_mut();
 | 
			
		||||
                headers.insert(
 | 
			
		||||
                    "X-RateLimit-Limit",
 | 
			
		||||
                    HeaderValue::from_str(format!("{}", rate).as_str())
 | 
			
		||||
                        .expect("invalid header value"),
 | 
			
		||||
                );
 | 
			
		||||
                headers.insert(
 | 
			
		||||
                    "X-RateLimit-Remaining",
 | 
			
		||||
                    HeaderValue::from_str(format!("{}", remaining).as_str())
 | 
			
		||||
                        .expect("invalid header value"),
 | 
			
		||||
                );
 | 
			
		||||
                headers.insert(
 | 
			
		||||
                    "X-RateLimit-Reset",
 | 
			
		||||
                    HeaderValue::from_str(format!("{}", reset_time).as_str())
 | 
			
		||||
                        .expect("invalid header value"),
 | 
			
		||||
                );
 | 
			
		||||
 | 
			
		||||
                return response;
 | 
			
		||||
            }
 | 
			
		||||
            Err(err) => {
 | 
			
		||||
                tracing::error!("error getting ratelimit info: {}", err);
 | 
			
		||||
                return json_err(
 | 
			
		||||
                    StatusCode::INTERNAL_SERVER_ERROR,
 | 
			
		||||
                    r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
 | 
			
		||||
                );
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    next.run(request).await
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										42
									
								
								services/api/src/util.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								services/api/src/util.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
			
		||||
use axum::{
 | 
			
		||||
    body::Body,
 | 
			
		||||
    http::{HeaderValue, Request, Response, StatusCode, Uri},
 | 
			
		||||
    response::IntoResponse,
 | 
			
		||||
};
 | 
			
		||||
use tracing::error;
 | 
			
		||||
 | 
			
		||||
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
 | 
			
		||||
    if let Some(value) = header {
 | 
			
		||||
        match value.to_str() {
 | 
			
		||||
            Ok(v) => v,
 | 
			
		||||
            Err(err) => {
 | 
			
		||||
                error!("failed to parse header value {:#?}: {:#?}", value, err);
 | 
			
		||||
                "failed to parse"
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        "unknown"
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn rproxy(req: Request<Body>) -> Response<Body> {
 | 
			
		||||
    let uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
 | 
			
		||||
 | 
			
		||||
    match hyper_reverse_proxy::call("0.0.0.0".parse().unwrap(), &uri[..uri.len() - 1], req).await {
 | 
			
		||||
        Ok(response) => response,
 | 
			
		||||
        Err(error) => {
 | 
			
		||||
            error!("error proxying request: {:?}", error);
 | 
			
		||||
            Response::builder()
 | 
			
		||||
                .status(StatusCode::INTERNAL_SERVER_ERROR)
 | 
			
		||||
                .body(Body::empty())
 | 
			
		||||
                .unwrap()
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
 | 
			
		||||
    let mut response = (code, text).into_response();
 | 
			
		||||
    let headers = response.headers_mut();
 | 
			
		||||
    headers.insert("content-type", HeaderValue::from_static("application/json"));
 | 
			
		||||
    response
 | 
			
		||||
}
 | 
			
		||||
@@ -2,25 +2,26 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httputil"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus/promhttp"
 | 
			
		||||
 | 
			
		||||
	"web-proxy/redis_rate"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var limiter *redis_rate.Limiter
 | 
			
		||||
 | 
			
		||||
// todo: be able to raise ratelimits for >1 consumers
 | 
			
		||||
var token2 string
 | 
			
		||||
func proxyTo(host string) *httputil.ReverseProxy {
 | 
			
		||||
	rp := httputil.NewSingleHostReverseProxy(&url.URL{
 | 
			
		||||
		Scheme:   "http",
 | 
			
		||||
		Host:     host,
 | 
			
		||||
		RawQuery: "",
 | 
			
		||||
	})
 | 
			
		||||
	rp.ModifyResponse = logTimeElapsed
 | 
			
		||||
	return rp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// todo: this shouldn't be in this repo
 | 
			
		||||
var remotes = map[string]*httputil.ReverseProxy{
 | 
			
		||||
@@ -29,22 +30,6 @@ var remotes = map[string]*httputil.ReverseProxy{
 | 
			
		||||
	"sentry.pluralkit.me": proxyTo("[fdaa:0:ae33:a7b:8dd7:0:a:202]:9000"),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	redisHost := requireEnv("REDIS_HOST")
 | 
			
		||||
	redisPassword := requireEnv("REDIS_PASSWORD")
 | 
			
		||||
 | 
			
		||||
	rdb := redis.NewClient(&redis.Options{
 | 
			
		||||
		Addr:     redisHost,
 | 
			
		||||
		Username: "default",
 | 
			
		||||
		Password: redisPassword,
 | 
			
		||||
	})
 | 
			
		||||
	limiter = redis_rate.NewLimiter(rdb)
 | 
			
		||||
 | 
			
		||||
	token2 = requireEnv("TOKEN2")
 | 
			
		||||
 | 
			
		||||
	remotes["dash.pluralkit.me"].ModifyResponse = modifyDashResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ProxyHandler struct{}
 | 
			
		||||
 | 
			
		||||
func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
 | 
			
		||||
@@ -61,48 +46,6 @@ func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.Host == "api.pluralkit.me" {
 | 
			
		||||
		// root
 | 
			
		||||
		if r.URL.Path == "" {
 | 
			
		||||
			// api root path redirects to docs
 | 
			
		||||
			http.Redirect(rw, r, "https://pluralkit.me/api/", http.StatusFound)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// CORS headers
 | 
			
		||||
		rw.Header().Add("Access-Control-Allow-Origin", "*")
 | 
			
		||||
		rw.Header().Add("Access-Control-Allow-Methods", "*")
 | 
			
		||||
		rw.Header().Add("Access-Control-Allow-Credentials", "true")
 | 
			
		||||
		rw.Header().Add("Access-Control-Allow-Headers", "Content-Type, Authorization, sentry-trace, User-Agent")
 | 
			
		||||
		rw.Header().Add("Access-Control-Max-Age", "86400")
 | 
			
		||||
 | 
			
		||||
		if r.Method == http.MethodOptions {
 | 
			
		||||
			rw.WriteHeader(200)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if r.URL.Path == "/" {
 | 
			
		||||
			http.Redirect(rw, r, "https://pluralkit.me/api", http.StatusFound)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if strings.HasPrefix(r.URL.Path, "/v1") {
 | 
			
		||||
			rw.Header().Set("content-type", "application/json")
 | 
			
		||||
			rw.WriteHeader(410)
 | 
			
		||||
			rw.Write([]byte(`{"message":"Unsupported API version","code":0}`))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if is_trying_to_use_v1_path_on_v2(r.URL.Path) {
 | 
			
		||||
			rw.WriteHeader(400)
 | 
			
		||||
			rw.Write([]byte(`{"message":"Invalid path for API version","code":0}`))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if is_api_ratelimited(rw, r) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	startTime := time.Now()
 | 
			
		||||
	r = r.WithContext(context.WithValue(r.Context(), "req-time", startTime))
 | 
			
		||||
 | 
			
		||||
@@ -119,40 +62,14 @@ func logTimeElapsed(resp *http.Response) error {
 | 
			
		||||
		"domain": r.Host,
 | 
			
		||||
		"method": r.Method,
 | 
			
		||||
		"status": strconv.Itoa(resp.StatusCode),
 | 
			
		||||
		"route":  cleanPath(r.Host, r.URL.Path),
 | 
			
		||||
		"route":  r.URL.Path,
 | 
			
		||||
	}).Observe(elapsed.Seconds())
 | 
			
		||||
 | 
			
		||||
	log, _ := json.Marshal(map[string]interface{}{
 | 
			
		||||
		"remote_ip":   r.Header.Get("Fly-Client-IP"),
 | 
			
		||||
		"method":      r.Method,
 | 
			
		||||
		"host":        r.Host,
 | 
			
		||||
		"route":       r.URL.Path,
 | 
			
		||||
		"route_clean": cleanPath(r.Host, r.URL.Path),
 | 
			
		||||
		"status":      resp.StatusCode,
 | 
			
		||||
		"elapsed":     elapsed.Milliseconds(),
 | 
			
		||||
		"user_agent":  r.Header.Get("User-Agent"),
 | 
			
		||||
	})
 | 
			
		||||
	fmt.Println(string(log))
 | 
			
		||||
 | 
			
		||||
	// log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
 | 
			
		||||
	log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func modifyDashResponse(resp *http.Response) error {
 | 
			
		||||
	r := resp.Request
 | 
			
		||||
 | 
			
		||||
	// cache built+hashed dashboard js/css files forever
 | 
			
		||||
	is_dash_static_asset := strings.HasPrefix(r.URL.Path, "/assets/") &&
 | 
			
		||||
		(strings.HasSuffix(r.URL.Path, ".js") || strings.HasSuffix(r.URL.Path, ".css") || strings.HasSuffix(r.URL.Path, ".map"))
 | 
			
		||||
 | 
			
		||||
	if is_dash_static_asset && resp.StatusCode == 200 {
 | 
			
		||||
		resp.Header.Add("Cache-Control", "max-age=31536000, s-maxage=31536000, immutable")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return logTimeElapsed(resp)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	prometheus.MustRegister(metric)
 | 
			
		||||
 | 
			
		||||
@@ -161,3 +78,11 @@ func main() {
 | 
			
		||||
 | 
			
		||||
	http.ListenAndServe(":8080", ProxyHandler{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var metric = prometheus.NewHistogramVec(
 | 
			
		||||
	prometheus.HistogramOpts{
 | 
			
		||||
		Name:    "pk_http_requests",
 | 
			
		||||
		Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
 | 
			
		||||
	},
 | 
			
		||||
	[]string{"domain", "method", "status", "route"},
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,43 +0,0 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"web-proxy/redis_rate"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func is_api_ratelimited(rw http.ResponseWriter, r *http.Request) bool {
 | 
			
		||||
	var limit int
 | 
			
		||||
	var key string
 | 
			
		||||
 | 
			
		||||
	if r.Header.Get("X-PluralKit-App") == token2 {
 | 
			
		||||
		limit = 20
 | 
			
		||||
		key = "token2"
 | 
			
		||||
	} else {
 | 
			
		||||
		limit = 2
 | 
			
		||||
		key = r.Header.Get("Fly-Client-IP")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res, err := limiter.Allow(r.Context(), "ratelimit:"+key, redis_rate.Limit{
 | 
			
		||||
		Period: time.Second,
 | 
			
		||||
		Rate:   limit,
 | 
			
		||||
		Burst:  5,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rw.Header().Set("X-RateLimit-Limit", fmt.Sprint(limit))
 | 
			
		||||
	rw.Header().Set("X-RateLimit-Remaining", fmt.Sprint(res.Remaining))
 | 
			
		||||
	rw.Header().Set("X-RateLimit-Reset", fmt.Sprint(time.Now().Add(res.ResetAfter).UnixNano()/1_000_000))
 | 
			
		||||
 | 
			
		||||
	if res.Allowed < 1 {
 | 
			
		||||
		rw.WriteHeader(429)
 | 
			
		||||
		rw.Write([]byte(`{"message":"429: too many requests","retry_after":` + fmt.Sprint(res.RetryAfter.Milliseconds()) + `,"code":0}`))
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
@@ -1,25 +0,0 @@
 | 
			
		||||
Copyright (c) 2013 The github.com/go-redis/redis_rate Authors.
 | 
			
		||||
All rights reserved.
 | 
			
		||||
 | 
			
		||||
Redistribution and use in source and binary forms, with or without
 | 
			
		||||
modification, are permitted provided that the following conditions are
 | 
			
		||||
met:
 | 
			
		||||
 | 
			
		||||
   * Redistributions of source code must retain the above copyright
 | 
			
		||||
notice, this list of conditions and the following disclaimer.
 | 
			
		||||
   * Redistributions in binary form must reproduce the above
 | 
			
		||||
copyright notice, this list of conditions and the following disclaimer
 | 
			
		||||
in the documentation and/or other materials provided with the
 | 
			
		||||
distribution.
 | 
			
		||||
 | 
			
		||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 | 
			
		||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 | 
			
		||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 | 
			
		||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 | 
			
		||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 | 
			
		||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 | 
			
		||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 | 
			
		||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 | 
			
		||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | 
			
		||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | 
			
		||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
			
		||||
@@ -1,140 +0,0 @@
 | 
			
		||||
package redis_rate
 | 
			
		||||
 | 
			
		||||
import "github.com/go-redis/redis/v8"
 | 
			
		||||
 | 
			
		||||
// pluralkit changes:
 | 
			
		||||
// fly's hosted redis doesn't support replicate commands
 | 
			
		||||
// we can remove it since it's a single host
 | 
			
		||||
 | 
			
		||||
// Copyright (c) 2017 Pavel Pravosud
 | 
			
		||||
// https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
 | 
			
		||||
var allowN = redis.NewScript(`
 | 
			
		||||
-- this script has side-effects, so it requires replicate commands mode
 | 
			
		||||
-- redis.replicate_commands()
 | 
			
		||||
 | 
			
		||||
local rate_limit_key = KEYS[1]
 | 
			
		||||
local burst = ARGV[1]
 | 
			
		||||
local rate = ARGV[2]
 | 
			
		||||
local period = ARGV[3]
 | 
			
		||||
local cost = tonumber(ARGV[4])
 | 
			
		||||
 | 
			
		||||
local emission_interval = period / rate
 | 
			
		||||
local increment = emission_interval * cost
 | 
			
		||||
local burst_offset = emission_interval * burst
 | 
			
		||||
 | 
			
		||||
-- redis returns time as an array containing two integers: seconds of the epoch
 | 
			
		||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
 | 
			
		||||
-- convert them to a floating point number. the resulting number is 16 digits,
 | 
			
		||||
-- bordering on the limits of a 64-bit double-precision floating point number.
 | 
			
		||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
 | 
			
		||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
 | 
			
		||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
 | 
			
		||||
local jan_1_2017 = 1483228800
 | 
			
		||||
local now = redis.call("TIME")
 | 
			
		||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
 | 
			
		||||
 | 
			
		||||
local tat = redis.call("GET", rate_limit_key)
 | 
			
		||||
 | 
			
		||||
if not tat then
 | 
			
		||||
  tat = now
 | 
			
		||||
else
 | 
			
		||||
  tat = tonumber(tat)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
tat = math.max(tat, now)
 | 
			
		||||
 | 
			
		||||
local new_tat = tat + increment
 | 
			
		||||
local allow_at = new_tat - burst_offset
 | 
			
		||||
 | 
			
		||||
local diff = now - allow_at
 | 
			
		||||
local remaining = diff / emission_interval
 | 
			
		||||
 | 
			
		||||
if remaining < 0 then
 | 
			
		||||
  local reset_after = tat - now
 | 
			
		||||
  local retry_after = diff * -1
 | 
			
		||||
  return {
 | 
			
		||||
    0, -- allowed
 | 
			
		||||
    0, -- remaining
 | 
			
		||||
    tostring(retry_after),
 | 
			
		||||
    tostring(reset_after),
 | 
			
		||||
  }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
local reset_after = new_tat - now
 | 
			
		||||
if reset_after > 0 then
 | 
			
		||||
  redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
 | 
			
		||||
end
 | 
			
		||||
local retry_after = -1
 | 
			
		||||
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
 | 
			
		||||
`)
 | 
			
		||||
 | 
			
		||||
var allowAtMost = redis.NewScript(`
 | 
			
		||||
-- this script has side-effects, so it requires replicate commands mode
 | 
			
		||||
-- redis.replicate_commands()
 | 
			
		||||
 | 
			
		||||
local rate_limit_key = KEYS[1]
 | 
			
		||||
local burst = ARGV[1]
 | 
			
		||||
local rate = ARGV[2]
 | 
			
		||||
local period = ARGV[3]
 | 
			
		||||
local cost = tonumber(ARGV[4])
 | 
			
		||||
 | 
			
		||||
local emission_interval = period / rate
 | 
			
		||||
local burst_offset = emission_interval * burst
 | 
			
		||||
 | 
			
		||||
-- redis returns time as an array containing two integers: seconds of the epoch
 | 
			
		||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
 | 
			
		||||
-- convert them to a floating point number. the resulting number is 16 digits,
 | 
			
		||||
-- bordering on the limits of a 64-bit double-precision floating point number.
 | 
			
		||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
 | 
			
		||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
 | 
			
		||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
 | 
			
		||||
local jan_1_2017 = 1483228800
 | 
			
		||||
local now = redis.call("TIME")
 | 
			
		||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
 | 
			
		||||
 | 
			
		||||
local tat = redis.call("GET", rate_limit_key)
 | 
			
		||||
 | 
			
		||||
if not tat then
 | 
			
		||||
  tat = now
 | 
			
		||||
else
 | 
			
		||||
  tat = tonumber(tat)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
tat = math.max(tat, now)
 | 
			
		||||
 | 
			
		||||
local diff = now - (tat - burst_offset)
 | 
			
		||||
local remaining = diff / emission_interval
 | 
			
		||||
 | 
			
		||||
if remaining < 1 then
 | 
			
		||||
  local reset_after = tat - now
 | 
			
		||||
  local retry_after = emission_interval - diff
 | 
			
		||||
  return {
 | 
			
		||||
    0, -- allowed
 | 
			
		||||
    0, -- remaining
 | 
			
		||||
    tostring(retry_after),
 | 
			
		||||
    tostring(reset_after),
 | 
			
		||||
  }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
if remaining < cost then
 | 
			
		||||
  cost = remaining
 | 
			
		||||
  remaining = 0
 | 
			
		||||
else
 | 
			
		||||
  remaining = remaining - cost
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
local increment = emission_interval * cost
 | 
			
		||||
local new_tat = tat + increment
 | 
			
		||||
 | 
			
		||||
local reset_after = new_tat - now
 | 
			
		||||
if reset_after > 0 then
 | 
			
		||||
  redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
return {
 | 
			
		||||
  cost,
 | 
			
		||||
  remaining,
 | 
			
		||||
  tostring(-1),
 | 
			
		||||
  tostring(reset_after),
 | 
			
		||||
}
 | 
			
		||||
`)
 | 
			
		||||
@@ -1,198 +0,0 @@
 | 
			
		||||
package redis_rate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const redisPrefix = "rate:"
 | 
			
		||||
 | 
			
		||||
type rediser interface {
 | 
			
		||||
	Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
 | 
			
		||||
	EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
 | 
			
		||||
	ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
 | 
			
		||||
	ScriptLoad(ctx context.Context, script string) *redis.StringCmd
 | 
			
		||||
	Del(ctx context.Context, keys ...string) *redis.IntCmd
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Limit struct {
 | 
			
		||||
	Rate   int
 | 
			
		||||
	Burst  int
 | 
			
		||||
	Period time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l Limit) String() string {
 | 
			
		||||
	return fmt.Sprintf("%d req/%s (burst %d)", l.Rate, fmtDur(l.Period), l.Burst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l Limit) IsZero() bool {
 | 
			
		||||
	return l == Limit{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fmtDur(d time.Duration) string {
 | 
			
		||||
	switch d {
 | 
			
		||||
	case time.Second:
 | 
			
		||||
		return "s"
 | 
			
		||||
	case time.Minute:
 | 
			
		||||
		return "m"
 | 
			
		||||
	case time.Hour:
 | 
			
		||||
		return "h"
 | 
			
		||||
	}
 | 
			
		||||
	return d.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PerSecond(rate int) Limit {
 | 
			
		||||
	return Limit{
 | 
			
		||||
		Rate:   rate,
 | 
			
		||||
		Period: time.Second,
 | 
			
		||||
		Burst:  rate,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PerMinute(rate int) Limit {
 | 
			
		||||
	return Limit{
 | 
			
		||||
		Rate:   rate,
 | 
			
		||||
		Period: time.Minute,
 | 
			
		||||
		Burst:  rate,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PerHour(rate int) Limit {
 | 
			
		||||
	return Limit{
 | 
			
		||||
		Rate:   rate,
 | 
			
		||||
		Period: time.Hour,
 | 
			
		||||
		Burst:  rate,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
// Limiter controls how frequently events are allowed to happen.
 | 
			
		||||
type Limiter struct {
 | 
			
		||||
	rdb rediser
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewLimiter returns a new Limiter.
 | 
			
		||||
func NewLimiter(rdb rediser) *Limiter {
 | 
			
		||||
	return &Limiter{
 | 
			
		||||
		rdb: rdb,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Allow is a shortcut for AllowN(ctx, key, limit, 1).
 | 
			
		||||
func (l Limiter) Allow(ctx context.Context, key string, limit Limit) (*Result, error) {
 | 
			
		||||
	return l.AllowN(ctx, key, limit, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowN reports whether n events may happen at time now.
 | 
			
		||||
func (l Limiter) AllowN(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	key string,
 | 
			
		||||
	limit Limit,
 | 
			
		||||
	n int,
 | 
			
		||||
) (*Result, error) {
 | 
			
		||||
	values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
 | 
			
		||||
	v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	values = v.([]interface{})
 | 
			
		||||
 | 
			
		||||
	retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := &Result{
 | 
			
		||||
		Limit:      limit,
 | 
			
		||||
		Allowed:    int(values[0].(int64)),
 | 
			
		||||
		Remaining:  int(values[1].(int64)),
 | 
			
		||||
		RetryAfter: dur(retryAfter),
 | 
			
		||||
		ResetAfter: dur(resetAfter),
 | 
			
		||||
	}
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowAtMost reports whether at most n events may happen at time now.
 | 
			
		||||
// It returns number of allowed events that is less than or equal to n.
 | 
			
		||||
func (l Limiter) AllowAtMost(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	key string,
 | 
			
		||||
	limit Limit,
 | 
			
		||||
	n int,
 | 
			
		||||
) (*Result, error) {
 | 
			
		||||
	values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
 | 
			
		||||
	v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	values = v.([]interface{})
 | 
			
		||||
 | 
			
		||||
	retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := &Result{
 | 
			
		||||
		Limit:      limit,
 | 
			
		||||
		Allowed:    int(values[0].(int64)),
 | 
			
		||||
		Remaining:  int(values[1].(int64)),
 | 
			
		||||
		RetryAfter: dur(retryAfter),
 | 
			
		||||
		ResetAfter: dur(resetAfter),
 | 
			
		||||
	}
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reset gets a key and reset all limitations and previous usages
 | 
			
		||||
func (l *Limiter) Reset(ctx context.Context, key string) error {
 | 
			
		||||
	return l.rdb.Del(ctx, redisPrefix+key).Err()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func dur(f float64) time.Duration {
 | 
			
		||||
	if f == -1 {
 | 
			
		||||
		return -1
 | 
			
		||||
	}
 | 
			
		||||
	return time.Duration(f * float64(time.Second))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Result struct {
 | 
			
		||||
	// Limit is the limit that was used to obtain this result.
 | 
			
		||||
	Limit Limit
 | 
			
		||||
 | 
			
		||||
	// Allowed is the number of events that may happen at time now.
 | 
			
		||||
	Allowed int
 | 
			
		||||
 | 
			
		||||
	// Remaining is the maximum number of requests that could be
 | 
			
		||||
	// permitted instantaneously for this key given the current
 | 
			
		||||
	// state. For example, if a rate limiter allows 10 requests per
 | 
			
		||||
	// second and has already received 6 requests for this key this
 | 
			
		||||
	// second, Remaining would be 4.
 | 
			
		||||
	Remaining int
 | 
			
		||||
 | 
			
		||||
	// RetryAfter is the time until the next request will be permitted.
 | 
			
		||||
	// It should be -1 unless the rate limit has been exceeded.
 | 
			
		||||
	RetryAfter time.Duration
 | 
			
		||||
 | 
			
		||||
	// ResetAfter is the time until the RateLimiter returns to its
 | 
			
		||||
	// initial state for a given key. For example, if a rate limiter
 | 
			
		||||
	// manages requests per second and received one request 200ms ago,
 | 
			
		||||
	// Reset would return 800ms. You can also think of this as the time
 | 
			
		||||
	// until Limit and Remaining will be equal.
 | 
			
		||||
	ResetAfter time.Duration
 | 
			
		||||
}
 | 
			
		||||
@@ -1,74 +0,0 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http/httputil"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var metric = prometheus.NewHistogramVec(
 | 
			
		||||
	prometheus.HistogramOpts{
 | 
			
		||||
		Name:    "pk_http_requests",
 | 
			
		||||
		Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
 | 
			
		||||
	},
 | 
			
		||||
	[]string{"domain", "method", "status", "route"},
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func proxyTo(host string) *httputil.ReverseProxy {
 | 
			
		||||
	rp := httputil.NewSingleHostReverseProxy(&url.URL{
 | 
			
		||||
		Scheme:   "http",
 | 
			
		||||
		Host:     host,
 | 
			
		||||
		RawQuery: "",
 | 
			
		||||
	})
 | 
			
		||||
	rp.ModifyResponse = logTimeElapsed
 | 
			
		||||
	return rp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var systemsRegex = regexp.MustCompile("systems/[^/{}]+")
 | 
			
		||||
var membersRegex = regexp.MustCompile("members/[^/{}]+")
 | 
			
		||||
var groupsRegex = regexp.MustCompile("groups/[^/{}]+")
 | 
			
		||||
var switchesRegex = regexp.MustCompile("switches/[^/{}]+")
 | 
			
		||||
var guildsRegex = regexp.MustCompile("guilds/[^/{}]+")
 | 
			
		||||
var messagesRegex = regexp.MustCompile("messages/[^/{}]+")
 | 
			
		||||
 | 
			
		||||
func cleanPath(host, path string) string {
 | 
			
		||||
	if host != "api.pluralkit.me" {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	path = strings.ToLower(path)
 | 
			
		||||
 | 
			
		||||
	if !(strings.HasPrefix(path, "/v2") || strings.HasPrefix(path, "/private")) {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	path = systemsRegex.ReplaceAllString(path, "systems/{systemRef}")
 | 
			
		||||
	path = membersRegex.ReplaceAllString(path, "members/{memberRef}")
 | 
			
		||||
	path = groupsRegex.ReplaceAllString(path, "groups/{groupRef}")
 | 
			
		||||
	path = switchesRegex.ReplaceAllString(path, "switches/{switchRef}")
 | 
			
		||||
	path = guildsRegex.ReplaceAllString(path, "guilds/{guild_id}")
 | 
			
		||||
	path = messagesRegex.ReplaceAllString(path, "messages/{message_id}")
 | 
			
		||||
 | 
			
		||||
	return path
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requireEnv(key string) string {
 | 
			
		||||
	if val, ok := os.LookupEnv(key); !ok {
 | 
			
		||||
		panic("missing `" + key + "` in environment")
 | 
			
		||||
	} else {
 | 
			
		||||
		return val
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func is_trying_to_use_v1_path_on_v2(path string) bool {
 | 
			
		||||
	return strings.HasPrefix(path, "/v2/s/") ||
 | 
			
		||||
		strings.HasPrefix(path, "/v2/m/") ||
 | 
			
		||||
		strings.HasPrefix(path, "/v2/a/") ||
 | 
			
		||||
		strings.HasPrefix(path, "/v2/msg/") ||
 | 
			
		||||
		path == "/v2/s" ||
 | 
			
		||||
		path == "/v2/m"
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user