199 lines
4.6 KiB
Go
199 lines
4.6 KiB
Go
|
package redis_rate
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"strconv"
|
||
|
"time"
|
||
|
|
||
|
"github.com/go-redis/redis/v8"
|
||
|
)
|
||
|
|
||
|
const redisPrefix = "rate:"
|
||
|
|
||
|
type rediser interface {
|
||
|
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
|
||
|
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
|
||
|
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
|
||
|
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
|
||
|
Del(ctx context.Context, keys ...string) *redis.IntCmd
|
||
|
}
|
||
|
|
||
|
type Limit struct {
|
||
|
Rate int
|
||
|
Burst int
|
||
|
Period time.Duration
|
||
|
}
|
||
|
|
||
|
func (l Limit) String() string {
|
||
|
return fmt.Sprintf("%d req/%s (burst %d)", l.Rate, fmtDur(l.Period), l.Burst)
|
||
|
}
|
||
|
|
||
|
func (l Limit) IsZero() bool {
|
||
|
return l == Limit{}
|
||
|
}
|
||
|
|
||
|
func fmtDur(d time.Duration) string {
|
||
|
switch d {
|
||
|
case time.Second:
|
||
|
return "s"
|
||
|
case time.Minute:
|
||
|
return "m"
|
||
|
case time.Hour:
|
||
|
return "h"
|
||
|
}
|
||
|
return d.String()
|
||
|
}
|
||
|
|
||
|
func PerSecond(rate int) Limit {
|
||
|
return Limit{
|
||
|
Rate: rate,
|
||
|
Period: time.Second,
|
||
|
Burst: rate,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func PerMinute(rate int) Limit {
|
||
|
return Limit{
|
||
|
Rate: rate,
|
||
|
Period: time.Minute,
|
||
|
Burst: rate,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func PerHour(rate int) Limit {
|
||
|
return Limit{
|
||
|
Rate: rate,
|
||
|
Period: time.Hour,
|
||
|
Burst: rate,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
// Limiter controls how frequently events are allowed to happen.
|
||
|
type Limiter struct {
|
||
|
rdb rediser
|
||
|
}
|
||
|
|
||
|
// NewLimiter returns a new Limiter.
|
||
|
func NewLimiter(rdb rediser) *Limiter {
|
||
|
return &Limiter{
|
||
|
rdb: rdb,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Allow is a shortcut for AllowN(ctx, key, limit, 1).
|
||
|
func (l Limiter) Allow(ctx context.Context, key string, limit Limit) (*Result, error) {
|
||
|
return l.AllowN(ctx, key, limit, 1)
|
||
|
}
|
||
|
|
||
|
// AllowN reports whether n events may happen at time now.
|
||
|
func (l Limiter) AllowN(
|
||
|
ctx context.Context,
|
||
|
key string,
|
||
|
limit Limit,
|
||
|
n int,
|
||
|
) (*Result, error) {
|
||
|
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
|
||
|
v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
values = v.([]interface{})
|
||
|
|
||
|
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res := &Result{
|
||
|
Limit: limit,
|
||
|
Allowed: int(values[0].(int64)),
|
||
|
Remaining: int(values[1].(int64)),
|
||
|
RetryAfter: dur(retryAfter),
|
||
|
ResetAfter: dur(resetAfter),
|
||
|
}
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
// AllowAtMost reports whether at most n events may happen at time now.
|
||
|
// It returns number of allowed events that is less than or equal to n.
|
||
|
func (l Limiter) AllowAtMost(
|
||
|
ctx context.Context,
|
||
|
key string,
|
||
|
limit Limit,
|
||
|
n int,
|
||
|
) (*Result, error) {
|
||
|
values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
|
||
|
v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
values = v.([]interface{})
|
||
|
|
||
|
retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res := &Result{
|
||
|
Limit: limit,
|
||
|
Allowed: int(values[0].(int64)),
|
||
|
Remaining: int(values[1].(int64)),
|
||
|
RetryAfter: dur(retryAfter),
|
||
|
ResetAfter: dur(resetAfter),
|
||
|
}
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
// Reset gets a key and reset all limitations and previous usages
|
||
|
func (l *Limiter) Reset(ctx context.Context, key string) error {
|
||
|
return l.rdb.Del(ctx, redisPrefix+key).Err()
|
||
|
}
|
||
|
|
||
|
func dur(f float64) time.Duration {
|
||
|
if f == -1 {
|
||
|
return -1
|
||
|
}
|
||
|
return time.Duration(f * float64(time.Second))
|
||
|
}
|
||
|
|
||
|
type Result struct {
|
||
|
// Limit is the limit that was used to obtain this result.
|
||
|
Limit Limit
|
||
|
|
||
|
// Allowed is the number of events that may happen at time now.
|
||
|
Allowed int
|
||
|
|
||
|
// Remaining is the maximum number of requests that could be
|
||
|
// permitted instantaneously for this key given the current
|
||
|
// state. For example, if a rate limiter allows 10 requests per
|
||
|
// second and has already received 6 requests for this key this
|
||
|
// second, Remaining would be 4.
|
||
|
Remaining int
|
||
|
|
||
|
// RetryAfter is the time until the next request will be permitted.
|
||
|
// It should be -1 unless the rate limit has been exceeded.
|
||
|
RetryAfter time.Duration
|
||
|
|
||
|
// ResetAfter is the time until the RateLimiter returns to its
|
||
|
// initial state for a given key. For example, if a rate limiter
|
||
|
// manages requests per second and received one request 200ms ago,
|
||
|
// Reset would return 800ms. You can also think of this as the time
|
||
|
// until Limit and Remaining will be equal.
|
||
|
ResetAfter time.Duration
|
||
|
}
|