mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 00:50:20 +08:00
initial import
This commit is contained in:
commit
7e3a369a8f
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@ -0,0 +1 @@
|
||||
**/.git
|
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@ -0,0 +1 @@
|
||||
* text=auto eol=lf
|
46
.gitignore
vendored
Normal file
46
.gitignore
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
# Ignore all
|
||||
*
|
||||
|
||||
# Unignore all with extensions
|
||||
!*.*
|
||||
!**/Dockerfile
|
||||
|
||||
# Unignore all dirs
|
||||
!*/
|
||||
!api
|
||||
|
||||
.idea
|
||||
**/.DS_Store
|
||||
**/logs
|
||||
!vendor/github.com/songtianyi/rrframework/logs
|
||||
**/*.pem
|
||||
**/*.prof
|
||||
**/*.p12
|
||||
!Makefile
|
||||
|
||||
# gitlab ci
|
||||
.cache
|
||||
|
||||
# chatbot
|
||||
**/production.json
|
||||
**/*.corpus.json
|
||||
**/*.txt
|
||||
**/*.gob
|
||||
|
||||
# example
|
||||
example/**/*.csv
|
||||
|
||||
# hera
|
||||
service/hera/cli/readdata/intergrationtest/data
|
||||
service/hera/devkit/ch331/data
|
||||
service/hera/devkit/ch331/ck
|
||||
service/hera/cli/replaybeat/etc
|
||||
# goctl
|
||||
tools/goctl/api/autogen
|
||||
|
||||
vendor/*
|
||||
/service/hera/cli/dboperation/etc/hera.json
|
||||
|
||||
# vim auto backup file
|
||||
*~
|
||||
!OWNERS
|
16
.gitlab-ci.yml
Normal file
16
.gitlab-ci.yml
Normal file
@ -0,0 +1,16 @@
|
||||
stages:
|
||||
- analysis
|
||||
|
||||
variables:
|
||||
GOPATH: '/runner-cache/zero'
|
||||
GOCACHE: '/runner-cache/zero'
|
||||
GOPROXY: 'https://goproxy.cn,direct'
|
||||
|
||||
analysis:
|
||||
stage: analysis
|
||||
image: golang
|
||||
script:
|
||||
- go version && go env
|
||||
- go test -short $(go list ./...) | grep -v "no test"
|
||||
only:
|
||||
- merge_requests
|
43
.golangci.yml
Normal file
43
.golangci.yml
Normal file
@ -0,0 +1,43 @@
|
||||
run:
|
||||
# concurrency: 6
|
||||
timeout: 5m
|
||||
skip-dirs:
|
||||
- core
|
||||
- diq
|
||||
- doc
|
||||
- dq
|
||||
- example
|
||||
- kmq
|
||||
- kq
|
||||
- ngin
|
||||
- rq
|
||||
- rpcx
|
||||
# - service
|
||||
- stash
|
||||
- tools
|
||||
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- bodyclose
|
||||
- deadcode
|
||||
- errcheck
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- staticcheck
|
||||
- structcheck
|
||||
- typecheck
|
||||
- unused
|
||||
- varcheck
|
||||
# - dupl
|
||||
|
||||
|
||||
linters-settings:
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: 'SA1019: (baseresponse.BoolResponse|oldresponse.FormatBadRequestResponse|oldresponse.FormatResponse)|SA5008: unknown JSON option ("optional"|"default=|"range=|"options=)'
|
161
core/bloom/bloom.go
Normal file
161
core/bloom/bloom.go
Normal file
@ -0,0 +1,161 @@
|
||||
package bloom
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"zero/core/hash"
|
||||
"zero/core/stores/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
|
||||
// maps as k in the error rate table
|
||||
maps = 14
|
||||
setScript = `
|
||||
local key = KEYS[1]
|
||||
for _, offset in ipairs(ARGV) do
|
||||
redis.call("setbit", key, offset, 1)
|
||||
end
|
||||
`
|
||||
testScript = `
|
||||
local key = KEYS[1]
|
||||
for _, offset in ipairs(ARGV) do
|
||||
if tonumber(redis.call("getbit", key, offset)) == 0 then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
`
|
||||
)
|
||||
|
||||
var ErrTooLargeOffset = errors.New("too large offset")
|
||||
|
||||
type (
|
||||
BitSetProvider interface {
|
||||
check([]uint) (bool, error)
|
||||
set([]uint) error
|
||||
}
|
||||
|
||||
BloomFilter struct {
|
||||
bits uint
|
||||
maps uint
|
||||
bitSet BitSetProvider
|
||||
}
|
||||
)
|
||||
|
||||
// New create a BloomFilter, store is the backed redis, key is the key for the bloom filter,
|
||||
// bits is how many bits will be used, maps is how many hashes for each addition.
|
||||
// best practices:
|
||||
// elements - means how many actual elements
|
||||
// when maps = 14, formula: 0.7*(bits/maps), bits = 20*elements, the error rate is 0.000067 < 1e-4
|
||||
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
|
||||
func New(store *redis.Redis, key string, bits uint) *BloomFilter {
|
||||
return &BloomFilter{
|
||||
bits: bits,
|
||||
bitSet: newRedisBitSet(store, key, bits),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *BloomFilter) Add(data []byte) error {
|
||||
locations := f.getLocations(data)
|
||||
err := f.bitSet.set(locations)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *BloomFilter) Exists(data []byte) (bool, error) {
|
||||
locations := f.getLocations(data)
|
||||
isSet, err := f.bitSet.check(locations)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !isSet {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (f *BloomFilter) getLocations(data []byte) []uint {
|
||||
locations := make([]uint, maps)
|
||||
for i := uint(0); i < maps; i++ {
|
||||
hashValue := hash.Hash(append(data, byte(i)))
|
||||
locations[i] = uint(hashValue % uint64(f.bits))
|
||||
}
|
||||
|
||||
return locations
|
||||
}
|
||||
|
||||
type redisBitSet struct {
|
||||
store *redis.Redis
|
||||
key string
|
||||
bits uint
|
||||
}
|
||||
|
||||
func newRedisBitSet(store *redis.Redis, key string, bits uint) *redisBitSet {
|
||||
return &redisBitSet{
|
||||
store: store,
|
||||
key: key,
|
||||
bits: bits,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
|
||||
var args []string
|
||||
|
||||
for _, offset := range offsets {
|
||||
if offset >= r.bits {
|
||||
return nil, ErrTooLargeOffset
|
||||
}
|
||||
|
||||
args = append(args, strconv.FormatUint(uint64(offset), 10))
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (r *redisBitSet) check(offsets []uint) (bool, error) {
|
||||
args, err := r.buildOffsetArgs(offsets)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
resp, err := r.store.Eval(testScript, []string{r.key}, args)
|
||||
if err == redis.Nil {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if exists, ok := resp.(int64); !ok {
|
||||
return false, nil
|
||||
} else {
|
||||
return exists == 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *redisBitSet) del() error {
|
||||
_, err := r.store.Del(r.key)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redisBitSet) expire(seconds int) error {
|
||||
return r.store.Expire(r.key, seconds)
|
||||
}
|
||||
|
||||
func (r *redisBitSet) set(offsets []uint) error {
|
||||
args, err := r.buildOffsetArgs(offsets)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = r.store.Eval(setScript, []string{r.key}, args)
|
||||
if err == redis.Nil {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
63
core/bloom/bloom_test.go
Normal file
63
core/bloom/bloom_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package bloom
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/stores/redis"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRedisBitSet_New_Set_Test(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error("Miniredis could not start")
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
store := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
bitSet := newRedisBitSet(store, "test_key", 1024)
|
||||
isSetBefore, err := bitSet.check([]uint{0})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isSetBefore {
|
||||
t.Fatal("Bit should not be set")
|
||||
}
|
||||
err = bitSet.set([]uint{512})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
isSetAfter, err := bitSet.check([]uint{512})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !isSetAfter {
|
||||
t.Fatal("Bit should be set")
|
||||
}
|
||||
err = bitSet.expire(3600)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = bitSet.del()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisBitSet_Add(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error("Miniredis could not start")
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
store := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
filter := New(store, "test_key", 64)
|
||||
assert.Nil(t, filter.Add([]byte("hello")))
|
||||
assert.Nil(t, filter.Add([]byte("world")))
|
||||
ok, err := filter.Exists([]byte("hello"))
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
}
|
229
core/breaker/breaker.go
Normal file
229
core/breaker/breaker.go
Normal file
@ -0,0 +1,229 @@
|
||||
package breaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/mathx"
|
||||
"zero/core/proc"
|
||||
"zero/core/stat"
|
||||
"zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
StateClosed State = iota
|
||||
StateOpen
|
||||
)
|
||||
|
||||
const (
|
||||
numHistoryReasons = 5
|
||||
timeFormat = "15:04:05"
|
||||
)
|
||||
|
||||
// ErrServiceUnavailable is returned when the CB state is open
|
||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||
|
||||
type (
|
||||
State = int32
|
||||
Acceptable func(err error) bool
|
||||
|
||||
Breaker interface {
|
||||
// Name returns the name of the netflixBreaker.
|
||||
Name() string
|
||||
|
||||
// Allow checks if the request is allowed.
|
||||
// If allowed, a promise will be returned, the caller needs to call promise.Accept()
|
||||
// on success, or call promise.Reject() on failure.
|
||||
// If not allow, ErrServiceUnavailable will be returned.
|
||||
Allow() (Promise, error)
|
||||
|
||||
// Do runs the given request if the netflixBreaker accepts it.
|
||||
// Do returns an error instantly if the netflixBreaker rejects the request.
|
||||
// If a panic occurs in the request, the netflixBreaker handles it as an error
|
||||
// and causes the same panic again.
|
||||
Do(req func() error) error
|
||||
|
||||
// DoWithAcceptable runs the given request if the netflixBreaker accepts it.
|
||||
// Do returns an error instantly if the netflixBreaker rejects the request.
|
||||
// If a panic occurs in the request, the netflixBreaker handles it as an error
|
||||
// and causes the same panic again.
|
||||
// acceptable checks if it's a successful call, even if the err is not nil.
|
||||
DoWithAcceptable(req func() error, acceptable Acceptable) error
|
||||
|
||||
// DoWithFallback runs the given request if the netflixBreaker accepts it.
|
||||
// DoWithFallback runs the fallback if the netflixBreaker rejects the request.
|
||||
// If a panic occurs in the request, the netflixBreaker handles it as an error
|
||||
// and causes the same panic again.
|
||||
DoWithFallback(req func() error, fallback func(err error) error) error
|
||||
|
||||
// DoWithFallbackAcceptable runs the given request if the netflixBreaker accepts it.
|
||||
// DoWithFallback runs the fallback if the netflixBreaker rejects the request.
|
||||
// If a panic occurs in the request, the netflixBreaker handles it as an error
|
||||
// and causes the same panic again.
|
||||
// acceptable checks if it's a successful call, even if the err is not nil.
|
||||
DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error
|
||||
}
|
||||
|
||||
BreakerOption func(breaker *circuitBreaker)
|
||||
|
||||
Promise interface {
|
||||
Accept()
|
||||
Reject(reason string)
|
||||
}
|
||||
|
||||
internalPromise interface {
|
||||
Accept()
|
||||
Reject()
|
||||
}
|
||||
|
||||
circuitBreaker struct {
|
||||
name string
|
||||
throttle
|
||||
}
|
||||
|
||||
internalThrottle interface {
|
||||
allow() (internalPromise, error)
|
||||
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
|
||||
}
|
||||
|
||||
throttle interface {
|
||||
allow() (Promise, error)
|
||||
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
|
||||
}
|
||||
)
|
||||
|
||||
func NewBreaker(opts ...BreakerOption) Breaker {
|
||||
var b circuitBreaker
|
||||
for _, opt := range opts {
|
||||
opt(&b)
|
||||
}
|
||||
if len(b.name) == 0 {
|
||||
b.name = stringx.Rand()
|
||||
}
|
||||
b.throttle = newLoggedThrottle(b.name, newGoogleBreaker())
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
func (cb *circuitBreaker) Allow() (Promise, error) {
|
||||
return cb.throttle.allow()
|
||||
}
|
||||
|
||||
func (cb *circuitBreaker) Do(req func() error) error {
|
||||
return cb.throttle.doReq(req, nil, defaultAcceptable)
|
||||
}
|
||||
|
||||
func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
|
||||
return cb.throttle.doReq(req, nil, acceptable)
|
||||
}
|
||||
|
||||
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
|
||||
return cb.throttle.doReq(req, fallback, defaultAcceptable)
|
||||
}
|
||||
|
||||
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
|
||||
acceptable Acceptable) error {
|
||||
return cb.throttle.doReq(req, fallback, acceptable)
|
||||
}
|
||||
|
||||
func (cb *circuitBreaker) Name() string {
|
||||
return cb.name
|
||||
}
|
||||
|
||||
func WithName(name string) BreakerOption {
|
||||
return func(b *circuitBreaker) {
|
||||
b.name = name
|
||||
}
|
||||
}
|
||||
|
||||
func defaultAcceptable(err error) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
type loggedThrottle struct {
|
||||
name string
|
||||
internalThrottle
|
||||
errWin *errorWindow
|
||||
}
|
||||
|
||||
func newLoggedThrottle(name string, t internalThrottle) loggedThrottle {
|
||||
return loggedThrottle{
|
||||
name: name,
|
||||
internalThrottle: t,
|
||||
errWin: new(errorWindow),
|
||||
}
|
||||
}
|
||||
|
||||
func (lt loggedThrottle) allow() (Promise, error) {
|
||||
promise, err := lt.internalThrottle.allow()
|
||||
return promiseWithReason{
|
||||
promise: promise,
|
||||
errWin: lt.errWin,
|
||||
}, lt.logError(err)
|
||||
}
|
||||
|
||||
func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
|
||||
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
|
||||
accept := acceptable(err)
|
||||
if !accept {
|
||||
lt.errWin.add(err.Error())
|
||||
}
|
||||
return accept
|
||||
}))
|
||||
}
|
||||
|
||||
func (lt loggedThrottle) logError(err error) error {
|
||||
if err == ErrServiceUnavailable {
|
||||
// if circuit open, not possible to have empty error window
|
||||
stat.Report(fmt.Sprintf(
|
||||
"proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s",
|
||||
proc.ProcessName(), proc.Pid(), lt.name, lt.errWin))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type errorWindow struct {
|
||||
reasons [numHistoryReasons]string
|
||||
index int
|
||||
count int
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (ew *errorWindow) add(reason string) {
|
||||
ew.lock.Lock()
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
||||
ew.index = (ew.index + 1) % numHistoryReasons
|
||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
||||
ew.lock.Unlock()
|
||||
}
|
||||
|
||||
func (ew *errorWindow) String() string {
|
||||
var builder strings.Builder
|
||||
|
||||
ew.lock.Lock()
|
||||
for i := ew.index + ew.count - 1; i >= ew.index; i-- {
|
||||
builder.WriteString(ew.reasons[i%numHistoryReasons])
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
ew.lock.Unlock()
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
type promiseWithReason struct {
|
||||
promise internalPromise
|
||||
errWin *errorWindow
|
||||
}
|
||||
|
||||
func (p promiseWithReason) Accept() {
|
||||
p.promise.Accept()
|
||||
}
|
||||
|
||||
func (p promiseWithReason) Reject(reason string) {
|
||||
p.errWin.add(reason)
|
||||
p.promise.Reject()
|
||||
}
|
44
core/breaker/breaker_test.go
Normal file
44
core/breaker/breaker_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package breaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Allow(t *testing.T) {
|
||||
b := NewBreaker()
|
||||
assert.True(t, len(b.Name()) > 0)
|
||||
_, err := b.Allow()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestLogReason(t *testing.T) {
|
||||
b := NewBreaker()
|
||||
assert.True(t, len(b.Name()) > 0)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
_ = b.Do(func() error {
|
||||
return errors.New(strconv.Itoa(i))
|
||||
})
|
||||
}
|
||||
errs := b.(*circuitBreaker).throttle.(loggedThrottle).errWin
|
||||
assert.Equal(t, numHistoryReasons, errs.count)
|
||||
}
|
||||
|
||||
func BenchmarkGoogleBreaker(b *testing.B) {
|
||||
br := NewBreaker()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = br.Do(func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
76
core/breaker/breakers.go
Normal file
76
core/breaker/breakers.go
Normal file
@ -0,0 +1,76 @@
|
||||
package breaker
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
lock sync.RWMutex
|
||||
breakers = make(map[string]Breaker)
|
||||
)
|
||||
|
||||
func Do(name string, req func() error) error {
|
||||
return do(name, func(b Breaker) error {
|
||||
return b.Do(req)
|
||||
})
|
||||
}
|
||||
|
||||
func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error {
|
||||
return do(name, func(b Breaker) error {
|
||||
return b.DoWithAcceptable(req, acceptable)
|
||||
})
|
||||
}
|
||||
|
||||
func DoWithFallback(name string, req func() error, fallback func(err error) error) error {
|
||||
return do(name, func(b Breaker) error {
|
||||
return b.DoWithFallback(req, fallback)
|
||||
})
|
||||
}
|
||||
|
||||
func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error,
|
||||
acceptable Acceptable) error {
|
||||
return do(name, func(b Breaker) error {
|
||||
return b.DoWithFallbackAcceptable(req, fallback, acceptable)
|
||||
})
|
||||
}
|
||||
|
||||
func GetBreaker(name string) Breaker {
|
||||
lock.RLock()
|
||||
b, ok := breakers[name]
|
||||
lock.RUnlock()
|
||||
if ok {
|
||||
return b
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
b = NewBreaker()
|
||||
breakers[name] = b
|
||||
return b
|
||||
}
|
||||
|
||||
func NoBreakFor(name string) {
|
||||
lock.Lock()
|
||||
breakers[name] = newNoOpBreaker()
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
func do(name string, execute func(b Breaker) error) error {
|
||||
lock.RLock()
|
||||
b, ok := breakers[name]
|
||||
lock.RUnlock()
|
||||
if ok {
|
||||
return execute(b)
|
||||
} else {
|
||||
lock.Lock()
|
||||
b, ok = breakers[name]
|
||||
if ok {
|
||||
lock.Unlock()
|
||||
return execute(b)
|
||||
} else {
|
||||
b = NewBreaker(WithName(name))
|
||||
breakers[name] = b
|
||||
lock.Unlock()
|
||||
return execute(b)
|
||||
}
|
||||
}
|
||||
}
|
115
core/breaker/breakers_test.go
Normal file
115
core/breaker/breakers_test.go
Normal file
@ -0,0 +1,115 @@
|
||||
package breaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestBreakersDo(t *testing.T) {
|
||||
assert.Nil(t, Do("any", func() error {
|
||||
return nil
|
||||
}))
|
||||
|
||||
errDummy := errors.New("any")
|
||||
assert.Equal(t, errDummy, Do("any", func() error {
|
||||
return errDummy
|
||||
}))
|
||||
}
|
||||
|
||||
func TestBreakersDoWithAcceptable(t *testing.T) {
|
||||
errDummy := errors.New("anyone")
|
||||
for i := 0; i < 10000; i++ {
|
||||
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
|
||||
return errDummy
|
||||
}, func(err error) bool {
|
||||
return err == nil || err == errDummy
|
||||
}))
|
||||
}
|
||||
verify(t, func() bool {
|
||||
return Do("anyone", func() error {
|
||||
return nil
|
||||
}) == nil
|
||||
})
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := DoWithAcceptable("another", func() error {
|
||||
return errDummy
|
||||
}, func(err error) bool {
|
||||
return err == nil
|
||||
})
|
||||
assert.True(t, err == errDummy || err == ErrServiceUnavailable)
|
||||
}
|
||||
verify(t, func() bool {
|
||||
return ErrServiceUnavailable == Do("another", func() error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBreakersNoBreakerFor(t *testing.T) {
|
||||
NoBreakFor("any")
|
||||
errDummy := errors.New("any")
|
||||
for i := 0; i < 10000; i++ {
|
||||
assert.Equal(t, errDummy, GetBreaker("any").Do(func() error {
|
||||
return errDummy
|
||||
}))
|
||||
}
|
||||
assert.Equal(t, nil, Do("any", func() error {
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestBreakersFallback(t *testing.T) {
|
||||
errDummy := errors.New("any")
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := DoWithFallback("fallback", func() error {
|
||||
return errDummy
|
||||
}, func(err error) error {
|
||||
return nil
|
||||
})
|
||||
assert.True(t, err == nil || err == errDummy)
|
||||
}
|
||||
verify(t, func() bool {
|
||||
return ErrServiceUnavailable == Do("fallback", func() error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBreakersAcceptableFallback(t *testing.T) {
|
||||
errDummy := errors.New("any")
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := DoWithFallbackAcceptable("acceptablefallback", func() error {
|
||||
return errDummy
|
||||
}, func(err error) error {
|
||||
return nil
|
||||
}, func(err error) bool {
|
||||
return err == nil
|
||||
})
|
||||
assert.True(t, err == nil || err == errDummy)
|
||||
}
|
||||
verify(t, func() bool {
|
||||
return ErrServiceUnavailable == Do("acceptablefallback", func() error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func verify(t *testing.T, fn func() bool) {
|
||||
var count int
|
||||
for i := 0; i < 100; i++ {
|
||||
if fn() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.True(t, count >= 80, fmt.Sprintf("should be greater than 80, actual %d", count))
|
||||
}
|
125
core/breaker/googlebreaker.go
Normal file
125
core/breaker/googlebreaker.go
Normal file
@ -0,0 +1,125 @@
|
||||
package breaker
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"zero/core/collection"
|
||||
"zero/core/mathx"
|
||||
)
|
||||
|
||||
const (
|
||||
// 250ms for bucket duration
|
||||
window = time.Second * 10
|
||||
buckets = 40
|
||||
k = 1.5
|
||||
protection = 5
|
||||
)
|
||||
|
||||
// googleBreaker is a netflixBreaker pattern from google.
|
||||
// see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/
|
||||
type googleBreaker struct {
|
||||
k float64
|
||||
state int32
|
||||
stat *collection.RollingWindow
|
||||
proba *mathx.Proba
|
||||
}
|
||||
|
||||
func newGoogleBreaker() *googleBreaker {
|
||||
bucketDuration := time.Duration(int64(window) / int64(buckets))
|
||||
st := collection.NewRollingWindow(buckets, bucketDuration)
|
||||
return &googleBreaker{
|
||||
stat: st,
|
||||
k: k,
|
||||
state: StateClosed,
|
||||
proba: mathx.NewProba(),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *googleBreaker) accept() error {
|
||||
accepts, total := b.history()
|
||||
weightedAccepts := b.k * float64(accepts)
|
||||
// https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101
|
||||
dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1))
|
||||
if dropRatio <= 0 {
|
||||
if atomic.LoadInt32(&b.state) == StateOpen {
|
||||
atomic.CompareAndSwapInt32(&b.state, StateOpen, StateClosed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&b.state) == StateClosed {
|
||||
atomic.CompareAndSwapInt32(&b.state, StateClosed, StateOpen)
|
||||
}
|
||||
if b.proba.TrueOnProba(dropRatio) {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *googleBreaker) allow() (internalPromise, error) {
|
||||
if err := b.accept(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return googlePromise{
|
||||
b: b,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *googleBreaker) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
|
||||
if err := b.accept(); err != nil {
|
||||
if fallback != nil {
|
||||
return fallback(err)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
b.markFailure()
|
||||
panic(e)
|
||||
}
|
||||
}()
|
||||
|
||||
err := req()
|
||||
if acceptable(err) {
|
||||
b.markSuccess()
|
||||
} else {
|
||||
b.markFailure()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *googleBreaker) markSuccess() {
|
||||
b.stat.Add(1)
|
||||
}
|
||||
|
||||
func (b *googleBreaker) markFailure() {
|
||||
b.stat.Add(0)
|
||||
}
|
||||
|
||||
func (b *googleBreaker) history() (accepts int64, total int64) {
|
||||
b.stat.Reduce(func(b *collection.Bucket) {
|
||||
accepts += int64(b.Sum)
|
||||
total += b.Count
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type googlePromise struct {
|
||||
b *googleBreaker
|
||||
}
|
||||
|
||||
func (p googlePromise) Accept() {
|
||||
p.b.markSuccess()
|
||||
}
|
||||
|
||||
func (p googlePromise) Reject() {
|
||||
p.b.markFailure()
|
||||
}
|
238
core/breaker/googlebreaker_test.go
Normal file
238
core/breaker/googlebreaker_test.go
Normal file
@ -0,0 +1,238 @@
|
||||
package breaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/collection"
|
||||
"zero/core/mathx"
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
testBuckets = 10
|
||||
testInterval = time.Millisecond * 10
|
||||
)
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func getGoogleBreaker() *googleBreaker {
|
||||
st := collection.NewRollingWindow(testBuckets, testInterval)
|
||||
return &googleBreaker{
|
||||
stat: st,
|
||||
k: 5,
|
||||
state: StateClosed,
|
||||
proba: mathx.NewProba(),
|
||||
}
|
||||
}
|
||||
|
||||
func markSuccessWithDuration(b *googleBreaker, count int, sleep time.Duration) {
|
||||
for i := 0; i < count; i++ {
|
||||
b.markSuccess()
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
}
|
||||
|
||||
func markFailedWithDuration(b *googleBreaker, count int, sleep time.Duration) {
|
||||
for i := 0; i < count; i++ {
|
||||
b.markFailure()
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleBreakerClose(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
markSuccess(b, 80)
|
||||
assert.Nil(t, b.accept())
|
||||
markSuccess(b, 120)
|
||||
assert.Nil(t, b.accept())
|
||||
}
|
||||
|
||||
func TestGoogleBreakerOpen(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
markSuccess(b, 10)
|
||||
assert.Nil(t, b.accept())
|
||||
markFailed(b, 100000)
|
||||
time.Sleep(testInterval * 2)
|
||||
verify(t, func() bool {
|
||||
return b.accept() != nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoogleBreakerFallback(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
markSuccess(b, 1)
|
||||
assert.Nil(t, b.accept())
|
||||
markFailed(b, 10000)
|
||||
time.Sleep(testInterval * 2)
|
||||
verify(t, func() bool {
|
||||
return b.doReq(func() error {
|
||||
return errors.New("any")
|
||||
}, func(err error) error {
|
||||
return nil
|
||||
}, defaultAcceptable) == nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoogleBreakerReject(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
markSuccess(b, 100)
|
||||
assert.Nil(t, b.accept())
|
||||
markFailed(b, 10000)
|
||||
time.Sleep(testInterval)
|
||||
assert.Equal(t, ErrServiceUnavailable, b.doReq(func() error {
|
||||
return ErrServiceUnavailable
|
||||
}, nil, defaultAcceptable))
|
||||
}
|
||||
|
||||
func TestGoogleBreakerAcceptable(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
errAcceptable := errors.New("any")
|
||||
assert.Equal(t, errAcceptable, b.doReq(func() error {
|
||||
return errAcceptable
|
||||
}, nil, func(err error) bool {
|
||||
return err == errAcceptable
|
||||
}))
|
||||
}
|
||||
|
||||
func TestGoogleBreakerNotAcceptable(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
errAcceptable := errors.New("any")
|
||||
assert.Equal(t, errAcceptable, b.doReq(func() error {
|
||||
return errAcceptable
|
||||
}, nil, func(err error) bool {
|
||||
return err != errAcceptable
|
||||
}))
|
||||
}
|
||||
|
||||
func TestGoogleBreakerPanic(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
assert.Panics(t, func() {
|
||||
_ = b.doReq(func() error {
|
||||
panic("fail")
|
||||
}, nil, defaultAcceptable)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoogleBreakerHalfOpen(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
assert.Nil(t, b.accept())
|
||||
t.Run("accept single failed/accept", func(t *testing.T) {
|
||||
markFailed(b, 10000)
|
||||
time.Sleep(testInterval * 2)
|
||||
verify(t, func() bool {
|
||||
return b.accept() != nil
|
||||
})
|
||||
})
|
||||
t.Run("accept single failed/allow", func(t *testing.T) {
|
||||
markFailed(b, 10000)
|
||||
time.Sleep(testInterval * 2)
|
||||
verify(t, func() bool {
|
||||
_, err := b.allow()
|
||||
return err != nil
|
||||
})
|
||||
})
|
||||
time.Sleep(testInterval * testBuckets)
|
||||
t.Run("accept single succeed", func(t *testing.T) {
|
||||
assert.Nil(t, b.accept())
|
||||
markSuccess(b, 10000)
|
||||
verify(t, func() bool {
|
||||
return b.accept() == nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoogleBreakerSelfProtection(t *testing.T) {
|
||||
t.Run("total request < 100", func(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
markFailed(b, 4)
|
||||
time.Sleep(testInterval)
|
||||
assert.Nil(t, b.accept())
|
||||
})
|
||||
t.Run("total request > 100, total < 2 * success", func(t *testing.T) {
|
||||
b := getGoogleBreaker()
|
||||
size := rand.Intn(10000)
|
||||
accepts := int(math.Ceil(float64(size))) + 1
|
||||
markSuccess(b, accepts)
|
||||
markFailed(b, size-accepts)
|
||||
assert.Nil(t, b.accept())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoogleBreakerHistory(t *testing.T) {
|
||||
var b *googleBreaker
|
||||
var accepts, total int64
|
||||
|
||||
sleep := testInterval
|
||||
t.Run("accepts == total", func(t *testing.T) {
|
||||
b = getGoogleBreaker()
|
||||
markSuccessWithDuration(b, 10, sleep/2)
|
||||
accepts, total = b.history()
|
||||
assert.Equal(t, int64(10), accepts)
|
||||
assert.Equal(t, int64(10), total)
|
||||
})
|
||||
|
||||
t.Run("fail == total", func(t *testing.T) {
|
||||
b = getGoogleBreaker()
|
||||
markFailedWithDuration(b, 10, sleep/2)
|
||||
accepts, total = b.history()
|
||||
assert.Equal(t, int64(0), accepts)
|
||||
assert.Equal(t, int64(10), total)
|
||||
})
|
||||
|
||||
t.Run("accepts = 1/2 * total, fail = 1/2 * total", func(t *testing.T) {
|
||||
b = getGoogleBreaker()
|
||||
markFailedWithDuration(b, 5, sleep/2)
|
||||
markSuccessWithDuration(b, 5, sleep/2)
|
||||
accepts, total = b.history()
|
||||
assert.Equal(t, int64(5), accepts)
|
||||
assert.Equal(t, int64(10), total)
|
||||
})
|
||||
|
||||
t.Run("auto reset rolling counter", func(t *testing.T) {
|
||||
b = getGoogleBreaker()
|
||||
time.Sleep(testInterval * testBuckets)
|
||||
accepts, total = b.history()
|
||||
assert.Equal(t, int64(0), accepts)
|
||||
assert.Equal(t, int64(0), total)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkGoogleBreakerAllow(b *testing.B) {
|
||||
breaker := getGoogleBreaker()
|
||||
b.ResetTimer()
|
||||
for i := 0; i <= b.N; i++ {
|
||||
breaker.accept()
|
||||
if i%2 == 0 {
|
||||
breaker.markSuccess()
|
||||
} else {
|
||||
breaker.markFailure()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func markSuccess(b *googleBreaker, count int) {
|
||||
for i := 0; i < count; i++ {
|
||||
p, err := b.allow()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
p.Accept()
|
||||
}
|
||||
}
|
||||
|
||||
func markFailed(b *googleBreaker, count int) {
|
||||
for i := 0; i < count; i++ {
|
||||
p, err := b.allow()
|
||||
if err == nil {
|
||||
p.Reject()
|
||||
}
|
||||
}
|
||||
}
|
42
core/breaker/nopbreaker.go
Normal file
42
core/breaker/nopbreaker.go
Normal file
@ -0,0 +1,42 @@
|
||||
package breaker
|
||||
|
||||
const noOpBreakerName = "nopBreaker"
|
||||
|
||||
type noOpBreaker struct{}
|
||||
|
||||
func newNoOpBreaker() Breaker {
|
||||
return noOpBreaker{}
|
||||
}
|
||||
|
||||
func (b noOpBreaker) Name() string {
|
||||
return noOpBreakerName
|
||||
}
|
||||
|
||||
func (b noOpBreaker) Allow() (Promise, error) {
|
||||
return nopPromise{}, nil
|
||||
}
|
||||
|
||||
func (b noOpBreaker) Do(req func() error) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
func (b noOpBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
func (b noOpBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
|
||||
acceptable Acceptable) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
type nopPromise struct{}
|
||||
|
||||
func (p nopPromise) Accept() {
|
||||
}
|
||||
|
||||
func (p nopPromise) Reject(reason string) {
|
||||
}
|
38
core/breaker/nopbreaker_test.go
Normal file
38
core/breaker/nopbreaker_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package breaker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNopBreaker(t *testing.T) {
|
||||
b := newNoOpBreaker()
|
||||
assert.Equal(t, noOpBreakerName, b.Name())
|
||||
p, err := b.Allow()
|
||||
assert.Nil(t, err)
|
||||
p.Accept()
|
||||
for i := 0; i < 1000; i++ {
|
||||
p, err := b.Allow()
|
||||
assert.Nil(t, err)
|
||||
p.Reject("any")
|
||||
}
|
||||
assert.Nil(t, b.Do(func() error {
|
||||
return nil
|
||||
}))
|
||||
assert.Nil(t, b.DoWithAcceptable(func() error {
|
||||
return nil
|
||||
}, defaultAcceptable))
|
||||
errDummy := errors.New("any")
|
||||
assert.Equal(t, errDummy, b.DoWithFallback(func() error {
|
||||
return errDummy
|
||||
}, func(err error) error {
|
||||
return nil
|
||||
}))
|
||||
assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error {
|
||||
return errDummy
|
||||
}, func(err error) error {
|
||||
return nil
|
||||
}, defaultAcceptable))
|
||||
}
|
19
core/cmdline/input.go
Normal file
19
core/cmdline/input.go
Normal file
@ -0,0 +1,19 @@
|
||||
package cmdline
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func EnterToContinue() {
|
||||
fmt.Print("Press 'Enter' to continue...")
|
||||
bufio.NewReader(os.Stdin).ReadBytes('\n')
|
||||
}
|
||||
|
||||
func ReadLine(prompt string) string {
|
||||
fmt.Print(prompt)
|
||||
input, _ := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||
return strings.TrimSpace(input)
|
||||
}
|
174
core/codec/aesecb.go
Normal file
174
core/codec/aesecb.go
Normal file
@ -0,0 +1,174 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
var ErrPaddingSize = errors.New("padding size error")
|
||||
|
||||
type ecb struct {
|
||||
b cipher.Block
|
||||
blockSize int
|
||||
}
|
||||
|
||||
func newECB(b cipher.Block) *ecb {
|
||||
return &ecb{
|
||||
b: b,
|
||||
blockSize: b.BlockSize(),
|
||||
}
|
||||
}
|
||||
|
||||
type ecbEncrypter ecb
|
||||
|
||||
func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
// why we don't return error is because cipher.BlockMode doesn't allow this
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
return
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
logx.Error("crypto/cipher: output smaller than input")
|
||||
return
|
||||
}
|
||||
|
||||
for len(src) > 0 {
|
||||
x.b.Encrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
type ecbDecrypter ecb
|
||||
|
||||
func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbDecrypter) BlockSize() int {
|
||||
return x.blockSize
|
||||
}
|
||||
|
||||
// why we don't return error is because cipher.BlockMode doesn't allow this
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
return
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
logx.Error("crypto/cipher: output smaller than input")
|
||||
return
|
||||
}
|
||||
|
||||
for len(src) > 0 {
|
||||
x.b.Decrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
func EcbDecrypt(key, src []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
logx.Errorf("Decrypt key error: % x", key)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypter := NewECBDecrypter(block)
|
||||
decrypted := make([]byte, len(src))
|
||||
decrypter.CryptBlocks(decrypted, src)
|
||||
|
||||
return pkcs5Unpadding(decrypted, decrypter.BlockSize())
|
||||
}
|
||||
|
||||
func EcbDecryptBase64(key, src string) (string, error) {
|
||||
keyBytes, err := getKeyBytes(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encryptedBytes, err := base64.StdEncoding.DecodeString(src)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
decryptedBytes, err := EcbDecrypt(keyBytes, encryptedBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(decryptedBytes), nil
|
||||
}
|
||||
|
||||
func EcbEncrypt(key, src []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
logx.Errorf("Encrypt key error: % x", key)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
padded := pkcs5Padding(src, block.BlockSize())
|
||||
crypted := make([]byte, len(padded))
|
||||
encrypter := NewECBEncrypter(block)
|
||||
encrypter.CryptBlocks(crypted, padded)
|
||||
|
||||
return crypted, nil
|
||||
}
|
||||
|
||||
func EcbEncryptBase64(key, src string) (string, error) {
|
||||
keyBytes, err := getKeyBytes(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
srcBytes, err := base64.StdEncoding.DecodeString(src)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encryptedBytes, err := EcbEncrypt(keyBytes, srcBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(encryptedBytes), nil
|
||||
}
|
||||
|
||||
func getKeyBytes(key string) ([]byte, error) {
|
||||
if len(key) > 32 {
|
||||
if keyBytes, err := base64.StdEncoding.DecodeString(key); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return keyBytes, nil
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(key), nil
|
||||
}
|
||||
|
||||
func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(ciphertext)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padtext...)
|
||||
}
|
||||
|
||||
func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) {
|
||||
length := len(src)
|
||||
unpadding := int(src[length-1])
|
||||
if unpadding >= length || unpadding > blockSize {
|
||||
return nil, ErrPaddingSize
|
||||
}
|
||||
|
||||
return src[:length-unpadding], nil
|
||||
}
|
88
core/codec/dh.go
Normal file
88
core/codec/dh.go
Normal file
@ -0,0 +1,88 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// see https://www.zhihu.com/question/29383090/answer/70435297
|
||||
// see https://www.ietf.org/rfc/rfc3526.txt
|
||||
// 2048-bit MODP Group
|
||||
|
||||
var (
|
||||
ErrInvalidPriKey = errors.New("invalid private key")
|
||||
ErrInvalidPubKey = errors.New("invalid public key")
|
||||
ErrPubKeyOutOfBound = errors.New("public key out of bound")
|
||||
|
||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
|
||||
g, _ = new(big.Int).SetString("2", 16)
|
||||
zero = big.NewInt(0)
|
||||
)
|
||||
|
||||
type DhKey struct {
|
||||
PriKey *big.Int
|
||||
PubKey *big.Int
|
||||
}
|
||||
|
||||
func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
|
||||
if pubKey == nil {
|
||||
return nil, ErrInvalidPubKey
|
||||
}
|
||||
|
||||
if pubKey.Sign() <= 0 && p.Cmp(pubKey) <= 0 {
|
||||
return nil, ErrPubKeyOutOfBound
|
||||
}
|
||||
|
||||
if priKey == nil {
|
||||
return nil, ErrInvalidPriKey
|
||||
}
|
||||
|
||||
return new(big.Int).Exp(pubKey, priKey, p), nil
|
||||
}
|
||||
|
||||
func GenerateKey() (*DhKey, error) {
|
||||
var err error
|
||||
var x *big.Int
|
||||
|
||||
for {
|
||||
x, err = rand.Int(rand.Reader, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zero.Cmp(x) < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
key := new(DhKey)
|
||||
key.PriKey = x
|
||||
key.PubKey = new(big.Int).Exp(g, x, p)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func NewPublicKey(bs []byte) *big.Int {
|
||||
return new(big.Int).SetBytes(bs)
|
||||
}
|
||||
|
||||
func (k *DhKey) Bytes() []byte {
|
||||
if k.PubKey == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
byteLen := (p.BitLen() + 7) >> 3
|
||||
ret := make([]byte, byteLen)
|
||||
copyWithLeftPad(ret, k.PubKey.Bytes())
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func copyWithLeftPad(dst, src []byte) {
|
||||
padBytes := len(dst) - len(src)
|
||||
for i := 0; i < padBytes; i++ {
|
||||
dst[i] = 0
|
||||
}
|
||||
copy(dst[padBytes:], src)
|
||||
}
|
73
core/codec/dh_test.go
Normal file
73
core/codec/dh_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDiffieHellman(t *testing.T) {
|
||||
key1, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
key2, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
|
||||
pubKey1, err := ComputeKey(key1.PubKey, key2.PriKey)
|
||||
assert.Nil(t, err)
|
||||
pubKey2, err := ComputeKey(key2.PubKey, key1.PriKey)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, pubKey1, pubKey2)
|
||||
}
|
||||
|
||||
func TestDiffieHellman1024(t *testing.T) {
|
||||
old := p
|
||||
p, _ = new(big.Int).SetString("F488FD584E49DBCD20B49DE49107366B336C380D451D0F7C88B31C7C5B2D8EF6F3C923C043F0A55B188D8EBB558CB85D38D334FD7C175743A31D186CDE33212CB52AFF3CE1B1294018118D7C84A70A72D686C40319C807297ACA950CD9969FABD00A509B0246D3083D66A45D419F9C7CBD894B221926BAABA25EC355E92F78C7", 16)
|
||||
defer func() {
|
||||
p = old
|
||||
}()
|
||||
|
||||
key1, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
key2, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
|
||||
pubKey1, err := ComputeKey(key1.PubKey, key2.PriKey)
|
||||
assert.Nil(t, err)
|
||||
pubKey2, err := ComputeKey(key2.PubKey, key1.PriKey)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, pubKey1, pubKey2)
|
||||
}
|
||||
|
||||
func TestDiffieHellmanMiddleManAttack(t *testing.T) {
|
||||
key1, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
keyMiddle, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
key2, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
|
||||
const aesByteLen = 32
|
||||
pubKey1, err := ComputeKey(keyMiddle.PubKey, key1.PriKey)
|
||||
assert.Nil(t, err)
|
||||
src := []byte(`hello, world!`)
|
||||
encryptedSrc, err := EcbEncrypt(pubKey1.Bytes()[:aesByteLen], src)
|
||||
assert.Nil(t, err)
|
||||
pubKeyMiddle, err := ComputeKey(key1.PubKey, keyMiddle.PriKey)
|
||||
assert.Nil(t, err)
|
||||
decryptedSrc, err := EcbDecrypt(pubKeyMiddle.Bytes()[:aesByteLen], encryptedSrc)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, string(src), string(decryptedSrc))
|
||||
|
||||
pubKeyMiddle, err = ComputeKey(key2.PubKey, keyMiddle.PriKey)
|
||||
assert.Nil(t, err)
|
||||
encryptedSrc, err = EcbEncrypt(pubKeyMiddle.Bytes()[:aesByteLen], decryptedSrc)
|
||||
assert.Nil(t, err)
|
||||
pubKey2, err := ComputeKey(keyMiddle.PubKey, key2.PriKey)
|
||||
assert.Nil(t, err)
|
||||
decryptedSrc, err = EcbDecrypt(pubKey2.Bytes()[:aesByteLen], encryptedSrc)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, string(src), string(decryptedSrc))
|
||||
}
|
33
core/codec/gzip.go
Normal file
33
core/codec/gzip.go
Normal file
@ -0,0 +1,33 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
func Gzip(bs []byte) []byte {
|
||||
var b bytes.Buffer
|
||||
|
||||
w := gzip.NewWriter(&b)
|
||||
w.Write(bs)
|
||||
w.Close()
|
||||
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func Gunzip(bs []byte) ([]byte, error) {
|
||||
r, err := gzip.NewReader(bytes.NewBuffer(bs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
var c bytes.Buffer
|
||||
_, err = io.Copy(&c, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Bytes(), nil
|
||||
}
|
23
core/codec/gzip_test.go
Normal file
23
core/codec/gzip_test.go
Normal file
@ -0,0 +1,23 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGzip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
for i := 0; i < 10000; i++ {
|
||||
fmt.Fprint(&buf, i)
|
||||
}
|
||||
|
||||
bs := Gzip(buf.Bytes())
|
||||
actual, err := Gunzip(bs)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, len(bs) < buf.Len())
|
||||
assert.Equal(t, buf.Bytes(), actual)
|
||||
}
|
18
core/codec/hmac.go
Normal file
18
core/codec/hmac.go
Normal file
@ -0,0 +1,18 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
)
|
||||
|
||||
func Hmac(key []byte, body string) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
io.WriteString(h, body)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func HmacBase64(key []byte, body string) string {
|
||||
return base64.StdEncoding.EncodeToString(Hmac(key, body))
|
||||
}
|
149
core/codec/rsa.go
Normal file
149
core/codec/rsa.go
Normal file
@ -0,0 +1,149 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivateKey = errors.New("private key error")
|
||||
ErrPublicKey = errors.New("failed to parse PEM block containing the public key")
|
||||
ErrNotRsaKey = errors.New("key type is not RSA")
|
||||
)
|
||||
|
||||
type (
|
||||
RsaDecrypter interface {
|
||||
Decrypt(input []byte) ([]byte, error)
|
||||
DecryptBase64(input string) ([]byte, error)
|
||||
}
|
||||
|
||||
RsaEncrypter interface {
|
||||
Encrypt(input []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
rsaBase struct {
|
||||
bytesLimit int
|
||||
}
|
||||
|
||||
rsaDecrypter struct {
|
||||
rsaBase
|
||||
privateKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
rsaEncrypter struct {
|
||||
rsaBase
|
||||
publicKey *rsa.PublicKey
|
||||
}
|
||||
)
|
||||
|
||||
func NewRsaDecrypter(file string) (RsaDecrypter, error) {
|
||||
content, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(content)
|
||||
if block == nil {
|
||||
return nil, ErrPrivateKey
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rsaDecrypter{
|
||||
rsaBase: rsaBase{
|
||||
bytesLimit: privateKey.N.BitLen() >> 3,
|
||||
},
|
||||
privateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *rsaDecrypter) Decrypt(input []byte) ([]byte, error) {
|
||||
return r.crypt(input, func(block []byte) ([]byte, error) {
|
||||
return rsaDecryptBlock(r.privateKey, block)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
|
||||
if len(input) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
base64Decoded, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.Decrypt(base64Decoded)
|
||||
}
|
||||
|
||||
func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, ErrPublicKey
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch pubKey := pub.(type) {
|
||||
case *rsa.PublicKey:
|
||||
return &rsaEncrypter{
|
||||
rsaBase: rsaBase{
|
||||
// https://www.ietf.org/rfc/rfc2313.txt
|
||||
// The length of the data D shall not be more than k-11 octets, which is
|
||||
// positive since the length k of the modulus is at least 12 octets.
|
||||
bytesLimit: (pubKey.N.BitLen() >> 3) - 11,
|
||||
},
|
||||
publicKey: pubKey,
|
||||
}, nil
|
||||
default:
|
||||
return nil, ErrNotRsaKey
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rsaEncrypter) Encrypt(input []byte) ([]byte, error) {
|
||||
return r.crypt(input, func(block []byte) ([]byte, error) {
|
||||
return rsaEncryptBlock(r.publicKey, block)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *rsaBase) crypt(input []byte, cryptFn func([]byte) ([]byte, error)) ([]byte, error) {
|
||||
var result []byte
|
||||
inputLen := len(input)
|
||||
|
||||
for i := 0; i*r.bytesLimit < inputLen; i++ {
|
||||
start := r.bytesLimit * i
|
||||
var stop int
|
||||
if r.bytesLimit*(i+1) > inputLen {
|
||||
stop = inputLen
|
||||
} else {
|
||||
stop = r.bytesLimit * (i + 1)
|
||||
}
|
||||
bs, err := cryptFn(input[start:stop])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, bs...)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) {
|
||||
return rsa.DecryptPKCS1v15(rand.Reader, privateKey, block)
|
||||
}
|
||||
|
||||
func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) {
|
||||
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg)
|
||||
}
|
275
core/collection/cache.go
Normal file
275
core/collection/cache.go
Normal file
@ -0,0 +1,275 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/mathx"
|
||||
"zero/core/syncx"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCacheName = "proc"
|
||||
slots = 300
|
||||
statInterval = time.Minute
|
||||
// make the expiry unstable to avoid lots of cached items expire at the same time
|
||||
// make the unstable expiry to be [0.95, 1.05] * seconds
|
||||
expiryDeviation = 0.05
|
||||
)
|
||||
|
||||
var emptyLruCache = emptyLru{}
|
||||
|
||||
type (
|
||||
CacheOption func(cache *Cache)
|
||||
|
||||
Cache struct {
|
||||
name string
|
||||
lock sync.Mutex
|
||||
data map[string]interface{}
|
||||
evicts *list.List
|
||||
expire time.Duration
|
||||
timingWheel *TimingWheel
|
||||
lruCache lru
|
||||
barrier syncx.SharedCalls
|
||||
unstableExpiry mathx.Unstable
|
||||
stats *cacheStat
|
||||
}
|
||||
)
|
||||
|
||||
func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
|
||||
cache := &Cache{
|
||||
data: make(map[string]interface{}),
|
||||
expire: expire,
|
||||
lruCache: emptyLruCache,
|
||||
barrier: syncx.NewSharedCalls(),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cache)
|
||||
}
|
||||
|
||||
if len(cache.name) == 0 {
|
||||
cache.name = defaultCacheName
|
||||
}
|
||||
cache.stats = newCacheStat(cache.name, cache.size)
|
||||
|
||||
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v interface{}) {
|
||||
key, ok := k.(string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
cache.Del(key)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache.timingWheel = timingWheel
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
func (c *Cache) Del(key string) {
|
||||
c.lock.Lock()
|
||||
delete(c.data, key)
|
||||
c.lruCache.remove(key)
|
||||
c.lock.Unlock()
|
||||
c.timingWheel.RemoveTimer(key)
|
||||
}
|
||||
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.lock.Lock()
|
||||
value, ok := c.data[key]
|
||||
if ok {
|
||||
c.lruCache.add(key)
|
||||
}
|
||||
c.lock.Unlock()
|
||||
if ok {
|
||||
c.stats.IncrementHit()
|
||||
} else {
|
||||
c.stats.IncrementMiss()
|
||||
}
|
||||
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func (c *Cache) Set(key string, value interface{}) {
|
||||
c.lock.Lock()
|
||||
_, ok := c.data[key]
|
||||
c.data[key] = value
|
||||
c.lruCache.add(key)
|
||||
c.lock.Unlock()
|
||||
|
||||
expiry := c.unstableExpiry.AroundDuration(c.expire)
|
||||
if ok {
|
||||
c.timingWheel.MoveTimer(key, expiry)
|
||||
} else {
|
||||
c.timingWheel.SetTimer(key, value, expiry)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
|
||||
val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) {
|
||||
v, e := fetch()
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
c.Set(key, v)
|
||||
return v, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if fresh {
|
||||
c.stats.IncrementMiss()
|
||||
return val, nil
|
||||
} else {
|
||||
// got the result from previous ongoing query
|
||||
c.stats.IncrementHit()
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *Cache) onEvict(key string) {
|
||||
// already locked
|
||||
delete(c.data, key)
|
||||
c.timingWheel.RemoveTimer(key)
|
||||
}
|
||||
|
||||
func (c *Cache) size() int {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return len(c.data)
|
||||
}
|
||||
|
||||
func WithLimit(limit int) CacheOption {
|
||||
return func(cache *Cache) {
|
||||
if limit > 0 {
|
||||
cache.lruCache = newKeyLru(limit, cache.onEvict)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithName(name string) CacheOption {
|
||||
return func(cache *Cache) {
|
||||
cache.name = name
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
lru interface {
|
||||
add(key string)
|
||||
remove(key string)
|
||||
}
|
||||
|
||||
emptyLru struct{}
|
||||
|
||||
keyLru struct {
|
||||
limit int
|
||||
evicts *list.List
|
||||
elements map[string]*list.Element
|
||||
onEvict func(key string)
|
||||
}
|
||||
)
|
||||
|
||||
func (elru emptyLru) add(string) {
|
||||
}
|
||||
|
||||
func (elru emptyLru) remove(string) {
|
||||
}
|
||||
|
||||
func newKeyLru(limit int, onEvict func(key string)) *keyLru {
|
||||
return &keyLru{
|
||||
limit: limit,
|
||||
evicts: list.New(),
|
||||
elements: make(map[string]*list.Element),
|
||||
onEvict: onEvict,
|
||||
}
|
||||
}
|
||||
|
||||
func (klru *keyLru) add(key string) {
|
||||
if elem, ok := klru.elements[key]; ok {
|
||||
klru.evicts.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
// Add new item
|
||||
elem := klru.evicts.PushFront(key)
|
||||
klru.elements[key] = elem
|
||||
|
||||
// Verify size not exceeded
|
||||
if klru.evicts.Len() > klru.limit {
|
||||
klru.removeOldest()
|
||||
}
|
||||
}
|
||||
|
||||
func (klru *keyLru) remove(key string) {
|
||||
if elem, ok := klru.elements[key]; ok {
|
||||
klru.removeElement(elem)
|
||||
}
|
||||
}
|
||||
|
||||
func (klru *keyLru) removeOldest() {
|
||||
elem := klru.evicts.Back()
|
||||
if elem != nil {
|
||||
klru.removeElement(elem)
|
||||
}
|
||||
}
|
||||
|
||||
func (klru *keyLru) removeElement(e *list.Element) {
|
||||
klru.evicts.Remove(e)
|
||||
key := e.Value.(string)
|
||||
delete(klru.elements, key)
|
||||
klru.onEvict(key)
|
||||
}
|
||||
|
||||
type cacheStat struct {
|
||||
name string
|
||||
hit uint64
|
||||
miss uint64
|
||||
sizeCallback func() int
|
||||
}
|
||||
|
||||
func newCacheStat(name string, sizeCallback func() int) *cacheStat {
|
||||
st := &cacheStat{
|
||||
name: name,
|
||||
sizeCallback: sizeCallback,
|
||||
}
|
||||
go st.statLoop()
|
||||
return st
|
||||
}
|
||||
|
||||
func (cs *cacheStat) IncrementHit() {
|
||||
atomic.AddUint64(&cs.hit, 1)
|
||||
}
|
||||
|
||||
func (cs *cacheStat) IncrementMiss() {
|
||||
atomic.AddUint64(&cs.miss, 1)
|
||||
}
|
||||
|
||||
func (cs *cacheStat) statLoop() {
|
||||
ticker := time.NewTicker(statInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hit := atomic.SwapUint64(&cs.hit, 0)
|
||||
miss := atomic.SwapUint64(&cs.miss, 0)
|
||||
total := hit + miss
|
||||
if total == 0 {
|
||||
continue
|
||||
}
|
||||
percent := 100 * float32(hit) / float32(total)
|
||||
logx.Statf("cache(%s) - qpm: %d, hit_ratio: %.1f%%, elements: %d, hit: %d, miss: %d",
|
||||
cs.name, total, percent, cs.sizeCallback(), hit, miss)
|
||||
}
|
||||
}
|
||||
}
|
139
core/collection/cache_test.go
Normal file
139
core/collection/cache_test.go
Normal file
@ -0,0 +1,139 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCacheSet(t *testing.T) {
|
||||
cache, err := NewCache(time.Second*2, WithName("any"))
|
||||
assert.Nil(t, err)
|
||||
|
||||
cache.Set("first", "first element")
|
||||
cache.Set("second", "second element")
|
||||
|
||||
value, ok := cache.Get("first")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "first element", value)
|
||||
value, ok = cache.Get("second")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "second element", value)
|
||||
}
|
||||
|
||||
func TestCacheDel(t *testing.T) {
|
||||
cache, err := NewCache(time.Second * 2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
cache.Set("first", "first element")
|
||||
cache.Set("second", "second element")
|
||||
cache.Del("first")
|
||||
|
||||
_, ok := cache.Get("first")
|
||||
assert.False(t, ok)
|
||||
value, ok := cache.Get("second")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "second element", value)
|
||||
}
|
||||
|
||||
func TestCacheTake(t *testing.T) {
|
||||
cache, err := NewCache(time.Second * 2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var count int32
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
cache.Take("first", func() (interface{}, error) {
|
||||
atomic.AddInt32(&count, 1)
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
return "first element", nil
|
||||
})
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 1, cache.size())
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
||||
}
|
||||
|
||||
func TestCacheWithLruEvicts(t *testing.T) {
|
||||
cache, err := NewCache(time.Minute, WithLimit(3))
|
||||
assert.Nil(t, err)
|
||||
|
||||
cache.Set("first", "first element")
|
||||
cache.Set("second", "second element")
|
||||
cache.Set("third", "third element")
|
||||
cache.Set("fourth", "fourth element")
|
||||
|
||||
value, ok := cache.Get("first")
|
||||
assert.False(t, ok)
|
||||
value, ok = cache.Get("second")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "second element", value)
|
||||
value, ok = cache.Get("third")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "third element", value)
|
||||
value, ok = cache.Get("fourth")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "fourth element", value)
|
||||
}
|
||||
|
||||
func TestCacheWithLruEvicted(t *testing.T) {
|
||||
cache, err := NewCache(time.Minute, WithLimit(3))
|
||||
assert.Nil(t, err)
|
||||
|
||||
cache.Set("first", "first element")
|
||||
cache.Set("second", "second element")
|
||||
cache.Set("third", "third element")
|
||||
cache.Set("fourth", "fourth element")
|
||||
|
||||
value, ok := cache.Get("first")
|
||||
assert.False(t, ok)
|
||||
value, ok = cache.Get("second")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "second element", value)
|
||||
cache.Set("fifth", "fifth element")
|
||||
cache.Set("sixth", "sixth element")
|
||||
_, ok = cache.Get("third")
|
||||
assert.False(t, ok)
|
||||
_, ok = cache.Get("fourth")
|
||||
assert.False(t, ok)
|
||||
value, ok = cache.Get("second")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "second element", value)
|
||||
}
|
||||
|
||||
func BenchmarkCache(b *testing.B) {
|
||||
cache, err := NewCache(time.Second*5, WithLimit(100000))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
for j := 0; j < 10; j++ {
|
||||
index := strconv.Itoa(i*10000 + j)
|
||||
cache.Set("key:"+index, "value:"+index)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
index := strconv.Itoa(i % 10000)
|
||||
cache.Get("key:" + index)
|
||||
if i%100 == 0 {
|
||||
cache.Set("key1:"+index, "value1:"+index)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
60
core/collection/fifo.go
Normal file
60
core/collection/fifo.go
Normal file
@ -0,0 +1,60 @@
|
||||
package collection
|
||||
|
||||
import "sync"
|
||||
|
||||
type Queue struct {
|
||||
lock sync.Mutex
|
||||
elements []interface{}
|
||||
size int
|
||||
head int
|
||||
tail int
|
||||
count int
|
||||
}
|
||||
|
||||
func NewQueue(size int) *Queue {
|
||||
return &Queue{
|
||||
elements: make([]interface{}, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Queue) Empty() bool {
|
||||
q.lock.Lock()
|
||||
empty := q.count == 0
|
||||
q.lock.Unlock()
|
||||
|
||||
return empty
|
||||
}
|
||||
|
||||
func (q *Queue) Put(element interface{}) {
|
||||
q.lock.Lock()
|
||||
defer q.lock.Unlock()
|
||||
|
||||
if q.head == q.tail && q.count > 0 {
|
||||
nodes := make([]interface{}, len(q.elements)+q.size)
|
||||
copy(nodes, q.elements[q.head:])
|
||||
copy(nodes[len(q.elements)-q.head:], q.elements[:q.head])
|
||||
q.head = 0
|
||||
q.tail = len(q.elements)
|
||||
q.elements = nodes
|
||||
}
|
||||
|
||||
q.elements[q.tail] = element
|
||||
q.tail = (q.tail + 1) % len(q.elements)
|
||||
q.count++
|
||||
}
|
||||
|
||||
func (q *Queue) Take() (interface{}, bool) {
|
||||
q.lock.Lock()
|
||||
defer q.lock.Unlock()
|
||||
|
||||
if q.count == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
element := q.elements[q.head]
|
||||
q.head = (q.head + 1) % len(q.elements)
|
||||
q.count--
|
||||
|
||||
return element, true
|
||||
}
|
63
core/collection/fifo_test.go
Normal file
63
core/collection/fifo_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFifo(t *testing.T) {
|
||||
elements := [][]byte{
|
||||
[]byte("hello"),
|
||||
[]byte("world"),
|
||||
[]byte("again"),
|
||||
}
|
||||
queue := NewQueue(8)
|
||||
for i := range elements {
|
||||
queue.Put(elements[i])
|
||||
}
|
||||
|
||||
for _, element := range elements {
|
||||
body, ok := queue.Take()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, string(element), string(body.([]byte)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTakeTooMany(t *testing.T) {
|
||||
elements := [][]byte{
|
||||
[]byte("hello"),
|
||||
[]byte("world"),
|
||||
[]byte("again"),
|
||||
}
|
||||
queue := NewQueue(8)
|
||||
for i := range elements {
|
||||
queue.Put(elements[i])
|
||||
}
|
||||
|
||||
for range elements {
|
||||
queue.Take()
|
||||
}
|
||||
|
||||
assert.True(t, queue.Empty())
|
||||
_, ok := queue.Take()
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPutMore(t *testing.T) {
|
||||
elements := [][]byte{
|
||||
[]byte("hello"),
|
||||
[]byte("world"),
|
||||
[]byte("again"),
|
||||
}
|
||||
queue := NewQueue(2)
|
||||
for i := range elements {
|
||||
queue.Put(elements[i])
|
||||
}
|
||||
|
||||
for _, element := range elements {
|
||||
body, ok := queue.Take()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, string(element), string(body.([]byte)))
|
||||
}
|
||||
}
|
35
core/collection/ring.go
Normal file
35
core/collection/ring.go
Normal file
@ -0,0 +1,35 @@
|
||||
package collection
|
||||
|
||||
type Ring struct {
|
||||
elements []interface{}
|
||||
index int
|
||||
}
|
||||
|
||||
func NewRing(n int) *Ring {
|
||||
return &Ring{
|
||||
elements: make([]interface{}, n),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Ring) Add(v interface{}) {
|
||||
r.elements[r.index%len(r.elements)] = v
|
||||
r.index++
|
||||
}
|
||||
|
||||
func (r *Ring) Take() []interface{} {
|
||||
var size int
|
||||
var start int
|
||||
if r.index > len(r.elements) {
|
||||
size = len(r.elements)
|
||||
start = r.index % len(r.elements)
|
||||
} else {
|
||||
size = r.index
|
||||
}
|
||||
|
||||
elements := make([]interface{}, size)
|
||||
for i := 0; i < size; i++ {
|
||||
elements[i] = r.elements[(start+i)%len(r.elements)]
|
||||
}
|
||||
|
||||
return elements
|
||||
}
|
25
core/collection/ring_test.go
Normal file
25
core/collection/ring_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRingLess(t *testing.T) {
|
||||
ring := NewRing(5)
|
||||
for i := 0; i < 3; i++ {
|
||||
ring.Add(i)
|
||||
}
|
||||
elements := ring.Take()
|
||||
assert.ElementsMatch(t, []interface{}{0, 1, 2}, elements)
|
||||
}
|
||||
|
||||
func TestRingMore(t *testing.T) {
|
||||
ring := NewRing(5)
|
||||
for i := 0; i < 11; i++ {
|
||||
ring.Add(i)
|
||||
}
|
||||
elements := ring.Take()
|
||||
assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements)
|
||||
}
|
145
core/collection/rollingwindow.go
Normal file
145
core/collection/rollingwindow.go
Normal file
@ -0,0 +1,145 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
type (
|
||||
RollingWindowOption func(rollingWindow *RollingWindow)
|
||||
|
||||
RollingWindow struct {
|
||||
lock sync.RWMutex
|
||||
size int
|
||||
win *window
|
||||
interval time.Duration
|
||||
offset int
|
||||
ignoreCurrent bool
|
||||
lastTime time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
|
||||
w := &RollingWindow{
|
||||
size: size,
|
||||
win: newWindow(size),
|
||||
interval: interval,
|
||||
lastTime: timex.Now(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(w)
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
func (rw *RollingWindow) Add(v float64) {
|
||||
rw.lock.Lock()
|
||||
defer rw.lock.Unlock()
|
||||
rw.updateOffset()
|
||||
rw.win.add(rw.offset, v)
|
||||
}
|
||||
|
||||
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
|
||||
rw.lock.RLock()
|
||||
defer rw.lock.RUnlock()
|
||||
|
||||
var diff int
|
||||
span := rw.span()
|
||||
// ignore current bucket, because of partial data
|
||||
if span == 0 && rw.ignoreCurrent {
|
||||
diff = rw.size - 1
|
||||
} else {
|
||||
diff = rw.size - span
|
||||
}
|
||||
if diff > 0 {
|
||||
offset := (rw.offset + span + 1) % rw.size
|
||||
rw.win.reduce(offset, diff, fn)
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *RollingWindow) span() int {
|
||||
offset := int(timex.Since(rw.lastTime) / rw.interval)
|
||||
if 0 <= offset && offset < rw.size {
|
||||
return offset
|
||||
} else {
|
||||
return rw.size
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *RollingWindow) updateOffset() {
|
||||
span := rw.span()
|
||||
if span > 0 {
|
||||
offset := rw.offset
|
||||
// reset expired buckets
|
||||
start := offset + 1
|
||||
steps := start + span
|
||||
var remainder int
|
||||
if steps > rw.size {
|
||||
remainder = steps - rw.size
|
||||
steps = rw.size
|
||||
}
|
||||
for i := start; i < steps; i++ {
|
||||
rw.win.resetBucket(i)
|
||||
offset = i
|
||||
}
|
||||
for i := 0; i < remainder; i++ {
|
||||
rw.win.resetBucket(i)
|
||||
offset = i
|
||||
}
|
||||
rw.offset = offset
|
||||
rw.lastTime = timex.Now()
|
||||
}
|
||||
}
|
||||
|
||||
type Bucket struct {
|
||||
Sum float64
|
||||
Count int64
|
||||
}
|
||||
|
||||
func (b *Bucket) add(v float64) {
|
||||
b.Sum += v
|
||||
b.Count++
|
||||
}
|
||||
|
||||
func (b *Bucket) reset() {
|
||||
b.Sum = 0
|
||||
b.Count = 0
|
||||
}
|
||||
|
||||
type window struct {
|
||||
buckets []*Bucket
|
||||
size int
|
||||
}
|
||||
|
||||
func newWindow(size int) *window {
|
||||
var buckets []*Bucket
|
||||
for i := 0; i < size; i++ {
|
||||
buckets = append(buckets, new(Bucket))
|
||||
}
|
||||
return &window{
|
||||
buckets: buckets,
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *window) add(offset int, v float64) {
|
||||
w.buckets[offset%w.size].add(v)
|
||||
}
|
||||
|
||||
func (w *window) reduce(start, count int, fn func(b *Bucket)) {
|
||||
for i := 0; i < count; i++ {
|
||||
fn(w.buckets[(start+i)%len(w.buckets)])
|
||||
}
|
||||
}
|
||||
|
||||
func (w *window) resetBucket(offset int) {
|
||||
w.buckets[offset].reset()
|
||||
}
|
||||
|
||||
func IgnoreCurrentBucket() RollingWindowOption {
|
||||
return func(w *RollingWindow) {
|
||||
w.ignoreCurrent = true
|
||||
}
|
||||
}
|
133
core/collection/rollingwindow_test.go
Normal file
133
core/collection/rollingwindow_test.go
Normal file
@ -0,0 +1,133 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/stringx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const duration = time.Millisecond * 50
|
||||
|
||||
func TestRollingWindowAdd(t *testing.T) {
|
||||
const size = 3
|
||||
r := NewRollingWindow(size, duration)
|
||||
listBuckets := func() []float64 {
|
||||
var buckets []float64
|
||||
r.Reduce(func(b *Bucket) {
|
||||
buckets = append(buckets, b.Sum)
|
||||
})
|
||||
return buckets
|
||||
}
|
||||
assert.Equal(t, []float64{0, 0, 0}, listBuckets())
|
||||
r.Add(1)
|
||||
assert.Equal(t, []float64{0, 0, 1}, listBuckets())
|
||||
elapse()
|
||||
r.Add(2)
|
||||
r.Add(3)
|
||||
assert.Equal(t, []float64{0, 1, 5}, listBuckets())
|
||||
elapse()
|
||||
r.Add(4)
|
||||
r.Add(5)
|
||||
r.Add(6)
|
||||
assert.Equal(t, []float64{1, 5, 15}, listBuckets())
|
||||
elapse()
|
||||
r.Add(7)
|
||||
assert.Equal(t, []float64{5, 15, 7}, listBuckets())
|
||||
}
|
||||
|
||||
func TestRollingWindowReset(t *testing.T) {
|
||||
const size = 3
|
||||
r := NewRollingWindow(size, duration, IgnoreCurrentBucket())
|
||||
listBuckets := func() []float64 {
|
||||
var buckets []float64
|
||||
r.Reduce(func(b *Bucket) {
|
||||
buckets = append(buckets, b.Sum)
|
||||
})
|
||||
return buckets
|
||||
}
|
||||
r.Add(1)
|
||||
elapse()
|
||||
assert.Equal(t, []float64{0, 1}, listBuckets())
|
||||
elapse()
|
||||
assert.Equal(t, []float64{1}, listBuckets())
|
||||
elapse()
|
||||
assert.Nil(t, listBuckets())
|
||||
|
||||
// cross window
|
||||
r.Add(1)
|
||||
time.Sleep(duration * 10)
|
||||
assert.Nil(t, listBuckets())
|
||||
}
|
||||
|
||||
func TestRollingWindowReduce(t *testing.T) {
|
||||
const size = 4
|
||||
tests := []struct {
|
||||
win *RollingWindow
|
||||
expect float64
|
||||
}{
|
||||
{
|
||||
win: NewRollingWindow(size, duration),
|
||||
expect: 10,
|
||||
},
|
||||
{
|
||||
win: NewRollingWindow(size, duration, IgnoreCurrentBucket()),
|
||||
expect: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.Rand(), func(t *testing.T) {
|
||||
r := test.win
|
||||
for x := 0; x < size; x = x + 1 {
|
||||
for i := 0; i <= x; i++ {
|
||||
r.Add(float64(i))
|
||||
}
|
||||
if x < size-1 {
|
||||
elapse()
|
||||
}
|
||||
}
|
||||
var result float64
|
||||
r.Reduce(func(b *Bucket) {
|
||||
result += b.Sum
|
||||
})
|
||||
assert.Equal(t, test.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollingWindowDataRace(t *testing.T) {
|
||||
const size = 3
|
||||
r := NewRollingWindow(size, duration)
|
||||
var stop = make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
r.Add(float64(rand.Int63()))
|
||||
time.Sleep(duration / 2)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
r.Reduce(func(b *Bucket) {})
|
||||
}
|
||||
}
|
||||
}()
|
||||
time.Sleep(duration * 5)
|
||||
close(stop)
|
||||
}
|
||||
|
||||
func elapse() {
|
||||
time.Sleep(duration)
|
||||
}
|
91
core/collection/safemap.go
Normal file
91
core/collection/safemap.go
Normal file
@ -0,0 +1,91 @@
|
||||
package collection
|
||||
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
copyThreshold = 1000
|
||||
maxDeletion = 10000
|
||||
)
|
||||
|
||||
// SafeMap provides a map alternative to avoid memory leak.
|
||||
// This implementation is not needed until issue below fixed.
|
||||
// https://github.com/golang/go/issues/20135
|
||||
type SafeMap struct {
|
||||
lock sync.RWMutex
|
||||
deletionOld int
|
||||
deletionNew int
|
||||
dirtyOld map[interface{}]interface{}
|
||||
dirtyNew map[interface{}]interface{}
|
||||
}
|
||||
|
||||
func NewSafeMap() *SafeMap {
|
||||
return &SafeMap{
|
||||
dirtyOld: make(map[interface{}]interface{}),
|
||||
dirtyNew: make(map[interface{}]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SafeMap) Del(key interface{}) {
|
||||
m.lock.Lock()
|
||||
if _, ok := m.dirtyOld[key]; ok {
|
||||
delete(m.dirtyOld, key)
|
||||
m.deletionOld++
|
||||
} else if _, ok := m.dirtyNew[key]; ok {
|
||||
delete(m.dirtyNew, key)
|
||||
m.deletionNew++
|
||||
}
|
||||
if m.deletionOld >= maxDeletion && len(m.dirtyOld) < copyThreshold {
|
||||
for k, v := range m.dirtyOld {
|
||||
m.dirtyNew[k] = v
|
||||
}
|
||||
m.dirtyOld = m.dirtyNew
|
||||
m.deletionOld = m.deletionNew
|
||||
m.dirtyNew = make(map[interface{}]interface{})
|
||||
m.deletionNew = 0
|
||||
}
|
||||
if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold {
|
||||
for k, v := range m.dirtyNew {
|
||||
m.dirtyOld[k] = v
|
||||
}
|
||||
m.dirtyNew = make(map[interface{}]interface{})
|
||||
m.deletionNew = 0
|
||||
}
|
||||
m.lock.Unlock()
|
||||
}
|
||||
|
||||
func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
if val, ok := m.dirtyOld[key]; ok {
|
||||
return val, true
|
||||
} else {
|
||||
val, ok := m.dirtyNew[key]
|
||||
return val, ok
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SafeMap) Set(key, value interface{}) {
|
||||
m.lock.Lock()
|
||||
if m.deletionOld <= maxDeletion {
|
||||
if _, ok := m.dirtyNew[key]; ok {
|
||||
delete(m.dirtyNew, key)
|
||||
m.deletionNew++
|
||||
}
|
||||
m.dirtyOld[key] = value
|
||||
} else {
|
||||
if _, ok := m.dirtyOld[key]; ok {
|
||||
delete(m.dirtyOld, key)
|
||||
m.deletionOld++
|
||||
}
|
||||
m.dirtyNew[key] = value
|
||||
}
|
||||
m.lock.Unlock()
|
||||
}
|
||||
|
||||
func (m *SafeMap) Size() int {
|
||||
m.lock.RLock()
|
||||
size := len(m.dirtyOld) + len(m.dirtyNew)
|
||||
m.lock.RUnlock()
|
||||
return size
|
||||
}
|
110
core/collection/safemap_test.go
Normal file
110
core/collection/safemap_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/stringx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSafeMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
size int
|
||||
exception int
|
||||
}{
|
||||
{
|
||||
100000,
|
||||
2000,
|
||||
},
|
||||
{
|
||||
100000,
|
||||
50,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.Rand(), func(t *testing.T) {
|
||||
testSafeMapWithParameters(t, test.size, test.exception)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeMap_CopyNew(t *testing.T) {
|
||||
const (
|
||||
size = 100000
|
||||
exception1 = 5
|
||||
exception2 = 500
|
||||
)
|
||||
m := NewSafeMap()
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
m.Set(i, i)
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
if i%exception1 == 0 {
|
||||
m.Del(i)
|
||||
}
|
||||
}
|
||||
|
||||
for i := size; i < size<<1; i++ {
|
||||
m.Set(i, i)
|
||||
}
|
||||
for i := size; i < size<<1; i++ {
|
||||
if i%exception2 != 0 {
|
||||
m.Del(i)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
val, ok := m.Get(i)
|
||||
if i%exception1 != 0 {
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, i, val.(int))
|
||||
} else {
|
||||
assert.False(t, ok)
|
||||
}
|
||||
}
|
||||
for i := size; i < size<<1; i++ {
|
||||
val, ok := m.Get(i)
|
||||
if i%exception2 == 0 {
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, i, val.(int))
|
||||
} else {
|
||||
assert.False(t, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testSafeMapWithParameters(t *testing.T, size, exception int) {
|
||||
m := NewSafeMap()
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
m.Set(i, i)
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
if i%exception != 0 {
|
||||
m.Del(i)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, size/exception, m.Size())
|
||||
|
||||
for i := size; i < size<<1; i++ {
|
||||
m.Set(i, i)
|
||||
}
|
||||
for i := size; i < size<<1; i++ {
|
||||
if i%exception != 0 {
|
||||
m.Del(i)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < size<<1; i++ {
|
||||
val, ok := m.Get(i)
|
||||
if i%exception == 0 {
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, i, val.(int))
|
||||
} else {
|
||||
assert.False(t, ok)
|
||||
}
|
||||
}
|
||||
}
|
230
core/collection/set.go
Normal file
230
core/collection/set.go
Normal file
@ -0,0 +1,230 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"zero/core/lang"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const (
|
||||
unmanaged = iota
|
||||
untyped
|
||||
intType
|
||||
int64Type
|
||||
uintType
|
||||
uint64Type
|
||||
stringType
|
||||
)
|
||||
|
||||
type Set struct {
|
||||
data map[interface{}]lang.PlaceholderType
|
||||
tp int
|
||||
}
|
||||
|
||||
func NewSet() *Set {
|
||||
return &Set{
|
||||
data: make(map[interface{}]lang.PlaceholderType),
|
||||
tp: untyped,
|
||||
}
|
||||
}
|
||||
|
||||
func NewUnmanagedSet() *Set {
|
||||
return &Set{
|
||||
data: make(map[interface{}]lang.PlaceholderType),
|
||||
tp: unmanaged,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) Add(i ...interface{}) {
|
||||
for _, each := range i {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) AddInt(ii ...int) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) AddInt64(ii ...int64) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) AddUint(ii ...uint) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) AddUint64(ii ...uint64) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) AddStr(ss ...string) {
|
||||
for _, each := range ss {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) Contains(i interface{}) bool {
|
||||
if len(s.data) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
s.validate(i)
|
||||
_, ok := s.data[i]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Set) Keys() []interface{} {
|
||||
var keys []interface{}
|
||||
|
||||
for key := range s.data {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) KeysInt() []int {
|
||||
var keys []int
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int); !ok {
|
||||
continue
|
||||
} else {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) KeysInt64() []int64 {
|
||||
var keys []int64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int64); !ok {
|
||||
continue
|
||||
} else {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) KeysUint() []uint {
|
||||
var keys []uint
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint); !ok {
|
||||
continue
|
||||
} else {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) KeysUint64() []uint64 {
|
||||
var keys []uint64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint64); !ok {
|
||||
continue
|
||||
} else {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) KeysStr() []string {
|
||||
var keys []string
|
||||
|
||||
for key := range s.data {
|
||||
if strKey, ok := key.(string); !ok {
|
||||
continue
|
||||
} else {
|
||||
keys = append(keys, strKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) Remove(i interface{}) {
|
||||
s.validate(i)
|
||||
delete(s.data, i)
|
||||
}
|
||||
|
||||
func (s *Set) Count() int {
|
||||
return len(s.data)
|
||||
}
|
||||
|
||||
func (s *Set) add(i interface{}) {
|
||||
switch s.tp {
|
||||
case unmanaged:
|
||||
// do nothing
|
||||
case untyped:
|
||||
s.setType(i)
|
||||
default:
|
||||
s.validate(i)
|
||||
}
|
||||
s.data[i] = lang.Placeholder
|
||||
}
|
||||
|
||||
func (s *Set) setType(i interface{}) {
|
||||
if s.tp != untyped {
|
||||
return
|
||||
}
|
||||
|
||||
switch i.(type) {
|
||||
case int:
|
||||
s.tp = intType
|
||||
case int64:
|
||||
s.tp = int64Type
|
||||
case uint:
|
||||
s.tp = uintType
|
||||
case uint64:
|
||||
s.tp = uint64Type
|
||||
case string:
|
||||
s.tp = stringType
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) validate(i interface{}) {
|
||||
if s.tp == unmanaged {
|
||||
return
|
||||
}
|
||||
|
||||
switch i.(type) {
|
||||
case int:
|
||||
if s.tp != intType {
|
||||
logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case int64:
|
||||
if s.tp != int64Type {
|
||||
logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint:
|
||||
if s.tp != uintType {
|
||||
logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint64:
|
||||
if s.tp != uint64Type {
|
||||
logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case string:
|
||||
if s.tp != stringType {
|
||||
logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
}
|
||||
}
|
149
core/collection/set_test.go
Normal file
149
core/collection/set_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkRawSet(b *testing.B) {
|
||||
m := make(map[interface{}]struct{})
|
||||
for i := 0; i < b.N; i++ {
|
||||
m[i] = struct{}{}
|
||||
_ = m[i]
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmanagedSet(b *testing.B) {
|
||||
s := NewUnmanagedSet()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Add(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSet(b *testing.B) {
|
||||
s := NewSet()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.AddInt(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
// given
|
||||
set := NewUnmanagedSet()
|
||||
values := []interface{}{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestAddInt(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
values := []int{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddInt(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
keys := set.KeysInt()
|
||||
sort.Ints(keys)
|
||||
assert.EqualValues(t, values, keys)
|
||||
}
|
||||
|
||||
func TestAddInt64(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
values := []int64{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddInt64(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysInt64()))
|
||||
}
|
||||
|
||||
func TestAddUint(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
values := []uint{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddUint(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysUint()))
|
||||
}
|
||||
|
||||
func TestAddUint64(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
values := []uint64{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddUint64(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysUint64()))
|
||||
}
|
||||
|
||||
func TestAddStr(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
values := []string{"1", "2", "3"}
|
||||
|
||||
// when
|
||||
set.AddStr(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3"))
|
||||
assert.Equal(t, len(values), len(set.KeysStr()))
|
||||
}
|
||||
|
||||
func TestContainsWithoutElements(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
|
||||
// then
|
||||
assert.False(t, set.Contains(1))
|
||||
}
|
||||
|
||||
func TestContainsUnmanagedWithoutElements(t *testing.T) {
|
||||
// given
|
||||
set := NewUnmanagedSet()
|
||||
|
||||
// then
|
||||
assert.False(t, set.Contains(1))
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set.Add([]interface{}{1, 2, 3}...)
|
||||
|
||||
// when
|
||||
set.Remove(2)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(1) && !set.Contains(2) && set.Contains(3))
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set.Add([]interface{}{1, 2, 3}...)
|
||||
|
||||
// then
|
||||
assert.Equal(t, set.Count(), 3)
|
||||
}
|
311
core/collection/timingwheel.go
Normal file
311
core/collection/timingwheel.go
Normal file
@ -0,0 +1,311 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"zero/core/lang"
|
||||
"zero/core/threading"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
const drainWorkers = 8
|
||||
|
||||
type (
|
||||
Execute func(key, value interface{})
|
||||
|
||||
TimingWheel struct {
|
||||
interval time.Duration
|
||||
ticker timex.Ticker
|
||||
slots []*list.List
|
||||
timers *SafeMap
|
||||
tickedPos int
|
||||
numSlots int
|
||||
execute Execute
|
||||
setChannel chan timingEntry
|
||||
moveChannel chan baseEntry
|
||||
removeChannel chan interface{}
|
||||
drainChannel chan func(key, value interface{})
|
||||
stopChannel chan lang.PlaceholderType
|
||||
}
|
||||
|
||||
timingEntry struct {
|
||||
baseEntry
|
||||
value interface{}
|
||||
circle int
|
||||
diff int
|
||||
removed bool
|
||||
}
|
||||
|
||||
baseEntry struct {
|
||||
delay time.Duration
|
||||
key interface{}
|
||||
}
|
||||
|
||||
positionEntry struct {
|
||||
pos int
|
||||
item *timingEntry
|
||||
}
|
||||
|
||||
timingTask struct {
|
||||
key interface{}
|
||||
value interface{}
|
||||
}
|
||||
)
|
||||
|
||||
func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
|
||||
if interval <= 0 || numSlots <= 0 || execute == nil {
|
||||
return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute)
|
||||
}
|
||||
|
||||
return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
|
||||
}
|
||||
|
||||
func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) (
|
||||
*TimingWheel, error) {
|
||||
tw := &TimingWheel{
|
||||
interval: interval,
|
||||
ticker: ticker,
|
||||
slots: make([]*list.List, numSlots),
|
||||
timers: NewSafeMap(),
|
||||
tickedPos: numSlots - 1, // at previous virtual circle
|
||||
execute: execute,
|
||||
numSlots: numSlots,
|
||||
setChannel: make(chan timingEntry),
|
||||
moveChannel: make(chan baseEntry),
|
||||
removeChannel: make(chan interface{}),
|
||||
drainChannel: make(chan func(key, value interface{})),
|
||||
stopChannel: make(chan lang.PlaceholderType),
|
||||
}
|
||||
|
||||
tw.initSlots()
|
||||
go tw.run()
|
||||
|
||||
return tw, nil
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) Drain(fn func(key, value interface{})) {
|
||||
tw.drainChannel <- fn
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
|
||||
if delay <= 0 || key == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tw.moveChannel <- baseEntry{
|
||||
delay: delay,
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) RemoveTimer(key interface{}) {
|
||||
if key == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tw.removeChannel <- key
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
|
||||
if delay <= 0 || key == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tw.setChannel <- timingEntry{
|
||||
baseEntry: baseEntry{
|
||||
delay: delay,
|
||||
key: key,
|
||||
},
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) Stop() {
|
||||
close(tw.stopChannel)
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
|
||||
runner := threading.NewTaskRunner(drainWorkers)
|
||||
for _, slot := range tw.slots {
|
||||
for e := slot.Front(); e != nil; {
|
||||
task := e.Value.(*timingEntry)
|
||||
next := e.Next()
|
||||
slot.Remove(e)
|
||||
e = next
|
||||
if !task.removed {
|
||||
runner.Schedule(func() {
|
||||
fn(task.key, task.value)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) {
|
||||
steps := int(d / tw.interval)
|
||||
pos = (tw.tickedPos + steps) % tw.numSlots
|
||||
circle = (steps - 1) / tw.numSlots
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) initSlots() {
|
||||
for i := 0; i < tw.numSlots; i++ {
|
||||
tw.slots[i] = list.New()
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) moveTask(task baseEntry) {
|
||||
val, ok := tw.timers.Get(task.key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
timer := val.(*positionEntry)
|
||||
if task.delay < tw.interval {
|
||||
threading.GoSafe(func() {
|
||||
tw.execute(timer.item.key, timer.item.value)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
pos, circle := tw.getPositionAndCircle(task.delay)
|
||||
if pos >= timer.pos {
|
||||
timer.item.circle = circle
|
||||
timer.item.diff = pos - timer.pos
|
||||
} else if circle > 0 {
|
||||
circle--
|
||||
timer.item.circle = circle
|
||||
timer.item.diff = tw.numSlots + pos - timer.pos
|
||||
} else {
|
||||
timer.item.removed = true
|
||||
newItem := &timingEntry{
|
||||
baseEntry: task,
|
||||
value: timer.item.value,
|
||||
}
|
||||
tw.slots[pos].PushBack(newItem)
|
||||
tw.setTimerPosition(pos, newItem)
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) onTick() {
|
||||
tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
|
||||
l := tw.slots[tw.tickedPos]
|
||||
tw.scanAndRunTasks(l)
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) removeTask(key interface{}) {
|
||||
val, ok := tw.timers.Get(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
timer := val.(*positionEntry)
|
||||
timer.item.removed = true
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) run() {
|
||||
for {
|
||||
select {
|
||||
case <-tw.ticker.Chan():
|
||||
tw.onTick()
|
||||
case task := <-tw.setChannel:
|
||||
tw.setTask(&task)
|
||||
case key := <-tw.removeChannel:
|
||||
tw.removeTask(key)
|
||||
case task := <-tw.moveChannel:
|
||||
tw.moveTask(task)
|
||||
case fn := <-tw.drainChannel:
|
||||
tw.drainAll(fn)
|
||||
case <-tw.stopChannel:
|
||||
tw.ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) runTasks(tasks []timingTask) {
|
||||
if len(tasks) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for i := range tasks {
|
||||
threading.RunSafe(func() {
|
||||
tw.execute(tasks[i].key, tasks[i].value)
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
|
||||
var tasks []timingTask
|
||||
|
||||
for e := l.Front(); e != nil; {
|
||||
task := e.Value.(*timingEntry)
|
||||
if task.removed {
|
||||
next := e.Next()
|
||||
l.Remove(e)
|
||||
tw.timers.Del(task.key)
|
||||
e = next
|
||||
continue
|
||||
} else if task.circle > 0 {
|
||||
task.circle--
|
||||
e = e.Next()
|
||||
continue
|
||||
} else if task.diff > 0 {
|
||||
next := e.Next()
|
||||
l.Remove(e)
|
||||
// (tw.tickedPos+task.diff)%tw.numSlots
|
||||
// cannot be the same value of tw.tickedPos
|
||||
pos := (tw.tickedPos + task.diff) % tw.numSlots
|
||||
tw.slots[pos].PushBack(task)
|
||||
tw.setTimerPosition(pos, task)
|
||||
task.diff = 0
|
||||
e = next
|
||||
continue
|
||||
}
|
||||
|
||||
tasks = append(tasks, timingTask{
|
||||
key: task.key,
|
||||
value: task.value,
|
||||
})
|
||||
next := e.Next()
|
||||
l.Remove(e)
|
||||
tw.timers.Del(task.key)
|
||||
e = next
|
||||
}
|
||||
|
||||
tw.runTasks(tasks)
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) setTask(task *timingEntry) {
|
||||
if task.delay < tw.interval {
|
||||
task.delay = tw.interval
|
||||
}
|
||||
|
||||
if val, ok := tw.timers.Get(task.key); ok {
|
||||
entry := val.(*positionEntry)
|
||||
entry.item.value = task.value
|
||||
tw.moveTask(task.baseEntry)
|
||||
} else {
|
||||
pos, circle := tw.getPositionAndCircle(task.delay)
|
||||
task.circle = circle
|
||||
tw.slots[pos].PushBack(task)
|
||||
tw.setTimerPosition(pos, task)
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
|
||||
if val, ok := tw.timers.Get(task.key); ok {
|
||||
timer := val.(*positionEntry)
|
||||
timer.pos = pos
|
||||
} else {
|
||||
tw.timers.Set(task.key, &positionEntry{
|
||||
pos: pos,
|
||||
item: task,
|
||||
})
|
||||
}
|
||||
}
|
593
core/collection/timingwheel_test.go
Normal file
593
core/collection/timingwheel_test.go
Normal file
@ -0,0 +1,593 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/lang"
|
||||
"zero/core/stringx"
|
||||
"zero/core/syncx"
|
||||
"zero/core/timex"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
testStep = time.Minute
|
||||
waitTime = time.Second
|
||||
)
|
||||
|
||||
func TestNewTimingWheel(t *testing.T) {
|
||||
_, err := NewTimingWheel(0, 10, func(key, value interface{}) {})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestTimingWheel_Drain(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
tw.SetTimer("first", 3, testStep*4)
|
||||
tw.SetTimer("second", 5, testStep*7)
|
||||
tw.SetTimer("third", 7, testStep*7)
|
||||
var keys []string
|
||||
var vals []int
|
||||
var lock sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
tw.Drain(func(key, value interface{}) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
keys = append(keys, key.(string))
|
||||
vals = append(vals, value.(int))
|
||||
wg.Done()
|
||||
})
|
||||
wg.Wait()
|
||||
sort.Strings(keys)
|
||||
sort.Ints(vals)
|
||||
assert.Equal(t, 3, len(keys))
|
||||
assert.EqualValues(t, []string{"first", "second", "third"}, keys)
|
||||
assert.EqualValues(t, []int{3, 5, 7}, vals)
|
||||
var count int
|
||||
tw.Drain(func(key, value interface{}) {
|
||||
count++
|
||||
})
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestTimingWheel_SetTimerSoon(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
ticker.Done()
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
tw.SetTimer("any", 3, testStep>>1)
|
||||
ticker.Tick()
|
||||
assert.Nil(t, ticker.Wait(waitTime))
|
||||
assert.True(t, run.True())
|
||||
}
|
||||
|
||||
func TestTimingWheel_SetTimerTwice(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 5, v.(int))
|
||||
ticker.Done()
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
tw.SetTimer("any", 3, testStep*4)
|
||||
tw.SetTimer("any", 5, testStep*7)
|
||||
for i := 0; i < 8; i++ {
|
||||
ticker.Tick()
|
||||
}
|
||||
assert.Nil(t, ticker.Wait(waitTime))
|
||||
assert.True(t, run.True())
|
||||
}
|
||||
|
||||
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
defer tw.Stop()
|
||||
assert.NotPanics(t, func() {
|
||||
tw.SetTimer("any", 3, -testStep)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTimingWheel_MoveTimer(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
ticker.Done()
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
tw.SetTimer("any", 3, testStep*4)
|
||||
tw.MoveTimer("any", testStep*7)
|
||||
tw.MoveTimer("any", -testStep)
|
||||
tw.MoveTimer("none", testStep)
|
||||
for i := 0; i < 5; i++ {
|
||||
ticker.Tick()
|
||||
}
|
||||
assert.False(t, run.True())
|
||||
for i := 0; i < 3; i++ {
|
||||
ticker.Tick()
|
||||
}
|
||||
assert.Nil(t, ticker.Wait(waitTime))
|
||||
assert.True(t, run.True())
|
||||
}
|
||||
|
||||
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
ticker.Done()
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
tw.SetTimer("any", 3, testStep*4)
|
||||
tw.MoveTimer("any", testStep>>1)
|
||||
assert.Nil(t, ticker.Wait(waitTime))
|
||||
assert.True(t, run.True())
|
||||
}
|
||||
|
||||
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
ticker.Done()
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
tw.SetTimer("any", 3, testStep*4)
|
||||
tw.MoveTimer("any", testStep*2)
|
||||
for i := 0; i < 3; i++ {
|
||||
ticker.Tick()
|
||||
}
|
||||
assert.Nil(t, ticker.Wait(waitTime))
|
||||
assert.True(t, run.True())
|
||||
}
|
||||
|
||||
func TestTimingWheel_RemoveTimer(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw.SetTimer("any", 3, testStep)
|
||||
assert.NotPanics(t, func() {
|
||||
tw.RemoveTimer("any")
|
||||
tw.RemoveTimer("none")
|
||||
tw.RemoveTimer(nil)
|
||||
})
|
||||
for i := 0; i < 5; i++ {
|
||||
ticker.Tick()
|
||||
}
|
||||
tw.Stop()
|
||||
}
|
||||
|
||||
func TestTimingWheel_SetTimer(t *testing.T) {
|
||||
tests := []struct {
|
||||
slots int
|
||||
setAt time.Duration
|
||||
}{
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 5,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 7,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 12,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 7,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 12,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
ticker.Tick()
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
assert.Equal(t, 1, key.(int))
|
||||
assert.Equal(t, 2, value.(int))
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
assert.Nil(t, err)
|
||||
defer tw.Stop()
|
||||
|
||||
tw.SetTimer(1, 2, testStep*test.setAt)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
assert.Equal(t, int32(test.setAt), actual)
|
||||
return
|
||||
default:
|
||||
tick()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
|
||||
tests := []struct {
|
||||
slots int
|
||||
setAt time.Duration
|
||||
moveAt time.Duration
|
||||
}{
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 5,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 7,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 12,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 5,
|
||||
moveAt: 7,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 5,
|
||||
moveAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 5,
|
||||
moveAt: 12,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
ticker.Tick()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
assert.Nil(t, err)
|
||||
defer tw.Stop()
|
||||
|
||||
tw.SetTimer(1, 2, testStep*test.setAt)
|
||||
tw.MoveTimer(1, testStep*test.moveAt)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
assert.Equal(t, int32(test.moveAt), actual)
|
||||
return
|
||||
default:
|
||||
tick()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
|
||||
tests := []struct {
|
||||
slots int
|
||||
setAt time.Duration
|
||||
moveAt time.Duration
|
||||
moveAgainAt time.Duration
|
||||
}{
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 5,
|
||||
moveAgainAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 7,
|
||||
moveAgainAt: 12,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 10,
|
||||
moveAgainAt: 15,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 3,
|
||||
moveAt: 12,
|
||||
moveAgainAt: 17,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 5,
|
||||
moveAt: 7,
|
||||
moveAgainAt: 12,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 5,
|
||||
moveAt: 10,
|
||||
moveAgainAt: 17,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
setAt: 5,
|
||||
moveAt: 12,
|
||||
moveAgainAt: 17,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
ticker.Tick()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
assert.Nil(t, err)
|
||||
defer tw.Stop()
|
||||
|
||||
tw.SetTimer(1, 2, testStep*test.setAt)
|
||||
tw.MoveTimer(1, testStep*test.moveAt)
|
||||
tw.MoveTimer(1, testStep*test.moveAgainAt)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
assert.Equal(t, int32(test.moveAgainAt), actual)
|
||||
return
|
||||
default:
|
||||
tick()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheel_ElapsedAndSet(t *testing.T) {
|
||||
tests := []struct {
|
||||
slots int
|
||||
elapsed time.Duration
|
||||
setAt time.Duration
|
||||
}{
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 5,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 7,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 12,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 5,
|
||||
setAt: 7,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 5,
|
||||
setAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 5,
|
||||
setAt: 12,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
ticker.Tick()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
assert.Nil(t, err)
|
||||
defer tw.Stop()
|
||||
|
||||
for i := 0; i < int(test.elapsed); i++ {
|
||||
tick()
|
||||
}
|
||||
|
||||
tw.SetTimer(1, 2, testStep*test.setAt)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
assert.Equal(t, int32(test.elapsed+test.setAt), actual)
|
||||
return
|
||||
default:
|
||||
tick()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
|
||||
tests := []struct {
|
||||
slots int
|
||||
elapsed time.Duration
|
||||
setAt time.Duration
|
||||
moveAt time.Duration
|
||||
}{
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 5,
|
||||
moveAt: 10,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 7,
|
||||
moveAt: 12,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 10,
|
||||
moveAt: 15,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 3,
|
||||
setAt: 12,
|
||||
moveAt: 16,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 5,
|
||||
setAt: 7,
|
||||
moveAt: 12,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 5,
|
||||
setAt: 10,
|
||||
moveAt: 15,
|
||||
},
|
||||
{
|
||||
slots: 5,
|
||||
elapsed: 5,
|
||||
setAt: 12,
|
||||
moveAt: 17,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
var count int32
|
||||
ticker := timex.NewFakeTicker()
|
||||
tick := func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
ticker.Tick()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
assert.Nil(t, err)
|
||||
defer tw.Stop()
|
||||
|
||||
for i := 0; i < int(test.elapsed); i++ {
|
||||
tick()
|
||||
}
|
||||
|
||||
tw.SetTimer(1, 2, testStep*test.setAt)
|
||||
tw.MoveTimer(1, testStep*test.moveAt)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
assert.Equal(t, int32(test.elapsed+test.moveAt), actual)
|
||||
return
|
||||
default:
|
||||
tick()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTimingWheel(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {})
|
||||
for i := 0; i < b.N; i++ {
|
||||
tw.SetTimer(i, i, time.Second)
|
||||
tw.SetTimer(b.N+i, b.N+i, time.Second)
|
||||
tw.MoveTimer(i, time.Second*time.Duration(i))
|
||||
tw.RemoveTimer(i)
|
||||
}
|
||||
}
|
40
core/conf/config.go
Normal file
40
core/conf/config.go
Normal file
@ -0,0 +1,40 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"path"
|
||||
|
||||
"zero/core/mapping"
|
||||
)
|
||||
|
||||
var loaders = map[string]func([]byte, interface{}) error{
|
||||
".json": LoadConfigFromJsonBytes,
|
||||
".yaml": LoadConfigFromYamlBytes,
|
||||
".yml": LoadConfigFromYamlBytes,
|
||||
}
|
||||
|
||||
func LoadConfig(file string, v interface{}) error {
|
||||
if content, err := ioutil.ReadFile(file); err != nil {
|
||||
return err
|
||||
} else if loader, ok := loaders[path.Ext(file)]; ok {
|
||||
return loader(content, v)
|
||||
} else {
|
||||
return fmt.Errorf("unrecoginized file type: %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
|
||||
return mapping.UnmarshalJsonBytes(content, v)
|
||||
}
|
||||
|
||||
func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
|
||||
return mapping.UnmarshalYamlBytes(content, v)
|
||||
}
|
||||
|
||||
func MustLoad(path string, v interface{}) {
|
||||
if err := LoadConfig(path, v); err != nil {
|
||||
log.Fatalf("error: config file %s, %s", path, err.Error())
|
||||
}
|
||||
}
|
109
core/conf/properties.go
Normal file
109
core/conf/properties.go
Normal file
@ -0,0 +1,109 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"zero/core/iox"
|
||||
)
|
||||
|
||||
// PropertyError represents a configuration error message.
|
||||
type PropertyError struct {
|
||||
error
|
||||
message string
|
||||
}
|
||||
|
||||
// Properties interface provides the means to access configuration.
|
||||
type Properties interface {
|
||||
GetString(key string) string
|
||||
SetString(key, value string)
|
||||
GetInt(key string) int
|
||||
SetInt(key string, value int)
|
||||
ToString() string
|
||||
}
|
||||
|
||||
// Properties config is a key/value pair based configuration structure.
|
||||
type mapBasedProperties struct {
|
||||
properties map[string]string
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Loads the properties into a properties configuration instance. May return the
|
||||
// configuration itself along with an error that indicates if there was a problem loading the configuration.
|
||||
func LoadProperties(filename string) (Properties, error) {
|
||||
lines, err := iox.ReadTextLines(filename, iox.WithoutBlank(), iox.OmitWithPrefix("#"))
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
raw := make(map[string]string)
|
||||
for i := range lines {
|
||||
pair := strings.Split(lines[i], "=")
|
||||
if len(pair) != 2 {
|
||||
// invalid property format
|
||||
return nil, &PropertyError{
|
||||
message: fmt.Sprintf("invalid property format: %s", pair),
|
||||
}
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(pair[0])
|
||||
value := strings.TrimSpace(pair[1])
|
||||
raw[key] = value
|
||||
}
|
||||
|
||||
return &mapBasedProperties{
|
||||
properties: raw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (config *mapBasedProperties) GetString(key string) string {
|
||||
config.lock.RLock()
|
||||
ret := config.properties[key]
|
||||
config.lock.RUnlock()
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (config *mapBasedProperties) SetString(key, value string) {
|
||||
config.lock.Lock()
|
||||
config.properties[key] = value
|
||||
config.lock.Unlock()
|
||||
}
|
||||
|
||||
func (config *mapBasedProperties) GetInt(key string) int {
|
||||
config.lock.RLock()
|
||||
// default 0
|
||||
value, _ := strconv.Atoi(config.properties[key])
|
||||
config.lock.RUnlock()
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func (config *mapBasedProperties) SetInt(key string, value int) {
|
||||
config.lock.Lock()
|
||||
config.properties[key] = strconv.Itoa(value)
|
||||
config.lock.Unlock()
|
||||
}
|
||||
|
||||
// Dumps the configuration internal map into a string.
|
||||
func (config *mapBasedProperties) ToString() string {
|
||||
config.lock.RLock()
|
||||
ret := fmt.Sprintf("%s", config.properties)
|
||||
config.lock.RUnlock()
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Returns the error message.
|
||||
func (configError *PropertyError) Error() string {
|
||||
return configError.message
|
||||
}
|
||||
|
||||
// Builds a new properties configuration structure
|
||||
func NewProperties() Properties {
|
||||
return &mapBasedProperties{
|
||||
properties: make(map[string]string),
|
||||
}
|
||||
}
|
44
core/conf/properties_test.go
Normal file
44
core/conf/properties_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"zero/core/fs"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProperties(t *testing.T) {
|
||||
text := `app.name = test
|
||||
|
||||
app.program=app
|
||||
|
||||
# this is comment
|
||||
app.threads = 5`
|
||||
tmpfile, err := fs.TempFilenameWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
props, err := LoadProperties(tmpfile)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "test", props.GetString("app.name"))
|
||||
assert.Equal(t, "app", props.GetString("app.program"))
|
||||
assert.Equal(t, 5, props.GetInt("app.threads"))
|
||||
}
|
||||
|
||||
func TestSetString(t *testing.T) {
|
||||
key := "a"
|
||||
value := "the value of a"
|
||||
props := NewProperties()
|
||||
props.SetString(key, value)
|
||||
assert.Equal(t, value, props.GetString(key))
|
||||
}
|
||||
|
||||
func TestSetInt(t *testing.T) {
|
||||
key := "a"
|
||||
value := 101
|
||||
props := NewProperties()
|
||||
props.SetInt(key, value)
|
||||
assert.Equal(t, value, props.GetInt(key))
|
||||
}
|
17
core/contextx/deadline.go
Normal file
17
core/contextx/deadline.go
Normal file
@ -0,0 +1,17 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ShrinkDeadline(ctx context.Context, timeout time.Duration) (context.Context, func()) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
leftTime := time.Until(deadline)
|
||||
if leftTime < timeout {
|
||||
timeout = leftTime
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithDeadline(ctx, time.Now().Add(timeout))
|
||||
}
|
27
core/contextx/deadline_test.go
Normal file
27
core/contextx/deadline_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShrinkDeadlineLess(t *testing.T) {
|
||||
deadline := time.Now().Add(time.Second)
|
||||
ctx, _ := context.WithDeadline(context.Background(), deadline)
|
||||
ctx, _ = ShrinkDeadline(ctx, time.Minute)
|
||||
dl, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, deadline, dl)
|
||||
}
|
||||
|
||||
func TestShrinkDeadlineMore(t *testing.T) {
|
||||
deadline := time.Now().Add(time.Minute)
|
||||
ctx, _ := context.WithDeadline(context.Background(), deadline)
|
||||
ctx, _ = ShrinkDeadline(ctx, time.Second)
|
||||
dl, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, dl.Before(deadline))
|
||||
}
|
26
core/contextx/unmarshaler.go
Normal file
26
core/contextx/unmarshaler.go
Normal file
@ -0,0 +1,26 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"zero/core/mapping"
|
||||
)
|
||||
|
||||
const contextTagKey = "ctx"
|
||||
|
||||
var unmarshaler = mapping.NewUnmarshaler(contextTagKey)
|
||||
|
||||
type contextValuer struct {
|
||||
context.Context
|
||||
}
|
||||
|
||||
func (cv contextValuer) Value(key string) (interface{}, bool) {
|
||||
v := cv.Context.Value(key)
|
||||
return v, v != nil
|
||||
}
|
||||
|
||||
func For(ctx context.Context, v interface{}) error {
|
||||
return unmarshaler.UnmarshalValuer(contextValuer{
|
||||
Context: ctx,
|
||||
}, v)
|
||||
}
|
58
core/contextx/unmarshaler_test.go
Normal file
58
core/contextx/unmarshaler_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnmarshalContext(t *testing.T) {
|
||||
type Person struct {
|
||||
Name string `ctx:"name"`
|
||||
Age int `ctx:"age"`
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "name", "kevin")
|
||||
ctx = context.WithValue(ctx, "age", 20)
|
||||
|
||||
var person Person
|
||||
err := For(ctx, &person)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "kevin", person.Name)
|
||||
assert.Equal(t, 20, person.Age)
|
||||
}
|
||||
|
||||
func TestUnmarshalContextWithOptional(t *testing.T) {
|
||||
type Person struct {
|
||||
Name string `ctx:"name"`
|
||||
Age int `ctx:"age,optional"`
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "name", "kevin")
|
||||
|
||||
var person Person
|
||||
err := For(ctx, &person)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "kevin", person.Name)
|
||||
assert.Equal(t, 0, person.Age)
|
||||
}
|
||||
|
||||
func TestUnmarshalContextWithMissing(t *testing.T) {
|
||||
type Person struct {
|
||||
Name string `ctx:"name"`
|
||||
Age int `ctx:"age"`
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "name", "kevin")
|
||||
|
||||
var person Person
|
||||
err := For(ctx, &person)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
}
|
28
core/contextx/valueonlycontext.go
Normal file
28
core/contextx/valueonlycontext.go
Normal file
@ -0,0 +1,28 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type valueOnlyContext struct {
|
||||
context.Context
|
||||
}
|
||||
|
||||
func (valueOnlyContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return
|
||||
}
|
||||
|
||||
func (valueOnlyContext) Done() <-chan struct{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (valueOnlyContext) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValueOnlyFrom(ctx context.Context) context.Context {
|
||||
return valueOnlyContext{
|
||||
Context: ctx,
|
||||
}
|
||||
}
|
54
core/contextx/valueonlycontext_test.go
Normal file
54
core/contextx/valueonlycontext_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContextCancel(t *testing.T) {
|
||||
c := context.WithValue(context.Background(), "key", "value")
|
||||
c1, cancel := context.WithCancel(c)
|
||||
o := ValueOnlyFrom(c1)
|
||||
c2, _ := context.WithCancel(o)
|
||||
contexts := []context.Context{c1, c2}
|
||||
|
||||
for _, c := range contexts {
|
||||
assert.NotNil(t, c.Done())
|
||||
assert.Nil(t, c.Err())
|
||||
|
||||
select {
|
||||
case x := <-c.Done():
|
||||
t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-c1.Done()
|
||||
|
||||
assert.Nil(t, o.Err())
|
||||
assert.Equal(t, context.Canceled, c1.Err())
|
||||
assert.NotEqual(t, context.Canceled, c2.Err())
|
||||
}
|
||||
|
||||
func TestConextDeadline(t *testing.T) {
|
||||
c, _ := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond))
|
||||
o := ValueOnlyFrom(c)
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
case <-o.Done():
|
||||
t.Fatal("ValueOnlyContext: context should not have timed out")
|
||||
}
|
||||
|
||||
c, _ = context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond))
|
||||
o = ValueOnlyFrom(c)
|
||||
c, _ = context.WithDeadline(o, time.Now().Add(20*time.Millisecond))
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("ValueOnlyContext+Deadline: context should have timed out")
|
||||
case <-c.Done():
|
||||
}
|
||||
}
|
40
core/discov/clients.go
Normal file
40
core/discov/clients.go
Normal file
@ -0,0 +1,40 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"zero/core/discov/internal"
|
||||
)
|
||||
|
||||
const (
|
||||
indexOfKey = iota
|
||||
indexOfId
|
||||
)
|
||||
|
||||
const timeToLive int64 = 10
|
||||
|
||||
var TimeToLive = timeToLive
|
||||
|
||||
func extract(etcdKey string, index int) (string, bool) {
|
||||
if index < 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
fields := strings.FieldsFunc(etcdKey, func(ch rune) bool {
|
||||
return ch == internal.Delimiter
|
||||
})
|
||||
if index >= len(fields) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return fields[index], true
|
||||
}
|
||||
|
||||
func extractId(etcdKey string) (string, bool) {
|
||||
return extract(etcdKey, indexOfId)
|
||||
}
|
||||
|
||||
func makeEtcdKey(key string, id int64) string {
|
||||
return fmt.Sprintf("%s%c%d", key, internal.Delimiter, id)
|
||||
}
|
36
core/discov/clients_test.go
Normal file
36
core/discov/clients_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"zero/core/discov/internal"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mockLock sync.Mutex
|
||||
|
||||
func setMockClient(cli internal.EtcdClient) func() {
|
||||
mockLock.Lock()
|
||||
internal.NewClient = func([]string) (internal.EtcdClient, error) {
|
||||
return cli, nil
|
||||
}
|
||||
return func() {
|
||||
internal.NewClient = internal.DialClient
|
||||
mockLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtract(t *testing.T) {
|
||||
id, ok := extractId("key/123/val")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "123", id)
|
||||
|
||||
_, ok = extract("any", -1)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMakeKey(t *testing.T) {
|
||||
assert.Equal(t, "key/123", makeEtcdKey("key", 123))
|
||||
}
|
18
core/discov/config.go
Normal file
18
core/discov/config.go
Normal file
@ -0,0 +1,18 @@
|
||||
package discov
|
||||
|
||||
import "errors"
|
||||
|
||||
type EtcdConf struct {
|
||||
Hosts []string
|
||||
Key string
|
||||
}
|
||||
|
||||
func (c EtcdConf) Validate() error {
|
||||
if len(c.Hosts) == 0 {
|
||||
return errors.New("empty etcd hosts")
|
||||
} else if len(c.Key) == 0 {
|
||||
return errors.New("empty etcd key")
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
46
core/discov/config_test.go
Normal file
46
core/discov/config_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
EtcdConf
|
||||
pass bool
|
||||
}{
|
||||
{
|
||||
EtcdConf: EtcdConf{},
|
||||
pass: false,
|
||||
},
|
||||
{
|
||||
EtcdConf: EtcdConf{
|
||||
Key: "any",
|
||||
},
|
||||
pass: false,
|
||||
},
|
||||
{
|
||||
EtcdConf: EtcdConf{
|
||||
Hosts: []string{"any"},
|
||||
},
|
||||
pass: false,
|
||||
},
|
||||
{
|
||||
EtcdConf: EtcdConf{
|
||||
Hosts: []string{"any"},
|
||||
Key: "key",
|
||||
},
|
||||
pass: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if test.pass {
|
||||
assert.Nil(t, test.EtcdConf.Validate())
|
||||
} else {
|
||||
assert.NotNil(t, test.EtcdConf.Validate())
|
||||
}
|
||||
}
|
||||
}
|
47
core/discov/facade.go
Normal file
47
core/discov/facade.go
Normal file
@ -0,0 +1,47 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"zero/core/discov/internal"
|
||||
"zero/core/lang"
|
||||
)
|
||||
|
||||
type (
|
||||
Facade struct {
|
||||
endpoints []string
|
||||
registry *internal.Registry
|
||||
}
|
||||
|
||||
FacadeListener interface {
|
||||
OnAdd(key, val string)
|
||||
OnDelete(key string)
|
||||
}
|
||||
)
|
||||
|
||||
func NewFacade(endpoints []string) Facade {
|
||||
return Facade{
|
||||
endpoints: endpoints,
|
||||
registry: internal.GetRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
func (f Facade) Client() internal.EtcdClient {
|
||||
conn, err := f.registry.GetConn(f.endpoints)
|
||||
lang.Must(err)
|
||||
return conn
|
||||
}
|
||||
|
||||
func (f Facade) Monitor(key string, l FacadeListener) {
|
||||
f.registry.Monitor(f.endpoints, key, listenerAdapter{l})
|
||||
}
|
||||
|
||||
type listenerAdapter struct {
|
||||
l FacadeListener
|
||||
}
|
||||
|
||||
func (la listenerAdapter) OnAdd(kv internal.KV) {
|
||||
la.l.OnAdd(kv.Key, kv.Val)
|
||||
}
|
||||
|
||||
func (la listenerAdapter) OnDelete(kv internal.KV) {
|
||||
la.l.OnDelete(kv.Key)
|
||||
}
|
103
core/discov/internal/balancer.go
Normal file
103
core/discov/internal/balancer.go
Normal file
@ -0,0 +1,103 @@
|
||||
package internal
|
||||
|
||||
import "sync"
|
||||
|
||||
type (
|
||||
DialFn func(server string) (interface{}, error)
|
||||
CloseFn func(server string, conn interface{}) error
|
||||
|
||||
Balancer interface {
|
||||
AddConn(kv KV) error
|
||||
IsEmpty() bool
|
||||
Next(key ...string) (interface{}, bool)
|
||||
RemoveKey(key string)
|
||||
initialize()
|
||||
setListener(listener Listener)
|
||||
}
|
||||
|
||||
serverConn struct {
|
||||
key string
|
||||
conn interface{}
|
||||
}
|
||||
|
||||
baseBalancer struct {
|
||||
exclusive bool
|
||||
servers map[string][]string
|
||||
mapping map[string]string
|
||||
lock sync.Mutex
|
||||
dialFn DialFn
|
||||
closeFn CloseFn
|
||||
listener Listener
|
||||
}
|
||||
)
|
||||
|
||||
func newBaseBalancer(dialFn DialFn, closeFn CloseFn, exclusive bool) *baseBalancer {
|
||||
return &baseBalancer{
|
||||
exclusive: exclusive,
|
||||
servers: make(map[string][]string),
|
||||
mapping: make(map[string]string),
|
||||
dialFn: dialFn,
|
||||
closeFn: closeFn,
|
||||
}
|
||||
}
|
||||
|
||||
// addKv adds the kv, returns if there are already other keys associate with the server
|
||||
func (b *baseBalancer) addKv(key, value string) ([]string, bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
keys := b.servers[value]
|
||||
previous := append([]string(nil), keys...)
|
||||
early := len(keys) > 0
|
||||
if b.exclusive && early {
|
||||
for _, each := range keys {
|
||||
b.doRemoveKv(each)
|
||||
}
|
||||
}
|
||||
b.servers[value] = append(b.servers[value], key)
|
||||
b.mapping[key] = value
|
||||
|
||||
if early {
|
||||
return previous, true
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseBalancer) doRemoveKv(key string) (server string, keepConn bool) {
|
||||
server, ok := b.mapping[key]
|
||||
if !ok {
|
||||
return "", true
|
||||
}
|
||||
|
||||
delete(b.mapping, key)
|
||||
keys := b.servers[server]
|
||||
remain := keys[:0]
|
||||
|
||||
for _, k := range keys {
|
||||
if k != key {
|
||||
remain = append(remain, k)
|
||||
}
|
||||
}
|
||||
|
||||
if len(remain) > 0 {
|
||||
b.servers[server] = remain
|
||||
return server, true
|
||||
} else {
|
||||
delete(b.servers, server)
|
||||
return server, false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseBalancer) removeKv(key string) (server string, keepConn bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
return b.doRemoveKv(key)
|
||||
}
|
||||
|
||||
func (b *baseBalancer) setListener(listener Listener) {
|
||||
b.lock.Lock()
|
||||
b.listener = listener
|
||||
b.lock.Unlock()
|
||||
}
|
5
core/discov/internal/balancer_test.go
Normal file
5
core/discov/internal/balancer_test.go
Normal file
@ -0,0 +1,5 @@
|
||||
package internal
|
||||
|
||||
type mockConn struct {
|
||||
server string
|
||||
}
|
152
core/discov/internal/consistentbalancer.go
Normal file
152
core/discov/internal/consistentbalancer.go
Normal file
@ -0,0 +1,152 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"zero/core/hash"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
type consistentBalancer struct {
|
||||
*baseBalancer
|
||||
conns map[string]interface{}
|
||||
buckets *hash.ConsistentHash
|
||||
bucketKey func(KV) string
|
||||
}
|
||||
|
||||
func NewConsistentBalancer(dialFn DialFn, closeFn CloseFn, keyer func(kv KV) string) *consistentBalancer {
|
||||
// we don't support exclusive mode for consistent Balancer, to avoid complexity,
|
||||
// because there are few scenarios, use it on your own risks.
|
||||
balancer := &consistentBalancer{
|
||||
conns: make(map[string]interface{}),
|
||||
buckets: hash.NewConsistentHash(),
|
||||
bucketKey: keyer,
|
||||
}
|
||||
balancer.baseBalancer = newBaseBalancer(dialFn, closeFn, false)
|
||||
return balancer
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) AddConn(kv KV) error {
|
||||
// not adding kv and conn within a transaction, but it doesn't matter
|
||||
// we just rollback the kv addition if dial failed
|
||||
var conn interface{}
|
||||
prev, found := b.addKv(kv.Key, kv.Val)
|
||||
if found {
|
||||
conn = b.handlePrevious(prev)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
var err error
|
||||
conn, err = b.dialFn(kv.Val)
|
||||
if err != nil {
|
||||
b.removeKv(kv.Key)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
bucketKey := b.bucketKey(kv)
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
b.conns[bucketKey] = conn
|
||||
b.buckets.Add(bucketKey)
|
||||
b.notify(bucketKey)
|
||||
|
||||
logx.Infof("added server, key: %s, server: %s", bucketKey, kv.Val)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) getConn(key string) (interface{}, bool) {
|
||||
b.lock.Lock()
|
||||
conn, ok := b.conns[key]
|
||||
b.lock.Unlock()
|
||||
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) handlePrevious(prev []string) interface{} {
|
||||
if len(prev) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// if not exclusive, only need to randomly find one connection
|
||||
for key, conn := range b.conns {
|
||||
if key == prev[0] {
|
||||
return conn
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) initialize() {
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) notify(key string) {
|
||||
if b.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var keys []string
|
||||
var values []string
|
||||
for k := range b.conns {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
for _, v := range b.mapping {
|
||||
values = append(values, v)
|
||||
}
|
||||
|
||||
b.listener.OnUpdate(keys, values, key)
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) RemoveKey(key string) {
|
||||
kv := KV{Key: key}
|
||||
server, keep := b.removeKv(key)
|
||||
kv.Val = server
|
||||
bucketKey := b.bucketKey(kv)
|
||||
b.buckets.Remove(b.bucketKey(kv))
|
||||
|
||||
// wrap the query & removal in a function to make sure the quick lock/unlock
|
||||
conn, ok := func() (interface{}, bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
conn, ok := b.conns[bucketKey]
|
||||
if ok {
|
||||
delete(b.conns, bucketKey)
|
||||
}
|
||||
|
||||
return conn, ok
|
||||
}()
|
||||
if ok && !keep {
|
||||
logx.Infof("removing server, key: %s", kv.Key)
|
||||
if err := b.closeFn(server, conn); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// notify without new key
|
||||
b.notify("")
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) IsEmpty() bool {
|
||||
b.lock.Lock()
|
||||
empty := len(b.conns) == 0
|
||||
b.lock.Unlock()
|
||||
|
||||
return empty
|
||||
}
|
||||
|
||||
func (b *consistentBalancer) Next(keys ...string) (interface{}, bool) {
|
||||
if len(keys) != 1 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
key := keys[0]
|
||||
if node, ok := b.buckets.Get(key); !ok {
|
||||
return nil, false
|
||||
} else {
|
||||
return b.getConn(node.(string))
|
||||
}
|
||||
}
|
178
core/discov/internal/consistentbalancer_test.go
Normal file
178
core/discov/internal/consistentbalancer_test.go
Normal file
@ -0,0 +1,178 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"zero/core/mathx"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConsistent_addConn(t *testing.T) {
|
||||
b := NewConsistentBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return errors.New("error")
|
||||
}, func(kv KV) string {
|
||||
return kv.Key
|
||||
})
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey1",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, map[string]interface{}{
|
||||
"thekey1": mockConn{server: "thevalue"},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey1"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey1": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey2",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, map[string]interface{}{
|
||||
"thekey1": mockConn{server: "thevalue"},
|
||||
"thekey2": mockConn{server: "thevalue"},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey1", "thekey2"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey1": "thevalue",
|
||||
"thekey2": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.False(t, b.IsEmpty())
|
||||
|
||||
b.RemoveKey("thekey1")
|
||||
assert.EqualValues(t, map[string]interface{}{
|
||||
"thekey2": mockConn{server: "thevalue"},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey2"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey2": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.False(t, b.IsEmpty())
|
||||
|
||||
b.RemoveKey("thekey2")
|
||||
assert.Equal(t, 0, len(b.conns))
|
||||
assert.EqualValues(t, map[string][]string{}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{}, b.mapping)
|
||||
assert.True(t, b.IsEmpty())
|
||||
}
|
||||
|
||||
func TestConsistent_addConnError(t *testing.T) {
|
||||
b := NewConsistentBalancer(func(server string) (interface{}, error) {
|
||||
return nil, errors.New("error")
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, func(kv KV) string {
|
||||
return kv.Key
|
||||
})
|
||||
assert.NotNil(t, b.AddConn(KV{
|
||||
Key: "thekey1",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.Equal(t, 0, len(b.conns))
|
||||
assert.EqualValues(t, map[string][]string{}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{}, b.mapping)
|
||||
}
|
||||
|
||||
func TestConsistent_next(t *testing.T) {
|
||||
b := NewConsistentBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return errors.New("error")
|
||||
}, func(kv KV) string {
|
||||
return kv.Key
|
||||
})
|
||||
b.initialize()
|
||||
|
||||
_, ok := b.Next("any")
|
||||
assert.False(t, ok)
|
||||
|
||||
const size = 100
|
||||
for i := 0; i < size; i++ {
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey/" + strconv.Itoa(i),
|
||||
Val: "thevalue/" + strconv.Itoa(i),
|
||||
}))
|
||||
}
|
||||
|
||||
m := make(map[interface{}]int)
|
||||
const total = 10000
|
||||
for i := 0; i < total; i++ {
|
||||
val, ok := b.Next(strconv.Itoa(i))
|
||||
assert.True(t, ok)
|
||||
m[val]++
|
||||
}
|
||||
|
||||
entropy := mathx.CalcEntropy(m, total)
|
||||
assert.Equal(t, size, len(m))
|
||||
assert.True(t, entropy > .95)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
b.RemoveKey("thekey/" + strconv.Itoa(i))
|
||||
}
|
||||
_, ok = b.Next()
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestConsistentBalancer_Listener(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
b := NewConsistentBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, func(kv KV) string {
|
||||
return kv.Key
|
||||
})
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "key1",
|
||||
Val: "val1",
|
||||
}))
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "key2",
|
||||
Val: "val2",
|
||||
}))
|
||||
|
||||
listener := NewMockListener(ctrl)
|
||||
listener.EXPECT().OnUpdate(gomock.Any(), gomock.Any(), "key2").Do(func(keys, vals, _ interface{}) {
|
||||
sort.Strings(keys.([]string))
|
||||
sort.Strings(vals.([]string))
|
||||
assert.EqualValues(t, []string{"key1", "key2"}, keys)
|
||||
assert.EqualValues(t, []string{"val1", "val2"}, vals)
|
||||
})
|
||||
b.setListener(listener)
|
||||
b.notify("key2")
|
||||
}
|
||||
|
||||
func TestConsistentBalancer_remove(t *testing.T) {
|
||||
b := NewConsistentBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, func(kv KV) string {
|
||||
return kv.Key
|
||||
})
|
||||
|
||||
assert.Nil(t, b.handlePrevious(nil))
|
||||
assert.Nil(t, b.handlePrevious([]string{"any"}))
|
||||
}
|
21
core/discov/internal/etcdclient.go
Normal file
21
core/discov/internal/etcdclient.go
Normal file
@ -0,0 +1,21 @@
|
||||
//go:generate mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type EtcdClient interface {
|
||||
ActiveConnection() *grpc.ClientConn
|
||||
Close() error
|
||||
Ctx() context.Context
|
||||
Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error)
|
||||
Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error)
|
||||
KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error)
|
||||
Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error)
|
||||
Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error)
|
||||
Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan
|
||||
}
|
182
core/discov/internal/etcdclient_mock.go
Normal file
182
core/discov/internal/etcdclient_mock.go
Normal file
@ -0,0 +1,182 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: etcdclient.go
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
|
||||
import (
|
||||
context "context"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
clientv3 "go.etcd.io/etcd/clientv3"
|
||||
grpc "google.golang.org/grpc"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockEtcdClient is a mock of EtcdClient interface
|
||||
type MockEtcdClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEtcdClientMockRecorder
|
||||
}
|
||||
|
||||
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient
|
||||
type MockEtcdClientMockRecorder struct {
|
||||
mock *MockEtcdClient
|
||||
}
|
||||
|
||||
// NewMockEtcdClient creates a new mock instance
|
||||
func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient {
|
||||
mock := &MockEtcdClient{ctrl: ctrl}
|
||||
mock.recorder = &MockEtcdClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ActiveConnection mocks base method
|
||||
func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ActiveConnection")
|
||||
ret0, _ := ret[0].(*grpc.ClientConn)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ActiveConnection indicates an expected call of ActiveConnection
|
||||
func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection))
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockEtcdClient) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close))
|
||||
}
|
||||
|
||||
// Ctx mocks base method
|
||||
func (m *MockEtcdClient) Ctx() context.Context {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Ctx")
|
||||
ret0, _ := ret[0].(context.Context)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Ctx indicates an expected call of Ctx
|
||||
func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx))
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{ctx, key}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Get", varargs...)
|
||||
ret0, _ := ret[0].(*clientv3.GetResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
func (mr *MockEtcdClientMockRecorder) Get(ctx, key interface{}, opts ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{ctx, key}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
|
||||
}
|
||||
|
||||
// Grant mocks base method
|
||||
func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Grant", ctx, ttl)
|
||||
ret0, _ := ret[0].(*clientv3.LeaseGrantResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Grant indicates an expected call of Grant
|
||||
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
|
||||
}
|
||||
|
||||
// KeepAlive mocks base method
|
||||
func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeepAlive", ctx, id)
|
||||
ret0, _ := ret[0].(<-chan *clientv3.LeaseKeepAliveResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// KeepAlive indicates an expected call of KeepAlive
|
||||
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
|
||||
}
|
||||
|
||||
// Put mocks base method
|
||||
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{ctx, key, val}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Put", varargs...)
|
||||
ret0, _ := ret[0].(*clientv3.PutResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Put indicates an expected call of Put
|
||||
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val interface{}, opts ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{ctx, key, val}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
|
||||
}
|
||||
|
||||
// Revoke mocks base method
|
||||
func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Revoke", ctx, id)
|
||||
ret0, _ := ret[0].(*clientv3.LeaseRevokeResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Revoke indicates an expected call of Revoke
|
||||
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
|
||||
}
|
||||
|
||||
// Watch mocks base method
|
||||
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{ctx, key}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Watch", varargs...)
|
||||
ret0, _ := ret[0].(clientv3.WatchChan)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Watch indicates an expected call of Watch
|
||||
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key interface{}, opts ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{ctx, key}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...)
|
||||
}
|
6
core/discov/internal/listener.go
Normal file
6
core/discov/internal/listener.go
Normal file
@ -0,0 +1,6 @@
|
||||
//go:generate mockgen -package internal -destination listener_mock.go -source listener.go Listener
|
||||
package internal
|
||||
|
||||
type Listener interface {
|
||||
OnUpdate(keys []string, values []string, newKey string)
|
||||
}
|
45
core/discov/internal/listener_mock.go
Normal file
45
core/discov/internal/listener_mock.go
Normal file
@ -0,0 +1,45 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: listener.go
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockListener is a mock of Listener interface
|
||||
type MockListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockListenerMockRecorder
|
||||
}
|
||||
|
||||
// MockListenerMockRecorder is the mock recorder for MockListener
|
||||
type MockListenerMockRecorder struct {
|
||||
mock *MockListener
|
||||
}
|
||||
|
||||
// NewMockListener creates a new mock instance
|
||||
func NewMockListener(ctrl *gomock.Controller) *MockListener {
|
||||
mock := &MockListener{ctrl: ctrl}
|
||||
mock.recorder = &MockListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// OnUpdate mocks base method
|
||||
func (m *MockListener) OnUpdate(keys, values []string, newKey string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnUpdate", keys, values, newKey)
|
||||
}
|
||||
|
||||
// OnUpdate indicates an expected call of OnUpdate
|
||||
func (mr *MockListenerMockRecorder) OnUpdate(keys, values, newKey interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnUpdate", reflect.TypeOf((*MockListener)(nil).OnUpdate), keys, values, newKey)
|
||||
}
|
310
core/discov/internal/registry.go
Normal file
310
core/discov/internal/registry.go
Normal file
@ -0,0 +1,310 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/contextx"
|
||||
"zero/core/lang"
|
||||
"zero/core/logx"
|
||||
"zero/core/syncx"
|
||||
"zero/core/threading"
|
||||
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
|
||||
var (
|
||||
registryInstance = Registry{
|
||||
clusters: make(map[string]*cluster),
|
||||
}
|
||||
connManager = syncx.NewResourceManager()
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
clusters map[string]*cluster
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func GetRegistry() *Registry {
|
||||
return ®istryInstance
|
||||
}
|
||||
|
||||
func (r *Registry) getCluster(endpoints []string) *cluster {
|
||||
clusterKey := getClusterKey(endpoints)
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
c, ok := r.clusters[clusterKey]
|
||||
if !ok {
|
||||
c = newCluster(endpoints)
|
||||
r.clusters[clusterKey] = c
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
|
||||
return r.getCluster(endpoints).getClient()
|
||||
}
|
||||
|
||||
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
|
||||
return r.getCluster(endpoints).monitor(key, l)
|
||||
}
|
||||
|
||||
type cluster struct {
|
||||
endpoints []string
|
||||
key string
|
||||
values map[string]map[string]string
|
||||
listeners map[string][]UpdateListener
|
||||
watchGroup *threading.RoutineGroup
|
||||
done chan lang.PlaceholderType
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func newCluster(endpoints []string) *cluster {
|
||||
return &cluster{
|
||||
endpoints: endpoints,
|
||||
key: getClusterKey(endpoints),
|
||||
values: make(map[string]map[string]string),
|
||||
listeners: make(map[string][]UpdateListener),
|
||||
watchGroup: threading.NewRoutineGroup(),
|
||||
done: make(chan lang.PlaceholderType),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) context(cli EtcdClient) context.Context {
|
||||
return contextx.ValueOnlyFrom(cli.Ctx())
|
||||
}
|
||||
|
||||
func (c *cluster) getClient() (EtcdClient, error) {
|
||||
val, err := connManager.GetResource(c.key, func() (io.Closer, error) {
|
||||
return c.newClient()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(EtcdClient), nil
|
||||
}
|
||||
|
||||
func (c *cluster) handleChanges(key string, kvs []KV) {
|
||||
var add []KV
|
||||
var remove []KV
|
||||
c.lock.Lock()
|
||||
listeners := append([]UpdateListener(nil), c.listeners[key]...)
|
||||
vals, ok := c.values[key]
|
||||
if !ok {
|
||||
add = kvs
|
||||
vals = make(map[string]string)
|
||||
for _, kv := range kvs {
|
||||
vals[kv.Key] = kv.Val
|
||||
}
|
||||
c.values[key] = vals
|
||||
} else {
|
||||
m := make(map[string]string)
|
||||
for _, kv := range kvs {
|
||||
m[kv.Key] = kv.Val
|
||||
}
|
||||
for k, v := range vals {
|
||||
if val, ok := m[k]; !ok || v != val {
|
||||
remove = append(remove, KV{
|
||||
Key: k,
|
||||
Val: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
for k, v := range m {
|
||||
if val, ok := vals[k]; !ok || v != val {
|
||||
add = append(add, KV{
|
||||
Key: k,
|
||||
Val: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.values[key] = m
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
for _, kv := range add {
|
||||
for _, l := range listeners {
|
||||
l.OnAdd(kv)
|
||||
}
|
||||
}
|
||||
for _, kv := range remove {
|
||||
for _, l := range listeners {
|
||||
l.OnDelete(kv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
|
||||
c.lock.Lock()
|
||||
listeners := append([]UpdateListener(nil), c.listeners[key]...)
|
||||
c.lock.Unlock()
|
||||
|
||||
for _, ev := range events {
|
||||
switch ev.Type {
|
||||
case clientv3.EventTypePut:
|
||||
c.lock.Lock()
|
||||
if vals, ok := c.values[key]; ok {
|
||||
vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
|
||||
} else {
|
||||
c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
|
||||
}
|
||||
c.lock.Unlock()
|
||||
for _, l := range listeners {
|
||||
l.OnAdd(KV{
|
||||
Key: string(ev.Kv.Key),
|
||||
Val: string(ev.Kv.Value),
|
||||
})
|
||||
}
|
||||
case clientv3.EventTypeDelete:
|
||||
if vals, ok := c.values[key]; ok {
|
||||
delete(vals, string(ev.Kv.Key))
|
||||
}
|
||||
for _, l := range listeners {
|
||||
l.OnDelete(KV{
|
||||
Key: string(ev.Kv.Key),
|
||||
Val: string(ev.Kv.Value),
|
||||
})
|
||||
}
|
||||
default:
|
||||
logx.Errorf("Unknown event type: %v", ev.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) load(cli EtcdClient, key string) {
|
||||
var resp *clientv3.GetResponse
|
||||
for {
|
||||
var err error
|
||||
ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout)
|
||||
resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix())
|
||||
cancel()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
logx.Error(err)
|
||||
time.Sleep(coolDownInterval)
|
||||
}
|
||||
|
||||
var kvs []KV
|
||||
c.lock.Lock()
|
||||
for _, ev := range resp.Kvs {
|
||||
kvs = append(kvs, KV{
|
||||
Key: string(ev.Key),
|
||||
Val: string(ev.Value),
|
||||
})
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
c.handleChanges(key, kvs)
|
||||
}
|
||||
|
||||
func (c *cluster) monitor(key string, l UpdateListener) error {
|
||||
c.lock.Lock()
|
||||
c.listeners[key] = append(c.listeners[key], l)
|
||||
c.lock.Unlock()
|
||||
|
||||
cli, err := c.getClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.load(cli, key)
|
||||
c.watchGroup.Run(func() {
|
||||
c.watch(cli, key)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cluster) newClient() (EtcdClient, error) {
|
||||
cli, err := NewClient(c.endpoints)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go c.watchConnState(cli)
|
||||
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
func (c *cluster) reload(cli EtcdClient) {
|
||||
c.lock.Lock()
|
||||
close(c.done)
|
||||
c.watchGroup.Wait()
|
||||
c.done = make(chan lang.PlaceholderType)
|
||||
c.watchGroup = threading.NewRoutineGroup()
|
||||
var keys []string
|
||||
for k := range c.listeners {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
for _, key := range keys {
|
||||
k := key
|
||||
c.watchGroup.Run(func() {
|
||||
c.load(cli, k)
|
||||
c.watch(cli, k)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) watch(cli EtcdClient, key string) {
|
||||
rch := cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
|
||||
for {
|
||||
select {
|
||||
case wresp, ok := <-rch:
|
||||
if !ok {
|
||||
logx.Error("etcd monitor chan has been closed")
|
||||
return
|
||||
}
|
||||
if wresp.Canceled {
|
||||
logx.Error("etcd monitor chan has been canceled")
|
||||
return
|
||||
}
|
||||
if wresp.Err() != nil {
|
||||
logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
|
||||
return
|
||||
}
|
||||
|
||||
c.handleWatchEvents(key, wresp.Events)
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) watchConnState(cli EtcdClient) {
|
||||
watcher := newStateWatcher()
|
||||
watcher.addListener(func() {
|
||||
go c.reload(cli)
|
||||
})
|
||||
watcher.watch(cli.ActiveConnection())
|
||||
}
|
||||
|
||||
func DialClient(endpoints []string) (EtcdClient, error) {
|
||||
return clientv3.New(clientv3.Config{
|
||||
Endpoints: endpoints,
|
||||
AutoSyncInterval: autoSyncInterval,
|
||||
DialTimeout: DialTimeout,
|
||||
DialKeepAliveTime: dialKeepAliveTime,
|
||||
DialKeepAliveTimeout: DialTimeout,
|
||||
RejectOldCluster: true,
|
||||
})
|
||||
}
|
||||
|
||||
func getClusterKey(endpoints []string) string {
|
||||
sort.Strings(endpoints)
|
||||
return strings.Join(endpoints, endpointsSeparator)
|
||||
}
|
||||
|
||||
func makeKeyPrefix(key string) string {
|
||||
return fmt.Sprintf("%s%c", key, Delimiter)
|
||||
}
|
245
core/discov/internal/registry_test.go
Normal file
245
core/discov/internal/registry_test.go
Normal file
@ -0,0 +1,245 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"zero/core/contextx"
|
||||
"zero/core/stringx"
|
||||
|
||||
"github.com/coreos/etcd/mvcc/mvccpb"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
|
||||
var mockLock sync.Mutex
|
||||
|
||||
func setMockClient(cli EtcdClient) func() {
|
||||
mockLock.Lock()
|
||||
NewClient = func([]string) (EtcdClient, error) {
|
||||
return cli, nil
|
||||
}
|
||||
return func() {
|
||||
NewClient = DialClient
|
||||
mockLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCluster(t *testing.T) {
|
||||
c1 := GetRegistry().getCluster([]string{"first"})
|
||||
c2 := GetRegistry().getCluster([]string{"second"})
|
||||
c3 := GetRegistry().getCluster([]string{"first"})
|
||||
assert.Equal(t, c1, c3)
|
||||
assert.NotEqual(t, c1, c2)
|
||||
}
|
||||
|
||||
func TestGetClusterKey(t *testing.T) {
|
||||
assert.Equal(t, getClusterKey([]string{"localhost:1234", "remotehost:5678"}),
|
||||
getClusterKey([]string{"remotehost:5678", "localhost:1234"}))
|
||||
}
|
||||
|
||||
func TestCluster_HandleChanges(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
l := NewMockUpdateListener(ctrl)
|
||||
l.EXPECT().OnAdd(KV{
|
||||
Key: "first",
|
||||
Val: "1",
|
||||
})
|
||||
l.EXPECT().OnAdd(KV{
|
||||
Key: "second",
|
||||
Val: "2",
|
||||
})
|
||||
l.EXPECT().OnDelete(KV{
|
||||
Key: "first",
|
||||
Val: "1",
|
||||
})
|
||||
l.EXPECT().OnDelete(KV{
|
||||
Key: "second",
|
||||
Val: "2",
|
||||
})
|
||||
l.EXPECT().OnAdd(KV{
|
||||
Key: "third",
|
||||
Val: "3",
|
||||
})
|
||||
l.EXPECT().OnAdd(KV{
|
||||
Key: "fourth",
|
||||
Val: "4",
|
||||
})
|
||||
c := newCluster([]string{"any"})
|
||||
c.listeners["any"] = []UpdateListener{l}
|
||||
c.handleChanges("any", []KV{
|
||||
{
|
||||
Key: "first",
|
||||
Val: "1",
|
||||
},
|
||||
{
|
||||
Key: "second",
|
||||
Val: "2",
|
||||
},
|
||||
})
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"first": "1",
|
||||
"second": "2",
|
||||
}, c.values["any"])
|
||||
c.handleChanges("any", []KV{
|
||||
{
|
||||
Key: "third",
|
||||
Val: "3",
|
||||
},
|
||||
{
|
||||
Key: "fourth",
|
||||
Val: "4",
|
||||
},
|
||||
})
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"third": "3",
|
||||
"fourth": "4",
|
||||
}, c.values["any"])
|
||||
}
|
||||
|
||||
func TestCluster_Load(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
cli := NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Get(gomock.Any(), "any/", gomock.Any()).Return(&clientv3.GetResponse{
|
||||
Kvs: []*mvccpb.KeyValue{
|
||||
{
|
||||
Key: []byte("hello"),
|
||||
Value: []byte("world"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
cli.EXPECT().Ctx().Return(context.Background())
|
||||
c := &cluster{
|
||||
values: make(map[string]map[string]string),
|
||||
}
|
||||
c.load(cli, "any")
|
||||
}
|
||||
|
||||
func TestCluster_Watch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method int
|
||||
eventType mvccpb.Event_EventType
|
||||
}{
|
||||
{
|
||||
name: "add",
|
||||
eventType: clientv3.EventTypePut,
|
||||
},
|
||||
{
|
||||
name: "delete",
|
||||
eventType: clientv3.EventTypeDelete,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
cli := NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
ch := make(chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
|
||||
cli.EXPECT().Ctx().Return(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
c := &cluster{
|
||||
listeners: make(map[string][]UpdateListener),
|
||||
values: make(map[string]map[string]string),
|
||||
}
|
||||
listener := NewMockUpdateListener(ctrl)
|
||||
c.listeners["any"] = []UpdateListener{listener}
|
||||
listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) {
|
||||
assert.Equal(t, "hello", kv.Key)
|
||||
assert.Equal(t, "world", kv.Val)
|
||||
wg.Done()
|
||||
}).MaxTimes(1)
|
||||
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) {
|
||||
wg.Done()
|
||||
}).MaxTimes(1)
|
||||
go c.watch(cli, "any")
|
||||
ch <- clientv3.WatchResponse{
|
||||
Events: []*clientv3.Event{
|
||||
{
|
||||
Type: test.eventType,
|
||||
Kv: &mvccpb.KeyValue{
|
||||
Key: []byte("hello"),
|
||||
Value: []byte("world"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClusterWatch_RespFailures(t *testing.T) {
|
||||
resps := []clientv3.WatchResponse{
|
||||
{
|
||||
Canceled: true,
|
||||
},
|
||||
{
|
||||
// cause resp.Err() != nil
|
||||
CompactRevision: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, resp := range resps {
|
||||
t.Run(stringx.Rand(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
cli := NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
ch := make(chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
|
||||
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
|
||||
c := new(cluster)
|
||||
go func() {
|
||||
ch <- resp
|
||||
}()
|
||||
c.watch(cli, "any")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClusterWatch_CloseChan(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
cli := NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
ch := make(chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
|
||||
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
|
||||
c := new(cluster)
|
||||
go func() {
|
||||
close(ch)
|
||||
}()
|
||||
c.watch(cli, "any")
|
||||
}
|
||||
|
||||
func TestValueOnlyContext(t *testing.T) {
|
||||
ctx := contextx.ValueOnlyFrom(context.Background())
|
||||
ctx.Done()
|
||||
assert.Nil(t, ctx.Err())
|
||||
}
|
||||
|
||||
type mockedSharedCalls struct {
|
||||
fn func() (interface{}, error)
|
||||
}
|
||||
|
||||
func (c mockedSharedCalls) Do(_ string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
return c.fn()
|
||||
}
|
||||
|
||||
func (c mockedSharedCalls) DoEx(_ string, fn func() (interface{}, error)) (interface{}, bool, error) {
|
||||
val, err := c.fn()
|
||||
return val, true, err
|
||||
}
|
148
core/discov/internal/roundrobinbalancer.go
Normal file
148
core/discov/internal/roundrobinbalancer.go
Normal file
@ -0,0 +1,148 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
type roundRobinBalancer struct {
|
||||
*baseBalancer
|
||||
conns []serverConn
|
||||
index int
|
||||
}
|
||||
|
||||
func NewRoundRobinBalancer(dialFn DialFn, closeFn CloseFn, exclusive bool) *roundRobinBalancer {
|
||||
balancer := new(roundRobinBalancer)
|
||||
balancer.baseBalancer = newBaseBalancer(dialFn, closeFn, exclusive)
|
||||
return balancer
|
||||
}
|
||||
|
||||
func (b *roundRobinBalancer) AddConn(kv KV) error {
|
||||
var conn interface{}
|
||||
prev, found := b.addKv(kv.Key, kv.Val)
|
||||
if found {
|
||||
conn = b.handlePrevious(prev, kv.Val)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
var err error
|
||||
conn, err = b.dialFn(kv.Val)
|
||||
if err != nil {
|
||||
b.removeKv(kv.Key)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
b.conns = append(b.conns, serverConn{
|
||||
key: kv.Key,
|
||||
conn: conn,
|
||||
})
|
||||
b.notify(kv.Key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *roundRobinBalancer) handlePrevious(prev []string, server string) interface{} {
|
||||
if len(prev) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if b.exclusive {
|
||||
for _, item := range prev {
|
||||
conns := b.conns[:0]
|
||||
for _, each := range b.conns {
|
||||
if each.key == item {
|
||||
if err := b.closeFn(server, each.conn); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
} else {
|
||||
conns = append(conns, each)
|
||||
}
|
||||
}
|
||||
b.conns = conns
|
||||
}
|
||||
} else {
|
||||
for _, each := range b.conns {
|
||||
if each.key == prev[0] {
|
||||
return each.conn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *roundRobinBalancer) initialize() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
if len(b.conns) > 0 {
|
||||
b.index = rand.Intn(len(b.conns))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *roundRobinBalancer) IsEmpty() bool {
|
||||
b.lock.Lock()
|
||||
empty := len(b.conns) == 0
|
||||
b.lock.Unlock()
|
||||
|
||||
return empty
|
||||
}
|
||||
|
||||
func (b *roundRobinBalancer) Next(...string) (interface{}, bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if len(b.conns) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
b.index = (b.index + 1) % len(b.conns)
|
||||
return b.conns[b.index].conn, true
|
||||
}
|
||||
|
||||
func (b *roundRobinBalancer) notify(key string) {
|
||||
if b.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// b.servers has the format of map[conn][]key
|
||||
var keys []string
|
||||
var values []string
|
||||
for k, v := range b.servers {
|
||||
values = append(values, k)
|
||||
keys = append(keys, v...)
|
||||
}
|
||||
|
||||
b.listener.OnUpdate(keys, values, key)
|
||||
}
|
||||
|
||||
func (b *roundRobinBalancer) RemoveKey(key string) {
|
||||
server, keep := b.removeKv(key)
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
conns := b.conns[:0]
|
||||
for _, conn := range b.conns {
|
||||
if conn.key == key {
|
||||
// there are other keys assocated with the conn, don't close the conn.
|
||||
if keep {
|
||||
continue
|
||||
}
|
||||
if err := b.closeFn(server, conn.conn); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
} else {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
b.conns = conns
|
||||
// notify without new key
|
||||
b.notify("")
|
||||
}
|
321
core/discov/internal/roundrobinbalancer_test.go
Normal file
321
core/discov/internal/roundrobinbalancer_test.go
Normal file
@ -0,0 +1,321 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/mathx"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestRoundRobin_addConn(t *testing.T) {
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return errors.New("error")
|
||||
}, false)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey1",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, []serverConn{
|
||||
{
|
||||
key: "thekey1",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey1"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey1": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey2",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, []serverConn{
|
||||
{
|
||||
key: "thekey1",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
{
|
||||
key: "thekey2",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey1", "thekey2"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey1": "thevalue",
|
||||
"thekey2": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.False(t, b.IsEmpty())
|
||||
|
||||
b.RemoveKey("thekey1")
|
||||
assert.EqualValues(t, []serverConn{
|
||||
{
|
||||
key: "thekey2",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey2"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey2": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.False(t, b.IsEmpty())
|
||||
|
||||
b.RemoveKey("thekey2")
|
||||
assert.EqualValues(t, []serverConn{}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{}, b.mapping)
|
||||
assert.True(t, b.IsEmpty())
|
||||
}
|
||||
|
||||
func TestRoundRobin_addConnExclusive(t *testing.T) {
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, true)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey1",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, []serverConn{
|
||||
{
|
||||
key: "thekey1",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey1"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey1": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey2",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, []serverConn{
|
||||
{
|
||||
key: "thekey2",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey2"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey2": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.False(t, b.IsEmpty())
|
||||
|
||||
b.RemoveKey("thekey1")
|
||||
b.RemoveKey("thekey2")
|
||||
assert.EqualValues(t, []serverConn{}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{}, b.mapping)
|
||||
assert.True(t, b.IsEmpty())
|
||||
}
|
||||
|
||||
func TestRoundRobin_addConnDupExclusive(t *testing.T) {
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return errors.New("error")
|
||||
}, true)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey1",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, []serverConn{
|
||||
{
|
||||
key: "thekey1",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"thevalue": {"thekey1"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey1": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey",
|
||||
Val: "anothervalue",
|
||||
}))
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey1",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.EqualValues(t, []serverConn{
|
||||
{
|
||||
key: "thekey",
|
||||
conn: mockConn{server: "anothervalue"},
|
||||
},
|
||||
{
|
||||
key: "thekey1",
|
||||
conn: mockConn{server: "thevalue"},
|
||||
},
|
||||
}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{
|
||||
"anothervalue": {"thekey"},
|
||||
"thevalue": {"thekey1"},
|
||||
}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"thekey": "anothervalue",
|
||||
"thekey1": "thevalue",
|
||||
}, b.mapping)
|
||||
assert.False(t, b.IsEmpty())
|
||||
|
||||
b.RemoveKey("thekey")
|
||||
b.RemoveKey("thekey1")
|
||||
assert.EqualValues(t, []serverConn{}, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{}, b.mapping)
|
||||
assert.True(t, b.IsEmpty())
|
||||
}
|
||||
|
||||
func TestRoundRobin_addConnError(t *testing.T) {
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return nil, errors.New("error")
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, true)
|
||||
assert.NotNil(t, b.AddConn(KV{
|
||||
Key: "thekey1",
|
||||
Val: "thevalue",
|
||||
}))
|
||||
assert.Nil(t, b.conns)
|
||||
assert.EqualValues(t, map[string][]string{}, b.servers)
|
||||
assert.EqualValues(t, map[string]string{}, b.mapping)
|
||||
}
|
||||
|
||||
func TestRoundRobin_initialize(t *testing.T) {
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, true)
|
||||
for i := 0; i < 100; i++ {
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey/" + strconv.Itoa(i),
|
||||
Val: "thevalue/" + strconv.Itoa(i),
|
||||
}))
|
||||
}
|
||||
|
||||
m := make(map[int]int)
|
||||
const total = 1000
|
||||
for i := 0; i < total; i++ {
|
||||
b.initialize()
|
||||
m[b.index]++
|
||||
}
|
||||
|
||||
mi := make(map[interface{}]int, len(m))
|
||||
for k, v := range m {
|
||||
mi[k] = v
|
||||
}
|
||||
entropy := mathx.CalcEntropy(mi, total)
|
||||
assert.True(t, entropy > .95)
|
||||
}
|
||||
|
||||
func TestRoundRobin_next(t *testing.T) {
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return errors.New("error")
|
||||
}, true)
|
||||
const size = 100
|
||||
for i := 0; i < size; i++ {
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "thekey/" + strconv.Itoa(i),
|
||||
Val: "thevalue/" + strconv.Itoa(i),
|
||||
}))
|
||||
}
|
||||
|
||||
m := make(map[interface{}]int)
|
||||
const total = 10000
|
||||
for i := 0; i < total; i++ {
|
||||
val, ok := b.Next()
|
||||
assert.True(t, ok)
|
||||
m[val]++
|
||||
}
|
||||
|
||||
entropy := mathx.CalcEntropy(m, total)
|
||||
assert.Equal(t, size, len(m))
|
||||
assert.True(t, entropy > .95)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
b.RemoveKey("thekey/" + strconv.Itoa(i))
|
||||
}
|
||||
_, ok := b.Next()
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestRoundRobinBalancer_Listener(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, true)
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "key1",
|
||||
Val: "val1",
|
||||
}))
|
||||
assert.Nil(t, b.AddConn(KV{
|
||||
Key: "key2",
|
||||
Val: "val2",
|
||||
}))
|
||||
|
||||
listener := NewMockListener(ctrl)
|
||||
listener.EXPECT().OnUpdate(gomock.Any(), gomock.Any(), "key2").Do(func(keys, vals, _ interface{}) {
|
||||
sort.Strings(vals.([]string))
|
||||
sort.Strings(keys.([]string))
|
||||
assert.EqualValues(t, []string{"key1", "key2"}, keys)
|
||||
assert.EqualValues(t, []string{"val1", "val2"}, vals)
|
||||
})
|
||||
b.setListener(listener)
|
||||
b.notify("key2")
|
||||
}
|
||||
|
||||
func TestRoundRobinBalancer_remove(t *testing.T) {
|
||||
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
|
||||
return mockConn{
|
||||
server: server,
|
||||
}, nil
|
||||
}, func(server string, conn interface{}) error {
|
||||
return nil
|
||||
}, true)
|
||||
|
||||
assert.Nil(t, b.handlePrevious(nil, "any"))
|
||||
_, ok := b.doRemoveKv("any")
|
||||
assert.True(t, ok)
|
||||
}
|
58
core/discov/internal/statewatcher.go
Normal file
58
core/discov/internal/statewatcher.go
Normal file
@ -0,0 +1,58 @@
|
||||
//go:generate mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc/connectivity"
|
||||
)
|
||||
|
||||
type (
|
||||
etcdConn interface {
|
||||
GetState() connectivity.State
|
||||
WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool
|
||||
}
|
||||
|
||||
stateWatcher struct {
|
||||
disconnected bool
|
||||
currentState connectivity.State
|
||||
listeners []func()
|
||||
lock sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
func newStateWatcher() *stateWatcher {
|
||||
return new(stateWatcher)
|
||||
}
|
||||
|
||||
func (sw *stateWatcher) addListener(l func()) {
|
||||
sw.lock.Lock()
|
||||
sw.listeners = append(sw.listeners, l)
|
||||
sw.lock.Unlock()
|
||||
}
|
||||
|
||||
func (sw *stateWatcher) watch(conn etcdConn) {
|
||||
sw.currentState = conn.GetState()
|
||||
for {
|
||||
if conn.WaitForStateChange(context.Background(), sw.currentState) {
|
||||
newState := conn.GetState()
|
||||
sw.lock.Lock()
|
||||
sw.currentState = newState
|
||||
|
||||
switch newState {
|
||||
case connectivity.TransientFailure, connectivity.Shutdown:
|
||||
sw.disconnected = true
|
||||
case connectivity.Ready:
|
||||
if sw.disconnected {
|
||||
sw.disconnected = false
|
||||
for _, l := range sw.listeners {
|
||||
l()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sw.lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
63
core/discov/internal/statewatcher_mock.go
Normal file
63
core/discov/internal/statewatcher_mock.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: statewatcher.go
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
|
||||
import (
|
||||
context "context"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
connectivity "google.golang.org/grpc/connectivity"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MocketcdConn is a mock of etcdConn interface
|
||||
type MocketcdConn struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MocketcdConnMockRecorder
|
||||
}
|
||||
|
||||
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn
|
||||
type MocketcdConnMockRecorder struct {
|
||||
mock *MocketcdConn
|
||||
}
|
||||
|
||||
// NewMocketcdConn creates a new mock instance
|
||||
func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn {
|
||||
mock := &MocketcdConn{ctrl: ctrl}
|
||||
mock.recorder = &MocketcdConnMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetState mocks base method
|
||||
func (m *MocketcdConn) GetState() connectivity.State {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetState")
|
||||
ret0, _ := ret[0].(connectivity.State)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetState indicates an expected call of GetState
|
||||
func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState))
|
||||
}
|
||||
|
||||
// WaitForStateChange mocks base method
|
||||
func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// WaitForStateChange indicates an expected call of WaitForStateChange
|
||||
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
|
||||
}
|
27
core/discov/internal/statewatcher_test.go
Normal file
27
core/discov/internal/statewatcher_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
)
|
||||
|
||||
func TestStateWatcher_watch(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
watcher := newStateWatcher()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
watcher.addListener(func() {
|
||||
wg.Done()
|
||||
})
|
||||
conn := NewMocketcdConn(ctrl)
|
||||
conn.EXPECT().GetState().Return(connectivity.Ready)
|
||||
conn.EXPECT().GetState().Return(connectivity.TransientFailure)
|
||||
conn.EXPECT().GetState().Return(connectivity.Ready).AnyTimes()
|
||||
conn.EXPECT().WaitForStateChange(gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
||||
go watcher.watch(conn)
|
||||
wg.Wait()
|
||||
}
|
14
core/discov/internal/updatelistener.go
Normal file
14
core/discov/internal/updatelistener.go
Normal file
@ -0,0 +1,14 @@
|
||||
//go:generate mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
|
||||
package internal
|
||||
|
||||
type (
|
||||
KV struct {
|
||||
Key string
|
||||
Val string
|
||||
}
|
||||
|
||||
UpdateListener interface {
|
||||
OnAdd(kv KV)
|
||||
OnDelete(kv KV)
|
||||
}
|
||||
)
|
57
core/discov/internal/updatelistener_mock.go
Normal file
57
core/discov/internal/updatelistener_mock.go
Normal file
@ -0,0 +1,57 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: updatelistener.go
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockUpdateListener is a mock of UpdateListener interface
|
||||
type MockUpdateListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUpdateListenerMockRecorder
|
||||
}
|
||||
|
||||
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener
|
||||
type MockUpdateListenerMockRecorder struct {
|
||||
mock *MockUpdateListener
|
||||
}
|
||||
|
||||
// NewMockUpdateListener creates a new mock instance
|
||||
func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener {
|
||||
mock := &MockUpdateListener{ctrl: ctrl}
|
||||
mock.recorder = &MockUpdateListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// OnAdd mocks base method
|
||||
func (m *MockUpdateListener) OnAdd(kv KV) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnAdd", kv)
|
||||
}
|
||||
|
||||
// OnAdd indicates an expected call of OnAdd
|
||||
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
|
||||
}
|
||||
|
||||
// OnDelete mocks base method
|
||||
func (m *MockUpdateListener) OnDelete(kv KV) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnDelete", kv)
|
||||
}
|
||||
|
||||
// OnDelete indicates an expected call of OnDelete
|
||||
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
|
||||
}
|
19
core/discov/internal/vars.go
Normal file
19
core/discov/internal/vars.go
Normal file
@ -0,0 +1,19 @@
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
autoSyncInterval = time.Minute
|
||||
coolDownInterval = time.Second
|
||||
dialTimeout = 5 * time.Second
|
||||
dialKeepAliveTime = 5 * time.Second
|
||||
requestTimeout = 3 * time.Second
|
||||
Delimiter = '/'
|
||||
endpointsSeparator = ","
|
||||
)
|
||||
|
||||
var (
|
||||
DialTimeout = dialTimeout
|
||||
RequestTimeout = requestTimeout
|
||||
NewClient = DialClient
|
||||
)
|
4
core/discov/kubernetes/discov-namespace.yaml
Normal file
4
core/discov/kubernetes/discov-namespace.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: discov
|
378
core/discov/kubernetes/etcd.yaml
Normal file
378
core/discov/kubernetes/etcd.yaml
Normal file
@ -0,0 +1,378 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: etcd
|
||||
namespace: discov
|
||||
spec:
|
||||
ports:
|
||||
- name: etcd-port
|
||||
port: 2379
|
||||
protocol: TCP
|
||||
targetPort: 2379
|
||||
selector:
|
||||
app: etcd
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: etcd
|
||||
etcd_node: etcd0
|
||||
name: etcd0
|
||||
namespace: discov
|
||||
spec:
|
||||
containers:
|
||||
- command:
|
||||
- /usr/local/bin/etcd
|
||||
- --name
|
||||
- etcd0
|
||||
- --initial-advertise-peer-urls
|
||||
- http://etcd0:2380
|
||||
- --listen-peer-urls
|
||||
- http://0.0.0.0:2380
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd0:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
- new
|
||||
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
|
||||
name: etcd0
|
||||
ports:
|
||||
- containerPort: 2379
|
||||
name: client
|
||||
protocol: TCP
|
||||
- containerPort: 2380
|
||||
name: server
|
||||
protocol: TCP
|
||||
imagePullSecrets:
|
||||
- name: aliyun
|
||||
affinity:
|
||||
podAntiAffinity:
|
||||
requiredDuringSchedulingIgnoredDuringExecution:
|
||||
- labelSelector:
|
||||
matchExpressions:
|
||||
- key: app
|
||||
operator: In
|
||||
values:
|
||||
- etcd
|
||||
topologyKey: "kubernetes.io/hostname"
|
||||
restartPolicy: Always
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
labels:
|
||||
etcd_node: etcd0
|
||||
name: etcd0
|
||||
namespace: discov
|
||||
spec:
|
||||
ports:
|
||||
- name: client
|
||||
port: 2379
|
||||
protocol: TCP
|
||||
targetPort: 2379
|
||||
- name: server
|
||||
port: 2380
|
||||
protocol: TCP
|
||||
targetPort: 2380
|
||||
selector:
|
||||
etcd_node: etcd0
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: etcd
|
||||
etcd_node: etcd1
|
||||
name: etcd1
|
||||
namespace: discov
|
||||
spec:
|
||||
containers:
|
||||
- command:
|
||||
- /usr/local/bin/etcd
|
||||
- --name
|
||||
- etcd1
|
||||
- --initial-advertise-peer-urls
|
||||
- http://etcd1:2380
|
||||
- --listen-peer-urls
|
||||
- http://0.0.0.0:2380
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd1:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
- new
|
||||
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
|
||||
name: etcd1
|
||||
ports:
|
||||
- containerPort: 2379
|
||||
name: client
|
||||
protocol: TCP
|
||||
- containerPort: 2380
|
||||
name: server
|
||||
protocol: TCP
|
||||
imagePullSecrets:
|
||||
- name: aliyun
|
||||
affinity:
|
||||
podAntiAffinity:
|
||||
requiredDuringSchedulingIgnoredDuringExecution:
|
||||
- labelSelector:
|
||||
matchExpressions:
|
||||
- key: app
|
||||
operator: In
|
||||
values:
|
||||
- etcd
|
||||
topologyKey: "kubernetes.io/hostname"
|
||||
restartPolicy: Always
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
labels:
|
||||
etcd_node: etcd1
|
||||
name: etcd1
|
||||
namespace: discov
|
||||
spec:
|
||||
ports:
|
||||
- name: client
|
||||
port: 2379
|
||||
protocol: TCP
|
||||
targetPort: 2379
|
||||
- name: server
|
||||
port: 2380
|
||||
protocol: TCP
|
||||
targetPort: 2380
|
||||
selector:
|
||||
etcd_node: etcd1
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: etcd
|
||||
etcd_node: etcd2
|
||||
name: etcd2
|
||||
namespace: discov
|
||||
spec:
|
||||
containers:
|
||||
- command:
|
||||
- /usr/local/bin/etcd
|
||||
- --name
|
||||
- etcd2
|
||||
- --initial-advertise-peer-urls
|
||||
- http://etcd2:2380
|
||||
- --listen-peer-urls
|
||||
- http://0.0.0.0:2380
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd2:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
- new
|
||||
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
|
||||
name: etcd2
|
||||
ports:
|
||||
- containerPort: 2379
|
||||
name: client
|
||||
protocol: TCP
|
||||
- containerPort: 2380
|
||||
name: server
|
||||
protocol: TCP
|
||||
imagePullSecrets:
|
||||
- name: aliyun
|
||||
affinity:
|
||||
podAntiAffinity:
|
||||
requiredDuringSchedulingIgnoredDuringExecution:
|
||||
- labelSelector:
|
||||
matchExpressions:
|
||||
- key: app
|
||||
operator: In
|
||||
values:
|
||||
- etcd
|
||||
topologyKey: "kubernetes.io/hostname"
|
||||
restartPolicy: Always
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
labels:
|
||||
etcd_node: etcd2
|
||||
name: etcd2
|
||||
namespace: discov
|
||||
spec:
|
||||
ports:
|
||||
- name: client
|
||||
port: 2379
|
||||
protocol: TCP
|
||||
targetPort: 2379
|
||||
- name: server
|
||||
port: 2380
|
||||
protocol: TCP
|
||||
targetPort: 2380
|
||||
selector:
|
||||
etcd_node: etcd2
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: etcd
|
||||
etcd_node: etcd3
|
||||
name: etcd3
|
||||
namespace: discov
|
||||
spec:
|
||||
containers:
|
||||
- command:
|
||||
- /usr/local/bin/etcd
|
||||
- --name
|
||||
- etcd3
|
||||
- --initial-advertise-peer-urls
|
||||
- http://etcd3:2380
|
||||
- --listen-peer-urls
|
||||
- http://0.0.0.0:2380
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd3:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
- new
|
||||
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
|
||||
name: etcd3
|
||||
ports:
|
||||
- containerPort: 2379
|
||||
name: client
|
||||
protocol: TCP
|
||||
- containerPort: 2380
|
||||
name: server
|
||||
protocol: TCP
|
||||
imagePullSecrets:
|
||||
- name: aliyun
|
||||
affinity:
|
||||
podAntiAffinity:
|
||||
requiredDuringSchedulingIgnoredDuringExecution:
|
||||
- labelSelector:
|
||||
matchExpressions:
|
||||
- key: app
|
||||
operator: In
|
||||
values:
|
||||
- etcd
|
||||
topologyKey: "kubernetes.io/hostname"
|
||||
restartPolicy: Always
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
labels:
|
||||
etcd_node: etcd3
|
||||
name: etcd3
|
||||
namespace: discov
|
||||
spec:
|
||||
ports:
|
||||
- name: client
|
||||
port: 2379
|
||||
protocol: TCP
|
||||
targetPort: 2379
|
||||
- name: server
|
||||
port: 2380
|
||||
protocol: TCP
|
||||
targetPort: 2380
|
||||
selector:
|
||||
etcd_node: etcd3
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Pod
|
||||
metadata:
|
||||
labels:
|
||||
app: etcd
|
||||
etcd_node: etcd4
|
||||
name: etcd4
|
||||
namespace: discov
|
||||
spec:
|
||||
containers:
|
||||
- command:
|
||||
- /usr/local/bin/etcd
|
||||
- --name
|
||||
- etcd4
|
||||
- --initial-advertise-peer-urls
|
||||
- http://etcd4:2380
|
||||
- --listen-peer-urls
|
||||
- http://0.0.0.0:2380
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd4:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
- new
|
||||
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
|
||||
name: etcd4
|
||||
ports:
|
||||
- containerPort: 2379
|
||||
name: client
|
||||
protocol: TCP
|
||||
- containerPort: 2380
|
||||
name: server
|
||||
protocol: TCP
|
||||
imagePullSecrets:
|
||||
- name: aliyun
|
||||
affinity:
|
||||
podAntiAffinity:
|
||||
requiredDuringSchedulingIgnoredDuringExecution:
|
||||
- labelSelector:
|
||||
matchExpressions:
|
||||
- key: app
|
||||
operator: In
|
||||
values:
|
||||
- etcd
|
||||
topologyKey: "kubernetes.io/hostname"
|
||||
restartPolicy: Always
|
||||
|
||||
---
|
||||
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
labels:
|
||||
etcd_node: etcd4
|
||||
name: etcd4
|
||||
namespace: discov
|
||||
spec:
|
||||
ports:
|
||||
- name: client
|
||||
port: 2379
|
||||
protocol: TCP
|
||||
targetPort: 2379
|
||||
- name: server
|
||||
port: 2380
|
||||
protocol: TCP
|
||||
targetPort: 2380
|
||||
selector:
|
||||
etcd_node: etcd4
|
143
core/discov/publisher.go
Normal file
143
core/discov/publisher.go
Normal file
@ -0,0 +1,143 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"zero/core/discov/internal"
|
||||
"zero/core/lang"
|
||||
"zero/core/logx"
|
||||
"zero/core/proc"
|
||||
"zero/core/syncx"
|
||||
"zero/core/threading"
|
||||
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
|
||||
type (
|
||||
PublisherOption func(client *Publisher)
|
||||
|
||||
Publisher struct {
|
||||
endpoints []string
|
||||
key string
|
||||
fullKey string
|
||||
id int64
|
||||
value string
|
||||
lease clientv3.LeaseID
|
||||
quit *syncx.DoneChan
|
||||
pauseChan chan lang.PlaceholderType
|
||||
resumeChan chan lang.PlaceholderType
|
||||
}
|
||||
)
|
||||
|
||||
func NewPublisher(endpoints []string, key, value string, opts ...PublisherOption) *Publisher {
|
||||
publisher := &Publisher{
|
||||
endpoints: endpoints,
|
||||
key: key,
|
||||
value: value,
|
||||
quit: syncx.NewDoneChan(),
|
||||
pauseChan: make(chan lang.PlaceholderType),
|
||||
resumeChan: make(chan lang.PlaceholderType),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(publisher)
|
||||
}
|
||||
|
||||
return publisher
|
||||
}
|
||||
|
||||
func (p *Publisher) KeepAlive() error {
|
||||
cli, err := internal.GetRegistry().GetConn(p.endpoints)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.lease, err = p.register(cli)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proc.AddWrapUpListener(func() {
|
||||
p.Stop()
|
||||
})
|
||||
|
||||
return p.keepAliveAsync(cli)
|
||||
}
|
||||
|
||||
func (p *Publisher) Pause() {
|
||||
p.pauseChan <- lang.Placeholder
|
||||
}
|
||||
|
||||
func (p *Publisher) Resume() {
|
||||
p.resumeChan <- lang.Placeholder
|
||||
}
|
||||
|
||||
func (p *Publisher) Stop() {
|
||||
p.quit.Close()
|
||||
}
|
||||
|
||||
func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
ch, err := cli.KeepAlive(cli.Ctx(), p.lease)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
threading.GoSafe(func() {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
p.revoke(cli)
|
||||
if err := p.KeepAlive(); err != nil {
|
||||
logx.Errorf("KeepAlive: %s", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
case <-p.pauseChan:
|
||||
logx.Infof("paused etcd renew, key: %s, value: %s", p.key, p.value)
|
||||
p.revoke(cli)
|
||||
select {
|
||||
case <-p.resumeChan:
|
||||
if err := p.KeepAlive(); err != nil {
|
||||
logx.Errorf("KeepAlive: %s", err.Error())
|
||||
}
|
||||
return
|
||||
case <-p.quit.Done():
|
||||
return
|
||||
}
|
||||
case <-p.quit.Done():
|
||||
p.revoke(cli)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, error) {
|
||||
resp, err := client.Grant(client.Ctx(), TimeToLive)
|
||||
if err != nil {
|
||||
return clientv3.NoLease, err
|
||||
}
|
||||
|
||||
lease := resp.ID
|
||||
if p.id > 0 {
|
||||
p.fullKey = makeEtcdKey(p.key, p.id)
|
||||
} else {
|
||||
p.fullKey = makeEtcdKey(p.key, int64(lease))
|
||||
}
|
||||
_, err = client.Put(client.Ctx(), p.fullKey, p.value, clientv3.WithLease(lease))
|
||||
|
||||
return lease, err
|
||||
}
|
||||
|
||||
func (p *Publisher) revoke(cli internal.EtcdClient) {
|
||||
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithId(id int64) PublisherOption {
|
||||
return func(publisher *Publisher) {
|
||||
publisher.id = id
|
||||
}
|
||||
}
|
151
core/discov/publisher_test.go
Normal file
151
core/discov/publisher_test.go
Normal file
@ -0,0 +1,151 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"zero/core/discov/internal"
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestPublisher_register(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
|
||||
ID: id,
|
||||
}, nil)
|
||||
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
_, err := pub.register(cli)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_registerWithId(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id = 2
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
|
||||
ID: 1,
|
||||
}, nil)
|
||||
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
|
||||
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id))
|
||||
_, err := pub.register(cli)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_registerError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(nil, errors.New("error"))
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
val, err := pub.register(cli)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, clientv3.NoLease, val)
|
||||
}
|
||||
|
||||
func TestPublisher_revoke(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().Revoke(gomock.Any(), id)
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
pub.revoke(cli)
|
||||
}
|
||||
|
||||
func TestPublisher_revokeError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Return(nil, errors.New("error"))
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
pub.revoke(cli)
|
||||
}
|
||||
|
||||
func TestPublisher_keepAliveAsyncError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id).Return(nil, errors.New("error"))
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
assert.NotNil(t, pub.keepAliveAsync(cli))
|
||||
}
|
||||
|
||||
func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
|
||||
wg.Done()
|
||||
})
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
pub.Stop()
|
||||
assert.Nil(t, pub.keepAliveAsync(cli))
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPublisher_keepAliveAsyncPause(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
|
||||
pub.Stop()
|
||||
wg.Done()
|
||||
})
|
||||
pub.lease = id
|
||||
assert.Nil(t, pub.keepAliveAsync(cli))
|
||||
pub.Pause()
|
||||
wg.Wait()
|
||||
}
|
35
core/discov/renewer.go
Normal file
35
core/discov/renewer.go
Normal file
@ -0,0 +1,35 @@
|
||||
package discov
|
||||
|
||||
import "zero/core/logx"
|
||||
|
||||
type (
|
||||
Renewer interface {
|
||||
Start()
|
||||
Stop()
|
||||
Pause()
|
||||
Resume()
|
||||
}
|
||||
|
||||
etcdRenewer struct {
|
||||
*Publisher
|
||||
}
|
||||
)
|
||||
|
||||
func NewRenewer(endpoints []string, key, value string, renewId int64) Renewer {
|
||||
var publisher *Publisher
|
||||
if renewId > 0 {
|
||||
publisher = NewPublisher(endpoints, key, value, WithId(renewId))
|
||||
} else {
|
||||
publisher = NewPublisher(endpoints, key, value)
|
||||
}
|
||||
|
||||
return &etcdRenewer{
|
||||
Publisher: publisher,
|
||||
}
|
||||
}
|
||||
|
||||
func (sr *etcdRenewer) Start() {
|
||||
if err := sr.KeepAlive(); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
186
core/discov/subclient.go
Normal file
186
core/discov/subclient.go
Normal file
@ -0,0 +1,186 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"zero/core/discov/internal"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const (
|
||||
_ = iota // keyBasedBalance, default
|
||||
idBasedBalance
|
||||
)
|
||||
|
||||
type (
|
||||
Listener internal.Listener
|
||||
|
||||
subClient struct {
|
||||
balancer internal.Balancer
|
||||
lock sync.Mutex
|
||||
cond *sync.Cond
|
||||
listeners []internal.Listener
|
||||
}
|
||||
|
||||
balanceOptions struct {
|
||||
balanceType int
|
||||
}
|
||||
|
||||
BalanceOption func(*balanceOptions)
|
||||
|
||||
RoundRobinSubClient struct {
|
||||
*subClient
|
||||
}
|
||||
|
||||
ConsistentSubClient struct {
|
||||
*subClient
|
||||
}
|
||||
|
||||
BatchConsistentSubClient struct {
|
||||
*ConsistentSubClient
|
||||
}
|
||||
)
|
||||
|
||||
func NewRoundRobinSubClient(endpoints []string, key string, dialFn internal.DialFn, closeFn internal.CloseFn,
|
||||
opts ...SubOption) (*RoundRobinSubClient, error) {
|
||||
var subOpts subOptions
|
||||
for _, opt := range opts {
|
||||
opt(&subOpts)
|
||||
}
|
||||
|
||||
cli, err := newSubClient(endpoints, key, internal.NewRoundRobinBalancer(dialFn, closeFn, subOpts.exclusive))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RoundRobinSubClient{
|
||||
subClient: cli,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewConsistentSubClient(endpoints []string, key string, dialFn internal.DialFn,
|
||||
closeFn internal.CloseFn, opts ...BalanceOption) (*ConsistentSubClient, error) {
|
||||
var balanceOpts balanceOptions
|
||||
for _, opt := range opts {
|
||||
opt(&balanceOpts)
|
||||
}
|
||||
|
||||
var keyer func(internal.KV) string
|
||||
switch balanceOpts.balanceType {
|
||||
case idBasedBalance:
|
||||
keyer = func(kv internal.KV) string {
|
||||
if id, ok := extractId(kv.Key); ok {
|
||||
return id
|
||||
} else {
|
||||
return kv.Key
|
||||
}
|
||||
}
|
||||
default:
|
||||
keyer = func(kv internal.KV) string {
|
||||
return kv.Val
|
||||
}
|
||||
}
|
||||
|
||||
cli, err := newSubClient(endpoints, key, internal.NewConsistentBalancer(dialFn, closeFn, keyer))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ConsistentSubClient{
|
||||
subClient: cli,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewBatchConsistentSubClient(endpoints []string, key string, dialFn internal.DialFn, closeFn internal.CloseFn,
|
||||
opts ...BalanceOption) (*BatchConsistentSubClient, error) {
|
||||
cli, err := NewConsistentSubClient(endpoints, key, dialFn, closeFn, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &BatchConsistentSubClient{
|
||||
ConsistentSubClient: cli,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newSubClient(endpoints []string, key string, balancer internal.Balancer) (*subClient, error) {
|
||||
client := &subClient{
|
||||
balancer: balancer,
|
||||
}
|
||||
client.cond = sync.NewCond(&client.lock)
|
||||
if err := internal.GetRegistry().Monitor(endpoints, key, client); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *subClient) AddListener(listener internal.Listener) {
|
||||
c.lock.Lock()
|
||||
c.listeners = append(c.listeners, listener)
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
func (c *subClient) OnAdd(kv internal.KV) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if err := c.balancer.AddConn(kv); err != nil {
|
||||
logx.Error(err)
|
||||
} else {
|
||||
c.cond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *subClient) OnDelete(kv internal.KV) {
|
||||
c.balancer.RemoveKey(kv.Key)
|
||||
}
|
||||
|
||||
func (c *subClient) WaitForServers() {
|
||||
logx.Error("Waiting for alive servers")
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.balancer.IsEmpty() {
|
||||
c.cond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *subClient) onAdd(keys []string, servers []string, newKey string) {
|
||||
// guarded by locked outside
|
||||
for _, listener := range c.listeners {
|
||||
listener.OnUpdate(keys, servers, newKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoundRobinSubClient) Next() (interface{}, bool) {
|
||||
return c.balancer.Next()
|
||||
}
|
||||
|
||||
func (c *ConsistentSubClient) Next(key string) (interface{}, bool) {
|
||||
return c.balancer.Next(key)
|
||||
}
|
||||
|
||||
func (bc *BatchConsistentSubClient) Next(keys []string) (map[interface{}][]string, bool) {
|
||||
if len(keys) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
result := make(map[interface{}][]string)
|
||||
for _, key := range keys {
|
||||
dest, ok := bc.ConsistentSubClient.Next(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
result[dest] = append(result[dest], key)
|
||||
}
|
||||
|
||||
return result, true
|
||||
}
|
||||
|
||||
func BalanceWithId() BalanceOption {
|
||||
return func(opts *balanceOptions) {
|
||||
opts.balanceType = idBasedBalance
|
||||
}
|
||||
}
|
151
core/discov/subscriber.go
Normal file
151
core/discov/subscriber.go
Normal file
@ -0,0 +1,151 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"zero/core/discov/internal"
|
||||
)
|
||||
|
||||
type (
|
||||
subOptions struct {
|
||||
exclusive bool
|
||||
}
|
||||
|
||||
SubOption func(opts *subOptions)
|
||||
|
||||
Subscriber struct {
|
||||
items *container
|
||||
}
|
||||
)
|
||||
|
||||
func NewSubscriber(endpoints []string, key string, opts ...SubOption) *Subscriber {
|
||||
var subOpts subOptions
|
||||
for _, opt := range opts {
|
||||
opt(&subOpts)
|
||||
}
|
||||
|
||||
subscriber := &Subscriber{
|
||||
items: newContainer(subOpts.exclusive),
|
||||
}
|
||||
internal.GetRegistry().Monitor(endpoints, key, subscriber.items)
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
func (s *Subscriber) Values() []string {
|
||||
return s.items.getValues()
|
||||
}
|
||||
|
||||
// exclusive means that key value can only be 1:1,
|
||||
// which means later added value will remove the keys associated with the same value previously.
|
||||
func Exclusive() SubOption {
|
||||
return func(opts *subOptions) {
|
||||
opts.exclusive = true
|
||||
}
|
||||
}
|
||||
|
||||
type container struct {
|
||||
exclusive bool
|
||||
values map[string][]string
|
||||
mapping map[string]string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func newContainer(exclusive bool) *container {
|
||||
return &container{
|
||||
exclusive: exclusive,
|
||||
values: make(map[string][]string),
|
||||
mapping: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *container) OnAdd(kv internal.KV) {
|
||||
c.addKv(kv.Key, kv.Val)
|
||||
}
|
||||
|
||||
func (c *container) OnDelete(kv internal.KV) {
|
||||
c.removeKey(kv.Key)
|
||||
}
|
||||
|
||||
// addKv adds the kv, returns if there are already other keys associate with the value
|
||||
func (c *container) addKv(key, value string) ([]string, bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
keys := c.values[value]
|
||||
previous := append([]string(nil), keys...)
|
||||
early := len(keys) > 0
|
||||
if c.exclusive && early {
|
||||
for _, each := range keys {
|
||||
c.doRemoveKey(each)
|
||||
}
|
||||
}
|
||||
c.values[value] = append(c.values[value], key)
|
||||
c.mapping[key] = value
|
||||
|
||||
if early {
|
||||
return previous, true
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *container) doRemoveKey(key string) {
|
||||
server, ok := c.mapping[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(c.mapping, key)
|
||||
keys := c.values[server]
|
||||
remain := keys[:0]
|
||||
|
||||
for _, k := range keys {
|
||||
if k != key {
|
||||
remain = append(remain, k)
|
||||
}
|
||||
}
|
||||
|
||||
if len(remain) > 0 {
|
||||
c.values[server] = remain
|
||||
} else {
|
||||
delete(c.values, server)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *container) getValues() []string {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
var vs []string
|
||||
for each := range c.values {
|
||||
vs = append(vs, each)
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
// removeKey removes the kv, returns true if there are still other keys associate with the value
|
||||
func (c *container) removeKey(key string) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.doRemoveKey(key)
|
||||
}
|
||||
|
||||
func (c *container) removeVal(val string) (empty bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
for k := range c.values {
|
||||
if k == val {
|
||||
delete(c.values, k)
|
||||
}
|
||||
}
|
||||
for k, v := range c.mapping {
|
||||
if v == val {
|
||||
delete(c.mapping, k)
|
||||
}
|
||||
}
|
||||
|
||||
return len(c.values) == 0
|
||||
}
|
21
core/errorx/atomicerror.go
Normal file
21
core/errorx/atomicerror.go
Normal file
@ -0,0 +1,21 @@
|
||||
package errorx
|
||||
|
||||
import "sync"
|
||||
|
||||
type AtomicError struct {
|
||||
err error
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (ae *AtomicError) Set(err error) {
|
||||
ae.lock.Lock()
|
||||
ae.err = err
|
||||
ae.lock.Unlock()
|
||||
}
|
||||
|
||||
func (ae *AtomicError) Load() error {
|
||||
ae.lock.Lock()
|
||||
err := ae.err
|
||||
ae.lock.Unlock()
|
||||
return err
|
||||
}
|
21
core/errorx/atomicerror_test.go
Normal file
21
core/errorx/atomicerror_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
package errorx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var errDummy = errors.New("hello")
|
||||
|
||||
func TestAtomicError(t *testing.T) {
|
||||
var err AtomicError
|
||||
err.Set(errDummy)
|
||||
assert.Equal(t, errDummy, err.Load())
|
||||
}
|
||||
|
||||
func TestAtomicErrorNil(t *testing.T) {
|
||||
var err AtomicError
|
||||
assert.Nil(t, err.Load())
|
||||
}
|
45
core/errorx/batcherror.go
Normal file
45
core/errorx/batcherror.go
Normal file
@ -0,0 +1,45 @@
|
||||
package errorx
|
||||
|
||||
import "bytes"
|
||||
|
||||
type (
|
||||
BatchError struct {
|
||||
errs errorArray
|
||||
}
|
||||
|
||||
errorArray []error
|
||||
)
|
||||
|
||||
func (be *BatchError) Add(err error) {
|
||||
if err != nil {
|
||||
be.errs = append(be.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (be *BatchError) Err() error {
|
||||
switch len(be.errs) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return be.errs[0]
|
||||
default:
|
||||
return be.errs
|
||||
}
|
||||
}
|
||||
|
||||
func (be *BatchError) NotNil() bool {
|
||||
return len(be.errs) > 0
|
||||
}
|
||||
|
||||
func (ea errorArray) Error() string {
|
||||
var buf bytes.Buffer
|
||||
|
||||
for i := range ea {
|
||||
if i > 0 {
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
buf.WriteString(ea[i].Error())
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
48
core/errorx/batcherror_test.go
Normal file
48
core/errorx/batcherror_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package errorx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
err1 = "first error"
|
||||
err2 = "second error"
|
||||
)
|
||||
|
||||
func TestBatchErrorNil(t *testing.T) {
|
||||
var batch BatchError
|
||||
assert.Nil(t, batch.Err())
|
||||
assert.False(t, batch.NotNil())
|
||||
batch.Add(nil)
|
||||
assert.Nil(t, batch.Err())
|
||||
assert.False(t, batch.NotNil())
|
||||
}
|
||||
|
||||
func TestBatchErrorNilFromFunc(t *testing.T) {
|
||||
err := func() error {
|
||||
var be BatchError
|
||||
return be.Err()
|
||||
}()
|
||||
assert.True(t, err == nil)
|
||||
}
|
||||
|
||||
func TestBatchErrorOneError(t *testing.T) {
|
||||
var batch BatchError
|
||||
batch.Add(errors.New(err1))
|
||||
assert.NotNil(t, batch)
|
||||
assert.Equal(t, err1, batch.Err().Error())
|
||||
assert.True(t, batch.NotNil())
|
||||
}
|
||||
|
||||
func TestBatchErrorWithErrors(t *testing.T) {
|
||||
var batch BatchError
|
||||
batch.Add(errors.New(err1))
|
||||
batch.Add(errors.New(err2))
|
||||
assert.NotNil(t, batch)
|
||||
assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error())
|
||||
assert.True(t, batch.NotNil())
|
||||
}
|
93
core/executors/bulkexecutor.go
Normal file
93
core/executors/bulkexecutor.go
Normal file
@ -0,0 +1,93 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultBulkTasks = 1000
|
||||
|
||||
type (
|
||||
BulkOption func(options *bulkOptions)
|
||||
|
||||
BulkExecutor struct {
|
||||
executor *PeriodicalExecutor
|
||||
container *bulkContainer
|
||||
}
|
||||
|
||||
bulkOptions struct {
|
||||
cachedTasks int
|
||||
flushInterval time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
|
||||
options := newBulkOptions()
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
container := &bulkContainer{
|
||||
execute: execute,
|
||||
maxTasks: options.cachedTasks,
|
||||
}
|
||||
executor := &BulkExecutor{
|
||||
executor: NewPeriodicalExecutor(options.flushInterval, container),
|
||||
container: container,
|
||||
}
|
||||
|
||||
return executor
|
||||
}
|
||||
|
||||
func (be *BulkExecutor) Add(task interface{}) error {
|
||||
be.executor.Add(task)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (be *BulkExecutor) Flush() {
|
||||
be.executor.Flush()
|
||||
}
|
||||
|
||||
func (be *BulkExecutor) Wait() {
|
||||
be.executor.Wait()
|
||||
}
|
||||
|
||||
func WithBulkTasks(tasks int) BulkOption {
|
||||
return func(options *bulkOptions) {
|
||||
options.cachedTasks = tasks
|
||||
}
|
||||
}
|
||||
|
||||
func WithBulkInterval(duration time.Duration) BulkOption {
|
||||
return func(options *bulkOptions) {
|
||||
options.flushInterval = duration
|
||||
}
|
||||
}
|
||||
|
||||
func newBulkOptions() bulkOptions {
|
||||
return bulkOptions{
|
||||
cachedTasks: defaultBulkTasks,
|
||||
flushInterval: defaultFlushInterval,
|
||||
}
|
||||
}
|
||||
|
||||
type bulkContainer struct {
|
||||
tasks []interface{}
|
||||
execute Execute
|
||||
maxTasks int
|
||||
}
|
||||
|
||||
func (bc *bulkContainer) AddTask(task interface{}) bool {
|
||||
bc.tasks = append(bc.tasks, task)
|
||||
return len(bc.tasks) >= bc.maxTasks
|
||||
}
|
||||
|
||||
func (bc *bulkContainer) Execute(tasks interface{}) {
|
||||
vals := tasks.([]interface{})
|
||||
bc.execute(vals)
|
||||
}
|
||||
|
||||
func (bc *bulkContainer) RemoveAll() interface{} {
|
||||
tasks := bc.tasks
|
||||
bc.tasks = nil
|
||||
return tasks
|
||||
}
|
113
core/executors/bulkexecutor_test.go
Normal file
113
core/executors/bulkexecutor_test.go
Normal file
@ -0,0 +1,113 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBulkExecutor(t *testing.T) {
|
||||
var values []int
|
||||
var lock sync.Mutex
|
||||
|
||||
exeutor := NewBulkExecutor(func(items []interface{}) {
|
||||
lock.Lock()
|
||||
values = append(values, len(items))
|
||||
lock.Unlock()
|
||||
}, WithBulkTasks(10), WithBulkInterval(time.Minute))
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
exeutor.Add(1)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
assert.True(t, len(values) > 0)
|
||||
// ignore last value
|
||||
for i := 0; i < len(values); i++ {
|
||||
assert.Equal(t, 10, values[i])
|
||||
}
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
func TestBulkExecutorFlushInterval(t *testing.T) {
|
||||
const (
|
||||
caches = 10
|
||||
size = 5
|
||||
)
|
||||
var wait sync.WaitGroup
|
||||
|
||||
wait.Add(1)
|
||||
exeutor := NewBulkExecutor(func(items []interface{}) {
|
||||
assert.Equal(t, size, len(items))
|
||||
wait.Done()
|
||||
}, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100))
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
exeutor.Add(1)
|
||||
}
|
||||
|
||||
wait.Wait()
|
||||
}
|
||||
|
||||
func TestBulkExecutorEmpty(t *testing.T) {
|
||||
NewBulkExecutor(func(items []interface{}) {
|
||||
assert.Fail(t, "should not called")
|
||||
}, WithBulkTasks(10), WithBulkInterval(time.Millisecond))
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
}
|
||||
|
||||
func TestBulkExecutorFlush(t *testing.T) {
|
||||
const (
|
||||
caches = 10
|
||||
tasks = 5
|
||||
)
|
||||
|
||||
var wait sync.WaitGroup
|
||||
wait.Add(1)
|
||||
be := NewBulkExecutor(func(items []interface{}) {
|
||||
assert.Equal(t, tasks, len(items))
|
||||
wait.Done()
|
||||
}, WithBulkTasks(caches), WithBulkInterval(time.Minute))
|
||||
for i := 0; i < tasks; i++ {
|
||||
be.Add(1)
|
||||
}
|
||||
be.Flush()
|
||||
wait.Wait()
|
||||
}
|
||||
|
||||
func TestBuldExecutorFlushSlowTasks(t *testing.T) {
|
||||
const total = 1500
|
||||
lock := new(sync.Mutex)
|
||||
result := make([]interface{}, 0, 10000)
|
||||
exec := NewBulkExecutor(func(tasks []interface{}) {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
for _, i := range tasks {
|
||||
result = append(result, i)
|
||||
}
|
||||
}, WithBulkTasks(1000))
|
||||
for i := 0; i < total; i++ {
|
||||
assert.Nil(t, exec.Add(i))
|
||||
}
|
||||
|
||||
exec.Flush()
|
||||
exec.Wait()
|
||||
assert.Equal(t, total, len(result))
|
||||
}
|
||||
|
||||
func BenchmarkBulkExecutor(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
be := NewBulkExecutor(func(tasks []interface{}) {
|
||||
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
|
||||
})
|
||||
for i := 0; i < b.N; i++ {
|
||||
time.Sleep(time.Microsecond * 200)
|
||||
be.Add(1)
|
||||
}
|
||||
be.Flush()
|
||||
}
|
103
core/executors/chunkexecutor.go
Normal file
103
core/executors/chunkexecutor.go
Normal file
@ -0,0 +1,103 @@
|
||||
package executors
|
||||
|
||||
import "time"
|
||||
|
||||
const defaultChunkSize = 1024 * 1024 // 1M
|
||||
|
||||
type (
|
||||
ChunkOption func(options *chunkOptions)
|
||||
|
||||
ChunkExecutor struct {
|
||||
executor *PeriodicalExecutor
|
||||
container *chunkContainer
|
||||
}
|
||||
|
||||
chunkOptions struct {
|
||||
chunkSize int
|
||||
flushInterval time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
|
||||
options := newChunkOptions()
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
container := &chunkContainer{
|
||||
execute: execute,
|
||||
maxChunkSize: options.chunkSize,
|
||||
}
|
||||
executor := &ChunkExecutor{
|
||||
executor: NewPeriodicalExecutor(options.flushInterval, container),
|
||||
container: container,
|
||||
}
|
||||
|
||||
return executor
|
||||
}
|
||||
|
||||
func (ce *ChunkExecutor) Add(task interface{}, size int) error {
|
||||
ce.executor.Add(chunk{
|
||||
val: task,
|
||||
size: size,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ce *ChunkExecutor) Flush() {
|
||||
ce.executor.Flush()
|
||||
}
|
||||
|
||||
func (ce *ChunkExecutor) Wait() {
|
||||
ce.executor.Wait()
|
||||
}
|
||||
|
||||
func WithChunkBytes(size int) ChunkOption {
|
||||
return func(options *chunkOptions) {
|
||||
options.chunkSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithFlushInterval(duration time.Duration) ChunkOption {
|
||||
return func(options *chunkOptions) {
|
||||
options.flushInterval = duration
|
||||
}
|
||||
}
|
||||
|
||||
func newChunkOptions() chunkOptions {
|
||||
return chunkOptions{
|
||||
chunkSize: defaultChunkSize,
|
||||
flushInterval: defaultFlushInterval,
|
||||
}
|
||||
}
|
||||
|
||||
type chunkContainer struct {
|
||||
tasks []interface{}
|
||||
execute Execute
|
||||
size int
|
||||
maxChunkSize int
|
||||
}
|
||||
|
||||
func (bc *chunkContainer) AddTask(task interface{}) bool {
|
||||
ck := task.(chunk)
|
||||
bc.tasks = append(bc.tasks, ck.val)
|
||||
bc.size += ck.size
|
||||
return bc.size >= bc.maxChunkSize
|
||||
}
|
||||
|
||||
func (bc *chunkContainer) Execute(tasks interface{}) {
|
||||
vals := tasks.([]interface{})
|
||||
bc.execute(vals)
|
||||
}
|
||||
|
||||
func (bc *chunkContainer) RemoveAll() interface{} {
|
||||
tasks := bc.tasks
|
||||
bc.tasks = nil
|
||||
bc.size = 0
|
||||
return tasks
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
val interface{}
|
||||
size int
|
||||
}
|
92
core/executors/chunkexecutor_test.go
Normal file
92
core/executors/chunkexecutor_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChunkExecutor(t *testing.T) {
|
||||
var values []int
|
||||
var lock sync.Mutex
|
||||
|
||||
exeutor := NewChunkExecutor(func(items []interface{}) {
|
||||
lock.Lock()
|
||||
values = append(values, len(items))
|
||||
lock.Unlock()
|
||||
}, WithChunkBytes(10), WithFlushInterval(time.Minute))
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
exeutor.Add(1, 1)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
assert.True(t, len(values) > 0)
|
||||
// ignore last value
|
||||
for i := 0; i < len(values); i++ {
|
||||
assert.Equal(t, 10, values[i])
|
||||
}
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
func TestChunkExecutorFlushInterval(t *testing.T) {
|
||||
const (
|
||||
caches = 10
|
||||
size = 5
|
||||
)
|
||||
var wait sync.WaitGroup
|
||||
|
||||
wait.Add(1)
|
||||
exeutor := NewChunkExecutor(func(items []interface{}) {
|
||||
assert.Equal(t, size, len(items))
|
||||
wait.Done()
|
||||
}, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100))
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
exeutor.Add(1, 1)
|
||||
}
|
||||
|
||||
wait.Wait()
|
||||
}
|
||||
|
||||
func TestChunkExecutorEmpty(t *testing.T) {
|
||||
NewChunkExecutor(func(items []interface{}) {
|
||||
assert.Fail(t, "should not called")
|
||||
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
}
|
||||
|
||||
func TestChunkExecutorFlush(t *testing.T) {
|
||||
const (
|
||||
caches = 10
|
||||
tasks = 5
|
||||
)
|
||||
|
||||
var wait sync.WaitGroup
|
||||
wait.Add(1)
|
||||
be := NewChunkExecutor(func(items []interface{}) {
|
||||
assert.Equal(t, tasks, len(items))
|
||||
wait.Done()
|
||||
}, WithChunkBytes(caches), WithFlushInterval(time.Minute))
|
||||
for i := 0; i < tasks; i++ {
|
||||
be.Add(1, 1)
|
||||
}
|
||||
be.Flush()
|
||||
wait.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkChunkExecutor(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
be := NewChunkExecutor(func(tasks []interface{}) {
|
||||
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
|
||||
})
|
||||
for i := 0; i < b.N; i++ {
|
||||
time.Sleep(time.Microsecond * 200)
|
||||
be.Add(1, 1)
|
||||
}
|
||||
be.Flush()
|
||||
}
|
44
core/executors/delayexecutor.go
Normal file
44
core/executors/delayexecutor.go
Normal file
@ -0,0 +1,44 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/threading"
|
||||
)
|
||||
|
||||
type DelayExecutor struct {
|
||||
fn func()
|
||||
delay time.Duration
|
||||
triggered bool
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewDelayExecutor(fn func(), delay time.Duration) *DelayExecutor {
|
||||
return &DelayExecutor{
|
||||
fn: fn,
|
||||
delay: delay,
|
||||
}
|
||||
}
|
||||
|
||||
func (de *DelayExecutor) Trigger() {
|
||||
de.lock.Lock()
|
||||
defer de.lock.Unlock()
|
||||
|
||||
if de.triggered {
|
||||
return
|
||||
}
|
||||
|
||||
de.triggered = true
|
||||
threading.GoSafe(func() {
|
||||
timer := time.NewTimer(de.delay)
|
||||
defer timer.Stop()
|
||||
<-timer.C
|
||||
|
||||
// set triggered to false before calling fn to ensure no triggers are missed.
|
||||
de.lock.Lock()
|
||||
de.triggered = false
|
||||
de.lock.Unlock()
|
||||
de.fn()
|
||||
})
|
||||
}
|
21
core/executors/delayexecutor_test.go
Normal file
21
core/executors/delayexecutor_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDelayExecutor(t *testing.T) {
|
||||
var count int32
|
||||
ex := NewDelayExecutor(func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
}, time.Millisecond*10)
|
||||
for i := 0; i < 100; i++ {
|
||||
ex.Trigger()
|
||||
}
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
||||
}
|
32
core/executors/lessexecutor.go
Normal file
32
core/executors/lessexecutor.go
Normal file
@ -0,0 +1,32 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"zero/core/syncx"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
type LessExecutor struct {
|
||||
threshold time.Duration
|
||||
lastTime *syncx.AtomicDuration
|
||||
}
|
||||
|
||||
func NewLessExecutor(threshold time.Duration) *LessExecutor {
|
||||
return &LessExecutor{
|
||||
threshold: threshold,
|
||||
lastTime: syncx.NewAtomicDuration(),
|
||||
}
|
||||
}
|
||||
|
||||
func (le *LessExecutor) DoOrDiscard(execute func()) bool {
|
||||
now := timex.Now()
|
||||
lastTime := le.lastTime.Load()
|
||||
if lastTime == 0 || lastTime+le.threshold < now {
|
||||
le.lastTime.Set(now)
|
||||
execute()
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
27
core/executors/lessexecutor_test.go
Normal file
27
core/executors/lessexecutor_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/timex"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLessExecutor_DoOrDiscard(t *testing.T) {
|
||||
executor := NewLessExecutor(time.Minute)
|
||||
assert.True(t, executor.DoOrDiscard(func() {}))
|
||||
assert.False(t, executor.DoOrDiscard(func() {}))
|
||||
executor.lastTime.Set(timex.Now() - time.Minute - time.Second*30)
|
||||
assert.True(t, executor.DoOrDiscard(func() {}))
|
||||
assert.False(t, executor.DoOrDiscard(func() {}))
|
||||
}
|
||||
|
||||
func BenchmarkLessExecutor(b *testing.B) {
|
||||
exec := NewLessExecutor(time.Millisecond)
|
||||
for i := 0; i < b.N; i++ {
|
||||
exec.DoOrDiscard(func() {
|
||||
})
|
||||
}
|
||||
}
|
158
core/executors/periodicalexecutor.go
Normal file
158
core/executors/periodicalexecutor.go
Normal file
@ -0,0 +1,158 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/proc"
|
||||
"zero/core/threading"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
const idleRound = 10
|
||||
|
||||
type (
|
||||
// A type that satisfies executors.TaskContainer can be used as the underlying
|
||||
// container that used to do periodical executions.
|
||||
TaskContainer interface {
|
||||
// AddTask adds the task into the container.
|
||||
// Returns true if the container needs to be flushed after the addition.
|
||||
AddTask(task interface{}) bool
|
||||
// Execute handles the collected tasks by the container when flushing.
|
||||
Execute(tasks interface{})
|
||||
// RemoveAll removes the contained tasks, and return them.
|
||||
RemoveAll() interface{}
|
||||
}
|
||||
|
||||
PeriodicalExecutor struct {
|
||||
commander chan interface{}
|
||||
interval time.Duration
|
||||
container TaskContainer
|
||||
waitGroup sync.WaitGroup
|
||||
guarded bool
|
||||
newTicker func(duration time.Duration) timex.Ticker
|
||||
lock sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor {
|
||||
executor := &PeriodicalExecutor{
|
||||
// buffer 1 to let the caller go quickly
|
||||
commander: make(chan interface{}, 1),
|
||||
interval: interval,
|
||||
container: container,
|
||||
newTicker: func(d time.Duration) timex.Ticker {
|
||||
return timex.NewTicker(interval)
|
||||
},
|
||||
}
|
||||
proc.AddShutdownListener(func() {
|
||||
executor.Flush()
|
||||
})
|
||||
|
||||
return executor
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) Add(task interface{}) {
|
||||
if vals, ok := pe.addAndCheck(task); ok {
|
||||
pe.commander <- vals
|
||||
}
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) Flush() bool {
|
||||
return pe.executeTasks(func() interface{} {
|
||||
pe.lock.Lock()
|
||||
defer pe.lock.Unlock()
|
||||
return pe.container.RemoveAll()
|
||||
}())
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) Sync(fn func()) {
|
||||
pe.lock.Lock()
|
||||
defer pe.lock.Unlock()
|
||||
fn()
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) Wait() {
|
||||
pe.waitGroup.Wait()
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
|
||||
pe.lock.Lock()
|
||||
defer func() {
|
||||
var start bool
|
||||
if !pe.guarded {
|
||||
pe.guarded = true
|
||||
start = true
|
||||
}
|
||||
pe.lock.Unlock()
|
||||
if start {
|
||||
pe.backgroundFlush()
|
||||
}
|
||||
}()
|
||||
|
||||
if pe.container.AddTask(task) {
|
||||
return pe.container.RemoveAll(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) backgroundFlush() {
|
||||
threading.GoSafe(func() {
|
||||
ticker := pe.newTicker(pe.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var commanded bool
|
||||
last := timex.Now()
|
||||
for {
|
||||
select {
|
||||
case vals := <-pe.commander:
|
||||
commanded = true
|
||||
pe.executeTasks(vals)
|
||||
last = timex.Now()
|
||||
case <-ticker.Chan():
|
||||
if commanded {
|
||||
commanded = false
|
||||
} else if pe.Flush() {
|
||||
last = timex.Now()
|
||||
} else if timex.Since(last) > pe.interval*idleRound {
|
||||
pe.lock.Lock()
|
||||
pe.guarded = false
|
||||
pe.lock.Unlock()
|
||||
|
||||
// flush again to avoid missing tasks
|
||||
pe.Flush()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) executeTasks(tasks interface{}) bool {
|
||||
pe.waitGroup.Add(1)
|
||||
defer pe.waitGroup.Done()
|
||||
|
||||
ok := pe.hasTasks(tasks)
|
||||
if ok {
|
||||
pe.container.Execute(tasks)
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
|
||||
if tasks == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(tasks)
|
||||
switch val.Kind() {
|
||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
|
||||
return val.Len() > 0
|
||||
default:
|
||||
// unknown type, let caller execute it
|
||||
return true
|
||||
}
|
||||
}
|
118
core/executors/periodicalexecutor_test.go
Normal file
118
core/executors/periodicalexecutor_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package executors
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/timex"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const threshold = 10
|
||||
|
||||
type container struct {
|
||||
interval time.Duration
|
||||
tasks []int
|
||||
execute func(tasks interface{})
|
||||
}
|
||||
|
||||
func newContainer(interval time.Duration, execute func(tasks interface{})) *container {
|
||||
return &container{
|
||||
interval: interval,
|
||||
execute: execute,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *container) AddTask(task interface{}) bool {
|
||||
c.tasks = append(c.tasks, task.(int))
|
||||
return len(c.tasks) > threshold
|
||||
}
|
||||
|
||||
func (c *container) Execute(tasks interface{}) {
|
||||
if c.execute != nil {
|
||||
c.execute(tasks)
|
||||
} else {
|
||||
time.Sleep(c.interval)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *container) RemoveAll() interface{} {
|
||||
tasks := c.tasks
|
||||
c.tasks = nil
|
||||
return tasks
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Sync(t *testing.T) {
|
||||
var done int32
|
||||
exec := NewPeriodicalExecutor(time.Second, newContainer(time.Millisecond*500, nil))
|
||||
exec.Sync(func() {
|
||||
atomic.AddInt32(&done, 1)
|
||||
})
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&done))
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
|
||||
exec.newTicker = func(d time.Duration) timex.Ticker {
|
||||
return ticker
|
||||
}
|
||||
routines := runtime.NumGoroutine()
|
||||
exec.Add(1)
|
||||
ticker.Tick()
|
||||
ticker.Wait(time.Millisecond * idleRound * 2)
|
||||
ticker.Tick()
|
||||
ticker.Wait(time.Millisecond * idleRound)
|
||||
assert.Equal(t, routines, runtime.NumGoroutine())
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Bulk(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
var vals []int
|
||||
// avoid data race
|
||||
var lock sync.Mutex
|
||||
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks interface{}) {
|
||||
t := tasks.([]int)
|
||||
for _, each := range t {
|
||||
lock.Lock()
|
||||
vals = append(vals, each)
|
||||
lock.Unlock()
|
||||
}
|
||||
}))
|
||||
exec.newTicker = func(d time.Duration) timex.Ticker {
|
||||
return ticker
|
||||
}
|
||||
for i := 0; i < threshold*10; i++ {
|
||||
if i%threshold == 5 {
|
||||
time.Sleep(time.Millisecond * idleRound * 2)
|
||||
}
|
||||
exec.Add(i)
|
||||
}
|
||||
ticker.Tick()
|
||||
ticker.Wait(time.Millisecond * idleRound * 2)
|
||||
ticker.Tick()
|
||||
ticker.Tick()
|
||||
ticker.Wait(time.Millisecond * idleRound)
|
||||
var expect []int
|
||||
for i := 0; i < threshold*10; i++ {
|
||||
expect = append(expect, i)
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
assert.EqualValues(t, expect, vals)
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
// go test -benchtime 10s -bench .
|
||||
func BenchmarkExecutor(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
executor := NewPeriodicalExecutor(time.Second, newContainer(time.Millisecond*500, nil))
|
||||
for i := 0; i < b.N; i++ {
|
||||
executor.Add(1)
|
||||
}
|
||||
}
|
7
core/executors/vars.go
Normal file
7
core/executors/vars.go
Normal file
@ -0,0 +1,7 @@
|
||||
package executors
|
||||
|
||||
import "time"
|
||||
|
||||
const defaultFlushInterval = time.Second
|
||||
|
||||
type Execute func(tasks []interface{})
|
84
core/filex/file.go
Normal file
84
core/filex/file.go
Normal file
@ -0,0 +1,84 @@
|
||||
package filex
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
const bufSize = 1024
|
||||
|
||||
func FirstLine(filename string) (string, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return firstLine(file)
|
||||
}
|
||||
|
||||
func LastLine(filename string) (string, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return lastLine(filename, file)
|
||||
}
|
||||
|
||||
func firstLine(file *os.File) (string, error) {
|
||||
var first []byte
|
||||
var offset int64
|
||||
for {
|
||||
buf := make([]byte, bufSize)
|
||||
n, err := file.ReadAt(buf, offset)
|
||||
if err != nil && err != io.EOF {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if buf[i] == '\n' {
|
||||
return string(append(first, buf[:i]...)), nil
|
||||
}
|
||||
}
|
||||
|
||||
first = append(first, buf[:n]...)
|
||||
offset += bufSize
|
||||
}
|
||||
}
|
||||
|
||||
func lastLine(filename string, file *os.File) (string, error) {
|
||||
info, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var last []byte
|
||||
offset := info.Size()
|
||||
for {
|
||||
offset -= bufSize
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
buf := make([]byte, bufSize)
|
||||
n, err := file.ReadAt(buf, offset)
|
||||
if err != nil && err != io.EOF {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if buf[n-1] == '\n' {
|
||||
buf = buf[:n-1]
|
||||
n -= 1
|
||||
} else {
|
||||
buf = buf[:n]
|
||||
}
|
||||
for n -= 1; n >= 0; n-- {
|
||||
if buf[n] == '\n' {
|
||||
return string(append(buf[n+1:], last...)), nil
|
||||
}
|
||||
}
|
||||
|
||||
last = append(buf, last...)
|
||||
}
|
||||
}
|
116
core/filex/file_test.go
Normal file
116
core/filex/file_test.go
Normal file
@ -0,0 +1,116 @@
|
||||
package filex
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"zero/core/fs"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
longLine = `Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.`
|
||||
longFirstLine = longLine + "\n" + text
|
||||
text = `first line
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
` + longLine
|
||||
textWithLastNewline = `first line
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
|
||||
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
|
||||
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
|
||||
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
|
||||
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
|
||||
` + longLine + "\n"
|
||||
shortText = `first line
|
||||
second line
|
||||
last line`
|
||||
shortTextWithLastNewline = `first line
|
||||
second line
|
||||
last line
|
||||
`
|
||||
)
|
||||
|
||||
func TestFirstLine(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText(longFirstLine)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
|
||||
val, err := FirstLine(filename)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, longLine, val)
|
||||
}
|
||||
|
||||
func TestFirstLineShort(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText(shortText)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
|
||||
val, err := FirstLine(filename)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "first line", val)
|
||||
}
|
||||
|
||||
func TestLastLine(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
|
||||
val, err := LastLine(filename)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, longLine, val)
|
||||
}
|
||||
|
||||
func TestLastLineWithLastNewline(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText(textWithLastNewline)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
|
||||
val, err := LastLine(filename)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, longLine, val)
|
||||
}
|
||||
|
||||
func TestLastLineShort(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText(shortText)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
|
||||
val, err := LastLine(filename)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "last line", val)
|
||||
}
|
||||
|
||||
func TestLastLineWithLastNewlineShort(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText(shortTextWithLastNewline)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
|
||||
val, err := LastLine(filename)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "last line", val)
|
||||
}
|
105
core/filex/lookup.go
Normal file
105
core/filex/lookup.go
Normal file
@ -0,0 +1,105 @@
|
||||
package filex
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
type OffsetRange struct {
|
||||
File string
|
||||
Start int64
|
||||
Stop int64
|
||||
}
|
||||
|
||||
func SplitLineChunks(filename string, chunks int) ([]OffsetRange, error) {
|
||||
info, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if chunks <= 1 {
|
||||
return []OffsetRange{
|
||||
{
|
||||
File: filename,
|
||||
Start: 0,
|
||||
Stop: info.Size(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var ranges []OffsetRange
|
||||
var offset int64
|
||||
// avoid the last chunk too few bytes
|
||||
preferSize := info.Size()/int64(chunks) + 1
|
||||
for {
|
||||
if offset+preferSize >= info.Size() {
|
||||
ranges = append(ranges, OffsetRange{
|
||||
File: filename,
|
||||
Start: offset,
|
||||
Stop: info.Size(),
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
offsetRange, err := nextRange(file, offset, offset+preferSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ranges = append(ranges, offsetRange)
|
||||
if offsetRange.Stop < info.Size() {
|
||||
offset = offsetRange.Stop
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return ranges, nil
|
||||
}
|
||||
|
||||
func nextRange(file *os.File, start, stop int64) (OffsetRange, error) {
|
||||
offset, err := skipPartialLine(file, stop)
|
||||
if err != nil {
|
||||
return OffsetRange{}, err
|
||||
}
|
||||
|
||||
return OffsetRange{
|
||||
File: file.Name(),
|
||||
Start: start,
|
||||
Stop: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func skipPartialLine(file *os.File, offset int64) (int64, error) {
|
||||
for {
|
||||
skipBuf := make([]byte, bufSize)
|
||||
n, err := file.ReadAt(skipBuf, offset)
|
||||
if err != nil && err != io.EOF {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if skipBuf[i] != '\r' && skipBuf[i] != '\n' {
|
||||
offset++
|
||||
} else {
|
||||
for ; i < n; i++ {
|
||||
if skipBuf[i] == '\r' || skipBuf[i] == '\n' {
|
||||
offset++
|
||||
} else {
|
||||
return offset, nil
|
||||
}
|
||||
}
|
||||
return offset, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
68
core/filex/lookup_test.go
Normal file
68
core/filex/lookup_test.go
Normal file
@ -0,0 +1,68 @@
|
||||
package filex
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"zero/core/fs"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSplitLineChunks(t *testing.T) {
|
||||
const text = `first line
|
||||
second line
|
||||
third line
|
||||
fourth line
|
||||
fifth line
|
||||
sixth line
|
||||
seventh line
|
||||
`
|
||||
fp, err := fs.TempFileWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer func() {
|
||||
fp.Close()
|
||||
os.Remove(fp.Name())
|
||||
}()
|
||||
|
||||
offsets, err := SplitLineChunks(fp.Name(), 3)
|
||||
assert.Nil(t, err)
|
||||
body := make([]byte, 512)
|
||||
for _, offset := range offsets {
|
||||
reader := NewRangeReader(fp, offset.Start, offset.Stop)
|
||||
n, err := reader.Read(body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, uint8('\n'), body[n-1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLineChunksNoFile(t *testing.T) {
|
||||
_, err := SplitLineChunks("nosuchfile", 2)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestSplitLineChunksFull(t *testing.T) {
|
||||
const text = `first line
|
||||
second line
|
||||
third line
|
||||
fourth line
|
||||
fifth line
|
||||
sixth line
|
||||
`
|
||||
fp, err := fs.TempFileWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer func() {
|
||||
fp.Close()
|
||||
os.Remove(fp.Name())
|
||||
}()
|
||||
|
||||
offsets, err := SplitLineChunks(fp.Name(), 1)
|
||||
assert.Nil(t, err)
|
||||
body := make([]byte, 512)
|
||||
for _, offset := range offsets {
|
||||
reader := NewRangeReader(fp, offset.Start, offset.Stop)
|
||||
n, err := reader.Read(body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, []byte(text), body[:n])
|
||||
}
|
||||
}
|
28
core/filex/progressscanner.go
Normal file
28
core/filex/progressscanner.go
Normal file
@ -0,0 +1,28 @@
|
||||
package filex
|
||||
|
||||
import "gopkg.in/cheggaaa/pb.v1"
|
||||
|
||||
type (
|
||||
Scanner interface {
|
||||
Scan() bool
|
||||
Text() string
|
||||
}
|
||||
|
||||
progressScanner struct {
|
||||
Scanner
|
||||
bar *pb.ProgressBar
|
||||
}
|
||||
)
|
||||
|
||||
func NewProgressScanner(scanner Scanner, bar *pb.ProgressBar) Scanner {
|
||||
return &progressScanner{
|
||||
Scanner: scanner,
|
||||
bar: bar,
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *progressScanner) Text() string {
|
||||
s := ps.Scanner.Text()
|
||||
ps.bar.Add64(int64(len(s)) + 1) // take newlines into account
|
||||
return s
|
||||
}
|
31
core/filex/progressscanner_test.go
Normal file
31
core/filex/progressscanner_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package filex
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/cheggaaa/pb.v1"
|
||||
)
|
||||
|
||||
func TestProgressScanner(t *testing.T) {
|
||||
const text = "hello, world"
|
||||
bar := pb.New(100)
|
||||
var builder strings.Builder
|
||||
builder.WriteString(text)
|
||||
scanner := NewProgressScanner(&mockedScanner{builder: &builder}, bar)
|
||||
assert.True(t, scanner.Scan())
|
||||
assert.Equal(t, text, scanner.Text())
|
||||
}
|
||||
|
||||
type mockedScanner struct {
|
||||
builder *strings.Builder
|
||||
}
|
||||
|
||||
func (s *mockedScanner) Scan() bool {
|
||||
return s.builder.Len() > 0
|
||||
}
|
||||
|
||||
func (s *mockedScanner) Text() string {
|
||||
return s.builder.String()
|
||||
}
|
43
core/filex/rangereader.go
Normal file
43
core/filex/rangereader.go
Normal file
@ -0,0 +1,43 @@
|
||||
package filex
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
)
|
||||
|
||||
type RangeReader struct {
|
||||
file *os.File
|
||||
start int64
|
||||
stop int64
|
||||
}
|
||||
|
||||
func NewRangeReader(file *os.File, start, stop int64) *RangeReader {
|
||||
return &RangeReader{
|
||||
file: file,
|
||||
start: start,
|
||||
stop: stop,
|
||||
}
|
||||
}
|
||||
|
||||
func (rr *RangeReader) Read(p []byte) (n int, err error) {
|
||||
stat, err := rr.file.Stat()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if rr.stop < rr.start || rr.start >= stat.Size() {
|
||||
return 0, errors.New("exceed file size")
|
||||
}
|
||||
|
||||
if rr.stop-rr.start < int64(len(p)) {
|
||||
p = p[:rr.stop-rr.start]
|
||||
}
|
||||
|
||||
n, err = rr.file.ReadAt(p, rr.start)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
rr.start += int64(n)
|
||||
return
|
||||
}
|
45
core/filex/rangereader_test.go
Normal file
45
core/filex/rangereader_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
package filex
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"zero/core/fs"
|
||||
)
|
||||
|
||||
func TestRangeReader(t *testing.T) {
|
||||
const text = `hello
|
||||
world`
|
||||
file, err := fs.TempFileWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer func() {
|
||||
file.Close()
|
||||
os.Remove(file.Name())
|
||||
}()
|
||||
|
||||
reader := NewRangeReader(file, 5, 8)
|
||||
buf := make([]byte, 10)
|
||||
n, err := reader.Read(buf)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.Equal(t, `
|
||||
wo`, string(buf[:n]))
|
||||
}
|
||||
|
||||
func TestRangeReader_OutOfRange(t *testing.T) {
|
||||
const text = `hello
|
||||
world`
|
||||
file, err := fs.TempFileWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer func() {
|
||||
file.Close()
|
||||
os.Remove(file.Name())
|
||||
}()
|
||||
|
||||
reader := NewRangeReader(file, 50, 8)
|
||||
buf := make([]byte, 10)
|
||||
_, err = reader.Read(buf)
|
||||
assert.NotNil(t, err)
|
||||
}
|
8
core/fs/files+polyfill.go
Normal file
8
core/fs/files+polyfill.go
Normal file
@ -0,0 +1,8 @@
|
||||
// +build windows
|
||||
|
||||
package fs
|
||||
|
||||
import "os"
|
||||
|
||||
func CloseOnExec(*os.File) {
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user