go-zero/core/limit/tokenlimit.go
2024-03-08 22:35:17 +08:00

183 lines
4.7 KiB
Go

package limit
import (
"context"
"errors"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
xrate "golang.org/x/time/rate"
)
const (
tokenFormat = "{%s}.tokens"
timestampFormat = "{%s}.ts"
pingInterval = time.Millisecond * 100
)
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
// KEYS[1] as tokens_key
// KEYS[2] as timestamp_key
var script = redis.NewScript(`local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
local last_tokens = tonumber(redis.call("get", KEYS[1]))
if last_tokens == nil then
last_tokens = capacity
end
local last_refreshed = tonumber(redis.call("get", KEYS[2]))
if last_refreshed == nil then
last_refreshed = 0
end
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
if allowed then
new_tokens = filled_tokens - requested
end
redis.call("setex", KEYS[1], ttl, new_tokens)
redis.call("setex", KEYS[2], ttl, now)
return allowed`)
// A TokenLimiter controls how frequently events are allowed to happen with in one second.
type TokenLimiter struct {
rate int
burst int
store *redis.Redis
tokenKey string
timestampKey string
rescueLock sync.Mutex
redisAlive uint32
monitorStarted bool
rescueLimiter *xrate.Limiter
}
// NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
// bursts of at most burst tokens.
func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
tokenKey := fmt.Sprintf(tokenFormat, key)
timestampKey := fmt.Sprintf(timestampFormat, key)
return &TokenLimiter{
rate: rate,
burst: burst,
store: store,
tokenKey: tokenKey,
timestampKey: timestampKey,
redisAlive: 1,
rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
}
}
// Allow is shorthand for AllowN(time.Now(), 1).
func (lim *TokenLimiter) Allow() bool {
return lim.AllowN(time.Now(), 1)
}
// AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
return lim.AllowNCtx(ctx, time.Now(), 1)
}
// AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(context.Background(), now, n)
}
// AllowNCtx reports whether n events may happen at time now with incoming context.
// Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
return lim.reserveN(ctx, now, n)
}
func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
if atomic.LoadUint32(&lim.redisAlive) == 0 {
return lim.rescueLimiter.AllowN(now, n)
}
resp, err := lim.store.ScriptRunCtx(ctx,
script,
[]string{
lim.tokenKey,
lim.timestampKey,
},
[]string{
strconv.Itoa(lim.rate),
strconv.Itoa(lim.burst),
strconv.FormatInt(now.Unix(), 10),
strconv.Itoa(n),
})
// redis allowed == false
// Lua boolean false -> r Nil bulk reply
if errors.Is(err, redis.Nil) {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
logx.Errorf("fail to use rate limiter: %s", err)
return false
}
if err != nil {
logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
lim.startMonitor()
return lim.rescueLimiter.AllowN(now, n)
}
code, ok := resp.(int64)
if !ok {
logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
lim.startMonitor()
return lim.rescueLimiter.AllowN(now, n)
}
// redis allowed == true
// Lua boolean true -> r integer reply with value of 1
return code == 1
}
func (lim *TokenLimiter) startMonitor() {
lim.rescueLock.Lock()
defer lim.rescueLock.Unlock()
if lim.monitorStarted {
return
}
lim.monitorStarted = true
atomic.StoreUint32(&lim.redisAlive, 0)
go lim.waitForRedis()
}
func (lim *TokenLimiter) waitForRedis() {
ticker := time.NewTicker(pingInterval)
defer func() {
ticker.Stop()
lim.rescueLock.Lock()
lim.monitorStarted = false
lim.rescueLock.Unlock()
}()
for range ticker.C {
if lim.store.Ping() {
atomic.StoreUint32(&lim.redisAlive, 1)
return
}
}
}