initial import

This commit is contained in:
kevin 2020-07-26 17:09:05 +08:00
commit 7e3a369a8f
647 changed files with 54754 additions and 0 deletions

1
.dockerignore Normal file
View File

@ -0,0 +1 @@
**/.git

1
.gitattributes vendored Normal file
View File

@ -0,0 +1 @@
* text=auto eol=lf

46
.gitignore vendored Normal file
View File

@ -0,0 +1,46 @@
# Ignore all
*
# Unignore all with extensions
!*.*
!**/Dockerfile
# Unignore all dirs
!*/
!api
.idea
**/.DS_Store
**/logs
!vendor/github.com/songtianyi/rrframework/logs
**/*.pem
**/*.prof
**/*.p12
!Makefile
# gitlab ci
.cache
# chatbot
**/production.json
**/*.corpus.json
**/*.txt
**/*.gob
# example
example/**/*.csv
# hera
service/hera/cli/readdata/intergrationtest/data
service/hera/devkit/ch331/data
service/hera/devkit/ch331/ck
service/hera/cli/replaybeat/etc
# goctl
tools/goctl/api/autogen
vendor/*
/service/hera/cli/dboperation/etc/hera.json
# vim auto backup file
*~
!OWNERS

16
.gitlab-ci.yml Normal file
View File

@ -0,0 +1,16 @@
stages:
- analysis
variables:
GOPATH: '/runner-cache/zero'
GOCACHE: '/runner-cache/zero'
GOPROXY: 'https://goproxy.cn,direct'
analysis:
stage: analysis
image: golang
script:
- go version && go env
- go test -short $(go list ./...) | grep -v "no test"
only:
- merge_requests

43
.golangci.yml Normal file
View File

@ -0,0 +1,43 @@
run:
# concurrency: 6
timeout: 5m
skip-dirs:
- core
- diq
- doc
- dq
- example
- kmq
- kq
- ngin
- rq
- rpcx
# - service
- stash
- tools
linters:
disable-all: true
enable:
- bodyclose
- deadcode
- errcheck
- gosimple
- govet
- ineffassign
- staticcheck
- structcheck
- typecheck
- unused
- varcheck
# - dupl
linters-settings:
issues:
exclude-rules:
- linters:
- staticcheck
text: 'SA1019: (baseresponse.BoolResponse|oldresponse.FormatBadRequestResponse|oldresponse.FormatResponse)|SA5008: unknown JSON option ("optional"|"default=|"range=|"options=)'

161
core/bloom/bloom.go Normal file
View File

@ -0,0 +1,161 @@
package bloom
import (
"errors"
"strconv"
"zero/core/hash"
"zero/core/stores/redis"
)
const (
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
// maps as k in the error rate table
maps = 14
setScript = `
local key = KEYS[1]
for _, offset in ipairs(ARGV) do
redis.call("setbit", key, offset, 1)
end
`
testScript = `
local key = KEYS[1]
for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", key, offset)) == 0 then
return false
end
end
return true
`
)
var ErrTooLargeOffset = errors.New("too large offset")
type (
BitSetProvider interface {
check([]uint) (bool, error)
set([]uint) error
}
BloomFilter struct {
bits uint
maps uint
bitSet BitSetProvider
}
)
// New create a BloomFilter, store is the backed redis, key is the key for the bloom filter,
// bits is how many bits will be used, maps is how many hashes for each addition.
// best practices:
// elements - means how many actual elements
// when maps = 14, formula: 0.7*(bits/maps), bits = 20*elements, the error rate is 0.000067 < 1e-4
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
func New(store *redis.Redis, key string, bits uint) *BloomFilter {
return &BloomFilter{
bits: bits,
bitSet: newRedisBitSet(store, key, bits),
}
}
func (f *BloomFilter) Add(data []byte) error {
locations := f.getLocations(data)
err := f.bitSet.set(locations)
if err != nil {
return err
}
return nil
}
func (f *BloomFilter) Exists(data []byte) (bool, error) {
locations := f.getLocations(data)
isSet, err := f.bitSet.check(locations)
if err != nil {
return false, err
}
if !isSet {
return false, nil
}
return true, nil
}
func (f *BloomFilter) getLocations(data []byte) []uint {
locations := make([]uint, maps)
for i := uint(0); i < maps; i++ {
hashValue := hash.Hash(append(data, byte(i)))
locations[i] = uint(hashValue % uint64(f.bits))
}
return locations
}
type redisBitSet struct {
store *redis.Redis
key string
bits uint
}
func newRedisBitSet(store *redis.Redis, key string, bits uint) *redisBitSet {
return &redisBitSet{
store: store,
key: key,
bits: bits,
}
}
func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
var args []string
for _, offset := range offsets {
if offset >= r.bits {
return nil, ErrTooLargeOffset
}
args = append(args, strconv.FormatUint(uint64(offset), 10))
}
return args, nil
}
func (r *redisBitSet) check(offsets []uint) (bool, error) {
args, err := r.buildOffsetArgs(offsets)
if err != nil {
return false, err
}
resp, err := r.store.Eval(testScript, []string{r.key}, args)
if err == redis.Nil {
return false, nil
} else if err != nil {
return false, err
}
if exists, ok := resp.(int64); !ok {
return false, nil
} else {
return exists == 1, nil
}
}
func (r *redisBitSet) del() error {
_, err := r.store.Del(r.key)
return err
}
func (r *redisBitSet) expire(seconds int) error {
return r.store.Expire(r.key, seconds)
}
func (r *redisBitSet) set(offsets []uint) error {
args, err := r.buildOffsetArgs(offsets)
if err != nil {
return err
}
_, err = r.store.Eval(setScript, []string{r.key}, args)
if err == redis.Nil {
return nil
} else {
return err
}
}

63
core/bloom/bloom_test.go Normal file
View File

@ -0,0 +1,63 @@
package bloom
import (
"testing"
"zero/core/stores/redis"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
)
func TestRedisBitSet_New_Set_Test(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error("Miniredis could not start")
}
defer s.Close()
store := redis.NewRedis(s.Addr(), redis.NodeType)
bitSet := newRedisBitSet(store, "test_key", 1024)
isSetBefore, err := bitSet.check([]uint{0})
if err != nil {
t.Fatal(err)
}
if isSetBefore {
t.Fatal("Bit should not be set")
}
err = bitSet.set([]uint{512})
if err != nil {
t.Fatal(err)
}
isSetAfter, err := bitSet.check([]uint{512})
if err != nil {
t.Fatal(err)
}
if !isSetAfter {
t.Fatal("Bit should be set")
}
err = bitSet.expire(3600)
if err != nil {
t.Fatal(err)
}
err = bitSet.del()
if err != nil {
t.Fatal(err)
}
}
func TestRedisBitSet_Add(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error("Miniredis could not start")
}
defer s.Close()
store := redis.NewRedis(s.Addr(), redis.NodeType)
filter := New(store, "test_key", 64)
assert.Nil(t, filter.Add([]byte("hello")))
assert.Nil(t, filter.Add([]byte("world")))
ok, err := filter.Exists([]byte("hello"))
assert.Nil(t, err)
assert.True(t, ok)
}

229
core/breaker/breaker.go Normal file
View File

@ -0,0 +1,229 @@
package breaker
import (
"errors"
"fmt"
"strings"
"sync"
"time"
"zero/core/mathx"
"zero/core/proc"
"zero/core/stat"
"zero/core/stringx"
)
const (
StateClosed State = iota
StateOpen
)
const (
numHistoryReasons = 5
timeFormat = "15:04:05"
)
// ErrServiceUnavailable is returned when the CB state is open
var ErrServiceUnavailable = errors.New("circuit breaker is open")
type (
State = int32
Acceptable func(err error) bool
Breaker interface {
// Name returns the name of the netflixBreaker.
Name() string
// Allow checks if the request is allowed.
// If allowed, a promise will be returned, the caller needs to call promise.Accept()
// on success, or call promise.Reject() on failure.
// If not allow, ErrServiceUnavailable will be returned.
Allow() (Promise, error)
// Do runs the given request if the netflixBreaker accepts it.
// Do returns an error instantly if the netflixBreaker rejects the request.
// If a panic occurs in the request, the netflixBreaker handles it as an error
// and causes the same panic again.
Do(req func() error) error
// DoWithAcceptable runs the given request if the netflixBreaker accepts it.
// Do returns an error instantly if the netflixBreaker rejects the request.
// If a panic occurs in the request, the netflixBreaker handles it as an error
// and causes the same panic again.
// acceptable checks if it's a successful call, even if the err is not nil.
DoWithAcceptable(req func() error, acceptable Acceptable) error
// DoWithFallback runs the given request if the netflixBreaker accepts it.
// DoWithFallback runs the fallback if the netflixBreaker rejects the request.
// If a panic occurs in the request, the netflixBreaker handles it as an error
// and causes the same panic again.
DoWithFallback(req func() error, fallback func(err error) error) error
// DoWithFallbackAcceptable runs the given request if the netflixBreaker accepts it.
// DoWithFallback runs the fallback if the netflixBreaker rejects the request.
// If a panic occurs in the request, the netflixBreaker handles it as an error
// and causes the same panic again.
// acceptable checks if it's a successful call, even if the err is not nil.
DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
BreakerOption func(breaker *circuitBreaker)
Promise interface {
Accept()
Reject(reason string)
}
internalPromise interface {
Accept()
Reject()
}
circuitBreaker struct {
name string
throttle
}
internalThrottle interface {
allow() (internalPromise, error)
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
throttle interface {
allow() (Promise, error)
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
)
func NewBreaker(opts ...BreakerOption) Breaker {
var b circuitBreaker
for _, opt := range opts {
opt(&b)
}
if len(b.name) == 0 {
b.name = stringx.Rand()
}
b.throttle = newLoggedThrottle(b.name, newGoogleBreaker())
return &b
}
func (cb *circuitBreaker) Allow() (Promise, error) {
return cb.throttle.allow()
}
func (cb *circuitBreaker) Do(req func() error) error {
return cb.throttle.doReq(req, nil, defaultAcceptable)
}
func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
return cb.throttle.doReq(req, nil, acceptable)
}
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return cb.throttle.doReq(req, fallback, defaultAcceptable)
}
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
acceptable Acceptable) error {
return cb.throttle.doReq(req, fallback, acceptable)
}
func (cb *circuitBreaker) Name() string {
return cb.name
}
func WithName(name string) BreakerOption {
return func(b *circuitBreaker) {
b.name = name
}
}
func defaultAcceptable(err error) bool {
return err == nil
}
type loggedThrottle struct {
name string
internalThrottle
errWin *errorWindow
}
func newLoggedThrottle(name string, t internalThrottle) loggedThrottle {
return loggedThrottle{
name: name,
internalThrottle: t,
errWin: new(errorWindow),
}
}
func (lt loggedThrottle) allow() (Promise, error) {
promise, err := lt.internalThrottle.allow()
return promiseWithReason{
promise: promise,
errWin: lt.errWin,
}, lt.logError(err)
}
func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
accept := acceptable(err)
if !accept {
lt.errWin.add(err.Error())
}
return accept
}))
}
func (lt loggedThrottle) logError(err error) error {
if err == ErrServiceUnavailable {
// if circuit open, not possible to have empty error window
stat.Report(fmt.Sprintf(
"proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s",
proc.ProcessName(), proc.Pid(), lt.name, lt.errWin))
}
return err
}
type errorWindow struct {
reasons [numHistoryReasons]string
index int
count int
lock sync.Mutex
}
func (ew *errorWindow) add(reason string) {
ew.lock.Lock()
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
ew.index = (ew.index + 1) % numHistoryReasons
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
ew.lock.Unlock()
}
func (ew *errorWindow) String() string {
var builder strings.Builder
ew.lock.Lock()
for i := ew.index + ew.count - 1; i >= ew.index; i-- {
builder.WriteString(ew.reasons[i%numHistoryReasons])
builder.WriteByte('\n')
}
ew.lock.Unlock()
return builder.String()
}
type promiseWithReason struct {
promise internalPromise
errWin *errorWindow
}
func (p promiseWithReason) Accept() {
p.promise.Accept()
}
func (p promiseWithReason) Reject(reason string) {
p.errWin.add(reason)
p.promise.Reject()
}

View File

@ -0,0 +1,44 @@
package breaker
import (
"errors"
"strconv"
"testing"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func init() {
stat.SetReporter(nil)
}
func TestCircuitBreaker_Allow(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.Allow()
assert.Nil(t, err)
}
func TestLogReason(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 1000; i++ {
_ = b.Do(func() error {
return errors.New(strconv.Itoa(i))
})
}
errs := b.(*circuitBreaker).throttle.(loggedThrottle).errWin
assert.Equal(t, numHistoryReasons, errs.count)
}
func BenchmarkGoogleBreaker(b *testing.B) {
br := NewBreaker()
for i := 0; i < b.N; i++ {
_ = br.Do(func() error {
return nil
})
}
}

76
core/breaker/breakers.go Normal file
View File

@ -0,0 +1,76 @@
package breaker
import "sync"
var (
lock sync.RWMutex
breakers = make(map[string]Breaker)
)
func Do(name string, req func() error) error {
return do(name, func(b Breaker) error {
return b.Do(req)
})
}
func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithAcceptable(req, acceptable)
})
}
func DoWithFallback(name string, req func() error, fallback func(err error) error) error {
return do(name, func(b Breaker) error {
return b.DoWithFallback(req, fallback)
})
}
func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error,
acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptable(req, fallback, acceptable)
})
}
func GetBreaker(name string) Breaker {
lock.RLock()
b, ok := breakers[name]
lock.RUnlock()
if ok {
return b
}
lock.Lock()
defer lock.Unlock()
b = NewBreaker()
breakers[name] = b
return b
}
func NoBreakFor(name string) {
lock.Lock()
breakers[name] = newNoOpBreaker()
lock.Unlock()
}
func do(name string, execute func(b Breaker) error) error {
lock.RLock()
b, ok := breakers[name]
lock.RUnlock()
if ok {
return execute(b)
} else {
lock.Lock()
b, ok = breakers[name]
if ok {
lock.Unlock()
return execute(b)
} else {
b = NewBreaker(WithName(name))
breakers[name] = b
lock.Unlock()
return execute(b)
}
}
}

View File

@ -0,0 +1,115 @@
package breaker
import (
"errors"
"fmt"
"testing"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
func init() {
stat.SetReporter(nil)
}
func TestBreakersDo(t *testing.T) {
assert.Nil(t, Do("any", func() error {
return nil
}))
errDummy := errors.New("any")
assert.Equal(t, errDummy, Do("any", func() error {
return errDummy
}))
}
func TestBreakersDoWithAcceptable(t *testing.T) {
errDummy := errors.New("anyone")
for i := 0; i < 10000; i++ {
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
return errDummy
}, func(err error) bool {
return err == nil || err == errDummy
}))
}
verify(t, func() bool {
return Do("anyone", func() error {
return nil
}) == nil
})
for i := 0; i < 10000; i++ {
err := DoWithAcceptable("another", func() error {
return errDummy
}, func(err error) bool {
return err == nil
})
assert.True(t, err == errDummy || err == ErrServiceUnavailable)
}
verify(t, func() bool {
return ErrServiceUnavailable == Do("another", func() error {
return nil
})
})
}
func TestBreakersNoBreakerFor(t *testing.T) {
NoBreakFor("any")
errDummy := errors.New("any")
for i := 0; i < 10000; i++ {
assert.Equal(t, errDummy, GetBreaker("any").Do(func() error {
return errDummy
}))
}
assert.Equal(t, nil, Do("any", func() error {
return nil
}))
}
func TestBreakersFallback(t *testing.T) {
errDummy := errors.New("any")
for i := 0; i < 10000; i++ {
err := DoWithFallback("fallback", func() error {
return errDummy
}, func(err error) error {
return nil
})
assert.True(t, err == nil || err == errDummy)
}
verify(t, func() bool {
return ErrServiceUnavailable == Do("fallback", func() error {
return nil
})
})
}
func TestBreakersAcceptableFallback(t *testing.T) {
errDummy := errors.New("any")
for i := 0; i < 10000; i++ {
err := DoWithFallbackAcceptable("acceptablefallback", func() error {
return errDummy
}, func(err error) error {
return nil
}, func(err error) bool {
return err == nil
})
assert.True(t, err == nil || err == errDummy)
}
verify(t, func() bool {
return ErrServiceUnavailable == Do("acceptablefallback", func() error {
return nil
})
})
}
func verify(t *testing.T, fn func() bool) {
var count int
for i := 0; i < 100; i++ {
if fn() {
count++
}
}
assert.True(t, count >= 80, fmt.Sprintf("should be greater than 80, actual %d", count))
}

View File

@ -0,0 +1,125 @@
package breaker
import (
"math"
"sync/atomic"
"time"
"zero/core/collection"
"zero/core/mathx"
)
const (
// 250ms for bucket duration
window = time.Second * 10
buckets = 40
k = 1.5
protection = 5
)
// googleBreaker is a netflixBreaker pattern from google.
// see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/
type googleBreaker struct {
k float64
state int32
stat *collection.RollingWindow
proba *mathx.Proba
}
func newGoogleBreaker() *googleBreaker {
bucketDuration := time.Duration(int64(window) / int64(buckets))
st := collection.NewRollingWindow(buckets, bucketDuration)
return &googleBreaker{
stat: st,
k: k,
state: StateClosed,
proba: mathx.NewProba(),
}
}
func (b *googleBreaker) accept() error {
accepts, total := b.history()
weightedAccepts := b.k * float64(accepts)
// https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101
dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1))
if dropRatio <= 0 {
if atomic.LoadInt32(&b.state) == StateOpen {
atomic.CompareAndSwapInt32(&b.state, StateOpen, StateClosed)
}
return nil
}
if atomic.LoadInt32(&b.state) == StateClosed {
atomic.CompareAndSwapInt32(&b.state, StateClosed, StateOpen)
}
if b.proba.TrueOnProba(dropRatio) {
return ErrServiceUnavailable
}
return nil
}
func (b *googleBreaker) allow() (internalPromise, error) {
if err := b.accept(); err != nil {
return nil, err
}
return googlePromise{
b: b,
}, nil
}
func (b *googleBreaker) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
if err := b.accept(); err != nil {
if fallback != nil {
return fallback(err)
} else {
return err
}
}
defer func() {
if e := recover(); e != nil {
b.markFailure()
panic(e)
}
}()
err := req()
if acceptable(err) {
b.markSuccess()
} else {
b.markFailure()
}
return err
}
func (b *googleBreaker) markSuccess() {
b.stat.Add(1)
}
func (b *googleBreaker) markFailure() {
b.stat.Add(0)
}
func (b *googleBreaker) history() (accepts int64, total int64) {
b.stat.Reduce(func(b *collection.Bucket) {
accepts += int64(b.Sum)
total += b.Count
})
return
}
type googlePromise struct {
b *googleBreaker
}
func (p googlePromise) Accept() {
p.b.markSuccess()
}
func (p googlePromise) Reject() {
p.b.markFailure()
}

View File

@ -0,0 +1,238 @@
package breaker
import (
"errors"
"math"
"math/rand"
"testing"
"time"
"zero/core/collection"
"zero/core/mathx"
"zero/core/stat"
"github.com/stretchr/testify/assert"
)
const (
testBuckets = 10
testInterval = time.Millisecond * 10
)
func init() {
stat.SetReporter(nil)
}
func getGoogleBreaker() *googleBreaker {
st := collection.NewRollingWindow(testBuckets, testInterval)
return &googleBreaker{
stat: st,
k: 5,
state: StateClosed,
proba: mathx.NewProba(),
}
}
func markSuccessWithDuration(b *googleBreaker, count int, sleep time.Duration) {
for i := 0; i < count; i++ {
b.markSuccess()
time.Sleep(sleep)
}
}
func markFailedWithDuration(b *googleBreaker, count int, sleep time.Duration) {
for i := 0; i < count; i++ {
b.markFailure()
time.Sleep(sleep)
}
}
func TestGoogleBreakerClose(t *testing.T) {
b := getGoogleBreaker()
markSuccess(b, 80)
assert.Nil(t, b.accept())
markSuccess(b, 120)
assert.Nil(t, b.accept())
}
func TestGoogleBreakerOpen(t *testing.T) {
b := getGoogleBreaker()
markSuccess(b, 10)
assert.Nil(t, b.accept())
markFailed(b, 100000)
time.Sleep(testInterval * 2)
verify(t, func() bool {
return b.accept() != nil
})
}
func TestGoogleBreakerFallback(t *testing.T) {
b := getGoogleBreaker()
markSuccess(b, 1)
assert.Nil(t, b.accept())
markFailed(b, 10000)
time.Sleep(testInterval * 2)
verify(t, func() bool {
return b.doReq(func() error {
return errors.New("any")
}, func(err error) error {
return nil
}, defaultAcceptable) == nil
})
}
func TestGoogleBreakerReject(t *testing.T) {
b := getGoogleBreaker()
markSuccess(b, 100)
assert.Nil(t, b.accept())
markFailed(b, 10000)
time.Sleep(testInterval)
assert.Equal(t, ErrServiceUnavailable, b.doReq(func() error {
return ErrServiceUnavailable
}, nil, defaultAcceptable))
}
func TestGoogleBreakerAcceptable(t *testing.T) {
b := getGoogleBreaker()
errAcceptable := errors.New("any")
assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable
}, nil, func(err error) bool {
return err == errAcceptable
}))
}
func TestGoogleBreakerNotAcceptable(t *testing.T) {
b := getGoogleBreaker()
errAcceptable := errors.New("any")
assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable
}, nil, func(err error) bool {
return err != errAcceptable
}))
}
func TestGoogleBreakerPanic(t *testing.T) {
b := getGoogleBreaker()
assert.Panics(t, func() {
_ = b.doReq(func() error {
panic("fail")
}, nil, defaultAcceptable)
})
}
func TestGoogleBreakerHalfOpen(t *testing.T) {
b := getGoogleBreaker()
assert.Nil(t, b.accept())
t.Run("accept single failed/accept", func(t *testing.T) {
markFailed(b, 10000)
time.Sleep(testInterval * 2)
verify(t, func() bool {
return b.accept() != nil
})
})
t.Run("accept single failed/allow", func(t *testing.T) {
markFailed(b, 10000)
time.Sleep(testInterval * 2)
verify(t, func() bool {
_, err := b.allow()
return err != nil
})
})
time.Sleep(testInterval * testBuckets)
t.Run("accept single succeed", func(t *testing.T) {
assert.Nil(t, b.accept())
markSuccess(b, 10000)
verify(t, func() bool {
return b.accept() == nil
})
})
}
func TestGoogleBreakerSelfProtection(t *testing.T) {
t.Run("total request < 100", func(t *testing.T) {
b := getGoogleBreaker()
markFailed(b, 4)
time.Sleep(testInterval)
assert.Nil(t, b.accept())
})
t.Run("total request > 100, total < 2 * success", func(t *testing.T) {
b := getGoogleBreaker()
size := rand.Intn(10000)
accepts := int(math.Ceil(float64(size))) + 1
markSuccess(b, accepts)
markFailed(b, size-accepts)
assert.Nil(t, b.accept())
})
}
func TestGoogleBreakerHistory(t *testing.T) {
var b *googleBreaker
var accepts, total int64
sleep := testInterval
t.Run("accepts == total", func(t *testing.T) {
b = getGoogleBreaker()
markSuccessWithDuration(b, 10, sleep/2)
accepts, total = b.history()
assert.Equal(t, int64(10), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("fail == total", func(t *testing.T) {
b = getGoogleBreaker()
markFailedWithDuration(b, 10, sleep/2)
accepts, total = b.history()
assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("accepts = 1/2 * total, fail = 1/2 * total", func(t *testing.T) {
b = getGoogleBreaker()
markFailedWithDuration(b, 5, sleep/2)
markSuccessWithDuration(b, 5, sleep/2)
accepts, total = b.history()
assert.Equal(t, int64(5), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("auto reset rolling counter", func(t *testing.T) {
b = getGoogleBreaker()
time.Sleep(testInterval * testBuckets)
accepts, total = b.history()
assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(0), total)
})
}
func BenchmarkGoogleBreakerAllow(b *testing.B) {
breaker := getGoogleBreaker()
b.ResetTimer()
for i := 0; i <= b.N; i++ {
breaker.accept()
if i%2 == 0 {
breaker.markSuccess()
} else {
breaker.markFailure()
}
}
}
func markSuccess(b *googleBreaker, count int) {
for i := 0; i < count; i++ {
p, err := b.allow()
if err != nil {
break
}
p.Accept()
}
}
func markFailed(b *googleBreaker, count int) {
for i := 0; i < count; i++ {
p, err := b.allow()
if err == nil {
p.Reject()
}
}
}

View File

@ -0,0 +1,42 @@
package breaker
const noOpBreakerName = "nopBreaker"
type noOpBreaker struct{}
func newNoOpBreaker() Breaker {
return noOpBreaker{}
}
func (b noOpBreaker) Name() string {
return noOpBreakerName
}
func (b noOpBreaker) Allow() (Promise, error) {
return nopPromise{}, nil
}
func (b noOpBreaker) Do(req func() error) error {
return req()
}
func (b noOpBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
return req()
}
func (b noOpBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return req()
}
func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
acceptable Acceptable) error {
return req()
}
type nopPromise struct{}
func (p nopPromise) Accept() {
}
func (p nopPromise) Reject(reason string) {
}

View File

@ -0,0 +1,38 @@
package breaker
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNopBreaker(t *testing.T) {
b := newNoOpBreaker()
assert.Equal(t, noOpBreakerName, b.Name())
p, err := b.Allow()
assert.Nil(t, err)
p.Accept()
for i := 0; i < 1000; i++ {
p, err := b.Allow()
assert.Nil(t, err)
p.Reject("any")
}
assert.Nil(t, b.Do(func() error {
return nil
}))
assert.Nil(t, b.DoWithAcceptable(func() error {
return nil
}, defaultAcceptable))
errDummy := errors.New("any")
assert.Equal(t, errDummy, b.DoWithFallback(func() error {
return errDummy
}, func(err error) error {
return nil
}))
assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error {
return errDummy
}, func(err error) error {
return nil
}, defaultAcceptable))
}

19
core/cmdline/input.go Normal file
View File

@ -0,0 +1,19 @@
package cmdline
import (
"bufio"
"fmt"
"os"
"strings"
)
func EnterToContinue() {
fmt.Print("Press 'Enter' to continue...")
bufio.NewReader(os.Stdin).ReadBytes('\n')
}
func ReadLine(prompt string) string {
fmt.Print(prompt)
input, _ := bufio.NewReader(os.Stdin).ReadString('\n')
return strings.TrimSpace(input)
}

174
core/codec/aesecb.go Normal file
View File

@ -0,0 +1,174 @@
package codec
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"errors"
"zero/core/logx"
)
var ErrPaddingSize = errors.New("padding size error")
type ecb struct {
b cipher.Block
blockSize int
}
func newECB(b cipher.Block) *ecb {
return &ecb{
b: b,
blockSize: b.BlockSize(),
}
}
type ecbEncrypter ecb
func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b))
}
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
// why we don't return error is because cipher.BlockMode doesn't allow this
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks")
return
}
if len(dst) < len(src) {
logx.Error("crypto/cipher: output smaller than input")
return
}
for len(src) > 0 {
x.b.Encrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
type ecbDecrypter ecb
func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b))
}
func (x *ecbDecrypter) BlockSize() int {
return x.blockSize
}
// why we don't return error is because cipher.BlockMode doesn't allow this
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks")
return
}
if len(dst) < len(src) {
logx.Error("crypto/cipher: output smaller than input")
return
}
for len(src) > 0 {
x.b.Decrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
func EcbDecrypt(key, src []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
logx.Errorf("Decrypt key error: % x", key)
return nil, err
}
decrypter := NewECBDecrypter(block)
decrypted := make([]byte, len(src))
decrypter.CryptBlocks(decrypted, src)
return pkcs5Unpadding(decrypted, decrypter.BlockSize())
}
func EcbDecryptBase64(key, src string) (string, error) {
keyBytes, err := getKeyBytes(key)
if err != nil {
return "", err
}
encryptedBytes, err := base64.StdEncoding.DecodeString(src)
if err != nil {
return "", err
}
decryptedBytes, err := EcbDecrypt(keyBytes, encryptedBytes)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(decryptedBytes), nil
}
func EcbEncrypt(key, src []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
logx.Errorf("Encrypt key error: % x", key)
return nil, err
}
padded := pkcs5Padding(src, block.BlockSize())
crypted := make([]byte, len(padded))
encrypter := NewECBEncrypter(block)
encrypter.CryptBlocks(crypted, padded)
return crypted, nil
}
func EcbEncryptBase64(key, src string) (string, error) {
keyBytes, err := getKeyBytes(key)
if err != nil {
return "", err
}
srcBytes, err := base64.StdEncoding.DecodeString(src)
if err != nil {
return "", err
}
encryptedBytes, err := EcbEncrypt(keyBytes, srcBytes)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encryptedBytes), nil
}
func getKeyBytes(key string) ([]byte, error) {
if len(key) > 32 {
if keyBytes, err := base64.StdEncoding.DecodeString(key); err != nil {
return nil, err
} else {
return keyBytes, nil
}
}
return []byte(key), nil
}
func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) {
length := len(src)
unpadding := int(src[length-1])
if unpadding >= length || unpadding > blockSize {
return nil, ErrPaddingSize
}
return src[:length-unpadding], nil
}

88
core/codec/dh.go Normal file
View File

@ -0,0 +1,88 @@
package codec
import (
"crypto/rand"
"errors"
"math/big"
)
// see https://www.zhihu.com/question/29383090/answer/70435297
// see https://www.ietf.org/rfc/rfc3526.txt
// 2048-bit MODP Group
var (
ErrInvalidPriKey = errors.New("invalid private key")
ErrInvalidPubKey = errors.New("invalid public key")
ErrPubKeyOutOfBound = errors.New("public key out of bound")
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
g, _ = new(big.Int).SetString("2", 16)
zero = big.NewInt(0)
)
type DhKey struct {
PriKey *big.Int
PubKey *big.Int
}
func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
if pubKey == nil {
return nil, ErrInvalidPubKey
}
if pubKey.Sign() <= 0 && p.Cmp(pubKey) <= 0 {
return nil, ErrPubKeyOutOfBound
}
if priKey == nil {
return nil, ErrInvalidPriKey
}
return new(big.Int).Exp(pubKey, priKey, p), nil
}
func GenerateKey() (*DhKey, error) {
var err error
var x *big.Int
for {
x, err = rand.Int(rand.Reader, p)
if err != nil {
return nil, err
}
if zero.Cmp(x) < 0 {
break
}
}
key := new(DhKey)
key.PriKey = x
key.PubKey = new(big.Int).Exp(g, x, p)
return key, nil
}
func NewPublicKey(bs []byte) *big.Int {
return new(big.Int).SetBytes(bs)
}
func (k *DhKey) Bytes() []byte {
if k.PubKey == nil {
return nil
}
byteLen := (p.BitLen() + 7) >> 3
ret := make([]byte, byteLen)
copyWithLeftPad(ret, k.PubKey.Bytes())
return ret
}
func copyWithLeftPad(dst, src []byte) {
padBytes := len(dst) - len(src)
for i := 0; i < padBytes; i++ {
dst[i] = 0
}
copy(dst[padBytes:], src)
}

73
core/codec/dh_test.go Normal file
View File

@ -0,0 +1,73 @@
package codec
import (
"math/big"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDiffieHellman(t *testing.T) {
key1, err := GenerateKey()
assert.Nil(t, err)
key2, err := GenerateKey()
assert.Nil(t, err)
pubKey1, err := ComputeKey(key1.PubKey, key2.PriKey)
assert.Nil(t, err)
pubKey2, err := ComputeKey(key2.PubKey, key1.PriKey)
assert.Nil(t, err)
assert.Equal(t, pubKey1, pubKey2)
}
func TestDiffieHellman1024(t *testing.T) {
old := p
p, _ = new(big.Int).SetString("F488FD584E49DBCD20B49DE49107366B336C380D451D0F7C88B31C7C5B2D8EF6F3C923C043F0A55B188D8EBB558CB85D38D334FD7C175743A31D186CDE33212CB52AFF3CE1B1294018118D7C84A70A72D686C40319C807297ACA950CD9969FABD00A509B0246D3083D66A45D419F9C7CBD894B221926BAABA25EC355E92F78C7", 16)
defer func() {
p = old
}()
key1, err := GenerateKey()
assert.Nil(t, err)
key2, err := GenerateKey()
assert.Nil(t, err)
pubKey1, err := ComputeKey(key1.PubKey, key2.PriKey)
assert.Nil(t, err)
pubKey2, err := ComputeKey(key2.PubKey, key1.PriKey)
assert.Nil(t, err)
assert.Equal(t, pubKey1, pubKey2)
}
func TestDiffieHellmanMiddleManAttack(t *testing.T) {
key1, err := GenerateKey()
assert.Nil(t, err)
keyMiddle, err := GenerateKey()
assert.Nil(t, err)
key2, err := GenerateKey()
assert.Nil(t, err)
const aesByteLen = 32
pubKey1, err := ComputeKey(keyMiddle.PubKey, key1.PriKey)
assert.Nil(t, err)
src := []byte(`hello, world!`)
encryptedSrc, err := EcbEncrypt(pubKey1.Bytes()[:aesByteLen], src)
assert.Nil(t, err)
pubKeyMiddle, err := ComputeKey(key1.PubKey, keyMiddle.PriKey)
assert.Nil(t, err)
decryptedSrc, err := EcbDecrypt(pubKeyMiddle.Bytes()[:aesByteLen], encryptedSrc)
assert.Nil(t, err)
assert.Equal(t, string(src), string(decryptedSrc))
pubKeyMiddle, err = ComputeKey(key2.PubKey, keyMiddle.PriKey)
assert.Nil(t, err)
encryptedSrc, err = EcbEncrypt(pubKeyMiddle.Bytes()[:aesByteLen], decryptedSrc)
assert.Nil(t, err)
pubKey2, err := ComputeKey(keyMiddle.PubKey, key2.PriKey)
assert.Nil(t, err)
decryptedSrc, err = EcbDecrypt(pubKey2.Bytes()[:aesByteLen], encryptedSrc)
assert.Nil(t, err)
assert.Equal(t, string(src), string(decryptedSrc))
}

33
core/codec/gzip.go Normal file
View File

@ -0,0 +1,33 @@
package codec
import (
"bytes"
"compress/gzip"
"io"
)
func Gzip(bs []byte) []byte {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write(bs)
w.Close()
return b.Bytes()
}
func Gunzip(bs []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewBuffer(bs))
if err != nil {
return nil, err
}
defer r.Close()
var c bytes.Buffer
_, err = io.Copy(&c, r)
if err != nil {
return nil, err
}
return c.Bytes(), nil
}

23
core/codec/gzip_test.go Normal file
View File

@ -0,0 +1,23 @@
package codec
import (
"bytes"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGzip(t *testing.T) {
var buf bytes.Buffer
for i := 0; i < 10000; i++ {
fmt.Fprint(&buf, i)
}
bs := Gzip(buf.Bytes())
actual, err := Gunzip(bs)
assert.Nil(t, err)
assert.True(t, len(bs) < buf.Len())
assert.Equal(t, buf.Bytes(), actual)
}

18
core/codec/hmac.go Normal file
View File

@ -0,0 +1,18 @@
package codec
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"io"
)
func Hmac(key []byte, body string) []byte {
h := hmac.New(sha256.New, key)
io.WriteString(h, body)
return h.Sum(nil)
}
func HmacBase64(key []byte, body string) string {
return base64.StdEncoding.EncodeToString(Hmac(key, body))
}

149
core/codec/rsa.go Normal file
View File

@ -0,0 +1,149 @@
package codec
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"io/ioutil"
)
var (
ErrPrivateKey = errors.New("private key error")
ErrPublicKey = errors.New("failed to parse PEM block containing the public key")
ErrNotRsaKey = errors.New("key type is not RSA")
)
type (
RsaDecrypter interface {
Decrypt(input []byte) ([]byte, error)
DecryptBase64(input string) ([]byte, error)
}
RsaEncrypter interface {
Encrypt(input []byte) ([]byte, error)
}
rsaBase struct {
bytesLimit int
}
rsaDecrypter struct {
rsaBase
privateKey *rsa.PrivateKey
}
rsaEncrypter struct {
rsaBase
publicKey *rsa.PublicKey
}
)
func NewRsaDecrypter(file string) (RsaDecrypter, error) {
content, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
block, _ := pem.Decode(content)
if block == nil {
return nil, ErrPrivateKey
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return &rsaDecrypter{
rsaBase: rsaBase{
bytesLimit: privateKey.N.BitLen() >> 3,
},
privateKey: privateKey,
}, nil
}
func (r *rsaDecrypter) Decrypt(input []byte) ([]byte, error) {
return r.crypt(input, func(block []byte) ([]byte, error) {
return rsaDecryptBlock(r.privateKey, block)
})
}
func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
if len(input) == 0 {
return nil, nil
}
base64Decoded, err := base64.StdEncoding.DecodeString(input)
if err != nil {
return nil, err
}
return r.Decrypt(base64Decoded)
}
func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
block, _ := pem.Decode(key)
if block == nil {
return nil, ErrPublicKey
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
switch pubKey := pub.(type) {
case *rsa.PublicKey:
return &rsaEncrypter{
rsaBase: rsaBase{
// https://www.ietf.org/rfc/rfc2313.txt
// The length of the data D shall not be more than k-11 octets, which is
// positive since the length k of the modulus is at least 12 octets.
bytesLimit: (pubKey.N.BitLen() >> 3) - 11,
},
publicKey: pubKey,
}, nil
default:
return nil, ErrNotRsaKey
}
}
func (r *rsaEncrypter) Encrypt(input []byte) ([]byte, error) {
return r.crypt(input, func(block []byte) ([]byte, error) {
return rsaEncryptBlock(r.publicKey, block)
})
}
func (r *rsaBase) crypt(input []byte, cryptFn func([]byte) ([]byte, error)) ([]byte, error) {
var result []byte
inputLen := len(input)
for i := 0; i*r.bytesLimit < inputLen; i++ {
start := r.bytesLimit * i
var stop int
if r.bytesLimit*(i+1) > inputLen {
stop = inputLen
} else {
stop = r.bytesLimit * (i + 1)
}
bs, err := cryptFn(input[start:stop])
if err != nil {
return nil, err
}
result = append(result, bs...)
}
return result, nil
}
func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) {
return rsa.DecryptPKCS1v15(rand.Reader, privateKey, block)
}
func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) {
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg)
}

275
core/collection/cache.go Normal file
View File

@ -0,0 +1,275 @@
package collection
import (
"container/list"
"sync"
"sync/atomic"
"time"
"zero/core/logx"
"zero/core/mathx"
"zero/core/syncx"
)
const (
defaultCacheName = "proc"
slots = 300
statInterval = time.Minute
// make the expiry unstable to avoid lots of cached items expire at the same time
// make the unstable expiry to be [0.95, 1.05] * seconds
expiryDeviation = 0.05
)
var emptyLruCache = emptyLru{}
type (
CacheOption func(cache *Cache)
Cache struct {
name string
lock sync.Mutex
data map[string]interface{}
evicts *list.List
expire time.Duration
timingWheel *TimingWheel
lruCache lru
barrier syncx.SharedCalls
unstableExpiry mathx.Unstable
stats *cacheStat
}
)
func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
cache := &Cache{
data: make(map[string]interface{}),
expire: expire,
lruCache: emptyLruCache,
barrier: syncx.NewSharedCalls(),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
}
for _, opt := range opts {
opt(cache)
}
if len(cache.name) == 0 {
cache.name = defaultCacheName
}
cache.stats = newCacheStat(cache.name, cache.size)
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v interface{}) {
key, ok := k.(string)
if !ok {
return
}
cache.Del(key)
})
if err != nil {
return nil, err
}
cache.timingWheel = timingWheel
return cache, nil
}
func (c *Cache) Del(key string) {
c.lock.Lock()
delete(c.data, key)
c.lruCache.remove(key)
c.lock.Unlock()
c.timingWheel.RemoveTimer(key)
}
func (c *Cache) Get(key string) (interface{}, bool) {
c.lock.Lock()
value, ok := c.data[key]
if ok {
c.lruCache.add(key)
}
c.lock.Unlock()
if ok {
c.stats.IncrementHit()
} else {
c.stats.IncrementMiss()
}
return value, ok
}
func (c *Cache) Set(key string, value interface{}) {
c.lock.Lock()
_, ok := c.data[key]
c.data[key] = value
c.lruCache.add(key)
c.lock.Unlock()
expiry := c.unstableExpiry.AroundDuration(c.expire)
if ok {
c.timingWheel.MoveTimer(key, expiry)
} else {
c.timingWheel.SetTimer(key, value, expiry)
}
}
func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) {
v, e := fetch()
if e != nil {
return nil, e
}
c.Set(key, v)
return v, nil
})
if err != nil {
return nil, err
}
if fresh {
c.stats.IncrementMiss()
return val, nil
} else {
// got the result from previous ongoing query
c.stats.IncrementHit()
}
return val, nil
}
func (c *Cache) onEvict(key string) {
// already locked
delete(c.data, key)
c.timingWheel.RemoveTimer(key)
}
func (c *Cache) size() int {
c.lock.Lock()
defer c.lock.Unlock()
return len(c.data)
}
func WithLimit(limit int) CacheOption {
return func(cache *Cache) {
if limit > 0 {
cache.lruCache = newKeyLru(limit, cache.onEvict)
}
}
}
func WithName(name string) CacheOption {
return func(cache *Cache) {
cache.name = name
}
}
type (
lru interface {
add(key string)
remove(key string)
}
emptyLru struct{}
keyLru struct {
limit int
evicts *list.List
elements map[string]*list.Element
onEvict func(key string)
}
)
func (elru emptyLru) add(string) {
}
func (elru emptyLru) remove(string) {
}
func newKeyLru(limit int, onEvict func(key string)) *keyLru {
return &keyLru{
limit: limit,
evicts: list.New(),
elements: make(map[string]*list.Element),
onEvict: onEvict,
}
}
func (klru *keyLru) add(key string) {
if elem, ok := klru.elements[key]; ok {
klru.evicts.MoveToFront(elem)
return
}
// Add new item
elem := klru.evicts.PushFront(key)
klru.elements[key] = elem
// Verify size not exceeded
if klru.evicts.Len() > klru.limit {
klru.removeOldest()
}
}
func (klru *keyLru) remove(key string) {
if elem, ok := klru.elements[key]; ok {
klru.removeElement(elem)
}
}
func (klru *keyLru) removeOldest() {
elem := klru.evicts.Back()
if elem != nil {
klru.removeElement(elem)
}
}
func (klru *keyLru) removeElement(e *list.Element) {
klru.evicts.Remove(e)
key := e.Value.(string)
delete(klru.elements, key)
klru.onEvict(key)
}
type cacheStat struct {
name string
hit uint64
miss uint64
sizeCallback func() int
}
func newCacheStat(name string, sizeCallback func() int) *cacheStat {
st := &cacheStat{
name: name,
sizeCallback: sizeCallback,
}
go st.statLoop()
return st
}
func (cs *cacheStat) IncrementHit() {
atomic.AddUint64(&cs.hit, 1)
}
func (cs *cacheStat) IncrementMiss() {
atomic.AddUint64(&cs.miss, 1)
}
func (cs *cacheStat) statLoop() {
ticker := time.NewTicker(statInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
hit := atomic.SwapUint64(&cs.hit, 0)
miss := atomic.SwapUint64(&cs.miss, 0)
total := hit + miss
if total == 0 {
continue
}
percent := 100 * float32(hit) / float32(total)
logx.Statf("cache(%s) - qpm: %d, hit_ratio: %.1f%%, elements: %d, hit: %d, miss: %d",
cs.name, total, percent, cs.sizeCallback(), hit, miss)
}
}
}

View File

@ -0,0 +1,139 @@
package collection
import (
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCacheSet(t *testing.T) {
cache, err := NewCache(time.Second*2, WithName("any"))
assert.Nil(t, err)
cache.Set("first", "first element")
cache.Set("second", "second element")
value, ok := cache.Get("first")
assert.True(t, ok)
assert.Equal(t, "first element", value)
value, ok = cache.Get("second")
assert.True(t, ok)
assert.Equal(t, "second element", value)
}
func TestCacheDel(t *testing.T) {
cache, err := NewCache(time.Second * 2)
assert.Nil(t, err)
cache.Set("first", "first element")
cache.Set("second", "second element")
cache.Del("first")
_, ok := cache.Get("first")
assert.False(t, ok)
value, ok := cache.Get("second")
assert.True(t, ok)
assert.Equal(t, "second element", value)
}
func TestCacheTake(t *testing.T) {
cache, err := NewCache(time.Second * 2)
assert.Nil(t, err)
var count int32
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "first element", nil
})
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, 1, cache.size())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
func TestCacheWithLruEvicts(t *testing.T) {
cache, err := NewCache(time.Minute, WithLimit(3))
assert.Nil(t, err)
cache.Set("first", "first element")
cache.Set("second", "second element")
cache.Set("third", "third element")
cache.Set("fourth", "fourth element")
value, ok := cache.Get("first")
assert.False(t, ok)
value, ok = cache.Get("second")
assert.True(t, ok)
assert.Equal(t, "second element", value)
value, ok = cache.Get("third")
assert.True(t, ok)
assert.Equal(t, "third element", value)
value, ok = cache.Get("fourth")
assert.True(t, ok)
assert.Equal(t, "fourth element", value)
}
func TestCacheWithLruEvicted(t *testing.T) {
cache, err := NewCache(time.Minute, WithLimit(3))
assert.Nil(t, err)
cache.Set("first", "first element")
cache.Set("second", "second element")
cache.Set("third", "third element")
cache.Set("fourth", "fourth element")
value, ok := cache.Get("first")
assert.False(t, ok)
value, ok = cache.Get("second")
assert.True(t, ok)
assert.Equal(t, "second element", value)
cache.Set("fifth", "fifth element")
cache.Set("sixth", "sixth element")
_, ok = cache.Get("third")
assert.False(t, ok)
_, ok = cache.Get("fourth")
assert.False(t, ok)
value, ok = cache.Get("second")
assert.True(t, ok)
assert.Equal(t, "second element", value)
}
func BenchmarkCache(b *testing.B) {
cache, err := NewCache(time.Second*5, WithLimit(100000))
if err != nil {
b.Fatal(err)
}
for i := 0; i < 10000; i++ {
for j := 0; j < 10; j++ {
index := strconv.Itoa(i*10000 + j)
cache.Set("key:"+index, "value:"+index)
}
}
time.Sleep(time.Second * 5)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
for i := 0; i < b.N; i++ {
index := strconv.Itoa(i % 10000)
cache.Get("key:" + index)
if i%100 == 0 {
cache.Set("key1:"+index, "value1:"+index)
}
}
}
})
}

60
core/collection/fifo.go Normal file
View File

@ -0,0 +1,60 @@
package collection
import "sync"
type Queue struct {
lock sync.Mutex
elements []interface{}
size int
head int
tail int
count int
}
func NewQueue(size int) *Queue {
return &Queue{
elements: make([]interface{}, size),
size: size,
}
}
func (q *Queue) Empty() bool {
q.lock.Lock()
empty := q.count == 0
q.lock.Unlock()
return empty
}
func (q *Queue) Put(element interface{}) {
q.lock.Lock()
defer q.lock.Unlock()
if q.head == q.tail && q.count > 0 {
nodes := make([]interface{}, len(q.elements)+q.size)
copy(nodes, q.elements[q.head:])
copy(nodes[len(q.elements)-q.head:], q.elements[:q.head])
q.head = 0
q.tail = len(q.elements)
q.elements = nodes
}
q.elements[q.tail] = element
q.tail = (q.tail + 1) % len(q.elements)
q.count++
}
func (q *Queue) Take() (interface{}, bool) {
q.lock.Lock()
defer q.lock.Unlock()
if q.count == 0 {
return nil, false
}
element := q.elements[q.head]
q.head = (q.head + 1) % len(q.elements)
q.count--
return element, true
}

View File

@ -0,0 +1,63 @@
package collection
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFifo(t *testing.T) {
elements := [][]byte{
[]byte("hello"),
[]byte("world"),
[]byte("again"),
}
queue := NewQueue(8)
for i := range elements {
queue.Put(elements[i])
}
for _, element := range elements {
body, ok := queue.Take()
assert.True(t, ok)
assert.Equal(t, string(element), string(body.([]byte)))
}
}
func TestTakeTooMany(t *testing.T) {
elements := [][]byte{
[]byte("hello"),
[]byte("world"),
[]byte("again"),
}
queue := NewQueue(8)
for i := range elements {
queue.Put(elements[i])
}
for range elements {
queue.Take()
}
assert.True(t, queue.Empty())
_, ok := queue.Take()
assert.False(t, ok)
}
func TestPutMore(t *testing.T) {
elements := [][]byte{
[]byte("hello"),
[]byte("world"),
[]byte("again"),
}
queue := NewQueue(2)
for i := range elements {
queue.Put(elements[i])
}
for _, element := range elements {
body, ok := queue.Take()
assert.True(t, ok)
assert.Equal(t, string(element), string(body.([]byte)))
}
}

35
core/collection/ring.go Normal file
View File

@ -0,0 +1,35 @@
package collection
type Ring struct {
elements []interface{}
index int
}
func NewRing(n int) *Ring {
return &Ring{
elements: make([]interface{}, n),
}
}
func (r *Ring) Add(v interface{}) {
r.elements[r.index%len(r.elements)] = v
r.index++
}
func (r *Ring) Take() []interface{} {
var size int
var start int
if r.index > len(r.elements) {
size = len(r.elements)
start = r.index % len(r.elements)
} else {
size = r.index
}
elements := make([]interface{}, size)
for i := 0; i < size; i++ {
elements[i] = r.elements[(start+i)%len(r.elements)]
}
return elements
}

View File

@ -0,0 +1,25 @@
package collection
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRingLess(t *testing.T) {
ring := NewRing(5)
for i := 0; i < 3; i++ {
ring.Add(i)
}
elements := ring.Take()
assert.ElementsMatch(t, []interface{}{0, 1, 2}, elements)
}
func TestRingMore(t *testing.T) {
ring := NewRing(5)
for i := 0; i < 11; i++ {
ring.Add(i)
}
elements := ring.Take()
assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements)
}

View File

@ -0,0 +1,145 @@
package collection
import (
"sync"
"time"
"zero/core/timex"
)
type (
RollingWindowOption func(rollingWindow *RollingWindow)
RollingWindow struct {
lock sync.RWMutex
size int
win *window
interval time.Duration
offset int
ignoreCurrent bool
lastTime time.Duration
}
)
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
w := &RollingWindow{
size: size,
win: newWindow(size),
interval: interval,
lastTime: timex.Now(),
}
for _, opt := range opts {
opt(w)
}
return w
}
func (rw *RollingWindow) Add(v float64) {
rw.lock.Lock()
defer rw.lock.Unlock()
rw.updateOffset()
rw.win.add(rw.offset, v)
}
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
rw.lock.RLock()
defer rw.lock.RUnlock()
var diff int
span := rw.span()
// ignore current bucket, because of partial data
if span == 0 && rw.ignoreCurrent {
diff = rw.size - 1
} else {
diff = rw.size - span
}
if diff > 0 {
offset := (rw.offset + span + 1) % rw.size
rw.win.reduce(offset, diff, fn)
}
}
func (rw *RollingWindow) span() int {
offset := int(timex.Since(rw.lastTime) / rw.interval)
if 0 <= offset && offset < rw.size {
return offset
} else {
return rw.size
}
}
func (rw *RollingWindow) updateOffset() {
span := rw.span()
if span > 0 {
offset := rw.offset
// reset expired buckets
start := offset + 1
steps := start + span
var remainder int
if steps > rw.size {
remainder = steps - rw.size
steps = rw.size
}
for i := start; i < steps; i++ {
rw.win.resetBucket(i)
offset = i
}
for i := 0; i < remainder; i++ {
rw.win.resetBucket(i)
offset = i
}
rw.offset = offset
rw.lastTime = timex.Now()
}
}
type Bucket struct {
Sum float64
Count int64
}
func (b *Bucket) add(v float64) {
b.Sum += v
b.Count++
}
func (b *Bucket) reset() {
b.Sum = 0
b.Count = 0
}
type window struct {
buckets []*Bucket
size int
}
func newWindow(size int) *window {
var buckets []*Bucket
for i := 0; i < size; i++ {
buckets = append(buckets, new(Bucket))
}
return &window{
buckets: buckets,
size: size,
}
}
func (w *window) add(offset int, v float64) {
w.buckets[offset%w.size].add(v)
}
func (w *window) reduce(start, count int, fn func(b *Bucket)) {
for i := 0; i < count; i++ {
fn(w.buckets[(start+i)%len(w.buckets)])
}
}
func (w *window) resetBucket(offset int) {
w.buckets[offset].reset()
}
func IgnoreCurrentBucket() RollingWindowOption {
return func(w *RollingWindow) {
w.ignoreCurrent = true
}
}

View File

@ -0,0 +1,133 @@
package collection
import (
"math/rand"
"testing"
"time"
"zero/core/stringx"
"github.com/stretchr/testify/assert"
)
const duration = time.Millisecond * 50
func TestRollingWindowAdd(t *testing.T) {
const size = 3
r := NewRollingWindow(size, duration)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
}
assert.Equal(t, []float64{0, 0, 0}, listBuckets())
r.Add(1)
assert.Equal(t, []float64{0, 0, 1}, listBuckets())
elapse()
r.Add(2)
r.Add(3)
assert.Equal(t, []float64{0, 1, 5}, listBuckets())
elapse()
r.Add(4)
r.Add(5)
r.Add(6)
assert.Equal(t, []float64{1, 5, 15}, listBuckets())
elapse()
r.Add(7)
assert.Equal(t, []float64{5, 15, 7}, listBuckets())
}
func TestRollingWindowReset(t *testing.T) {
const size = 3
r := NewRollingWindow(size, duration, IgnoreCurrentBucket())
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
}
r.Add(1)
elapse()
assert.Equal(t, []float64{0, 1}, listBuckets())
elapse()
assert.Equal(t, []float64{1}, listBuckets())
elapse()
assert.Nil(t, listBuckets())
// cross window
r.Add(1)
time.Sleep(duration * 10)
assert.Nil(t, listBuckets())
}
func TestRollingWindowReduce(t *testing.T) {
const size = 4
tests := []struct {
win *RollingWindow
expect float64
}{
{
win: NewRollingWindow(size, duration),
expect: 10,
},
{
win: NewRollingWindow(size, duration, IgnoreCurrentBucket()),
expect: 4,
},
}
for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) {
r := test.win
for x := 0; x < size; x = x + 1 {
for i := 0; i <= x; i++ {
r.Add(float64(i))
}
if x < size-1 {
elapse()
}
}
var result float64
r.Reduce(func(b *Bucket) {
result += b.Sum
})
assert.Equal(t, test.expect, result)
})
}
}
func TestRollingWindowDataRace(t *testing.T) {
const size = 3
r := NewRollingWindow(size, duration)
var stop = make(chan bool)
go func() {
for {
select {
case <-stop:
return
default:
r.Add(float64(rand.Int63()))
time.Sleep(duration / 2)
}
}
}()
go func() {
for {
select {
case <-stop:
return
default:
r.Reduce(func(b *Bucket) {})
}
}
}()
time.Sleep(duration * 5)
close(stop)
}
func elapse() {
time.Sleep(duration)
}

View File

@ -0,0 +1,91 @@
package collection
import "sync"
const (
copyThreshold = 1000
maxDeletion = 10000
)
// SafeMap provides a map alternative to avoid memory leak.
// This implementation is not needed until issue below fixed.
// https://github.com/golang/go/issues/20135
type SafeMap struct {
lock sync.RWMutex
deletionOld int
deletionNew int
dirtyOld map[interface{}]interface{}
dirtyNew map[interface{}]interface{}
}
func NewSafeMap() *SafeMap {
return &SafeMap{
dirtyOld: make(map[interface{}]interface{}),
dirtyNew: make(map[interface{}]interface{}),
}
}
func (m *SafeMap) Del(key interface{}) {
m.lock.Lock()
if _, ok := m.dirtyOld[key]; ok {
delete(m.dirtyOld, key)
m.deletionOld++
} else if _, ok := m.dirtyNew[key]; ok {
delete(m.dirtyNew, key)
m.deletionNew++
}
if m.deletionOld >= maxDeletion && len(m.dirtyOld) < copyThreshold {
for k, v := range m.dirtyOld {
m.dirtyNew[k] = v
}
m.dirtyOld = m.dirtyNew
m.deletionOld = m.deletionNew
m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0
}
if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold {
for k, v := range m.dirtyNew {
m.dirtyOld[k] = v
}
m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0
}
m.lock.Unlock()
}
func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
if val, ok := m.dirtyOld[key]; ok {
return val, true
} else {
val, ok := m.dirtyNew[key]
return val, ok
}
}
func (m *SafeMap) Set(key, value interface{}) {
m.lock.Lock()
if m.deletionOld <= maxDeletion {
if _, ok := m.dirtyNew[key]; ok {
delete(m.dirtyNew, key)
m.deletionNew++
}
m.dirtyOld[key] = value
} else {
if _, ok := m.dirtyOld[key]; ok {
delete(m.dirtyOld, key)
m.deletionOld++
}
m.dirtyNew[key] = value
}
m.lock.Unlock()
}
func (m *SafeMap) Size() int {
m.lock.RLock()
size := len(m.dirtyOld) + len(m.dirtyNew)
m.lock.RUnlock()
return size
}

View File

@ -0,0 +1,110 @@
package collection
import (
"testing"
"zero/core/stringx"
"github.com/stretchr/testify/assert"
)
func TestSafeMap(t *testing.T) {
tests := []struct {
size int
exception int
}{
{
100000,
2000,
},
{
100000,
50,
},
}
for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) {
testSafeMapWithParameters(t, test.size, test.exception)
})
}
}
func TestSafeMap_CopyNew(t *testing.T) {
const (
size = 100000
exception1 = 5
exception2 = 500
)
m := NewSafeMap()
for i := 0; i < size; i++ {
m.Set(i, i)
}
for i := 0; i < size; i++ {
if i%exception1 == 0 {
m.Del(i)
}
}
for i := size; i < size<<1; i++ {
m.Set(i, i)
}
for i := size; i < size<<1; i++ {
if i%exception2 != 0 {
m.Del(i)
}
}
for i := 0; i < size; i++ {
val, ok := m.Get(i)
if i%exception1 != 0 {
assert.True(t, ok)
assert.Equal(t, i, val.(int))
} else {
assert.False(t, ok)
}
}
for i := size; i < size<<1; i++ {
val, ok := m.Get(i)
if i%exception2 == 0 {
assert.True(t, ok)
assert.Equal(t, i, val.(int))
} else {
assert.False(t, ok)
}
}
}
func testSafeMapWithParameters(t *testing.T, size, exception int) {
m := NewSafeMap()
for i := 0; i < size; i++ {
m.Set(i, i)
}
for i := 0; i < size; i++ {
if i%exception != 0 {
m.Del(i)
}
}
assert.Equal(t, size/exception, m.Size())
for i := size; i < size<<1; i++ {
m.Set(i, i)
}
for i := size; i < size<<1; i++ {
if i%exception != 0 {
m.Del(i)
}
}
for i := 0; i < size<<1; i++ {
val, ok := m.Get(i)
if i%exception == 0 {
assert.True(t, ok)
assert.Equal(t, i, val.(int))
} else {
assert.False(t, ok)
}
}
}

230
core/collection/set.go Normal file
View File

@ -0,0 +1,230 @@
package collection
import (
"zero/core/lang"
"zero/core/logx"
)
const (
unmanaged = iota
untyped
intType
int64Type
uintType
uint64Type
stringType
)
type Set struct {
data map[interface{}]lang.PlaceholderType
tp int
}
func NewSet() *Set {
return &Set{
data: make(map[interface{}]lang.PlaceholderType),
tp: untyped,
}
}
func NewUnmanagedSet() *Set {
return &Set{
data: make(map[interface{}]lang.PlaceholderType),
tp: unmanaged,
}
}
func (s *Set) Add(i ...interface{}) {
for _, each := range i {
s.add(each)
}
}
func (s *Set) AddInt(ii ...int) {
for _, each := range ii {
s.add(each)
}
}
func (s *Set) AddInt64(ii ...int64) {
for _, each := range ii {
s.add(each)
}
}
func (s *Set) AddUint(ii ...uint) {
for _, each := range ii {
s.add(each)
}
}
func (s *Set) AddUint64(ii ...uint64) {
for _, each := range ii {
s.add(each)
}
}
func (s *Set) AddStr(ss ...string) {
for _, each := range ss {
s.add(each)
}
}
func (s *Set) Contains(i interface{}) bool {
if len(s.data) == 0 {
return false
}
s.validate(i)
_, ok := s.data[i]
return ok
}
func (s *Set) Keys() []interface{} {
var keys []interface{}
for key := range s.data {
keys = append(keys, key)
}
return keys
}
func (s *Set) KeysInt() []int {
var keys []int
for key := range s.data {
if intKey, ok := key.(int); !ok {
continue
} else {
keys = append(keys, intKey)
}
}
return keys
}
func (s *Set) KeysInt64() []int64 {
var keys []int64
for key := range s.data {
if intKey, ok := key.(int64); !ok {
continue
} else {
keys = append(keys, intKey)
}
}
return keys
}
func (s *Set) KeysUint() []uint {
var keys []uint
for key := range s.data {
if intKey, ok := key.(uint); !ok {
continue
} else {
keys = append(keys, intKey)
}
}
return keys
}
func (s *Set) KeysUint64() []uint64 {
var keys []uint64
for key := range s.data {
if intKey, ok := key.(uint64); !ok {
continue
} else {
keys = append(keys, intKey)
}
}
return keys
}
func (s *Set) KeysStr() []string {
var keys []string
for key := range s.data {
if strKey, ok := key.(string); !ok {
continue
} else {
keys = append(keys, strKey)
}
}
return keys
}
func (s *Set) Remove(i interface{}) {
s.validate(i)
delete(s.data, i)
}
func (s *Set) Count() int {
return len(s.data)
}
func (s *Set) add(i interface{}) {
switch s.tp {
case unmanaged:
// do nothing
case untyped:
s.setType(i)
default:
s.validate(i)
}
s.data[i] = lang.Placeholder
}
func (s *Set) setType(i interface{}) {
if s.tp != untyped {
return
}
switch i.(type) {
case int:
s.tp = intType
case int64:
s.tp = int64Type
case uint:
s.tp = uintType
case uint64:
s.tp = uint64Type
case string:
s.tp = stringType
}
}
func (s *Set) validate(i interface{}) {
if s.tp == unmanaged {
return
}
switch i.(type) {
case int:
if s.tp != intType {
logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
}
case int64:
if s.tp != int64Type {
logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
}
case uint:
if s.tp != uintType {
logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
}
case uint64:
if s.tp != uint64Type {
logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
}
case string:
if s.tp != stringType {
logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
}
}
}

149
core/collection/set_test.go Normal file
View File

@ -0,0 +1,149 @@
package collection
import (
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkRawSet(b *testing.B) {
m := make(map[interface{}]struct{})
for i := 0; i < b.N; i++ {
m[i] = struct{}{}
_ = m[i]
}
}
func BenchmarkUnmanagedSet(b *testing.B) {
s := NewUnmanagedSet()
for i := 0; i < b.N; i++ {
s.Add(i)
_ = s.Contains(i)
}
}
func BenchmarkSet(b *testing.B) {
s := NewSet()
for i := 0; i < b.N; i++ {
s.AddInt(i)
_ = s.Contains(i)
}
}
func TestAdd(t *testing.T) {
// given
set := NewUnmanagedSet()
values := []interface{}{1, 2, 3}
// when
set.Add(values...)
// then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
assert.Equal(t, len(values), len(set.Keys()))
}
func TestAddInt(t *testing.T) {
// given
set := NewSet()
values := []int{1, 2, 3}
// when
set.AddInt(values...)
// then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
keys := set.KeysInt()
sort.Ints(keys)
assert.EqualValues(t, values, keys)
}
func TestAddInt64(t *testing.T) {
// given
set := NewSet()
values := []int64{1, 2, 3}
// when
set.AddInt64(values...)
// then
assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3)))
assert.Equal(t, len(values), len(set.KeysInt64()))
}
func TestAddUint(t *testing.T) {
// given
set := NewSet()
values := []uint{1, 2, 3}
// when
set.AddUint(values...)
// then
assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3)))
assert.Equal(t, len(values), len(set.KeysUint()))
}
func TestAddUint64(t *testing.T) {
// given
set := NewSet()
values := []uint64{1, 2, 3}
// when
set.AddUint64(values...)
// then
assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3)))
assert.Equal(t, len(values), len(set.KeysUint64()))
}
func TestAddStr(t *testing.T) {
// given
set := NewSet()
values := []string{"1", "2", "3"}
// when
set.AddStr(values...)
// then
assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3"))
assert.Equal(t, len(values), len(set.KeysStr()))
}
func TestContainsWithoutElements(t *testing.T) {
// given
set := NewSet()
// then
assert.False(t, set.Contains(1))
}
func TestContainsUnmanagedWithoutElements(t *testing.T) {
// given
set := NewUnmanagedSet()
// then
assert.False(t, set.Contains(1))
}
func TestRemove(t *testing.T) {
// given
set := NewSet()
set.Add([]interface{}{1, 2, 3}...)
// when
set.Remove(2)
// then
assert.True(t, set.Contains(1) && !set.Contains(2) && set.Contains(3))
}
func TestCount(t *testing.T) {
// given
set := NewSet()
set.Add([]interface{}{1, 2, 3}...)
// then
assert.Equal(t, set.Count(), 3)
}

View File

@ -0,0 +1,311 @@
package collection
import (
"container/list"
"fmt"
"time"
"zero/core/lang"
"zero/core/threading"
"zero/core/timex"
)
const drainWorkers = 8
type (
Execute func(key, value interface{})
TimingWheel struct {
interval time.Duration
ticker timex.Ticker
slots []*list.List
timers *SafeMap
tickedPos int
numSlots int
execute Execute
setChannel chan timingEntry
moveChannel chan baseEntry
removeChannel chan interface{}
drainChannel chan func(key, value interface{})
stopChannel chan lang.PlaceholderType
}
timingEntry struct {
baseEntry
value interface{}
circle int
diff int
removed bool
}
baseEntry struct {
delay time.Duration
key interface{}
}
positionEntry struct {
pos int
item *timingEntry
}
timingTask struct {
key interface{}
value interface{}
}
)
func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
if interval <= 0 || numSlots <= 0 || execute == nil {
return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute)
}
return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
}
func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) (
*TimingWheel, error) {
tw := &TimingWheel{
interval: interval,
ticker: ticker,
slots: make([]*list.List, numSlots),
timers: NewSafeMap(),
tickedPos: numSlots - 1, // at previous virtual circle
execute: execute,
numSlots: numSlots,
setChannel: make(chan timingEntry),
moveChannel: make(chan baseEntry),
removeChannel: make(chan interface{}),
drainChannel: make(chan func(key, value interface{})),
stopChannel: make(chan lang.PlaceholderType),
}
tw.initSlots()
go tw.run()
return tw, nil
}
func (tw *TimingWheel) Drain(fn func(key, value interface{})) {
tw.drainChannel <- fn
}
func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
if delay <= 0 || key == nil {
return
}
tw.moveChannel <- baseEntry{
delay: delay,
key: key,
}
}
func (tw *TimingWheel) RemoveTimer(key interface{}) {
if key == nil {
return
}
tw.removeChannel <- key
}
func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
if delay <= 0 || key == nil {
return
}
tw.setChannel <- timingEntry{
baseEntry: baseEntry{
delay: delay,
key: key,
},
value: value,
}
}
func (tw *TimingWheel) Stop() {
close(tw.stopChannel)
}
func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
runner := threading.NewTaskRunner(drainWorkers)
for _, slot := range tw.slots {
for e := slot.Front(); e != nil; {
task := e.Value.(*timingEntry)
next := e.Next()
slot.Remove(e)
e = next
if !task.removed {
runner.Schedule(func() {
fn(task.key, task.value)
})
}
}
}
}
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) {
steps := int(d / tw.interval)
pos = (tw.tickedPos + steps) % tw.numSlots
circle = (steps - 1) / tw.numSlots
return
}
func (tw *TimingWheel) initSlots() {
for i := 0; i < tw.numSlots; i++ {
tw.slots[i] = list.New()
}
}
func (tw *TimingWheel) moveTask(task baseEntry) {
val, ok := tw.timers.Get(task.key)
if !ok {
return
}
timer := val.(*positionEntry)
if task.delay < tw.interval {
threading.GoSafe(func() {
tw.execute(timer.item.key, timer.item.value)
})
return
}
pos, circle := tw.getPositionAndCircle(task.delay)
if pos >= timer.pos {
timer.item.circle = circle
timer.item.diff = pos - timer.pos
} else if circle > 0 {
circle--
timer.item.circle = circle
timer.item.diff = tw.numSlots + pos - timer.pos
} else {
timer.item.removed = true
newItem := &timingEntry{
baseEntry: task,
value: timer.item.value,
}
tw.slots[pos].PushBack(newItem)
tw.setTimerPosition(pos, newItem)
}
}
func (tw *TimingWheel) onTick() {
tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
l := tw.slots[tw.tickedPos]
tw.scanAndRunTasks(l)
}
func (tw *TimingWheel) removeTask(key interface{}) {
val, ok := tw.timers.Get(key)
if !ok {
return
}
timer := val.(*positionEntry)
timer.item.removed = true
}
func (tw *TimingWheel) run() {
for {
select {
case <-tw.ticker.Chan():
tw.onTick()
case task := <-tw.setChannel:
tw.setTask(&task)
case key := <-tw.removeChannel:
tw.removeTask(key)
case task := <-tw.moveChannel:
tw.moveTask(task)
case fn := <-tw.drainChannel:
tw.drainAll(fn)
case <-tw.stopChannel:
tw.ticker.Stop()
return
}
}
}
func (tw *TimingWheel) runTasks(tasks []timingTask) {
if len(tasks) == 0 {
return
}
go func() {
for i := range tasks {
threading.RunSafe(func() {
tw.execute(tasks[i].key, tasks[i].value)
})
}
}()
}
func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
var tasks []timingTask
for e := l.Front(); e != nil; {
task := e.Value.(*timingEntry)
if task.removed {
next := e.Next()
l.Remove(e)
tw.timers.Del(task.key)
e = next
continue
} else if task.circle > 0 {
task.circle--
e = e.Next()
continue
} else if task.diff > 0 {
next := e.Next()
l.Remove(e)
// (tw.tickedPos+task.diff)%tw.numSlots
// cannot be the same value of tw.tickedPos
pos := (tw.tickedPos + task.diff) % tw.numSlots
tw.slots[pos].PushBack(task)
tw.setTimerPosition(pos, task)
task.diff = 0
e = next
continue
}
tasks = append(tasks, timingTask{
key: task.key,
value: task.value,
})
next := e.Next()
l.Remove(e)
tw.timers.Del(task.key)
e = next
}
tw.runTasks(tasks)
}
func (tw *TimingWheel) setTask(task *timingEntry) {
if task.delay < tw.interval {
task.delay = tw.interval
}
if val, ok := tw.timers.Get(task.key); ok {
entry := val.(*positionEntry)
entry.item.value = task.value
tw.moveTask(task.baseEntry)
} else {
pos, circle := tw.getPositionAndCircle(task.delay)
task.circle = circle
tw.slots[pos].PushBack(task)
tw.setTimerPosition(pos, task)
}
}
func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
if val, ok := tw.timers.Get(task.key); ok {
timer := val.(*positionEntry)
timer.pos = pos
} else {
tw.timers.Set(task.key, &positionEntry{
pos: pos,
item: task,
})
}
}

View File

@ -0,0 +1,593 @@
package collection
import (
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"zero/core/lang"
"zero/core/stringx"
"zero/core/syncx"
"zero/core/timex"
"github.com/stretchr/testify/assert"
)
const (
testStep = time.Minute
waitTime = time.Second
)
func TestNewTimingWheel(t *testing.T) {
_, err := NewTimingWheel(0, 10, func(key, value interface{}) {})
assert.NotNil(t, err)
}
func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
}, ticker)
defer tw.Stop()
tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7)
tw.SetTimer("third", 7, testStep*7)
var keys []string
var vals []int
var lock sync.Mutex
var wg sync.WaitGroup
wg.Add(3)
tw.Drain(func(key, value interface{}) {
lock.Lock()
defer lock.Unlock()
keys = append(keys, key.(string))
vals = append(vals, value.(int))
wg.Done()
})
wg.Wait()
sort.Strings(keys)
sort.Ints(vals)
assert.Equal(t, 3, len(keys))
assert.EqualValues(t, []string{"first", "second", "third"}, keys)
assert.EqualValues(t, []int{3, 5, 7}, vals)
var count int
tw.Drain(func(key, value interface{}) {
count++
})
time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count)
}
func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep>>1)
ticker.Tick()
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4)
tw.SetTimer("any", 5, testStep*7)
for i := 0; i < 8; i++ {
ticker.Tick()
}
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
defer tw.Stop()
assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep)
})
}
func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep*7)
tw.MoveTimer("any", -testStep)
tw.MoveTimer("none", testStep)
for i := 0; i < 5; i++ {
ticker.Tick()
}
assert.False(t, run.True())
for i := 0; i < 3; i++ {
ticker.Tick()
}
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep>>1)
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep*2)
for i := 0; i < 3; i++ {
ticker.Tick()
}
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() {
tw.RemoveTimer("any")
tw.RemoveTimer("none")
tw.RemoveTimer(nil)
})
for i := 0; i < 5; i++ {
ticker.Tick()
}
tw.Stop()
}
func TestTimingWheel_SetTimer(t *testing.T) {
tests := []struct {
slots int
setAt time.Duration
}{
{
slots: 5,
setAt: 5,
},
{
slots: 5,
setAt: 7,
},
{
slots: 5,
setAt: 10,
},
{
slots: 5,
setAt: 12,
},
{
slots: 5,
setAt: 7,
},
{
slots: 5,
setAt: 10,
},
{
slots: 5,
setAt: 12,
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
tw.SetTimer(1, 2, testStep*test.setAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.setAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
tests := []struct {
slots int
setAt time.Duration
moveAt time.Duration
}{
{
slots: 5,
setAt: 3,
moveAt: 5,
},
{
slots: 5,
setAt: 3,
moveAt: 7,
},
{
slots: 5,
setAt: 3,
moveAt: 10,
},
{
slots: 5,
setAt: 3,
moveAt: 12,
},
{
slots: 5,
setAt: 5,
moveAt: 7,
},
{
slots: 5,
setAt: 5,
moveAt: 10,
},
{
slots: 5,
setAt: 5,
moveAt: 12,
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
tw.SetTimer(1, 2, testStep*test.setAt)
tw.MoveTimer(1, testStep*test.moveAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.moveAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
tests := []struct {
slots int
setAt time.Duration
moveAt time.Duration
moveAgainAt time.Duration
}{
{
slots: 5,
setAt: 3,
moveAt: 5,
moveAgainAt: 10,
},
{
slots: 5,
setAt: 3,
moveAt: 7,
moveAgainAt: 12,
},
{
slots: 5,
setAt: 3,
moveAt: 10,
moveAgainAt: 15,
},
{
slots: 5,
setAt: 3,
moveAt: 12,
moveAgainAt: 17,
},
{
slots: 5,
setAt: 5,
moveAt: 7,
moveAgainAt: 12,
},
{
slots: 5,
setAt: 5,
moveAt: 10,
moveAgainAt: 17,
},
{
slots: 5,
setAt: 5,
moveAt: 12,
moveAgainAt: 17,
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
tw.SetTimer(1, 2, testStep*test.setAt)
tw.MoveTimer(1, testStep*test.moveAt)
tw.MoveTimer(1, testStep*test.moveAgainAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.moveAgainAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_ElapsedAndSet(t *testing.T) {
tests := []struct {
slots int
elapsed time.Duration
setAt time.Duration
}{
{
slots: 5,
elapsed: 3,
setAt: 5,
},
{
slots: 5,
elapsed: 3,
setAt: 7,
},
{
slots: 5,
elapsed: 3,
setAt: 10,
},
{
slots: 5,
elapsed: 3,
setAt: 12,
},
{
slots: 5,
elapsed: 5,
setAt: 7,
},
{
slots: 5,
elapsed: 5,
setAt: 10,
},
{
slots: 5,
elapsed: 5,
setAt: 12,
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
for i := 0; i < int(test.elapsed); i++ {
tick()
}
tw.SetTimer(1, 2, testStep*test.setAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.elapsed+test.setAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
tests := []struct {
slots int
elapsed time.Duration
setAt time.Duration
moveAt time.Duration
}{
{
slots: 5,
elapsed: 3,
setAt: 5,
moveAt: 10,
},
{
slots: 5,
elapsed: 3,
setAt: 7,
moveAt: 12,
},
{
slots: 5,
elapsed: 3,
setAt: 10,
moveAt: 15,
},
{
slots: 5,
elapsed: 3,
setAt: 12,
moveAt: 16,
},
{
slots: 5,
elapsed: 5,
setAt: 7,
moveAt: 12,
},
{
slots: 5,
elapsed: 5,
setAt: 10,
moveAt: 15,
},
{
slots: 5,
elapsed: 5,
setAt: 12,
moveAt: 17,
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
for i := 0; i < int(test.elapsed); i++ {
tick()
}
tw.SetTimer(1, 2, testStep*test.setAt)
tw.MoveTimer(1, testStep*test.moveAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.elapsed+test.moveAt), actual)
return
default:
tick()
}
}
})
}
}
func BenchmarkTimingWheel(b *testing.B) {
b.ReportAllocs()
tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {})
for i := 0; i < b.N; i++ {
tw.SetTimer(i, i, time.Second)
tw.SetTimer(b.N+i, b.N+i, time.Second)
tw.MoveTimer(i, time.Second*time.Duration(i))
tw.RemoveTimer(i)
}
}

40
core/conf/config.go Normal file
View File

@ -0,0 +1,40 @@
package conf
import (
"fmt"
"io/ioutil"
"log"
"path"
"zero/core/mapping"
)
var loaders = map[string]func([]byte, interface{}) error{
".json": LoadConfigFromJsonBytes,
".yaml": LoadConfigFromYamlBytes,
".yml": LoadConfigFromYamlBytes,
}
func LoadConfig(file string, v interface{}) error {
if content, err := ioutil.ReadFile(file); err != nil {
return err
} else if loader, ok := loaders[path.Ext(file)]; ok {
return loader(content, v)
} else {
return fmt.Errorf("unrecoginized file type: %s", file)
}
}
func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
return mapping.UnmarshalJsonBytes(content, v)
}
func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
return mapping.UnmarshalYamlBytes(content, v)
}
func MustLoad(path string, v interface{}) {
if err := LoadConfig(path, v); err != nil {
log.Fatalf("error: config file %s, %s", path, err.Error())
}
}

109
core/conf/properties.go Normal file
View File

@ -0,0 +1,109 @@
package conf
import (
"fmt"
"strconv"
"strings"
"sync"
"zero/core/iox"
)
// PropertyError represents a configuration error message.
type PropertyError struct {
error
message string
}
// Properties interface provides the means to access configuration.
type Properties interface {
GetString(key string) string
SetString(key, value string)
GetInt(key string) int
SetInt(key string, value int)
ToString() string
}
// Properties config is a key/value pair based configuration structure.
type mapBasedProperties struct {
properties map[string]string
lock sync.RWMutex
}
// Loads the properties into a properties configuration instance. May return the
// configuration itself along with an error that indicates if there was a problem loading the configuration.
func LoadProperties(filename string) (Properties, error) {
lines, err := iox.ReadTextLines(filename, iox.WithoutBlank(), iox.OmitWithPrefix("#"))
if err != nil {
return nil, nil
}
raw := make(map[string]string)
for i := range lines {
pair := strings.Split(lines[i], "=")
if len(pair) != 2 {
// invalid property format
return nil, &PropertyError{
message: fmt.Sprintf("invalid property format: %s", pair),
}
}
key := strings.TrimSpace(pair[0])
value := strings.TrimSpace(pair[1])
raw[key] = value
}
return &mapBasedProperties{
properties: raw,
}, nil
}
func (config *mapBasedProperties) GetString(key string) string {
config.lock.RLock()
ret := config.properties[key]
config.lock.RUnlock()
return ret
}
func (config *mapBasedProperties) SetString(key, value string) {
config.lock.Lock()
config.properties[key] = value
config.lock.Unlock()
}
func (config *mapBasedProperties) GetInt(key string) int {
config.lock.RLock()
// default 0
value, _ := strconv.Atoi(config.properties[key])
config.lock.RUnlock()
return value
}
func (config *mapBasedProperties) SetInt(key string, value int) {
config.lock.Lock()
config.properties[key] = strconv.Itoa(value)
config.lock.Unlock()
}
// Dumps the configuration internal map into a string.
func (config *mapBasedProperties) ToString() string {
config.lock.RLock()
ret := fmt.Sprintf("%s", config.properties)
config.lock.RUnlock()
return ret
}
// Returns the error message.
func (configError *PropertyError) Error() string {
return configError.message
}
// Builds a new properties configuration structure
func NewProperties() Properties {
return &mapBasedProperties{
properties: make(map[string]string),
}
}

View File

@ -0,0 +1,44 @@
package conf
import (
"os"
"testing"
"zero/core/fs"
"github.com/stretchr/testify/assert"
)
func TestProperties(t *testing.T) {
text := `app.name = test
app.program=app
# this is comment
app.threads = 5`
tmpfile, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
props, err := LoadProperties(tmpfile)
assert.Nil(t, err)
assert.Equal(t, "test", props.GetString("app.name"))
assert.Equal(t, "app", props.GetString("app.program"))
assert.Equal(t, 5, props.GetInt("app.threads"))
}
func TestSetString(t *testing.T) {
key := "a"
value := "the value of a"
props := NewProperties()
props.SetString(key, value)
assert.Equal(t, value, props.GetString(key))
}
func TestSetInt(t *testing.T) {
key := "a"
value := 101
props := NewProperties()
props.SetInt(key, value)
assert.Equal(t, value, props.GetInt(key))
}

17
core/contextx/deadline.go Normal file
View File

@ -0,0 +1,17 @@
package contextx
import (
"context"
"time"
)
func ShrinkDeadline(ctx context.Context, timeout time.Duration) (context.Context, func()) {
if deadline, ok := ctx.Deadline(); ok {
leftTime := time.Until(deadline)
if leftTime < timeout {
timeout = leftTime
}
}
return context.WithDeadline(ctx, time.Now().Add(timeout))
}

View File

@ -0,0 +1,27 @@
package contextx
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestShrinkDeadlineLess(t *testing.T) {
deadline := time.Now().Add(time.Second)
ctx, _ := context.WithDeadline(context.Background(), deadline)
ctx, _ = ShrinkDeadline(ctx, time.Minute)
dl, ok := ctx.Deadline()
assert.True(t, ok)
assert.Equal(t, deadline, dl)
}
func TestShrinkDeadlineMore(t *testing.T) {
deadline := time.Now().Add(time.Minute)
ctx, _ := context.WithDeadline(context.Background(), deadline)
ctx, _ = ShrinkDeadline(ctx, time.Second)
dl, ok := ctx.Deadline()
assert.True(t, ok)
assert.True(t, dl.Before(deadline))
}

View File

@ -0,0 +1,26 @@
package contextx
import (
"context"
"zero/core/mapping"
)
const contextTagKey = "ctx"
var unmarshaler = mapping.NewUnmarshaler(contextTagKey)
type contextValuer struct {
context.Context
}
func (cv contextValuer) Value(key string) (interface{}, bool) {
v := cv.Context.Value(key)
return v, v != nil
}
func For(ctx context.Context, v interface{}) error {
return unmarshaler.UnmarshalValuer(contextValuer{
Context: ctx,
}, v)
}

View File

@ -0,0 +1,58 @@
package contextx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUnmarshalContext(t *testing.T) {
type Person struct {
Name string `ctx:"name"`
Age int `ctx:"age"`
}
ctx := context.Background()
ctx = context.WithValue(ctx, "name", "kevin")
ctx = context.WithValue(ctx, "age", 20)
var person Person
err := For(ctx, &person)
assert.Nil(t, err)
assert.Equal(t, "kevin", person.Name)
assert.Equal(t, 20, person.Age)
}
func TestUnmarshalContextWithOptional(t *testing.T) {
type Person struct {
Name string `ctx:"name"`
Age int `ctx:"age,optional"`
}
ctx := context.Background()
ctx = context.WithValue(ctx, "name", "kevin")
var person Person
err := For(ctx, &person)
assert.Nil(t, err)
assert.Equal(t, "kevin", person.Name)
assert.Equal(t, 0, person.Age)
}
func TestUnmarshalContextWithMissing(t *testing.T) {
type Person struct {
Name string `ctx:"name"`
Age int `ctx:"age"`
}
ctx := context.Background()
ctx = context.WithValue(ctx, "name", "kevin")
var person Person
err := For(ctx, &person)
assert.NotNil(t, err)
}

View File

@ -0,0 +1,28 @@
package contextx
import (
"context"
"time"
)
type valueOnlyContext struct {
context.Context
}
func (valueOnlyContext) Deadline() (deadline time.Time, ok bool) {
return
}
func (valueOnlyContext) Done() <-chan struct{} {
return nil
}
func (valueOnlyContext) Err() error {
return nil
}
func ValueOnlyFrom(ctx context.Context) context.Context {
return valueOnlyContext{
Context: ctx,
}
}

View File

@ -0,0 +1,54 @@
package contextx
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestContextCancel(t *testing.T) {
c := context.WithValue(context.Background(), "key", "value")
c1, cancel := context.WithCancel(c)
o := ValueOnlyFrom(c1)
c2, _ := context.WithCancel(o)
contexts := []context.Context{c1, c2}
for _, c := range contexts {
assert.NotNil(t, c.Done())
assert.Nil(t, c.Err())
select {
case x := <-c.Done():
t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
default:
}
}
cancel()
<-c1.Done()
assert.Nil(t, o.Err())
assert.Equal(t, context.Canceled, c1.Err())
assert.NotEqual(t, context.Canceled, c2.Err())
}
func TestConextDeadline(t *testing.T) {
c, _ := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond))
o := ValueOnlyFrom(c)
select {
case <-time.After(100 * time.Millisecond):
case <-o.Done():
t.Fatal("ValueOnlyContext: context should not have timed out")
}
c, _ = context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond))
o = ValueOnlyFrom(c)
c, _ = context.WithDeadline(o, time.Now().Add(20*time.Millisecond))
select {
case <-time.After(100 * time.Millisecond):
t.Fatal("ValueOnlyContext+Deadline: context should have timed out")
case <-c.Done():
}
}

40
core/discov/clients.go Normal file
View File

@ -0,0 +1,40 @@
package discov
import (
"fmt"
"strings"
"zero/core/discov/internal"
)
const (
indexOfKey = iota
indexOfId
)
const timeToLive int64 = 10
var TimeToLive = timeToLive
func extract(etcdKey string, index int) (string, bool) {
if index < 0 {
return "", false
}
fields := strings.FieldsFunc(etcdKey, func(ch rune) bool {
return ch == internal.Delimiter
})
if index >= len(fields) {
return "", false
}
return fields[index], true
}
func extractId(etcdKey string) (string, bool) {
return extract(etcdKey, indexOfId)
}
func makeEtcdKey(key string, id int64) string {
return fmt.Sprintf("%s%c%d", key, internal.Delimiter, id)
}

View File

@ -0,0 +1,36 @@
package discov
import (
"sync"
"testing"
"zero/core/discov/internal"
"github.com/stretchr/testify/assert"
)
var mockLock sync.Mutex
func setMockClient(cli internal.EtcdClient) func() {
mockLock.Lock()
internal.NewClient = func([]string) (internal.EtcdClient, error) {
return cli, nil
}
return func() {
internal.NewClient = internal.DialClient
mockLock.Unlock()
}
}
func TestExtract(t *testing.T) {
id, ok := extractId("key/123/val")
assert.True(t, ok)
assert.Equal(t, "123", id)
_, ok = extract("any", -1)
assert.False(t, ok)
}
func TestMakeKey(t *testing.T) {
assert.Equal(t, "key/123", makeEtcdKey("key", 123))
}

18
core/discov/config.go Normal file
View File

@ -0,0 +1,18 @@
package discov
import "errors"
type EtcdConf struct {
Hosts []string
Key string
}
func (c EtcdConf) Validate() error {
if len(c.Hosts) == 0 {
return errors.New("empty etcd hosts")
} else if len(c.Key) == 0 {
return errors.New("empty etcd key")
} else {
return nil
}
}

View File

@ -0,0 +1,46 @@
package discov
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig(t *testing.T) {
tests := []struct {
EtcdConf
pass bool
}{
{
EtcdConf: EtcdConf{},
pass: false,
},
{
EtcdConf: EtcdConf{
Key: "any",
},
pass: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
},
pass: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
Key: "key",
},
pass: true,
},
}
for _, test := range tests {
if test.pass {
assert.Nil(t, test.EtcdConf.Validate())
} else {
assert.NotNil(t, test.EtcdConf.Validate())
}
}
}

47
core/discov/facade.go Normal file
View File

@ -0,0 +1,47 @@
package discov
import (
"zero/core/discov/internal"
"zero/core/lang"
)
type (
Facade struct {
endpoints []string
registry *internal.Registry
}
FacadeListener interface {
OnAdd(key, val string)
OnDelete(key string)
}
)
func NewFacade(endpoints []string) Facade {
return Facade{
endpoints: endpoints,
registry: internal.GetRegistry(),
}
}
func (f Facade) Client() internal.EtcdClient {
conn, err := f.registry.GetConn(f.endpoints)
lang.Must(err)
return conn
}
func (f Facade) Monitor(key string, l FacadeListener) {
f.registry.Monitor(f.endpoints, key, listenerAdapter{l})
}
type listenerAdapter struct {
l FacadeListener
}
func (la listenerAdapter) OnAdd(kv internal.KV) {
la.l.OnAdd(kv.Key, kv.Val)
}
func (la listenerAdapter) OnDelete(kv internal.KV) {
la.l.OnDelete(kv.Key)
}

View File

@ -0,0 +1,103 @@
package internal
import "sync"
type (
DialFn func(server string) (interface{}, error)
CloseFn func(server string, conn interface{}) error
Balancer interface {
AddConn(kv KV) error
IsEmpty() bool
Next(key ...string) (interface{}, bool)
RemoveKey(key string)
initialize()
setListener(listener Listener)
}
serverConn struct {
key string
conn interface{}
}
baseBalancer struct {
exclusive bool
servers map[string][]string
mapping map[string]string
lock sync.Mutex
dialFn DialFn
closeFn CloseFn
listener Listener
}
)
func newBaseBalancer(dialFn DialFn, closeFn CloseFn, exclusive bool) *baseBalancer {
return &baseBalancer{
exclusive: exclusive,
servers: make(map[string][]string),
mapping: make(map[string]string),
dialFn: dialFn,
closeFn: closeFn,
}
}
// addKv adds the kv, returns if there are already other keys associate with the server
func (b *baseBalancer) addKv(key, value string) ([]string, bool) {
b.lock.Lock()
defer b.lock.Unlock()
keys := b.servers[value]
previous := append([]string(nil), keys...)
early := len(keys) > 0
if b.exclusive && early {
for _, each := range keys {
b.doRemoveKv(each)
}
}
b.servers[value] = append(b.servers[value], key)
b.mapping[key] = value
if early {
return previous, true
} else {
return nil, false
}
}
func (b *baseBalancer) doRemoveKv(key string) (server string, keepConn bool) {
server, ok := b.mapping[key]
if !ok {
return "", true
}
delete(b.mapping, key)
keys := b.servers[server]
remain := keys[:0]
for _, k := range keys {
if k != key {
remain = append(remain, k)
}
}
if len(remain) > 0 {
b.servers[server] = remain
return server, true
} else {
delete(b.servers, server)
return server, false
}
}
func (b *baseBalancer) removeKv(key string) (server string, keepConn bool) {
b.lock.Lock()
defer b.lock.Unlock()
return b.doRemoveKv(key)
}
func (b *baseBalancer) setListener(listener Listener) {
b.lock.Lock()
b.listener = listener
b.lock.Unlock()
}

View File

@ -0,0 +1,5 @@
package internal
type mockConn struct {
server string
}

View File

@ -0,0 +1,152 @@
package internal
import (
"zero/core/hash"
"zero/core/logx"
)
type consistentBalancer struct {
*baseBalancer
conns map[string]interface{}
buckets *hash.ConsistentHash
bucketKey func(KV) string
}
func NewConsistentBalancer(dialFn DialFn, closeFn CloseFn, keyer func(kv KV) string) *consistentBalancer {
// we don't support exclusive mode for consistent Balancer, to avoid complexity,
// because there are few scenarios, use it on your own risks.
balancer := &consistentBalancer{
conns: make(map[string]interface{}),
buckets: hash.NewConsistentHash(),
bucketKey: keyer,
}
balancer.baseBalancer = newBaseBalancer(dialFn, closeFn, false)
return balancer
}
func (b *consistentBalancer) AddConn(kv KV) error {
// not adding kv and conn within a transaction, but it doesn't matter
// we just rollback the kv addition if dial failed
var conn interface{}
prev, found := b.addKv(kv.Key, kv.Val)
if found {
conn = b.handlePrevious(prev)
}
if conn == nil {
var err error
conn, err = b.dialFn(kv.Val)
if err != nil {
b.removeKv(kv.Key)
return err
}
}
bucketKey := b.bucketKey(kv)
b.lock.Lock()
defer b.lock.Unlock()
b.conns[bucketKey] = conn
b.buckets.Add(bucketKey)
b.notify(bucketKey)
logx.Infof("added server, key: %s, server: %s", bucketKey, kv.Val)
return nil
}
func (b *consistentBalancer) getConn(key string) (interface{}, bool) {
b.lock.Lock()
conn, ok := b.conns[key]
b.lock.Unlock()
return conn, ok
}
func (b *consistentBalancer) handlePrevious(prev []string) interface{} {
if len(prev) == 0 {
return nil
}
b.lock.Lock()
defer b.lock.Unlock()
// if not exclusive, only need to randomly find one connection
for key, conn := range b.conns {
if key == prev[0] {
return conn
}
}
return nil
}
func (b *consistentBalancer) initialize() {
}
func (b *consistentBalancer) notify(key string) {
if b.listener == nil {
return
}
var keys []string
var values []string
for k := range b.conns {
keys = append(keys, k)
}
for _, v := range b.mapping {
values = append(values, v)
}
b.listener.OnUpdate(keys, values, key)
}
func (b *consistentBalancer) RemoveKey(key string) {
kv := KV{Key: key}
server, keep := b.removeKv(key)
kv.Val = server
bucketKey := b.bucketKey(kv)
b.buckets.Remove(b.bucketKey(kv))
// wrap the query & removal in a function to make sure the quick lock/unlock
conn, ok := func() (interface{}, bool) {
b.lock.Lock()
defer b.lock.Unlock()
conn, ok := b.conns[bucketKey]
if ok {
delete(b.conns, bucketKey)
}
return conn, ok
}()
if ok && !keep {
logx.Infof("removing server, key: %s", kv.Key)
if err := b.closeFn(server, conn); err != nil {
logx.Error(err)
}
}
// notify without new key
b.notify("")
}
func (b *consistentBalancer) IsEmpty() bool {
b.lock.Lock()
empty := len(b.conns) == 0
b.lock.Unlock()
return empty
}
func (b *consistentBalancer) Next(keys ...string) (interface{}, bool) {
if len(keys) != 1 {
return nil, false
}
key := keys[0]
if node, ok := b.buckets.Get(key); !ok {
return nil, false
} else {
return b.getConn(node.(string))
}
}

View File

@ -0,0 +1,178 @@
package internal
import (
"errors"
"sort"
"strconv"
"testing"
"zero/core/mathx"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestConsistent_addConn(t *testing.T) {
b := NewConsistentBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return errors.New("error")
}, func(kv KV) string {
return kv.Key
})
assert.Nil(t, b.AddConn(KV{
Key: "thekey1",
Val: "thevalue",
}))
assert.EqualValues(t, map[string]interface{}{
"thekey1": mockConn{server: "thevalue"},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey1"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey1": "thevalue",
}, b.mapping)
assert.Nil(t, b.AddConn(KV{
Key: "thekey2",
Val: "thevalue",
}))
assert.EqualValues(t, map[string]interface{}{
"thekey1": mockConn{server: "thevalue"},
"thekey2": mockConn{server: "thevalue"},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey1", "thekey2"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey1": "thevalue",
"thekey2": "thevalue",
}, b.mapping)
assert.False(t, b.IsEmpty())
b.RemoveKey("thekey1")
assert.EqualValues(t, map[string]interface{}{
"thekey2": mockConn{server: "thevalue"},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey2"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey2": "thevalue",
}, b.mapping)
assert.False(t, b.IsEmpty())
b.RemoveKey("thekey2")
assert.Equal(t, 0, len(b.conns))
assert.EqualValues(t, map[string][]string{}, b.servers)
assert.EqualValues(t, map[string]string{}, b.mapping)
assert.True(t, b.IsEmpty())
}
func TestConsistent_addConnError(t *testing.T) {
b := NewConsistentBalancer(func(server string) (interface{}, error) {
return nil, errors.New("error")
}, func(server string, conn interface{}) error {
return nil
}, func(kv KV) string {
return kv.Key
})
assert.NotNil(t, b.AddConn(KV{
Key: "thekey1",
Val: "thevalue",
}))
assert.Equal(t, 0, len(b.conns))
assert.EqualValues(t, map[string][]string{}, b.servers)
assert.EqualValues(t, map[string]string{}, b.mapping)
}
func TestConsistent_next(t *testing.T) {
b := NewConsistentBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return errors.New("error")
}, func(kv KV) string {
return kv.Key
})
b.initialize()
_, ok := b.Next("any")
assert.False(t, ok)
const size = 100
for i := 0; i < size; i++ {
assert.Nil(t, b.AddConn(KV{
Key: "thekey/" + strconv.Itoa(i),
Val: "thevalue/" + strconv.Itoa(i),
}))
}
m := make(map[interface{}]int)
const total = 10000
for i := 0; i < total; i++ {
val, ok := b.Next(strconv.Itoa(i))
assert.True(t, ok)
m[val]++
}
entropy := mathx.CalcEntropy(m, total)
assert.Equal(t, size, len(m))
assert.True(t, entropy > .95)
for i := 0; i < size; i++ {
b.RemoveKey("thekey/" + strconv.Itoa(i))
}
_, ok = b.Next()
assert.False(t, ok)
}
func TestConsistentBalancer_Listener(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
b := NewConsistentBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return nil
}, func(kv KV) string {
return kv.Key
})
assert.Nil(t, b.AddConn(KV{
Key: "key1",
Val: "val1",
}))
assert.Nil(t, b.AddConn(KV{
Key: "key2",
Val: "val2",
}))
listener := NewMockListener(ctrl)
listener.EXPECT().OnUpdate(gomock.Any(), gomock.Any(), "key2").Do(func(keys, vals, _ interface{}) {
sort.Strings(keys.([]string))
sort.Strings(vals.([]string))
assert.EqualValues(t, []string{"key1", "key2"}, keys)
assert.EqualValues(t, []string{"val1", "val2"}, vals)
})
b.setListener(listener)
b.notify("key2")
}
func TestConsistentBalancer_remove(t *testing.T) {
b := NewConsistentBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return nil
}, func(kv KV) string {
return kv.Key
})
assert.Nil(t, b.handlePrevious(nil))
assert.Nil(t, b.handlePrevious([]string{"any"}))
}

View File

@ -0,0 +1,21 @@
//go:generate mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
package internal
import (
"context"
"go.etcd.io/etcd/clientv3"
"google.golang.org/grpc"
)
type EtcdClient interface {
ActiveConnection() *grpc.ClientConn
Close() error
Ctx() context.Context
Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error)
Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error)
KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error)
Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error)
Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error)
Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan
}

View File

@ -0,0 +1,182 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: etcdclient.go
// Package internal is a generated GoMock package.
package internal
import (
context "context"
gomock "github.com/golang/mock/gomock"
clientv3 "go.etcd.io/etcd/clientv3"
grpc "google.golang.org/grpc"
reflect "reflect"
)
// MockEtcdClient is a mock of EtcdClient interface
type MockEtcdClient struct {
ctrl *gomock.Controller
recorder *MockEtcdClientMockRecorder
}
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient
type MockEtcdClientMockRecorder struct {
mock *MockEtcdClient
}
// NewMockEtcdClient creates a new mock instance
func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient {
mock := &MockEtcdClient{ctrl: ctrl}
mock.recorder = &MockEtcdClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder {
return m.recorder
}
// ActiveConnection mocks base method
func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ActiveConnection")
ret0, _ := ret[0].(*grpc.ClientConn)
return ret0
}
// ActiveConnection indicates an expected call of ActiveConnection
func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection))
}
// Close mocks base method
func (m *MockEtcdClient) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close))
}
// Ctx mocks base method
func (m *MockEtcdClient) Ctx() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Ctx")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Ctx indicates an expected call of Ctx
func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx))
}
// Get mocks base method
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
m.ctrl.T.Helper()
varargs := []interface{}{ctx, key}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Get", varargs...)
ret0, _ := ret[0].(*clientv3.GetResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get
func (mr *MockEtcdClientMockRecorder) Get(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
}
// Grant mocks base method
func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Grant", ctx, ttl)
ret0, _ := ret[0].(*clientv3.LeaseGrantResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Grant indicates an expected call of Grant
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
}
// KeepAlive mocks base method
func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeepAlive", ctx, id)
ret0, _ := ret[0].(<-chan *clientv3.LeaseKeepAliveResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// KeepAlive indicates an expected call of KeepAlive
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
}
// Put mocks base method
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
m.ctrl.T.Helper()
varargs := []interface{}{ctx, key, val}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Put", varargs...)
ret0, _ := ret[0].(*clientv3.PutResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Put indicates an expected call of Put
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx, key, val}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
}
// Revoke mocks base method
func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Revoke", ctx, id)
ret0, _ := ret[0].(*clientv3.LeaseRevokeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Revoke indicates an expected call of Revoke
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
}
// Watch mocks base method
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
m.ctrl.T.Helper()
varargs := []interface{}{ctx, key}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Watch", varargs...)
ret0, _ := ret[0].(clientv3.WatchChan)
return ret0
}
// Watch indicates an expected call of Watch
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...)
}

View File

@ -0,0 +1,6 @@
//go:generate mockgen -package internal -destination listener_mock.go -source listener.go Listener
package internal
type Listener interface {
OnUpdate(keys []string, values []string, newKey string)
}

View File

@ -0,0 +1,45 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: listener.go
// Package internal is a generated GoMock package.
package internal
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// OnUpdate mocks base method
func (m *MockListener) OnUpdate(keys, values []string, newKey string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnUpdate", keys, values, newKey)
}
// OnUpdate indicates an expected call of OnUpdate
func (mr *MockListenerMockRecorder) OnUpdate(keys, values, newKey interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnUpdate", reflect.TypeOf((*MockListener)(nil).OnUpdate), keys, values, newKey)
}

View File

@ -0,0 +1,310 @@
package internal
import (
"context"
"fmt"
"io"
"sort"
"strings"
"sync"
"time"
"zero/core/contextx"
"zero/core/lang"
"zero/core/logx"
"zero/core/syncx"
"zero/core/threading"
"go.etcd.io/etcd/clientv3"
)
var (
registryInstance = Registry{
clusters: make(map[string]*cluster),
}
connManager = syncx.NewResourceManager()
)
type Registry struct {
clusters map[string]*cluster
lock sync.Mutex
}
func GetRegistry() *Registry {
return &registryInstance
}
func (r *Registry) getCluster(endpoints []string) *cluster {
clusterKey := getClusterKey(endpoints)
r.lock.Lock()
defer r.lock.Unlock()
c, ok := r.clusters[clusterKey]
if !ok {
c = newCluster(endpoints)
r.clusters[clusterKey] = c
}
return c
}
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
return r.getCluster(endpoints).getClient()
}
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
return r.getCluster(endpoints).monitor(key, l)
}
type cluster struct {
endpoints []string
key string
values map[string]map[string]string
listeners map[string][]UpdateListener
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType
lock sync.Mutex
}
func newCluster(endpoints []string) *cluster {
return &cluster{
endpoints: endpoints,
key: getClusterKey(endpoints),
values: make(map[string]map[string]string),
listeners: make(map[string][]UpdateListener),
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
}
}
func (c *cluster) context(cli EtcdClient) context.Context {
return contextx.ValueOnlyFrom(cli.Ctx())
}
func (c *cluster) getClient() (EtcdClient, error) {
val, err := connManager.GetResource(c.key, func() (io.Closer, error) {
return c.newClient()
})
if err != nil {
return nil, err
}
return val.(EtcdClient), nil
}
func (c *cluster) handleChanges(key string, kvs []KV) {
var add []KV
var remove []KV
c.lock.Lock()
listeners := append([]UpdateListener(nil), c.listeners[key]...)
vals, ok := c.values[key]
if !ok {
add = kvs
vals = make(map[string]string)
for _, kv := range kvs {
vals[kv.Key] = kv.Val
}
c.values[key] = vals
} else {
m := make(map[string]string)
for _, kv := range kvs {
m[kv.Key] = kv.Val
}
for k, v := range vals {
if val, ok := m[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
for k, v := range m {
if val, ok := vals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
c.values[key] = m
}
c.lock.Unlock()
for _, kv := range add {
for _, l := range listeners {
l.OnAdd(kv)
}
}
for _, kv := range remove {
for _, l := range listeners {
l.OnDelete(kv)
}
}
}
func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
c.lock.Lock()
listeners := append([]UpdateListener(nil), c.listeners[key]...)
c.lock.Unlock()
for _, ev := range events {
switch ev.Type {
case clientv3.EventTypePut:
c.lock.Lock()
if vals, ok := c.values[key]; ok {
vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
} else {
c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
}
c.lock.Unlock()
for _, l := range listeners {
l.OnAdd(KV{
Key: string(ev.Kv.Key),
Val: string(ev.Kv.Value),
})
}
case clientv3.EventTypeDelete:
if vals, ok := c.values[key]; ok {
delete(vals, string(ev.Kv.Key))
}
for _, l := range listeners {
l.OnDelete(KV{
Key: string(ev.Kv.Key),
Val: string(ev.Kv.Value),
})
}
default:
logx.Errorf("Unknown event type: %v", ev.Type)
}
}
}
func (c *cluster) load(cli EtcdClient, key string) {
var resp *clientv3.GetResponse
for {
var err error
ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout)
resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix())
cancel()
if err == nil {
break
}
logx.Error(err)
time.Sleep(coolDownInterval)
}
var kvs []KV
c.lock.Lock()
for _, ev := range resp.Kvs {
kvs = append(kvs, KV{
Key: string(ev.Key),
Val: string(ev.Value),
})
}
c.lock.Unlock()
c.handleChanges(key, kvs)
}
func (c *cluster) monitor(key string, l UpdateListener) error {
c.lock.Lock()
c.listeners[key] = append(c.listeners[key], l)
c.lock.Unlock()
cli, err := c.getClient()
if err != nil {
return err
}
c.load(cli, key)
c.watchGroup.Run(func() {
c.watch(cli, key)
})
return nil
}
func (c *cluster) newClient() (EtcdClient, error) {
cli, err := NewClient(c.endpoints)
if err != nil {
return nil, err
}
go c.watchConnState(cli)
return cli, nil
}
func (c *cluster) reload(cli EtcdClient) {
c.lock.Lock()
close(c.done)
c.watchGroup.Wait()
c.done = make(chan lang.PlaceholderType)
c.watchGroup = threading.NewRoutineGroup()
var keys []string
for k := range c.listeners {
keys = append(keys, k)
}
c.lock.Unlock()
for _, key := range keys {
k := key
c.watchGroup.Run(func() {
c.load(cli, k)
c.watch(cli, k)
})
}
}
func (c *cluster) watch(cli EtcdClient, key string) {
rch := cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
for {
select {
case wresp, ok := <-rch:
if !ok {
logx.Error("etcd monitor chan has been closed")
return
}
if wresp.Canceled {
logx.Error("etcd monitor chan has been canceled")
return
}
if wresp.Err() != nil {
logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
return
}
c.handleWatchEvents(key, wresp.Events)
case <-c.done:
return
}
}
}
func (c *cluster) watchConnState(cli EtcdClient) {
watcher := newStateWatcher()
watcher.addListener(func() {
go c.reload(cli)
})
watcher.watch(cli.ActiveConnection())
}
func DialClient(endpoints []string) (EtcdClient, error) {
return clientv3.New(clientv3.Config{
Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout,
DialKeepAliveTime: dialKeepAliveTime,
DialKeepAliveTimeout: DialTimeout,
RejectOldCluster: true,
})
}
func getClusterKey(endpoints []string) string {
sort.Strings(endpoints)
return strings.Join(endpoints, endpointsSeparator)
}
func makeKeyPrefix(key string) string {
return fmt.Sprintf("%s%c", key, Delimiter)
}

View File

@ -0,0 +1,245 @@
package internal
import (
"context"
"sync"
"testing"
"zero/core/contextx"
"zero/core/stringx"
"github.com/coreos/etcd/mvcc/mvccpb"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"go.etcd.io/etcd/clientv3"
)
var mockLock sync.Mutex
func setMockClient(cli EtcdClient) func() {
mockLock.Lock()
NewClient = func([]string) (EtcdClient, error) {
return cli, nil
}
return func() {
NewClient = DialClient
mockLock.Unlock()
}
}
func TestGetCluster(t *testing.T) {
c1 := GetRegistry().getCluster([]string{"first"})
c2 := GetRegistry().getCluster([]string{"second"})
c3 := GetRegistry().getCluster([]string{"first"})
assert.Equal(t, c1, c3)
assert.NotEqual(t, c1, c2)
}
func TestGetClusterKey(t *testing.T) {
assert.Equal(t, getClusterKey([]string{"localhost:1234", "remotehost:5678"}),
getClusterKey([]string{"remotehost:5678", "localhost:1234"}))
}
func TestCluster_HandleChanges(t *testing.T) {
ctrl := gomock.NewController(t)
l := NewMockUpdateListener(ctrl)
l.EXPECT().OnAdd(KV{
Key: "first",
Val: "1",
})
l.EXPECT().OnAdd(KV{
Key: "second",
Val: "2",
})
l.EXPECT().OnDelete(KV{
Key: "first",
Val: "1",
})
l.EXPECT().OnDelete(KV{
Key: "second",
Val: "2",
})
l.EXPECT().OnAdd(KV{
Key: "third",
Val: "3",
})
l.EXPECT().OnAdd(KV{
Key: "fourth",
Val: "4",
})
c := newCluster([]string{"any"})
c.listeners["any"] = []UpdateListener{l}
c.handleChanges("any", []KV{
{
Key: "first",
Val: "1",
},
{
Key: "second",
Val: "2",
},
})
assert.EqualValues(t, map[string]string{
"first": "1",
"second": "2",
}, c.values["any"])
c.handleChanges("any", []KV{
{
Key: "third",
Val: "3",
},
{
Key: "fourth",
Val: "4",
},
})
assert.EqualValues(t, map[string]string{
"third": "3",
"fourth": "4",
}, c.values["any"])
}
func TestCluster_Load(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Get(gomock.Any(), "any/", gomock.Any()).Return(&clientv3.GetResponse{
Kvs: []*mvccpb.KeyValue{
{
Key: []byte("hello"),
Value: []byte("world"),
},
},
}, nil)
cli.EXPECT().Ctx().Return(context.Background())
c := &cluster{
values: make(map[string]map[string]string),
}
c.load(cli, "any")
}
func TestCluster_Watch(t *testing.T) {
tests := []struct {
name string
method int
eventType mvccpb.Event_EventType
}{
{
name: "add",
eventType: clientv3.EventTypePut,
},
{
name: "delete",
eventType: clientv3.EventTypeDelete,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
cli.EXPECT().Ctx().Return(context.Background())
var wg sync.WaitGroup
wg.Add(1)
c := &cluster{
listeners: make(map[string][]UpdateListener),
values: make(map[string]map[string]string),
}
listener := NewMockUpdateListener(ctrl)
c.listeners["any"] = []UpdateListener{listener}
listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) {
assert.Equal(t, "hello", kv.Key)
assert.Equal(t, "world", kv.Val)
wg.Done()
}).MaxTimes(1)
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) {
wg.Done()
}).MaxTimes(1)
go c.watch(cli, "any")
ch <- clientv3.WatchResponse{
Events: []*clientv3.Event{
{
Type: test.eventType,
Kv: &mvccpb.KeyValue{
Key: []byte("hello"),
Value: []byte("world"),
},
},
},
}
wg.Wait()
})
}
}
func TestClusterWatch_RespFailures(t *testing.T) {
resps := []clientv3.WatchResponse{
{
Canceled: true,
},
{
// cause resp.Err() != nil
CompactRevision: 1,
},
}
for _, resp := range resps {
t.Run(stringx.Rand(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := new(cluster)
go func() {
ch <- resp
}()
c.watch(cli, "any")
})
}
}
func TestClusterWatch_CloseChan(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := new(cluster)
go func() {
close(ch)
}()
c.watch(cli, "any")
}
func TestValueOnlyContext(t *testing.T) {
ctx := contextx.ValueOnlyFrom(context.Background())
ctx.Done()
assert.Nil(t, ctx.Err())
}
type mockedSharedCalls struct {
fn func() (interface{}, error)
}
func (c mockedSharedCalls) Do(_ string, fn func() (interface{}, error)) (interface{}, error) {
return c.fn()
}
func (c mockedSharedCalls) DoEx(_ string, fn func() (interface{}, error)) (interface{}, bool, error) {
val, err := c.fn()
return val, true, err
}

View File

@ -0,0 +1,148 @@
package internal
import (
"math/rand"
"time"
"zero/core/logx"
)
type roundRobinBalancer struct {
*baseBalancer
conns []serverConn
index int
}
func NewRoundRobinBalancer(dialFn DialFn, closeFn CloseFn, exclusive bool) *roundRobinBalancer {
balancer := new(roundRobinBalancer)
balancer.baseBalancer = newBaseBalancer(dialFn, closeFn, exclusive)
return balancer
}
func (b *roundRobinBalancer) AddConn(kv KV) error {
var conn interface{}
prev, found := b.addKv(kv.Key, kv.Val)
if found {
conn = b.handlePrevious(prev, kv.Val)
}
if conn == nil {
var err error
conn, err = b.dialFn(kv.Val)
if err != nil {
b.removeKv(kv.Key)
return err
}
}
b.lock.Lock()
defer b.lock.Unlock()
b.conns = append(b.conns, serverConn{
key: kv.Key,
conn: conn,
})
b.notify(kv.Key)
return nil
}
func (b *roundRobinBalancer) handlePrevious(prev []string, server string) interface{} {
if len(prev) == 0 {
return nil
}
b.lock.Lock()
defer b.lock.Unlock()
if b.exclusive {
for _, item := range prev {
conns := b.conns[:0]
for _, each := range b.conns {
if each.key == item {
if err := b.closeFn(server, each.conn); err != nil {
logx.Error(err)
}
} else {
conns = append(conns, each)
}
}
b.conns = conns
}
} else {
for _, each := range b.conns {
if each.key == prev[0] {
return each.conn
}
}
}
return nil
}
func (b *roundRobinBalancer) initialize() {
rand.Seed(time.Now().UnixNano())
if len(b.conns) > 0 {
b.index = rand.Intn(len(b.conns))
}
}
func (b *roundRobinBalancer) IsEmpty() bool {
b.lock.Lock()
empty := len(b.conns) == 0
b.lock.Unlock()
return empty
}
func (b *roundRobinBalancer) Next(...string) (interface{}, bool) {
b.lock.Lock()
defer b.lock.Unlock()
if len(b.conns) == 0 {
return nil, false
}
b.index = (b.index + 1) % len(b.conns)
return b.conns[b.index].conn, true
}
func (b *roundRobinBalancer) notify(key string) {
if b.listener == nil {
return
}
// b.servers has the format of map[conn][]key
var keys []string
var values []string
for k, v := range b.servers {
values = append(values, k)
keys = append(keys, v...)
}
b.listener.OnUpdate(keys, values, key)
}
func (b *roundRobinBalancer) RemoveKey(key string) {
server, keep := b.removeKv(key)
b.lock.Lock()
defer b.lock.Unlock()
conns := b.conns[:0]
for _, conn := range b.conns {
if conn.key == key {
// there are other keys assocated with the conn, don't close the conn.
if keep {
continue
}
if err := b.closeFn(server, conn.conn); err != nil {
logx.Error(err)
}
} else {
conns = append(conns, conn)
}
}
b.conns = conns
// notify without new key
b.notify("")
}

View File

@ -0,0 +1,321 @@
package internal
import (
"errors"
"sort"
"strconv"
"testing"
"zero/core/logx"
"zero/core/mathx"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func init() {
logx.Disable()
}
func TestRoundRobin_addConn(t *testing.T) {
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return errors.New("error")
}, false)
assert.Nil(t, b.AddConn(KV{
Key: "thekey1",
Val: "thevalue",
}))
assert.EqualValues(t, []serverConn{
{
key: "thekey1",
conn: mockConn{server: "thevalue"},
},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey1"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey1": "thevalue",
}, b.mapping)
assert.Nil(t, b.AddConn(KV{
Key: "thekey2",
Val: "thevalue",
}))
assert.EqualValues(t, []serverConn{
{
key: "thekey1",
conn: mockConn{server: "thevalue"},
},
{
key: "thekey2",
conn: mockConn{server: "thevalue"},
},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey1", "thekey2"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey1": "thevalue",
"thekey2": "thevalue",
}, b.mapping)
assert.False(t, b.IsEmpty())
b.RemoveKey("thekey1")
assert.EqualValues(t, []serverConn{
{
key: "thekey2",
conn: mockConn{server: "thevalue"},
},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey2"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey2": "thevalue",
}, b.mapping)
assert.False(t, b.IsEmpty())
b.RemoveKey("thekey2")
assert.EqualValues(t, []serverConn{}, b.conns)
assert.EqualValues(t, map[string][]string{}, b.servers)
assert.EqualValues(t, map[string]string{}, b.mapping)
assert.True(t, b.IsEmpty())
}
func TestRoundRobin_addConnExclusive(t *testing.T) {
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return nil
}, true)
assert.Nil(t, b.AddConn(KV{
Key: "thekey1",
Val: "thevalue",
}))
assert.EqualValues(t, []serverConn{
{
key: "thekey1",
conn: mockConn{server: "thevalue"},
},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey1"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey1": "thevalue",
}, b.mapping)
assert.Nil(t, b.AddConn(KV{
Key: "thekey2",
Val: "thevalue",
}))
assert.EqualValues(t, []serverConn{
{
key: "thekey2",
conn: mockConn{server: "thevalue"},
},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey2"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey2": "thevalue",
}, b.mapping)
assert.False(t, b.IsEmpty())
b.RemoveKey("thekey1")
b.RemoveKey("thekey2")
assert.EqualValues(t, []serverConn{}, b.conns)
assert.EqualValues(t, map[string][]string{}, b.servers)
assert.EqualValues(t, map[string]string{}, b.mapping)
assert.True(t, b.IsEmpty())
}
func TestRoundRobin_addConnDupExclusive(t *testing.T) {
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return errors.New("error")
}, true)
assert.Nil(t, b.AddConn(KV{
Key: "thekey1",
Val: "thevalue",
}))
assert.EqualValues(t, []serverConn{
{
key: "thekey1",
conn: mockConn{server: "thevalue"},
},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"thevalue": {"thekey1"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey1": "thevalue",
}, b.mapping)
assert.Nil(t, b.AddConn(KV{
Key: "thekey",
Val: "anothervalue",
}))
assert.Nil(t, b.AddConn(KV{
Key: "thekey1",
Val: "thevalue",
}))
assert.EqualValues(t, []serverConn{
{
key: "thekey",
conn: mockConn{server: "anothervalue"},
},
{
key: "thekey1",
conn: mockConn{server: "thevalue"},
},
}, b.conns)
assert.EqualValues(t, map[string][]string{
"anothervalue": {"thekey"},
"thevalue": {"thekey1"},
}, b.servers)
assert.EqualValues(t, map[string]string{
"thekey": "anothervalue",
"thekey1": "thevalue",
}, b.mapping)
assert.False(t, b.IsEmpty())
b.RemoveKey("thekey")
b.RemoveKey("thekey1")
assert.EqualValues(t, []serverConn{}, b.conns)
assert.EqualValues(t, map[string][]string{}, b.servers)
assert.EqualValues(t, map[string]string{}, b.mapping)
assert.True(t, b.IsEmpty())
}
func TestRoundRobin_addConnError(t *testing.T) {
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return nil, errors.New("error")
}, func(server string, conn interface{}) error {
return nil
}, true)
assert.NotNil(t, b.AddConn(KV{
Key: "thekey1",
Val: "thevalue",
}))
assert.Nil(t, b.conns)
assert.EqualValues(t, map[string][]string{}, b.servers)
assert.EqualValues(t, map[string]string{}, b.mapping)
}
func TestRoundRobin_initialize(t *testing.T) {
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return nil
}, true)
for i := 0; i < 100; i++ {
assert.Nil(t, b.AddConn(KV{
Key: "thekey/" + strconv.Itoa(i),
Val: "thevalue/" + strconv.Itoa(i),
}))
}
m := make(map[int]int)
const total = 1000
for i := 0; i < total; i++ {
b.initialize()
m[b.index]++
}
mi := make(map[interface{}]int, len(m))
for k, v := range m {
mi[k] = v
}
entropy := mathx.CalcEntropy(mi, total)
assert.True(t, entropy > .95)
}
func TestRoundRobin_next(t *testing.T) {
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return errors.New("error")
}, true)
const size = 100
for i := 0; i < size; i++ {
assert.Nil(t, b.AddConn(KV{
Key: "thekey/" + strconv.Itoa(i),
Val: "thevalue/" + strconv.Itoa(i),
}))
}
m := make(map[interface{}]int)
const total = 10000
for i := 0; i < total; i++ {
val, ok := b.Next()
assert.True(t, ok)
m[val]++
}
entropy := mathx.CalcEntropy(m, total)
assert.Equal(t, size, len(m))
assert.True(t, entropy > .95)
for i := 0; i < size; i++ {
b.RemoveKey("thekey/" + strconv.Itoa(i))
}
_, ok := b.Next()
assert.False(t, ok)
}
func TestRoundRobinBalancer_Listener(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return nil
}, true)
assert.Nil(t, b.AddConn(KV{
Key: "key1",
Val: "val1",
}))
assert.Nil(t, b.AddConn(KV{
Key: "key2",
Val: "val2",
}))
listener := NewMockListener(ctrl)
listener.EXPECT().OnUpdate(gomock.Any(), gomock.Any(), "key2").Do(func(keys, vals, _ interface{}) {
sort.Strings(vals.([]string))
sort.Strings(keys.([]string))
assert.EqualValues(t, []string{"key1", "key2"}, keys)
assert.EqualValues(t, []string{"val1", "val2"}, vals)
})
b.setListener(listener)
b.notify("key2")
}
func TestRoundRobinBalancer_remove(t *testing.T) {
b := NewRoundRobinBalancer(func(server string) (interface{}, error) {
return mockConn{
server: server,
}, nil
}, func(server string, conn interface{}) error {
return nil
}, true)
assert.Nil(t, b.handlePrevious(nil, "any"))
_, ok := b.doRemoveKv("any")
assert.True(t, ok)
}

View File

@ -0,0 +1,58 @@
//go:generate mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
package internal
import (
"context"
"sync"
"google.golang.org/grpc/connectivity"
)
type (
etcdConn interface {
GetState() connectivity.State
WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool
}
stateWatcher struct {
disconnected bool
currentState connectivity.State
listeners []func()
lock sync.Mutex
}
)
func newStateWatcher() *stateWatcher {
return new(stateWatcher)
}
func (sw *stateWatcher) addListener(l func()) {
sw.lock.Lock()
sw.listeners = append(sw.listeners, l)
sw.lock.Unlock()
}
func (sw *stateWatcher) watch(conn etcdConn) {
sw.currentState = conn.GetState()
for {
if conn.WaitForStateChange(context.Background(), sw.currentState) {
newState := conn.GetState()
sw.lock.Lock()
sw.currentState = newState
switch newState {
case connectivity.TransientFailure, connectivity.Shutdown:
sw.disconnected = true
case connectivity.Ready:
if sw.disconnected {
sw.disconnected = false
for _, l := range sw.listeners {
l()
}
}
}
sw.lock.Unlock()
}
}
}

View File

@ -0,0 +1,63 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: statewatcher.go
// Package internal is a generated GoMock package.
package internal
import (
context "context"
gomock "github.com/golang/mock/gomock"
connectivity "google.golang.org/grpc/connectivity"
reflect "reflect"
)
// MocketcdConn is a mock of etcdConn interface
type MocketcdConn struct {
ctrl *gomock.Controller
recorder *MocketcdConnMockRecorder
}
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn
type MocketcdConnMockRecorder struct {
mock *MocketcdConn
}
// NewMocketcdConn creates a new mock instance
func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn {
mock := &MocketcdConn{ctrl: ctrl}
mock.recorder = &MocketcdConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder {
return m.recorder
}
// GetState mocks base method
func (m *MocketcdConn) GetState() connectivity.State {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetState")
ret0, _ := ret[0].(connectivity.State)
return ret0
}
// GetState indicates an expected call of GetState
func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState))
}
// WaitForStateChange mocks base method
func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState)
ret0, _ := ret[0].(bool)
return ret0
}
// WaitForStateChange indicates an expected call of WaitForStateChange
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
}

View File

@ -0,0 +1,27 @@
package internal
import (
"sync"
"testing"
"github.com/golang/mock/gomock"
"google.golang.org/grpc/connectivity"
)
func TestStateWatcher_watch(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
watcher := newStateWatcher()
var wg sync.WaitGroup
wg.Add(1)
watcher.addListener(func() {
wg.Done()
})
conn := NewMocketcdConn(ctrl)
conn.EXPECT().GetState().Return(connectivity.Ready)
conn.EXPECT().GetState().Return(connectivity.TransientFailure)
conn.EXPECT().GetState().Return(connectivity.Ready).AnyTimes()
conn.EXPECT().WaitForStateChange(gomock.Any(), gomock.Any()).Return(true).AnyTimes()
go watcher.watch(conn)
wg.Wait()
}

View File

@ -0,0 +1,14 @@
//go:generate mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
package internal
type (
KV struct {
Key string
Val string
}
UpdateListener interface {
OnAdd(kv KV)
OnDelete(kv KV)
}
)

View File

@ -0,0 +1,57 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: updatelistener.go
// Package internal is a generated GoMock package.
package internal
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockUpdateListener is a mock of UpdateListener interface
type MockUpdateListener struct {
ctrl *gomock.Controller
recorder *MockUpdateListenerMockRecorder
}
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener
type MockUpdateListenerMockRecorder struct {
mock *MockUpdateListener
}
// NewMockUpdateListener creates a new mock instance
func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener {
mock := &MockUpdateListener{ctrl: ctrl}
mock.recorder = &MockUpdateListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder {
return m.recorder
}
// OnAdd mocks base method
func (m *MockUpdateListener) OnAdd(kv KV) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnAdd", kv)
}
// OnAdd indicates an expected call of OnAdd
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
}
// OnDelete mocks base method
func (m *MockUpdateListener) OnDelete(kv KV) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnDelete", kv)
}
// OnDelete indicates an expected call of OnDelete
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
}

View File

@ -0,0 +1,19 @@
package internal
import "time"
const (
autoSyncInterval = time.Minute
coolDownInterval = time.Second
dialTimeout = 5 * time.Second
dialKeepAliveTime = 5 * time.Second
requestTimeout = 3 * time.Second
Delimiter = '/'
endpointsSeparator = ","
)
var (
DialTimeout = dialTimeout
RequestTimeout = requestTimeout
NewClient = DialClient
)

View File

@ -0,0 +1,4 @@
apiVersion: v1
kind: Namespace
metadata:
name: discov

View File

@ -0,0 +1,378 @@
apiVersion: v1
kind: Service
metadata:
name: etcd
namespace: discov
spec:
ports:
- name: etcd-port
port: 2379
protocol: TCP
targetPort: 2379
selector:
app: etcd
---
apiVersion: v1
kind: Pod
metadata:
labels:
app: etcd
etcd_node: etcd0
name: etcd0
namespace: discov
spec:
containers:
- command:
- /usr/local/bin/etcd
- --name
- etcd0
- --initial-advertise-peer-urls
- http://etcd0:2380
- --listen-peer-urls
- http://0.0.0.0:2380
- --listen-client-urls
- http://0.0.0.0:2379
- --advertise-client-urls
- http://etcd0:2379
- --initial-cluster
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state
- new
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
name: etcd0
ports:
- containerPort: 2379
name: client
protocol: TCP
- containerPort: 2380
name: server
protocol: TCP
imagePullSecrets:
- name: aliyun
affinity:
podAntiAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: app
operator: In
values:
- etcd
topologyKey: "kubernetes.io/hostname"
restartPolicy: Always
---
apiVersion: v1
kind: Service
metadata:
labels:
etcd_node: etcd0
name: etcd0
namespace: discov
spec:
ports:
- name: client
port: 2379
protocol: TCP
targetPort: 2379
- name: server
port: 2380
protocol: TCP
targetPort: 2380
selector:
etcd_node: etcd0
---
apiVersion: v1
kind: Pod
metadata:
labels:
app: etcd
etcd_node: etcd1
name: etcd1
namespace: discov
spec:
containers:
- command:
- /usr/local/bin/etcd
- --name
- etcd1
- --initial-advertise-peer-urls
- http://etcd1:2380
- --listen-peer-urls
- http://0.0.0.0:2380
- --listen-client-urls
- http://0.0.0.0:2379
- --advertise-client-urls
- http://etcd1:2379
- --initial-cluster
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state
- new
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
name: etcd1
ports:
- containerPort: 2379
name: client
protocol: TCP
- containerPort: 2380
name: server
protocol: TCP
imagePullSecrets:
- name: aliyun
affinity:
podAntiAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: app
operator: In
values:
- etcd
topologyKey: "kubernetes.io/hostname"
restartPolicy: Always
---
apiVersion: v1
kind: Service
metadata:
labels:
etcd_node: etcd1
name: etcd1
namespace: discov
spec:
ports:
- name: client
port: 2379
protocol: TCP
targetPort: 2379
- name: server
port: 2380
protocol: TCP
targetPort: 2380
selector:
etcd_node: etcd1
---
apiVersion: v1
kind: Pod
metadata:
labels:
app: etcd
etcd_node: etcd2
name: etcd2
namespace: discov
spec:
containers:
- command:
- /usr/local/bin/etcd
- --name
- etcd2
- --initial-advertise-peer-urls
- http://etcd2:2380
- --listen-peer-urls
- http://0.0.0.0:2380
- --listen-client-urls
- http://0.0.0.0:2379
- --advertise-client-urls
- http://etcd2:2379
- --initial-cluster
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state
- new
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
name: etcd2
ports:
- containerPort: 2379
name: client
protocol: TCP
- containerPort: 2380
name: server
protocol: TCP
imagePullSecrets:
- name: aliyun
affinity:
podAntiAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: app
operator: In
values:
- etcd
topologyKey: "kubernetes.io/hostname"
restartPolicy: Always
---
apiVersion: v1
kind: Service
metadata:
labels:
etcd_node: etcd2
name: etcd2
namespace: discov
spec:
ports:
- name: client
port: 2379
protocol: TCP
targetPort: 2379
- name: server
port: 2380
protocol: TCP
targetPort: 2380
selector:
etcd_node: etcd2
---
apiVersion: v1
kind: Pod
metadata:
labels:
app: etcd
etcd_node: etcd3
name: etcd3
namespace: discov
spec:
containers:
- command:
- /usr/local/bin/etcd
- --name
- etcd3
- --initial-advertise-peer-urls
- http://etcd3:2380
- --listen-peer-urls
- http://0.0.0.0:2380
- --listen-client-urls
- http://0.0.0.0:2379
- --advertise-client-urls
- http://etcd3:2379
- --initial-cluster
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state
- new
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
name: etcd3
ports:
- containerPort: 2379
name: client
protocol: TCP
- containerPort: 2380
name: server
protocol: TCP
imagePullSecrets:
- name: aliyun
affinity:
podAntiAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: app
operator: In
values:
- etcd
topologyKey: "kubernetes.io/hostname"
restartPolicy: Always
---
apiVersion: v1
kind: Service
metadata:
labels:
etcd_node: etcd3
name: etcd3
namespace: discov
spec:
ports:
- name: client
port: 2379
protocol: TCP
targetPort: 2379
- name: server
port: 2380
protocol: TCP
targetPort: 2380
selector:
etcd_node: etcd3
---
apiVersion: v1
kind: Pod
metadata:
labels:
app: etcd
etcd_node: etcd4
name: etcd4
namespace: discov
spec:
containers:
- command:
- /usr/local/bin/etcd
- --name
- etcd4
- --initial-advertise-peer-urls
- http://etcd4:2380
- --listen-peer-urls
- http://0.0.0.0:2380
- --listen-client-urls
- http://0.0.0.0:2379
- --advertise-client-urls
- http://etcd4:2379
- --initial-cluster
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state
- new
image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest
name: etcd4
ports:
- containerPort: 2379
name: client
protocol: TCP
- containerPort: 2380
name: server
protocol: TCP
imagePullSecrets:
- name: aliyun
affinity:
podAntiAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: app
operator: In
values:
- etcd
topologyKey: "kubernetes.io/hostname"
restartPolicy: Always
---
apiVersion: v1
kind: Service
metadata:
labels:
etcd_node: etcd4
name: etcd4
namespace: discov
spec:
ports:
- name: client
port: 2379
protocol: TCP
targetPort: 2379
- name: server
port: 2380
protocol: TCP
targetPort: 2380
selector:
etcd_node: etcd4

143
core/discov/publisher.go Normal file
View File

@ -0,0 +1,143 @@
package discov
import (
"zero/core/discov/internal"
"zero/core/lang"
"zero/core/logx"
"zero/core/proc"
"zero/core/syncx"
"zero/core/threading"
"go.etcd.io/etcd/clientv3"
)
type (
PublisherOption func(client *Publisher)
Publisher struct {
endpoints []string
key string
fullKey string
id int64
value string
lease clientv3.LeaseID
quit *syncx.DoneChan
pauseChan chan lang.PlaceholderType
resumeChan chan lang.PlaceholderType
}
)
func NewPublisher(endpoints []string, key, value string, opts ...PublisherOption) *Publisher {
publisher := &Publisher{
endpoints: endpoints,
key: key,
value: value,
quit: syncx.NewDoneChan(),
pauseChan: make(chan lang.PlaceholderType),
resumeChan: make(chan lang.PlaceholderType),
}
for _, opt := range opts {
opt(publisher)
}
return publisher
}
func (p *Publisher) KeepAlive() error {
cli, err := internal.GetRegistry().GetConn(p.endpoints)
if err != nil {
return err
}
p.lease, err = p.register(cli)
if err != nil {
return err
}
proc.AddWrapUpListener(func() {
p.Stop()
})
return p.keepAliveAsync(cli)
}
func (p *Publisher) Pause() {
p.pauseChan <- lang.Placeholder
}
func (p *Publisher) Resume() {
p.resumeChan <- lang.Placeholder
}
func (p *Publisher) Stop() {
p.quit.Close()
}
func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
ch, err := cli.KeepAlive(cli.Ctx(), p.lease)
if err != nil {
return err
}
threading.GoSafe(func() {
for {
select {
case _, ok := <-ch:
if !ok {
p.revoke(cli)
if err := p.KeepAlive(); err != nil {
logx.Errorf("KeepAlive: %s", err.Error())
}
return
}
case <-p.pauseChan:
logx.Infof("paused etcd renew, key: %s, value: %s", p.key, p.value)
p.revoke(cli)
select {
case <-p.resumeChan:
if err := p.KeepAlive(); err != nil {
logx.Errorf("KeepAlive: %s", err.Error())
}
return
case <-p.quit.Done():
return
}
case <-p.quit.Done():
p.revoke(cli)
return
}
}
})
return nil
}
func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, error) {
resp, err := client.Grant(client.Ctx(), TimeToLive)
if err != nil {
return clientv3.NoLease, err
}
lease := resp.ID
if p.id > 0 {
p.fullKey = makeEtcdKey(p.key, p.id)
} else {
p.fullKey = makeEtcdKey(p.key, int64(lease))
}
_, err = client.Put(client.Ctx(), p.fullKey, p.value, clientv3.WithLease(lease))
return lease, err
}
func (p *Publisher) revoke(cli internal.EtcdClient) {
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
logx.Error(err)
}
}
func WithId(id int64) PublisherOption {
return func(publisher *Publisher) {
publisher.id = id
}
}

View File

@ -0,0 +1,151 @@
package discov
import (
"errors"
"sync"
"testing"
"zero/core/discov/internal"
"zero/core/logx"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"go.etcd.io/etcd/clientv3"
)
func init() {
logx.Disable()
}
func TestPublisher_register(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
ID: id,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
pub := NewPublisher(nil, "thekey", "thevalue")
_, err := pub.register(cli)
assert.Nil(t, err)
}
func TestPublisher_registerWithId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id = 2
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
ID: 1,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id))
_, err := pub.register(cli)
assert.Nil(t, err)
}
func TestPublisher_registerError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(nil, errors.New("error"))
pub := NewPublisher(nil, "thekey", "thevalue")
val, err := pub.register(cli)
assert.NotNil(t, err)
assert.Equal(t, clientv3.NoLease, val)
}
func TestPublisher_revoke(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().Revoke(gomock.Any(), id)
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.revoke(cli)
}
func TestPublisher_revokeError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().Revoke(gomock.Any(), id).Return(nil, errors.New("error"))
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.revoke(cli)
}
func TestPublisher_keepAliveAsyncError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id).Return(nil, errors.New("error"))
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
assert.NotNil(t, pub.keepAliveAsync(cli))
}
func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.Stop()
assert.Nil(t, pub.keepAliveAsync(cli))
wg.Wait()
}
func TestPublisher_keepAliveAsyncPause(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
pub := NewPublisher(nil, "thekey", "thevalue")
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
pub.Stop()
wg.Done()
})
pub.lease = id
assert.Nil(t, pub.keepAliveAsync(cli))
pub.Pause()
wg.Wait()
}

35
core/discov/renewer.go Normal file
View File

@ -0,0 +1,35 @@
package discov
import "zero/core/logx"
type (
Renewer interface {
Start()
Stop()
Pause()
Resume()
}
etcdRenewer struct {
*Publisher
}
)
func NewRenewer(endpoints []string, key, value string, renewId int64) Renewer {
var publisher *Publisher
if renewId > 0 {
publisher = NewPublisher(endpoints, key, value, WithId(renewId))
} else {
publisher = NewPublisher(endpoints, key, value)
}
return &etcdRenewer{
Publisher: publisher,
}
}
func (sr *etcdRenewer) Start() {
if err := sr.KeepAlive(); err != nil {
logx.Error(err)
}
}

186
core/discov/subclient.go Normal file
View File

@ -0,0 +1,186 @@
package discov
import (
"sync"
"zero/core/discov/internal"
"zero/core/logx"
)
const (
_ = iota // keyBasedBalance, default
idBasedBalance
)
type (
Listener internal.Listener
subClient struct {
balancer internal.Balancer
lock sync.Mutex
cond *sync.Cond
listeners []internal.Listener
}
balanceOptions struct {
balanceType int
}
BalanceOption func(*balanceOptions)
RoundRobinSubClient struct {
*subClient
}
ConsistentSubClient struct {
*subClient
}
BatchConsistentSubClient struct {
*ConsistentSubClient
}
)
func NewRoundRobinSubClient(endpoints []string, key string, dialFn internal.DialFn, closeFn internal.CloseFn,
opts ...SubOption) (*RoundRobinSubClient, error) {
var subOpts subOptions
for _, opt := range opts {
opt(&subOpts)
}
cli, err := newSubClient(endpoints, key, internal.NewRoundRobinBalancer(dialFn, closeFn, subOpts.exclusive))
if err != nil {
return nil, err
}
return &RoundRobinSubClient{
subClient: cli,
}, nil
}
func NewConsistentSubClient(endpoints []string, key string, dialFn internal.DialFn,
closeFn internal.CloseFn, opts ...BalanceOption) (*ConsistentSubClient, error) {
var balanceOpts balanceOptions
for _, opt := range opts {
opt(&balanceOpts)
}
var keyer func(internal.KV) string
switch balanceOpts.balanceType {
case idBasedBalance:
keyer = func(kv internal.KV) string {
if id, ok := extractId(kv.Key); ok {
return id
} else {
return kv.Key
}
}
default:
keyer = func(kv internal.KV) string {
return kv.Val
}
}
cli, err := newSubClient(endpoints, key, internal.NewConsistentBalancer(dialFn, closeFn, keyer))
if err != nil {
return nil, err
}
return &ConsistentSubClient{
subClient: cli,
}, nil
}
func NewBatchConsistentSubClient(endpoints []string, key string, dialFn internal.DialFn, closeFn internal.CloseFn,
opts ...BalanceOption) (*BatchConsistentSubClient, error) {
cli, err := NewConsistentSubClient(endpoints, key, dialFn, closeFn, opts...)
if err != nil {
return nil, err
}
return &BatchConsistentSubClient{
ConsistentSubClient: cli,
}, nil
}
func newSubClient(endpoints []string, key string, balancer internal.Balancer) (*subClient, error) {
client := &subClient{
balancer: balancer,
}
client.cond = sync.NewCond(&client.lock)
if err := internal.GetRegistry().Monitor(endpoints, key, client); err != nil {
return nil, err
}
return client, nil
}
func (c *subClient) AddListener(listener internal.Listener) {
c.lock.Lock()
c.listeners = append(c.listeners, listener)
c.lock.Unlock()
}
func (c *subClient) OnAdd(kv internal.KV) {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.balancer.AddConn(kv); err != nil {
logx.Error(err)
} else {
c.cond.Broadcast()
}
}
func (c *subClient) OnDelete(kv internal.KV) {
c.balancer.RemoveKey(kv.Key)
}
func (c *subClient) WaitForServers() {
logx.Error("Waiting for alive servers")
c.lock.Lock()
defer c.lock.Unlock()
if c.balancer.IsEmpty() {
c.cond.Wait()
}
}
func (c *subClient) onAdd(keys []string, servers []string, newKey string) {
// guarded by locked outside
for _, listener := range c.listeners {
listener.OnUpdate(keys, servers, newKey)
}
}
func (c *RoundRobinSubClient) Next() (interface{}, bool) {
return c.balancer.Next()
}
func (c *ConsistentSubClient) Next(key string) (interface{}, bool) {
return c.balancer.Next(key)
}
func (bc *BatchConsistentSubClient) Next(keys []string) (map[interface{}][]string, bool) {
if len(keys) == 0 {
return nil, false
}
result := make(map[interface{}][]string)
for _, key := range keys {
dest, ok := bc.ConsistentSubClient.Next(key)
if !ok {
return nil, false
}
result[dest] = append(result[dest], key)
}
return result, true
}
func BalanceWithId() BalanceOption {
return func(opts *balanceOptions) {
opts.balanceType = idBasedBalance
}
}

151
core/discov/subscriber.go Normal file
View File

@ -0,0 +1,151 @@
package discov
import (
"sync"
"zero/core/discov/internal"
)
type (
subOptions struct {
exclusive bool
}
SubOption func(opts *subOptions)
Subscriber struct {
items *container
}
)
func NewSubscriber(endpoints []string, key string, opts ...SubOption) *Subscriber {
var subOpts subOptions
for _, opt := range opts {
opt(&subOpts)
}
subscriber := &Subscriber{
items: newContainer(subOpts.exclusive),
}
internal.GetRegistry().Monitor(endpoints, key, subscriber.items)
return subscriber
}
func (s *Subscriber) Values() []string {
return s.items.getValues()
}
// exclusive means that key value can only be 1:1,
// which means later added value will remove the keys associated with the same value previously.
func Exclusive() SubOption {
return func(opts *subOptions) {
opts.exclusive = true
}
}
type container struct {
exclusive bool
values map[string][]string
mapping map[string]string
lock sync.Mutex
}
func newContainer(exclusive bool) *container {
return &container{
exclusive: exclusive,
values: make(map[string][]string),
mapping: make(map[string]string),
}
}
func (c *container) OnAdd(kv internal.KV) {
c.addKv(kv.Key, kv.Val)
}
func (c *container) OnDelete(kv internal.KV) {
c.removeKey(kv.Key)
}
// addKv adds the kv, returns if there are already other keys associate with the value
func (c *container) addKv(key, value string) ([]string, bool) {
c.lock.Lock()
defer c.lock.Unlock()
keys := c.values[value]
previous := append([]string(nil), keys...)
early := len(keys) > 0
if c.exclusive && early {
for _, each := range keys {
c.doRemoveKey(each)
}
}
c.values[value] = append(c.values[value], key)
c.mapping[key] = value
if early {
return previous, true
} else {
return nil, false
}
}
func (c *container) doRemoveKey(key string) {
server, ok := c.mapping[key]
if !ok {
return
}
delete(c.mapping, key)
keys := c.values[server]
remain := keys[:0]
for _, k := range keys {
if k != key {
remain = append(remain, k)
}
}
if len(remain) > 0 {
c.values[server] = remain
} else {
delete(c.values, server)
}
}
func (c *container) getValues() []string {
c.lock.Lock()
defer c.lock.Unlock()
var vs []string
for each := range c.values {
vs = append(vs, each)
}
return vs
}
// removeKey removes the kv, returns true if there are still other keys associate with the value
func (c *container) removeKey(key string) {
c.lock.Lock()
defer c.lock.Unlock()
c.doRemoveKey(key)
}
func (c *container) removeVal(val string) (empty bool) {
c.lock.Lock()
defer c.lock.Unlock()
for k := range c.values {
if k == val {
delete(c.values, k)
}
}
for k, v := range c.mapping {
if v == val {
delete(c.mapping, k)
}
}
return len(c.values) == 0
}

View File

@ -0,0 +1,21 @@
package errorx
import "sync"
type AtomicError struct {
err error
lock sync.Mutex
}
func (ae *AtomicError) Set(err error) {
ae.lock.Lock()
ae.err = err
ae.lock.Unlock()
}
func (ae *AtomicError) Load() error {
ae.lock.Lock()
err := ae.err
ae.lock.Unlock()
return err
}

View File

@ -0,0 +1,21 @@
package errorx
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
var errDummy = errors.New("hello")
func TestAtomicError(t *testing.T) {
var err AtomicError
err.Set(errDummy)
assert.Equal(t, errDummy, err.Load())
}
func TestAtomicErrorNil(t *testing.T) {
var err AtomicError
assert.Nil(t, err.Load())
}

45
core/errorx/batcherror.go Normal file
View File

@ -0,0 +1,45 @@
package errorx
import "bytes"
type (
BatchError struct {
errs errorArray
}
errorArray []error
)
func (be *BatchError) Add(err error) {
if err != nil {
be.errs = append(be.errs, err)
}
}
func (be *BatchError) Err() error {
switch len(be.errs) {
case 0:
return nil
case 1:
return be.errs[0]
default:
return be.errs
}
}
func (be *BatchError) NotNil() bool {
return len(be.errs) > 0
}
func (ea errorArray) Error() string {
var buf bytes.Buffer
for i := range ea {
if i > 0 {
buf.WriteByte('\n')
}
buf.WriteString(ea[i].Error())
}
return buf.String()
}

View File

@ -0,0 +1,48 @@
package errorx
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
const (
err1 = "first error"
err2 = "second error"
)
func TestBatchErrorNil(t *testing.T) {
var batch BatchError
assert.Nil(t, batch.Err())
assert.False(t, batch.NotNil())
batch.Add(nil)
assert.Nil(t, batch.Err())
assert.False(t, batch.NotNil())
}
func TestBatchErrorNilFromFunc(t *testing.T) {
err := func() error {
var be BatchError
return be.Err()
}()
assert.True(t, err == nil)
}
func TestBatchErrorOneError(t *testing.T) {
var batch BatchError
batch.Add(errors.New(err1))
assert.NotNil(t, batch)
assert.Equal(t, err1, batch.Err().Error())
assert.True(t, batch.NotNil())
}
func TestBatchErrorWithErrors(t *testing.T) {
var batch BatchError
batch.Add(errors.New(err1))
batch.Add(errors.New(err2))
assert.NotNil(t, batch)
assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error())
assert.True(t, batch.NotNil())
}

View File

@ -0,0 +1,93 @@
package executors
import (
"time"
)
const defaultBulkTasks = 1000
type (
BulkOption func(options *bulkOptions)
BulkExecutor struct {
executor *PeriodicalExecutor
container *bulkContainer
}
bulkOptions struct {
cachedTasks int
flushInterval time.Duration
}
)
func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
options := newBulkOptions()
for _, opt := range opts {
opt(&options)
}
container := &bulkContainer{
execute: execute,
maxTasks: options.cachedTasks,
}
executor := &BulkExecutor{
executor: NewPeriodicalExecutor(options.flushInterval, container),
container: container,
}
return executor
}
func (be *BulkExecutor) Add(task interface{}) error {
be.executor.Add(task)
return nil
}
func (be *BulkExecutor) Flush() {
be.executor.Flush()
}
func (be *BulkExecutor) Wait() {
be.executor.Wait()
}
func WithBulkTasks(tasks int) BulkOption {
return func(options *bulkOptions) {
options.cachedTasks = tasks
}
}
func WithBulkInterval(duration time.Duration) BulkOption {
return func(options *bulkOptions) {
options.flushInterval = duration
}
}
func newBulkOptions() bulkOptions {
return bulkOptions{
cachedTasks: defaultBulkTasks,
flushInterval: defaultFlushInterval,
}
}
type bulkContainer struct {
tasks []interface{}
execute Execute
maxTasks int
}
func (bc *bulkContainer) AddTask(task interface{}) bool {
bc.tasks = append(bc.tasks, task)
return len(bc.tasks) >= bc.maxTasks
}
func (bc *bulkContainer) Execute(tasks interface{}) {
vals := tasks.([]interface{})
bc.execute(vals)
}
func (bc *bulkContainer) RemoveAll() interface{} {
tasks := bc.tasks
bc.tasks = nil
return tasks
}

View File

@ -0,0 +1,113 @@
package executors
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestBulkExecutor(t *testing.T) {
var values []int
var lock sync.Mutex
exeutor := NewBulkExecutor(func(items []interface{}) {
lock.Lock()
values = append(values, len(items))
lock.Unlock()
}, WithBulkTasks(10), WithBulkInterval(time.Minute))
for i := 0; i < 50; i++ {
exeutor.Add(1)
time.Sleep(time.Millisecond)
}
lock.Lock()
assert.True(t, len(values) > 0)
// ignore last value
for i := 0; i < len(values); i++ {
assert.Equal(t, 10, values[i])
}
lock.Unlock()
}
func TestBulkExecutorFlushInterval(t *testing.T) {
const (
caches = 10
size = 5
)
var wait sync.WaitGroup
wait.Add(1)
exeutor := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items))
wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100))
for i := 0; i < size; i++ {
exeutor.Add(1)
}
wait.Wait()
}
func TestBulkExecutorEmpty(t *testing.T) {
NewBulkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called")
}, WithBulkTasks(10), WithBulkInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100)
}
func TestBulkExecutorFlush(t *testing.T) {
const (
caches = 10
tasks = 5
)
var wait sync.WaitGroup
wait.Add(1)
be := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items))
wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Minute))
for i := 0; i < tasks; i++ {
be.Add(1)
}
be.Flush()
wait.Wait()
}
func TestBuldExecutorFlushSlowTasks(t *testing.T) {
const total = 1500
lock := new(sync.Mutex)
result := make([]interface{}, 0, 10000)
exec := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * 100)
lock.Lock()
defer lock.Unlock()
for _, i := range tasks {
result = append(result, i)
}
}, WithBulkTasks(1000))
for i := 0; i < total; i++ {
assert.Nil(t, exec.Add(i))
}
exec.Flush()
exec.Wait()
assert.Equal(t, total, len(result))
}
func BenchmarkBulkExecutor(b *testing.B) {
b.ReportAllocs()
be := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
})
for i := 0; i < b.N; i++ {
time.Sleep(time.Microsecond * 200)
be.Add(1)
}
be.Flush()
}

View File

@ -0,0 +1,103 @@
package executors
import "time"
const defaultChunkSize = 1024 * 1024 // 1M
type (
ChunkOption func(options *chunkOptions)
ChunkExecutor struct {
executor *PeriodicalExecutor
container *chunkContainer
}
chunkOptions struct {
chunkSize int
flushInterval time.Duration
}
)
func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
options := newChunkOptions()
for _, opt := range opts {
opt(&options)
}
container := &chunkContainer{
execute: execute,
maxChunkSize: options.chunkSize,
}
executor := &ChunkExecutor{
executor: NewPeriodicalExecutor(options.flushInterval, container),
container: container,
}
return executor
}
func (ce *ChunkExecutor) Add(task interface{}, size int) error {
ce.executor.Add(chunk{
val: task,
size: size,
})
return nil
}
func (ce *ChunkExecutor) Flush() {
ce.executor.Flush()
}
func (ce *ChunkExecutor) Wait() {
ce.executor.Wait()
}
func WithChunkBytes(size int) ChunkOption {
return func(options *chunkOptions) {
options.chunkSize = size
}
}
func WithFlushInterval(duration time.Duration) ChunkOption {
return func(options *chunkOptions) {
options.flushInterval = duration
}
}
func newChunkOptions() chunkOptions {
return chunkOptions{
chunkSize: defaultChunkSize,
flushInterval: defaultFlushInterval,
}
}
type chunkContainer struct {
tasks []interface{}
execute Execute
size int
maxChunkSize int
}
func (bc *chunkContainer) AddTask(task interface{}) bool {
ck := task.(chunk)
bc.tasks = append(bc.tasks, ck.val)
bc.size += ck.size
return bc.size >= bc.maxChunkSize
}
func (bc *chunkContainer) Execute(tasks interface{}) {
vals := tasks.([]interface{})
bc.execute(vals)
}
func (bc *chunkContainer) RemoveAll() interface{} {
tasks := bc.tasks
bc.tasks = nil
bc.size = 0
return tasks
}
type chunk struct {
val interface{}
size int
}

View File

@ -0,0 +1,92 @@
package executors
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestChunkExecutor(t *testing.T) {
var values []int
var lock sync.Mutex
exeutor := NewChunkExecutor(func(items []interface{}) {
lock.Lock()
values = append(values, len(items))
lock.Unlock()
}, WithChunkBytes(10), WithFlushInterval(time.Minute))
for i := 0; i < 50; i++ {
exeutor.Add(1, 1)
time.Sleep(time.Millisecond)
}
lock.Lock()
assert.True(t, len(values) > 0)
// ignore last value
for i := 0; i < len(values); i++ {
assert.Equal(t, 10, values[i])
}
lock.Unlock()
}
func TestChunkExecutorFlushInterval(t *testing.T) {
const (
caches = 10
size = 5
)
var wait sync.WaitGroup
wait.Add(1)
exeutor := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items))
wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100))
for i := 0; i < size; i++ {
exeutor.Add(1, 1)
}
wait.Wait()
}
func TestChunkExecutorEmpty(t *testing.T) {
NewChunkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called")
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100)
}
func TestChunkExecutorFlush(t *testing.T) {
const (
caches = 10
tasks = 5
)
var wait sync.WaitGroup
wait.Add(1)
be := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items))
wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Minute))
for i := 0; i < tasks; i++ {
be.Add(1, 1)
}
be.Flush()
wait.Wait()
}
func BenchmarkChunkExecutor(b *testing.B) {
b.ReportAllocs()
be := NewChunkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
})
for i := 0; i < b.N; i++ {
time.Sleep(time.Microsecond * 200)
be.Add(1, 1)
}
be.Flush()
}

View File

@ -0,0 +1,44 @@
package executors
import (
"sync"
"time"
"zero/core/threading"
)
type DelayExecutor struct {
fn func()
delay time.Duration
triggered bool
lock sync.Mutex
}
func NewDelayExecutor(fn func(), delay time.Duration) *DelayExecutor {
return &DelayExecutor{
fn: fn,
delay: delay,
}
}
func (de *DelayExecutor) Trigger() {
de.lock.Lock()
defer de.lock.Unlock()
if de.triggered {
return
}
de.triggered = true
threading.GoSafe(func() {
timer := time.NewTimer(de.delay)
defer timer.Stop()
<-timer.C
// set triggered to false before calling fn to ensure no triggers are missed.
de.lock.Lock()
de.triggered = false
de.lock.Unlock()
de.fn()
})
}

View File

@ -0,0 +1,21 @@
package executors
import (
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDelayExecutor(t *testing.T) {
var count int32
ex := NewDelayExecutor(func() {
atomic.AddInt32(&count, 1)
}, time.Millisecond*10)
for i := 0; i < 100; i++ {
ex.Trigger()
}
time.Sleep(time.Millisecond * 100)
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}

View File

@ -0,0 +1,32 @@
package executors
import (
"time"
"zero/core/syncx"
"zero/core/timex"
)
type LessExecutor struct {
threshold time.Duration
lastTime *syncx.AtomicDuration
}
func NewLessExecutor(threshold time.Duration) *LessExecutor {
return &LessExecutor{
threshold: threshold,
lastTime: syncx.NewAtomicDuration(),
}
}
func (le *LessExecutor) DoOrDiscard(execute func()) bool {
now := timex.Now()
lastTime := le.lastTime.Load()
if lastTime == 0 || lastTime+le.threshold < now {
le.lastTime.Set(now)
execute()
return true
}
return false
}

View File

@ -0,0 +1,27 @@
package executors
import (
"testing"
"time"
"zero/core/timex"
"github.com/stretchr/testify/assert"
)
func TestLessExecutor_DoOrDiscard(t *testing.T) {
executor := NewLessExecutor(time.Minute)
assert.True(t, executor.DoOrDiscard(func() {}))
assert.False(t, executor.DoOrDiscard(func() {}))
executor.lastTime.Set(timex.Now() - time.Minute - time.Second*30)
assert.True(t, executor.DoOrDiscard(func() {}))
assert.False(t, executor.DoOrDiscard(func() {}))
}
func BenchmarkLessExecutor(b *testing.B) {
exec := NewLessExecutor(time.Millisecond)
for i := 0; i < b.N; i++ {
exec.DoOrDiscard(func() {
})
}
}

View File

@ -0,0 +1,158 @@
package executors
import (
"reflect"
"sync"
"time"
"zero/core/proc"
"zero/core/threading"
"zero/core/timex"
)
const idleRound = 10
type (
// A type that satisfies executors.TaskContainer can be used as the underlying
// container that used to do periodical executions.
TaskContainer interface {
// AddTask adds the task into the container.
// Returns true if the container needs to be flushed after the addition.
AddTask(task interface{}) bool
// Execute handles the collected tasks by the container when flushing.
Execute(tasks interface{})
// RemoveAll removes the contained tasks, and return them.
RemoveAll() interface{}
}
PeriodicalExecutor struct {
commander chan interface{}
interval time.Duration
container TaskContainer
waitGroup sync.WaitGroup
guarded bool
newTicker func(duration time.Duration) timex.Ticker
lock sync.Mutex
}
)
func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor {
executor := &PeriodicalExecutor{
// buffer 1 to let the caller go quickly
commander: make(chan interface{}, 1),
interval: interval,
container: container,
newTicker: func(d time.Duration) timex.Ticker {
return timex.NewTicker(interval)
},
}
proc.AddShutdownListener(func() {
executor.Flush()
})
return executor
}
func (pe *PeriodicalExecutor) Add(task interface{}) {
if vals, ok := pe.addAndCheck(task); ok {
pe.commander <- vals
}
}
func (pe *PeriodicalExecutor) Flush() bool {
return pe.executeTasks(func() interface{} {
pe.lock.Lock()
defer pe.lock.Unlock()
return pe.container.RemoveAll()
}())
}
func (pe *PeriodicalExecutor) Sync(fn func()) {
pe.lock.Lock()
defer pe.lock.Unlock()
fn()
}
func (pe *PeriodicalExecutor) Wait() {
pe.waitGroup.Wait()
}
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock()
defer func() {
var start bool
if !pe.guarded {
pe.guarded = true
start = true
}
pe.lock.Unlock()
if start {
pe.backgroundFlush()
}
}()
if pe.container.AddTask(task) {
return pe.container.RemoveAll(), true
}
return nil, false
}
func (pe *PeriodicalExecutor) backgroundFlush() {
threading.GoSafe(func() {
ticker := pe.newTicker(pe.interval)
defer ticker.Stop()
var commanded bool
last := timex.Now()
for {
select {
case vals := <-pe.commander:
commanded = true
pe.executeTasks(vals)
last = timex.Now()
case <-ticker.Chan():
if commanded {
commanded = false
} else if pe.Flush() {
last = timex.Now()
} else if timex.Since(last) > pe.interval*idleRound {
pe.lock.Lock()
pe.guarded = false
pe.lock.Unlock()
// flush again to avoid missing tasks
pe.Flush()
return
}
}
}
})
}
func (pe *PeriodicalExecutor) executeTasks(tasks interface{}) bool {
pe.waitGroup.Add(1)
defer pe.waitGroup.Done()
ok := pe.hasTasks(tasks)
if ok {
pe.container.Execute(tasks)
}
return ok
}
func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
if tasks == nil {
return false
}
val := reflect.ValueOf(tasks)
switch val.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
return val.Len() > 0
default:
// unknown type, let caller execute it
return true
}
}

View File

@ -0,0 +1,118 @@
package executors
import (
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"zero/core/timex"
"github.com/stretchr/testify/assert"
)
const threshold = 10
type container struct {
interval time.Duration
tasks []int
execute func(tasks interface{})
}
func newContainer(interval time.Duration, execute func(tasks interface{})) *container {
return &container{
interval: interval,
execute: execute,
}
}
func (c *container) AddTask(task interface{}) bool {
c.tasks = append(c.tasks, task.(int))
return len(c.tasks) > threshold
}
func (c *container) Execute(tasks interface{}) {
if c.execute != nil {
c.execute(tasks)
} else {
time.Sleep(c.interval)
}
}
func (c *container) RemoveAll() interface{} {
tasks := c.tasks
c.tasks = nil
return tasks
}
func TestPeriodicalExecutor_Sync(t *testing.T) {
var done int32
exec := NewPeriodicalExecutor(time.Second, newContainer(time.Millisecond*500, nil))
exec.Sync(func() {
atomic.AddInt32(&done, 1)
})
assert.Equal(t, int32(1), atomic.LoadInt32(&done))
}
func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
ticker := timex.NewFakeTicker()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
routines := runtime.NumGoroutine()
exec.Add(1)
ticker.Tick()
ticker.Wait(time.Millisecond * idleRound * 2)
ticker.Tick()
ticker.Wait(time.Millisecond * idleRound)
assert.Equal(t, routines, runtime.NumGoroutine())
}
func TestPeriodicalExecutor_Bulk(t *testing.T) {
ticker := timex.NewFakeTicker()
var vals []int
// avoid data race
var lock sync.Mutex
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks interface{}) {
t := tasks.([]int)
for _, each := range t {
lock.Lock()
vals = append(vals, each)
lock.Unlock()
}
}))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
for i := 0; i < threshold*10; i++ {
if i%threshold == 5 {
time.Sleep(time.Millisecond * idleRound * 2)
}
exec.Add(i)
}
ticker.Tick()
ticker.Wait(time.Millisecond * idleRound * 2)
ticker.Tick()
ticker.Tick()
ticker.Wait(time.Millisecond * idleRound)
var expect []int
for i := 0; i < threshold*10; i++ {
expect = append(expect, i)
}
lock.Lock()
assert.EqualValues(t, expect, vals)
lock.Unlock()
}
// go test -benchtime 10s -bench .
func BenchmarkExecutor(b *testing.B) {
b.ReportAllocs()
executor := NewPeriodicalExecutor(time.Second, newContainer(time.Millisecond*500, nil))
for i := 0; i < b.N; i++ {
executor.Add(1)
}
}

7
core/executors/vars.go Normal file
View File

@ -0,0 +1,7 @@
package executors
import "time"
const defaultFlushInterval = time.Second
type Execute func(tasks []interface{})

84
core/filex/file.go Normal file
View File

@ -0,0 +1,84 @@
package filex
import (
"io"
"os"
)
const bufSize = 1024
func FirstLine(filename string) (string, error) {
file, err := os.Open(filename)
if err != nil {
return "", err
}
defer file.Close()
return firstLine(file)
}
func LastLine(filename string) (string, error) {
file, err := os.Open(filename)
if err != nil {
return "", err
}
defer file.Close()
return lastLine(filename, file)
}
func firstLine(file *os.File) (string, error) {
var first []byte
var offset int64
for {
buf := make([]byte, bufSize)
n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF {
return "", err
}
for i := 0; i < n; i++ {
if buf[i] == '\n' {
return string(append(first, buf[:i]...)), nil
}
}
first = append(first, buf[:n]...)
offset += bufSize
}
}
func lastLine(filename string, file *os.File) (string, error) {
info, err := os.Stat(filename)
if err != nil {
return "", err
}
var last []byte
offset := info.Size()
for {
offset -= bufSize
if offset < 0 {
offset = 0
}
buf := make([]byte, bufSize)
n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF {
return "", err
}
if buf[n-1] == '\n' {
buf = buf[:n-1]
n -= 1
} else {
buf = buf[:n]
}
for n -= 1; n >= 0; n-- {
if buf[n] == '\n' {
return string(append(buf[n+1:], last...)), nil
}
}
last = append(buf, last...)
}
}

116
core/filex/file_test.go Normal file
View File

@ -0,0 +1,116 @@
package filex
import (
"os"
"testing"
"zero/core/fs"
"github.com/stretchr/testify/assert"
)
const (
longLine = `Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.`
longFirstLine = longLine + "\n" + text
text = `first line
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
` + longLine
textWithLastNewline = `first line
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus.
Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum.
Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna.
Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil!
Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur.
` + longLine + "\n"
shortText = `first line
second line
last line`
shortTextWithLastNewline = `first line
second line
last line
`
)
func TestFirstLine(t *testing.T) {
filename, err := fs.TempFilenameWithText(longFirstLine)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestFirstLineShort(t *testing.T) {
filename, err := fs.TempFilenameWithText(shortText)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, "first line", val)
}
func TestLastLine(t *testing.T) {
filename, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLineWithLastNewline(t *testing.T) {
filename, err := fs.TempFilenameWithText(textWithLastNewline)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLineShort(t *testing.T) {
filename, err := fs.TempFilenameWithText(shortText)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "last line", val)
}
func TestLastLineWithLastNewlineShort(t *testing.T) {
filename, err := fs.TempFilenameWithText(shortTextWithLastNewline)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "last line", val)
}

105
core/filex/lookup.go Normal file
View File

@ -0,0 +1,105 @@
package filex
import (
"io"
"os"
)
type OffsetRange struct {
File string
Start int64
Stop int64
}
func SplitLineChunks(filename string, chunks int) ([]OffsetRange, error) {
info, err := os.Stat(filename)
if err != nil {
return nil, err
}
if chunks <= 1 {
return []OffsetRange{
{
File: filename,
Start: 0,
Stop: info.Size(),
},
}, nil
}
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
var ranges []OffsetRange
var offset int64
// avoid the last chunk too few bytes
preferSize := info.Size()/int64(chunks) + 1
for {
if offset+preferSize >= info.Size() {
ranges = append(ranges, OffsetRange{
File: filename,
Start: offset,
Stop: info.Size(),
})
break
}
offsetRange, err := nextRange(file, offset, offset+preferSize)
if err != nil {
return nil, err
}
ranges = append(ranges, offsetRange)
if offsetRange.Stop < info.Size() {
offset = offsetRange.Stop
} else {
break
}
}
return ranges, nil
}
func nextRange(file *os.File, start, stop int64) (OffsetRange, error) {
offset, err := skipPartialLine(file, stop)
if err != nil {
return OffsetRange{}, err
}
return OffsetRange{
File: file.Name(),
Start: start,
Stop: offset,
}, nil
}
func skipPartialLine(file *os.File, offset int64) (int64, error) {
for {
skipBuf := make([]byte, bufSize)
n, err := file.ReadAt(skipBuf, offset)
if err != nil && err != io.EOF {
return 0, err
}
if n == 0 {
return 0, io.EOF
}
for i := 0; i < n; i++ {
if skipBuf[i] != '\r' && skipBuf[i] != '\n' {
offset++
} else {
for ; i < n; i++ {
if skipBuf[i] == '\r' || skipBuf[i] == '\n' {
offset++
} else {
return offset, nil
}
}
return offset, nil
}
}
}
}

68
core/filex/lookup_test.go Normal file
View File

@ -0,0 +1,68 @@
package filex
import (
"os"
"testing"
"zero/core/fs"
"github.com/stretchr/testify/assert"
)
func TestSplitLineChunks(t *testing.T) {
const text = `first line
second line
third line
fourth line
fifth line
sixth line
seventh line
`
fp, err := fs.TempFileWithText(text)
assert.Nil(t, err)
defer func() {
fp.Close()
os.Remove(fp.Name())
}()
offsets, err := SplitLineChunks(fp.Name(), 3)
assert.Nil(t, err)
body := make([]byte, 512)
for _, offset := range offsets {
reader := NewRangeReader(fp, offset.Start, offset.Stop)
n, err := reader.Read(body)
assert.Nil(t, err)
assert.Equal(t, uint8('\n'), body[n-1])
}
}
func TestSplitLineChunksNoFile(t *testing.T) {
_, err := SplitLineChunks("nosuchfile", 2)
assert.NotNil(t, err)
}
func TestSplitLineChunksFull(t *testing.T) {
const text = `first line
second line
third line
fourth line
fifth line
sixth line
`
fp, err := fs.TempFileWithText(text)
assert.Nil(t, err)
defer func() {
fp.Close()
os.Remove(fp.Name())
}()
offsets, err := SplitLineChunks(fp.Name(), 1)
assert.Nil(t, err)
body := make([]byte, 512)
for _, offset := range offsets {
reader := NewRangeReader(fp, offset.Start, offset.Stop)
n, err := reader.Read(body)
assert.Nil(t, err)
assert.Equal(t, []byte(text), body[:n])
}
}

View File

@ -0,0 +1,28 @@
package filex
import "gopkg.in/cheggaaa/pb.v1"
type (
Scanner interface {
Scan() bool
Text() string
}
progressScanner struct {
Scanner
bar *pb.ProgressBar
}
)
func NewProgressScanner(scanner Scanner, bar *pb.ProgressBar) Scanner {
return &progressScanner{
Scanner: scanner,
bar: bar,
}
}
func (ps *progressScanner) Text() string {
s := ps.Scanner.Text()
ps.bar.Add64(int64(len(s)) + 1) // take newlines into account
return s
}

View File

@ -0,0 +1,31 @@
package filex
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/cheggaaa/pb.v1"
)
func TestProgressScanner(t *testing.T) {
const text = "hello, world"
bar := pb.New(100)
var builder strings.Builder
builder.WriteString(text)
scanner := NewProgressScanner(&mockedScanner{builder: &builder}, bar)
assert.True(t, scanner.Scan())
assert.Equal(t, text, scanner.Text())
}
type mockedScanner struct {
builder *strings.Builder
}
func (s *mockedScanner) Scan() bool {
return s.builder.Len() > 0
}
func (s *mockedScanner) Text() string {
return s.builder.String()
}

43
core/filex/rangereader.go Normal file
View File

@ -0,0 +1,43 @@
package filex
import (
"errors"
"os"
)
type RangeReader struct {
file *os.File
start int64
stop int64
}
func NewRangeReader(file *os.File, start, stop int64) *RangeReader {
return &RangeReader{
file: file,
start: start,
stop: stop,
}
}
func (rr *RangeReader) Read(p []byte) (n int, err error) {
stat, err := rr.file.Stat()
if err != nil {
return 0, err
}
if rr.stop < rr.start || rr.start >= stat.Size() {
return 0, errors.New("exceed file size")
}
if rr.stop-rr.start < int64(len(p)) {
p = p[:rr.stop-rr.start]
}
n, err = rr.file.ReadAt(p, rr.start)
if err != nil {
return n, err
}
rr.start += int64(n)
return
}

View File

@ -0,0 +1,45 @@
package filex
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"zero/core/fs"
)
func TestRangeReader(t *testing.T) {
const text = `hello
world`
file, err := fs.TempFileWithText(text)
assert.Nil(t, err)
defer func() {
file.Close()
os.Remove(file.Name())
}()
reader := NewRangeReader(file, 5, 8)
buf := make([]byte, 10)
n, err := reader.Read(buf)
assert.Nil(t, err)
assert.Equal(t, 3, n)
assert.Equal(t, `
wo`, string(buf[:n]))
}
func TestRangeReader_OutOfRange(t *testing.T) {
const text = `hello
world`
file, err := fs.TempFileWithText(text)
assert.Nil(t, err)
defer func() {
file.Close()
os.Remove(file.Name())
}()
reader := NewRangeReader(file, 50, 8)
buf := make([]byte, 10)
_, err = reader.Read(buf)
assert.NotNil(t, err)
}

View File

@ -0,0 +1,8 @@
// +build windows
package fs
import "os"
func CloseOnExec(*os.File) {
}

Some files were not shown because too many files have changed in this diff Show More