feat(api): init rust rewrite

This commit is contained in:
spiral
2023-02-15 19:27:36 -05:00
parent 5da3c84bce
commit 5440386969
24 changed files with 2443 additions and 586 deletions

View File

@@ -2,25 +2,26 @@ package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"web-proxy/redis_rate"
)
var limiter *redis_rate.Limiter
// todo: be able to raise ratelimits for >1 consumers
var token2 string
func proxyTo(host string) *httputil.ReverseProxy {
rp := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: host,
RawQuery: "",
})
rp.ModifyResponse = logTimeElapsed
return rp
}
// todo: this shouldn't be in this repo
var remotes = map[string]*httputil.ReverseProxy{
@@ -29,22 +30,6 @@ var remotes = map[string]*httputil.ReverseProxy{
"sentry.pluralkit.me": proxyTo("[fdaa:0:ae33:a7b:8dd7:0:a:202]:9000"),
}
func init() {
redisHost := requireEnv("REDIS_HOST")
redisPassword := requireEnv("REDIS_PASSWORD")
rdb := redis.NewClient(&redis.Options{
Addr: redisHost,
Username: "default",
Password: redisPassword,
})
limiter = redis_rate.NewLimiter(rdb)
token2 = requireEnv("TOKEN2")
remotes["dash.pluralkit.me"].ModifyResponse = modifyDashResponse
}
type ProxyHandler struct{}
func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
@@ -61,48 +46,6 @@ func (p ProxyHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
return
}
if r.Host == "api.pluralkit.me" {
// root
if r.URL.Path == "" {
// api root path redirects to docs
http.Redirect(rw, r, "https://pluralkit.me/api/", http.StatusFound)
return
}
// CORS headers
rw.Header().Add("Access-Control-Allow-Origin", "*")
rw.Header().Add("Access-Control-Allow-Methods", "*")
rw.Header().Add("Access-Control-Allow-Credentials", "true")
rw.Header().Add("Access-Control-Allow-Headers", "Content-Type, Authorization, sentry-trace, User-Agent")
rw.Header().Add("Access-Control-Max-Age", "86400")
if r.Method == http.MethodOptions {
rw.WriteHeader(200)
return
}
if r.URL.Path == "/" {
http.Redirect(rw, r, "https://pluralkit.me/api", http.StatusFound)
return
}
if strings.HasPrefix(r.URL.Path, "/v1") {
rw.Header().Set("content-type", "application/json")
rw.WriteHeader(410)
rw.Write([]byte(`{"message":"Unsupported API version","code":0}`))
}
if is_trying_to_use_v1_path_on_v2(r.URL.Path) {
rw.WriteHeader(400)
rw.Write([]byte(`{"message":"Invalid path for API version","code":0}`))
return
}
if is_api_ratelimited(rw, r) {
return
}
}
startTime := time.Now()
r = r.WithContext(context.WithValue(r.Context(), "req-time", startTime))
@@ -119,40 +62,14 @@ func logTimeElapsed(resp *http.Response) error {
"domain": r.Host,
"method": r.Method,
"status": strconv.Itoa(resp.StatusCode),
"route": cleanPath(r.Host, r.URL.Path),
"route": r.URL.Path,
}).Observe(elapsed.Seconds())
log, _ := json.Marshal(map[string]interface{}{
"remote_ip": r.Header.Get("Fly-Client-IP"),
"method": r.Method,
"host": r.Host,
"route": r.URL.Path,
"route_clean": cleanPath(r.Host, r.URL.Path),
"status": resp.StatusCode,
"elapsed": elapsed.Milliseconds(),
"user_agent": r.Header.Get("User-Agent"),
})
fmt.Println(string(log))
// log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
log.Printf("[%s] \"%s %s%s\" %d - %vms %s\n", r.Header.Get("Fly-Client-IP"), r.Method, r.Host, r.URL.Path, resp.StatusCode, elapsed.Milliseconds(), r.Header.Get("User-Agent"))
return nil
}
func modifyDashResponse(resp *http.Response) error {
r := resp.Request
// cache built+hashed dashboard js/css files forever
is_dash_static_asset := strings.HasPrefix(r.URL.Path, "/assets/") &&
(strings.HasSuffix(r.URL.Path, ".js") || strings.HasSuffix(r.URL.Path, ".css") || strings.HasSuffix(r.URL.Path, ".map"))
if is_dash_static_asset && resp.StatusCode == 200 {
resp.Header.Add("Cache-Control", "max-age=31536000, s-maxage=31536000, immutable")
}
return logTimeElapsed(resp)
}
func main() {
prometheus.MustRegister(metric)
@@ -161,3 +78,11 @@ func main() {
http.ListenAndServe(":8080", ProxyHandler{})
}
var metric = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "pk_http_requests",
Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
},
[]string{"domain", "method", "status", "route"},
)

View File

@@ -1,43 +0,0 @@
package main
import (
"fmt"
"net/http"
"time"
"web-proxy/redis_rate"
)
func is_api_ratelimited(rw http.ResponseWriter, r *http.Request) bool {
var limit int
var key string
if r.Header.Get("X-PluralKit-App") == token2 {
limit = 20
key = "token2"
} else {
limit = 2
key = r.Header.Get("Fly-Client-IP")
}
res, err := limiter.Allow(r.Context(), "ratelimit:"+key, redis_rate.Limit{
Period: time.Second,
Rate: limit,
Burst: 5,
})
if err != nil {
panic(err)
}
rw.Header().Set("X-RateLimit-Limit", fmt.Sprint(limit))
rw.Header().Set("X-RateLimit-Remaining", fmt.Sprint(res.Remaining))
rw.Header().Set("X-RateLimit-Reset", fmt.Sprint(time.Now().Add(res.ResetAfter).UnixNano()/1_000_000))
if res.Allowed < 1 {
rw.WriteHeader(429)
rw.Write([]byte(`{"message":"429: too many requests","retry_after":` + fmt.Sprint(res.RetryAfter.Milliseconds()) + `,"code":0}`))
return true
}
return false
}

View File

@@ -1,25 +0,0 @@
Copyright (c) 2013 The github.com/go-redis/redis_rate Authors.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,140 +0,0 @@
package redis_rate
import "github.com/go-redis/redis/v8"
// pluralkit changes:
// fly's hosted redis doesn't support replicate commands
// we can remove it since it's a single host
// Copyright (c) 2017 Pavel Pravosud
// https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
var allowN = redis.NewScript(`
-- this script has side-effects, so it requires replicate commands mode
-- redis.replicate_commands()
local rate_limit_key = KEYS[1]
local burst = ARGV[1]
local rate = ARGV[2]
local period = ARGV[3]
local cost = tonumber(ARGV[4])
local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
local reset_after = tat - now
local retry_after = diff * -1
return {
0, -- allowed
0, -- remaining
tostring(retry_after),
tostring(reset_after),
}
end
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
local retry_after = -1
return {cost, remaining, tostring(retry_after), tostring(reset_after)}
`)
var allowAtMost = redis.NewScript(`
-- this script has side-effects, so it requires replicate commands mode
-- redis.replicate_commands()
local rate_limit_key = KEYS[1]
local burst = ARGV[1]
local rate = ARGV[2]
local period = ARGV[3]
local cost = tonumber(ARGV[4])
local emission_interval = period / rate
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local diff = now - (tat - burst_offset)
local remaining = diff / emission_interval
if remaining < 1 then
local reset_after = tat - now
local retry_after = emission_interval - diff
return {
0, -- allowed
0, -- remaining
tostring(retry_after),
tostring(reset_after),
}
end
if remaining < cost then
cost = remaining
remaining = 0
else
remaining = remaining - cost
end
local increment = emission_interval * cost
local new_tat = tat + increment
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
return {
cost,
remaining,
tostring(-1),
tostring(reset_after),
}
`)

View File

@@ -1,198 +0,0 @@
package redis_rate
import (
"context"
"fmt"
"strconv"
"time"
"github.com/go-redis/redis/v8"
)
const redisPrefix = "rate:"
type rediser interface {
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
Del(ctx context.Context, keys ...string) *redis.IntCmd
}
type Limit struct {
Rate int
Burst int
Period time.Duration
}
func (l Limit) String() string {
return fmt.Sprintf("%d req/%s (burst %d)", l.Rate, fmtDur(l.Period), l.Burst)
}
func (l Limit) IsZero() bool {
return l == Limit{}
}
func fmtDur(d time.Duration) string {
switch d {
case time.Second:
return "s"
case time.Minute:
return "m"
case time.Hour:
return "h"
}
return d.String()
}
func PerSecond(rate int) Limit {
return Limit{
Rate: rate,
Period: time.Second,
Burst: rate,
}
}
func PerMinute(rate int) Limit {
return Limit{
Rate: rate,
Period: time.Minute,
Burst: rate,
}
}
func PerHour(rate int) Limit {
return Limit{
Rate: rate,
Period: time.Hour,
Burst: rate,
}
}
//------------------------------------------------------------------------------
// Limiter controls how frequently events are allowed to happen.
type Limiter struct {
rdb rediser
}
// NewLimiter returns a new Limiter.
func NewLimiter(rdb rediser) *Limiter {
return &Limiter{
rdb: rdb,
}
}
// Allow is a shortcut for AllowN(ctx, key, limit, 1).
func (l Limiter) Allow(ctx context.Context, key string, limit Limit) (*Result, error) {
return l.AllowN(ctx, key, limit, 1)
}
// AllowN reports whether n events may happen at time now.
func (l Limiter) AllowN(
ctx context.Context,
key string,
limit Limit,
n int,
) (*Result, error) {
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
if err != nil {
return nil, err
}
values = v.([]interface{})
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
if err != nil {
return nil, err
}
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
if err != nil {
return nil, err
}
res := &Result{
Limit: limit,
Allowed: int(values[0].(int64)),
Remaining: int(values[1].(int64)),
RetryAfter: dur(retryAfter),
ResetAfter: dur(resetAfter),
}
return res, nil
}
// AllowAtMost reports whether at most n events may happen at time now.
// It returns number of allowed events that is less than or equal to n.
func (l Limiter) AllowAtMost(
ctx context.Context,
key string,
limit Limit,
n int,
) (*Result, error) {
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
if err != nil {
return nil, err
}
values = v.([]interface{})
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
if err != nil {
return nil, err
}
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
if err != nil {
return nil, err
}
res := &Result{
Limit: limit,
Allowed: int(values[0].(int64)),
Remaining: int(values[1].(int64)),
RetryAfter: dur(retryAfter),
ResetAfter: dur(resetAfter),
}
return res, nil
}
// Reset gets a key and reset all limitations and previous usages
func (l *Limiter) Reset(ctx context.Context, key string) error {
return l.rdb.Del(ctx, redisPrefix+key).Err()
}
func dur(f float64) time.Duration {
if f == -1 {
return -1
}
return time.Duration(f * float64(time.Second))
}
type Result struct {
// Limit is the limit that was used to obtain this result.
Limit Limit
// Allowed is the number of events that may happen at time now.
Allowed int
// Remaining is the maximum number of requests that could be
// permitted instantaneously for this key given the current
// state. For example, if a rate limiter allows 10 requests per
// second and has already received 6 requests for this key this
// second, Remaining would be 4.
Remaining int
// RetryAfter is the time until the next request will be permitted.
// It should be -1 unless the rate limit has been exceeded.
RetryAfter time.Duration
// ResetAfter is the time until the RateLimiter returns to its
// initial state for a given key. For example, if a rate limiter
// manages requests per second and received one request 200ms ago,
// Reset would return 800ms. You can also think of this as the time
// until Limit and Remaining will be equal.
ResetAfter time.Duration
}

View File

@@ -1,74 +0,0 @@
package main
import (
"net/http/httputil"
"net/url"
"os"
"regexp"
"strings"
"github.com/prometheus/client_golang/prometheus"
)
var metric = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "pk_http_requests",
Buckets: []float64{.1, .25, 1, 2.5, 5, 20},
},
[]string{"domain", "method", "status", "route"},
)
func proxyTo(host string) *httputil.ReverseProxy {
rp := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: host,
RawQuery: "",
})
rp.ModifyResponse = logTimeElapsed
return rp
}
var systemsRegex = regexp.MustCompile("systems/[^/{}]+")
var membersRegex = regexp.MustCompile("members/[^/{}]+")
var groupsRegex = regexp.MustCompile("groups/[^/{}]+")
var switchesRegex = regexp.MustCompile("switches/[^/{}]+")
var guildsRegex = regexp.MustCompile("guilds/[^/{}]+")
var messagesRegex = regexp.MustCompile("messages/[^/{}]+")
func cleanPath(host, path string) string {
if host != "api.pluralkit.me" {
return ""
}
path = strings.ToLower(path)
if !(strings.HasPrefix(path, "/v2") || strings.HasPrefix(path, "/private")) {
return ""
}
path = systemsRegex.ReplaceAllString(path, "systems/{systemRef}")
path = membersRegex.ReplaceAllString(path, "members/{memberRef}")
path = groupsRegex.ReplaceAllString(path, "groups/{groupRef}")
path = switchesRegex.ReplaceAllString(path, "switches/{switchRef}")
path = guildsRegex.ReplaceAllString(path, "guilds/{guild_id}")
path = messagesRegex.ReplaceAllString(path, "messages/{message_id}")
return path
}
func requireEnv(key string) string {
if val, ok := os.LookupEnv(key); !ok {
panic("missing `" + key + "` in environment")
} else {
return val
}
}
func is_trying_to_use_v1_path_on_v2(path string) bool {
return strings.HasPrefix(path, "/v2/s/") ||
strings.HasPrefix(path, "/v2/m/") ||
strings.HasPrefix(path, "/v2/a/") ||
strings.HasPrefix(path, "/v2/msg/") ||
path == "/v2/s" ||
path == "/v2/m"
}