go-zero/core/limit/periodlimit.go

129 lines
2.9 KiB
Go

package limit
import (
"context"
"errors"
"strconv"
"time"
"github.com/zeromicro/go-zero/core/stores/redis"
)
const (
// Unknown means not initialized state.
Unknown = iota
// Allowed means allowed state.
Allowed
// HitQuota means this request exactly hit the quota.
HitQuota
// OverQuota means passed the quota.
OverQuota
internalOverQuota = 0
internalAllowed = 1
internalHitQuota = 2
)
var (
// ErrUnknownCode is an error that represents unknown status code.
ErrUnknownCode = errors.New("unknown status code")
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
periodScript = redis.NewScript(`local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = redis.call("INCRBY", KEYS[1], 1)
if current == 1 then
redis.call("expire", KEYS[1], window)
end
if current < limit then
return 1
elseif current == limit then
return 2
else
return 0
end`)
)
type (
// PeriodOption defines the method to customize a PeriodLimit.
PeriodOption func(l *PeriodLimit)
// A PeriodLimit is used to limit requests during a period of time.
PeriodLimit struct {
period int
quota int
limitStore *redis.Redis
keyPrefix string
align bool
}
)
// NewPeriodLimit returns a PeriodLimit with given parameters.
func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
opts ...PeriodOption) *PeriodLimit {
limiter := &PeriodLimit{
period: period,
quota: quota,
limitStore: limitStore,
keyPrefix: keyPrefix,
}
for _, opt := range opts {
opt(limiter)
}
return limiter
}
// Take requests a permit, it returns the permit state.
func (h *PeriodLimit) Take(key string) (int, error) {
return h.TakeCtx(context.Background(), key)
}
// TakeCtx requests a permit with context, it returns the permit state.
func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
strconv.Itoa(h.quota),
strconv.Itoa(h.calcExpireSeconds()),
})
if err != nil {
return Unknown, err
}
code, ok := resp.(int64)
if !ok {
return Unknown, ErrUnknownCode
}
switch code {
case internalOverQuota:
return OverQuota, nil
case internalAllowed:
return Allowed, nil
case internalHitQuota:
return HitQuota, nil
default:
return Unknown, ErrUnknownCode
}
}
func (h *PeriodLimit) calcExpireSeconds() int {
if h.align {
now := time.Now()
_, offset := now.Zone()
unix := now.Unix() + int64(offset)
return h.period - int(unix%int64(h.period))
}
return h.period
}
// Align returns a func to customize a PeriodLimit with alignment.
// For example, if we want to limit end users with 5 sms verification messages every day,
// we need to align with the local timezone and the start of the day.
func Align() PeriodOption {
return func(l *PeriodLimit) {
l.align = true
}
}