commit 7e3a369a8fb7a3cfadd1301718214f038985fa57 Author: kevin Date: Sun Jul 26 17:09:05 2020 +0800 initial import diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..f3b64113 --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +**/.git diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..6313b56c --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..eba5e9d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# Ignore all +* + +# Unignore all with extensions +!*.* +!**/Dockerfile + +# Unignore all dirs +!*/ +!api + +.idea +**/.DS_Store +**/logs +!vendor/github.com/songtianyi/rrframework/logs +**/*.pem +**/*.prof +**/*.p12 +!Makefile + +# gitlab ci +.cache + +# chatbot +**/production.json +**/*.corpus.json +**/*.txt +**/*.gob + +# example +example/**/*.csv + +# hera +service/hera/cli/readdata/intergrationtest/data +service/hera/devkit/ch331/data +service/hera/devkit/ch331/ck +service/hera/cli/replaybeat/etc +# goctl +tools/goctl/api/autogen + +vendor/* +/service/hera/cli/dboperation/etc/hera.json + +# vim auto backup file +*~ +!OWNERS \ No newline at end of file diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 00000000..a0ff831e --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,16 @@ +stages: +- analysis + +variables: + GOPATH: '/runner-cache/zero' + GOCACHE: '/runner-cache/zero' + GOPROXY: 'https://goproxy.cn,direct' + +analysis: + stage: analysis + image: golang + script: + - go version && go env + - go test -short $(go list ./...) | grep -v "no test" + only: + - merge_requests diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..d726d16a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,43 @@ +run: + # concurrency: 6 + timeout: 5m + skip-dirs: + - core + - diq + - doc + - dq + - example + - kmq + - kq + - ngin + - rq + - rpcx + # - service + - stash + - tools + + +linters: + disable-all: true + enable: + - bodyclose + - deadcode + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - structcheck + - typecheck + - unused + - varcheck +# - dupl + + +linters-settings: + +issues: + exclude-rules: + - linters: + - staticcheck + text: 'SA1019: (baseresponse.BoolResponse|oldresponse.FormatBadRequestResponse|oldresponse.FormatResponse)|SA5008: unknown JSON option ("optional"|"default=|"range=|"options=)' diff --git a/core/bloom/bloom.go b/core/bloom/bloom.go new file mode 100644 index 00000000..6e4027b2 --- /dev/null +++ b/core/bloom/bloom.go @@ -0,0 +1,161 @@ +package bloom + +import ( + "errors" + "strconv" + + "zero/core/hash" + "zero/core/stores/redis" +) + +const ( + // for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html + // maps as k in the error rate table + maps = 14 + setScript = ` +local key = KEYS[1] +for _, offset in ipairs(ARGV) do + redis.call("setbit", key, offset, 1) +end +` + testScript = ` +local key = KEYS[1] +for _, offset in ipairs(ARGV) do + if tonumber(redis.call("getbit", key, offset)) == 0 then + return false + end +end +return true +` +) + +var ErrTooLargeOffset = errors.New("too large offset") + +type ( + BitSetProvider interface { + check([]uint) (bool, error) + set([]uint) error + } + + BloomFilter struct { + bits uint + maps uint + bitSet BitSetProvider + } +) + +// New create a BloomFilter, store is the backed redis, key is the key for the bloom filter, +// bits is how many bits will be used, maps is how many hashes for each addition. +// best practices: +// elements - means how many actual elements +// when maps = 14, formula: 0.7*(bits/maps), bits = 20*elements, the error rate is 0.000067 < 1e-4 +// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html +func New(store *redis.Redis, key string, bits uint) *BloomFilter { + return &BloomFilter{ + bits: bits, + bitSet: newRedisBitSet(store, key, bits), + } +} + +func (f *BloomFilter) Add(data []byte) error { + locations := f.getLocations(data) + err := f.bitSet.set(locations) + if err != nil { + return err + } + return nil +} + +func (f *BloomFilter) Exists(data []byte) (bool, error) { + locations := f.getLocations(data) + isSet, err := f.bitSet.check(locations) + if err != nil { + return false, err + } + if !isSet { + return false, nil + } + + return true, nil +} + +func (f *BloomFilter) getLocations(data []byte) []uint { + locations := make([]uint, maps) + for i := uint(0); i < maps; i++ { + hashValue := hash.Hash(append(data, byte(i))) + locations[i] = uint(hashValue % uint64(f.bits)) + } + + return locations +} + +type redisBitSet struct { + store *redis.Redis + key string + bits uint +} + +func newRedisBitSet(store *redis.Redis, key string, bits uint) *redisBitSet { + return &redisBitSet{ + store: store, + key: key, + bits: bits, + } +} + +func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) { + var args []string + + for _, offset := range offsets { + if offset >= r.bits { + return nil, ErrTooLargeOffset + } + + args = append(args, strconv.FormatUint(uint64(offset), 10)) + } + + return args, nil +} + +func (r *redisBitSet) check(offsets []uint) (bool, error) { + args, err := r.buildOffsetArgs(offsets) + if err != nil { + return false, err + } + + resp, err := r.store.Eval(testScript, []string{r.key}, args) + if err == redis.Nil { + return false, nil + } else if err != nil { + return false, err + } + + if exists, ok := resp.(int64); !ok { + return false, nil + } else { + return exists == 1, nil + } +} + +func (r *redisBitSet) del() error { + _, err := r.store.Del(r.key) + return err +} + +func (r *redisBitSet) expire(seconds int) error { + return r.store.Expire(r.key, seconds) +} + +func (r *redisBitSet) set(offsets []uint) error { + args, err := r.buildOffsetArgs(offsets) + if err != nil { + return err + } + + _, err = r.store.Eval(setScript, []string{r.key}, args) + if err == redis.Nil { + return nil + } else { + return err + } +} diff --git a/core/bloom/bloom_test.go b/core/bloom/bloom_test.go new file mode 100644 index 00000000..07d5a173 --- /dev/null +++ b/core/bloom/bloom_test.go @@ -0,0 +1,63 @@ +package bloom + +import ( + "testing" + + "zero/core/stores/redis" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +func TestRedisBitSet_New_Set_Test(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error("Miniredis could not start") + } + defer s.Close() + + store := redis.NewRedis(s.Addr(), redis.NodeType) + bitSet := newRedisBitSet(store, "test_key", 1024) + isSetBefore, err := bitSet.check([]uint{0}) + if err != nil { + t.Fatal(err) + } + if isSetBefore { + t.Fatal("Bit should not be set") + } + err = bitSet.set([]uint{512}) + if err != nil { + t.Fatal(err) + } + isSetAfter, err := bitSet.check([]uint{512}) + if err != nil { + t.Fatal(err) + } + if !isSetAfter { + t.Fatal("Bit should be set") + } + err = bitSet.expire(3600) + if err != nil { + t.Fatal(err) + } + err = bitSet.del() + if err != nil { + t.Fatal(err) + } +} + +func TestRedisBitSet_Add(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error("Miniredis could not start") + } + defer s.Close() + + store := redis.NewRedis(s.Addr(), redis.NodeType) + filter := New(store, "test_key", 64) + assert.Nil(t, filter.Add([]byte("hello"))) + assert.Nil(t, filter.Add([]byte("world"))) + ok, err := filter.Exists([]byte("hello")) + assert.Nil(t, err) + assert.True(t, ok) +} diff --git a/core/breaker/breaker.go b/core/breaker/breaker.go new file mode 100644 index 00000000..5f9c8636 --- /dev/null +++ b/core/breaker/breaker.go @@ -0,0 +1,229 @@ +package breaker + +import ( + "errors" + "fmt" + "strings" + "sync" + "time" + + "zero/core/mathx" + "zero/core/proc" + "zero/core/stat" + "zero/core/stringx" +) + +const ( + StateClosed State = iota + StateOpen +) + +const ( + numHistoryReasons = 5 + timeFormat = "15:04:05" +) + +// ErrServiceUnavailable is returned when the CB state is open +var ErrServiceUnavailable = errors.New("circuit breaker is open") + +type ( + State = int32 + Acceptable func(err error) bool + + Breaker interface { + // Name returns the name of the netflixBreaker. + Name() string + + // Allow checks if the request is allowed. + // If allowed, a promise will be returned, the caller needs to call promise.Accept() + // on success, or call promise.Reject() on failure. + // If not allow, ErrServiceUnavailable will be returned. + Allow() (Promise, error) + + // Do runs the given request if the netflixBreaker accepts it. + // Do returns an error instantly if the netflixBreaker rejects the request. + // If a panic occurs in the request, the netflixBreaker handles it as an error + // and causes the same panic again. + Do(req func() error) error + + // DoWithAcceptable runs the given request if the netflixBreaker accepts it. + // Do returns an error instantly if the netflixBreaker rejects the request. + // If a panic occurs in the request, the netflixBreaker handles it as an error + // and causes the same panic again. + // acceptable checks if it's a successful call, even if the err is not nil. + DoWithAcceptable(req func() error, acceptable Acceptable) error + + // DoWithFallback runs the given request if the netflixBreaker accepts it. + // DoWithFallback runs the fallback if the netflixBreaker rejects the request. + // If a panic occurs in the request, the netflixBreaker handles it as an error + // and causes the same panic again. + DoWithFallback(req func() error, fallback func(err error) error) error + + // DoWithFallbackAcceptable runs the given request if the netflixBreaker accepts it. + // DoWithFallback runs the fallback if the netflixBreaker rejects the request. + // If a panic occurs in the request, the netflixBreaker handles it as an error + // and causes the same panic again. + // acceptable checks if it's a successful call, even if the err is not nil. + DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error + } + + BreakerOption func(breaker *circuitBreaker) + + Promise interface { + Accept() + Reject(reason string) + } + + internalPromise interface { + Accept() + Reject() + } + + circuitBreaker struct { + name string + throttle + } + + internalThrottle interface { + allow() (internalPromise, error) + doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error + } + + throttle interface { + allow() (Promise, error) + doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error + } +) + +func NewBreaker(opts ...BreakerOption) Breaker { + var b circuitBreaker + for _, opt := range opts { + opt(&b) + } + if len(b.name) == 0 { + b.name = stringx.Rand() + } + b.throttle = newLoggedThrottle(b.name, newGoogleBreaker()) + + return &b +} + +func (cb *circuitBreaker) Allow() (Promise, error) { + return cb.throttle.allow() +} + +func (cb *circuitBreaker) Do(req func() error) error { + return cb.throttle.doReq(req, nil, defaultAcceptable) +} + +func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error { + return cb.throttle.doReq(req, nil, acceptable) +} + +func (cb *circuitBreaker) DoWithFallback(req func() error, fallback func(err error) error) error { + return cb.throttle.doReq(req, fallback, defaultAcceptable) +} + +func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error, + acceptable Acceptable) error { + return cb.throttle.doReq(req, fallback, acceptable) +} + +func (cb *circuitBreaker) Name() string { + return cb.name +} + +func WithName(name string) BreakerOption { + return func(b *circuitBreaker) { + b.name = name + } +} + +func defaultAcceptable(err error) bool { + return err == nil +} + +type loggedThrottle struct { + name string + internalThrottle + errWin *errorWindow +} + +func newLoggedThrottle(name string, t internalThrottle) loggedThrottle { + return loggedThrottle{ + name: name, + internalThrottle: t, + errWin: new(errorWindow), + } +} + +func (lt loggedThrottle) allow() (Promise, error) { + promise, err := lt.internalThrottle.allow() + return promiseWithReason{ + promise: promise, + errWin: lt.errWin, + }, lt.logError(err) +} + +func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error { + return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool { + accept := acceptable(err) + if !accept { + lt.errWin.add(err.Error()) + } + return accept + })) +} + +func (lt loggedThrottle) logError(err error) error { + if err == ErrServiceUnavailable { + // if circuit open, not possible to have empty error window + stat.Report(fmt.Sprintf( + "proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s", + proc.ProcessName(), proc.Pid(), lt.name, lt.errWin)) + } + + return err +} + +type errorWindow struct { + reasons [numHistoryReasons]string + index int + count int + lock sync.Mutex +} + +func (ew *errorWindow) add(reason string) { + ew.lock.Lock() + ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason) + ew.index = (ew.index + 1) % numHistoryReasons + ew.count = mathx.MinInt(ew.count+1, numHistoryReasons) + ew.lock.Unlock() +} + +func (ew *errorWindow) String() string { + var builder strings.Builder + + ew.lock.Lock() + for i := ew.index + ew.count - 1; i >= ew.index; i-- { + builder.WriteString(ew.reasons[i%numHistoryReasons]) + builder.WriteByte('\n') + } + ew.lock.Unlock() + + return builder.String() +} + +type promiseWithReason struct { + promise internalPromise + errWin *errorWindow +} + +func (p promiseWithReason) Accept() { + p.promise.Accept() +} + +func (p promiseWithReason) Reject(reason string) { + p.errWin.add(reason) + p.promise.Reject() +} diff --git a/core/breaker/breaker_test.go b/core/breaker/breaker_test.go new file mode 100644 index 00000000..d75de200 --- /dev/null +++ b/core/breaker/breaker_test.go @@ -0,0 +1,44 @@ +package breaker + +import ( + "errors" + "strconv" + "testing" + + "zero/core/stat" + + "github.com/stretchr/testify/assert" +) + +func init() { + stat.SetReporter(nil) +} + +func TestCircuitBreaker_Allow(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + _, err := b.Allow() + assert.Nil(t, err) +} + +func TestLogReason(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + + for i := 0; i < 1000; i++ { + _ = b.Do(func() error { + return errors.New(strconv.Itoa(i)) + }) + } + errs := b.(*circuitBreaker).throttle.(loggedThrottle).errWin + assert.Equal(t, numHistoryReasons, errs.count) +} + +func BenchmarkGoogleBreaker(b *testing.B) { + br := NewBreaker() + for i := 0; i < b.N; i++ { + _ = br.Do(func() error { + return nil + }) + } +} diff --git a/core/breaker/breakers.go b/core/breaker/breakers.go new file mode 100644 index 00000000..26760142 --- /dev/null +++ b/core/breaker/breakers.go @@ -0,0 +1,76 @@ +package breaker + +import "sync" + +var ( + lock sync.RWMutex + breakers = make(map[string]Breaker) +) + +func Do(name string, req func() error) error { + return do(name, func(b Breaker) error { + return b.Do(req) + }) +} + +func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error { + return do(name, func(b Breaker) error { + return b.DoWithAcceptable(req, acceptable) + }) +} + +func DoWithFallback(name string, req func() error, fallback func(err error) error) error { + return do(name, func(b Breaker) error { + return b.DoWithFallback(req, fallback) + }) +} + +func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error, + acceptable Acceptable) error { + return do(name, func(b Breaker) error { + return b.DoWithFallbackAcceptable(req, fallback, acceptable) + }) +} + +func GetBreaker(name string) Breaker { + lock.RLock() + b, ok := breakers[name] + lock.RUnlock() + if ok { + return b + } + + lock.Lock() + defer lock.Unlock() + + b = NewBreaker() + breakers[name] = b + return b +} + +func NoBreakFor(name string) { + lock.Lock() + breakers[name] = newNoOpBreaker() + lock.Unlock() +} + +func do(name string, execute func(b Breaker) error) error { + lock.RLock() + b, ok := breakers[name] + lock.RUnlock() + if ok { + return execute(b) + } else { + lock.Lock() + b, ok = breakers[name] + if ok { + lock.Unlock() + return execute(b) + } else { + b = NewBreaker(WithName(name)) + breakers[name] = b + lock.Unlock() + return execute(b) + } + } +} diff --git a/core/breaker/breakers_test.go b/core/breaker/breakers_test.go new file mode 100644 index 00000000..b148e47a --- /dev/null +++ b/core/breaker/breakers_test.go @@ -0,0 +1,115 @@ +package breaker + +import ( + "errors" + "fmt" + "testing" + + "zero/core/stat" + + "github.com/stretchr/testify/assert" +) + +func init() { + stat.SetReporter(nil) +} + +func TestBreakersDo(t *testing.T) { + assert.Nil(t, Do("any", func() error { + return nil + })) + + errDummy := errors.New("any") + assert.Equal(t, errDummy, Do("any", func() error { + return errDummy + })) +} + +func TestBreakersDoWithAcceptable(t *testing.T) { + errDummy := errors.New("anyone") + for i := 0; i < 10000; i++ { + assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error { + return errDummy + }, func(err error) bool { + return err == nil || err == errDummy + })) + } + verify(t, func() bool { + return Do("anyone", func() error { + return nil + }) == nil + }) + + for i := 0; i < 10000; i++ { + err := DoWithAcceptable("another", func() error { + return errDummy + }, func(err error) bool { + return err == nil + }) + assert.True(t, err == errDummy || err == ErrServiceUnavailable) + } + verify(t, func() bool { + return ErrServiceUnavailable == Do("another", func() error { + return nil + }) + }) +} + +func TestBreakersNoBreakerFor(t *testing.T) { + NoBreakFor("any") + errDummy := errors.New("any") + for i := 0; i < 10000; i++ { + assert.Equal(t, errDummy, GetBreaker("any").Do(func() error { + return errDummy + })) + } + assert.Equal(t, nil, Do("any", func() error { + return nil + })) +} + +func TestBreakersFallback(t *testing.T) { + errDummy := errors.New("any") + for i := 0; i < 10000; i++ { + err := DoWithFallback("fallback", func() error { + return errDummy + }, func(err error) error { + return nil + }) + assert.True(t, err == nil || err == errDummy) + } + verify(t, func() bool { + return ErrServiceUnavailable == Do("fallback", func() error { + return nil + }) + }) +} + +func TestBreakersAcceptableFallback(t *testing.T) { + errDummy := errors.New("any") + for i := 0; i < 10000; i++ { + err := DoWithFallbackAcceptable("acceptablefallback", func() error { + return errDummy + }, func(err error) error { + return nil + }, func(err error) bool { + return err == nil + }) + assert.True(t, err == nil || err == errDummy) + } + verify(t, func() bool { + return ErrServiceUnavailable == Do("acceptablefallback", func() error { + return nil + }) + }) +} + +func verify(t *testing.T, fn func() bool) { + var count int + for i := 0; i < 100; i++ { + if fn() { + count++ + } + } + assert.True(t, count >= 80, fmt.Sprintf("should be greater than 80, actual %d", count)) +} diff --git a/core/breaker/googlebreaker.go b/core/breaker/googlebreaker.go new file mode 100644 index 00000000..b013132e --- /dev/null +++ b/core/breaker/googlebreaker.go @@ -0,0 +1,125 @@ +package breaker + +import ( + "math" + "sync/atomic" + "time" + + "zero/core/collection" + "zero/core/mathx" +) + +const ( + // 250ms for bucket duration + window = time.Second * 10 + buckets = 40 + k = 1.5 + protection = 5 +) + +// googleBreaker is a netflixBreaker pattern from google. +// see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/ +type googleBreaker struct { + k float64 + state int32 + stat *collection.RollingWindow + proba *mathx.Proba +} + +func newGoogleBreaker() *googleBreaker { + bucketDuration := time.Duration(int64(window) / int64(buckets)) + st := collection.NewRollingWindow(buckets, bucketDuration) + return &googleBreaker{ + stat: st, + k: k, + state: StateClosed, + proba: mathx.NewProba(), + } +} + +func (b *googleBreaker) accept() error { + accepts, total := b.history() + weightedAccepts := b.k * float64(accepts) + // https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101 + dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1)) + if dropRatio <= 0 { + if atomic.LoadInt32(&b.state) == StateOpen { + atomic.CompareAndSwapInt32(&b.state, StateOpen, StateClosed) + } + return nil + } + + if atomic.LoadInt32(&b.state) == StateClosed { + atomic.CompareAndSwapInt32(&b.state, StateClosed, StateOpen) + } + if b.proba.TrueOnProba(dropRatio) { + return ErrServiceUnavailable + } + + return nil +} + +func (b *googleBreaker) allow() (internalPromise, error) { + if err := b.accept(); err != nil { + return nil, err + } + + return googlePromise{ + b: b, + }, nil +} + +func (b *googleBreaker) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error { + if err := b.accept(); err != nil { + if fallback != nil { + return fallback(err) + } else { + return err + } + } + + defer func() { + if e := recover(); e != nil { + b.markFailure() + panic(e) + } + }() + + err := req() + if acceptable(err) { + b.markSuccess() + } else { + b.markFailure() + } + + return err +} + +func (b *googleBreaker) markSuccess() { + b.stat.Add(1) +} + +func (b *googleBreaker) markFailure() { + b.stat.Add(0) +} + +func (b *googleBreaker) history() (accepts int64, total int64) { + b.stat.Reduce(func(b *collection.Bucket) { + accepts += int64(b.Sum) + total += b.Count + }) + + return +} + +type googlePromise struct { + b *googleBreaker +} + +func (p googlePromise) Accept() { + p.b.markSuccess() +} + +func (p googlePromise) Reject() { + p.b.markFailure() +} diff --git a/core/breaker/googlebreaker_test.go b/core/breaker/googlebreaker_test.go new file mode 100644 index 00000000..ba16eae0 --- /dev/null +++ b/core/breaker/googlebreaker_test.go @@ -0,0 +1,238 @@ +package breaker + +import ( + "errors" + "math" + "math/rand" + "testing" + "time" + + "zero/core/collection" + "zero/core/mathx" + "zero/core/stat" + + "github.com/stretchr/testify/assert" +) + +const ( + testBuckets = 10 + testInterval = time.Millisecond * 10 +) + +func init() { + stat.SetReporter(nil) +} + +func getGoogleBreaker() *googleBreaker { + st := collection.NewRollingWindow(testBuckets, testInterval) + return &googleBreaker{ + stat: st, + k: 5, + state: StateClosed, + proba: mathx.NewProba(), + } +} + +func markSuccessWithDuration(b *googleBreaker, count int, sleep time.Duration) { + for i := 0; i < count; i++ { + b.markSuccess() + time.Sleep(sleep) + } +} + +func markFailedWithDuration(b *googleBreaker, count int, sleep time.Duration) { + for i := 0; i < count; i++ { + b.markFailure() + time.Sleep(sleep) + } +} + +func TestGoogleBreakerClose(t *testing.T) { + b := getGoogleBreaker() + markSuccess(b, 80) + assert.Nil(t, b.accept()) + markSuccess(b, 120) + assert.Nil(t, b.accept()) +} + +func TestGoogleBreakerOpen(t *testing.T) { + b := getGoogleBreaker() + markSuccess(b, 10) + assert.Nil(t, b.accept()) + markFailed(b, 100000) + time.Sleep(testInterval * 2) + verify(t, func() bool { + return b.accept() != nil + }) +} + +func TestGoogleBreakerFallback(t *testing.T) { + b := getGoogleBreaker() + markSuccess(b, 1) + assert.Nil(t, b.accept()) + markFailed(b, 10000) + time.Sleep(testInterval * 2) + verify(t, func() bool { + return b.doReq(func() error { + return errors.New("any") + }, func(err error) error { + return nil + }, defaultAcceptable) == nil + }) +} + +func TestGoogleBreakerReject(t *testing.T) { + b := getGoogleBreaker() + markSuccess(b, 100) + assert.Nil(t, b.accept()) + markFailed(b, 10000) + time.Sleep(testInterval) + assert.Equal(t, ErrServiceUnavailable, b.doReq(func() error { + return ErrServiceUnavailable + }, nil, defaultAcceptable)) +} + +func TestGoogleBreakerAcceptable(t *testing.T) { + b := getGoogleBreaker() + errAcceptable := errors.New("any") + assert.Equal(t, errAcceptable, b.doReq(func() error { + return errAcceptable + }, nil, func(err error) bool { + return err == errAcceptable + })) +} + +func TestGoogleBreakerNotAcceptable(t *testing.T) { + b := getGoogleBreaker() + errAcceptable := errors.New("any") + assert.Equal(t, errAcceptable, b.doReq(func() error { + return errAcceptable + }, nil, func(err error) bool { + return err != errAcceptable + })) +} + +func TestGoogleBreakerPanic(t *testing.T) { + b := getGoogleBreaker() + assert.Panics(t, func() { + _ = b.doReq(func() error { + panic("fail") + }, nil, defaultAcceptable) + }) +} + +func TestGoogleBreakerHalfOpen(t *testing.T) { + b := getGoogleBreaker() + assert.Nil(t, b.accept()) + t.Run("accept single failed/accept", func(t *testing.T) { + markFailed(b, 10000) + time.Sleep(testInterval * 2) + verify(t, func() bool { + return b.accept() != nil + }) + }) + t.Run("accept single failed/allow", func(t *testing.T) { + markFailed(b, 10000) + time.Sleep(testInterval * 2) + verify(t, func() bool { + _, err := b.allow() + return err != nil + }) + }) + time.Sleep(testInterval * testBuckets) + t.Run("accept single succeed", func(t *testing.T) { + assert.Nil(t, b.accept()) + markSuccess(b, 10000) + verify(t, func() bool { + return b.accept() == nil + }) + }) +} + +func TestGoogleBreakerSelfProtection(t *testing.T) { + t.Run("total request < 100", func(t *testing.T) { + b := getGoogleBreaker() + markFailed(b, 4) + time.Sleep(testInterval) + assert.Nil(t, b.accept()) + }) + t.Run("total request > 100, total < 2 * success", func(t *testing.T) { + b := getGoogleBreaker() + size := rand.Intn(10000) + accepts := int(math.Ceil(float64(size))) + 1 + markSuccess(b, accepts) + markFailed(b, size-accepts) + assert.Nil(t, b.accept()) + }) +} + +func TestGoogleBreakerHistory(t *testing.T) { + var b *googleBreaker + var accepts, total int64 + + sleep := testInterval + t.Run("accepts == total", func(t *testing.T) { + b = getGoogleBreaker() + markSuccessWithDuration(b, 10, sleep/2) + accepts, total = b.history() + assert.Equal(t, int64(10), accepts) + assert.Equal(t, int64(10), total) + }) + + t.Run("fail == total", func(t *testing.T) { + b = getGoogleBreaker() + markFailedWithDuration(b, 10, sleep/2) + accepts, total = b.history() + assert.Equal(t, int64(0), accepts) + assert.Equal(t, int64(10), total) + }) + + t.Run("accepts = 1/2 * total, fail = 1/2 * total", func(t *testing.T) { + b = getGoogleBreaker() + markFailedWithDuration(b, 5, sleep/2) + markSuccessWithDuration(b, 5, sleep/2) + accepts, total = b.history() + assert.Equal(t, int64(5), accepts) + assert.Equal(t, int64(10), total) + }) + + t.Run("auto reset rolling counter", func(t *testing.T) { + b = getGoogleBreaker() + time.Sleep(testInterval * testBuckets) + accepts, total = b.history() + assert.Equal(t, int64(0), accepts) + assert.Equal(t, int64(0), total) + }) +} + +func BenchmarkGoogleBreakerAllow(b *testing.B) { + breaker := getGoogleBreaker() + b.ResetTimer() + for i := 0; i <= b.N; i++ { + breaker.accept() + if i%2 == 0 { + breaker.markSuccess() + } else { + breaker.markFailure() + } + } +} + +func markSuccess(b *googleBreaker, count int) { + for i := 0; i < count; i++ { + p, err := b.allow() + if err != nil { + break + } + p.Accept() + } +} + +func markFailed(b *googleBreaker, count int) { + for i := 0; i < count; i++ { + p, err := b.allow() + if err == nil { + p.Reject() + } + } +} diff --git a/core/breaker/nopbreaker.go b/core/breaker/nopbreaker.go new file mode 100644 index 00000000..19c4ffbb --- /dev/null +++ b/core/breaker/nopbreaker.go @@ -0,0 +1,42 @@ +package breaker + +const noOpBreakerName = "nopBreaker" + +type noOpBreaker struct{} + +func newNoOpBreaker() Breaker { + return noOpBreaker{} +} + +func (b noOpBreaker) Name() string { + return noOpBreakerName +} + +func (b noOpBreaker) Allow() (Promise, error) { + return nopPromise{}, nil +} + +func (b noOpBreaker) Do(req func() error) error { + return req() +} + +func (b noOpBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error { + return req() +} + +func (b noOpBreaker) DoWithFallback(req func() error, fallback func(err error) error) error { + return req() +} + +func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error, + acceptable Acceptable) error { + return req() +} + +type nopPromise struct{} + +func (p nopPromise) Accept() { +} + +func (p nopPromise) Reject(reason string) { +} diff --git a/core/breaker/nopbreaker_test.go b/core/breaker/nopbreaker_test.go new file mode 100644 index 00000000..ecca1806 --- /dev/null +++ b/core/breaker/nopbreaker_test.go @@ -0,0 +1,38 @@ +package breaker + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNopBreaker(t *testing.T) { + b := newNoOpBreaker() + assert.Equal(t, noOpBreakerName, b.Name()) + p, err := b.Allow() + assert.Nil(t, err) + p.Accept() + for i := 0; i < 1000; i++ { + p, err := b.Allow() + assert.Nil(t, err) + p.Reject("any") + } + assert.Nil(t, b.Do(func() error { + return nil + })) + assert.Nil(t, b.DoWithAcceptable(func() error { + return nil + }, defaultAcceptable)) + errDummy := errors.New("any") + assert.Equal(t, errDummy, b.DoWithFallback(func() error { + return errDummy + }, func(err error) error { + return nil + })) + assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error { + return errDummy + }, func(err error) error { + return nil + }, defaultAcceptable)) +} diff --git a/core/cmdline/input.go b/core/cmdline/input.go new file mode 100644 index 00000000..0284d99e --- /dev/null +++ b/core/cmdline/input.go @@ -0,0 +1,19 @@ +package cmdline + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +func EnterToContinue() { + fmt.Print("Press 'Enter' to continue...") + bufio.NewReader(os.Stdin).ReadBytes('\n') +} + +func ReadLine(prompt string) string { + fmt.Print(prompt) + input, _ := bufio.NewReader(os.Stdin).ReadString('\n') + return strings.TrimSpace(input) +} diff --git a/core/codec/aesecb.go b/core/codec/aesecb.go new file mode 100644 index 00000000..86be5cce --- /dev/null +++ b/core/codec/aesecb.go @@ -0,0 +1,174 @@ +package codec + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "errors" + + "zero/core/logx" +) + +var ErrPaddingSize = errors.New("padding size error") + +type ecb struct { + b cipher.Block + blockSize int +} + +func newECB(b cipher.Block) *ecb { + return &ecb{ + b: b, + blockSize: b.BlockSize(), + } +} + +type ecbEncrypter ecb + +func NewECBEncrypter(b cipher.Block) cipher.BlockMode { + return (*ecbEncrypter)(newECB(b)) +} + +func (x *ecbEncrypter) BlockSize() int { return x.blockSize } + +// why we don't return error is because cipher.BlockMode doesn't allow this +func (x *ecbEncrypter) CryptBlocks(dst, src []byte) { + if len(src)%x.blockSize != 0 { + logx.Error("crypto/cipher: input not full blocks") + return + } + if len(dst) < len(src) { + logx.Error("crypto/cipher: output smaller than input") + return + } + + for len(src) > 0 { + x.b.Encrypt(dst, src[:x.blockSize]) + src = src[x.blockSize:] + dst = dst[x.blockSize:] + } +} + +type ecbDecrypter ecb + +func NewECBDecrypter(b cipher.Block) cipher.BlockMode { + return (*ecbDecrypter)(newECB(b)) +} + +func (x *ecbDecrypter) BlockSize() int { + return x.blockSize +} + +// why we don't return error is because cipher.BlockMode doesn't allow this +func (x *ecbDecrypter) CryptBlocks(dst, src []byte) { + if len(src)%x.blockSize != 0 { + logx.Error("crypto/cipher: input not full blocks") + return + } + if len(dst) < len(src) { + logx.Error("crypto/cipher: output smaller than input") + return + } + + for len(src) > 0 { + x.b.Decrypt(dst, src[:x.blockSize]) + src = src[x.blockSize:] + dst = dst[x.blockSize:] + } +} + +func EcbDecrypt(key, src []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + logx.Errorf("Decrypt key error: % x", key) + return nil, err + } + + decrypter := NewECBDecrypter(block) + decrypted := make([]byte, len(src)) + decrypter.CryptBlocks(decrypted, src) + + return pkcs5Unpadding(decrypted, decrypter.BlockSize()) +} + +func EcbDecryptBase64(key, src string) (string, error) { + keyBytes, err := getKeyBytes(key) + if err != nil { + return "", err + } + + encryptedBytes, err := base64.StdEncoding.DecodeString(src) + if err != nil { + return "", err + } + + decryptedBytes, err := EcbDecrypt(keyBytes, encryptedBytes) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(decryptedBytes), nil +} + +func EcbEncrypt(key, src []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + logx.Errorf("Encrypt key error: % x", key) + return nil, err + } + + padded := pkcs5Padding(src, block.BlockSize()) + crypted := make([]byte, len(padded)) + encrypter := NewECBEncrypter(block) + encrypter.CryptBlocks(crypted, padded) + + return crypted, nil +} + +func EcbEncryptBase64(key, src string) (string, error) { + keyBytes, err := getKeyBytes(key) + if err != nil { + return "", err + } + + srcBytes, err := base64.StdEncoding.DecodeString(src) + if err != nil { + return "", err + } + + encryptedBytes, err := EcbEncrypt(keyBytes, srcBytes) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(encryptedBytes), nil +} + +func getKeyBytes(key string) ([]byte, error) { + if len(key) > 32 { + if keyBytes, err := base64.StdEncoding.DecodeString(key); err != nil { + return nil, err + } else { + return keyBytes, nil + } + } + + return []byte(key), nil +} + +func pkcs5Padding(ciphertext []byte, blockSize int) []byte { + padding := blockSize - len(ciphertext)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padtext...) +} + +func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) { + length := len(src) + unpadding := int(src[length-1]) + if unpadding >= length || unpadding > blockSize { + return nil, ErrPaddingSize + } + + return src[:length-unpadding], nil +} diff --git a/core/codec/dh.go b/core/codec/dh.go new file mode 100644 index 00000000..51883c31 --- /dev/null +++ b/core/codec/dh.go @@ -0,0 +1,88 @@ +package codec + +import ( + "crypto/rand" + "errors" + "math/big" +) + +// see https://www.zhihu.com/question/29383090/answer/70435297 +// see https://www.ietf.org/rfc/rfc3526.txt +// 2048-bit MODP Group + +var ( + ErrInvalidPriKey = errors.New("invalid private key") + ErrInvalidPubKey = errors.New("invalid public key") + ErrPubKeyOutOfBound = errors.New("public key out of bound") + + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + g, _ = new(big.Int).SetString("2", 16) + zero = big.NewInt(0) +) + +type DhKey struct { + PriKey *big.Int + PubKey *big.Int +} + +func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) { + if pubKey == nil { + return nil, ErrInvalidPubKey + } + + if pubKey.Sign() <= 0 && p.Cmp(pubKey) <= 0 { + return nil, ErrPubKeyOutOfBound + } + + if priKey == nil { + return nil, ErrInvalidPriKey + } + + return new(big.Int).Exp(pubKey, priKey, p), nil +} + +func GenerateKey() (*DhKey, error) { + var err error + var x *big.Int + + for { + x, err = rand.Int(rand.Reader, p) + if err != nil { + return nil, err + } + + if zero.Cmp(x) < 0 { + break + } + } + + key := new(DhKey) + key.PriKey = x + key.PubKey = new(big.Int).Exp(g, x, p) + + return key, nil +} + +func NewPublicKey(bs []byte) *big.Int { + return new(big.Int).SetBytes(bs) +} + +func (k *DhKey) Bytes() []byte { + if k.PubKey == nil { + return nil + } + + byteLen := (p.BitLen() + 7) >> 3 + ret := make([]byte, byteLen) + copyWithLeftPad(ret, k.PubKey.Bytes()) + + return ret +} + +func copyWithLeftPad(dst, src []byte) { + padBytes := len(dst) - len(src) + for i := 0; i < padBytes; i++ { + dst[i] = 0 + } + copy(dst[padBytes:], src) +} diff --git a/core/codec/dh_test.go b/core/codec/dh_test.go new file mode 100644 index 00000000..56f37412 --- /dev/null +++ b/core/codec/dh_test.go @@ -0,0 +1,73 @@ +package codec + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDiffieHellman(t *testing.T) { + key1, err := GenerateKey() + assert.Nil(t, err) + key2, err := GenerateKey() + assert.Nil(t, err) + + pubKey1, err := ComputeKey(key1.PubKey, key2.PriKey) + assert.Nil(t, err) + pubKey2, err := ComputeKey(key2.PubKey, key1.PriKey) + assert.Nil(t, err) + + assert.Equal(t, pubKey1, pubKey2) +} + +func TestDiffieHellman1024(t *testing.T) { + old := p + p, _ = new(big.Int).SetString("F488FD584E49DBCD20B49DE49107366B336C380D451D0F7C88B31C7C5B2D8EF6F3C923C043F0A55B188D8EBB558CB85D38D334FD7C175743A31D186CDE33212CB52AFF3CE1B1294018118D7C84A70A72D686C40319C807297ACA950CD9969FABD00A509B0246D3083D66A45D419F9C7CBD894B221926BAABA25EC355E92F78C7", 16) + defer func() { + p = old + }() + + key1, err := GenerateKey() + assert.Nil(t, err) + key2, err := GenerateKey() + assert.Nil(t, err) + + pubKey1, err := ComputeKey(key1.PubKey, key2.PriKey) + assert.Nil(t, err) + pubKey2, err := ComputeKey(key2.PubKey, key1.PriKey) + assert.Nil(t, err) + + assert.Equal(t, pubKey1, pubKey2) +} + +func TestDiffieHellmanMiddleManAttack(t *testing.T) { + key1, err := GenerateKey() + assert.Nil(t, err) + keyMiddle, err := GenerateKey() + assert.Nil(t, err) + key2, err := GenerateKey() + assert.Nil(t, err) + + const aesByteLen = 32 + pubKey1, err := ComputeKey(keyMiddle.PubKey, key1.PriKey) + assert.Nil(t, err) + src := []byte(`hello, world!`) + encryptedSrc, err := EcbEncrypt(pubKey1.Bytes()[:aesByteLen], src) + assert.Nil(t, err) + pubKeyMiddle, err := ComputeKey(key1.PubKey, keyMiddle.PriKey) + assert.Nil(t, err) + decryptedSrc, err := EcbDecrypt(pubKeyMiddle.Bytes()[:aesByteLen], encryptedSrc) + assert.Nil(t, err) + assert.Equal(t, string(src), string(decryptedSrc)) + + pubKeyMiddle, err = ComputeKey(key2.PubKey, keyMiddle.PriKey) + assert.Nil(t, err) + encryptedSrc, err = EcbEncrypt(pubKeyMiddle.Bytes()[:aesByteLen], decryptedSrc) + assert.Nil(t, err) + pubKey2, err := ComputeKey(keyMiddle.PubKey, key2.PriKey) + assert.Nil(t, err) + decryptedSrc, err = EcbDecrypt(pubKey2.Bytes()[:aesByteLen], encryptedSrc) + assert.Nil(t, err) + assert.Equal(t, string(src), string(decryptedSrc)) +} diff --git a/core/codec/gzip.go b/core/codec/gzip.go new file mode 100644 index 00000000..d93cf072 --- /dev/null +++ b/core/codec/gzip.go @@ -0,0 +1,33 @@ +package codec + +import ( + "bytes" + "compress/gzip" + "io" +) + +func Gzip(bs []byte) []byte { + var b bytes.Buffer + + w := gzip.NewWriter(&b) + w.Write(bs) + w.Close() + + return b.Bytes() +} + +func Gunzip(bs []byte) ([]byte, error) { + r, err := gzip.NewReader(bytes.NewBuffer(bs)) + if err != nil { + return nil, err + } + defer r.Close() + + var c bytes.Buffer + _, err = io.Copy(&c, r) + if err != nil { + return nil, err + } + + return c.Bytes(), nil +} diff --git a/core/codec/gzip_test.go b/core/codec/gzip_test.go new file mode 100644 index 00000000..29a89748 --- /dev/null +++ b/core/codec/gzip_test.go @@ -0,0 +1,23 @@ +package codec + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGzip(t *testing.T) { + var buf bytes.Buffer + for i := 0; i < 10000; i++ { + fmt.Fprint(&buf, i) + } + + bs := Gzip(buf.Bytes()) + actual, err := Gunzip(bs) + + assert.Nil(t, err) + assert.True(t, len(bs) < buf.Len()) + assert.Equal(t, buf.Bytes(), actual) +} diff --git a/core/codec/hmac.go b/core/codec/hmac.go new file mode 100644 index 00000000..e21819be --- /dev/null +++ b/core/codec/hmac.go @@ -0,0 +1,18 @@ +package codec + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "io" +) + +func Hmac(key []byte, body string) []byte { + h := hmac.New(sha256.New, key) + io.WriteString(h, body) + return h.Sum(nil) +} + +func HmacBase64(key []byte, body string) string { + return base64.StdEncoding.EncodeToString(Hmac(key, body)) +} diff --git a/core/codec/rsa.go b/core/codec/rsa.go new file mode 100644 index 00000000..b2eed9dc --- /dev/null +++ b/core/codec/rsa.go @@ -0,0 +1,149 @@ +package codec + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "errors" + "io/ioutil" +) + +var ( + ErrPrivateKey = errors.New("private key error") + ErrPublicKey = errors.New("failed to parse PEM block containing the public key") + ErrNotRsaKey = errors.New("key type is not RSA") +) + +type ( + RsaDecrypter interface { + Decrypt(input []byte) ([]byte, error) + DecryptBase64(input string) ([]byte, error) + } + + RsaEncrypter interface { + Encrypt(input []byte) ([]byte, error) + } + + rsaBase struct { + bytesLimit int + } + + rsaDecrypter struct { + rsaBase + privateKey *rsa.PrivateKey + } + + rsaEncrypter struct { + rsaBase + publicKey *rsa.PublicKey + } +) + +func NewRsaDecrypter(file string) (RsaDecrypter, error) { + content, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(content) + if block == nil { + return nil, ErrPrivateKey + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + return &rsaDecrypter{ + rsaBase: rsaBase{ + bytesLimit: privateKey.N.BitLen() >> 3, + }, + privateKey: privateKey, + }, nil +} + +func (r *rsaDecrypter) Decrypt(input []byte) ([]byte, error) { + return r.crypt(input, func(block []byte) ([]byte, error) { + return rsaDecryptBlock(r.privateKey, block) + }) +} + +func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) { + if len(input) == 0 { + return nil, nil + } + + base64Decoded, err := base64.StdEncoding.DecodeString(input) + if err != nil { + return nil, err + } + + return r.Decrypt(base64Decoded) +} + +func NewRsaEncrypter(key []byte) (RsaEncrypter, error) { + block, _ := pem.Decode(key) + if block == nil { + return nil, ErrPublicKey + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + + switch pubKey := pub.(type) { + case *rsa.PublicKey: + return &rsaEncrypter{ + rsaBase: rsaBase{ + // https://www.ietf.org/rfc/rfc2313.txt + // The length of the data D shall not be more than k-11 octets, which is + // positive since the length k of the modulus is at least 12 octets. + bytesLimit: (pubKey.N.BitLen() >> 3) - 11, + }, + publicKey: pubKey, + }, nil + default: + return nil, ErrNotRsaKey + } +} + +func (r *rsaEncrypter) Encrypt(input []byte) ([]byte, error) { + return r.crypt(input, func(block []byte) ([]byte, error) { + return rsaEncryptBlock(r.publicKey, block) + }) +} + +func (r *rsaBase) crypt(input []byte, cryptFn func([]byte) ([]byte, error)) ([]byte, error) { + var result []byte + inputLen := len(input) + + for i := 0; i*r.bytesLimit < inputLen; i++ { + start := r.bytesLimit * i + var stop int + if r.bytesLimit*(i+1) > inputLen { + stop = inputLen + } else { + stop = r.bytesLimit * (i + 1) + } + bs, err := cryptFn(input[start:stop]) + if err != nil { + return nil, err + } + + result = append(result, bs...) + } + + return result, nil +} + +func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) { + return rsa.DecryptPKCS1v15(rand.Reader, privateKey, block) +} + +func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) { + return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg) +} diff --git a/core/collection/cache.go b/core/collection/cache.go new file mode 100644 index 00000000..c2eca2a6 --- /dev/null +++ b/core/collection/cache.go @@ -0,0 +1,275 @@ +package collection + +import ( + "container/list" + "sync" + "sync/atomic" + "time" + + "zero/core/logx" + "zero/core/mathx" + "zero/core/syncx" +) + +const ( + defaultCacheName = "proc" + slots = 300 + statInterval = time.Minute + // make the expiry unstable to avoid lots of cached items expire at the same time + // make the unstable expiry to be [0.95, 1.05] * seconds + expiryDeviation = 0.05 +) + +var emptyLruCache = emptyLru{} + +type ( + CacheOption func(cache *Cache) + + Cache struct { + name string + lock sync.Mutex + data map[string]interface{} + evicts *list.List + expire time.Duration + timingWheel *TimingWheel + lruCache lru + barrier syncx.SharedCalls + unstableExpiry mathx.Unstable + stats *cacheStat + } +) + +func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) { + cache := &Cache{ + data: make(map[string]interface{}), + expire: expire, + lruCache: emptyLruCache, + barrier: syncx.NewSharedCalls(), + unstableExpiry: mathx.NewUnstable(expiryDeviation), + } + + for _, opt := range opts { + opt(cache) + } + + if len(cache.name) == 0 { + cache.name = defaultCacheName + } + cache.stats = newCacheStat(cache.name, cache.size) + + timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v interface{}) { + key, ok := k.(string) + if !ok { + return + } + + cache.Del(key) + }) + if err != nil { + return nil, err + } + + cache.timingWheel = timingWheel + return cache, nil +} + +func (c *Cache) Del(key string) { + c.lock.Lock() + delete(c.data, key) + c.lruCache.remove(key) + c.lock.Unlock() + c.timingWheel.RemoveTimer(key) +} + +func (c *Cache) Get(key string) (interface{}, bool) { + c.lock.Lock() + value, ok := c.data[key] + if ok { + c.lruCache.add(key) + } + c.lock.Unlock() + if ok { + c.stats.IncrementHit() + } else { + c.stats.IncrementMiss() + } + + return value, ok +} + +func (c *Cache) Set(key string, value interface{}) { + c.lock.Lock() + _, ok := c.data[key] + c.data[key] = value + c.lruCache.add(key) + c.lock.Unlock() + + expiry := c.unstableExpiry.AroundDuration(c.expire) + if ok { + c.timingWheel.MoveTimer(key, expiry) + } else { + c.timingWheel.SetTimer(key, value, expiry) + } +} + +func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) { + val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) { + v, e := fetch() + if e != nil { + return nil, e + } + + c.Set(key, v) + return v, nil + }) + if err != nil { + return nil, err + } + + if fresh { + c.stats.IncrementMiss() + return val, nil + } else { + // got the result from previous ongoing query + c.stats.IncrementHit() + } + + return val, nil +} + +func (c *Cache) onEvict(key string) { + // already locked + delete(c.data, key) + c.timingWheel.RemoveTimer(key) +} + +func (c *Cache) size() int { + c.lock.Lock() + defer c.lock.Unlock() + return len(c.data) +} + +func WithLimit(limit int) CacheOption { + return func(cache *Cache) { + if limit > 0 { + cache.lruCache = newKeyLru(limit, cache.onEvict) + } + } +} + +func WithName(name string) CacheOption { + return func(cache *Cache) { + cache.name = name + } +} + +type ( + lru interface { + add(key string) + remove(key string) + } + + emptyLru struct{} + + keyLru struct { + limit int + evicts *list.List + elements map[string]*list.Element + onEvict func(key string) + } +) + +func (elru emptyLru) add(string) { +} + +func (elru emptyLru) remove(string) { +} + +func newKeyLru(limit int, onEvict func(key string)) *keyLru { + return &keyLru{ + limit: limit, + evicts: list.New(), + elements: make(map[string]*list.Element), + onEvict: onEvict, + } +} + +func (klru *keyLru) add(key string) { + if elem, ok := klru.elements[key]; ok { + klru.evicts.MoveToFront(elem) + return + } + + // Add new item + elem := klru.evicts.PushFront(key) + klru.elements[key] = elem + + // Verify size not exceeded + if klru.evicts.Len() > klru.limit { + klru.removeOldest() + } +} + +func (klru *keyLru) remove(key string) { + if elem, ok := klru.elements[key]; ok { + klru.removeElement(elem) + } +} + +func (klru *keyLru) removeOldest() { + elem := klru.evicts.Back() + if elem != nil { + klru.removeElement(elem) + } +} + +func (klru *keyLru) removeElement(e *list.Element) { + klru.evicts.Remove(e) + key := e.Value.(string) + delete(klru.elements, key) + klru.onEvict(key) +} + +type cacheStat struct { + name string + hit uint64 + miss uint64 + sizeCallback func() int +} + +func newCacheStat(name string, sizeCallback func() int) *cacheStat { + st := &cacheStat{ + name: name, + sizeCallback: sizeCallback, + } + go st.statLoop() + return st +} + +func (cs *cacheStat) IncrementHit() { + atomic.AddUint64(&cs.hit, 1) +} + +func (cs *cacheStat) IncrementMiss() { + atomic.AddUint64(&cs.miss, 1) +} + +func (cs *cacheStat) statLoop() { + ticker := time.NewTicker(statInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + hit := atomic.SwapUint64(&cs.hit, 0) + miss := atomic.SwapUint64(&cs.miss, 0) + total := hit + miss + if total == 0 { + continue + } + percent := 100 * float32(hit) / float32(total) + logx.Statf("cache(%s) - qpm: %d, hit_ratio: %.1f%%, elements: %d, hit: %d, miss: %d", + cs.name, total, percent, cs.sizeCallback(), hit, miss) + } + } +} diff --git a/core/collection/cache_test.go b/core/collection/cache_test.go new file mode 100644 index 00000000..b296e378 --- /dev/null +++ b/core/collection/cache_test.go @@ -0,0 +1,139 @@ +package collection + +import ( + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCacheSet(t *testing.T) { + cache, err := NewCache(time.Second*2, WithName("any")) + assert.Nil(t, err) + + cache.Set("first", "first element") + cache.Set("second", "second element") + + value, ok := cache.Get("first") + assert.True(t, ok) + assert.Equal(t, "first element", value) + value, ok = cache.Get("second") + assert.True(t, ok) + assert.Equal(t, "second element", value) +} + +func TestCacheDel(t *testing.T) { + cache, err := NewCache(time.Second * 2) + assert.Nil(t, err) + + cache.Set("first", "first element") + cache.Set("second", "second element") + cache.Del("first") + + _, ok := cache.Get("first") + assert.False(t, ok) + value, ok := cache.Get("second") + assert.True(t, ok) + assert.Equal(t, "second element", value) +} + +func TestCacheTake(t *testing.T) { + cache, err := NewCache(time.Second * 2) + assert.Nil(t, err) + + var count int32 + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + cache.Take("first", func() (interface{}, error) { + atomic.AddInt32(&count, 1) + time.Sleep(time.Millisecond * 100) + return "first element", nil + }) + wg.Done() + }() + } + wg.Wait() + + assert.Equal(t, 1, cache.size()) + assert.Equal(t, int32(1), atomic.LoadInt32(&count)) +} + +func TestCacheWithLruEvicts(t *testing.T) { + cache, err := NewCache(time.Minute, WithLimit(3)) + assert.Nil(t, err) + + cache.Set("first", "first element") + cache.Set("second", "second element") + cache.Set("third", "third element") + cache.Set("fourth", "fourth element") + + value, ok := cache.Get("first") + assert.False(t, ok) + value, ok = cache.Get("second") + assert.True(t, ok) + assert.Equal(t, "second element", value) + value, ok = cache.Get("third") + assert.True(t, ok) + assert.Equal(t, "third element", value) + value, ok = cache.Get("fourth") + assert.True(t, ok) + assert.Equal(t, "fourth element", value) +} + +func TestCacheWithLruEvicted(t *testing.T) { + cache, err := NewCache(time.Minute, WithLimit(3)) + assert.Nil(t, err) + + cache.Set("first", "first element") + cache.Set("second", "second element") + cache.Set("third", "third element") + cache.Set("fourth", "fourth element") + + value, ok := cache.Get("first") + assert.False(t, ok) + value, ok = cache.Get("second") + assert.True(t, ok) + assert.Equal(t, "second element", value) + cache.Set("fifth", "fifth element") + cache.Set("sixth", "sixth element") + _, ok = cache.Get("third") + assert.False(t, ok) + _, ok = cache.Get("fourth") + assert.False(t, ok) + value, ok = cache.Get("second") + assert.True(t, ok) + assert.Equal(t, "second element", value) +} + +func BenchmarkCache(b *testing.B) { + cache, err := NewCache(time.Second*5, WithLimit(100000)) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < 10000; i++ { + for j := 0; j < 10; j++ { + index := strconv.Itoa(i*10000 + j) + cache.Set("key:"+index, "value:"+index) + } + } + + time.Sleep(time.Second * 5) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + for i := 0; i < b.N; i++ { + index := strconv.Itoa(i % 10000) + cache.Get("key:" + index) + if i%100 == 0 { + cache.Set("key1:"+index, "value1:"+index) + } + } + } + }) +} diff --git a/core/collection/fifo.go b/core/collection/fifo.go new file mode 100644 index 00000000..be6eb3e2 --- /dev/null +++ b/core/collection/fifo.go @@ -0,0 +1,60 @@ +package collection + +import "sync" + +type Queue struct { + lock sync.Mutex + elements []interface{} + size int + head int + tail int + count int +} + +func NewQueue(size int) *Queue { + return &Queue{ + elements: make([]interface{}, size), + size: size, + } +} + +func (q *Queue) Empty() bool { + q.lock.Lock() + empty := q.count == 0 + q.lock.Unlock() + + return empty +} + +func (q *Queue) Put(element interface{}) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.head == q.tail && q.count > 0 { + nodes := make([]interface{}, len(q.elements)+q.size) + copy(nodes, q.elements[q.head:]) + copy(nodes[len(q.elements)-q.head:], q.elements[:q.head]) + q.head = 0 + q.tail = len(q.elements) + q.elements = nodes + } + + q.elements[q.tail] = element + q.tail = (q.tail + 1) % len(q.elements) + q.count++ +} + +func (q *Queue) Take() (interface{}, bool) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.count == 0 { + return nil, false + } + + element := q.elements[q.head] + q.head = (q.head + 1) % len(q.elements) + q.count-- + + return element, true +} diff --git a/core/collection/fifo_test.go b/core/collection/fifo_test.go new file mode 100644 index 00000000..3e58afef --- /dev/null +++ b/core/collection/fifo_test.go @@ -0,0 +1,63 @@ +package collection + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFifo(t *testing.T) { + elements := [][]byte{ + []byte("hello"), + []byte("world"), + []byte("again"), + } + queue := NewQueue(8) + for i := range elements { + queue.Put(elements[i]) + } + + for _, element := range elements { + body, ok := queue.Take() + assert.True(t, ok) + assert.Equal(t, string(element), string(body.([]byte))) + } +} + +func TestTakeTooMany(t *testing.T) { + elements := [][]byte{ + []byte("hello"), + []byte("world"), + []byte("again"), + } + queue := NewQueue(8) + for i := range elements { + queue.Put(elements[i]) + } + + for range elements { + queue.Take() + } + + assert.True(t, queue.Empty()) + _, ok := queue.Take() + assert.False(t, ok) +} + +func TestPutMore(t *testing.T) { + elements := [][]byte{ + []byte("hello"), + []byte("world"), + []byte("again"), + } + queue := NewQueue(2) + for i := range elements { + queue.Put(elements[i]) + } + + for _, element := range elements { + body, ok := queue.Take() + assert.True(t, ok) + assert.Equal(t, string(element), string(body.([]byte))) + } +} diff --git a/core/collection/ring.go b/core/collection/ring.go new file mode 100644 index 00000000..17060c06 --- /dev/null +++ b/core/collection/ring.go @@ -0,0 +1,35 @@ +package collection + +type Ring struct { + elements []interface{} + index int +} + +func NewRing(n int) *Ring { + return &Ring{ + elements: make([]interface{}, n), + } +} + +func (r *Ring) Add(v interface{}) { + r.elements[r.index%len(r.elements)] = v + r.index++ +} + +func (r *Ring) Take() []interface{} { + var size int + var start int + if r.index > len(r.elements) { + size = len(r.elements) + start = r.index % len(r.elements) + } else { + size = r.index + } + + elements := make([]interface{}, size) + for i := 0; i < size; i++ { + elements[i] = r.elements[(start+i)%len(r.elements)] + } + + return elements +} diff --git a/core/collection/ring_test.go b/core/collection/ring_test.go new file mode 100644 index 00000000..182b67f8 --- /dev/null +++ b/core/collection/ring_test.go @@ -0,0 +1,25 @@ +package collection + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRingLess(t *testing.T) { + ring := NewRing(5) + for i := 0; i < 3; i++ { + ring.Add(i) + } + elements := ring.Take() + assert.ElementsMatch(t, []interface{}{0, 1, 2}, elements) +} + +func TestRingMore(t *testing.T) { + ring := NewRing(5) + for i := 0; i < 11; i++ { + ring.Add(i) + } + elements := ring.Take() + assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements) +} diff --git a/core/collection/rollingwindow.go b/core/collection/rollingwindow.go new file mode 100644 index 00000000..c34726fe --- /dev/null +++ b/core/collection/rollingwindow.go @@ -0,0 +1,145 @@ +package collection + +import ( + "sync" + "time" + + "zero/core/timex" +) + +type ( + RollingWindowOption func(rollingWindow *RollingWindow) + + RollingWindow struct { + lock sync.RWMutex + size int + win *window + interval time.Duration + offset int + ignoreCurrent bool + lastTime time.Duration + } +) + +func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow { + w := &RollingWindow{ + size: size, + win: newWindow(size), + interval: interval, + lastTime: timex.Now(), + } + for _, opt := range opts { + opt(w) + } + return w +} + +func (rw *RollingWindow) Add(v float64) { + rw.lock.Lock() + defer rw.lock.Unlock() + rw.updateOffset() + rw.win.add(rw.offset, v) +} + +func (rw *RollingWindow) Reduce(fn func(b *Bucket)) { + rw.lock.RLock() + defer rw.lock.RUnlock() + + var diff int + span := rw.span() + // ignore current bucket, because of partial data + if span == 0 && rw.ignoreCurrent { + diff = rw.size - 1 + } else { + diff = rw.size - span + } + if diff > 0 { + offset := (rw.offset + span + 1) % rw.size + rw.win.reduce(offset, diff, fn) + } +} + +func (rw *RollingWindow) span() int { + offset := int(timex.Since(rw.lastTime) / rw.interval) + if 0 <= offset && offset < rw.size { + return offset + } else { + return rw.size + } +} + +func (rw *RollingWindow) updateOffset() { + span := rw.span() + if span > 0 { + offset := rw.offset + // reset expired buckets + start := offset + 1 + steps := start + span + var remainder int + if steps > rw.size { + remainder = steps - rw.size + steps = rw.size + } + for i := start; i < steps; i++ { + rw.win.resetBucket(i) + offset = i + } + for i := 0; i < remainder; i++ { + rw.win.resetBucket(i) + offset = i + } + rw.offset = offset + rw.lastTime = timex.Now() + } +} + +type Bucket struct { + Sum float64 + Count int64 +} + +func (b *Bucket) add(v float64) { + b.Sum += v + b.Count++ +} + +func (b *Bucket) reset() { + b.Sum = 0 + b.Count = 0 +} + +type window struct { + buckets []*Bucket + size int +} + +func newWindow(size int) *window { + var buckets []*Bucket + for i := 0; i < size; i++ { + buckets = append(buckets, new(Bucket)) + } + return &window{ + buckets: buckets, + size: size, + } +} + +func (w *window) add(offset int, v float64) { + w.buckets[offset%w.size].add(v) +} + +func (w *window) reduce(start, count int, fn func(b *Bucket)) { + for i := 0; i < count; i++ { + fn(w.buckets[(start+i)%len(w.buckets)]) + } +} + +func (w *window) resetBucket(offset int) { + w.buckets[offset].reset() +} + +func IgnoreCurrentBucket() RollingWindowOption { + return func(w *RollingWindow) { + w.ignoreCurrent = true + } +} diff --git a/core/collection/rollingwindow_test.go b/core/collection/rollingwindow_test.go new file mode 100644 index 00000000..951117c8 --- /dev/null +++ b/core/collection/rollingwindow_test.go @@ -0,0 +1,133 @@ +package collection + +import ( + "math/rand" + "testing" + "time" + + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +const duration = time.Millisecond * 50 + +func TestRollingWindowAdd(t *testing.T) { + const size = 3 + r := NewRollingWindow(size, duration) + listBuckets := func() []float64 { + var buckets []float64 + r.Reduce(func(b *Bucket) { + buckets = append(buckets, b.Sum) + }) + return buckets + } + assert.Equal(t, []float64{0, 0, 0}, listBuckets()) + r.Add(1) + assert.Equal(t, []float64{0, 0, 1}, listBuckets()) + elapse() + r.Add(2) + r.Add(3) + assert.Equal(t, []float64{0, 1, 5}, listBuckets()) + elapse() + r.Add(4) + r.Add(5) + r.Add(6) + assert.Equal(t, []float64{1, 5, 15}, listBuckets()) + elapse() + r.Add(7) + assert.Equal(t, []float64{5, 15, 7}, listBuckets()) +} + +func TestRollingWindowReset(t *testing.T) { + const size = 3 + r := NewRollingWindow(size, duration, IgnoreCurrentBucket()) + listBuckets := func() []float64 { + var buckets []float64 + r.Reduce(func(b *Bucket) { + buckets = append(buckets, b.Sum) + }) + return buckets + } + r.Add(1) + elapse() + assert.Equal(t, []float64{0, 1}, listBuckets()) + elapse() + assert.Equal(t, []float64{1}, listBuckets()) + elapse() + assert.Nil(t, listBuckets()) + + // cross window + r.Add(1) + time.Sleep(duration * 10) + assert.Nil(t, listBuckets()) +} + +func TestRollingWindowReduce(t *testing.T) { + const size = 4 + tests := []struct { + win *RollingWindow + expect float64 + }{ + { + win: NewRollingWindow(size, duration), + expect: 10, + }, + { + win: NewRollingWindow(size, duration, IgnoreCurrentBucket()), + expect: 4, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + r := test.win + for x := 0; x < size; x = x + 1 { + for i := 0; i <= x; i++ { + r.Add(float64(i)) + } + if x < size-1 { + elapse() + } + } + var result float64 + r.Reduce(func(b *Bucket) { + result += b.Sum + }) + assert.Equal(t, test.expect, result) + }) + } +} + +func TestRollingWindowDataRace(t *testing.T) { + const size = 3 + r := NewRollingWindow(size, duration) + var stop = make(chan bool) + go func() { + for { + select { + case <-stop: + return + default: + r.Add(float64(rand.Int63())) + time.Sleep(duration / 2) + } + } + }() + go func() { + for { + select { + case <-stop: + return + default: + r.Reduce(func(b *Bucket) {}) + } + } + }() + time.Sleep(duration * 5) + close(stop) +} + +func elapse() { + time.Sleep(duration) +} diff --git a/core/collection/safemap.go b/core/collection/safemap.go new file mode 100644 index 00000000..a8547a5f --- /dev/null +++ b/core/collection/safemap.go @@ -0,0 +1,91 @@ +package collection + +import "sync" + +const ( + copyThreshold = 1000 + maxDeletion = 10000 +) + +// SafeMap provides a map alternative to avoid memory leak. +// This implementation is not needed until issue below fixed. +// https://github.com/golang/go/issues/20135 +type SafeMap struct { + lock sync.RWMutex + deletionOld int + deletionNew int + dirtyOld map[interface{}]interface{} + dirtyNew map[interface{}]interface{} +} + +func NewSafeMap() *SafeMap { + return &SafeMap{ + dirtyOld: make(map[interface{}]interface{}), + dirtyNew: make(map[interface{}]interface{}), + } +} + +func (m *SafeMap) Del(key interface{}) { + m.lock.Lock() + if _, ok := m.dirtyOld[key]; ok { + delete(m.dirtyOld, key) + m.deletionOld++ + } else if _, ok := m.dirtyNew[key]; ok { + delete(m.dirtyNew, key) + m.deletionNew++ + } + if m.deletionOld >= maxDeletion && len(m.dirtyOld) < copyThreshold { + for k, v := range m.dirtyOld { + m.dirtyNew[k] = v + } + m.dirtyOld = m.dirtyNew + m.deletionOld = m.deletionNew + m.dirtyNew = make(map[interface{}]interface{}) + m.deletionNew = 0 + } + if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold { + for k, v := range m.dirtyNew { + m.dirtyOld[k] = v + } + m.dirtyNew = make(map[interface{}]interface{}) + m.deletionNew = 0 + } + m.lock.Unlock() +} + +func (m *SafeMap) Get(key interface{}) (interface{}, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + + if val, ok := m.dirtyOld[key]; ok { + return val, true + } else { + val, ok := m.dirtyNew[key] + return val, ok + } +} + +func (m *SafeMap) Set(key, value interface{}) { + m.lock.Lock() + if m.deletionOld <= maxDeletion { + if _, ok := m.dirtyNew[key]; ok { + delete(m.dirtyNew, key) + m.deletionNew++ + } + m.dirtyOld[key] = value + } else { + if _, ok := m.dirtyOld[key]; ok { + delete(m.dirtyOld, key) + m.deletionOld++ + } + m.dirtyNew[key] = value + } + m.lock.Unlock() +} + +func (m *SafeMap) Size() int { + m.lock.RLock() + size := len(m.dirtyOld) + len(m.dirtyNew) + m.lock.RUnlock() + return size +} diff --git a/core/collection/safemap_test.go b/core/collection/safemap_test.go new file mode 100644 index 00000000..afce3476 --- /dev/null +++ b/core/collection/safemap_test.go @@ -0,0 +1,110 @@ +package collection + +import ( + "testing" + + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +func TestSafeMap(t *testing.T) { + tests := []struct { + size int + exception int + }{ + { + 100000, + 2000, + }, + { + 100000, + 50, + }, + } + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + testSafeMapWithParameters(t, test.size, test.exception) + }) + } +} + +func TestSafeMap_CopyNew(t *testing.T) { + const ( + size = 100000 + exception1 = 5 + exception2 = 500 + ) + m := NewSafeMap() + + for i := 0; i < size; i++ { + m.Set(i, i) + } + for i := 0; i < size; i++ { + if i%exception1 == 0 { + m.Del(i) + } + } + + for i := size; i < size<<1; i++ { + m.Set(i, i) + } + for i := size; i < size<<1; i++ { + if i%exception2 != 0 { + m.Del(i) + } + } + + for i := 0; i < size; i++ { + val, ok := m.Get(i) + if i%exception1 != 0 { + assert.True(t, ok) + assert.Equal(t, i, val.(int)) + } else { + assert.False(t, ok) + } + } + for i := size; i < size<<1; i++ { + val, ok := m.Get(i) + if i%exception2 == 0 { + assert.True(t, ok) + assert.Equal(t, i, val.(int)) + } else { + assert.False(t, ok) + } + } +} + +func testSafeMapWithParameters(t *testing.T, size, exception int) { + m := NewSafeMap() + + for i := 0; i < size; i++ { + m.Set(i, i) + } + for i := 0; i < size; i++ { + if i%exception != 0 { + m.Del(i) + } + } + + assert.Equal(t, size/exception, m.Size()) + + for i := size; i < size<<1; i++ { + m.Set(i, i) + } + for i := size; i < size<<1; i++ { + if i%exception != 0 { + m.Del(i) + } + } + + for i := 0; i < size<<1; i++ { + val, ok := m.Get(i) + if i%exception == 0 { + assert.True(t, ok) + assert.Equal(t, i, val.(int)) + } else { + assert.False(t, ok) + } + } +} diff --git a/core/collection/set.go b/core/collection/set.go new file mode 100644 index 00000000..4a1f3ea7 --- /dev/null +++ b/core/collection/set.go @@ -0,0 +1,230 @@ +package collection + +import ( + "zero/core/lang" + "zero/core/logx" +) + +const ( + unmanaged = iota + untyped + intType + int64Type + uintType + uint64Type + stringType +) + +type Set struct { + data map[interface{}]lang.PlaceholderType + tp int +} + +func NewSet() *Set { + return &Set{ + data: make(map[interface{}]lang.PlaceholderType), + tp: untyped, + } +} + +func NewUnmanagedSet() *Set { + return &Set{ + data: make(map[interface{}]lang.PlaceholderType), + tp: unmanaged, + } +} + +func (s *Set) Add(i ...interface{}) { + for _, each := range i { + s.add(each) + } +} + +func (s *Set) AddInt(ii ...int) { + for _, each := range ii { + s.add(each) + } +} + +func (s *Set) AddInt64(ii ...int64) { + for _, each := range ii { + s.add(each) + } +} + +func (s *Set) AddUint(ii ...uint) { + for _, each := range ii { + s.add(each) + } +} + +func (s *Set) AddUint64(ii ...uint64) { + for _, each := range ii { + s.add(each) + } +} + +func (s *Set) AddStr(ss ...string) { + for _, each := range ss { + s.add(each) + } +} + +func (s *Set) Contains(i interface{}) bool { + if len(s.data) == 0 { + return false + } + + s.validate(i) + _, ok := s.data[i] + return ok +} + +func (s *Set) Keys() []interface{} { + var keys []interface{} + + for key := range s.data { + keys = append(keys, key) + } + + return keys +} + +func (s *Set) KeysInt() []int { + var keys []int + + for key := range s.data { + if intKey, ok := key.(int); !ok { + continue + } else { + keys = append(keys, intKey) + } + } + + return keys +} + +func (s *Set) KeysInt64() []int64 { + var keys []int64 + + for key := range s.data { + if intKey, ok := key.(int64); !ok { + continue + } else { + keys = append(keys, intKey) + } + } + + return keys +} + +func (s *Set) KeysUint() []uint { + var keys []uint + + for key := range s.data { + if intKey, ok := key.(uint); !ok { + continue + } else { + keys = append(keys, intKey) + } + } + + return keys +} + +func (s *Set) KeysUint64() []uint64 { + var keys []uint64 + + for key := range s.data { + if intKey, ok := key.(uint64); !ok { + continue + } else { + keys = append(keys, intKey) + } + } + + return keys +} + +func (s *Set) KeysStr() []string { + var keys []string + + for key := range s.data { + if strKey, ok := key.(string); !ok { + continue + } else { + keys = append(keys, strKey) + } + } + + return keys +} + +func (s *Set) Remove(i interface{}) { + s.validate(i) + delete(s.data, i) +} + +func (s *Set) Count() int { + return len(s.data) +} + +func (s *Set) add(i interface{}) { + switch s.tp { + case unmanaged: + // do nothing + case untyped: + s.setType(i) + default: + s.validate(i) + } + s.data[i] = lang.Placeholder +} + +func (s *Set) setType(i interface{}) { + if s.tp != untyped { + return + } + + switch i.(type) { + case int: + s.tp = intType + case int64: + s.tp = int64Type + case uint: + s.tp = uintType + case uint64: + s.tp = uint64Type + case string: + s.tp = stringType + } +} + +func (s *Set) validate(i interface{}) { + if s.tp == unmanaged { + return + } + + switch i.(type) { + case int: + if s.tp != intType { + logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp) + } + case int64: + if s.tp != int64Type { + logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp) + } + case uint: + if s.tp != uintType { + logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp) + } + case uint64: + if s.tp != uint64Type { + logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp) + } + case string: + if s.tp != stringType { + logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp) + } + } +} diff --git a/core/collection/set_test.go b/core/collection/set_test.go new file mode 100644 index 00000000..0841d5cd --- /dev/null +++ b/core/collection/set_test.go @@ -0,0 +1,149 @@ +package collection + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkRawSet(b *testing.B) { + m := make(map[interface{}]struct{}) + for i := 0; i < b.N; i++ { + m[i] = struct{}{} + _ = m[i] + } +} + +func BenchmarkUnmanagedSet(b *testing.B) { + s := NewUnmanagedSet() + for i := 0; i < b.N; i++ { + s.Add(i) + _ = s.Contains(i) + } +} + +func BenchmarkSet(b *testing.B) { + s := NewSet() + for i := 0; i < b.N; i++ { + s.AddInt(i) + _ = s.Contains(i) + } +} + +func TestAdd(t *testing.T) { + // given + set := NewUnmanagedSet() + values := []interface{}{1, 2, 3} + + // when + set.Add(values...) + + // then + assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) + assert.Equal(t, len(values), len(set.Keys())) +} + +func TestAddInt(t *testing.T) { + // given + set := NewSet() + values := []int{1, 2, 3} + + // when + set.AddInt(values...) + + // then + assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) + keys := set.KeysInt() + sort.Ints(keys) + assert.EqualValues(t, values, keys) +} + +func TestAddInt64(t *testing.T) { + // given + set := NewSet() + values := []int64{1, 2, 3} + + // when + set.AddInt64(values...) + + // then + assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3))) + assert.Equal(t, len(values), len(set.KeysInt64())) +} + +func TestAddUint(t *testing.T) { + // given + set := NewSet() + values := []uint{1, 2, 3} + + // when + set.AddUint(values...) + + // then + assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3))) + assert.Equal(t, len(values), len(set.KeysUint())) +} + +func TestAddUint64(t *testing.T) { + // given + set := NewSet() + values := []uint64{1, 2, 3} + + // when + set.AddUint64(values...) + + // then + assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3))) + assert.Equal(t, len(values), len(set.KeysUint64())) +} + +func TestAddStr(t *testing.T) { + // given + set := NewSet() + values := []string{"1", "2", "3"} + + // when + set.AddStr(values...) + + // then + assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3")) + assert.Equal(t, len(values), len(set.KeysStr())) +} + +func TestContainsWithoutElements(t *testing.T) { + // given + set := NewSet() + + // then + assert.False(t, set.Contains(1)) +} + +func TestContainsUnmanagedWithoutElements(t *testing.T) { + // given + set := NewUnmanagedSet() + + // then + assert.False(t, set.Contains(1)) +} + +func TestRemove(t *testing.T) { + // given + set := NewSet() + set.Add([]interface{}{1, 2, 3}...) + + // when + set.Remove(2) + + // then + assert.True(t, set.Contains(1) && !set.Contains(2) && set.Contains(3)) +} + +func TestCount(t *testing.T) { + // given + set := NewSet() + set.Add([]interface{}{1, 2, 3}...) + + // then + assert.Equal(t, set.Count(), 3) +} diff --git a/core/collection/timingwheel.go b/core/collection/timingwheel.go new file mode 100644 index 00000000..7bb97473 --- /dev/null +++ b/core/collection/timingwheel.go @@ -0,0 +1,311 @@ +package collection + +import ( + "container/list" + "fmt" + "time" + + "zero/core/lang" + "zero/core/threading" + "zero/core/timex" +) + +const drainWorkers = 8 + +type ( + Execute func(key, value interface{}) + + TimingWheel struct { + interval time.Duration + ticker timex.Ticker + slots []*list.List + timers *SafeMap + tickedPos int + numSlots int + execute Execute + setChannel chan timingEntry + moveChannel chan baseEntry + removeChannel chan interface{} + drainChannel chan func(key, value interface{}) + stopChannel chan lang.PlaceholderType + } + + timingEntry struct { + baseEntry + value interface{} + circle int + diff int + removed bool + } + + baseEntry struct { + delay time.Duration + key interface{} + } + + positionEntry struct { + pos int + item *timingEntry + } + + timingTask struct { + key interface{} + value interface{} + } +) + +func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) { + if interval <= 0 || numSlots <= 0 || execute == nil { + return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute) + } + + return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval)) +} + +func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) ( + *TimingWheel, error) { + tw := &TimingWheel{ + interval: interval, + ticker: ticker, + slots: make([]*list.List, numSlots), + timers: NewSafeMap(), + tickedPos: numSlots - 1, // at previous virtual circle + execute: execute, + numSlots: numSlots, + setChannel: make(chan timingEntry), + moveChannel: make(chan baseEntry), + removeChannel: make(chan interface{}), + drainChannel: make(chan func(key, value interface{})), + stopChannel: make(chan lang.PlaceholderType), + } + + tw.initSlots() + go tw.run() + + return tw, nil +} + +func (tw *TimingWheel) Drain(fn func(key, value interface{})) { + tw.drainChannel <- fn +} + +func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) { + if delay <= 0 || key == nil { + return + } + + tw.moveChannel <- baseEntry{ + delay: delay, + key: key, + } +} + +func (tw *TimingWheel) RemoveTimer(key interface{}) { + if key == nil { + return + } + + tw.removeChannel <- key +} + +func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) { + if delay <= 0 || key == nil { + return + } + + tw.setChannel <- timingEntry{ + baseEntry: baseEntry{ + delay: delay, + key: key, + }, + value: value, + } +} + +func (tw *TimingWheel) Stop() { + close(tw.stopChannel) +} + +func (tw *TimingWheel) drainAll(fn func(key, value interface{})) { + runner := threading.NewTaskRunner(drainWorkers) + for _, slot := range tw.slots { + for e := slot.Front(); e != nil; { + task := e.Value.(*timingEntry) + next := e.Next() + slot.Remove(e) + e = next + if !task.removed { + runner.Schedule(func() { + fn(task.key, task.value) + }) + } + } + } +} + +func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) { + steps := int(d / tw.interval) + pos = (tw.tickedPos + steps) % tw.numSlots + circle = (steps - 1) / tw.numSlots + + return +} + +func (tw *TimingWheel) initSlots() { + for i := 0; i < tw.numSlots; i++ { + tw.slots[i] = list.New() + } +} + +func (tw *TimingWheel) moveTask(task baseEntry) { + val, ok := tw.timers.Get(task.key) + if !ok { + return + } + + timer := val.(*positionEntry) + if task.delay < tw.interval { + threading.GoSafe(func() { + tw.execute(timer.item.key, timer.item.value) + }) + return + } + + pos, circle := tw.getPositionAndCircle(task.delay) + if pos >= timer.pos { + timer.item.circle = circle + timer.item.diff = pos - timer.pos + } else if circle > 0 { + circle-- + timer.item.circle = circle + timer.item.diff = tw.numSlots + pos - timer.pos + } else { + timer.item.removed = true + newItem := &timingEntry{ + baseEntry: task, + value: timer.item.value, + } + tw.slots[pos].PushBack(newItem) + tw.setTimerPosition(pos, newItem) + } +} + +func (tw *TimingWheel) onTick() { + tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots + l := tw.slots[tw.tickedPos] + tw.scanAndRunTasks(l) +} + +func (tw *TimingWheel) removeTask(key interface{}) { + val, ok := tw.timers.Get(key) + if !ok { + return + } + + timer := val.(*positionEntry) + timer.item.removed = true +} + +func (tw *TimingWheel) run() { + for { + select { + case <-tw.ticker.Chan(): + tw.onTick() + case task := <-tw.setChannel: + tw.setTask(&task) + case key := <-tw.removeChannel: + tw.removeTask(key) + case task := <-tw.moveChannel: + tw.moveTask(task) + case fn := <-tw.drainChannel: + tw.drainAll(fn) + case <-tw.stopChannel: + tw.ticker.Stop() + return + } + } +} + +func (tw *TimingWheel) runTasks(tasks []timingTask) { + if len(tasks) == 0 { + return + } + + go func() { + for i := range tasks { + threading.RunSafe(func() { + tw.execute(tasks[i].key, tasks[i].value) + }) + } + }() +} + +func (tw *TimingWheel) scanAndRunTasks(l *list.List) { + var tasks []timingTask + + for e := l.Front(); e != nil; { + task := e.Value.(*timingEntry) + if task.removed { + next := e.Next() + l.Remove(e) + tw.timers.Del(task.key) + e = next + continue + } else if task.circle > 0 { + task.circle-- + e = e.Next() + continue + } else if task.diff > 0 { + next := e.Next() + l.Remove(e) + // (tw.tickedPos+task.diff)%tw.numSlots + // cannot be the same value of tw.tickedPos + pos := (tw.tickedPos + task.diff) % tw.numSlots + tw.slots[pos].PushBack(task) + tw.setTimerPosition(pos, task) + task.diff = 0 + e = next + continue + } + + tasks = append(tasks, timingTask{ + key: task.key, + value: task.value, + }) + next := e.Next() + l.Remove(e) + tw.timers.Del(task.key) + e = next + } + + tw.runTasks(tasks) +} + +func (tw *TimingWheel) setTask(task *timingEntry) { + if task.delay < tw.interval { + task.delay = tw.interval + } + + if val, ok := tw.timers.Get(task.key); ok { + entry := val.(*positionEntry) + entry.item.value = task.value + tw.moveTask(task.baseEntry) + } else { + pos, circle := tw.getPositionAndCircle(task.delay) + task.circle = circle + tw.slots[pos].PushBack(task) + tw.setTimerPosition(pos, task) + } +} + +func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) { + if val, ok := tw.timers.Get(task.key); ok { + timer := val.(*positionEntry) + timer.pos = pos + } else { + tw.timers.Set(task.key, &positionEntry{ + pos: pos, + item: task, + }) + } +} diff --git a/core/collection/timingwheel_test.go b/core/collection/timingwheel_test.go new file mode 100644 index 00000000..569f5ac8 --- /dev/null +++ b/core/collection/timingwheel_test.go @@ -0,0 +1,593 @@ +package collection + +import ( + "sort" + "sync" + "sync/atomic" + "testing" + "time" + + "zero/core/lang" + "zero/core/stringx" + "zero/core/syncx" + "zero/core/timex" + + "github.com/stretchr/testify/assert" +) + +const ( + testStep = time.Minute + waitTime = time.Second +) + +func TestNewTimingWheel(t *testing.T) { + _, err := NewTimingWheel(0, 10, func(key, value interface{}) {}) + assert.NotNil(t, err) +} + +func TestTimingWheel_Drain(t *testing.T) { + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) { + }, ticker) + defer tw.Stop() + tw.SetTimer("first", 3, testStep*4) + tw.SetTimer("second", 5, testStep*7) + tw.SetTimer("third", 7, testStep*7) + var keys []string + var vals []int + var lock sync.Mutex + var wg sync.WaitGroup + wg.Add(3) + tw.Drain(func(key, value interface{}) { + lock.Lock() + defer lock.Unlock() + keys = append(keys, key.(string)) + vals = append(vals, value.(int)) + wg.Done() + }) + wg.Wait() + sort.Strings(keys) + sort.Ints(vals) + assert.Equal(t, 3, len(keys)) + assert.EqualValues(t, []string{"first", "second", "third"}, keys) + assert.EqualValues(t, []int{3, 5, 7}, vals) + var count int + tw.Drain(func(key, value interface{}) { + count++ + }) + time.Sleep(time.Millisecond * 100) + assert.Equal(t, 0, count) +} + +func TestTimingWheel_SetTimerSoon(t *testing.T) { + run := syncx.NewAtomicBool() + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "any", k) + assert.Equal(t, 3, v.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + tw.SetTimer("any", 3, testStep>>1) + ticker.Tick() + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func TestTimingWheel_SetTimerTwice(t *testing.T) { + run := syncx.NewAtomicBool() + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "any", k) + assert.Equal(t, 5, v.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + tw.SetTimer("any", 3, testStep*4) + tw.SetTimer("any", 5, testStep*7) + for i := 0; i < 8; i++ { + ticker.Tick() + } + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func TestTimingWheel_SetTimerWrongDelay(t *testing.T) { + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker) + defer tw.Stop() + assert.NotPanics(t, func() { + tw.SetTimer("any", 3, -testStep) + }) +} + +func TestTimingWheel_MoveTimer(t *testing.T) { + run := syncx.NewAtomicBool() + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "any", k) + assert.Equal(t, 3, v.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + tw.SetTimer("any", 3, testStep*4) + tw.MoveTimer("any", testStep*7) + tw.MoveTimer("any", -testStep) + tw.MoveTimer("none", testStep) + for i := 0; i < 5; i++ { + ticker.Tick() + } + assert.False(t, run.True()) + for i := 0; i < 3; i++ { + ticker.Tick() + } + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func TestTimingWheel_MoveTimerSoon(t *testing.T) { + run := syncx.NewAtomicBool() + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "any", k) + assert.Equal(t, 3, v.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + tw.SetTimer("any", 3, testStep*4) + tw.MoveTimer("any", testStep>>1) + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func TestTimingWheel_MoveTimerEarlier(t *testing.T) { + run := syncx.NewAtomicBool() + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) { + assert.True(t, run.CompareAndSwap(false, true)) + assert.Equal(t, "any", k) + assert.Equal(t, 3, v.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + tw.SetTimer("any", 3, testStep*4) + tw.MoveTimer("any", testStep*2) + for i := 0; i < 3; i++ { + ticker.Tick() + } + assert.Nil(t, ticker.Wait(waitTime)) + assert.True(t, run.True()) +} + +func TestTimingWheel_RemoveTimer(t *testing.T) { + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker) + tw.SetTimer("any", 3, testStep) + assert.NotPanics(t, func() { + tw.RemoveTimer("any") + tw.RemoveTimer("none") + tw.RemoveTimer(nil) + }) + for i := 0; i < 5; i++ { + ticker.Tick() + } + tw.Stop() +} + +func TestTimingWheel_SetTimer(t *testing.T) { + tests := []struct { + slots int + setAt time.Duration + }{ + { + slots: 5, + setAt: 5, + }, + { + slots: 5, + setAt: 7, + }, + { + slots: 5, + setAt: 10, + }, + { + slots: 5, + setAt: 12, + }, + { + slots: 5, + setAt: 7, + }, + { + slots: 5, + setAt: 10, + }, + { + slots: 5, + setAt: 12, + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + var count int32 + ticker := timex.NewFakeTicker() + tick := func() { + atomic.AddInt32(&count, 1) + ticker.Tick() + time.Sleep(time.Millisecond) + } + var actual int32 + done := make(chan lang.PlaceholderType) + tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) { + assert.Equal(t, 1, key.(int)) + assert.Equal(t, 2, value.(int)) + actual = atomic.LoadInt32(&count) + close(done) + }, ticker) + assert.Nil(t, err) + defer tw.Stop() + + tw.SetTimer(1, 2, testStep*test.setAt) + + for { + select { + case <-done: + assert.Equal(t, int32(test.setAt), actual) + return + default: + tick() + } + } + }) + } +} + +func TestTimingWheel_SetAndMoveThenStart(t *testing.T) { + tests := []struct { + slots int + setAt time.Duration + moveAt time.Duration + }{ + { + slots: 5, + setAt: 3, + moveAt: 5, + }, + { + slots: 5, + setAt: 3, + moveAt: 7, + }, + { + slots: 5, + setAt: 3, + moveAt: 10, + }, + { + slots: 5, + setAt: 3, + moveAt: 12, + }, + { + slots: 5, + setAt: 5, + moveAt: 7, + }, + { + slots: 5, + setAt: 5, + moveAt: 10, + }, + { + slots: 5, + setAt: 5, + moveAt: 12, + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + var count int32 + ticker := timex.NewFakeTicker() + tick := func() { + atomic.AddInt32(&count, 1) + ticker.Tick() + time.Sleep(time.Millisecond * 10) + } + var actual int32 + done := make(chan lang.PlaceholderType) + tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) { + actual = atomic.LoadInt32(&count) + close(done) + }, ticker) + assert.Nil(t, err) + defer tw.Stop() + + tw.SetTimer(1, 2, testStep*test.setAt) + tw.MoveTimer(1, testStep*test.moveAt) + + for { + select { + case <-done: + assert.Equal(t, int32(test.moveAt), actual) + return + default: + tick() + } + } + }) + } +} + +func TestTimingWheel_SetAndMoveTwice(t *testing.T) { + tests := []struct { + slots int + setAt time.Duration + moveAt time.Duration + moveAgainAt time.Duration + }{ + { + slots: 5, + setAt: 3, + moveAt: 5, + moveAgainAt: 10, + }, + { + slots: 5, + setAt: 3, + moveAt: 7, + moveAgainAt: 12, + }, + { + slots: 5, + setAt: 3, + moveAt: 10, + moveAgainAt: 15, + }, + { + slots: 5, + setAt: 3, + moveAt: 12, + moveAgainAt: 17, + }, + { + slots: 5, + setAt: 5, + moveAt: 7, + moveAgainAt: 12, + }, + { + slots: 5, + setAt: 5, + moveAt: 10, + moveAgainAt: 17, + }, + { + slots: 5, + setAt: 5, + moveAt: 12, + moveAgainAt: 17, + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + var count int32 + ticker := timex.NewFakeTicker() + tick := func() { + atomic.AddInt32(&count, 1) + ticker.Tick() + time.Sleep(time.Millisecond * 10) + } + var actual int32 + done := make(chan lang.PlaceholderType) + tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) { + actual = atomic.LoadInt32(&count) + close(done) + }, ticker) + assert.Nil(t, err) + defer tw.Stop() + + tw.SetTimer(1, 2, testStep*test.setAt) + tw.MoveTimer(1, testStep*test.moveAt) + tw.MoveTimer(1, testStep*test.moveAgainAt) + + for { + select { + case <-done: + assert.Equal(t, int32(test.moveAgainAt), actual) + return + default: + tick() + } + } + }) + } +} + +func TestTimingWheel_ElapsedAndSet(t *testing.T) { + tests := []struct { + slots int + elapsed time.Duration + setAt time.Duration + }{ + { + slots: 5, + elapsed: 3, + setAt: 5, + }, + { + slots: 5, + elapsed: 3, + setAt: 7, + }, + { + slots: 5, + elapsed: 3, + setAt: 10, + }, + { + slots: 5, + elapsed: 3, + setAt: 12, + }, + { + slots: 5, + elapsed: 5, + setAt: 7, + }, + { + slots: 5, + elapsed: 5, + setAt: 10, + }, + { + slots: 5, + elapsed: 5, + setAt: 12, + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + var count int32 + ticker := timex.NewFakeTicker() + tick := func() { + atomic.AddInt32(&count, 1) + ticker.Tick() + time.Sleep(time.Millisecond * 10) + } + var actual int32 + done := make(chan lang.PlaceholderType) + tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) { + actual = atomic.LoadInt32(&count) + close(done) + }, ticker) + assert.Nil(t, err) + defer tw.Stop() + + for i := 0; i < int(test.elapsed); i++ { + tick() + } + + tw.SetTimer(1, 2, testStep*test.setAt) + + for { + select { + case <-done: + assert.Equal(t, int32(test.elapsed+test.setAt), actual) + return + default: + tick() + } + } + }) + } +} + +func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) { + tests := []struct { + slots int + elapsed time.Duration + setAt time.Duration + moveAt time.Duration + }{ + { + slots: 5, + elapsed: 3, + setAt: 5, + moveAt: 10, + }, + { + slots: 5, + elapsed: 3, + setAt: 7, + moveAt: 12, + }, + { + slots: 5, + elapsed: 3, + setAt: 10, + moveAt: 15, + }, + { + slots: 5, + elapsed: 3, + setAt: 12, + moveAt: 16, + }, + { + slots: 5, + elapsed: 5, + setAt: 7, + moveAt: 12, + }, + { + slots: 5, + elapsed: 5, + setAt: 10, + moveAt: 15, + }, + { + slots: 5, + elapsed: 5, + setAt: 12, + moveAt: 17, + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + var count int32 + ticker := timex.NewFakeTicker() + tick := func() { + atomic.AddInt32(&count, 1) + ticker.Tick() + time.Sleep(time.Millisecond * 10) + } + var actual int32 + done := make(chan lang.PlaceholderType) + tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) { + actual = atomic.LoadInt32(&count) + close(done) + }, ticker) + assert.Nil(t, err) + defer tw.Stop() + + for i := 0; i < int(test.elapsed); i++ { + tick() + } + + tw.SetTimer(1, 2, testStep*test.setAt) + tw.MoveTimer(1, testStep*test.moveAt) + + for { + select { + case <-done: + assert.Equal(t, int32(test.elapsed+test.moveAt), actual) + return + default: + tick() + } + } + }) + } +} + +func BenchmarkTimingWheel(b *testing.B) { + b.ReportAllocs() + + tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {}) + for i := 0; i < b.N; i++ { + tw.SetTimer(i, i, time.Second) + tw.SetTimer(b.N+i, b.N+i, time.Second) + tw.MoveTimer(i, time.Second*time.Duration(i)) + tw.RemoveTimer(i) + } +} diff --git a/core/conf/config.go b/core/conf/config.go new file mode 100644 index 00000000..6a316d68 --- /dev/null +++ b/core/conf/config.go @@ -0,0 +1,40 @@ +package conf + +import ( + "fmt" + "io/ioutil" + "log" + "path" + + "zero/core/mapping" +) + +var loaders = map[string]func([]byte, interface{}) error{ + ".json": LoadConfigFromJsonBytes, + ".yaml": LoadConfigFromYamlBytes, + ".yml": LoadConfigFromYamlBytes, +} + +func LoadConfig(file string, v interface{}) error { + if content, err := ioutil.ReadFile(file); err != nil { + return err + } else if loader, ok := loaders[path.Ext(file)]; ok { + return loader(content, v) + } else { + return fmt.Errorf("unrecoginized file type: %s", file) + } +} + +func LoadConfigFromJsonBytes(content []byte, v interface{}) error { + return mapping.UnmarshalJsonBytes(content, v) +} + +func LoadConfigFromYamlBytes(content []byte, v interface{}) error { + return mapping.UnmarshalYamlBytes(content, v) +} + +func MustLoad(path string, v interface{}) { + if err := LoadConfig(path, v); err != nil { + log.Fatalf("error: config file %s, %s", path, err.Error()) + } +} diff --git a/core/conf/properties.go b/core/conf/properties.go new file mode 100644 index 00000000..d24f4158 --- /dev/null +++ b/core/conf/properties.go @@ -0,0 +1,109 @@ +package conf + +import ( + "fmt" + "strconv" + "strings" + "sync" + + "zero/core/iox" +) + +// PropertyError represents a configuration error message. +type PropertyError struct { + error + message string +} + +// Properties interface provides the means to access configuration. +type Properties interface { + GetString(key string) string + SetString(key, value string) + GetInt(key string) int + SetInt(key string, value int) + ToString() string +} + +// Properties config is a key/value pair based configuration structure. +type mapBasedProperties struct { + properties map[string]string + lock sync.RWMutex +} + +// Loads the properties into a properties configuration instance. May return the +// configuration itself along with an error that indicates if there was a problem loading the configuration. +func LoadProperties(filename string) (Properties, error) { + lines, err := iox.ReadTextLines(filename, iox.WithoutBlank(), iox.OmitWithPrefix("#")) + if err != nil { + return nil, nil + } + + raw := make(map[string]string) + for i := range lines { + pair := strings.Split(lines[i], "=") + if len(pair) != 2 { + // invalid property format + return nil, &PropertyError{ + message: fmt.Sprintf("invalid property format: %s", pair), + } + } + + key := strings.TrimSpace(pair[0]) + value := strings.TrimSpace(pair[1]) + raw[key] = value + } + + return &mapBasedProperties{ + properties: raw, + }, nil +} + +func (config *mapBasedProperties) GetString(key string) string { + config.lock.RLock() + ret := config.properties[key] + config.lock.RUnlock() + + return ret +} + +func (config *mapBasedProperties) SetString(key, value string) { + config.lock.Lock() + config.properties[key] = value + config.lock.Unlock() +} + +func (config *mapBasedProperties) GetInt(key string) int { + config.lock.RLock() + // default 0 + value, _ := strconv.Atoi(config.properties[key]) + config.lock.RUnlock() + + return value +} + +func (config *mapBasedProperties) SetInt(key string, value int) { + config.lock.Lock() + config.properties[key] = strconv.Itoa(value) + config.lock.Unlock() +} + +// Dumps the configuration internal map into a string. +func (config *mapBasedProperties) ToString() string { + config.lock.RLock() + ret := fmt.Sprintf("%s", config.properties) + config.lock.RUnlock() + + return ret +} + +// Returns the error message. +func (configError *PropertyError) Error() string { + return configError.message +} + +// Builds a new properties configuration structure +func NewProperties() Properties { + return &mapBasedProperties{ + properties: make(map[string]string), + } +} diff --git a/core/conf/properties_test.go b/core/conf/properties_test.go new file mode 100644 index 00000000..1d22631b --- /dev/null +++ b/core/conf/properties_test.go @@ -0,0 +1,44 @@ +package conf + +import ( + "os" + "testing" + + "zero/core/fs" + + "github.com/stretchr/testify/assert" +) + +func TestProperties(t *testing.T) { + text := `app.name = test + + app.program=app + + # this is comment + app.threads = 5` + tmpfile, err := fs.TempFilenameWithText(text) + assert.Nil(t, err) + defer os.Remove(tmpfile) + + props, err := LoadProperties(tmpfile) + assert.Nil(t, err) + assert.Equal(t, "test", props.GetString("app.name")) + assert.Equal(t, "app", props.GetString("app.program")) + assert.Equal(t, 5, props.GetInt("app.threads")) +} + +func TestSetString(t *testing.T) { + key := "a" + value := "the value of a" + props := NewProperties() + props.SetString(key, value) + assert.Equal(t, value, props.GetString(key)) +} + +func TestSetInt(t *testing.T) { + key := "a" + value := 101 + props := NewProperties() + props.SetInt(key, value) + assert.Equal(t, value, props.GetInt(key)) +} diff --git a/core/contextx/deadline.go b/core/contextx/deadline.go new file mode 100644 index 00000000..24bb2bac --- /dev/null +++ b/core/contextx/deadline.go @@ -0,0 +1,17 @@ +package contextx + +import ( + "context" + "time" +) + +func ShrinkDeadline(ctx context.Context, timeout time.Duration) (context.Context, func()) { + if deadline, ok := ctx.Deadline(); ok { + leftTime := time.Until(deadline) + if leftTime < timeout { + timeout = leftTime + } + } + + return context.WithDeadline(ctx, time.Now().Add(timeout)) +} diff --git a/core/contextx/deadline_test.go b/core/contextx/deadline_test.go new file mode 100644 index 00000000..c20a374a --- /dev/null +++ b/core/contextx/deadline_test.go @@ -0,0 +1,27 @@ +package contextx + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestShrinkDeadlineLess(t *testing.T) { + deadline := time.Now().Add(time.Second) + ctx, _ := context.WithDeadline(context.Background(), deadline) + ctx, _ = ShrinkDeadline(ctx, time.Minute) + dl, ok := ctx.Deadline() + assert.True(t, ok) + assert.Equal(t, deadline, dl) +} + +func TestShrinkDeadlineMore(t *testing.T) { + deadline := time.Now().Add(time.Minute) + ctx, _ := context.WithDeadline(context.Background(), deadline) + ctx, _ = ShrinkDeadline(ctx, time.Second) + dl, ok := ctx.Deadline() + assert.True(t, ok) + assert.True(t, dl.Before(deadline)) +} diff --git a/core/contextx/unmarshaler.go b/core/contextx/unmarshaler.go new file mode 100644 index 00000000..ed5f9f28 --- /dev/null +++ b/core/contextx/unmarshaler.go @@ -0,0 +1,26 @@ +package contextx + +import ( + "context" + + "zero/core/mapping" +) + +const contextTagKey = "ctx" + +var unmarshaler = mapping.NewUnmarshaler(contextTagKey) + +type contextValuer struct { + context.Context +} + +func (cv contextValuer) Value(key string) (interface{}, bool) { + v := cv.Context.Value(key) + return v, v != nil +} + +func For(ctx context.Context, v interface{}) error { + return unmarshaler.UnmarshalValuer(contextValuer{ + Context: ctx, + }, v) +} diff --git a/core/contextx/unmarshaler_test.go b/core/contextx/unmarshaler_test.go new file mode 100644 index 00000000..df390236 --- /dev/null +++ b/core/contextx/unmarshaler_test.go @@ -0,0 +1,58 @@ +package contextx + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnmarshalContext(t *testing.T) { + type Person struct { + Name string `ctx:"name"` + Age int `ctx:"age"` + } + + ctx := context.Background() + ctx = context.WithValue(ctx, "name", "kevin") + ctx = context.WithValue(ctx, "age", 20) + + var person Person + err := For(ctx, &person) + + assert.Nil(t, err) + assert.Equal(t, "kevin", person.Name) + assert.Equal(t, 20, person.Age) +} + +func TestUnmarshalContextWithOptional(t *testing.T) { + type Person struct { + Name string `ctx:"name"` + Age int `ctx:"age,optional"` + } + + ctx := context.Background() + ctx = context.WithValue(ctx, "name", "kevin") + + var person Person + err := For(ctx, &person) + + assert.Nil(t, err) + assert.Equal(t, "kevin", person.Name) + assert.Equal(t, 0, person.Age) +} + +func TestUnmarshalContextWithMissing(t *testing.T) { + type Person struct { + Name string `ctx:"name"` + Age int `ctx:"age"` + } + + ctx := context.Background() + ctx = context.WithValue(ctx, "name", "kevin") + + var person Person + err := For(ctx, &person) + + assert.NotNil(t, err) +} diff --git a/core/contextx/valueonlycontext.go b/core/contextx/valueonlycontext.go new file mode 100644 index 00000000..627697e1 --- /dev/null +++ b/core/contextx/valueonlycontext.go @@ -0,0 +1,28 @@ +package contextx + +import ( + "context" + "time" +) + +type valueOnlyContext struct { + context.Context +} + +func (valueOnlyContext) Deadline() (deadline time.Time, ok bool) { + return +} + +func (valueOnlyContext) Done() <-chan struct{} { + return nil +} + +func (valueOnlyContext) Err() error { + return nil +} + +func ValueOnlyFrom(ctx context.Context) context.Context { + return valueOnlyContext{ + Context: ctx, + } +} diff --git a/core/contextx/valueonlycontext_test.go b/core/contextx/valueonlycontext_test.go new file mode 100644 index 00000000..fda5f6de --- /dev/null +++ b/core/contextx/valueonlycontext_test.go @@ -0,0 +1,54 @@ +package contextx + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestContextCancel(t *testing.T) { + c := context.WithValue(context.Background(), "key", "value") + c1, cancel := context.WithCancel(c) + o := ValueOnlyFrom(c1) + c2, _ := context.WithCancel(o) + contexts := []context.Context{c1, c2} + + for _, c := range contexts { + assert.NotNil(t, c.Done()) + assert.Nil(t, c.Err()) + + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + } + + cancel() + <-c1.Done() + + assert.Nil(t, o.Err()) + assert.Equal(t, context.Canceled, c1.Err()) + assert.NotEqual(t, context.Canceled, c2.Err()) +} + +func TestConextDeadline(t *testing.T) { + c, _ := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond)) + o := ValueOnlyFrom(c) + select { + case <-time.After(100 * time.Millisecond): + case <-o.Done(): + t.Fatal("ValueOnlyContext: context should not have timed out") + } + + c, _ = context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond)) + o = ValueOnlyFrom(c) + c, _ = context.WithDeadline(o, time.Now().Add(20*time.Millisecond)) + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("ValueOnlyContext+Deadline: context should have timed out") + case <-c.Done(): + } +} diff --git a/core/discov/clients.go b/core/discov/clients.go new file mode 100644 index 00000000..6f33b019 --- /dev/null +++ b/core/discov/clients.go @@ -0,0 +1,40 @@ +package discov + +import ( + "fmt" + "strings" + + "zero/core/discov/internal" +) + +const ( + indexOfKey = iota + indexOfId +) + +const timeToLive int64 = 10 + +var TimeToLive = timeToLive + +func extract(etcdKey string, index int) (string, bool) { + if index < 0 { + return "", false + } + + fields := strings.FieldsFunc(etcdKey, func(ch rune) bool { + return ch == internal.Delimiter + }) + if index >= len(fields) { + return "", false + } + + return fields[index], true +} + +func extractId(etcdKey string) (string, bool) { + return extract(etcdKey, indexOfId) +} + +func makeEtcdKey(key string, id int64) string { + return fmt.Sprintf("%s%c%d", key, internal.Delimiter, id) +} diff --git a/core/discov/clients_test.go b/core/discov/clients_test.go new file mode 100644 index 00000000..008ff5bf --- /dev/null +++ b/core/discov/clients_test.go @@ -0,0 +1,36 @@ +package discov + +import ( + "sync" + "testing" + + "zero/core/discov/internal" + + "github.com/stretchr/testify/assert" +) + +var mockLock sync.Mutex + +func setMockClient(cli internal.EtcdClient) func() { + mockLock.Lock() + internal.NewClient = func([]string) (internal.EtcdClient, error) { + return cli, nil + } + return func() { + internal.NewClient = internal.DialClient + mockLock.Unlock() + } +} + +func TestExtract(t *testing.T) { + id, ok := extractId("key/123/val") + assert.True(t, ok) + assert.Equal(t, "123", id) + + _, ok = extract("any", -1) + assert.False(t, ok) +} + +func TestMakeKey(t *testing.T) { + assert.Equal(t, "key/123", makeEtcdKey("key", 123)) +} diff --git a/core/discov/config.go b/core/discov/config.go new file mode 100644 index 00000000..e9da8d10 --- /dev/null +++ b/core/discov/config.go @@ -0,0 +1,18 @@ +package discov + +import "errors" + +type EtcdConf struct { + Hosts []string + Key string +} + +func (c EtcdConf) Validate() error { + if len(c.Hosts) == 0 { + return errors.New("empty etcd hosts") + } else if len(c.Key) == 0 { + return errors.New("empty etcd key") + } else { + return nil + } +} diff --git a/core/discov/config_test.go b/core/discov/config_test.go new file mode 100644 index 00000000..5e1733e2 --- /dev/null +++ b/core/discov/config_test.go @@ -0,0 +1,46 @@ +package discov + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfig(t *testing.T) { + tests := []struct { + EtcdConf + pass bool + }{ + { + EtcdConf: EtcdConf{}, + pass: false, + }, + { + EtcdConf: EtcdConf{ + Key: "any", + }, + pass: false, + }, + { + EtcdConf: EtcdConf{ + Hosts: []string{"any"}, + }, + pass: false, + }, + { + EtcdConf: EtcdConf{ + Hosts: []string{"any"}, + Key: "key", + }, + pass: true, + }, + } + + for _, test := range tests { + if test.pass { + assert.Nil(t, test.EtcdConf.Validate()) + } else { + assert.NotNil(t, test.EtcdConf.Validate()) + } + } +} diff --git a/core/discov/facade.go b/core/discov/facade.go new file mode 100644 index 00000000..e9ed9df2 --- /dev/null +++ b/core/discov/facade.go @@ -0,0 +1,47 @@ +package discov + +import ( + "zero/core/discov/internal" + "zero/core/lang" +) + +type ( + Facade struct { + endpoints []string + registry *internal.Registry + } + + FacadeListener interface { + OnAdd(key, val string) + OnDelete(key string) + } +) + +func NewFacade(endpoints []string) Facade { + return Facade{ + endpoints: endpoints, + registry: internal.GetRegistry(), + } +} + +func (f Facade) Client() internal.EtcdClient { + conn, err := f.registry.GetConn(f.endpoints) + lang.Must(err) + return conn +} + +func (f Facade) Monitor(key string, l FacadeListener) { + f.registry.Monitor(f.endpoints, key, listenerAdapter{l}) +} + +type listenerAdapter struct { + l FacadeListener +} + +func (la listenerAdapter) OnAdd(kv internal.KV) { + la.l.OnAdd(kv.Key, kv.Val) +} + +func (la listenerAdapter) OnDelete(kv internal.KV) { + la.l.OnDelete(kv.Key) +} diff --git a/core/discov/internal/balancer.go b/core/discov/internal/balancer.go new file mode 100644 index 00000000..d55133e5 --- /dev/null +++ b/core/discov/internal/balancer.go @@ -0,0 +1,103 @@ +package internal + +import "sync" + +type ( + DialFn func(server string) (interface{}, error) + CloseFn func(server string, conn interface{}) error + + Balancer interface { + AddConn(kv KV) error + IsEmpty() bool + Next(key ...string) (interface{}, bool) + RemoveKey(key string) + initialize() + setListener(listener Listener) + } + + serverConn struct { + key string + conn interface{} + } + + baseBalancer struct { + exclusive bool + servers map[string][]string + mapping map[string]string + lock sync.Mutex + dialFn DialFn + closeFn CloseFn + listener Listener + } +) + +func newBaseBalancer(dialFn DialFn, closeFn CloseFn, exclusive bool) *baseBalancer { + return &baseBalancer{ + exclusive: exclusive, + servers: make(map[string][]string), + mapping: make(map[string]string), + dialFn: dialFn, + closeFn: closeFn, + } +} + +// addKv adds the kv, returns if there are already other keys associate with the server +func (b *baseBalancer) addKv(key, value string) ([]string, bool) { + b.lock.Lock() + defer b.lock.Unlock() + + keys := b.servers[value] + previous := append([]string(nil), keys...) + early := len(keys) > 0 + if b.exclusive && early { + for _, each := range keys { + b.doRemoveKv(each) + } + } + b.servers[value] = append(b.servers[value], key) + b.mapping[key] = value + + if early { + return previous, true + } else { + return nil, false + } +} + +func (b *baseBalancer) doRemoveKv(key string) (server string, keepConn bool) { + server, ok := b.mapping[key] + if !ok { + return "", true + } + + delete(b.mapping, key) + keys := b.servers[server] + remain := keys[:0] + + for _, k := range keys { + if k != key { + remain = append(remain, k) + } + } + + if len(remain) > 0 { + b.servers[server] = remain + return server, true + } else { + delete(b.servers, server) + return server, false + } +} + +func (b *baseBalancer) removeKv(key string) (server string, keepConn bool) { + b.lock.Lock() + defer b.lock.Unlock() + + return b.doRemoveKv(key) +} + +func (b *baseBalancer) setListener(listener Listener) { + b.lock.Lock() + b.listener = listener + b.lock.Unlock() +} diff --git a/core/discov/internal/balancer_test.go b/core/discov/internal/balancer_test.go new file mode 100644 index 00000000..466efc46 --- /dev/null +++ b/core/discov/internal/balancer_test.go @@ -0,0 +1,5 @@ +package internal + +type mockConn struct { + server string +} diff --git a/core/discov/internal/consistentbalancer.go b/core/discov/internal/consistentbalancer.go new file mode 100644 index 00000000..9359ed2b --- /dev/null +++ b/core/discov/internal/consistentbalancer.go @@ -0,0 +1,152 @@ +package internal + +import ( + "zero/core/hash" + "zero/core/logx" +) + +type consistentBalancer struct { + *baseBalancer + conns map[string]interface{} + buckets *hash.ConsistentHash + bucketKey func(KV) string +} + +func NewConsistentBalancer(dialFn DialFn, closeFn CloseFn, keyer func(kv KV) string) *consistentBalancer { + // we don't support exclusive mode for consistent Balancer, to avoid complexity, + // because there are few scenarios, use it on your own risks. + balancer := &consistentBalancer{ + conns: make(map[string]interface{}), + buckets: hash.NewConsistentHash(), + bucketKey: keyer, + } + balancer.baseBalancer = newBaseBalancer(dialFn, closeFn, false) + return balancer +} + +func (b *consistentBalancer) AddConn(kv KV) error { + // not adding kv and conn within a transaction, but it doesn't matter + // we just rollback the kv addition if dial failed + var conn interface{} + prev, found := b.addKv(kv.Key, kv.Val) + if found { + conn = b.handlePrevious(prev) + } + + if conn == nil { + var err error + conn, err = b.dialFn(kv.Val) + if err != nil { + b.removeKv(kv.Key) + return err + } + } + + bucketKey := b.bucketKey(kv) + b.lock.Lock() + defer b.lock.Unlock() + b.conns[bucketKey] = conn + b.buckets.Add(bucketKey) + b.notify(bucketKey) + + logx.Infof("added server, key: %s, server: %s", bucketKey, kv.Val) + + return nil +} + +func (b *consistentBalancer) getConn(key string) (interface{}, bool) { + b.lock.Lock() + conn, ok := b.conns[key] + b.lock.Unlock() + + return conn, ok +} + +func (b *consistentBalancer) handlePrevious(prev []string) interface{} { + if len(prev) == 0 { + return nil + } + + b.lock.Lock() + defer b.lock.Unlock() + + // if not exclusive, only need to randomly find one connection + for key, conn := range b.conns { + if key == prev[0] { + return conn + } + } + + return nil +} + +func (b *consistentBalancer) initialize() { +} + +func (b *consistentBalancer) notify(key string) { + if b.listener == nil { + return + } + + var keys []string + var values []string + for k := range b.conns { + keys = append(keys, k) + } + for _, v := range b.mapping { + values = append(values, v) + } + + b.listener.OnUpdate(keys, values, key) +} + +func (b *consistentBalancer) RemoveKey(key string) { + kv := KV{Key: key} + server, keep := b.removeKv(key) + kv.Val = server + bucketKey := b.bucketKey(kv) + b.buckets.Remove(b.bucketKey(kv)) + + // wrap the query & removal in a function to make sure the quick lock/unlock + conn, ok := func() (interface{}, bool) { + b.lock.Lock() + defer b.lock.Unlock() + + conn, ok := b.conns[bucketKey] + if ok { + delete(b.conns, bucketKey) + } + + return conn, ok + }() + if ok && !keep { + logx.Infof("removing server, key: %s", kv.Key) + if err := b.closeFn(server, conn); err != nil { + logx.Error(err) + } + } + + // notify without new key + b.notify("") +} + +func (b *consistentBalancer) IsEmpty() bool { + b.lock.Lock() + empty := len(b.conns) == 0 + b.lock.Unlock() + + return empty +} + +func (b *consistentBalancer) Next(keys ...string) (interface{}, bool) { + if len(keys) != 1 { + return nil, false + } + + key := keys[0] + if node, ok := b.buckets.Get(key); !ok { + return nil, false + } else { + return b.getConn(node.(string)) + } +} diff --git a/core/discov/internal/consistentbalancer_test.go b/core/discov/internal/consistentbalancer_test.go new file mode 100644 index 00000000..567c1358 --- /dev/null +++ b/core/discov/internal/consistentbalancer_test.go @@ -0,0 +1,178 @@ +package internal + +import ( + "errors" + "sort" + "strconv" + "testing" + + "zero/core/mathx" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestConsistent_addConn(t *testing.T) { + b := NewConsistentBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return errors.New("error") + }, func(kv KV) string { + return kv.Key + }) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey1", + Val: "thevalue", + })) + assert.EqualValues(t, map[string]interface{}{ + "thekey1": mockConn{server: "thevalue"}, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey1"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey1": "thevalue", + }, b.mapping) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey2", + Val: "thevalue", + })) + assert.EqualValues(t, map[string]interface{}{ + "thekey1": mockConn{server: "thevalue"}, + "thekey2": mockConn{server: "thevalue"}, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey1", "thekey2"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey1": "thevalue", + "thekey2": "thevalue", + }, b.mapping) + assert.False(t, b.IsEmpty()) + + b.RemoveKey("thekey1") + assert.EqualValues(t, map[string]interface{}{ + "thekey2": mockConn{server: "thevalue"}, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey2"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey2": "thevalue", + }, b.mapping) + assert.False(t, b.IsEmpty()) + + b.RemoveKey("thekey2") + assert.Equal(t, 0, len(b.conns)) + assert.EqualValues(t, map[string][]string{}, b.servers) + assert.EqualValues(t, map[string]string{}, b.mapping) + assert.True(t, b.IsEmpty()) +} + +func TestConsistent_addConnError(t *testing.T) { + b := NewConsistentBalancer(func(server string) (interface{}, error) { + return nil, errors.New("error") + }, func(server string, conn interface{}) error { + return nil + }, func(kv KV) string { + return kv.Key + }) + assert.NotNil(t, b.AddConn(KV{ + Key: "thekey1", + Val: "thevalue", + })) + assert.Equal(t, 0, len(b.conns)) + assert.EqualValues(t, map[string][]string{}, b.servers) + assert.EqualValues(t, map[string]string{}, b.mapping) +} + +func TestConsistent_next(t *testing.T) { + b := NewConsistentBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return errors.New("error") + }, func(kv KV) string { + return kv.Key + }) + b.initialize() + + _, ok := b.Next("any") + assert.False(t, ok) + + const size = 100 + for i := 0; i < size; i++ { + assert.Nil(t, b.AddConn(KV{ + Key: "thekey/" + strconv.Itoa(i), + Val: "thevalue/" + strconv.Itoa(i), + })) + } + + m := make(map[interface{}]int) + const total = 10000 + for i := 0; i < total; i++ { + val, ok := b.Next(strconv.Itoa(i)) + assert.True(t, ok) + m[val]++ + } + + entropy := mathx.CalcEntropy(m, total) + assert.Equal(t, size, len(m)) + assert.True(t, entropy > .95) + + for i := 0; i < size; i++ { + b.RemoveKey("thekey/" + strconv.Itoa(i)) + } + _, ok = b.Next() + assert.False(t, ok) +} + +func TestConsistentBalancer_Listener(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + b := NewConsistentBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return nil + }, func(kv KV) string { + return kv.Key + }) + assert.Nil(t, b.AddConn(KV{ + Key: "key1", + Val: "val1", + })) + assert.Nil(t, b.AddConn(KV{ + Key: "key2", + Val: "val2", + })) + + listener := NewMockListener(ctrl) + listener.EXPECT().OnUpdate(gomock.Any(), gomock.Any(), "key2").Do(func(keys, vals, _ interface{}) { + sort.Strings(keys.([]string)) + sort.Strings(vals.([]string)) + assert.EqualValues(t, []string{"key1", "key2"}, keys) + assert.EqualValues(t, []string{"val1", "val2"}, vals) + }) + b.setListener(listener) + b.notify("key2") +} + +func TestConsistentBalancer_remove(t *testing.T) { + b := NewConsistentBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return nil + }, func(kv KV) string { + return kv.Key + }) + + assert.Nil(t, b.handlePrevious(nil)) + assert.Nil(t, b.handlePrevious([]string{"any"})) +} diff --git a/core/discov/internal/etcdclient.go b/core/discov/internal/etcdclient.go new file mode 100644 index 00000000..f1f37914 --- /dev/null +++ b/core/discov/internal/etcdclient.go @@ -0,0 +1,21 @@ +//go:generate mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient +package internal + +import ( + "context" + + "go.etcd.io/etcd/clientv3" + "google.golang.org/grpc" +) + +type EtcdClient interface { + ActiveConnection() *grpc.ClientConn + Close() error + Ctx() context.Context + Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) + Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) + KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) + Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) + Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) + Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan +} diff --git a/core/discov/internal/etcdclient_mock.go b/core/discov/internal/etcdclient_mock.go new file mode 100644 index 00000000..2f5af11f --- /dev/null +++ b/core/discov/internal/etcdclient_mock.go @@ -0,0 +1,182 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: etcdclient.go + +// Package internal is a generated GoMock package. +package internal + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + clientv3 "go.etcd.io/etcd/clientv3" + grpc "google.golang.org/grpc" + reflect "reflect" +) + +// MockEtcdClient is a mock of EtcdClient interface +type MockEtcdClient struct { + ctrl *gomock.Controller + recorder *MockEtcdClientMockRecorder +} + +// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient +type MockEtcdClientMockRecorder struct { + mock *MockEtcdClient +} + +// NewMockEtcdClient creates a new mock instance +func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient { + mock := &MockEtcdClient{ctrl: ctrl} + mock.recorder = &MockEtcdClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder { + return m.recorder +} + +// ActiveConnection mocks base method +func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ActiveConnection") + ret0, _ := ret[0].(*grpc.ClientConn) + return ret0 +} + +// ActiveConnection indicates an expected call of ActiveConnection +func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection)) +} + +// Close mocks base method +func (m *MockEtcdClient) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close)) +} + +// Ctx mocks base method +func (m *MockEtcdClient) Ctx() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ctx") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Ctx indicates an expected call of Ctx +func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx)) +} + +// Get mocks base method +func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, key} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*clientv3.GetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockEtcdClientMockRecorder) Get(ctx, key interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, key}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...) +} + +// Grant mocks base method +func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Grant", ctx, ttl) + ret0, _ := ret[0].(*clientv3.LeaseGrantResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Grant indicates an expected call of Grant +func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl) +} + +// KeepAlive mocks base method +func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeepAlive", ctx, id) + ret0, _ := ret[0].(<-chan *clientv3.LeaseKeepAliveResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// KeepAlive indicates an expected call of KeepAlive +func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id) +} + +// Put mocks base method +func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, key, val} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Put", varargs...) + ret0, _ := ret[0].(*clientv3.PutResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Put indicates an expected call of Put +func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, key, val}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...) +} + +// Revoke mocks base method +func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Revoke", ctx, id) + ret0, _ := ret[0].(*clientv3.LeaseRevokeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Revoke indicates an expected call of Revoke +func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id) +} + +// Watch mocks base method +func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, key} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Watch", varargs...) + ret0, _ := ret[0].(clientv3.WatchChan) + return ret0 +} + +// Watch indicates an expected call of Watch +func (mr *MockEtcdClientMockRecorder) Watch(ctx, key interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, key}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...) +} diff --git a/core/discov/internal/listener.go b/core/discov/internal/listener.go new file mode 100644 index 00000000..2f416d04 --- /dev/null +++ b/core/discov/internal/listener.go @@ -0,0 +1,6 @@ +//go:generate mockgen -package internal -destination listener_mock.go -source listener.go Listener +package internal + +type Listener interface { + OnUpdate(keys []string, values []string, newKey string) +} diff --git a/core/discov/internal/listener_mock.go b/core/discov/internal/listener_mock.go new file mode 100644 index 00000000..7e81e87b --- /dev/null +++ b/core/discov/internal/listener_mock.go @@ -0,0 +1,45 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: listener.go + +// Package internal is a generated GoMock package. +package internal + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockListener is a mock of Listener interface +type MockListener struct { + ctrl *gomock.Controller + recorder *MockListenerMockRecorder +} + +// MockListenerMockRecorder is the mock recorder for MockListener +type MockListenerMockRecorder struct { + mock *MockListener +} + +// NewMockListener creates a new mock instance +func NewMockListener(ctrl *gomock.Controller) *MockListener { + mock := &MockListener{ctrl: ctrl} + mock.recorder = &MockListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockListener) EXPECT() *MockListenerMockRecorder { + return m.recorder +} + +// OnUpdate mocks base method +func (m *MockListener) OnUpdate(keys, values []string, newKey string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnUpdate", keys, values, newKey) +} + +// OnUpdate indicates an expected call of OnUpdate +func (mr *MockListenerMockRecorder) OnUpdate(keys, values, newKey interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnUpdate", reflect.TypeOf((*MockListener)(nil).OnUpdate), keys, values, newKey) +} diff --git a/core/discov/internal/registry.go b/core/discov/internal/registry.go new file mode 100644 index 00000000..fdb9f249 --- /dev/null +++ b/core/discov/internal/registry.go @@ -0,0 +1,310 @@ +package internal + +import ( + "context" + "fmt" + "io" + "sort" + "strings" + "sync" + "time" + + "zero/core/contextx" + "zero/core/lang" + "zero/core/logx" + "zero/core/syncx" + "zero/core/threading" + + "go.etcd.io/etcd/clientv3" +) + +var ( + registryInstance = Registry{ + clusters: make(map[string]*cluster), + } + connManager = syncx.NewResourceManager() +) + +type Registry struct { + clusters map[string]*cluster + lock sync.Mutex +} + +func GetRegistry() *Registry { + return ®istryInstance +} + +func (r *Registry) getCluster(endpoints []string) *cluster { + clusterKey := getClusterKey(endpoints) + r.lock.Lock() + defer r.lock.Unlock() + c, ok := r.clusters[clusterKey] + if !ok { + c = newCluster(endpoints) + r.clusters[clusterKey] = c + } + + return c +} + +func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) { + return r.getCluster(endpoints).getClient() +} + +func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error { + return r.getCluster(endpoints).monitor(key, l) +} + +type cluster struct { + endpoints []string + key string + values map[string]map[string]string + listeners map[string][]UpdateListener + watchGroup *threading.RoutineGroup + done chan lang.PlaceholderType + lock sync.Mutex +} + +func newCluster(endpoints []string) *cluster { + return &cluster{ + endpoints: endpoints, + key: getClusterKey(endpoints), + values: make(map[string]map[string]string), + listeners: make(map[string][]UpdateListener), + watchGroup: threading.NewRoutineGroup(), + done: make(chan lang.PlaceholderType), + } +} + +func (c *cluster) context(cli EtcdClient) context.Context { + return contextx.ValueOnlyFrom(cli.Ctx()) +} + +func (c *cluster) getClient() (EtcdClient, error) { + val, err := connManager.GetResource(c.key, func() (io.Closer, error) { + return c.newClient() + }) + if err != nil { + return nil, err + } + + return val.(EtcdClient), nil +} + +func (c *cluster) handleChanges(key string, kvs []KV) { + var add []KV + var remove []KV + c.lock.Lock() + listeners := append([]UpdateListener(nil), c.listeners[key]...) + vals, ok := c.values[key] + if !ok { + add = kvs + vals = make(map[string]string) + for _, kv := range kvs { + vals[kv.Key] = kv.Val + } + c.values[key] = vals + } else { + m := make(map[string]string) + for _, kv := range kvs { + m[kv.Key] = kv.Val + } + for k, v := range vals { + if val, ok := m[k]; !ok || v != val { + remove = append(remove, KV{ + Key: k, + Val: v, + }) + } + } + for k, v := range m { + if val, ok := vals[k]; !ok || v != val { + add = append(add, KV{ + Key: k, + Val: v, + }) + } + } + c.values[key] = m + } + c.lock.Unlock() + + for _, kv := range add { + for _, l := range listeners { + l.OnAdd(kv) + } + } + for _, kv := range remove { + for _, l := range listeners { + l.OnDelete(kv) + } + } +} + +func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) { + c.lock.Lock() + listeners := append([]UpdateListener(nil), c.listeners[key]...) + c.lock.Unlock() + + for _, ev := range events { + switch ev.Type { + case clientv3.EventTypePut: + c.lock.Lock() + if vals, ok := c.values[key]; ok { + vals[string(ev.Kv.Key)] = string(ev.Kv.Value) + } else { + c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)} + } + c.lock.Unlock() + for _, l := range listeners { + l.OnAdd(KV{ + Key: string(ev.Kv.Key), + Val: string(ev.Kv.Value), + }) + } + case clientv3.EventTypeDelete: + if vals, ok := c.values[key]; ok { + delete(vals, string(ev.Kv.Key)) + } + for _, l := range listeners { + l.OnDelete(KV{ + Key: string(ev.Kv.Key), + Val: string(ev.Kv.Value), + }) + } + default: + logx.Errorf("Unknown event type: %v", ev.Type) + } + } +} + +func (c *cluster) load(cli EtcdClient, key string) { + var resp *clientv3.GetResponse + for { + var err error + ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout) + resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix()) + cancel() + if err == nil { + break + } + + logx.Error(err) + time.Sleep(coolDownInterval) + } + + var kvs []KV + c.lock.Lock() + for _, ev := range resp.Kvs { + kvs = append(kvs, KV{ + Key: string(ev.Key), + Val: string(ev.Value), + }) + } + c.lock.Unlock() + + c.handleChanges(key, kvs) +} + +func (c *cluster) monitor(key string, l UpdateListener) error { + c.lock.Lock() + c.listeners[key] = append(c.listeners[key], l) + c.lock.Unlock() + + cli, err := c.getClient() + if err != nil { + return err + } + + c.load(cli, key) + c.watchGroup.Run(func() { + c.watch(cli, key) + }) + + return nil +} + +func (c *cluster) newClient() (EtcdClient, error) { + cli, err := NewClient(c.endpoints) + if err != nil { + return nil, err + } + + go c.watchConnState(cli) + + return cli, nil +} + +func (c *cluster) reload(cli EtcdClient) { + c.lock.Lock() + close(c.done) + c.watchGroup.Wait() + c.done = make(chan lang.PlaceholderType) + c.watchGroup = threading.NewRoutineGroup() + var keys []string + for k := range c.listeners { + keys = append(keys, k) + } + c.lock.Unlock() + + for _, key := range keys { + k := key + c.watchGroup.Run(func() { + c.load(cli, k) + c.watch(cli, k) + }) + } +} + +func (c *cluster) watch(cli EtcdClient, key string) { + rch := cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix()) + for { + select { + case wresp, ok := <-rch: + if !ok { + logx.Error("etcd monitor chan has been closed") + return + } + if wresp.Canceled { + logx.Error("etcd monitor chan has been canceled") + return + } + if wresp.Err() != nil { + logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err())) + return + } + + c.handleWatchEvents(key, wresp.Events) + case <-c.done: + return + } + } +} + +func (c *cluster) watchConnState(cli EtcdClient) { + watcher := newStateWatcher() + watcher.addListener(func() { + go c.reload(cli) + }) + watcher.watch(cli.ActiveConnection()) +} + +func DialClient(endpoints []string) (EtcdClient, error) { + return clientv3.New(clientv3.Config{ + Endpoints: endpoints, + AutoSyncInterval: autoSyncInterval, + DialTimeout: DialTimeout, + DialKeepAliveTime: dialKeepAliveTime, + DialKeepAliveTimeout: DialTimeout, + RejectOldCluster: true, + }) +} + +func getClusterKey(endpoints []string) string { + sort.Strings(endpoints) + return strings.Join(endpoints, endpointsSeparator) +} + +func makeKeyPrefix(key string) string { + return fmt.Sprintf("%s%c", key, Delimiter) +} diff --git a/core/discov/internal/registry_test.go b/core/discov/internal/registry_test.go new file mode 100644 index 00000000..14966898 --- /dev/null +++ b/core/discov/internal/registry_test.go @@ -0,0 +1,245 @@ +package internal + +import ( + "context" + "sync" + "testing" + + "zero/core/contextx" + "zero/core/stringx" + + "github.com/coreos/etcd/mvcc/mvccpb" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/clientv3" +) + +var mockLock sync.Mutex + +func setMockClient(cli EtcdClient) func() { + mockLock.Lock() + NewClient = func([]string) (EtcdClient, error) { + return cli, nil + } + return func() { + NewClient = DialClient + mockLock.Unlock() + } +} + +func TestGetCluster(t *testing.T) { + c1 := GetRegistry().getCluster([]string{"first"}) + c2 := GetRegistry().getCluster([]string{"second"}) + c3 := GetRegistry().getCluster([]string{"first"}) + assert.Equal(t, c1, c3) + assert.NotEqual(t, c1, c2) +} + +func TestGetClusterKey(t *testing.T) { + assert.Equal(t, getClusterKey([]string{"localhost:1234", "remotehost:5678"}), + getClusterKey([]string{"remotehost:5678", "localhost:1234"})) +} + +func TestCluster_HandleChanges(t *testing.T) { + ctrl := gomock.NewController(t) + l := NewMockUpdateListener(ctrl) + l.EXPECT().OnAdd(KV{ + Key: "first", + Val: "1", + }) + l.EXPECT().OnAdd(KV{ + Key: "second", + Val: "2", + }) + l.EXPECT().OnDelete(KV{ + Key: "first", + Val: "1", + }) + l.EXPECT().OnDelete(KV{ + Key: "second", + Val: "2", + }) + l.EXPECT().OnAdd(KV{ + Key: "third", + Val: "3", + }) + l.EXPECT().OnAdd(KV{ + Key: "fourth", + Val: "4", + }) + c := newCluster([]string{"any"}) + c.listeners["any"] = []UpdateListener{l} + c.handleChanges("any", []KV{ + { + Key: "first", + Val: "1", + }, + { + Key: "second", + Val: "2", + }, + }) + assert.EqualValues(t, map[string]string{ + "first": "1", + "second": "2", + }, c.values["any"]) + c.handleChanges("any", []KV{ + { + Key: "third", + Val: "3", + }, + { + Key: "fourth", + Val: "4", + }, + }) + assert.EqualValues(t, map[string]string{ + "third": "3", + "fourth": "4", + }, c.values["any"]) +} + +func TestCluster_Load(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Get(gomock.Any(), "any/", gomock.Any()).Return(&clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + { + Key: []byte("hello"), + Value: []byte("world"), + }, + }, + }, nil) + cli.EXPECT().Ctx().Return(context.Background()) + c := &cluster{ + values: make(map[string]map[string]string), + } + c.load(cli, "any") +} + +func TestCluster_Watch(t *testing.T) { + tests := []struct { + name string + method int + eventType mvccpb.Event_EventType + }{ + { + name: "add", + eventType: clientv3.EventTypePut, + }, + { + name: "delete", + eventType: clientv3.EventTypeDelete, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + ch := make(chan clientv3.WatchResponse) + cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch) + cli.EXPECT().Ctx().Return(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + c := &cluster{ + listeners: make(map[string][]UpdateListener), + values: make(map[string]map[string]string), + } + listener := NewMockUpdateListener(ctrl) + c.listeners["any"] = []UpdateListener{listener} + listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) { + assert.Equal(t, "hello", kv.Key) + assert.Equal(t, "world", kv.Val) + wg.Done() + }).MaxTimes(1) + listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) { + wg.Done() + }).MaxTimes(1) + go c.watch(cli, "any") + ch <- clientv3.WatchResponse{ + Events: []*clientv3.Event{ + { + Type: test.eventType, + Kv: &mvccpb.KeyValue{ + Key: []byte("hello"), + Value: []byte("world"), + }, + }, + }, + } + wg.Wait() + }) + } +} + +func TestClusterWatch_RespFailures(t *testing.T) { + resps := []clientv3.WatchResponse{ + { + Canceled: true, + }, + { + // cause resp.Err() != nil + CompactRevision: 1, + }, + } + + for _, resp := range resps { + t.Run(stringx.Rand(), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + ch := make(chan clientv3.WatchResponse) + cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch) + cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() + c := new(cluster) + go func() { + ch <- resp + }() + c.watch(cli, "any") + }) + } +} + +func TestClusterWatch_CloseChan(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + ch := make(chan clientv3.WatchResponse) + cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch) + cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() + c := new(cluster) + go func() { + close(ch) + }() + c.watch(cli, "any") +} + +func TestValueOnlyContext(t *testing.T) { + ctx := contextx.ValueOnlyFrom(context.Background()) + ctx.Done() + assert.Nil(t, ctx.Err()) +} + +type mockedSharedCalls struct { + fn func() (interface{}, error) +} + +func (c mockedSharedCalls) Do(_ string, fn func() (interface{}, error)) (interface{}, error) { + return c.fn() +} + +func (c mockedSharedCalls) DoEx(_ string, fn func() (interface{}, error)) (interface{}, bool, error) { + val, err := c.fn() + return val, true, err +} diff --git a/core/discov/internal/roundrobinbalancer.go b/core/discov/internal/roundrobinbalancer.go new file mode 100644 index 00000000..c908efae --- /dev/null +++ b/core/discov/internal/roundrobinbalancer.go @@ -0,0 +1,148 @@ +package internal + +import ( + "math/rand" + "time" + + "zero/core/logx" +) + +type roundRobinBalancer struct { + *baseBalancer + conns []serverConn + index int +} + +func NewRoundRobinBalancer(dialFn DialFn, closeFn CloseFn, exclusive bool) *roundRobinBalancer { + balancer := new(roundRobinBalancer) + balancer.baseBalancer = newBaseBalancer(dialFn, closeFn, exclusive) + return balancer +} + +func (b *roundRobinBalancer) AddConn(kv KV) error { + var conn interface{} + prev, found := b.addKv(kv.Key, kv.Val) + if found { + conn = b.handlePrevious(prev, kv.Val) + } + + if conn == nil { + var err error + conn, err = b.dialFn(kv.Val) + if err != nil { + b.removeKv(kv.Key) + return err + } + } + + b.lock.Lock() + defer b.lock.Unlock() + b.conns = append(b.conns, serverConn{ + key: kv.Key, + conn: conn, + }) + b.notify(kv.Key) + + return nil +} + +func (b *roundRobinBalancer) handlePrevious(prev []string, server string) interface{} { + if len(prev) == 0 { + return nil + } + + b.lock.Lock() + defer b.lock.Unlock() + + if b.exclusive { + for _, item := range prev { + conns := b.conns[:0] + for _, each := range b.conns { + if each.key == item { + if err := b.closeFn(server, each.conn); err != nil { + logx.Error(err) + } + } else { + conns = append(conns, each) + } + } + b.conns = conns + } + } else { + for _, each := range b.conns { + if each.key == prev[0] { + return each.conn + } + } + } + + return nil +} + +func (b *roundRobinBalancer) initialize() { + rand.Seed(time.Now().UnixNano()) + if len(b.conns) > 0 { + b.index = rand.Intn(len(b.conns)) + } +} + +func (b *roundRobinBalancer) IsEmpty() bool { + b.lock.Lock() + empty := len(b.conns) == 0 + b.lock.Unlock() + + return empty +} + +func (b *roundRobinBalancer) Next(...string) (interface{}, bool) { + b.lock.Lock() + defer b.lock.Unlock() + + if len(b.conns) == 0 { + return nil, false + } + + b.index = (b.index + 1) % len(b.conns) + return b.conns[b.index].conn, true +} + +func (b *roundRobinBalancer) notify(key string) { + if b.listener == nil { + return + } + + // b.servers has the format of map[conn][]key + var keys []string + var values []string + for k, v := range b.servers { + values = append(values, k) + keys = append(keys, v...) + } + + b.listener.OnUpdate(keys, values, key) +} + +func (b *roundRobinBalancer) RemoveKey(key string) { + server, keep := b.removeKv(key) + + b.lock.Lock() + defer b.lock.Unlock() + + conns := b.conns[:0] + for _, conn := range b.conns { + if conn.key == key { + // there are other keys assocated with the conn, don't close the conn. + if keep { + continue + } + if err := b.closeFn(server, conn.conn); err != nil { + logx.Error(err) + } + } else { + conns = append(conns, conn) + } + } + b.conns = conns + // notify without new key + b.notify("") +} diff --git a/core/discov/internal/roundrobinbalancer_test.go b/core/discov/internal/roundrobinbalancer_test.go new file mode 100644 index 00000000..ef3037e1 --- /dev/null +++ b/core/discov/internal/roundrobinbalancer_test.go @@ -0,0 +1,321 @@ +package internal + +import ( + "errors" + "sort" + "strconv" + "testing" + + "zero/core/logx" + "zero/core/mathx" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func init() { + logx.Disable() +} + +func TestRoundRobin_addConn(t *testing.T) { + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return errors.New("error") + }, false) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey1", + Val: "thevalue", + })) + assert.EqualValues(t, []serverConn{ + { + key: "thekey1", + conn: mockConn{server: "thevalue"}, + }, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey1"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey1": "thevalue", + }, b.mapping) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey2", + Val: "thevalue", + })) + assert.EqualValues(t, []serverConn{ + { + key: "thekey1", + conn: mockConn{server: "thevalue"}, + }, + { + key: "thekey2", + conn: mockConn{server: "thevalue"}, + }, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey1", "thekey2"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey1": "thevalue", + "thekey2": "thevalue", + }, b.mapping) + assert.False(t, b.IsEmpty()) + + b.RemoveKey("thekey1") + assert.EqualValues(t, []serverConn{ + { + key: "thekey2", + conn: mockConn{server: "thevalue"}, + }, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey2"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey2": "thevalue", + }, b.mapping) + assert.False(t, b.IsEmpty()) + + b.RemoveKey("thekey2") + assert.EqualValues(t, []serverConn{}, b.conns) + assert.EqualValues(t, map[string][]string{}, b.servers) + assert.EqualValues(t, map[string]string{}, b.mapping) + assert.True(t, b.IsEmpty()) +} + +func TestRoundRobin_addConnExclusive(t *testing.T) { + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return nil + }, true) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey1", + Val: "thevalue", + })) + assert.EqualValues(t, []serverConn{ + { + key: "thekey1", + conn: mockConn{server: "thevalue"}, + }, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey1"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey1": "thevalue", + }, b.mapping) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey2", + Val: "thevalue", + })) + assert.EqualValues(t, []serverConn{ + { + key: "thekey2", + conn: mockConn{server: "thevalue"}, + }, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey2"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey2": "thevalue", + }, b.mapping) + assert.False(t, b.IsEmpty()) + + b.RemoveKey("thekey1") + b.RemoveKey("thekey2") + assert.EqualValues(t, []serverConn{}, b.conns) + assert.EqualValues(t, map[string][]string{}, b.servers) + assert.EqualValues(t, map[string]string{}, b.mapping) + assert.True(t, b.IsEmpty()) +} + +func TestRoundRobin_addConnDupExclusive(t *testing.T) { + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return errors.New("error") + }, true) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey1", + Val: "thevalue", + })) + assert.EqualValues(t, []serverConn{ + { + key: "thekey1", + conn: mockConn{server: "thevalue"}, + }, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "thevalue": {"thekey1"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey1": "thevalue", + }, b.mapping) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey", + Val: "anothervalue", + })) + assert.Nil(t, b.AddConn(KV{ + Key: "thekey1", + Val: "thevalue", + })) + assert.EqualValues(t, []serverConn{ + { + key: "thekey", + conn: mockConn{server: "anothervalue"}, + }, + { + key: "thekey1", + conn: mockConn{server: "thevalue"}, + }, + }, b.conns) + assert.EqualValues(t, map[string][]string{ + "anothervalue": {"thekey"}, + "thevalue": {"thekey1"}, + }, b.servers) + assert.EqualValues(t, map[string]string{ + "thekey": "anothervalue", + "thekey1": "thevalue", + }, b.mapping) + assert.False(t, b.IsEmpty()) + + b.RemoveKey("thekey") + b.RemoveKey("thekey1") + assert.EqualValues(t, []serverConn{}, b.conns) + assert.EqualValues(t, map[string][]string{}, b.servers) + assert.EqualValues(t, map[string]string{}, b.mapping) + assert.True(t, b.IsEmpty()) +} + +func TestRoundRobin_addConnError(t *testing.T) { + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return nil, errors.New("error") + }, func(server string, conn interface{}) error { + return nil + }, true) + assert.NotNil(t, b.AddConn(KV{ + Key: "thekey1", + Val: "thevalue", + })) + assert.Nil(t, b.conns) + assert.EqualValues(t, map[string][]string{}, b.servers) + assert.EqualValues(t, map[string]string{}, b.mapping) +} + +func TestRoundRobin_initialize(t *testing.T) { + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return nil + }, true) + for i := 0; i < 100; i++ { + assert.Nil(t, b.AddConn(KV{ + Key: "thekey/" + strconv.Itoa(i), + Val: "thevalue/" + strconv.Itoa(i), + })) + } + + m := make(map[int]int) + const total = 1000 + for i := 0; i < total; i++ { + b.initialize() + m[b.index]++ + } + + mi := make(map[interface{}]int, len(m)) + for k, v := range m { + mi[k] = v + } + entropy := mathx.CalcEntropy(mi, total) + assert.True(t, entropy > .95) +} + +func TestRoundRobin_next(t *testing.T) { + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return errors.New("error") + }, true) + const size = 100 + for i := 0; i < size; i++ { + assert.Nil(t, b.AddConn(KV{ + Key: "thekey/" + strconv.Itoa(i), + Val: "thevalue/" + strconv.Itoa(i), + })) + } + + m := make(map[interface{}]int) + const total = 10000 + for i := 0; i < total; i++ { + val, ok := b.Next() + assert.True(t, ok) + m[val]++ + } + + entropy := mathx.CalcEntropy(m, total) + assert.Equal(t, size, len(m)) + assert.True(t, entropy > .95) + + for i := 0; i < size; i++ { + b.RemoveKey("thekey/" + strconv.Itoa(i)) + } + _, ok := b.Next() + assert.False(t, ok) +} + +func TestRoundRobinBalancer_Listener(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return nil + }, true) + assert.Nil(t, b.AddConn(KV{ + Key: "key1", + Val: "val1", + })) + assert.Nil(t, b.AddConn(KV{ + Key: "key2", + Val: "val2", + })) + + listener := NewMockListener(ctrl) + listener.EXPECT().OnUpdate(gomock.Any(), gomock.Any(), "key2").Do(func(keys, vals, _ interface{}) { + sort.Strings(vals.([]string)) + sort.Strings(keys.([]string)) + assert.EqualValues(t, []string{"key1", "key2"}, keys) + assert.EqualValues(t, []string{"val1", "val2"}, vals) + }) + b.setListener(listener) + b.notify("key2") +} + +func TestRoundRobinBalancer_remove(t *testing.T) { + b := NewRoundRobinBalancer(func(server string) (interface{}, error) { + return mockConn{ + server: server, + }, nil + }, func(server string, conn interface{}) error { + return nil + }, true) + + assert.Nil(t, b.handlePrevious(nil, "any")) + _, ok := b.doRemoveKv("any") + assert.True(t, ok) +} diff --git a/core/discov/internal/statewatcher.go b/core/discov/internal/statewatcher.go new file mode 100644 index 00000000..d1b727fa --- /dev/null +++ b/core/discov/internal/statewatcher.go @@ -0,0 +1,58 @@ +//go:generate mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn +package internal + +import ( + "context" + "sync" + + "google.golang.org/grpc/connectivity" +) + +type ( + etcdConn interface { + GetState() connectivity.State + WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool + } + + stateWatcher struct { + disconnected bool + currentState connectivity.State + listeners []func() + lock sync.Mutex + } +) + +func newStateWatcher() *stateWatcher { + return new(stateWatcher) +} + +func (sw *stateWatcher) addListener(l func()) { + sw.lock.Lock() + sw.listeners = append(sw.listeners, l) + sw.lock.Unlock() +} + +func (sw *stateWatcher) watch(conn etcdConn) { + sw.currentState = conn.GetState() + for { + if conn.WaitForStateChange(context.Background(), sw.currentState) { + newState := conn.GetState() + sw.lock.Lock() + sw.currentState = newState + + switch newState { + case connectivity.TransientFailure, connectivity.Shutdown: + sw.disconnected = true + case connectivity.Ready: + if sw.disconnected { + sw.disconnected = false + for _, l := range sw.listeners { + l() + } + } + } + + sw.lock.Unlock() + } + } +} diff --git a/core/discov/internal/statewatcher_mock.go b/core/discov/internal/statewatcher_mock.go new file mode 100644 index 00000000..7c0ebe3d --- /dev/null +++ b/core/discov/internal/statewatcher_mock.go @@ -0,0 +1,63 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: statewatcher.go + +// Package internal is a generated GoMock package. +package internal + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + connectivity "google.golang.org/grpc/connectivity" + reflect "reflect" +) + +// MocketcdConn is a mock of etcdConn interface +type MocketcdConn struct { + ctrl *gomock.Controller + recorder *MocketcdConnMockRecorder +} + +// MocketcdConnMockRecorder is the mock recorder for MocketcdConn +type MocketcdConnMockRecorder struct { + mock *MocketcdConn +} + +// NewMocketcdConn creates a new mock instance +func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn { + mock := &MocketcdConn{ctrl: ctrl} + mock.recorder = &MocketcdConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder { + return m.recorder +} + +// GetState mocks base method +func (m *MocketcdConn) GetState() connectivity.State { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetState") + ret0, _ := ret[0].(connectivity.State) + return ret0 +} + +// GetState indicates an expected call of GetState +func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState)) +} + +// WaitForStateChange mocks base method +func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState) + ret0, _ := ret[0].(bool) + return ret0 +} + +// WaitForStateChange indicates an expected call of WaitForStateChange +func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState) +} diff --git a/core/discov/internal/statewatcher_test.go b/core/discov/internal/statewatcher_test.go new file mode 100644 index 00000000..64761dbf --- /dev/null +++ b/core/discov/internal/statewatcher_test.go @@ -0,0 +1,27 @@ +package internal + +import ( + "sync" + "testing" + + "github.com/golang/mock/gomock" + "google.golang.org/grpc/connectivity" +) + +func TestStateWatcher_watch(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + watcher := newStateWatcher() + var wg sync.WaitGroup + wg.Add(1) + watcher.addListener(func() { + wg.Done() + }) + conn := NewMocketcdConn(ctrl) + conn.EXPECT().GetState().Return(connectivity.Ready) + conn.EXPECT().GetState().Return(connectivity.TransientFailure) + conn.EXPECT().GetState().Return(connectivity.Ready).AnyTimes() + conn.EXPECT().WaitForStateChange(gomock.Any(), gomock.Any()).Return(true).AnyTimes() + go watcher.watch(conn) + wg.Wait() +} diff --git a/core/discov/internal/updatelistener.go b/core/discov/internal/updatelistener.go new file mode 100644 index 00000000..b1bbcd3e --- /dev/null +++ b/core/discov/internal/updatelistener.go @@ -0,0 +1,14 @@ +//go:generate mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener +package internal + +type ( + KV struct { + Key string + Val string + } + + UpdateListener interface { + OnAdd(kv KV) + OnDelete(kv KV) + } +) diff --git a/core/discov/internal/updatelistener_mock.go b/core/discov/internal/updatelistener_mock.go new file mode 100644 index 00000000..392faa78 --- /dev/null +++ b/core/discov/internal/updatelistener_mock.go @@ -0,0 +1,57 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: updatelistener.go + +// Package internal is a generated GoMock package. +package internal + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockUpdateListener is a mock of UpdateListener interface +type MockUpdateListener struct { + ctrl *gomock.Controller + recorder *MockUpdateListenerMockRecorder +} + +// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener +type MockUpdateListenerMockRecorder struct { + mock *MockUpdateListener +} + +// NewMockUpdateListener creates a new mock instance +func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener { + mock := &MockUpdateListener{ctrl: ctrl} + mock.recorder = &MockUpdateListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder { + return m.recorder +} + +// OnAdd mocks base method +func (m *MockUpdateListener) OnAdd(kv KV) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnAdd", kv) +} + +// OnAdd indicates an expected call of OnAdd +func (mr *MockUpdateListenerMockRecorder) OnAdd(kv interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv) +} + +// OnDelete mocks base method +func (m *MockUpdateListener) OnDelete(kv KV) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnDelete", kv) +} + +// OnDelete indicates an expected call of OnDelete +func (mr *MockUpdateListenerMockRecorder) OnDelete(kv interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv) +} diff --git a/core/discov/internal/vars.go b/core/discov/internal/vars.go new file mode 100644 index 00000000..719612e6 --- /dev/null +++ b/core/discov/internal/vars.go @@ -0,0 +1,19 @@ +package internal + +import "time" + +const ( + autoSyncInterval = time.Minute + coolDownInterval = time.Second + dialTimeout = 5 * time.Second + dialKeepAliveTime = 5 * time.Second + requestTimeout = 3 * time.Second + Delimiter = '/' + endpointsSeparator = "," +) + +var ( + DialTimeout = dialTimeout + RequestTimeout = requestTimeout + NewClient = DialClient +) diff --git a/core/discov/kubernetes/discov-namespace.yaml b/core/discov/kubernetes/discov-namespace.yaml new file mode 100644 index 00000000..16b397bc --- /dev/null +++ b/core/discov/kubernetes/discov-namespace.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: discov diff --git a/core/discov/kubernetes/etcd.yaml b/core/discov/kubernetes/etcd.yaml new file mode 100644 index 00000000..dcfaef7c --- /dev/null +++ b/core/discov/kubernetes/etcd.yaml @@ -0,0 +1,378 @@ +apiVersion: v1 +kind: Service +metadata: + name: etcd + namespace: discov +spec: + ports: + - name: etcd-port + port: 2379 + protocol: TCP + targetPort: 2379 + selector: + app: etcd + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd0 + name: etcd0 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd0 + - --initial-advertise-peer-urls + - http://etcd0:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd0:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: etcd0 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - etcd + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd0 + name: etcd0 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd0 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd1 + name: etcd1 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd1 + - --initial-advertise-peer-urls + - http://etcd1:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd1:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: etcd1 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - etcd + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd1 + name: etcd1 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd1 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd2 + name: etcd2 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd2 + - --initial-advertise-peer-urls + - http://etcd2:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd2:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: etcd2 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - etcd + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd2 + name: etcd2 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd2 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd3 + name: etcd3 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd3 + - --initial-advertise-peer-urls + - http://etcd3:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd3:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: etcd3 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - etcd + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd3 + name: etcd3 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd3 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd4 + name: etcd4 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd4 + - --initial-advertise-peer-urls + - http://etcd4:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd4:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: etcd4 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - etcd + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd4 + name: etcd4 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd4 diff --git a/core/discov/publisher.go b/core/discov/publisher.go new file mode 100644 index 00000000..955f5f9c --- /dev/null +++ b/core/discov/publisher.go @@ -0,0 +1,143 @@ +package discov + +import ( + "zero/core/discov/internal" + "zero/core/lang" + "zero/core/logx" + "zero/core/proc" + "zero/core/syncx" + "zero/core/threading" + + "go.etcd.io/etcd/clientv3" +) + +type ( + PublisherOption func(client *Publisher) + + Publisher struct { + endpoints []string + key string + fullKey string + id int64 + value string + lease clientv3.LeaseID + quit *syncx.DoneChan + pauseChan chan lang.PlaceholderType + resumeChan chan lang.PlaceholderType + } +) + +func NewPublisher(endpoints []string, key, value string, opts ...PublisherOption) *Publisher { + publisher := &Publisher{ + endpoints: endpoints, + key: key, + value: value, + quit: syncx.NewDoneChan(), + pauseChan: make(chan lang.PlaceholderType), + resumeChan: make(chan lang.PlaceholderType), + } + + for _, opt := range opts { + opt(publisher) + } + + return publisher +} + +func (p *Publisher) KeepAlive() error { + cli, err := internal.GetRegistry().GetConn(p.endpoints) + if err != nil { + return err + } + + p.lease, err = p.register(cli) + if err != nil { + return err + } + + proc.AddWrapUpListener(func() { + p.Stop() + }) + + return p.keepAliveAsync(cli) +} + +func (p *Publisher) Pause() { + p.pauseChan <- lang.Placeholder +} + +func (p *Publisher) Resume() { + p.resumeChan <- lang.Placeholder +} + +func (p *Publisher) Stop() { + p.quit.Close() +} + +func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error { + ch, err := cli.KeepAlive(cli.Ctx(), p.lease) + if err != nil { + return err + } + + threading.GoSafe(func() { + for { + select { + case _, ok := <-ch: + if !ok { + p.revoke(cli) + if err := p.KeepAlive(); err != nil { + logx.Errorf("KeepAlive: %s", err.Error()) + } + return + } + case <-p.pauseChan: + logx.Infof("paused etcd renew, key: %s, value: %s", p.key, p.value) + p.revoke(cli) + select { + case <-p.resumeChan: + if err := p.KeepAlive(); err != nil { + logx.Errorf("KeepAlive: %s", err.Error()) + } + return + case <-p.quit.Done(): + return + } + case <-p.quit.Done(): + p.revoke(cli) + return + } + } + }) + + return nil +} + +func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, error) { + resp, err := client.Grant(client.Ctx(), TimeToLive) + if err != nil { + return clientv3.NoLease, err + } + + lease := resp.ID + if p.id > 0 { + p.fullKey = makeEtcdKey(p.key, p.id) + } else { + p.fullKey = makeEtcdKey(p.key, int64(lease)) + } + _, err = client.Put(client.Ctx(), p.fullKey, p.value, clientv3.WithLease(lease)) + + return lease, err +} + +func (p *Publisher) revoke(cli internal.EtcdClient) { + if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil { + logx.Error(err) + } +} + +func WithId(id int64) PublisherOption { + return func(publisher *Publisher) { + publisher.id = id + } +} diff --git a/core/discov/publisher_test.go b/core/discov/publisher_test.go new file mode 100644 index 00000000..350f6154 --- /dev/null +++ b/core/discov/publisher_test.go @@ -0,0 +1,151 @@ +package discov + +import ( + "errors" + "sync" + "testing" + + "zero/core/discov/internal" + "zero/core/logx" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/clientv3" +) + +func init() { + logx.Disable() +} + +func TestPublisher_register(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const id = 1 + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{ + ID: id, + }, nil) + cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any()) + pub := NewPublisher(nil, "thekey", "thevalue") + _, err := pub.register(cli) + assert.Nil(t, err) +} + +func TestPublisher_registerWithId(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const id = 2 + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{ + ID: 1, + }, nil) + cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any()) + pub := NewPublisher(nil, "thekey", "thevalue", WithId(id)) + _, err := pub.register(cli) + assert.Nil(t, err) +} + +func TestPublisher_registerError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(nil, errors.New("error")) + pub := NewPublisher(nil, "thekey", "thevalue") + val, err := pub.register(cli) + assert.NotNil(t, err) + assert.Equal(t, clientv3.NoLease, val) +} + +func TestPublisher_revoke(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const id clientv3.LeaseID = 1 + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().Revoke(gomock.Any(), id) + pub := NewPublisher(nil, "thekey", "thevalue") + pub.lease = id + pub.revoke(cli) +} + +func TestPublisher_revokeError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const id clientv3.LeaseID = 1 + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().Revoke(gomock.Any(), id).Return(nil, errors.New("error")) + pub := NewPublisher(nil, "thekey", "thevalue") + pub.lease = id + pub.revoke(cli) +} + +func TestPublisher_keepAliveAsyncError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const id clientv3.LeaseID = 1 + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().KeepAlive(gomock.Any(), id).Return(nil, errors.New("error")) + pub := NewPublisher(nil, "thekey", "thevalue") + pub.lease = id + assert.NotNil(t, pub.keepAliveAsync(cli)) +} + +func TestPublisher_keepAliveAsyncQuit(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const id clientv3.LeaseID = 1 + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().KeepAlive(gomock.Any(), id) + var wg sync.WaitGroup + wg.Add(1) + cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) { + wg.Done() + }) + pub := NewPublisher(nil, "thekey", "thevalue") + pub.lease = id + pub.Stop() + assert.Nil(t, pub.keepAliveAsync(cli)) + wg.Wait() +} + +func TestPublisher_keepAliveAsyncPause(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const id clientv3.LeaseID = 1 + cli := internal.NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + cli.EXPECT().Ctx().AnyTimes() + cli.EXPECT().KeepAlive(gomock.Any(), id) + pub := NewPublisher(nil, "thekey", "thevalue") + var wg sync.WaitGroup + wg.Add(1) + cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) { + pub.Stop() + wg.Done() + }) + pub.lease = id + assert.Nil(t, pub.keepAliveAsync(cli)) + pub.Pause() + wg.Wait() +} diff --git a/core/discov/renewer.go b/core/discov/renewer.go new file mode 100644 index 00000000..c0d56ba5 --- /dev/null +++ b/core/discov/renewer.go @@ -0,0 +1,35 @@ +package discov + +import "zero/core/logx" + +type ( + Renewer interface { + Start() + Stop() + Pause() + Resume() + } + + etcdRenewer struct { + *Publisher + } +) + +func NewRenewer(endpoints []string, key, value string, renewId int64) Renewer { + var publisher *Publisher + if renewId > 0 { + publisher = NewPublisher(endpoints, key, value, WithId(renewId)) + } else { + publisher = NewPublisher(endpoints, key, value) + } + + return &etcdRenewer{ + Publisher: publisher, + } +} + +func (sr *etcdRenewer) Start() { + if err := sr.KeepAlive(); err != nil { + logx.Error(err) + } +} diff --git a/core/discov/subclient.go b/core/discov/subclient.go new file mode 100644 index 00000000..c3290d30 --- /dev/null +++ b/core/discov/subclient.go @@ -0,0 +1,186 @@ +package discov + +import ( + "sync" + + "zero/core/discov/internal" + "zero/core/logx" +) + +const ( + _ = iota // keyBasedBalance, default + idBasedBalance +) + +type ( + Listener internal.Listener + + subClient struct { + balancer internal.Balancer + lock sync.Mutex + cond *sync.Cond + listeners []internal.Listener + } + + balanceOptions struct { + balanceType int + } + + BalanceOption func(*balanceOptions) + + RoundRobinSubClient struct { + *subClient + } + + ConsistentSubClient struct { + *subClient + } + + BatchConsistentSubClient struct { + *ConsistentSubClient + } +) + +func NewRoundRobinSubClient(endpoints []string, key string, dialFn internal.DialFn, closeFn internal.CloseFn, + opts ...SubOption) (*RoundRobinSubClient, error) { + var subOpts subOptions + for _, opt := range opts { + opt(&subOpts) + } + + cli, err := newSubClient(endpoints, key, internal.NewRoundRobinBalancer(dialFn, closeFn, subOpts.exclusive)) + if err != nil { + return nil, err + } + + return &RoundRobinSubClient{ + subClient: cli, + }, nil +} + +func NewConsistentSubClient(endpoints []string, key string, dialFn internal.DialFn, + closeFn internal.CloseFn, opts ...BalanceOption) (*ConsistentSubClient, error) { + var balanceOpts balanceOptions + for _, opt := range opts { + opt(&balanceOpts) + } + + var keyer func(internal.KV) string + switch balanceOpts.balanceType { + case idBasedBalance: + keyer = func(kv internal.KV) string { + if id, ok := extractId(kv.Key); ok { + return id + } else { + return kv.Key + } + } + default: + keyer = func(kv internal.KV) string { + return kv.Val + } + } + + cli, err := newSubClient(endpoints, key, internal.NewConsistentBalancer(dialFn, closeFn, keyer)) + if err != nil { + return nil, err + } + + return &ConsistentSubClient{ + subClient: cli, + }, nil +} + +func NewBatchConsistentSubClient(endpoints []string, key string, dialFn internal.DialFn, closeFn internal.CloseFn, + opts ...BalanceOption) (*BatchConsistentSubClient, error) { + cli, err := NewConsistentSubClient(endpoints, key, dialFn, closeFn, opts...) + if err != nil { + return nil, err + } + + return &BatchConsistentSubClient{ + ConsistentSubClient: cli, + }, nil +} + +func newSubClient(endpoints []string, key string, balancer internal.Balancer) (*subClient, error) { + client := &subClient{ + balancer: balancer, + } + client.cond = sync.NewCond(&client.lock) + if err := internal.GetRegistry().Monitor(endpoints, key, client); err != nil { + return nil, err + } + + return client, nil +} + +func (c *subClient) AddListener(listener internal.Listener) { + c.lock.Lock() + c.listeners = append(c.listeners, listener) + c.lock.Unlock() +} + +func (c *subClient) OnAdd(kv internal.KV) { + c.lock.Lock() + defer c.lock.Unlock() + + if err := c.balancer.AddConn(kv); err != nil { + logx.Error(err) + } else { + c.cond.Broadcast() + } +} + +func (c *subClient) OnDelete(kv internal.KV) { + c.balancer.RemoveKey(kv.Key) +} + +func (c *subClient) WaitForServers() { + logx.Error("Waiting for alive servers") + c.lock.Lock() + defer c.lock.Unlock() + + if c.balancer.IsEmpty() { + c.cond.Wait() + } +} + +func (c *subClient) onAdd(keys []string, servers []string, newKey string) { + // guarded by locked outside + for _, listener := range c.listeners { + listener.OnUpdate(keys, servers, newKey) + } +} + +func (c *RoundRobinSubClient) Next() (interface{}, bool) { + return c.balancer.Next() +} + +func (c *ConsistentSubClient) Next(key string) (interface{}, bool) { + return c.balancer.Next(key) +} + +func (bc *BatchConsistentSubClient) Next(keys []string) (map[interface{}][]string, bool) { + if len(keys) == 0 { + return nil, false + } + + result := make(map[interface{}][]string) + for _, key := range keys { + dest, ok := bc.ConsistentSubClient.Next(key) + if !ok { + return nil, false + } + + result[dest] = append(result[dest], key) + } + + return result, true +} + +func BalanceWithId() BalanceOption { + return func(opts *balanceOptions) { + opts.balanceType = idBasedBalance + } +} diff --git a/core/discov/subscriber.go b/core/discov/subscriber.go new file mode 100644 index 00000000..d5a295fc --- /dev/null +++ b/core/discov/subscriber.go @@ -0,0 +1,151 @@ +package discov + +import ( + "sync" + + "zero/core/discov/internal" +) + +type ( + subOptions struct { + exclusive bool + } + + SubOption func(opts *subOptions) + + Subscriber struct { + items *container + } +) + +func NewSubscriber(endpoints []string, key string, opts ...SubOption) *Subscriber { + var subOpts subOptions + for _, opt := range opts { + opt(&subOpts) + } + + subscriber := &Subscriber{ + items: newContainer(subOpts.exclusive), + } + internal.GetRegistry().Monitor(endpoints, key, subscriber.items) + + return subscriber +} + +func (s *Subscriber) Values() []string { + return s.items.getValues() +} + +// exclusive means that key value can only be 1:1, +// which means later added value will remove the keys associated with the same value previously. +func Exclusive() SubOption { + return func(opts *subOptions) { + opts.exclusive = true + } +} + +type container struct { + exclusive bool + values map[string][]string + mapping map[string]string + lock sync.Mutex +} + +func newContainer(exclusive bool) *container { + return &container{ + exclusive: exclusive, + values: make(map[string][]string), + mapping: make(map[string]string), + } +} + +func (c *container) OnAdd(kv internal.KV) { + c.addKv(kv.Key, kv.Val) +} + +func (c *container) OnDelete(kv internal.KV) { + c.removeKey(kv.Key) +} + +// addKv adds the kv, returns if there are already other keys associate with the value +func (c *container) addKv(key, value string) ([]string, bool) { + c.lock.Lock() + defer c.lock.Unlock() + + keys := c.values[value] + previous := append([]string(nil), keys...) + early := len(keys) > 0 + if c.exclusive && early { + for _, each := range keys { + c.doRemoveKey(each) + } + } + c.values[value] = append(c.values[value], key) + c.mapping[key] = value + + if early { + return previous, true + } else { + return nil, false + } +} + +func (c *container) doRemoveKey(key string) { + server, ok := c.mapping[key] + if !ok { + return + } + + delete(c.mapping, key) + keys := c.values[server] + remain := keys[:0] + + for _, k := range keys { + if k != key { + remain = append(remain, k) + } + } + + if len(remain) > 0 { + c.values[server] = remain + } else { + delete(c.values, server) + } +} + +func (c *container) getValues() []string { + c.lock.Lock() + defer c.lock.Unlock() + + var vs []string + for each := range c.values { + vs = append(vs, each) + } + return vs +} + +// removeKey removes the kv, returns true if there are still other keys associate with the value +func (c *container) removeKey(key string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.doRemoveKey(key) +} + +func (c *container) removeVal(val string) (empty bool) { + c.lock.Lock() + defer c.lock.Unlock() + + for k := range c.values { + if k == val { + delete(c.values, k) + } + } + for k, v := range c.mapping { + if v == val { + delete(c.mapping, k) + } + } + + return len(c.values) == 0 +} diff --git a/core/errorx/atomicerror.go b/core/errorx/atomicerror.go new file mode 100644 index 00000000..76f874e5 --- /dev/null +++ b/core/errorx/atomicerror.go @@ -0,0 +1,21 @@ +package errorx + +import "sync" + +type AtomicError struct { + err error + lock sync.Mutex +} + +func (ae *AtomicError) Set(err error) { + ae.lock.Lock() + ae.err = err + ae.lock.Unlock() +} + +func (ae *AtomicError) Load() error { + ae.lock.Lock() + err := ae.err + ae.lock.Unlock() + return err +} diff --git a/core/errorx/atomicerror_test.go b/core/errorx/atomicerror_test.go new file mode 100644 index 00000000..fa19cdb2 --- /dev/null +++ b/core/errorx/atomicerror_test.go @@ -0,0 +1,21 @@ +package errorx + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var errDummy = errors.New("hello") + +func TestAtomicError(t *testing.T) { + var err AtomicError + err.Set(errDummy) + assert.Equal(t, errDummy, err.Load()) +} + +func TestAtomicErrorNil(t *testing.T) { + var err AtomicError + assert.Nil(t, err.Load()) +} diff --git a/core/errorx/batcherror.go b/core/errorx/batcherror.go new file mode 100644 index 00000000..dc83ab36 --- /dev/null +++ b/core/errorx/batcherror.go @@ -0,0 +1,45 @@ +package errorx + +import "bytes" + +type ( + BatchError struct { + errs errorArray + } + + errorArray []error +) + +func (be *BatchError) Add(err error) { + if err != nil { + be.errs = append(be.errs, err) + } +} + +func (be *BatchError) Err() error { + switch len(be.errs) { + case 0: + return nil + case 1: + return be.errs[0] + default: + return be.errs + } +} + +func (be *BatchError) NotNil() bool { + return len(be.errs) > 0 +} + +func (ea errorArray) Error() string { + var buf bytes.Buffer + + for i := range ea { + if i > 0 { + buf.WriteByte('\n') + } + buf.WriteString(ea[i].Error()) + } + + return buf.String() +} diff --git a/core/errorx/batcherror_test.go b/core/errorx/batcherror_test.go new file mode 100644 index 00000000..ae5c8c3e --- /dev/null +++ b/core/errorx/batcherror_test.go @@ -0,0 +1,48 @@ +package errorx + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + err1 = "first error" + err2 = "second error" +) + +func TestBatchErrorNil(t *testing.T) { + var batch BatchError + assert.Nil(t, batch.Err()) + assert.False(t, batch.NotNil()) + batch.Add(nil) + assert.Nil(t, batch.Err()) + assert.False(t, batch.NotNil()) +} + +func TestBatchErrorNilFromFunc(t *testing.T) { + err := func() error { + var be BatchError + return be.Err() + }() + assert.True(t, err == nil) +} + +func TestBatchErrorOneError(t *testing.T) { + var batch BatchError + batch.Add(errors.New(err1)) + assert.NotNil(t, batch) + assert.Equal(t, err1, batch.Err().Error()) + assert.True(t, batch.NotNil()) +} + +func TestBatchErrorWithErrors(t *testing.T) { + var batch BatchError + batch.Add(errors.New(err1)) + batch.Add(errors.New(err2)) + assert.NotNil(t, batch) + assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error()) + assert.True(t, batch.NotNil()) +} diff --git a/core/executors/bulkexecutor.go b/core/executors/bulkexecutor.go new file mode 100644 index 00000000..b7ca0232 --- /dev/null +++ b/core/executors/bulkexecutor.go @@ -0,0 +1,93 @@ +package executors + +import ( + "time" +) + +const defaultBulkTasks = 1000 + +type ( + BulkOption func(options *bulkOptions) + + BulkExecutor struct { + executor *PeriodicalExecutor + container *bulkContainer + } + + bulkOptions struct { + cachedTasks int + flushInterval time.Duration + } +) + +func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor { + options := newBulkOptions() + for _, opt := range opts { + opt(&options) + } + + container := &bulkContainer{ + execute: execute, + maxTasks: options.cachedTasks, + } + executor := &BulkExecutor{ + executor: NewPeriodicalExecutor(options.flushInterval, container), + container: container, + } + + return executor +} + +func (be *BulkExecutor) Add(task interface{}) error { + be.executor.Add(task) + return nil +} + +func (be *BulkExecutor) Flush() { + be.executor.Flush() +} + +func (be *BulkExecutor) Wait() { + be.executor.Wait() +} + +func WithBulkTasks(tasks int) BulkOption { + return func(options *bulkOptions) { + options.cachedTasks = tasks + } +} + +func WithBulkInterval(duration time.Duration) BulkOption { + return func(options *bulkOptions) { + options.flushInterval = duration + } +} + +func newBulkOptions() bulkOptions { + return bulkOptions{ + cachedTasks: defaultBulkTasks, + flushInterval: defaultFlushInterval, + } +} + +type bulkContainer struct { + tasks []interface{} + execute Execute + maxTasks int +} + +func (bc *bulkContainer) AddTask(task interface{}) bool { + bc.tasks = append(bc.tasks, task) + return len(bc.tasks) >= bc.maxTasks +} + +func (bc *bulkContainer) Execute(tasks interface{}) { + vals := tasks.([]interface{}) + bc.execute(vals) +} + +func (bc *bulkContainer) RemoveAll() interface{} { + tasks := bc.tasks + bc.tasks = nil + return tasks +} diff --git a/core/executors/bulkexecutor_test.go b/core/executors/bulkexecutor_test.go new file mode 100644 index 00000000..7645a257 --- /dev/null +++ b/core/executors/bulkexecutor_test.go @@ -0,0 +1,113 @@ +package executors + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBulkExecutor(t *testing.T) { + var values []int + var lock sync.Mutex + + exeutor := NewBulkExecutor(func(items []interface{}) { + lock.Lock() + values = append(values, len(items)) + lock.Unlock() + }, WithBulkTasks(10), WithBulkInterval(time.Minute)) + + for i := 0; i < 50; i++ { + exeutor.Add(1) + time.Sleep(time.Millisecond) + } + + lock.Lock() + assert.True(t, len(values) > 0) + // ignore last value + for i := 0; i < len(values); i++ { + assert.Equal(t, 10, values[i]) + } + lock.Unlock() +} + +func TestBulkExecutorFlushInterval(t *testing.T) { + const ( + caches = 10 + size = 5 + ) + var wait sync.WaitGroup + + wait.Add(1) + exeutor := NewBulkExecutor(func(items []interface{}) { + assert.Equal(t, size, len(items)) + wait.Done() + }, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100)) + + for i := 0; i < size; i++ { + exeutor.Add(1) + } + + wait.Wait() +} + +func TestBulkExecutorEmpty(t *testing.T) { + NewBulkExecutor(func(items []interface{}) { + assert.Fail(t, "should not called") + }, WithBulkTasks(10), WithBulkInterval(time.Millisecond)) + time.Sleep(time.Millisecond * 100) +} + +func TestBulkExecutorFlush(t *testing.T) { + const ( + caches = 10 + tasks = 5 + ) + + var wait sync.WaitGroup + wait.Add(1) + be := NewBulkExecutor(func(items []interface{}) { + assert.Equal(t, tasks, len(items)) + wait.Done() + }, WithBulkTasks(caches), WithBulkInterval(time.Minute)) + for i := 0; i < tasks; i++ { + be.Add(1) + } + be.Flush() + wait.Wait() +} + +func TestBuldExecutorFlushSlowTasks(t *testing.T) { + const total = 1500 + lock := new(sync.Mutex) + result := make([]interface{}, 0, 10000) + exec := NewBulkExecutor(func(tasks []interface{}) { + time.Sleep(time.Millisecond * 100) + lock.Lock() + defer lock.Unlock() + for _, i := range tasks { + result = append(result, i) + } + }, WithBulkTasks(1000)) + for i := 0; i < total; i++ { + assert.Nil(t, exec.Add(i)) + } + + exec.Flush() + exec.Wait() + assert.Equal(t, total, len(result)) +} + +func BenchmarkBulkExecutor(b *testing.B) { + b.ReportAllocs() + + be := NewBulkExecutor(func(tasks []interface{}) { + time.Sleep(time.Millisecond * time.Duration(len(tasks))) + }) + for i := 0; i < b.N; i++ { + time.Sleep(time.Microsecond * 200) + be.Add(1) + } + be.Flush() +} diff --git a/core/executors/chunkexecutor.go b/core/executors/chunkexecutor.go new file mode 100644 index 00000000..226e0e8a --- /dev/null +++ b/core/executors/chunkexecutor.go @@ -0,0 +1,103 @@ +package executors + +import "time" + +const defaultChunkSize = 1024 * 1024 // 1M + +type ( + ChunkOption func(options *chunkOptions) + + ChunkExecutor struct { + executor *PeriodicalExecutor + container *chunkContainer + } + + chunkOptions struct { + chunkSize int + flushInterval time.Duration + } +) + +func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor { + options := newChunkOptions() + for _, opt := range opts { + opt(&options) + } + + container := &chunkContainer{ + execute: execute, + maxChunkSize: options.chunkSize, + } + executor := &ChunkExecutor{ + executor: NewPeriodicalExecutor(options.flushInterval, container), + container: container, + } + + return executor +} + +func (ce *ChunkExecutor) Add(task interface{}, size int) error { + ce.executor.Add(chunk{ + val: task, + size: size, + }) + return nil +} + +func (ce *ChunkExecutor) Flush() { + ce.executor.Flush() +} + +func (ce *ChunkExecutor) Wait() { + ce.executor.Wait() +} + +func WithChunkBytes(size int) ChunkOption { + return func(options *chunkOptions) { + options.chunkSize = size + } +} + +func WithFlushInterval(duration time.Duration) ChunkOption { + return func(options *chunkOptions) { + options.flushInterval = duration + } +} + +func newChunkOptions() chunkOptions { + return chunkOptions{ + chunkSize: defaultChunkSize, + flushInterval: defaultFlushInterval, + } +} + +type chunkContainer struct { + tasks []interface{} + execute Execute + size int + maxChunkSize int +} + +func (bc *chunkContainer) AddTask(task interface{}) bool { + ck := task.(chunk) + bc.tasks = append(bc.tasks, ck.val) + bc.size += ck.size + return bc.size >= bc.maxChunkSize +} + +func (bc *chunkContainer) Execute(tasks interface{}) { + vals := tasks.([]interface{}) + bc.execute(vals) +} + +func (bc *chunkContainer) RemoveAll() interface{} { + tasks := bc.tasks + bc.tasks = nil + bc.size = 0 + return tasks +} + +type chunk struct { + val interface{} + size int +} diff --git a/core/executors/chunkexecutor_test.go b/core/executors/chunkexecutor_test.go new file mode 100644 index 00000000..9820c597 --- /dev/null +++ b/core/executors/chunkexecutor_test.go @@ -0,0 +1,92 @@ +package executors + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestChunkExecutor(t *testing.T) { + var values []int + var lock sync.Mutex + + exeutor := NewChunkExecutor(func(items []interface{}) { + lock.Lock() + values = append(values, len(items)) + lock.Unlock() + }, WithChunkBytes(10), WithFlushInterval(time.Minute)) + + for i := 0; i < 50; i++ { + exeutor.Add(1, 1) + time.Sleep(time.Millisecond) + } + + lock.Lock() + assert.True(t, len(values) > 0) + // ignore last value + for i := 0; i < len(values); i++ { + assert.Equal(t, 10, values[i]) + } + lock.Unlock() +} + +func TestChunkExecutorFlushInterval(t *testing.T) { + const ( + caches = 10 + size = 5 + ) + var wait sync.WaitGroup + + wait.Add(1) + exeutor := NewChunkExecutor(func(items []interface{}) { + assert.Equal(t, size, len(items)) + wait.Done() + }, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100)) + + for i := 0; i < size; i++ { + exeutor.Add(1, 1) + } + + wait.Wait() +} + +func TestChunkExecutorEmpty(t *testing.T) { + NewChunkExecutor(func(items []interface{}) { + assert.Fail(t, "should not called") + }, WithChunkBytes(10), WithFlushInterval(time.Millisecond)) + time.Sleep(time.Millisecond * 100) +} + +func TestChunkExecutorFlush(t *testing.T) { + const ( + caches = 10 + tasks = 5 + ) + + var wait sync.WaitGroup + wait.Add(1) + be := NewChunkExecutor(func(items []interface{}) { + assert.Equal(t, tasks, len(items)) + wait.Done() + }, WithChunkBytes(caches), WithFlushInterval(time.Minute)) + for i := 0; i < tasks; i++ { + be.Add(1, 1) + } + be.Flush() + wait.Wait() +} + +func BenchmarkChunkExecutor(b *testing.B) { + b.ReportAllocs() + + be := NewChunkExecutor(func(tasks []interface{}) { + time.Sleep(time.Millisecond * time.Duration(len(tasks))) + }) + for i := 0; i < b.N; i++ { + time.Sleep(time.Microsecond * 200) + be.Add(1, 1) + } + be.Flush() +} diff --git a/core/executors/delayexecutor.go b/core/executors/delayexecutor.go new file mode 100644 index 00000000..fe15f941 --- /dev/null +++ b/core/executors/delayexecutor.go @@ -0,0 +1,44 @@ +package executors + +import ( + "sync" + "time" + + "zero/core/threading" +) + +type DelayExecutor struct { + fn func() + delay time.Duration + triggered bool + lock sync.Mutex +} + +func NewDelayExecutor(fn func(), delay time.Duration) *DelayExecutor { + return &DelayExecutor{ + fn: fn, + delay: delay, + } +} + +func (de *DelayExecutor) Trigger() { + de.lock.Lock() + defer de.lock.Unlock() + + if de.triggered { + return + } + + de.triggered = true + threading.GoSafe(func() { + timer := time.NewTimer(de.delay) + defer timer.Stop() + <-timer.C + + // set triggered to false before calling fn to ensure no triggers are missed. + de.lock.Lock() + de.triggered = false + de.lock.Unlock() + de.fn() + }) +} diff --git a/core/executors/delayexecutor_test.go b/core/executors/delayexecutor_test.go new file mode 100644 index 00000000..67a779dd --- /dev/null +++ b/core/executors/delayexecutor_test.go @@ -0,0 +1,21 @@ +package executors + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDelayExecutor(t *testing.T) { + var count int32 + ex := NewDelayExecutor(func() { + atomic.AddInt32(&count, 1) + }, time.Millisecond*10) + for i := 0; i < 100; i++ { + ex.Trigger() + } + time.Sleep(time.Millisecond * 100) + assert.Equal(t, int32(1), atomic.LoadInt32(&count)) +} diff --git a/core/executors/lessexecutor.go b/core/executors/lessexecutor.go new file mode 100644 index 00000000..826786b1 --- /dev/null +++ b/core/executors/lessexecutor.go @@ -0,0 +1,32 @@ +package executors + +import ( + "time" + + "zero/core/syncx" + "zero/core/timex" +) + +type LessExecutor struct { + threshold time.Duration + lastTime *syncx.AtomicDuration +} + +func NewLessExecutor(threshold time.Duration) *LessExecutor { + return &LessExecutor{ + threshold: threshold, + lastTime: syncx.NewAtomicDuration(), + } +} + +func (le *LessExecutor) DoOrDiscard(execute func()) bool { + now := timex.Now() + lastTime := le.lastTime.Load() + if lastTime == 0 || lastTime+le.threshold < now { + le.lastTime.Set(now) + execute() + return true + } + + return false +} diff --git a/core/executors/lessexecutor_test.go b/core/executors/lessexecutor_test.go new file mode 100644 index 00000000..361b07df --- /dev/null +++ b/core/executors/lessexecutor_test.go @@ -0,0 +1,27 @@ +package executors + +import ( + "testing" + "time" + + "zero/core/timex" + + "github.com/stretchr/testify/assert" +) + +func TestLessExecutor_DoOrDiscard(t *testing.T) { + executor := NewLessExecutor(time.Minute) + assert.True(t, executor.DoOrDiscard(func() {})) + assert.False(t, executor.DoOrDiscard(func() {})) + executor.lastTime.Set(timex.Now() - time.Minute - time.Second*30) + assert.True(t, executor.DoOrDiscard(func() {})) + assert.False(t, executor.DoOrDiscard(func() {})) +} + +func BenchmarkLessExecutor(b *testing.B) { + exec := NewLessExecutor(time.Millisecond) + for i := 0; i < b.N; i++ { + exec.DoOrDiscard(func() { + }) + } +} diff --git a/core/executors/periodicalexecutor.go b/core/executors/periodicalexecutor.go new file mode 100644 index 00000000..d5372bf6 --- /dev/null +++ b/core/executors/periodicalexecutor.go @@ -0,0 +1,158 @@ +package executors + +import ( + "reflect" + "sync" + "time" + + "zero/core/proc" + "zero/core/threading" + "zero/core/timex" +) + +const idleRound = 10 + +type ( + // A type that satisfies executors.TaskContainer can be used as the underlying + // container that used to do periodical executions. + TaskContainer interface { + // AddTask adds the task into the container. + // Returns true if the container needs to be flushed after the addition. + AddTask(task interface{}) bool + // Execute handles the collected tasks by the container when flushing. + Execute(tasks interface{}) + // RemoveAll removes the contained tasks, and return them. + RemoveAll() interface{} + } + + PeriodicalExecutor struct { + commander chan interface{} + interval time.Duration + container TaskContainer + waitGroup sync.WaitGroup + guarded bool + newTicker func(duration time.Duration) timex.Ticker + lock sync.Mutex + } +) + +func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor { + executor := &PeriodicalExecutor{ + // buffer 1 to let the caller go quickly + commander: make(chan interface{}, 1), + interval: interval, + container: container, + newTicker: func(d time.Duration) timex.Ticker { + return timex.NewTicker(interval) + }, + } + proc.AddShutdownListener(func() { + executor.Flush() + }) + + return executor +} + +func (pe *PeriodicalExecutor) Add(task interface{}) { + if vals, ok := pe.addAndCheck(task); ok { + pe.commander <- vals + } +} + +func (pe *PeriodicalExecutor) Flush() bool { + return pe.executeTasks(func() interface{} { + pe.lock.Lock() + defer pe.lock.Unlock() + return pe.container.RemoveAll() + }()) +} + +func (pe *PeriodicalExecutor) Sync(fn func()) { + pe.lock.Lock() + defer pe.lock.Unlock() + fn() +} + +func (pe *PeriodicalExecutor) Wait() { + pe.waitGroup.Wait() +} + +func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) { + pe.lock.Lock() + defer func() { + var start bool + if !pe.guarded { + pe.guarded = true + start = true + } + pe.lock.Unlock() + if start { + pe.backgroundFlush() + } + }() + + if pe.container.AddTask(task) { + return pe.container.RemoveAll(), true + } + + return nil, false +} + +func (pe *PeriodicalExecutor) backgroundFlush() { + threading.GoSafe(func() { + ticker := pe.newTicker(pe.interval) + defer ticker.Stop() + + var commanded bool + last := timex.Now() + for { + select { + case vals := <-pe.commander: + commanded = true + pe.executeTasks(vals) + last = timex.Now() + case <-ticker.Chan(): + if commanded { + commanded = false + } else if pe.Flush() { + last = timex.Now() + } else if timex.Since(last) > pe.interval*idleRound { + pe.lock.Lock() + pe.guarded = false + pe.lock.Unlock() + + // flush again to avoid missing tasks + pe.Flush() + return + } + } + } + }) +} + +func (pe *PeriodicalExecutor) executeTasks(tasks interface{}) bool { + pe.waitGroup.Add(1) + defer pe.waitGroup.Done() + + ok := pe.hasTasks(tasks) + if ok { + pe.container.Execute(tasks) + } + + return ok +} + +func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool { + if tasks == nil { + return false + } + + val := reflect.ValueOf(tasks) + switch val.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice: + return val.Len() > 0 + default: + // unknown type, let caller execute it + return true + } +} diff --git a/core/executors/periodicalexecutor_test.go b/core/executors/periodicalexecutor_test.go new file mode 100644 index 00000000..163a3540 --- /dev/null +++ b/core/executors/periodicalexecutor_test.go @@ -0,0 +1,118 @@ +package executors + +import ( + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "zero/core/timex" + + "github.com/stretchr/testify/assert" +) + +const threshold = 10 + +type container struct { + interval time.Duration + tasks []int + execute func(tasks interface{}) +} + +func newContainer(interval time.Duration, execute func(tasks interface{})) *container { + return &container{ + interval: interval, + execute: execute, + } +} + +func (c *container) AddTask(task interface{}) bool { + c.tasks = append(c.tasks, task.(int)) + return len(c.tasks) > threshold +} + +func (c *container) Execute(tasks interface{}) { + if c.execute != nil { + c.execute(tasks) + } else { + time.Sleep(c.interval) + } +} + +func (c *container) RemoveAll() interface{} { + tasks := c.tasks + c.tasks = nil + return tasks +} + +func TestPeriodicalExecutor_Sync(t *testing.T) { + var done int32 + exec := NewPeriodicalExecutor(time.Second, newContainer(time.Millisecond*500, nil)) + exec.Sync(func() { + atomic.AddInt32(&done, 1) + }) + assert.Equal(t, int32(1), atomic.LoadInt32(&done)) +} + +func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) { + ticker := timex.NewFakeTicker() + exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil)) + exec.newTicker = func(d time.Duration) timex.Ticker { + return ticker + } + routines := runtime.NumGoroutine() + exec.Add(1) + ticker.Tick() + ticker.Wait(time.Millisecond * idleRound * 2) + ticker.Tick() + ticker.Wait(time.Millisecond * idleRound) + assert.Equal(t, routines, runtime.NumGoroutine()) +} + +func TestPeriodicalExecutor_Bulk(t *testing.T) { + ticker := timex.NewFakeTicker() + var vals []int + // avoid data race + var lock sync.Mutex + exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks interface{}) { + t := tasks.([]int) + for _, each := range t { + lock.Lock() + vals = append(vals, each) + lock.Unlock() + } + })) + exec.newTicker = func(d time.Duration) timex.Ticker { + return ticker + } + for i := 0; i < threshold*10; i++ { + if i%threshold == 5 { + time.Sleep(time.Millisecond * idleRound * 2) + } + exec.Add(i) + } + ticker.Tick() + ticker.Wait(time.Millisecond * idleRound * 2) + ticker.Tick() + ticker.Tick() + ticker.Wait(time.Millisecond * idleRound) + var expect []int + for i := 0; i < threshold*10; i++ { + expect = append(expect, i) + } + + lock.Lock() + assert.EqualValues(t, expect, vals) + lock.Unlock() +} + +// go test -benchtime 10s -bench . +func BenchmarkExecutor(b *testing.B) { + b.ReportAllocs() + + executor := NewPeriodicalExecutor(time.Second, newContainer(time.Millisecond*500, nil)) + for i := 0; i < b.N; i++ { + executor.Add(1) + } +} diff --git a/core/executors/vars.go b/core/executors/vars.go new file mode 100644 index 00000000..7d91c15c --- /dev/null +++ b/core/executors/vars.go @@ -0,0 +1,7 @@ +package executors + +import "time" + +const defaultFlushInterval = time.Second + +type Execute func(tasks []interface{}) diff --git a/core/filex/file.go b/core/filex/file.go new file mode 100644 index 00000000..af0794bb --- /dev/null +++ b/core/filex/file.go @@ -0,0 +1,84 @@ +package filex + +import ( + "io" + "os" +) + +const bufSize = 1024 + +func FirstLine(filename string) (string, error) { + file, err := os.Open(filename) + if err != nil { + return "", err + } + defer file.Close() + + return firstLine(file) +} + +func LastLine(filename string) (string, error) { + file, err := os.Open(filename) + if err != nil { + return "", err + } + defer file.Close() + + return lastLine(filename, file) +} + +func firstLine(file *os.File) (string, error) { + var first []byte + var offset int64 + for { + buf := make([]byte, bufSize) + n, err := file.ReadAt(buf, offset) + if err != nil && err != io.EOF { + return "", err + } + + for i := 0; i < n; i++ { + if buf[i] == '\n' { + return string(append(first, buf[:i]...)), nil + } + } + + first = append(first, buf[:n]...) + offset += bufSize + } +} + +func lastLine(filename string, file *os.File) (string, error) { + info, err := os.Stat(filename) + if err != nil { + return "", err + } + + var last []byte + offset := info.Size() + for { + offset -= bufSize + if offset < 0 { + offset = 0 + } + buf := make([]byte, bufSize) + n, err := file.ReadAt(buf, offset) + if err != nil && err != io.EOF { + return "", err + } + + if buf[n-1] == '\n' { + buf = buf[:n-1] + n -= 1 + } else { + buf = buf[:n] + } + for n -= 1; n >= 0; n-- { + if buf[n] == '\n' { + return string(append(buf[n+1:], last...)), nil + } + } + + last = append(buf, last...) + } +} diff --git a/core/filex/file_test.go b/core/filex/file_test.go new file mode 100644 index 00000000..7906ef71 --- /dev/null +++ b/core/filex/file_test.go @@ -0,0 +1,116 @@ +package filex + +import ( + "os" + "testing" + + "zero/core/fs" + + "github.com/stretchr/testify/assert" +) + +const ( + longLine = `Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.Quid securi etiam tamquam eu fugiat nulla pariatur. Nec dubitamus multa iter quae et nos invenerat. Non equidem invideo, miror magis posuere velit aliquet. Integer legentibus erat a ante historiarum dapibus. Prima luce, cum quibus mons aliud consensu ab eo.` + longFirstLine = longLine + "\n" + text + text = `first line +Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus. +Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum. +Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna. +Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil! +Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur. +Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus. +Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum. +Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna. +Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil! +Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur. +Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus. +Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum. +Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna. +Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil! +Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur. +` + longLine + textWithLastNewline = `first line +Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus. +Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum. +Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna. +Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil! +Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur. +Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus. +Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum. +Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna. +Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil! +Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur. +Cum sociis natoque penatibus et magnis dis parturient. Phasellus laoreet lorem vel dolor tempus vehicula. Vivamus sagittis lacus vel augue laoreet rutrum faucibus. Integer legentibus erat a ante historiarum dapibus. +Quisque ut dolor gravida, placerat libero vel, euismod. Quam temere in vitiis, legem sancimus haerentia. Qui ipsorum lingua Celtae, nostra Galli appellantur. Quis aute iure reprehenderit in voluptate velit esse. Fabio vel iudice vincam, sunt in culpa qui officia. Cras mattis iudicium purus sit amet fermentum. +Quo usque tandem abutere, Catilina, patientia nostra? Gallia est omnis divisa in partes tres, quarum. Quam diu etiam furor iste tuus nos eludet? Quid securi etiam tamquam eu fugiat nulla pariatur. Curabitur blandit tempus ardua ridiculus sed magna. +Magna pars studiorum, prodita quaerimus. Cum ceteris in veneratione tui montes, nascetur mus. Morbi odio eros, volutpat ut pharetra vitae, lobortis sed nibh. Plura mihi bona sunt, inclinet, amari petere vellent. Idque Caesaris facere voluntate liceret: sese habere. Tu quoque, Brute, fili mi, nihil timor populi, nihil! +Tityre, tu patulae recubans sub tegmine fagi dolor. Inmensae subtilitatis, obscuris et malesuada fames. Quae vero auctorem tractata ab fiducia dicuntur. +` + longLine + "\n" + shortText = `first line +second line +last line` + shortTextWithLastNewline = `first line +second line +last line +` +) + +func TestFirstLine(t *testing.T) { + filename, err := fs.TempFilenameWithText(longFirstLine) + assert.Nil(t, err) + defer os.Remove(filename) + + val, err := FirstLine(filename) + assert.Nil(t, err) + assert.Equal(t, longLine, val) +} + +func TestFirstLineShort(t *testing.T) { + filename, err := fs.TempFilenameWithText(shortText) + assert.Nil(t, err) + defer os.Remove(filename) + + val, err := FirstLine(filename) + assert.Nil(t, err) + assert.Equal(t, "first line", val) +} + +func TestLastLine(t *testing.T) { + filename, err := fs.TempFilenameWithText(text) + assert.Nil(t, err) + defer os.Remove(filename) + + val, err := LastLine(filename) + assert.Nil(t, err) + assert.Equal(t, longLine, val) +} + +func TestLastLineWithLastNewline(t *testing.T) { + filename, err := fs.TempFilenameWithText(textWithLastNewline) + assert.Nil(t, err) + defer os.Remove(filename) + + val, err := LastLine(filename) + assert.Nil(t, err) + assert.Equal(t, longLine, val) +} + +func TestLastLineShort(t *testing.T) { + filename, err := fs.TempFilenameWithText(shortText) + assert.Nil(t, err) + defer os.Remove(filename) + + val, err := LastLine(filename) + assert.Nil(t, err) + assert.Equal(t, "last line", val) +} + +func TestLastLineWithLastNewlineShort(t *testing.T) { + filename, err := fs.TempFilenameWithText(shortTextWithLastNewline) + assert.Nil(t, err) + defer os.Remove(filename) + + val, err := LastLine(filename) + assert.Nil(t, err) + assert.Equal(t, "last line", val) +} diff --git a/core/filex/lookup.go b/core/filex/lookup.go new file mode 100644 index 00000000..9dfc5c68 --- /dev/null +++ b/core/filex/lookup.go @@ -0,0 +1,105 @@ +package filex + +import ( + "io" + "os" +) + +type OffsetRange struct { + File string + Start int64 + Stop int64 +} + +func SplitLineChunks(filename string, chunks int) ([]OffsetRange, error) { + info, err := os.Stat(filename) + if err != nil { + return nil, err + } + + if chunks <= 1 { + return []OffsetRange{ + { + File: filename, + Start: 0, + Stop: info.Size(), + }, + }, nil + } + + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + + var ranges []OffsetRange + var offset int64 + // avoid the last chunk too few bytes + preferSize := info.Size()/int64(chunks) + 1 + for { + if offset+preferSize >= info.Size() { + ranges = append(ranges, OffsetRange{ + File: filename, + Start: offset, + Stop: info.Size(), + }) + break + } + + offsetRange, err := nextRange(file, offset, offset+preferSize) + if err != nil { + return nil, err + } + + ranges = append(ranges, offsetRange) + if offsetRange.Stop < info.Size() { + offset = offsetRange.Stop + } else { + break + } + } + + return ranges, nil +} + +func nextRange(file *os.File, start, stop int64) (OffsetRange, error) { + offset, err := skipPartialLine(file, stop) + if err != nil { + return OffsetRange{}, err + } + + return OffsetRange{ + File: file.Name(), + Start: start, + Stop: offset, + }, nil +} + +func skipPartialLine(file *os.File, offset int64) (int64, error) { + for { + skipBuf := make([]byte, bufSize) + n, err := file.ReadAt(skipBuf, offset) + if err != nil && err != io.EOF { + return 0, err + } + if n == 0 { + return 0, io.EOF + } + + for i := 0; i < n; i++ { + if skipBuf[i] != '\r' && skipBuf[i] != '\n' { + offset++ + } else { + for ; i < n; i++ { + if skipBuf[i] == '\r' || skipBuf[i] == '\n' { + offset++ + } else { + return offset, nil + } + } + return offset, nil + } + } + } +} diff --git a/core/filex/lookup_test.go b/core/filex/lookup_test.go new file mode 100644 index 00000000..8243f042 --- /dev/null +++ b/core/filex/lookup_test.go @@ -0,0 +1,68 @@ +package filex + +import ( + "os" + "testing" + + "zero/core/fs" + + "github.com/stretchr/testify/assert" +) + +func TestSplitLineChunks(t *testing.T) { + const text = `first line +second line +third line +fourth line +fifth line +sixth line +seventh line +` + fp, err := fs.TempFileWithText(text) + assert.Nil(t, err) + defer func() { + fp.Close() + os.Remove(fp.Name()) + }() + + offsets, err := SplitLineChunks(fp.Name(), 3) + assert.Nil(t, err) + body := make([]byte, 512) + for _, offset := range offsets { + reader := NewRangeReader(fp, offset.Start, offset.Stop) + n, err := reader.Read(body) + assert.Nil(t, err) + assert.Equal(t, uint8('\n'), body[n-1]) + } +} + +func TestSplitLineChunksNoFile(t *testing.T) { + _, err := SplitLineChunks("nosuchfile", 2) + assert.NotNil(t, err) +} + +func TestSplitLineChunksFull(t *testing.T) { + const text = `first line +second line +third line +fourth line +fifth line +sixth line +` + fp, err := fs.TempFileWithText(text) + assert.Nil(t, err) + defer func() { + fp.Close() + os.Remove(fp.Name()) + }() + + offsets, err := SplitLineChunks(fp.Name(), 1) + assert.Nil(t, err) + body := make([]byte, 512) + for _, offset := range offsets { + reader := NewRangeReader(fp, offset.Start, offset.Stop) + n, err := reader.Read(body) + assert.Nil(t, err) + assert.Equal(t, []byte(text), body[:n]) + } +} diff --git a/core/filex/progressscanner.go b/core/filex/progressscanner.go new file mode 100644 index 00000000..be0e8611 --- /dev/null +++ b/core/filex/progressscanner.go @@ -0,0 +1,28 @@ +package filex + +import "gopkg.in/cheggaaa/pb.v1" + +type ( + Scanner interface { + Scan() bool + Text() string + } + + progressScanner struct { + Scanner + bar *pb.ProgressBar + } +) + +func NewProgressScanner(scanner Scanner, bar *pb.ProgressBar) Scanner { + return &progressScanner{ + Scanner: scanner, + bar: bar, + } +} + +func (ps *progressScanner) Text() string { + s := ps.Scanner.Text() + ps.bar.Add64(int64(len(s)) + 1) // take newlines into account + return s +} diff --git a/core/filex/progressscanner_test.go b/core/filex/progressscanner_test.go new file mode 100644 index 00000000..b6b2d384 --- /dev/null +++ b/core/filex/progressscanner_test.go @@ -0,0 +1,31 @@ +package filex + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/cheggaaa/pb.v1" +) + +func TestProgressScanner(t *testing.T) { + const text = "hello, world" + bar := pb.New(100) + var builder strings.Builder + builder.WriteString(text) + scanner := NewProgressScanner(&mockedScanner{builder: &builder}, bar) + assert.True(t, scanner.Scan()) + assert.Equal(t, text, scanner.Text()) +} + +type mockedScanner struct { + builder *strings.Builder +} + +func (s *mockedScanner) Scan() bool { + return s.builder.Len() > 0 +} + +func (s *mockedScanner) Text() string { + return s.builder.String() +} diff --git a/core/filex/rangereader.go b/core/filex/rangereader.go new file mode 100644 index 00000000..74f6fcbe --- /dev/null +++ b/core/filex/rangereader.go @@ -0,0 +1,43 @@ +package filex + +import ( + "errors" + "os" +) + +type RangeReader struct { + file *os.File + start int64 + stop int64 +} + +func NewRangeReader(file *os.File, start, stop int64) *RangeReader { + return &RangeReader{ + file: file, + start: start, + stop: stop, + } +} + +func (rr *RangeReader) Read(p []byte) (n int, err error) { + stat, err := rr.file.Stat() + if err != nil { + return 0, err + } + + if rr.stop < rr.start || rr.start >= stat.Size() { + return 0, errors.New("exceed file size") + } + + if rr.stop-rr.start < int64(len(p)) { + p = p[:rr.stop-rr.start] + } + + n, err = rr.file.ReadAt(p, rr.start) + if err != nil { + return n, err + } + + rr.start += int64(n) + return +} diff --git a/core/filex/rangereader_test.go b/core/filex/rangereader_test.go new file mode 100644 index 00000000..3d773973 --- /dev/null +++ b/core/filex/rangereader_test.go @@ -0,0 +1,45 @@ +package filex + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "zero/core/fs" +) + +func TestRangeReader(t *testing.T) { + const text = `hello +world` + file, err := fs.TempFileWithText(text) + assert.Nil(t, err) + defer func() { + file.Close() + os.Remove(file.Name()) + }() + + reader := NewRangeReader(file, 5, 8) + buf := make([]byte, 10) + n, err := reader.Read(buf) + assert.Nil(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, ` +wo`, string(buf[:n])) +} + +func TestRangeReader_OutOfRange(t *testing.T) { + const text = `hello +world` + file, err := fs.TempFileWithText(text) + assert.Nil(t, err) + defer func() { + file.Close() + os.Remove(file.Name()) + }() + + reader := NewRangeReader(file, 50, 8) + buf := make([]byte, 10) + _, err = reader.Read(buf) + assert.NotNil(t, err) +} diff --git a/core/fs/files+polyfill.go b/core/fs/files+polyfill.go new file mode 100644 index 00000000..fef5a026 --- /dev/null +++ b/core/fs/files+polyfill.go @@ -0,0 +1,8 @@ +// +build windows + +package fs + +import "os" + +func CloseOnExec(*os.File) { +} diff --git a/core/fs/files.go b/core/fs/files.go new file mode 100644 index 00000000..b0cb870c --- /dev/null +++ b/core/fs/files.go @@ -0,0 +1,14 @@ +// +build linux darwin + +package fs + +import ( + "os" + "syscall" +) + +func CloseOnExec(file *os.File) { + if file != nil { + syscall.CloseOnExec(int(file.Fd())) + } +} diff --git a/core/fs/temps.go b/core/fs/temps.go new file mode 100644 index 00000000..78a00e1d --- /dev/null +++ b/core/fs/temps.go @@ -0,0 +1,42 @@ +package fs + +import ( + "io/ioutil" + "os" + + "zero/core/hash" +) + +// TempFileWithText creates the temporary file with the given content, +// and returns the opened *os.File instance. +// The file is kept as open, the caller should close the file handle, +// and remove the file by name. +func TempFileWithText(text string) (*os.File, error) { + tmpfile, err := ioutil.TempFile(os.TempDir(), hash.Md5Hex([]byte(text))) + if err != nil { + return nil, err + } + + if err := ioutil.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil { + return nil, err + } + + return tmpfile, nil +} + +// TempFilenameWithText creates the file with the given content, +// and returns the filename (full path). +// The caller should remove the file after use. +func TempFilenameWithText(text string) (string, error) { + tmpfile, err := TempFileWithText(text) + if err != nil { + return "", err + } + + filename := tmpfile.Name() + if err = tmpfile.Close(); err != nil { + return "", err + } + + return filename, nil +} diff --git a/core/fx/fn.go b/core/fx/fn.go new file mode 100644 index 00000000..67067984 --- /dev/null +++ b/core/fx/fn.go @@ -0,0 +1,357 @@ +package fx + +import ( + "sort" + "sync" + + "zero/core/collection" + "zero/core/lang" + "zero/core/threading" +) + +const ( + defaultWorkers = 16 + minWorkers = 1 +) + +type ( + rxOptions struct { + unlimitedWorkers bool + workers int + } + + FilterFunc func(item interface{}) bool + ForAllFunc func(pipe <-chan interface{}) + ForEachFunc func(item interface{}) + GenerateFunc func(source chan<- interface{}) + KeyFunc func(item interface{}) interface{} + LessFunc func(a, b interface{}) bool + MapFunc func(item interface{}) interface{} + Option func(opts *rxOptions) + ParallelFunc func(item interface{}) + ReduceFunc func(pipe <-chan interface{}) (interface{}, error) + WalkFunc func(item interface{}, pipe chan<- interface{}) + + Stream struct { + source <-chan interface{} + } +) + +// From constructs a Stream from the given GenerateFunc. +func From(generate GenerateFunc) Stream { + source := make(chan interface{}) + + threading.GoSafe(func() { + defer close(source) + generate(source) + }) + + return Range(source) +} + +// Just converts the given arbitary items to a Stream. +func Just(items ...interface{}) Stream { + source := make(chan interface{}, len(items)) + for _, item := range items { + source <- item + } + close(source) + + return Range(source) +} + +// Range converts the given channel to a Stream. +func Range(source <-chan interface{}) Stream { + return Stream{ + source: source, + } +} + +// Buffer buffers the items into a queue with size n. +func (p Stream) Buffer(n int) Stream { + if n < 0 { + n = 0 + } + + source := make(chan interface{}, n) + go func() { + for item := range p.source { + source <- item + } + close(source) + }() + + return Range(source) +} + +// Distinct removes the duplicated items base on the given KeyFunc. +func (p Stream) Distinct(fn KeyFunc) Stream { + source := make(chan interface{}) + + threading.GoSafe(func() { + defer close(source) + + keys := make(map[interface{}]lang.PlaceholderType) + for item := range p.source { + key := fn(item) + if _, ok := keys[key]; !ok { + source <- item + keys[key] = lang.Placeholder + } + } + }) + + return Range(source) +} + +// Done waits all upstreaming operations to be done. +func (p Stream) Done() { + for range p.source { + } +} + +// Filter filters the items by the given FilterFunc. +func (p Stream) Filter(fn FilterFunc, opts ...Option) Stream { + return p.Walk(func(item interface{}, pipe chan<- interface{}) { + if fn(item) { + pipe <- item + } + }, opts...) +} + +// ForAll handles the streaming elements from the source and no later streams. +func (p Stream) ForAll(fn ForAllFunc) { + fn(p.source) +} + +// ForEach seals the Stream with the ForEachFunc on each item, no successive operations. +func (p Stream) ForEach(fn ForEachFunc) { + for item := range p.source { + fn(item) + } +} + +// Group groups the elements into different groups based on their keys. +func (p Stream) Group(fn KeyFunc) Stream { + groups := make(map[interface{}][]interface{}) + for item := range p.source { + key := fn(item) + groups[key] = append(groups[key], item) + } + + source := make(chan interface{}) + go func() { + for _, group := range groups { + source <- group + } + close(source) + }() + + return Range(source) +} + +func (p Stream) Head(n int64) Stream { + source := make(chan interface{}) + + go func() { + for item := range p.source { + n-- + if n >= 0 { + source <- item + } + if n == 0 { + // let successive method go ASAP even we have more items to skip + // why we don't just break the loop, because if break, + // this former goroutine will block forever, which will cause goroutine leak. + close(source) + } + } + if n > 0 { + close(source) + } + }() + + return Range(source) +} + +// Maps converts each item to another corresponding item, which means it's a 1:1 model. +func (p Stream) Map(fn MapFunc, opts ...Option) Stream { + return p.Walk(func(item interface{}, pipe chan<- interface{}) { + pipe <- fn(item) + }, opts...) +} + +// Merge merges all the items into a slice and generates a new stream. +func (p Stream) Merge() Stream { + var items []interface{} + for item := range p.source { + items = append(items, item) + } + + source := make(chan interface{}, 1) + source <- items + close(source) + + return Range(source) +} + +// Parallel applies the given ParallenFunc to each item concurrently with given number of workers. +func (p Stream) Parallel(fn ParallelFunc, opts ...Option) { + p.Walk(func(item interface{}, pipe chan<- interface{}) { + fn(item) + }, opts...).Done() +} + +// Reduce is a utility method to let the caller deal with the underlying channel. +func (p Stream) Reduce(fn ReduceFunc) (interface{}, error) { + return fn(p.source) +} + +// Reverse reverses the elements in the stream. +func (p Stream) Reverse() Stream { + var items []interface{} + for item := range p.source { + items = append(items, item) + } + // reverse, official method + for i := len(items)/2 - 1; i >= 0; i-- { + opp := len(items) - 1 - i + items[i], items[opp] = items[opp], items[i] + } + + return Just(items...) +} + +// Sort sorts the items from the underlying source. +func (p Stream) Sort(less LessFunc) Stream { + var items []interface{} + for item := range p.source { + items = append(items, item) + } + sort.Slice(items, func(i, j int) bool { + return less(items[i], items[j]) + }) + + return Just(items...) +} + +func (p Stream) Tail(n int64) Stream { + source := make(chan interface{}) + + go func() { + ring := collection.NewRing(int(n)) + for item := range p.source { + ring.Add(item) + } + for _, item := range ring.Take() { + source <- item + } + close(source) + }() + + return Range(source) +} + +// Walk lets the callers handle each item, the caller may write zero, one or more items base on the given item. +func (p Stream) Walk(fn WalkFunc, opts ...Option) Stream { + option := buildOptions(opts...) + if option.unlimitedWorkers { + return p.walkUnlimited(fn, option) + } else { + return p.walkLimited(fn, option) + } +} + +func (p Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream { + pipe := make(chan interface{}, option.workers) + + go func() { + var wg sync.WaitGroup + pool := make(chan lang.PlaceholderType, option.workers) + + for { + pool <- lang.Placeholder + item, ok := <-p.source + if !ok { + <-pool + break + } + + wg.Add(1) + // better to safely run caller defined method + threading.GoSafe(func() { + defer func() { + wg.Done() + <-pool + }() + + fn(item, pipe) + }) + } + + wg.Wait() + close(pipe) + }() + + return Range(pipe) +} + +func (p Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream { + pipe := make(chan interface{}, defaultWorkers) + + go func() { + var wg sync.WaitGroup + + for { + item, ok := <-p.source + if !ok { + break + } + + wg.Add(1) + // better to safely run caller defined method + threading.GoSafe(func() { + defer wg.Done() + fn(item, pipe) + }) + } + + wg.Wait() + close(pipe) + }() + + return Range(pipe) +} + +// UnlimitedWorkers lets the caller to use as many workers as the tasks. +func UnlimitedWorkers() Option { + return func(opts *rxOptions) { + opts.unlimitedWorkers = true + } +} + +// WithWorkers lets the caller to customize the concurrent workers. +func WithWorkers(workers int) Option { + return func(opts *rxOptions) { + if workers < minWorkers { + opts.workers = minWorkers + } else { + opts.workers = workers + } + } +} + +func buildOptions(opts ...Option) *rxOptions { + options := newOptions() + for _, opt := range opts { + opt(options) + } + + return options +} + +func newOptions() *rxOptions { + return &rxOptions{ + workers: defaultWorkers, + } +} diff --git a/core/fx/fn_test.go b/core/fx/fn_test.go new file mode 100644 index 00000000..aed5d01f --- /dev/null +++ b/core/fx/fn_test.go @@ -0,0 +1,293 @@ +package fx + +import ( + "io/ioutil" + "log" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +func TestBuffer(t *testing.T) { + const N = 5 + var count int32 + var wait sync.WaitGroup + wait.Add(1) + From(func(source chan<- interface{}) { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for i := 0; i < 2*N; i++ { + select { + case source <- i: + atomic.AddInt32(&count, 1) + case <-ticker.C: + wait.Done() + return + } + } + }).Buffer(N).ForAll(func(pipe <-chan interface{}) { + wait.Wait() + // why N+1, because take one more to wait for sending into the channel + assert.Equal(t, int32(N+1), atomic.LoadInt32(&count)) + }) +} + +func TestBufferNegative(t *testing.T) { + var result int + Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + assert.Equal(t, 10, result) +} + +func TestDone(t *testing.T) { + var count int32 + Just(1, 2, 3).Walk(func(item interface{}, pipe chan<- interface{}) { + time.Sleep(time.Millisecond * 100) + atomic.AddInt32(&count, int32(item.(int))) + }).Done() + assert.Equal(t, int32(6), count) +} + +func TestJust(t *testing.T) { + var result int + Just(1, 2, 3, 4).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + assert.Equal(t, 10, result) +} + +func TestDistinct(t *testing.T) { + var result int + Just(4, 1, 3, 2, 3, 4).Distinct(func(item interface{}) interface{} { + return item + }).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + assert.Equal(t, 10, result) +} + +func TestFilter(t *testing.T) { + var result int + Just(1, 2, 3, 4).Filter(func(item interface{}) bool { + return item.(int)%2 == 0 + }).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + assert.Equal(t, 6, result) +} + +func TestForAll(t *testing.T) { + var result int + Just(1, 2, 3, 4).Filter(func(item interface{}) bool { + return item.(int)%2 == 0 + }).ForAll(func(pipe <-chan interface{}) { + for item := range pipe { + result += item.(int) + } + }) + assert.Equal(t, 6, result) +} + +func TestGroup(t *testing.T) { + var groups [][]int + Just(10, 11, 20, 21).Group(func(item interface{}) interface{} { + v := item.(int) + return v / 10 + }).ForEach(func(item interface{}) { + v := item.([]interface{}) + var group []int + for _, each := range v { + group = append(group, each.(int)) + } + groups = append(groups, group) + }) + + assert.Equal(t, 2, len(groups)) + for _, group := range groups { + assert.Equal(t, 2, len(group)) + assert.True(t, group[0]/10 == group[1]/10) + } +} + +func TestHead(t *testing.T) { + var result int + Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + assert.Equal(t, 3, result) +} + +func TestHeadMore(t *testing.T) { + var result int + Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + assert.Equal(t, 10, result) +} + +func TestMap(t *testing.T) { + log.SetOutput(ioutil.Discard) + + tests := []struct { + mapper MapFunc + expect int + }{ + { + mapper: func(item interface{}) interface{} { + v := item.(int) + return v * v + }, + expect: 30, + }, + { + mapper: func(item interface{}) interface{} { + v := item.(int) + if v%2 == 0 { + return 0 + } + return v * v + }, + expect: 10, + }, + { + mapper: func(item interface{}) interface{} { + v := item.(int) + if v%2 == 0 { + panic(v) + } + return v * v + }, + expect: 10, + }, + } + + // Map(...) works even WithWorkers(0) + for i, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + var result int + var workers int + if i%2 == 0 { + workers = 0 + } else { + workers = runtime.NumCPU() + } + From(func(source chan<- interface{}) { + for i := 1; i < 5; i++ { + source <- i + } + }).Map(test.mapper, WithWorkers(workers)).Reduce( + func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + + assert.Equal(t, test.expect, result) + }) + } +} + +func TestMerge(t *testing.T) { + Just(1, 2, 3, 4).Merge().ForEach(func(item interface{}) { + assert.ElementsMatch(t, []interface{}{1, 2, 3, 4}, item.([]interface{})) + }) +} + +func TestParallelJust(t *testing.T) { + var count int32 + Just(1, 2, 3).Parallel(func(item interface{}) { + time.Sleep(time.Millisecond * 100) + atomic.AddInt32(&count, int32(item.(int))) + }, UnlimitedWorkers()) + assert.Equal(t, int32(6), count) +} + +func TestReverse(t *testing.T) { + Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item interface{}) { + assert.ElementsMatch(t, []interface{}{4, 3, 2, 1}, item.([]interface{})) + }) +} + +func TestSort(t *testing.T) { + var prev int + Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b interface{}) bool { + return a.(int) < b.(int) + }).ForEach(func(item interface{}) { + next := item.(int) + assert.True(t, prev < next) + prev = next + }) +} + +func TestTail(t *testing.T) { + var result int + Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + for item := range pipe { + result += item.(int) + } + return result, nil + }) + assert.Equal(t, 7, result) +} + +func TestWalk(t *testing.T) { + var result int + Just(1, 2, 3, 4, 5).Walk(func(item interface{}, pipe chan<- interface{}) { + if item.(int)%2 != 0 { + pipe <- item + } + }, UnlimitedWorkers()).ForEach(func(item interface{}) { + result += item.(int) + }) + assert.Equal(t, 9, result) +} + +func BenchmarkMapReduce(b *testing.B) { + b.ReportAllocs() + + mapper := func(v interface{}) interface{} { + return v.(int64) * v.(int64) + } + reducer := func(input <-chan interface{}) (interface{}, error) { + var result int64 + for v := range input { + result += v.(int64) + } + return result, nil + } + + for i := 0; i < b.N; i++ { + From(func(input chan<- interface{}) { + for j := 0; j < 2; j++ { + input <- int64(j) + } + }).Map(mapper).Reduce(reducer) + } +} diff --git a/core/fx/parallel.go b/core/fx/parallel.go new file mode 100644 index 00000000..89c91142 --- /dev/null +++ b/core/fx/parallel.go @@ -0,0 +1,11 @@ +package fx + +import "zero/core/threading" + +func Parallel(fns ...func()) { + group := threading.NewRoutineGroup() + for _, fn := range fns { + group.RunSafe(fn) + } + group.Wait() +} diff --git a/core/fx/parallel_test.go b/core/fx/parallel_test.go new file mode 100644 index 00000000..eed5521a --- /dev/null +++ b/core/fx/parallel_test.go @@ -0,0 +1,24 @@ +package fx + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParallel(t *testing.T) { + var count int32 + Parallel(func() { + time.Sleep(time.Millisecond * 100) + atomic.AddInt32(&count, 1) + }, func() { + time.Sleep(time.Millisecond * 100) + atomic.AddInt32(&count, 2) + }, func() { + time.Sleep(time.Millisecond * 100) + atomic.AddInt32(&count, 3) + }) + assert.Equal(t, int32(6), count) +} diff --git a/core/fx/retry.go b/core/fx/retry.go new file mode 100644 index 00000000..338c2da2 --- /dev/null +++ b/core/fx/retry.go @@ -0,0 +1,43 @@ +package fx + +import "zero/core/errorx" + +const defaultRetryTimes = 3 + +type ( + RetryOption func(*retryOptions) + + retryOptions struct { + times int + } +) + +func DoWithRetries(fn func() error, opts ...RetryOption) error { + var options = newRetryOptions() + for _, opt := range opts { + opt(options) + } + + var berr errorx.BatchError + for i := 0; i < options.times; i++ { + if err := fn(); err != nil { + berr.Add(err) + } else { + return nil + } + } + + return berr.Err() +} + +func WithRetries(times int) RetryOption { + return func(options *retryOptions) { + options.times = times + } +} + +func newRetryOptions() *retryOptions { + return &retryOptions{ + times: defaultRetryTimes, + } +} diff --git a/core/fx/retry_test.go b/core/fx/retry_test.go new file mode 100644 index 00000000..73f8b408 --- /dev/null +++ b/core/fx/retry_test.go @@ -0,0 +1,42 @@ +package fx + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRetry(t *testing.T) { + assert.NotNil(t, DoWithRetries(func() error { + return errors.New("any") + })) + + var times int + assert.Nil(t, DoWithRetries(func() error { + times++ + if times == defaultRetryTimes { + return nil + } + return errors.New("any") + })) + + times = 0 + assert.NotNil(t, DoWithRetries(func() error { + times++ + if times == defaultRetryTimes+1 { + return nil + } + return errors.New("any") + })) + + var total = 2 * defaultRetryTimes + times = 0 + assert.Nil(t, DoWithRetries(func() error { + times++ + if times == total { + return nil + } + return errors.New("any") + }, WithRetries(total))) +} diff --git a/core/fx/timeout.go b/core/fx/timeout.go new file mode 100644 index 00000000..2cd1b6c7 --- /dev/null +++ b/core/fx/timeout.go @@ -0,0 +1,49 @@ +package fx + +import ( + "context" + "time" +) + +var ( + ErrCanceled = context.Canceled + ErrTimeout = context.DeadlineExceeded +) + +type FxOption func() context.Context + +func DoWithTimeout(fn func() error, timeout time.Duration, opts ...FxOption) error { + parentCtx := context.Background() + for _, opt := range opts { + parentCtx = opt() + } + ctx, cancel := context.WithTimeout(parentCtx, timeout) + defer cancel() + + done := make(chan error) + panicChan := make(chan interface{}, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + done <- fn() + close(done) + }() + + select { + case p := <-panicChan: + panic(p) + case err := <-done: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func WithContext(ctx context.Context) FxOption { + return func() context.Context { + return ctx + } +} diff --git a/core/fx/timeout_test.go b/core/fx/timeout_test.go new file mode 100644 index 00000000..bb944ac9 --- /dev/null +++ b/core/fx/timeout_test.go @@ -0,0 +1,43 @@ +package fx + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWithPanic(t *testing.T) { + assert.Panics(t, func() { + _ = DoWithTimeout(func() error { + panic("hello") + }, time.Millisecond*50) + }) +} + +func TestWithTimeout(t *testing.T) { + assert.Equal(t, ErrTimeout, DoWithTimeout(func() error { + time.Sleep(time.Millisecond * 50) + return nil + }, time.Millisecond)) +} + +func TestWithoutTimeout(t *testing.T) { + assert.Nil(t, DoWithTimeout(func() error { + return nil + }, time.Millisecond*50)) +} + +func TestWithCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + err := DoWithTimeout(func() error { + time.Sleep(time.Minute) + return nil + }, time.Second, WithContext(ctx)) + assert.Equal(t, ErrCanceled, err) +} diff --git a/core/hash/consistenthash.go b/core/hash/consistenthash.go new file mode 100644 index 00000000..ece56559 --- /dev/null +++ b/core/hash/consistenthash.go @@ -0,0 +1,180 @@ +package hash + +import ( + "fmt" + "sort" + "strconv" + "sync" + + "zero/core/lang" + "zero/core/mapping" +) + +const ( + TopWeight = 100 + + minReplicas = 100 + prime = 16777619 +) + +type ( + HashFunc func(data []byte) uint64 + + ConsistentHash struct { + hashFunc HashFunc + replicas int + keys []uint64 + ring map[uint64][]interface{} + nodes map[string]lang.PlaceholderType + lock sync.RWMutex + } +) + +func NewConsistentHash() *ConsistentHash { + return NewCustomConsistentHash(minReplicas, Hash) +} + +func NewCustomConsistentHash(replicas int, fn HashFunc) *ConsistentHash { + if replicas < minReplicas { + replicas = minReplicas + } + + if fn == nil { + fn = Hash + } + + return &ConsistentHash{ + hashFunc: fn, + replicas: replicas, + ring: make(map[uint64][]interface{}), + nodes: make(map[string]lang.PlaceholderType), + } +} + +// Add adds the node with the number of h.replicas, +// the later call will overwrite the replicas of the former calls. +func (h *ConsistentHash) Add(node interface{}) { + h.AddWithReplicas(node, h.replicas) +} + +// AddWithReplicas adds the node with the number of replicas, +// replicas will be truncated to h.replicas if it's larger than h.replicas, +// the later call will overwrite the replicas of the former calls. +func (h *ConsistentHash) AddWithReplicas(node interface{}, replicas int) { + h.Remove(node) + + if replicas > h.replicas { + replicas = h.replicas + } + + nodeRepr := repr(node) + h.lock.Lock() + defer h.lock.Unlock() + h.addNode(nodeRepr) + + for i := 0; i < replicas; i++ { + hash := h.hashFunc([]byte(nodeRepr + strconv.Itoa(i))) + h.keys = append(h.keys, hash) + h.ring[hash] = append(h.ring[hash], node) + } + + sort.Slice(h.keys, func(i int, j int) bool { + return h.keys[i] < h.keys[j] + }) +} + +// AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent, +// the later call will overwrite the replicas of the former calls. +func (h *ConsistentHash) AddWithWeight(node interface{}, weight int) { + // don't need to make sure weight not larger than TopWeight, + // because AddWithReplicas makes sure replicas cannot be larger than h.replicas + replicas := h.replicas * weight / TopWeight + h.AddWithReplicas(node, replicas) +} + +func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) { + h.lock.RLock() + defer h.lock.RUnlock() + + if len(h.ring) == 0 { + return nil, false + } + + hash := h.hashFunc([]byte(repr(v))) + index := sort.Search(len(h.keys), func(i int) bool { + return h.keys[i] >= hash + }) % len(h.keys) + + nodes := h.ring[h.keys[index]] + switch len(nodes) { + case 0: + return nil, false + case 1: + return nodes[0], true + default: + innerIndex := h.hashFunc([]byte(innerRepr(v))) + pos := int(innerIndex % uint64(len(nodes))) + return nodes[pos], true + } +} + +func (h *ConsistentHash) Remove(node interface{}) { + nodeRepr := repr(node) + + h.lock.Lock() + defer h.lock.Unlock() + + if !h.containsNode(nodeRepr) { + return + } + + for i := 0; i < h.replicas; i++ { + hash := h.hashFunc([]byte(nodeRepr + strconv.Itoa(i))) + index := sort.Search(len(h.keys), func(i int) bool { + return h.keys[i] >= hash + }) + if index < len(h.keys) { + h.keys = append(h.keys[:index], h.keys[index+1:]...) + } + h.removeRingNode(hash, nodeRepr) + } + + h.removeNode(nodeRepr) +} + +func (h *ConsistentHash) removeRingNode(hash uint64, nodeRepr string) { + if nodes, ok := h.ring[hash]; ok { + newNodes := nodes[:0] + for _, x := range nodes { + if repr(x) != nodeRepr { + newNodes = append(newNodes, x) + } + } + if len(newNodes) > 0 { + h.ring[hash] = newNodes + } else { + delete(h.ring, hash) + } + } +} + +func (h *ConsistentHash) addNode(nodeRepr string) { + h.nodes[nodeRepr] = lang.Placeholder +} + +func (h *ConsistentHash) containsNode(nodeRepr string) bool { + _, ok := h.nodes[nodeRepr] + return ok +} + +func (h *ConsistentHash) removeNode(nodeRepr string) { + delete(h.nodes, nodeRepr) +} + +func innerRepr(node interface{}) string { + return fmt.Sprintf("%d:%v", prime, node) +} + +func repr(node interface{}) string { + return mapping.Repr(node) +} diff --git a/core/hash/consistenthash_test.go b/core/hash/consistenthash_test.go new file mode 100644 index 00000000..de344c63 --- /dev/null +++ b/core/hash/consistenthash_test.go @@ -0,0 +1,182 @@ +package hash + +import ( + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "zero/core/mathx" +) + +const ( + keySize = 20 + requestSize = 1000 +) + +func BenchmarkConsistentHashGet(b *testing.B) { + ch := NewConsistentHash() + for i := 0; i < keySize; i++ { + ch.Add("localhost:" + strconv.Itoa(i)) + } + + for i := 0; i < b.N; i++ { + ch.Get(i) + } +} + +func TestConsistentHash(t *testing.T) { + ch := NewCustomConsistentHash(0, nil) + val, ok := ch.Get("any") + assert.False(t, ok) + assert.Nil(t, val) + + for i := 0; i < keySize; i++ { + ch.AddWithReplicas("localhost:"+strconv.Itoa(i), minReplicas<<1) + } + + keys := make(map[string]int) + for i := 0; i < requestSize; i++ { + key, ok := ch.Get(requestSize + i) + assert.True(t, ok) + keys[key.(string)]++ + } + + mi := make(map[interface{}]int, len(keys)) + for k, v := range keys { + mi[k] = v + } + entropy := mathx.CalcEntropy(mi, requestSize) + assert.True(t, entropy > .95) +} + +func TestConsistentHashIncrementalTransfer(t *testing.T) { + prefix := "anything" + create := func() *ConsistentHash { + ch := NewConsistentHash() + for i := 0; i < keySize; i++ { + ch.Add(prefix + strconv.Itoa(i)) + } + return ch + } + + originCh := create() + keys := make(map[int]string, requestSize) + for i := 0; i < requestSize; i++ { + key, ok := originCh.Get(requestSize + i) + assert.True(t, ok) + assert.NotNil(t, key) + keys[i] = key.(string) + } + + node := fmt.Sprintf("%s%d", prefix, keySize) + for i := 0; i < 10; i++ { + laterCh := create() + laterCh.AddWithWeight(node, 10*(i+1)) + + for i := 0; i < requestSize; i++ { + key, ok := laterCh.Get(requestSize + i) + assert.True(t, ok) + assert.NotNil(t, key) + value := key.(string) + assert.True(t, value == keys[i] || value == node) + } + } +} + +func TestConsistentHashTransferOnFailure(t *testing.T) { + index := 41 + keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index) + var transferred int + for k, v := range newKeys { + if v != keys[k] { + transferred++ + } + } + + ratio := float32(transferred) / float32(requestSize) + assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio)) +} + +func TestConsistentHashLeastTransferOnFailure(t *testing.T) { + prefix := "localhost:" + index := 41 + keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index) + for k, v := range keys { + newV := newKeys[k] + if v != prefix+strconv.Itoa(index) { + assert.Equal(t, v, newV) + } + } +} + +func TestConsistentHash_Remove(t *testing.T) { + ch := NewConsistentHash() + ch.Add("first") + ch.Add("second") + ch.Remove("first") + for i := 0; i < 100; i++ { + val, ok := ch.Get(i) + assert.True(t, ok) + assert.Equal(t, "second", val) + } +} + +func TestConsistentHash_RemoveInterface(t *testing.T) { + const key = "any" + ch := NewConsistentHash() + node1 := newMockNode(key, 1) + node2 := newMockNode(key, 2) + ch.AddWithWeight(node1, 80) + ch.AddWithWeight(node2, 50) + assert.Equal(t, 1, len(ch.nodes)) + node, ok := ch.Get(1) + assert.True(t, ok) + assert.Equal(t, key, node.(*MockNode).Addr) + assert.Equal(t, 2, node.(*MockNode).Id) +} + +func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[int]string, map[int]string) { + ch := NewConsistentHash() + for i := 0; i < keySize; i++ { + ch.Add(prefix + strconv.Itoa(i)) + } + + keys := make(map[int]string, requestSize) + for i := 0; i < requestSize; i++ { + key, ok := ch.Get(requestSize + i) + assert.True(t, ok) + assert.NotNil(t, key) + keys[i] = key.(string) + } + + remove := fmt.Sprintf("%s%d", prefix, index) + ch.Remove(remove) + newKeys := make(map[int]string, requestSize) + for i := 0; i < requestSize; i++ { + key, ok := ch.Get(requestSize + i) + assert.True(t, ok) + assert.NotNil(t, key) + assert.NotEqual(t, remove, key) + newKeys[i] = key.(string) + } + + return keys, newKeys +} + +type MockNode struct { + Addr string + Id int +} + +func newMockNode(addr string, id int) *MockNode { + return &MockNode{ + Addr: addr, + Id: id, + } +} + +func (n *MockNode) String() string { + return n.Addr +} diff --git a/core/hash/hash.go b/core/hash/hash.go new file mode 100644 index 00000000..3cc562fd --- /dev/null +++ b/core/hash/hash.go @@ -0,0 +1,22 @@ +package hash + +import ( + "crypto/md5" + "fmt" + + "github.com/spaolacci/murmur3" +) + +func Hash(data []byte) uint64 { + return murmur3.Sum64(data) +} + +func Md5(data []byte) []byte { + digest := md5.New() + digest.Write(data) + return digest.Sum(nil) +} + +func Md5Hex(data []byte) string { + return fmt.Sprintf("%x", Md5(data)) +} diff --git a/core/hash/hash_test.go b/core/hash/hash_test.go new file mode 100644 index 00000000..02a760ca --- /dev/null +++ b/core/hash/hash_test.go @@ -0,0 +1,42 @@ +package hash + +import ( + "crypto/md5" + "fmt" + "hash/fnv" + "math/big" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + text = "hello, world!\n" + md5Digest = "910c8bc73110b0cd1bc5d2bcae782511" +) + +func TestMd5(t *testing.T) { + actual := fmt.Sprintf("%x", Md5([]byte(text))) + assert.Equal(t, md5Digest, actual) +} + +func BenchmarkHashFnv(b *testing.B) { + for i := 0; i < b.N; i++ { + h := fnv.New32() + new(big.Int).SetBytes(h.Sum([]byte(text))).Int64() + } +} + +func BenchmarkHashMd5(b *testing.B) { + for i := 0; i < b.N; i++ { + h := md5.New() + bytes := h.Sum([]byte(text)) + new(big.Int).SetBytes(bytes).Int64() + } +} + +func BenchmarkMurmur3(b *testing.B) { + for i := 0; i < b.N; i++ { + Hash([]byte(text)) + } +} diff --git a/core/httphandler/authhandler.go b/core/httphandler/authhandler.go new file mode 100644 index 00000000..5e9b3e75 --- /dev/null +++ b/core/httphandler/authhandler.go @@ -0,0 +1,129 @@ +package httphandler + +import ( + "context" + "net/http" + "net/http/httputil" + + "zero/core/httpsecurity" + "zero/core/logx" + + "github.com/dgrijalva/jwt-go" +) + +const ( + jwtAudience = "aud" + jwtExpire = "exp" + jwtId = "jti" + jwtIssueAt = "iat" + jwtIssuer = "iss" + jwtNotBefore = "nbf" + jwtSubject = "sub" +) + +type ( + AuthorizeOptions struct { + PrevSecret string + Callback UnauthorizedCallback + } + + UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error) + AuthorizeOption func(opts *AuthorizeOptions) +) + +func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler { + var authOpts AuthorizeOptions + for _, opt := range opts { + opt(&authOpts) + } + + parser := httpsecurity.NewTokenParser() + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, err := parser.ParseToken(r, secret, authOpts.PrevSecret) + if err != nil { + unauthorized(w, r, err, authOpts.Callback) + return + } + + if !token.Valid { + unauthorized(w, r, err, authOpts.Callback) + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + unauthorized(w, r, err, authOpts.Callback) + return + } + + ctx := r.Context() + for k, v := range claims { + switch k { + case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject: + // ignore the standard claims + default: + ctx = context.WithValue(ctx, k, v) + } + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func WithPrevSecret(secret string) AuthorizeOption { + return func(opts *AuthorizeOptions) { + opts.PrevSecret = secret + } +} + +func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption { + return func(opts *AuthorizeOptions) { + opts.Callback = callback + } +} + +func detailAuthLog(r *http.Request, reason string) { + // discard dump error, only for debug purpose + details, _ := httputil.DumpRequest(r, true) + logx.Errorf("authorize failed: %s\n=> %+v", reason, string(details)) +} + +func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) { + writer := newGuardedResponseWriter(w) + + detailAuthLog(r, err.Error()) + if callback != nil { + callback(writer, r, err) + } + writer.WriteHeader(http.StatusUnauthorized) +} + +type guardedResponseWriter struct { + writer http.ResponseWriter + wroteHeader bool +} + +func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter { + return &guardedResponseWriter{ + writer: w, + } +} + +func (grw *guardedResponseWriter) Header() http.Header { + return grw.writer.Header() +} + +func (grw *guardedResponseWriter) Write(body []byte) (int, error) { + return grw.writer.Write(body) +} + +func (grw *guardedResponseWriter) WriteHeader(statusCode int) { + if grw.wroteHeader { + return + } + + grw.wroteHeader = true + grw.writer.WriteHeader(statusCode) +} diff --git a/core/httphandler/authhandler_test.go b/core/httphandler/authhandler_test.go new file mode 100644 index 00000000..7e89a875 --- /dev/null +++ b/core/httphandler/authhandler_test.go @@ -0,0 +1,91 @@ +package httphandler + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/assert" +) + +func TestAuthHandlerFailed(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler := Authorize("B63F477D-BBA3-4E52-96D3-C0034C27694A", WithUnauthorizedCallback( + func(w http.ResponseWriter, r *http.Request, err error) { + w.Header().Set("X-Test", "test") + w.WriteHeader(http.StatusUnauthorized) + _, err = w.Write([]byte("content")) + assert.Nil(t, err) + }))( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusUnauthorized, resp.Code) +} + +func TestAuthHandler(t *testing.T) { + const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + token, err := buildToken(key, map[string]interface{}{ + "key": "value", + }, 3600) + assert.Nil(t, err) + req.Header.Set("Authorization", "Bearer "+token) + handler := Authorize(key)( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "test") + _, err := w.Write([]byte("content")) + assert.Nil(t, err) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +} + +func TestAuthHandlerWithPrevSecret(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + token, err := buildToken(key, map[string]interface{}{ + "key": "value", + }, 3600) + assert.Nil(t, err) + req.Header.Set("Authorization", "Bearer "+token) + handler := Authorize(key, WithPrevSecret(prevKey))( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "test") + _, err := w.Write([]byte("content")) + assert.Nil(t, err) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +} + +func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) { + now := time.Now().Unix() + claims := make(jwt.MapClaims) + claims["exp"] = now + seconds + claims["iat"] = now + for k, v := range payloads { + claims[k] = v + } + + token := jwt.New(jwt.SigningMethodHS256) + token.Claims = claims + + return token.SignedString([]byte(secretKey)) +} diff --git a/core/httphandler/breakerhandler.go b/core/httphandler/breakerhandler.go new file mode 100644 index 00000000..9c1325b2 --- /dev/null +++ b/core/httphandler/breakerhandler.go @@ -0,0 +1,41 @@ +package httphandler + +import ( + "fmt" + "net/http" + "strings" + + "zero/core/breaker" + "zero/core/httphandler/internal" + "zero/core/httpx" + "zero/core/logx" + "zero/core/stat" +) + +const breakerSeparator = "://" + +func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler { + brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator))) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + promise, err := brk.Allow() + if err != nil { + metrics.AddDrop() + logx.Errorf("[http] dropped, %s - %s - %s", + r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()) + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + cw := &internal.WithCodeResponseWriter{Writer: w} + defer func() { + if cw.Code < http.StatusInternalServerError { + promise.Accept() + } else { + promise.Reject(fmt.Sprintf("%d %s", cw.Code, http.StatusText(cw.Code))) + } + }() + next.ServeHTTP(cw, r) + }) + } +} diff --git a/core/httphandler/breakerhandler_test.go b/core/httphandler/breakerhandler_test.go new file mode 100644 index 00000000..230e2231 --- /dev/null +++ b/core/httphandler/breakerhandler_test.go @@ -0,0 +1,102 @@ +package httphandler + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "zero/core/logx" + "zero/core/stat" + + "github.com/stretchr/testify/assert" +) + +func init() { + logx.Disable() + stat.SetReporter(nil) +} + +func TestBreakerHandlerAccept(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + breakerHandler := BreakerHandler(http.MethodGet, "/", metrics) + handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "test") + _, err := w.Write([]byte("content")) + assert.Nil(t, err) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("X-Test", "test") + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +} + +func TestBreakerHandlerFail(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + breakerHandler := BreakerHandler(http.MethodGet, "/", metrics) + handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusBadGateway, resp.Code) +} + +func TestBreakerHandler_4XX(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + breakerHandler := BreakerHandler(http.MethodGet, "/", metrics) + handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + + for i := 0; i < 1000; i++ { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + } + + const tries = 100 + var pass int + for i := 0; i < tries; i++ { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + if resp.Code == http.StatusBadRequest { + pass++ + } + } + + assert.Equal(t, tries, pass) +} + +func TestBreakerHandlerReject(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + breakerHandler := BreakerHandler(http.MethodGet, "/", metrics) + handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + + for i := 0; i < 1000; i++ { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + } + + var drops int + for i := 0; i < 100; i++ { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + if resp.Code == http.StatusServiceUnavailable { + drops++ + } + } + + assert.True(t, drops >= 80, fmt.Sprintf("expected to be greater than 80, but got %d", drops)) +} diff --git a/core/httphandler/contentsecurityhandler.go b/core/httphandler/contentsecurityhandler.go new file mode 100644 index 00000000..404fc6cb --- /dev/null +++ b/core/httphandler/contentsecurityhandler.go @@ -0,0 +1,61 @@ +package httphandler + +import ( + "net/http" + "time" + + "zero/core/codec" + "zero/core/httphandler/internal" + "zero/core/httpx" + "zero/core/logx" +) + +const contentSecurity = "X-Content-Security" + +type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) + +func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration, + strict bool, callbacks ...UnsignedCallback) func(http.Handler) http.Handler { + if len(callbacks) == 0 { + callbacks = append(callbacks, handleVerificationFailure) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut: + header, err := internal.ParseContentSecurity(decrypters, r) + if err != nil { + logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s", + r.Header.Get(contentSecurity), err.Error()) + executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks) + } else if code := internal.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass { + logx.Infof("Signature verification failed, X-Content-Security: %s", + r.Header.Get(contentSecurity)) + executeCallbacks(w, r, next, strict, code, callbacks) + } else if r.ContentLength > 0 && header.Encrypted() { + CryptionHandler(header.Key)(next).ServeHTTP(w, r) + } else { + next.ServeHTTP(w, r) + } + default: + next.ServeHTTP(w, r) + } + }) + } +} + +func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, + code int, callbacks []UnsignedCallback) { + for _, callback := range callbacks { + callback(w, r, next, strict, code) + } +} + +func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) { + if strict { + w.WriteHeader(http.StatusUnauthorized) + } else { + next.ServeHTTP(w, r) + } +} diff --git a/core/httphandler/contentsecurityhandler_test.go b/core/httphandler/contentsecurityhandler_test.go new file mode 100644 index 00000000..b22dfeff --- /dev/null +++ b/core/httphandler/contentsecurityhandler_test.go @@ -0,0 +1,388 @@ +package httphandler + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "strings" + "testing" + "time" + + "zero/core/codec" + "zero/core/httpx" + + "github.com/stretchr/testify/assert" +) + +const timeDiff = time.Hour * 2 * 24 + +var ( + fingerprint = "12345" + pubKey = []byte(`-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE +eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH +miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR +my47YlhspwszKdRP+wIDAQAB +-----END PUBLIC KEY-----`) + priKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i +1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/ +r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB +AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH +Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY +J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0 +Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP +cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO +ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR +3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV +MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l +Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc +moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ= +-----END RSA PRIVATE KEY-----`) + key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D") +) + +type requestSettings struct { + method string + url string + body io.Reader + strict bool + crypt bool + requestUri string + timestamp int64 + fingerprint string + missHeader bool + signature string +} + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestContentSecurityHandler(t *testing.T) { + tests := []struct { + method string + url string + body string + strict bool + crypt bool + requestUri string + timestamp int64 + fingerprint string + missHeader bool + signature string + statusCode int + }{ + { + method: http.MethodGet, + url: "http://localhost/a/b?c=d&e=f", + strict: true, + crypt: false, + }, + { + method: http.MethodPost, + url: "http://localhost/a/b?c=d&e=f", + body: "hello", + strict: true, + crypt: false, + }, + { + method: http.MethodGet, + url: "http://localhost/a/b?c=d&e=f", + strict: true, + crypt: true, + }, + { + method: http.MethodPost, + url: "http://localhost/a/b?c=d&e=f", + body: "hello", + strict: true, + crypt: true, + }, + { + method: http.MethodGet, + url: "http://localhost/a/b?c=d&e=f", + strict: true, + crypt: true, + timestamp: time.Now().Add(timeDiff).Unix(), + statusCode: http.StatusUnauthorized, + }, + { + method: http.MethodPost, + url: "http://localhost/a/b?c=d&e=f", + body: "hello", + strict: true, + crypt: true, + timestamp: time.Now().Add(-timeDiff).Unix(), + statusCode: http.StatusUnauthorized, + }, + { + method: http.MethodPost, + url: "http://remotehost/", + body: "hello", + strict: true, + crypt: true, + requestUri: "http://localhost/a/b?c=d&e=f", + }, + { + method: http.MethodPost, + url: "http://localhost/a/b?c=d&e=f", + body: "hello", + strict: false, + crypt: true, + fingerprint: "badone", + }, + { + method: http.MethodPost, + url: "http://localhost/a/b?c=d&e=f", + body: "hello", + strict: true, + crypt: true, + timestamp: time.Now().Add(-timeDiff).Unix(), + fingerprint: "badone", + statusCode: http.StatusUnauthorized, + }, + { + method: http.MethodPost, + url: "http://localhost/a/b?c=d&e=f", + body: "hello", + strict: true, + crypt: true, + missHeader: true, + statusCode: http.StatusUnauthorized, + }, + { + method: http.MethodHead, + url: "http://localhost/a/b?c=d&e=f", + strict: true, + crypt: false, + }, + { + method: http.MethodGet, + url: "http://localhost/a/b?c=d&e=f", + strict: true, + crypt: false, + signature: "badone", + statusCode: http.StatusUnauthorized, + }, + } + + for _, test := range tests { + t.Run(test.url, func(t *testing.T) { + if test.statusCode == 0 { + test.statusCode = http.StatusOK + } + if len(test.fingerprint) == 0 { + test.fingerprint = fingerprint + } + if test.timestamp == 0 { + test.timestamp = time.Now().Unix() + } + + func() { + keyFile, err := createTempFile(priKey) + defer os.Remove(keyFile) + + assert.Nil(t, err) + decrypter, err := codec.NewRsaDecrypter(keyFile) + assert.Nil(t, err) + contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{ + fingerprint: decrypter, + }, time.Hour, test.strict) + handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + + var reader io.Reader + if len(test.body) > 0 { + reader = strings.NewReader(test.body) + } + setting := requestSettings{ + method: test.method, + url: test.url, + body: reader, + strict: test.strict, + crypt: test.crypt, + requestUri: test.requestUri, + timestamp: test.timestamp, + fingerprint: test.fingerprint, + missHeader: test.missHeader, + signature: test.signature, + } + req, err := buildRequest(setting) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, test.statusCode, resp.Code) + }() + }) + } +} + +func TestContentSecurityHandler_UnsignedCallback(t *testing.T) { + keyFile, err := createTempFile(priKey) + defer os.Remove(keyFile) + + assert.Nil(t, err) + decrypter, err := codec.NewRsaDecrypter(keyFile) + assert.Nil(t, err) + contentSecurityHandler := ContentSecurityHandler( + map[string]codec.RsaDecrypter{ + fingerprint: decrypter, + }, + time.Hour, + true, + func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) { + w.WriteHeader(http.StatusOK) + }) + handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + setting := requestSettings{ + method: http.MethodGet, + url: "http://localhost/a/b?c=d&e=f", + signature: "badone", + } + req, err := buildRequest(setting) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + +func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) { + keyFile, err := createTempFile(priKey) + defer os.Remove(keyFile) + + assert.Nil(t, err) + decrypter, err := codec.NewRsaDecrypter(keyFile) + assert.Nil(t, err) + contentSecurityHandler := ContentSecurityHandler( + map[string]codec.RsaDecrypter{ + fingerprint: decrypter, + }, + time.Hour, + true, + func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) { + assert.Equal(t, httpx.CodeSignatureWrongTime, code) + w.WriteHeader(http.StatusOK) + }) + handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + var reader io.Reader + reader = strings.NewReader("hello") + setting := requestSettings{ + method: http.MethodPost, + url: "http://localhost/a/b?c=d&e=f", + body: reader, + strict: true, + crypt: true, + timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(), + fingerprint: fingerprint, + } + req, err := buildRequest(setting) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + +func buildRequest(rs requestSettings) (*http.Request, error) { + var bodyStr string + var err error + + if rs.crypt && rs.body != nil { + var buf bytes.Buffer + io.Copy(&buf, rs.body) + bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes()) + if err != nil { + return nil, err + } + bodyStr = base64.StdEncoding.EncodeToString(bodyBytes) + } + + r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr)) + if len(rs.signature) == 0 { + sha := sha256.New() + sha.Write([]byte(bodyStr)) + bodySign := fmt.Sprintf("%x", sha.Sum(nil)) + var path string + var query string + if len(rs.requestUri) > 0 { + if u, err := url.Parse(rs.requestUri); err != nil { + return nil, err + } else { + path = u.Path + query = u.RawQuery + } + } else { + path = r.URL.Path + query = r.URL.RawQuery + } + contentOfSign := strings.Join([]string{ + strconv.FormatInt(rs.timestamp, 10), + rs.method, + path, + query, + bodySign, + }, "\n") + rs.signature = codec.HmacBase64([]byte(key), contentOfSign) + } + + var mode string + if rs.crypt { + mode = "1" + } else { + mode = "0" + } + content := strings.Join([]string{ + "version=v1", + "type=" + mode, + fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)), + "time=" + strconv.FormatInt(rs.timestamp, 10), + }, "; ") + + encrypter, err := codec.NewRsaEncrypter([]byte(pubKey)) + if err != nil { + log.Fatal(err) + } + + output, err := encrypter.Encrypt([]byte(content)) + if err != nil { + log.Fatal(err) + } + + encryptedContent := base64.StdEncoding.EncodeToString(output) + if !rs.missHeader { + r.Header.Set(httpx.ContentSecurity, strings.Join([]string{ + fmt.Sprintf("key=%s", rs.fingerprint), + "secret=" + encryptedContent, + "signature=" + rs.signature, + }, "; ")) + } + if len(rs.requestUri) > 0 { + r.Header.Set("X-Request-Uri", rs.requestUri) + } + + return r, nil +} + +func createTempFile(body []byte) (string, error) { + tmpFile, err := ioutil.TempFile(os.TempDir(), "go-unit-*.tmp") + if err != nil { + return "", err + } else { + tmpFile.Close() + } + + err = ioutil.WriteFile(tmpFile.Name(), body, os.ModePerm) + if err != nil { + return "", err + } + + return tmpFile.Name(), nil +} diff --git a/core/httphandler/cryptionhandler.go b/core/httphandler/cryptionhandler.go new file mode 100644 index 00000000..7aeccd5f --- /dev/null +++ b/core/httphandler/cryptionhandler.go @@ -0,0 +1,101 @@ +package httphandler + +import ( + "bytes" + "encoding/base64" + "io" + "io/ioutil" + "net/http" + + "zero/core/codec" + "zero/core/logx" +) + +const maxBytes = 1 << 20 // 1 MiB + +func CryptionHandler(key []byte) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cw := newCryptionResponseWriter(w) + defer cw.flush(key) + + if r.ContentLength <= 0 { + next.ServeHTTP(cw, r) + return + } + + if err := decryptBody(key, r); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + next.ServeHTTP(cw, r) + }) + } +} + +func decryptBody(key []byte, r *http.Request) error { + content, err := ioutil.ReadAll(io.LimitReader(r.Body, maxBytes)) + if err != nil { + return err + } + + content, err = base64.StdEncoding.DecodeString(string(content)) + if err != nil { + return err + } + + output, err := codec.EcbDecrypt(key, content) + if err != nil { + return err + } + + var buf bytes.Buffer + buf.Write(output) + r.Body = ioutil.NopCloser(&buf) + + return nil +} + +type cryptionResponseWriter struct { + http.ResponseWriter + buf *bytes.Buffer +} + +func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter { + return &cryptionResponseWriter{ + ResponseWriter: w, + buf: new(bytes.Buffer), + } +} + +func (w *cryptionResponseWriter) Header() http.Header { + return w.ResponseWriter.Header() +} + +func (w *cryptionResponseWriter) Write(p []byte) (int, error) { + return w.buf.Write(p) +} + +func (w *cryptionResponseWriter) WriteHeader(statusCode int) { + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *cryptionResponseWriter) flush(key []byte) { + if w.buf.Len() == 0 { + return + } + + content, err := codec.EcbEncrypt(key, w.buf.Bytes()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + body := base64.StdEncoding.EncodeToString(content) + if n, err := io.WriteString(w.ResponseWriter, body); err != nil { + logx.Errorf("write response failed, error: %s", err) + } else if n < len(content) { + logx.Errorf("actual bytes: %d, written bytes: %d", len(content), n) + } +} diff --git a/core/httphandler/cryptionhandler_test.go b/core/httphandler/cryptionhandler_test.go new file mode 100644 index 00000000..caaeaa51 --- /dev/null +++ b/core/httphandler/cryptionhandler_test.go @@ -0,0 +1,90 @@ +package httphandler + +import ( + "bytes" + "encoding/base64" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + + "zero/core/codec" + + "github.com/stretchr/testify/assert" +) + +const ( + reqText = "ping" + respText = "pong" +) + +var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`) + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestCryptionHandlerGet(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/any", nil) + handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(respText)) + w.Header().Set("X-Test", "test") + assert.Nil(t, err) + })) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + expect, err := codec.EcbEncrypt(aesKey, []byte(respText)) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "test", recorder.Header().Get("X-Test")) + assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) +} + +func TestCryptionHandlerPost(t *testing.T) { + var buf bytes.Buffer + enc, err := codec.EcbEncrypt(aesKey, []byte(reqText)) + assert.Nil(t, err) + buf.WriteString(base64.StdEncoding.EncodeToString(enc)) + + req := httptest.NewRequest(http.MethodPost, "/any", &buf) + handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + assert.Nil(t, err) + assert.Equal(t, reqText, string(body)) + + w.Write([]byte(respText)) + })) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + expect, err := codec.EcbEncrypt(aesKey, []byte(respText)) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) +} + +func TestCryptionHandlerPostBadEncryption(t *testing.T) { + var buf bytes.Buffer + enc, err := codec.EcbEncrypt(aesKey, []byte(reqText)) + assert.Nil(t, err) + buf.Write(enc) + + req := httptest.NewRequest(http.MethodPost, "/any", &buf) + handler := CryptionHandler(aesKey)(nil) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestCryptionHandlerWriteHeader(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/any", nil) + handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) +} diff --git a/core/httphandler/gunziphandler.go b/core/httphandler/gunziphandler.go new file mode 100644 index 00000000..ede0209c --- /dev/null +++ b/core/httphandler/gunziphandler.go @@ -0,0 +1,27 @@ +package httphandler + +import ( + "compress/gzip" + "net/http" + "strings" + + "zero/core/httpx" +) + +const gzipEncoding = "gzip" + +func GunzipHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) { + reader, err := gzip.NewReader(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + r.Body = reader + } + + next.ServeHTTP(w, r) + }) +} diff --git a/core/httphandler/gunziphandler_test.go b/core/httphandler/gunziphandler_test.go new file mode 100644 index 00000000..824c9fcd --- /dev/null +++ b/core/httphandler/gunziphandler_test.go @@ -0,0 +1,66 @@ +package httphandler + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "zero/core/codec" + "zero/core/httpx" + + "github.com/stretchr/testify/assert" +) + +func TestGunzipHandler(t *testing.T) { + const message = "hello world" + var wg sync.WaitGroup + wg.Add(1) + handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + assert.Nil(t, err) + assert.Equal(t, string(body), message) + wg.Done() + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost", + bytes.NewReader(codec.Gzip([]byte(message)))) + req.Header.Set(httpx.ContentEncoding, gzipEncoding) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + wg.Wait() +} + +func TestGunzipHandler_NoGzip(t *testing.T) { + const message = "hello world" + var wg sync.WaitGroup + wg.Add(1) + handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + assert.Nil(t, err) + assert.Equal(t, string(body), message) + wg.Done() + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost", + strings.NewReader(message)) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + wg.Wait() +} + +func TestGunzipHandler_NoGzipButTelling(t *testing.T) { + const message = "hello world" + handler := GunzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + req := httptest.NewRequest(http.MethodPost, "http://localhost", + strings.NewReader(message)) + req.Header.Set(httpx.ContentEncoding, gzipEncoding) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusBadRequest, resp.Code) +} diff --git a/core/httphandler/internal/contentsecurity.go b/core/httphandler/internal/contentsecurity.go new file mode 100644 index 00000000..e6f97944 --- /dev/null +++ b/core/httphandler/internal/contentsecurity.go @@ -0,0 +1,147 @@ +package internal + +import ( + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "zero/core/codec" + "zero/core/httpx" + "zero/core/iox" + "zero/core/logx" +) + +const ( + requestUriHeader = "X-Request-Uri" + signatureField = "signature" + timeField = "time" +) + +var ( + ErrInvalidContentType = errors.New("invalid content type") + ErrInvalidHeader = errors.New("invalid X-Content-Security header") + ErrInvalidKey = errors.New("invalid key") + ErrInvalidPublicKey = errors.New("invalid public key") + ErrInvalidSecret = errors.New("invalid secret") +) + +type ContentSecurityHeader struct { + Key []byte + Timestamp string + ContentType int + Signature string +} + +func (h *ContentSecurityHeader) Encrypted() bool { + return h.ContentType == httpx.CryptionType +} + +func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) ( + *ContentSecurityHeader, error) { + contentSecurity := r.Header.Get(httpx.ContentSecurity) + attrs := httpx.ParseHeader(contentSecurity) + fingerprint := attrs[httpx.KeyField] + secret := attrs[httpx.SecretField] + signature := attrs[signatureField] + + if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 { + return nil, ErrInvalidHeader + } + + decrypter, ok := decrypters[fingerprint] + if !ok { + return nil, ErrInvalidPublicKey + } + + decryptedSecret, err := decrypter.DecryptBase64(secret) + if err != nil { + return nil, ErrInvalidSecret + } + + attrs = httpx.ParseHeader(string(decryptedSecret)) + base64Key := attrs[httpx.KeyField] + timestamp := attrs[timeField] + contentType := attrs[httpx.TypeField] + + key, err := base64.StdEncoding.DecodeString(base64Key) + if err != nil { + return nil, ErrInvalidKey + } + + cType, err := strconv.Atoi(contentType) + if err != nil { + return nil, ErrInvalidContentType + } + + return &ContentSecurityHeader{ + Key: key, + Timestamp: timestamp, + ContentType: cType, + Signature: signature, + }, nil +} + +func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int { + seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64) + if err != nil { + return httpx.CodeSignatureInvalidHeader + } + + now := time.Now().Unix() + toleranceSeconds := int64(tolerance.Seconds()) + if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds { + return httpx.CodeSignatureWrongTime + } + + reqPath, reqQuery := getPathQuery(r) + signContent := strings.Join([]string{ + securityHeader.Timestamp, + r.Method, + reqPath, + reqQuery, + computeBodySignature(r), + }, "\n") + actualSignature := codec.HmacBase64(securityHeader.Key, signContent) + + passed := securityHeader.Signature == actualSignature + if !passed { + logx.Infof("signature different, expect: %s, actual: %s", + securityHeader.Signature, actualSignature) + } + + if passed { + return httpx.CodeSignaturePass + } else { + return httpx.CodeSignatureInvalidToken + } +} + +func computeBodySignature(r *http.Request) string { + var dup io.ReadCloser + r.Body, dup = iox.DupReadCloser(r.Body) + sha := sha256.New() + io.Copy(sha, r.Body) + r.Body = dup + return fmt.Sprintf("%x", sha.Sum(nil)) +} + +func getPathQuery(r *http.Request) (string, string) { + requestUri := r.Header.Get(requestUriHeader) + if len(requestUri) == 0 { + return r.URL.Path, r.URL.RawQuery + } + + uri, err := url.Parse(requestUri) + if err != nil { + return r.URL.Path, r.URL.RawQuery + } + + return uri.Path, uri.RawQuery +} diff --git a/core/httphandler/internal/withcoderesponsewriter.go b/core/httphandler/internal/withcoderesponsewriter.go new file mode 100644 index 00000000..5631cc22 --- /dev/null +++ b/core/httphandler/internal/withcoderesponsewriter.go @@ -0,0 +1,21 @@ +package internal + +import "net/http" + +type WithCodeResponseWriter struct { + Writer http.ResponseWriter + Code int +} + +func (w *WithCodeResponseWriter) Header() http.Header { + return w.Writer.Header() +} + +func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) { + return w.Writer.Write(bytes) +} + +func (w *WithCodeResponseWriter) WriteHeader(code int) { + w.Writer.WriteHeader(code) + w.Code = code +} diff --git a/core/httphandler/loghandler.go b/core/httphandler/loghandler.go new file mode 100644 index 00000000..49e02a83 --- /dev/null +++ b/core/httphandler/loghandler.go @@ -0,0 +1,166 @@ +package httphandler + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httputil" + "time" + + "zero/core/httplog" + "zero/core/httpx" + "zero/core/iox" + "zero/core/logx" + "zero/core/timex" + "zero/core/utils" +) + +const slowThreshold = time.Millisecond * 500 + +type LoggedResponseWriter struct { + w http.ResponseWriter + r *http.Request + code int +} + +func (w *LoggedResponseWriter) Header() http.Header { + return w.w.Header() +} + +func (w *LoggedResponseWriter) Write(bytes []byte) (int, error) { + return w.w.Write(bytes) +} + +func (w *LoggedResponseWriter) WriteHeader(code int) { + w.w.WriteHeader(code) + w.code = code +} + +func LogHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timer := utils.NewElapsedTimer() + logs := new(httplog.LogCollector) + lrw := LoggedResponseWriter{ + w: w, + r: r, + code: http.StatusOK, + } + + var dup io.ReadCloser + r.Body, dup = iox.DupReadCloser(r.Body) + next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs))) + r.Body = dup + logBrief(r, lrw.code, timer, logs) + }) +} + +type DetailLoggedResponseWriter struct { + writer *LoggedResponseWriter + buf *bytes.Buffer +} + +func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buffer) *DetailLoggedResponseWriter { + return &DetailLoggedResponseWriter{ + writer: writer, + buf: buf, + } +} + +func (w *DetailLoggedResponseWriter) Header() http.Header { + return w.writer.Header() +} + +func (w *DetailLoggedResponseWriter) Write(bs []byte) (int, error) { + w.buf.Write(bs) + return w.writer.Write(bs) +} + +func (w *DetailLoggedResponseWriter) WriteHeader(code int) { + w.writer.WriteHeader(code) +} + +func DetailedLogHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timer := utils.NewElapsedTimer() + var buf bytes.Buffer + lrw := newDetailLoggedResponseWriter(&LoggedResponseWriter{ + w: w, + r: r, + code: http.StatusOK, + }, &buf) + + var dup io.ReadCloser + r.Body, dup = iox.DupReadCloser(r.Body) + logs := new(httplog.LogCollector) + next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs))) + r.Body = dup + logDetails(r, lrw, timer, logs) + }) +} + +func dumpRequest(r *http.Request) string { + reqContent, err := httputil.DumpRequest(r, true) + if err != nil { + return err.Error() + } else { + return string(reqContent) + } +} + +func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *httplog.LogCollector) { + var buf bytes.Buffer + duration := timer.Duration() + buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s", + code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))) + if duration > slowThreshold { + logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)", + code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)) + } + + ok := isOkResponse(code) + if !ok { + buf.WriteString(fmt.Sprintf("\n%s", dumpRequest(r))) + } + + body := logs.Flush() + if len(body) > 0 { + buf.WriteString(fmt.Sprintf("\n%s", body)) + } + + if ok { + logx.Info(buf.String()) + } else { + logx.Error(buf.String()) + } +} + +func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer, + logs *httplog.LogCollector) { + var buf bytes.Buffer + duration := timer.Duration() + buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n", + response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r))) + if duration > slowThreshold { + logx.Slowf("[HTTP] %d - %s - slowcall(%s)\n=> %s\n", + response.writer.code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)) + } + + body := logs.Flush() + if len(body) > 0 { + buf.WriteString(fmt.Sprintf("%s\n", body)) + } + + respBuf := response.buf.Bytes() + if len(respBuf) > 0 { + buf.WriteString(fmt.Sprintf("<= %s", respBuf)) + } + + logx.Info(buf.String()) +} + +func isOkResponse(code int) bool { + // not server error + return code < http.StatusInternalServerError +} diff --git a/core/httphandler/loghandler_test.go b/core/httphandler/loghandler_test.go new file mode 100644 index 00000000..aee7f520 --- /dev/null +++ b/core/httphandler/loghandler_test.go @@ -0,0 +1,74 @@ +package httphandler + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "zero/core/httplog" +) + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestLogHandler(t *testing.T) { + handlers := []func(handler http.Handler) http.Handler{ + LogHandler, + DetailedLogHandler, + } + + for _, logHandler := range handlers { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Context().Value(httplog.LogContext).(*httplog.LogCollector).Append("anything") + w.Header().Set("X-Test", "test") + w.WriteHeader(http.StatusServiceUnavailable) + _, err := w.Write([]byte("content")) + assert.Nil(t, err) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) + } +} + +func TestLogHandlerSlow(t *testing.T) { + handlers := []func(handler http.Handler) http.Handler{ + LogHandler, + DetailedLogHandler, + } + + for _, logHandler := range handlers { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(slowThreshold + time.Millisecond*50) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + } +} + +func BenchmarkLogHandler(b *testing.B) { + b.ReportAllocs() + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < b.N; i++ { + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + } +} diff --git a/core/httphandler/maxbyteshandler.go b/core/httphandler/maxbyteshandler.go new file mode 100644 index 00000000..bcb534d1 --- /dev/null +++ b/core/httphandler/maxbyteshandler.go @@ -0,0 +1,27 @@ +package httphandler + +import ( + "net/http" + + "zero/core/httplog" +) + +func MaxBytesHandler(n int64) func(http.Handler) http.Handler { + if n <= 0 { + return func(next http.Handler) http.Handler { + return next + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength > n { + httplog.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d", + n, r.ContentLength, http.StatusRequestEntityTooLarge) + w.WriteHeader(http.StatusRequestEntityTooLarge) + } else { + next.ServeHTTP(w, r) + } + }) + } +} diff --git a/core/httphandler/maxbyteshandler_test.go b/core/httphandler/maxbyteshandler_test.go new file mode 100644 index 00000000..760a6954 --- /dev/null +++ b/core/httphandler/maxbyteshandler_test.go @@ -0,0 +1,37 @@ +package httphandler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMaxBytesHandler(t *testing.T) { + maxb := MaxBytesHandler(10) + handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + req := httptest.NewRequest(http.MethodPost, "http://localhost", + bytes.NewBufferString("123456789012345")) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusRequestEntityTooLarge, resp.Code) + + req = httptest.NewRequest(http.MethodPost, "http://localhost", bytes.NewBufferString("12345")) + resp = httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + +func TestMaxBytesHandlerNoLimit(t *testing.T) { + maxb := MaxBytesHandler(-1) + handler := maxb(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + req := httptest.NewRequest(http.MethodPost, "http://localhost", + bytes.NewBufferString("123456789012345")) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} diff --git a/core/httphandler/maxconnshandler.go b/core/httphandler/maxconnshandler.go new file mode 100644 index 00000000..c96c53ef --- /dev/null +++ b/core/httphandler/maxconnshandler.go @@ -0,0 +1,37 @@ +package httphandler + +import ( + "net/http" + + "zero/core/httplog" + "zero/core/logx" + "zero/core/syncx" +) + +func MaxConns(n int) func(http.Handler) http.Handler { + if n <= 0 { + return func(next http.Handler) http.Handler { + return next + } + } + + return func(next http.Handler) http.Handler { + latchLimiter := syncx.NewLimit(n) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if latchLimiter.TryBorrow() { + defer func() { + if err := latchLimiter.Return(); err != nil { + logx.Error(err) + } + }() + + next.ServeHTTP(w, r) + } else { + httplog.Errorf(r, "Concurrent connections over %d, rejected with code %d", + n, http.StatusServiceUnavailable) + w.WriteHeader(http.StatusServiceUnavailable) + } + }) + } +} diff --git a/core/httphandler/maxconnshandler_test.go b/core/httphandler/maxconnshandler_test.go new file mode 100644 index 00000000..979cb15b --- /dev/null +++ b/core/httphandler/maxconnshandler_test.go @@ -0,0 +1,80 @@ +package httphandler + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "zero/core/lang" + + "github.com/stretchr/testify/assert" +) + +const conns = 4 + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestMaxConnsHandler(t *testing.T) { + var waitGroup sync.WaitGroup + waitGroup.Add(conns) + done := make(chan lang.PlaceholderType) + defer close(done) + + maxConns := MaxConns(conns) + handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + waitGroup.Done() + <-done + })) + + for i := 0; i < conns; i++ { + go func() { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler.ServeHTTP(httptest.NewRecorder(), req) + }() + } + + waitGroup.Wait() + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) +} + +func TestWithoutMaxConnsHandler(t *testing.T) { + const ( + key = "block" + value = "1" + ) + var waitGroup sync.WaitGroup + waitGroup.Add(conns) + done := make(chan lang.PlaceholderType) + defer close(done) + + maxConns := MaxConns(0) + handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val := r.Header.Get(key) + if val == value { + waitGroup.Done() + <-done + } + })) + + for i := 0; i < conns; i++ { + go func() { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(key, value) + handler.ServeHTTP(httptest.NewRecorder(), req) + }() + } + + waitGroup.Wait() + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} diff --git a/core/httphandler/metrichandler.go b/core/httphandler/metrichandler.go new file mode 100644 index 00000000..61ffb9fb --- /dev/null +++ b/core/httphandler/metrichandler.go @@ -0,0 +1,23 @@ +package httphandler + +import ( + "net/http" + + "zero/core/stat" + "zero/core/timex" +) + +func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startTime := timex.Now() + defer func() { + metrics.Add(stat.Task{ + Duration: timex.Since(startTime), + }) + }() + + next.ServeHTTP(w, r) + }) + } +} diff --git a/core/httphandler/metrichandler_test.go b/core/httphandler/metrichandler_test.go new file mode 100644 index 00000000..1ad50360 --- /dev/null +++ b/core/httphandler/metrichandler_test.go @@ -0,0 +1,24 @@ +package httphandler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "zero/core/stat" + + "github.com/stretchr/testify/assert" +) + +func TestMetricHandler(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + metricHandler := MetricHandler(metrics) + handler := metricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} diff --git a/core/httphandler/prommetrichandler.go b/core/httphandler/prommetrichandler.go new file mode 100644 index 00000000..f15fa3d8 --- /dev/null +++ b/core/httphandler/prommetrichandler.go @@ -0,0 +1,47 @@ +package httphandler + +import ( + "net/http" + "strconv" + "time" + + "zero/core/httphandler/internal" + "zero/core/metric" + "zero/core/timex" +) + +const serverNamespace = "http_server" + +var ( + metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{ + Namespace: serverNamespace, + Subsystem: "requests", + Name: "duration_ms", + Help: "http server requests duration(ms).", + Labels: []string{"path"}, + Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000}, + }) + + metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{ + Namespace: serverNamespace, + Subsystem: "requests", + Name: "code_total", + Help: "http server requests error count.", + Labels: []string{"path", "code"}, + }) +) + +func PromMetricHandler(path string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startTime := timex.Now() + cw := &internal.WithCodeResponseWriter{Writer: w} + defer func() { + metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path) + metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code)) + }() + + next.ServeHTTP(cw, r) + }) + } +} diff --git a/core/httphandler/prommetrichandler_test.go b/core/httphandler/prommetrichandler_test.go new file mode 100644 index 00000000..156419ea --- /dev/null +++ b/core/httphandler/prommetrichandler_test.go @@ -0,0 +1,21 @@ +package httphandler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPromMetricHandler(t *testing.T) { + promMetricHandler := PromMetricHandler("/user/login") + handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} diff --git a/core/httphandler/recoverhandler.go b/core/httphandler/recoverhandler.go new file mode 100644 index 00000000..aca99c86 --- /dev/null +++ b/core/httphandler/recoverhandler.go @@ -0,0 +1,22 @@ +package httphandler + +import ( + "fmt" + "net/http" + "runtime/debug" + + "zero/core/httplog" +) + +func RecoverHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if result := recover(); result != nil { + httplog.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack())) + w.WriteHeader(http.StatusInternalServerError) + } + }() + + next.ServeHTTP(w, r) + }) +} diff --git a/core/httphandler/recoverhandler_test.go b/core/httphandler/recoverhandler_test.go new file mode 100644 index 00000000..1281bdec --- /dev/null +++ b/core/httphandler/recoverhandler_test.go @@ -0,0 +1,36 @@ +package httphandler + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestWithPanic(t *testing.T) { + handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("whatever") + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusInternalServerError, resp.Code) +} + +func TestWithoutPanic(t *testing.T) { + handler := RecoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} diff --git a/core/httphandler/sheddinghandler.go b/core/httphandler/sheddinghandler.go new file mode 100644 index 00000000..6fac2c20 --- /dev/null +++ b/core/httphandler/sheddinghandler.go @@ -0,0 +1,63 @@ +package httphandler + +import ( + "net/http" + "sync" + + "zero/core/httphandler/internal" + "zero/core/httpx" + "zero/core/load" + "zero/core/logx" + "zero/core/stat" +) + +const serviceType = "api" + +var ( + sheddingStat *load.SheddingStat + lock sync.Mutex +) + +func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler { + if shedder == nil { + return func(next http.Handler) http.Handler { + return next + } + } + + ensureSheddingStat() + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sheddingStat.IncrementTotal() + promise, err := shedder.Allow() + if err != nil { + metrics.AddDrop() + sheddingStat.IncrementDrop() + logx.Errorf("[http] dropped, %s - %s - %s", + r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()) + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + cw := &internal.WithCodeResponseWriter{Writer: w} + defer func() { + if cw.Code == http.StatusServiceUnavailable { + promise.Fail() + } else { + sheddingStat.IncrementPass() + promise.Pass() + } + }() + next.ServeHTTP(cw, r) + }) + } +} + +func ensureSheddingStat() { + lock.Lock() + if sheddingStat == nil { + sheddingStat = load.NewSheddingStat(serviceType) + } + lock.Unlock() +} diff --git a/core/httphandler/sheddinghandler_test.go b/core/httphandler/sheddinghandler_test.go new file mode 100644 index 00000000..9eb741f7 --- /dev/null +++ b/core/httphandler/sheddinghandler_test.go @@ -0,0 +1,105 @@ +package httphandler + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + + "zero/core/load" + "zero/core/stat" + + "github.com/stretchr/testify/assert" +) + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestSheddingHandlerAccept(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + shedder := mockShedder{ + allow: true, + } + sheddingHandler := SheddingHandler(shedder, metrics) + handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "test") + _, err := w.Write([]byte("content")) + assert.Nil(t, err) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("X-Test", "test") + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +} + +func TestSheddingHandlerFail(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + shedder := mockShedder{ + allow: true, + } + sheddingHandler := SheddingHandler(shedder, metrics) + handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) +} + +func TestSheddingHandlerReject(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + shedder := mockShedder{ + allow: false, + } + sheddingHandler := SheddingHandler(shedder, metrics) + handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) +} + +func TestSheddingHandlerNoShedding(t *testing.T) { + metrics := stat.NewMetrics("unit-test") + sheddingHandler := SheddingHandler(nil, metrics) + handler := sheddingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + +type mockShedder struct { + allow bool +} + +func (s mockShedder) Allow() (load.Promise, error) { + if s.allow { + return mockPromise{}, nil + } else { + return nil, load.ErrServiceOverloaded + } +} + +type mockPromise struct { +} + +func (p mockPromise) Pass() { +} + +func (p mockPromise) Fail() { +} diff --git a/core/httphandler/timeouthandler.go b/core/httphandler/timeouthandler.go new file mode 100644 index 00000000..f57d5ead --- /dev/null +++ b/core/httphandler/timeouthandler.go @@ -0,0 +1,18 @@ +package httphandler + +import ( + "net/http" + "time" +) + +const reason = "Request Timeout" + +func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + if duration > 0 { + return http.TimeoutHandler(next, duration, reason) + } else { + return next + } + } +} diff --git a/core/httphandler/timeouthandler_test.go b/core/httphandler/timeouthandler_test.go new file mode 100644 index 00000000..4f4d086c --- /dev/null +++ b/core/httphandler/timeouthandler_test.go @@ -0,0 +1,52 @@ +package httphandler + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestTimeout(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Millisecond) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Minute) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) +} + +func TestWithinTimeout(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Second) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + +func TestWithoutTimeout(t *testing.T) { + timeoutHandler := TimeoutHandler(0) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} diff --git a/core/httphandler/tracinghandler.go b/core/httphandler/tracinghandler.go new file mode 100644 index 00000000..49456660 --- /dev/null +++ b/core/httphandler/tracinghandler.go @@ -0,0 +1,25 @@ +package httphandler + +import ( + "net/http" + + "zero/core/logx" + "zero/core/sysx" + "zero/core/trace" +) + +func TracingHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + carrier, err := trace.Extract(trace.HttpFormat, r.Header) + // ErrInvalidCarrier means no trace id was set in http header + if err != nil && err != trace.ErrInvalidCarrier { + logx.Error(err) + } + + ctx, span := trace.StartServerSpan(r.Context(), carrier, sysx.Hostname(), r.RequestURI) + defer span.Finish() + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) +} diff --git a/core/httphandler/tracinghandler_test.go b/core/httphandler/tracinghandler_test.go new file mode 100644 index 00000000..f1261f02 --- /dev/null +++ b/core/httphandler/tracinghandler_test.go @@ -0,0 +1,25 @@ +package httphandler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "zero/core/trace/tracespec" + + "github.com/stretchr/testify/assert" +) + +func TestTracingHandler(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("X-Trace-ID", "theid") + handler := TracingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + span, ok := r.Context().Value(tracespec.TracingKey).(tracespec.Trace) + assert.True(t, ok) + assert.Equal(t, "theid", span.TraceId()) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} diff --git a/core/httplog/log.go b/core/httplog/log.go new file mode 100644 index 00000000..76ac50ed --- /dev/null +++ b/core/httplog/log.go @@ -0,0 +1,84 @@ +package httplog + +import ( + "bytes" + "fmt" + "net/http" + "sync" + + "zero/core/httpx" + "zero/core/logx" +) + +const LogContext = "request_logs" + +type LogCollector struct { + Messages []string + lock sync.Mutex +} + +func (lc *LogCollector) Append(msg string) { + lc.lock.Lock() + lc.Messages = append(lc.Messages, msg) + lc.lock.Unlock() +} + +func (lc *LogCollector) Flush() string { + var buffer bytes.Buffer + + start := true + for _, message := range lc.takeAll() { + if start { + start = false + } else { + buffer.WriteByte('\n') + } + buffer.WriteString(message) + } + + return buffer.String() +} + +func (lc *LogCollector) takeAll() []string { + lc.lock.Lock() + messages := lc.Messages + lc.Messages = nil + lc.lock.Unlock() + + return messages +} + +func Error(r *http.Request, v ...interface{}) { + logx.ErrorCaller(1, format(r, v...)) +} + +func Errorf(r *http.Request, format string, v ...interface{}) { + logx.ErrorCaller(1, formatf(r, format, v...)) +} + +func Info(r *http.Request, v ...interface{}) { + appendLog(r, format(r, v...)) +} + +func Infof(r *http.Request, format string, v ...interface{}) { + appendLog(r, formatf(r, format, v...)) +} + +func appendLog(r *http.Request, message string) { + logs := r.Context().Value(LogContext) + if logs != nil { + logs.(*LogCollector).Append(message) + } +} + +func format(r *http.Request, v ...interface{}) string { + return formatWithReq(r, fmt.Sprint(v...)) +} + +func formatf(r *http.Request, format string, v ...interface{}) string { + return formatWithReq(r, fmt.Sprintf(format, v...)) +} + +func formatWithReq(r *http.Request, v string) string { + return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v) +} diff --git a/core/httplog/log_test.go b/core/httplog/log_test.go new file mode 100644 index 00000000..98d1858d --- /dev/null +++ b/core/httplog/log_test.go @@ -0,0 +1,38 @@ +package httplog + +import ( + "context" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInfo(t *testing.T) { + collector := new(LogCollector) + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req = req.WithContext(context.WithValue(req.Context(), LogContext, collector)) + Info(req, "first") + Infof(req, "second %s", "third") + val := collector.Flush() + assert.True(t, strings.Contains(val, "first")) + assert.True(t, strings.Contains(val, "second")) + assert.True(t, strings.Contains(val, "third")) + assert.True(t, strings.Contains(val, "\n")) +} + +func TestError(t *testing.T) { + var writer strings.Builder + log.SetOutput(&writer) + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + Error(req, "first") + Errorf(req, "second %s", "third") + val := writer.String() + assert.True(t, strings.Contains(val, "first")) + assert.True(t, strings.Contains(val, "second")) + assert.True(t, strings.Contains(val, "third")) + assert.True(t, strings.Contains(val, "\n")) +} diff --git a/core/httprouter/patrouter.go b/core/httprouter/patrouter.go new file mode 100644 index 00000000..c8728d17 --- /dev/null +++ b/core/httprouter/patrouter.go @@ -0,0 +1,115 @@ +package httprouter + +import ( + "context" + "net/http" + "path" + "strings" + + "zero/core/search" +) + +const ( + allowHeader = "Allow" + allowMethodSeparator = ", " + pathVars = "pathVars" +) + +type PatRouter struct { + trees map[string]*search.Tree + notFound http.Handler +} + +func NewPatRouter() Router { + return &PatRouter{ + trees: make(map[string]*search.Tree), + } +} + +func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error { + if !validMethod(method) { + return ErrInvalidMethod + } + + if len(reqPath) == 0 || reqPath[0] != '/' { + return ErrInvalidPath + } + + cleanPath := path.Clean(reqPath) + if tree, ok := pr.trees[method]; ok { + return tree.Add(cleanPath, handler) + } else { + tree = search.NewTree() + pr.trees[method] = tree + return tree.Add(cleanPath, handler) + } +} + +func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + reqPath := path.Clean(r.URL.Path) + if tree, ok := pr.trees[r.Method]; ok { + if result, ok := tree.Search(reqPath); ok { + if len(result.Params) > 0 { + r = r.WithContext(context.WithValue(r.Context(), pathVars, result.Params)) + } + result.Item.(http.Handler).ServeHTTP(w, r) + return + } + } + + if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok { + w.Header().Set(allowHeader, allow) + w.WriteHeader(http.StatusMethodNotAllowed) + } else { + pr.handleNotFound(w, r) + } +} + +func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) { + pr.notFound = handler +} + +func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) { + if pr.notFound != nil { + pr.notFound.ServeHTTP(w, r) + } else { + http.NotFound(w, r) + } +} + +func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) { + var allows []string + + for treeMethod, tree := range pr.trees { + if treeMethod == method { + continue + } + + _, ok := tree.Search(path) + if ok { + allows = append(allows, treeMethod) + } + } + + if len(allows) > 0 { + return strings.Join(allows, allowMethodSeparator), true + } else { + return "", false + } +} + +func Vars(r *http.Request) map[string]string { + vars, ok := r.Context().Value(pathVars).(map[string]string) + if ok { + return vars + } + + return nil +} + +func validMethod(method string) bool { + return method == http.MethodDelete || method == http.MethodGet || + method == http.MethodHead || method == http.MethodOptions || + method == http.MethodPatch || method == http.MethodPost || + method == http.MethodPut +} diff --git a/core/httprouter/patrouter_test.go b/core/httprouter/patrouter_test.go new file mode 100644 index 00000000..cd075afa --- /dev/null +++ b/core/httprouter/patrouter_test.go @@ -0,0 +1,120 @@ +package httprouter + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockedResponseWriter struct { + code int +} + +func (m *mockedResponseWriter) Header() http.Header { + return http.Header{} +} + +func (m *mockedResponseWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +func (m *mockedResponseWriter) WriteHeader(code int) { + m.code = code +} + +func TestPatRouterHandleErrors(t *testing.T) { + tests := []struct { + method string + path string + err error + }{ + {"FAKE", "", ErrInvalidMethod}, + {"GET", "", ErrInvalidPath}, + } + + for _, test := range tests { + t.Run(test.method, func(t *testing.T) { + router := NewPatRouter() + err := router.Handle(test.method, test.path, nil) + assert.Error(t, ErrInvalidMethod, err) + }) + } +} + +func TestPatRouterNotFound(t *testing.T) { + var notFound bool + router := NewPatRouter() + router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + notFound = true + })) + router.Handle(http.MethodGet, "/a/b", nil) + r, _ := http.NewRequest(http.MethodGet, "/b/c", nil) + w := new(mockedResponseWriter) + router.ServeHTTP(w, r) + assert.True(t, notFound) +} + +func TestPatRouter(t *testing.T) { + tests := []struct { + method string + path string + expect bool + code int + err error + }{ + // we don't explicitly set status code, framework will do it. + {http.MethodGet, "/a/b", true, 0, nil}, + {http.MethodGet, "/a/b/", true, 0, nil}, + {http.MethodGet, "/a/b?a=b", true, 0, nil}, + {http.MethodGet, "/a/b/?a=b", true, 0, nil}, + {http.MethodGet, "/a/b/c?a=b", true, 0, nil}, + {http.MethodGet, "/b/d", false, http.StatusNotFound, nil}, + } + + for _, test := range tests { + t.Run(test.method+":"+test.path, func(t *testing.T) { + routed := false + router := NewPatRouter() + err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + routed = true + assert.Equal(t, 1, len(Vars(r))) + })) + assert.Nil(t, err) + err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + routed = true + assert.Nil(t, Vars(r)) + })) + assert.Nil(t, err) + err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + routed = true + })) + assert.Nil(t, err) + + w := new(mockedResponseWriter) + r, _ := http.NewRequest(test.method, test.path, nil) + router.ServeHTTP(w, r) + assert.Equal(t, test.expect, routed) + assert.Equal(t, test.code, w.code) + + if test.code == 0 { + r, _ = http.NewRequest(http.MethodPut, test.path, nil) + router.ServeHTTP(w, r) + assert.Equal(t, http.StatusMethodNotAllowed, w.code) + } + }) + } +} + +func BenchmarkPatRouter(b *testing.B) { + b.ReportAllocs() + + router := NewPatRouter() + router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + w := &mockedResponseWriter{} + r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil) + for i := 0; i < b.N; i++ { + router.ServeHTTP(w, r) + } +} diff --git a/core/httprouter/router.go b/core/httprouter/router.go new file mode 100644 index 00000000..d641c139 --- /dev/null +++ b/core/httprouter/router.go @@ -0,0 +1,24 @@ +package httprouter + +import ( + "errors" + "net/http" +) + +var ( + ErrInvalidMethod = errors.New("not a valid http method") + ErrInvalidPath = errors.New("path must begin with '/'") +) + +type ( + Route struct { + Path string + Handler http.HandlerFunc + } + + Router interface { + http.Handler + Handle(method string, path string, handler http.Handler) error + SetNotFoundHandler(handler http.Handler) + } +) diff --git a/core/httpsecurity/tokenparser.go b/core/httpsecurity/tokenparser.go new file mode 100644 index 00000000..32cb11d7 --- /dev/null +++ b/core/httpsecurity/tokenparser.go @@ -0,0 +1,122 @@ +package httpsecurity + +import ( + "net/http" + "sync" + "sync/atomic" + "time" + + "zero/core/timex" + + "github.com/dgrijalva/jwt-go" + "github.com/dgrijalva/jwt-go/request" +) + +const claimHistoryResetDuration = time.Hour * 24 + +type ( + ParseOption func(parser *TokenParser) + + TokenParser struct { + resetTime time.Duration + resetDuration time.Duration + history sync.Map + } +) + +func NewTokenParser(opts ...ParseOption) *TokenParser { + parser := &TokenParser{ + resetTime: timex.Now(), + resetDuration: claimHistoryResetDuration, + } + + for _, opt := range opts { + opt(parser) + } + + return parser +} + +func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) { + var token *jwt.Token + var err error + + if len(prevSecret) > 0 { + count := tp.loadCount(secret) + prevCount := tp.loadCount(prevSecret) + + var first, second string + if count > prevCount { + first = secret + second = prevSecret + } else { + first = prevSecret + second = secret + } + + token, err = tp.doParseToken(r, first) + if err != nil { + token, err = tp.doParseToken(r, second) + if err != nil { + return nil, err + } else { + tp.incrementCount(second) + } + } else { + tp.incrementCount(first) + } + } else { + token, err = tp.doParseToken(r, secret) + if err != nil { + return nil, err + } + } + + return token, nil +} + +func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) { + return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor, + func(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }, request.WithParser(newParser())) +} + +func (tp *TokenParser) incrementCount(secret string) { + now := timex.Now() + if tp.resetTime+tp.resetDuration < now { + tp.history.Range(func(key, value interface{}) bool { + tp.history.Delete(key) + return true + }) + } + + value, ok := tp.history.Load(secret) + if ok { + atomic.AddUint64(value.(*uint64), 1) + } else { + var count uint64 = 1 + tp.history.Store(secret, &count) + } +} + +func (tp *TokenParser) loadCount(secret string) uint64 { + value, ok := tp.history.Load(secret) + if ok { + return *value.(*uint64) + } + + return 0 +} + +func WithResetDuration(duration time.Duration) ParseOption { + return func(parser *TokenParser) { + parser.resetDuration = duration + } +} + +func newParser() *jwt.Parser { + return &jwt.Parser{ + UseJSONNumber: true, + } +} diff --git a/core/httpsecurity/tokenparser_test.go b/core/httpsecurity/tokenparser_test.go new file mode 100644 index 00000000..edeb74cb --- /dev/null +++ b/core/httpsecurity/tokenparser_test.go @@ -0,0 +1,87 @@ +package httpsecurity + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/assert" + + "zero/core/timex" +) + +func TestTokenParser(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + keys := []struct { + key string + prevKey string + }{ + { + key, + prevKey, + }, + { + key, + "", + }, + } + + for _, pair := range keys { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + token, err := buildToken(key, map[string]interface{}{ + "key": "value", + }, 3600) + assert.Nil(t, err) + req.Header.Set("Authorization", "Bearer "+token) + + parser := NewTokenParser(WithResetDuration(time.Minute)) + tok, err := parser.ParseToken(req, pair.key, pair.prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + } +} + +func TestTokenParser_Expired(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + token, err := buildToken(key, map[string]interface{}{ + "key": "value", + }, 3600) + assert.Nil(t, err) + req.Header.Set("Authorization", "Bearer "+token) + + parser := NewTokenParser(WithResetDuration(time.Second)) + tok, err := parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + parser.resetTime = timex.Now() - time.Hour + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) +} + +func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) { + now := time.Now().Unix() + claims := make(jwt.MapClaims) + claims["exp"] = now + seconds + claims["iat"] = now + for k, v := range payloads { + claims[k] = v + } + + token := jwt.New(jwt.SigningMethodHS256) + token.Claims = claims + + return token.SignedString([]byte(secretKey)) +} diff --git a/core/httpserver/server.go b/core/httpserver/server.go new file mode 100644 index 00000000..9bf431c3 --- /dev/null +++ b/core/httpserver/server.go @@ -0,0 +1,40 @@ +package httpserver + +import ( + "crypto/tls" + "fmt" + "net/http" +) + +func StartHttp(host string, port int, handler http.Handler) error { + addr := fmt.Sprintf("%s:%d", host, port) + server := buildHttpServer(addr, handler) + return StartServer(server) +} + +func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error { + addr := fmt.Sprintf("%s:%d", host, port) + if server, err := buildHttpsServer(addr, handler, certFile, keyFile); err != nil { + return err + } else { + return StartServer(server) + } +} + +func buildHttpServer(addr string, handler http.Handler) *http.Server { + return &http.Server{Addr: addr, Handler: handler} +} + +func buildHttpsServer(addr string, handler http.Handler, certFile, keyFile string) (*http.Server, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + config := tls.Config{Certificates: []tls.Certificate{cert}} + return &http.Server{ + Addr: addr, + Handler: handler, + TLSConfig: &config, + }, nil +} diff --git a/core/httpserver/starter.go b/core/httpserver/starter.go new file mode 100644 index 00000000..2cea7362 --- /dev/null +++ b/core/httpserver/starter.go @@ -0,0 +1,16 @@ +package httpserver + +import ( + "context" + "net/http" + + "zero/core/proc" +) + +func StartServer(srv *http.Server) error { + proc.AddWrapUpListener(func() { + srv.Shutdown(context.Background()) + }) + + return srv.ListenAndServe() +} diff --git a/core/httpx/constants.go b/core/httpx/constants.go new file mode 100644 index 00000000..3df99c8d --- /dev/null +++ b/core/httpx/constants.go @@ -0,0 +1,19 @@ +package httpx + +const ( + ApplicationJson = "application/json" + ContentEncoding = "Content-Encoding" + ContentSecurity = "X-Content-Security" + ContentType = "Content-Type" + KeyField = "key" + SecretField = "secret" + TypeField = "type" + CryptionType = 1 +) + +const ( + CodeSignaturePass = iota + CodeSignatureInvalidHeader + CodeSignatureWrongTime + CodeSignatureInvalidToken +) diff --git a/core/httpx/requests.go b/core/httpx/requests.go new file mode 100644 index 00000000..b2cb9623 --- /dev/null +++ b/core/httpx/requests.go @@ -0,0 +1,124 @@ +package httpx + +import ( + "errors" + "io" + "net/http" + "strings" + + "zero/core/httprouter" + "zero/core/mapping" +) + +const ( + multipartFormData = "multipart/form-data" + xForwardFor = "X-Forward-For" + formKey = "form" + pathKey = "path" + emptyJson = "{}" + maxMemory = 32 << 20 // 32MB + maxBodyLen = 8 << 20 // 8MB + separator = ";" + tokensInAttribute = 2 +) + +var ( + ErrBodylessRequest = errors.New("not a POST|PUT|PATCH request") + + formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) + pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) +) + +// Returns the peer address, supports X-Forward-For +func GetRemoteAddr(r *http.Request) string { + v := r.Header.Get(xForwardFor) + if len(v) > 0 { + return v + } + return r.RemoteAddr +} + +func Parse(r *http.Request, v interface{}) error { + if err := ParsePath(r, v); err != nil { + return err + } + + if err := ParseForm(r, v); err != nil { + return err + } + + return ParseJsonBody(r, v) +} + +// Parses the form request. +func ParseForm(r *http.Request, v interface{}) error { + if strings.Index(r.Header.Get(ContentType), multipartFormData) != -1 { + if err := r.ParseMultipartForm(maxMemory); err != nil { + return err + } + } else { + if err := r.ParseForm(); err != nil { + return err + } + } + + params := make(map[string]interface{}, len(r.Form)) + for name := range r.Form { + formValue := r.Form.Get(name) + if len(formValue) > 0 { + params[name] = formValue + } + } + + return formUnmarshaler.Unmarshal(params, v) +} + +func ParseHeader(headerValue string) map[string]string { + ret := make(map[string]string) + fields := strings.Split(headerValue, separator) + + for _, field := range fields { + field = strings.TrimSpace(field) + if len(field) == 0 { + continue + } + + kv := strings.SplitN(field, "=", tokensInAttribute) + if len(kv) != tokensInAttribute { + continue + } + + ret[kv[0]] = kv[1] + } + + return ret +} + +// Parses the post request which contains json in body. +func ParseJsonBody(r *http.Request, v interface{}) error { + var reader io.Reader + + if withJsonBody(r) { + reader = io.LimitReader(r.Body, maxBodyLen) + } else { + reader = strings.NewReader(emptyJson) + } + + return mapping.UnmarshalJsonReader(reader, v) +} + +// Parses the symbols reside in url path. +// Like http://localhost/bag/:name +func ParsePath(r *http.Request, v interface{}) error { + vars := httprouter.Vars(r) + m := make(map[string]interface{}, len(vars)) + for k, v := range vars { + m[k] = v + } + + return pathUnmarshaler.Unmarshal(m, v) +} + +func withJsonBody(r *http.Request) bool { + return r.ContentLength > 0 && strings.Index(r.Header.Get(ContentType), ApplicationJson) != -1 +} diff --git a/core/httpx/requests_test.go b/core/httpx/requests_test.go new file mode 100644 index 00000000..379f5508 --- /dev/null +++ b/core/httpx/requests_test.go @@ -0,0 +1,1032 @@ +package httpx + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "zero/core/httprouter" + + "github.com/stretchr/testify/assert" +) + +const ( + applicationJsonWithUtf8 = "application/json; charset=utf-8" + contentLength = "Content-Length" +) + +func TestGetRemoteAddr(t *testing.T) { + host := "8.8.8.8" + r, err := http.NewRequest(http.MethodGet, "/", strings.NewReader("")) + assert.Nil(t, err) + + r.Header.Set(xForwardFor, host) + assert.Equal(t, host, GetRemoteAddr(r)) +} + +func TestParseForm(t *testing.T) { + var v struct { + Name string `form:"name"` + Age int `form:"age"` + Percent float64 `form:"percent,optional"` + } + + r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", nil) + assert.Nil(t, err) + + err = Parse(r, &v) + assert.Nil(t, err) + assert.Equal(t, "hello", v.Name) + assert.Equal(t, 18, v.Age) + assert.Equal(t, 3.4, v.Percent) +} + +func TestParseFormOutOfRange(t *testing.T) { + var v struct { + Age int `form:"age,range=[10:20)"` + } + + tests := []struct { + url string + pass bool + }{ + { + url: "http://hello.com/a?age=5", + pass: false, + }, + { + url: "http://hello.com/a?age=10", + pass: true, + }, + { + url: "http://hello.com/a?age=15", + pass: true, + }, + { + url: "http://hello.com/a?age=20", + pass: false, + }, + { + url: "http://hello.com/a?age=28", + pass: false, + }, + } + + for _, test := range tests { + r, err := http.NewRequest(http.MethodGet, test.url, nil) + assert.Nil(t, err) + + err = Parse(r, &v) + if test.pass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + } +} + +func TestParseMultipartForm(t *testing.T) { + var v struct { + Name string `form:"name"` + Age int `form:"age"` + } + + body := strings.Replace(`----------------------------220477612388154780019383 +Content-Disposition: form-data; name="name" + +kevin +----------------------------220477612388154780019383 +Content-Disposition: form-data; name="age" + +18 +----------------------------220477612388154780019383--`, "\n", "\r\n", -1) + + r := httptest.NewRequest(http.MethodPost, "http://localhost:3333/", strings.NewReader(body)) + r.Header.Set(ContentType, "multipart/form-data; boundary=--------------------------220477612388154780019383") + + err := Parse(r, &v) + assert.Nil(t, err) + assert.Equal(t, "kevin", v.Name) + assert.Equal(t, 18, v.Age) +} + +func TestParseRequired(t *testing.T) { + v := struct { + Name string `form:"name"` + Percent float64 `form:"percent"` + }{} + + r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello", nil) + assert.Nil(t, err) + + err = Parse(r, &v) + assert.NotNil(t, err) +} + +func TestParseSlice(t *testing.T) { + body := `names=%5B%22first%22%2C%22second%22%5D` + reader := strings.NewReader(body) + r, err := http.NewRequest(http.MethodPost, "http://hello.com/", reader) + assert.Nil(t, err) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + v := struct { + Names []string `form:"names"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + assert.Equal(t, 2, len(v.Names)) + assert.Equal(t, "first", v.Names[0]) + assert.Equal(t, "second", v.Names[1]) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseJsonPost(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", + bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( + w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d:%s:%d:%s:%d", v.Name, v.Year, + v.Nickname, v.Zipcode, v.Location, v.Time)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017:whatever:200000:shanghai:20170912", rr.Body.String()) +} + +func TestParseJsonPostWithIntSlice(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017", + bytes.NewBufferString(`{"ages": [1, 2], "years": [3, 4]}`)) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( + w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Ages []int `json:"ages"` + Years []int64 `json:"years"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2}, v.Ages) + assert.ElementsMatch(t, []int64{3, 4}, v.Years) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseJsonPostError(t *testing.T) { + payload := `[{"abcd": "cdef"}]` + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", + bytes.NewBufferString(payload)) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseJsonPostInvalidRequest(t *testing.T) { + payload := `{"ages": ["cdef"]}` + r, err := http.NewRequest(http.MethodPost, "http://hello.com/", + bytes.NewBufferString(payload)) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Ages []int `json:"ages"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseJsonPostRequired(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017", + bytes.NewBufferString(`{"location": "shanghai"`)) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParsePath(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s in %d", v.Name, v.Year)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin in 2017", rr.Body.String()) +} + +func TestParsePathRequired(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseQuery(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "whatever:200000", rr.Body.String()) +} + +func TestParseQueryRequired(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + v := struct { + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseOptional(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode,optional"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "whatever:0", rr.Body.String()) +} + +func TestParseNestedInRequestEmpty(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017", bytes.NewBufferString("{}")) + assert.Nil(t, err) + + type ( + Request struct { + Name string `path:"name"` + Year int `path:"year"` + } + + Audio struct { + Volume int `json:"volume"` + } + + WrappedRequest struct { + Request + Audio Audio `json:"audio,optional"` + } + ) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var v WrappedRequest + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Name, v.Year)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017", rr.Body.String()) +} + +func TestParsePtrInRequest(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017", + bytes.NewBufferString(`{"audio": {"volume": 100}}`)) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + + type ( + Request struct { + Name string `path:"name"` + Year int `path:"year"` + } + + Audio struct { + Volume int `json:"volume"` + } + + WrappedRequest struct { + Request + Audio *Audio `json:"audio,optional"` + } + ) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var v WrappedRequest + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d:%d", v.Name, v.Year, v.Audio.Volume)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017:100", rr.Body.String()) +} + +func TestParsePtrInRequestEmpty(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin", bytes.NewBufferString("{}")) + assert.Nil(t, err) + + type ( + Audio struct { + Volume int `json:"volume"` + } + + WrappedRequest struct { + Audio *Audio `json:"audio,optional"` + } + ) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/kevin", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var v WrappedRequest + err = Parse(r, &v) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseQueryOptional(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode,optional"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "whatever:0", rr.Body.String()) +} + +func TestParse(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d:%s:%d", v.Name, v.Year, v.Nickname, v.Zipcode)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017:whatever:200000", rr.Body.String()) +} + +func TestParseWrappedRequest(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil) + assert.Nil(t, err) + + type ( + Request struct { + Name string `path:"name"` + Year int `path:"year"` + } + + WrappedRequest struct { + Request + } + ) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var v WrappedRequest + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Name, v.Year)) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017", rr.Body.String()) +} + +func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil) + assert.Nil(t, err) + r.Header.Set(ContentType, applicationJsonWithUtf8) + + type ( + Request struct { + Name string `path:"name"` + Year int `path:"year"` + } + + WrappedRequest struct { + Request + } + ) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var v WrappedRequest + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Name, v.Year)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017", rr.Body.String()) +} + +func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) { + r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", nil) + assert.Nil(t, err) + r.Header.Set(ContentType, applicationJsonWithUtf8) + + type ( + Request struct { + Name string `path:"name"` + Year int `path:"year"` + } + + WrappedRequest struct { + Request + } + ) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodHead, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var v WrappedRequest + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Name, v.Year)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017", rr.Body.String()) +} + +func TestParseWrappedRequestPtr(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil) + assert.Nil(t, err) + + type ( + Request struct { + Name string `path:"name"` + Year int `path:"year"` + } + + WrappedRequest struct { + *Request + } + ) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var v WrappedRequest + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Name, v.Year)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017", rr.Body.String()) +} + +func TestParseWithAll(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", + bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d:%s:%d:%s:%d", v.Name, v.Year, + v.Nickname, v.Zipcode, v.Location, v.Time)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017:whatever:200000:shanghai:20170912", rr.Body.String()) +} + +func TestParseWithAllUtf8(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", + bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) + assert.Nil(t, err) + r.Header.Set(ContentType, applicationJsonWithUtf8) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d:%s:%d:%s:%d", v.Name, v.Year, + v.Nickname, v.Zipcode, v.Location, v.Time)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "kevin:2017:whatever:200000:shanghai:20170912", rr.Body.String()) +} + +func TestParseWithMissingForm(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", + bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + assert.Equal(t, "field zipcode is not set", err.Error()) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseWithMissingAllForms(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017", + bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseWithMissingJson(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", + bytes.NewBufferString(`{"location": "shanghai"}`)) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotEqual(t, io.EOF, err) + assert.NotNil(t, Parse(r, &v)) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseWithMissingAllJsons(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotEqual(t, io.EOF, err) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseWithMissingPath(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/2017?nickname=whatever&zipcode=200000", + bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + assert.Equal(t, "field name is not set", err.Error()) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseWithMissingAllPaths(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/?nickname=whatever&zipcode=200000", + bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) + assert.Nil(t, err) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseGetWithContentLengthHeader(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) + assert.Nil(t, err) + r.Header.Set(ContentType, ApplicationJson) + r.Header.Set(contentLength, "1024") + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Location string `json:"location"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseJsonPostWithTypeMismatch(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", + bytes.NewBufferString(`{"time": "20170912"}`)) + assert.Nil(t, err) + r.Header.Set(ContentType, applicationJsonWithUtf8) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + Time int64 `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func TestParseJsonPostWithInt2String(t *testing.T) { + r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017", + bytes.NewBufferString(`{"time": 20170912}`)) + assert.Nil(t, err) + r.Header.Set(ContentType, applicationJsonWithUtf8) + + router := httprouter.NewPatRouter() + err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Name string `path:"name"` + Year int `path:"year"` + Time string `json:"time"` + }{} + + err = Parse(r, &v) + assert.NotNil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) +} + +func BenchmarkParseRaw(b *testing.B) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", nil) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + v := struct { + Name string `form:"name"` + Age int `form:"age"` + Percent float64 `form:"percent,optional"` + }{} + + v.Name = r.FormValue("name") + v.Age, err = strconv.Atoi(r.FormValue("age")) + if err != nil { + b.Fatal(err) + } + v.Percent, err = strconv.ParseFloat(r.FormValue("percent"), 64) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkParseAuto(b *testing.B) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", nil) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + v := struct { + Name string `form:"name"` + Age int `form:"age"` + Percent float64 `form:"percent,optional"` + }{} + + if err = Parse(r, &v); err != nil { + b.Fatal(err) + } + } +} diff --git a/core/httpx/responses.go b/core/httpx/responses.go new file mode 100644 index 00000000..d822640a --- /dev/null +++ b/core/httpx/responses.go @@ -0,0 +1,29 @@ +package httpx + +import ( + "encoding/json" + "net/http" + + "zero/core/logx" +) + +func OkJson(w http.ResponseWriter, v interface{}) { + WriteJson(w, http.StatusOK, v) +} + +func WriteJson(w http.ResponseWriter, code int, v interface{}) { + w.Header().Set(ContentType, ApplicationJson) + w.WriteHeader(code) + + if bs, err := json.Marshal(v); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } else if n, err := w.Write(bs); err != nil { + // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, + // so it's ignored here. + if err != http.ErrHandlerTimeout { + logx.Errorf("write response failed, error: %s", err) + } + } else if n < len(bs) { + logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) + } +} diff --git a/core/httpx/responses_test.go b/core/httpx/responses_test.go new file mode 100644 index 00000000..e2b3cc6f --- /dev/null +++ b/core/httpx/responses_test.go @@ -0,0 +1,78 @@ +package httpx + +import ( + "net/http" + "strings" + "testing" + + "zero/core/logx" + + "github.com/stretchr/testify/assert" +) + +type message struct { + Name string `json:"name"` +} + +func init() { + logx.Disable() +} + +func TestOkJson(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + msg := message{Name: "anyone"} + OkJson(&w, msg) + assert.Equal(t, http.StatusOK, w.code) + assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String()) +} + +func TestWriteJsonTimeout(t *testing.T) { + // only log it and ignore + w := tracedResponseWriter{ + headers: make(map[string][]string), + timeout: true, + } + msg := message{Name: "anyone"} + WriteJson(&w, http.StatusOK, msg) + assert.Equal(t, http.StatusOK, w.code) +} + +func TestWriteJsonLessWritten(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + lessWritten: true, + } + msg := message{Name: "anyone"} + WriteJson(&w, http.StatusOK, msg) + assert.Equal(t, http.StatusOK, w.code) +} + +type tracedResponseWriter struct { + headers map[string][]string + builder strings.Builder + code int + lessWritten bool + timeout bool +} + +func (w *tracedResponseWriter) Header() http.Header { + return w.headers +} + +func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) { + if w.timeout { + return 0, http.ErrHandlerTimeout + } + + n, err = w.builder.Write(bytes) + if w.lessWritten { + n -= 1 + } + return +} + +func (w *tracedResponseWriter) WriteHeader(code int) { + w.code = code +} diff --git a/core/iox/bufferpool.go b/core/iox/bufferpool.go new file mode 100644 index 00000000..3f3c9102 --- /dev/null +++ b/core/iox/bufferpool.go @@ -0,0 +1,34 @@ +package iox + +import ( + "bytes" + "sync" +) + +type BufferPool struct { + capability int + pool *sync.Pool +} + +func NewBufferPool(capability int) *BufferPool { + return &BufferPool{ + capability: capability, + pool: &sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, + }, + } +} + +func (bp *BufferPool) Get() *bytes.Buffer { + buf := bp.pool.Get().(*bytes.Buffer) + buf.Reset() + return buf +} + +func (bp *BufferPool) Put(buf *bytes.Buffer) { + if buf.Cap() < bp.capability { + bp.pool.Put(buf) + } +} diff --git a/core/iox/bufferpool_test.go b/core/iox/bufferpool_test.go new file mode 100644 index 00000000..254e5efd --- /dev/null +++ b/core/iox/bufferpool_test.go @@ -0,0 +1,15 @@ +package iox + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBufferPool(t *testing.T) { + capacity := 1024 + pool := NewBufferPool(capacity) + pool.Put(bytes.NewBuffer(make([]byte, 0, 2*capacity))) + assert.True(t, pool.Get().Cap() <= capacity) +} diff --git a/core/iox/nopcloser.go b/core/iox/nopcloser.go new file mode 100644 index 00000000..2fefc444 --- /dev/null +++ b/core/iox/nopcloser.go @@ -0,0 +1,15 @@ +package iox + +import "io" + +type nopCloser struct { + io.Writer +} + +func (nopCloser) Close() error { + return nil +} + +func NopCloser(w io.Writer) io.WriteCloser { + return nopCloser{w} +} diff --git a/core/iox/read.go b/core/iox/read.go new file mode 100644 index 00000000..6c8a76a2 --- /dev/null +++ b/core/iox/read.go @@ -0,0 +1,103 @@ +package iox + +import ( + "bufio" + "bytes" + "io" + "io/ioutil" + "os" + "strings" +) + +type ( + textReadOptions struct { + keepSpace bool + withoutBlanks bool + omitPrefix string + } + + TextReadOption func(*textReadOptions) +) + +// The first returned reader needs to be read first, because the content +// read from it will be written to the underlying buffer of the second reader. +func DupReadCloser(reader io.ReadCloser) (io.ReadCloser, io.ReadCloser) { + var buf bytes.Buffer + tee := io.TeeReader(reader, &buf) + return ioutil.NopCloser(tee), ioutil.NopCloser(&buf) +} + +func KeepSpace() TextReadOption { + return func(o *textReadOptions) { + o.keepSpace = true + } +} + +// ReadBytes reads exactly the bytes with the length of len(buf) +func ReadBytes(reader io.Reader, buf []byte) error { + var got int + + for got < len(buf) { + n, err := reader.Read(buf[got:]) + if err != nil { + return err + } + + got += n + } + + return nil +} + +func ReadText(filename string) (string, error) { + content, err := ioutil.ReadFile(filename) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(content)), nil +} + +func ReadTextLines(filename string, opts ...TextReadOption) ([]string, error) { + var readOpts textReadOptions + for _, opt := range opts { + opt(&readOpts) + } + + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + + var lines []string + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if !readOpts.keepSpace { + line = strings.TrimSpace(line) + } + if readOpts.withoutBlanks && len(line) == 0 { + continue + } + if len(readOpts.omitPrefix) > 0 && strings.HasPrefix(line, readOpts.omitPrefix) { + continue + } + + lines = append(lines, line) + } + + return lines, scanner.Err() +} + +func WithoutBlank() TextReadOption { + return func(o *textReadOptions) { + o.withoutBlanks = true + } +} + +func OmitWithPrefix(prefix string) TextReadOption { + return func(o *textReadOptions) { + o.omitPrefix = prefix + } +} diff --git a/core/iox/read_test.go b/core/iox/read_test.go new file mode 100644 index 00000000..51bf78a5 --- /dev/null +++ b/core/iox/read_test.go @@ -0,0 +1,142 @@ +package iox + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "testing" + "time" + + "zero/core/fs" + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +func TestReadText(t *testing.T) { + tests := []struct { + input string + expect string + }{ + { + input: `a`, + expect: `a`, + }, { + input: `a +`, + expect: `a`, + }, { + input: `a +b`, + expect: `a +b`, + }, { + input: `a +b +`, + expect: `a +b`, + }, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + tmpfile, err := fs.TempFilenameWithText(test.input) + assert.Nil(t, err) + defer os.Remove(tmpfile) + + content, err := ReadText(tmpfile) + assert.Nil(t, err) + assert.Equal(t, test.expect, content) + }) + } +} + +func TestReadTextLines(t *testing.T) { + text := `1 + + 2 + + #a + 3` + + tmpfile, err := fs.TempFilenameWithText(text) + assert.Nil(t, err) + defer os.Remove(tmpfile) + + tests := []struct { + options []TextReadOption + expectLines int + }{ + { + nil, + 6, + }, { + []TextReadOption{KeepSpace(), OmitWithPrefix("#")}, + 6, + }, { + []TextReadOption{WithoutBlank()}, + 4, + }, { + []TextReadOption{OmitWithPrefix("#")}, + 5, + }, { + []TextReadOption{WithoutBlank(), OmitWithPrefix("#")}, + 3, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + lines, err := ReadTextLines(tmpfile, test.options...) + assert.Nil(t, err) + assert.Equal(t, test.expectLines, len(lines)) + }) + } +} + +func TestDupReadCloser(t *testing.T) { + input := "hello" + reader := ioutil.NopCloser(bytes.NewBufferString(input)) + r1, r2 := DupReadCloser(reader) + verify := func(r io.Reader) { + output, err := ioutil.ReadAll(r) + assert.Nil(t, err) + assert.Equal(t, input, string(output)) + } + + verify(r1) + verify(r2) +} + +func TestReadBytes(t *testing.T) { + reader := ioutil.NopCloser(bytes.NewBufferString("helloworld")) + buf := make([]byte, 5) + err := ReadBytes(reader, buf) + assert.Nil(t, err) + assert.Equal(t, "hello", string(buf)) +} + +func TestReadBytesNotEnough(t *testing.T) { + reader := ioutil.NopCloser(bytes.NewBufferString("hell")) + buf := make([]byte, 5) + err := ReadBytes(reader, buf) + assert.Equal(t, io.EOF, err) +} + +func TestReadBytesChunks(t *testing.T) { + buf := make([]byte, 5) + reader, writer := io.Pipe() + + go func() { + for i := 0; i < 10; i++ { + writer.Write([]byte{'a'}) + time.Sleep(10 * time.Millisecond) + } + }() + + err := ReadBytes(reader, buf) + assert.Nil(t, err) + assert.Equal(t, "aaaaa", string(buf)) +} diff --git a/core/iox/textfile.go b/core/iox/textfile.go new file mode 100644 index 00000000..1cc2cb0b --- /dev/null +++ b/core/iox/textfile.go @@ -0,0 +1,39 @@ +package iox + +import ( + "bytes" + "io" + "os" +) + +const bufSize = 32 * 1024 + +func CountLines(file string) (int, error) { + f, err := os.Open(file) + if err != nil { + return 0, err + } + defer f.Close() + + var noEol bool + buf := make([]byte, bufSize) + count := 0 + lineSep := []byte{'\n'} + + for { + c, err := f.Read(buf) + count += bytes.Count(buf[:c], lineSep) + + switch { + case err == io.EOF: + if noEol { + count++ + } + return count, nil + case err != nil: + return count, err + } + + noEol = c > 0 && buf[c-1] != '\n' + } +} diff --git a/core/iox/textfile_test.go b/core/iox/textfile_test.go new file mode 100644 index 00000000..2c5f58e3 --- /dev/null +++ b/core/iox/textfile_test.go @@ -0,0 +1,27 @@ +package iox + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCountLines(t *testing.T) { + const val = `1 +2 +3 +4` + file, err := ioutil.TempFile(os.TempDir(), "test-") + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + + file.WriteString(val) + file.Close() + lines, err := CountLines(file.Name()) + assert.Nil(t, err) + assert.Equal(t, 4, lines) +} diff --git a/core/iox/textlinescanner.go b/core/iox/textlinescanner.go new file mode 100644 index 00000000..ade9049b --- /dev/null +++ b/core/iox/textlinescanner.go @@ -0,0 +1,42 @@ +package iox + +import ( + "bufio" + "io" + "strings" +) + +type TextLineScanner struct { + reader *bufio.Reader + hasNext bool + line string + err error +} + +func NewTextLineScanner(reader io.Reader) *TextLineScanner { + return &TextLineScanner{ + reader: bufio.NewReader(reader), + hasNext: true, + } +} + +func (scanner *TextLineScanner) Scan() bool { + if !scanner.hasNext { + return false + } + + line, err := scanner.reader.ReadString('\n') + scanner.line = strings.TrimRight(line, "\n") + if err == io.EOF { + scanner.hasNext = false + return true + } else if err != nil { + scanner.err = err + return false + } + return true +} + +func (scanner *TextLineScanner) Line() (string, error) { + return scanner.line, scanner.err +} diff --git a/core/iox/textlinescanner_test.go b/core/iox/textlinescanner_test.go new file mode 100644 index 00000000..621a36cc --- /dev/null +++ b/core/iox/textlinescanner_test.go @@ -0,0 +1,24 @@ +package iox + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestScanner(t *testing.T) { + const val = `1 +2 +3 +4` + reader := strings.NewReader(val) + scanner := NewTextLineScanner(reader) + var lines []string + for scanner.Scan() { + line, err := scanner.Line() + assert.Nil(t, err) + lines = append(lines, line) + } + assert.EqualValues(t, []string{"1", "2", "3", "4"}, lines) +} diff --git a/core/jsontype/time.go b/core/jsontype/time.go new file mode 100644 index 00000000..d55ec162 --- /dev/null +++ b/core/jsontype/time.go @@ -0,0 +1,38 @@ +package jsontype + +import ( + "encoding/json" + "time" + + "github.com/globalsign/mgo/bson" +) + +type MilliTime struct { + time.Time +} + +func (mt MilliTime) MarshalJSON() ([]byte, error) { + return json.Marshal(mt.Milli()) +} + +func (mt *MilliTime) UnmarshalJSON(data []byte) error { + var milli int64 + if err := json.Unmarshal(data, &milli); err != nil { + return err + } else { + mt.Time = time.Unix(0, milli*int64(time.Millisecond)) + return nil + } +} + +func (mt MilliTime) GetBSON() (interface{}, error) { + return mt.Time, nil +} + +func (mt *MilliTime) SetBSON(raw bson.Raw) error { + return raw.Unmarshal(&mt.Time) +} + +func (mt MilliTime) Milli() int64 { + return mt.UnixNano() / int64(time.Millisecond) +} diff --git a/core/jsontype/time_test.go b/core/jsontype/time_test.go new file mode 100644 index 00000000..5d675fea --- /dev/null +++ b/core/jsontype/time_test.go @@ -0,0 +1,108 @@ +package jsontype + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMilliTime_GetBSON(t *testing.T) { + tests := []struct { + name string + tm time.Time + }{ + { + name: "now", + tm: time.Now(), + }, + { + name: "future", + tm: time.Now().Add(time.Hour), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := MilliTime{test.tm}.GetBSON() + assert.Nil(t, err) + assert.Equal(t, test.tm, got) + }) + } +} + +func TestMilliTime_MarshalJSON(t *testing.T) { + tests := []struct { + name string + tm time.Time + }{ + { + name: "now", + tm: time.Now(), + }, + { + name: "future", + tm: time.Now().Add(time.Hour), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + b, err := MilliTime{test.tm}.MarshalJSON() + assert.Nil(t, err) + assert.Equal(t, strconv.FormatInt(test.tm.UnixNano()/1e6, 10), string(b)) + }) + } +} + +func TestMilliTime_Milli(t *testing.T) { + tests := []struct { + name string + tm time.Time + }{ + { + name: "now", + tm: time.Now(), + }, + { + name: "future", + tm: time.Now().Add(time.Hour), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + n := MilliTime{test.tm}.Milli() + assert.Equal(t, test.tm.UnixNano()/1e6, n) + }) + } +} + +func TestMilliTime_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + tm time.Time + }{ + { + name: "now", + tm: time.Now(), + }, + { + name: "future", + tm: time.Now().Add(time.Hour), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var mt MilliTime + s := strconv.FormatInt(test.tm.UnixNano()/1e6, 10) + err := mt.UnmarshalJSON([]byte(s)) + assert.Nil(t, err) + s1, err := mt.MarshalJSON() + assert.Nil(t, err) + assert.Equal(t, s, string(s1)) + }) + } +} diff --git a/core/jsonx/json.go b/core/jsonx/json.go new file mode 100644 index 00000000..a1f5f7ef --- /dev/null +++ b/core/jsonx/json.go @@ -0,0 +1,51 @@ +package jsonx + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "strings" +) + +func Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func Unmarshal(data []byte, v interface{}) error { + decoder := json.NewDecoder(bytes.NewReader(data)) + if err := unmarshalUseNumber(decoder, v); err != nil { + return formatError(string(data), err) + } + + return nil +} + +func UnmarshalFromString(str string, v interface{}) error { + decoder := json.NewDecoder(strings.NewReader(str)) + if err := unmarshalUseNumber(decoder, v); err != nil { + return formatError(str, err) + } + + return nil +} + +func UnmarshalFromReader(reader io.Reader, v interface{}) error { + var buf strings.Builder + teeReader := io.TeeReader(reader, &buf) + decoder := json.NewDecoder(teeReader) + if err := unmarshalUseNumber(decoder, v); err != nil { + return formatError(buf.String(), err) + } + + return nil +} + +func unmarshalUseNumber(decoder *json.Decoder, v interface{}) error { + decoder.UseNumber() + return decoder.Decode(v) +} + +func formatError(v string, err error) error { + return fmt.Errorf("string: `%s`, error: `%s`", v, err.Error()) +} diff --git a/core/lang/lang.go b/core/lang/lang.go new file mode 100644 index 00000000..cf02ad6e --- /dev/null +++ b/core/lang/lang.go @@ -0,0 +1,16 @@ +package lang + +import "log" + +var Placeholder PlaceholderType + +type ( + GenericType = interface{} + PlaceholderType = struct{} +) + +func Must(err error) { + if err != nil { + log.Fatal(err) + } +} diff --git a/core/lang/lang_test.go b/core/lang/lang_test.go new file mode 100644 index 00000000..6b3c6417 --- /dev/null +++ b/core/lang/lang_test.go @@ -0,0 +1,7 @@ +package lang + +import "testing" + +func TestMust(t *testing.T) { + Must(nil) +} diff --git a/core/limit/periodlimit.go b/core/limit/periodlimit.go new file mode 100644 index 00000000..e957b655 --- /dev/null +++ b/core/limit/periodlimit.go @@ -0,0 +1,109 @@ +package limit + +import ( + "errors" + "strconv" + "time" + + "zero/core/stores/redis" +) + +const ( + // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key + periodScript = `local limit = tonumber(ARGV[1]) +local window = tonumber(ARGV[2]) +local current = redis.call("INCRBY", KEYS[1], 1) +if current == 1 then + redis.call("expire", KEYS[1], window) + return 1 +elseif current < limit then + return 1 +elseif current == limit then + return 2 +else + return 0 +end` + zoneDiff = 3600 * 8 // GMT+8 for our services +) + +const ( + Unknown = iota + Allowed + HitQuota + OverQuota + + internalOverQuota = 0 + internalAllowed = 1 + internalHitQuota = 2 +) + +var ErrUnknownCode = errors.New("unknown status code") + +type ( + LimitOption func(l *PeriodLimit) + + PeriodLimit struct { + period int + quota int + limitStore *redis.Redis + keyPrefix string + align bool + } +) + +func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string, + opts ...LimitOption) *PeriodLimit { + limiter := &PeriodLimit{ + period: period, + quota: quota, + limitStore: limitStore, + keyPrefix: keyPrefix, + } + + for _, opt := range opts { + opt(limiter) + } + + return limiter +} + +func (h *PeriodLimit) Take(key string) (int, error) { + resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{ + strconv.Itoa(h.quota), + strconv.Itoa(h.calcExpireSeconds()), + }) + if err != nil { + return Unknown, err + } + + code, ok := resp.(int64) + if !ok { + return Unknown, ErrUnknownCode + } + + switch code { + case internalOverQuota: + return OverQuota, nil + case internalAllowed: + return Allowed, nil + case internalHitQuota: + return HitQuota, nil + default: + return Unknown, ErrUnknownCode + } +} + +func (h *PeriodLimit) calcExpireSeconds() int { + if h.align { + unix := time.Now().Unix() + zoneDiff + return h.period - int(unix%int64(h.period)) + } else { + return h.period + } +} + +func Align() LimitOption { + return func(l *PeriodLimit) { + l.align = true + } +} diff --git a/core/limit/periodlimit_test.go b/core/limit/periodlimit_test.go new file mode 100644 index 00000000..8d16c974 --- /dev/null +++ b/core/limit/periodlimit_test.go @@ -0,0 +1,68 @@ +package limit + +import ( + "testing" + + "zero/core/stores/redis" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +func TestPeriodLimit_Take(t *testing.T) { + testPeriodLimit(t) +} + +func TestPeriodLimit_TakeWithAlign(t *testing.T) { + testPeriodLimit(t, Align()) +} + +func TestPeriodLimit_RedisUnavailable(t *testing.T) { + s, err := miniredis.Run() + assert.Nil(t, err) + + const ( + seconds = 1 + total = 100 + quota = 5 + ) + l := NewPeriodLimit(seconds, quota, redis.NewRedis(s.Addr(), redis.NodeType), "periodlimit") + s.Close() + val, err := l.Take("first") + assert.NotNil(t, err) + assert.Equal(t, 0, val) +} + +func testPeriodLimit(t *testing.T, opts ...LimitOption) { + s, err := miniredis.Run() + assert.Nil(t, err) + defer s.Close() + + const ( + seconds = 1 + total = 100 + quota = 5 + ) + l := NewPeriodLimit(seconds, quota, redis.NewRedis(s.Addr(), redis.NodeType), "periodlimit", opts...) + var allowed, hitQuota, overQuota int + for i := 0; i < total; i++ { + val, err := l.Take("first") + if err != nil { + t.Error(err) + } + switch val { + case Allowed: + allowed++ + case HitQuota: + hitQuota++ + case OverQuota: + overQuota++ + default: + t.Error("unknown status") + } + } + + assert.Equal(t, quota-1, allowed) + assert.Equal(t, 1, hitQuota) + assert.Equal(t, total-quota, overQuota) +} diff --git a/core/limit/tokenlimit.go b/core/limit/tokenlimit.go new file mode 100644 index 00000000..03b15ac9 --- /dev/null +++ b/core/limit/tokenlimit.go @@ -0,0 +1,166 @@ +package limit + +import ( + "fmt" + "strconv" + "sync" + "sync/atomic" + "time" + + "zero/core/logx" + "zero/core/stores/redis" + + xrate "golang.org/x/time/rate" +) + +const ( + // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key + // KEYS[1] as tokens_key + // KEYS[2] as timestamp_key + script = `local rate = tonumber(ARGV[1]) +local capacity = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local requested = tonumber(ARGV[4]) +local fill_time = capacity/rate +local ttl = math.floor(fill_time*2) +local last_tokens = tonumber(redis.call("get", KEYS[1])) +if last_tokens == nil then + last_tokens = capacity +end + +local last_refreshed = tonumber(redis.call("get", KEYS[2])) +if last_refreshed == nil then + last_refreshed = 0 +end + +local delta = math.max(0, now-last_refreshed) +local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) +local allowed = filled_tokens >= requested +local new_tokens = filled_tokens +if allowed then + new_tokens = filled_tokens - requested +end + +redis.call("setex", KEYS[1], ttl, new_tokens) +redis.call("setex", KEYS[2], ttl, now) + +return allowed` + tokenFormat = "{%s}.tokens" + timestampFormat = "{%s}.ts" + pingInterval = time.Millisecond * 100 +) + +// A TokenLimiter controls how frequently events are allowed to happen with in one second. +type TokenLimiter struct { + rate int + burst int + store *redis.Redis + tokenKey string + timestampKey string + rescueLock sync.Mutex + redisAlive uint32 + rescueLimiter *xrate.Limiter + monitorStarted bool +} + +// NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits +// bursts of at most burst tokens. +func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter { + tokenKey := fmt.Sprintf(tokenFormat, key) + timestampKey := fmt.Sprintf(timestampFormat, key) + + return &TokenLimiter{ + rate: rate, + burst: burst, + store: store, + tokenKey: tokenKey, + timestampKey: timestampKey, + redisAlive: 1, + rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst), + } +} + +// Allow is shorthand for AllowN(time.Now(), 1). +func (lim *TokenLimiter) Allow() bool { + return lim.AllowN(time.Now(), 1) +} + +// AllowN reports whether n events may happen at time now. +// Use this method if you intend to drop / skip events that exceed the rate rate. +// Otherwise use Reserve or Wait. +func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { + return lim.reserveN(now, n) +} + +func (lim *TokenLimiter) reserveN(now time.Time, n int) bool { + if atomic.LoadUint32(&lim.redisAlive) == 0 { + return lim.rescueLimiter.AllowN(now, n) + } + + resp, err := lim.store.Eval( + script, + []string{ + lim.tokenKey, + lim.timestampKey, + }, + []string{ + strconv.Itoa(lim.rate), + strconv.Itoa(lim.burst), + strconv.FormatInt(now.Unix(), 10), + strconv.Itoa(n), + }) + // redis allowed == false + // Lua boolean false -> r Nil bulk reply + if err == redis.Nil { + return false + } else if err != nil { + logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) + lim.startMonitor() + return lim.rescueLimiter.AllowN(now, n) + } + + code, ok := resp.(int64) + if !ok { + logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp) + lim.startMonitor() + return lim.rescueLimiter.AllowN(now, n) + } + + // redis allowed == true + // Lua boolean true -> r integer reply with value of 1 + return code == 1 +} + +func (lim *TokenLimiter) startMonitor() { + lim.rescueLock.Lock() + defer lim.rescueLock.Unlock() + + if lim.monitorStarted { + return + } + + lim.monitorStarted = true + atomic.StoreUint32(&lim.redisAlive, 0) + + go lim.waitForRedis() +} + +func (lim *TokenLimiter) waitForRedis() { + ticker := time.NewTicker(pingInterval) + defer func() { + ticker.Stop() + lim.rescueLock.Lock() + lim.monitorStarted = false + lim.rescueLock.Unlock() + }() + + for { + select { + case <-ticker.C: + if lim.store.Ping() { + atomic.StoreUint32(&lim.redisAlive, 1) + return + } + } + } +} diff --git a/core/limit/tokenlimit_test.go b/core/limit/tokenlimit_test.go new file mode 100644 index 00000000..084c03b3 --- /dev/null +++ b/core/limit/tokenlimit_test.go @@ -0,0 +1,88 @@ +package limit + +import ( + "testing" + "time" + + "zero/core/logx" + "zero/core/stores/redis" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +func init() { + logx.Disable() +} + +func TestTokenLimit_Rescue(t *testing.T) { + s, err := miniredis.Run() + assert.Nil(t, err) + + const ( + total = 100 + rate = 5 + burst = 10 + ) + l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit") + s.Close() + + var allowed int + for i := 0; i < total; i++ { + time.Sleep(time.Second / time.Duration(total)) + if i == total>>1 { + assert.Nil(t, s.Restart()) + } + if l.Allow() { + allowed++ + } + + // make sure start monitor more than once doesn't matter + l.startMonitor() + } + + assert.True(t, allowed >= burst+rate) +} + +func TestTokenLimit_Take(t *testing.T) { + s, err := miniredis.Run() + assert.Nil(t, err) + defer s.Close() + + const ( + total = 100 + rate = 5 + burst = 10 + ) + l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit") + var allowed int + for i := 0; i < total; i++ { + time.Sleep(time.Second / time.Duration(total)) + if l.Allow() { + allowed++ + } + } + + assert.True(t, allowed >= burst+rate) +} + +func TestTokenLimit_TakeBurst(t *testing.T) { + s, err := miniredis.Run() + assert.Nil(t, err) + defer s.Close() + + const ( + total = 100 + rate = 5 + burst = 10 + ) + l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit") + var allowed int + for i := 0; i < total; i++ { + if l.Allow() { + allowed++ + } + } + + assert.True(t, allowed >= burst) +} diff --git a/core/load/adaptiveshedder.go b/core/load/adaptiveshedder.go new file mode 100644 index 00000000..db45253d --- /dev/null +++ b/core/load/adaptiveshedder.go @@ -0,0 +1,248 @@ +package load + +import ( + "errors" + "fmt" + "math" + "sync/atomic" + "time" + + "zero/core/collection" + "zero/core/logx" + "zero/core/stat" + "zero/core/syncx" + "zero/core/timex" +) + +const ( + defaultBuckets = 50 + defaultWindow = time.Second * 5 + // using 1000m notation, 900m is like 80%, keep it as var for unit test + defaultCpuThreshold = 900 + defaultMinRt = float64(time.Second / time.Millisecond) + // moving average hyperparameter beta for calculating requests on the fly + flyingBeta = 0.9 + coolOffDuration = time.Second +) + +var ( + ErrServiceOverloaded = errors.New("service overloaded") + + // default to be enabled + enabled = syncx.ForAtomicBool(true) + // make it a variable for unit test + systemOverloadChecker = func(cpuThreshold int64) bool { + return stat.CpuUsage() >= cpuThreshold + } +) + +type ( + Promise interface { + Pass() + Fail() + } + + Shedder interface { + Allow() (Promise, error) + } + + ShedderOption func(opts *shedderOptions) + + shedderOptions struct { + window time.Duration + buckets int + cpuThreshold int64 + } + + adaptiveShedder struct { + cpuThreshold int64 + windows int64 + flying int64 + avgFlying float64 + avgFlyingLock syncx.SpinLock + dropTime *syncx.AtomicDuration + droppedRecently *syncx.AtomicBool + passCounter *collection.RollingWindow + rtCounter *collection.RollingWindow + } +) + +func Disable() { + enabled.Set(false) +} + +func NewAdaptiveShedder(opts ...ShedderOption) Shedder { + if !enabled.True() { + return newNopShedder() + } + + options := shedderOptions{ + window: defaultWindow, + buckets: defaultBuckets, + cpuThreshold: defaultCpuThreshold, + } + for _, opt := range opts { + opt(&options) + } + bucketDuration := options.window / time.Duration(options.buckets) + return &adaptiveShedder{ + cpuThreshold: options.cpuThreshold, + windows: int64(time.Second / bucketDuration), + dropTime: syncx.NewAtomicDuration(), + droppedRecently: syncx.NewAtomicBool(), + passCounter: collection.NewRollingWindow(options.buckets, bucketDuration, + collection.IgnoreCurrentBucket()), + rtCounter: collection.NewRollingWindow(options.buckets, bucketDuration, + collection.IgnoreCurrentBucket()), + } +} + +func (as *adaptiveShedder) Allow() (Promise, error) { + if as.shouldDrop() { + as.dropTime.Set(timex.Now()) + as.droppedRecently.Set(true) + + return nil, ErrServiceOverloaded + } + + as.addFlying(1) + + return &promise{ + start: timex.Now(), + shedder: as, + }, nil +} + +func (as *adaptiveShedder) addFlying(delta int64) { + flying := atomic.AddInt64(&as.flying, delta) + // update avgFlying when the request is finished. + // this strategy makes avgFlying have a little bit lag against flying, and smoother. + // when the flying requests increase rapidly, avgFlying increase slower, accept more requests. + // when the flying requests drop rapidly, avgFlying drop slower, accept less requests. + // it makes the service to serve as more requests as possible. + if delta < 0 { + as.avgFlyingLock.Lock() + as.avgFlying = as.avgFlying*flyingBeta + float64(flying)*(1-flyingBeta) + as.avgFlyingLock.Unlock() + } +} + +func (as *adaptiveShedder) highThru() bool { + as.avgFlyingLock.Lock() + avgFlying := as.avgFlying + as.avgFlyingLock.Unlock() + maxFlight := as.maxFlight() + return int64(avgFlying) > maxFlight && atomic.LoadInt64(&as.flying) > maxFlight +} + +func (as *adaptiveShedder) maxFlight() int64 { + // windows = buckets per second + // maxQPS = maxPASS * windows + // minRT = min average response time in milliseconds + // maxQPS * minRT / milliseconds_per_second + return int64(math.Max(1, float64(as.maxPass()*as.windows)*(as.minRt()/1e3))) +} + +func (as *adaptiveShedder) maxPass() int64 { + var result float64 = 1 + + as.passCounter.Reduce(func(b *collection.Bucket) { + if b.Sum > result { + result = b.Sum + } + }) + + return int64(result) +} + +func (as *adaptiveShedder) minRt() float64 { + var result = defaultMinRt + + as.rtCounter.Reduce(func(b *collection.Bucket) { + if b.Count <= 0 { + return + } + + avg := math.Round(b.Sum / float64(b.Count)) + if avg < result { + result = avg + } + }) + + return result +} + +func (as *adaptiveShedder) shouldDrop() bool { + if as.systemOverloaded() || as.stillHot() { + if as.highThru() { + flying := atomic.LoadInt64(&as.flying) + as.avgFlyingLock.Lock() + avgFlying := as.avgFlying + as.avgFlyingLock.Unlock() + msg := fmt.Sprintf( + "dropreq, cpu: %d, maxPass: %d, minRt: %.2f, hot: %t, flying: %d, avgFlying: %.2f", + stat.CpuUsage(), as.maxPass(), as.minRt(), as.stillHot(), flying, avgFlying) + logx.Error(msg) + stat.Report(msg) + return true + } + } + + return false +} + +func (as *adaptiveShedder) stillHot() bool { + if !as.droppedRecently.True() { + return false + } + + dropTime := as.dropTime.Load() + if dropTime == 0 { + return false + } + + hot := timex.Since(dropTime) < coolOffDuration + if !hot { + as.droppedRecently.Set(false) + } + + return hot +} + +func (as *adaptiveShedder) systemOverloaded() bool { + return systemOverloadChecker(as.cpuThreshold) +} + +func WithBuckets(buckets int) ShedderOption { + return func(opts *shedderOptions) { + opts.buckets = buckets + } +} + +func WithCpuThreshold(threshold int64) ShedderOption { + return func(opts *shedderOptions) { + opts.cpuThreshold = threshold + } +} + +func WithWindow(window time.Duration) ShedderOption { + return func(opts *shedderOptions) { + opts.window = window + } +} + +type promise struct { + start time.Duration + shedder *adaptiveShedder +} + +func (p *promise) Fail() { + p.shedder.addFlying(-1) +} + +func (p *promise) Pass() { + rt := float64(timex.Since(p.start)) / float64(time.Millisecond) + p.shedder.addFlying(-1) + p.shedder.rtCounter.Add(math.Ceil(rt)) + p.shedder.passCounter.Add(1) +} diff --git a/core/load/adaptiveshedder_test.go b/core/load/adaptiveshedder_test.go new file mode 100644 index 00000000..449640ed --- /dev/null +++ b/core/load/adaptiveshedder_test.go @@ -0,0 +1,205 @@ +package load + +import ( + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "zero/core/collection" + "zero/core/logx" + "zero/core/mathx" + "zero/core/stat" + "zero/core/syncx" + + "github.com/stretchr/testify/assert" +) + +const ( + buckets = 10 + bucketDuration = time.Millisecond * 50 +) + +func init() { + stat.SetReporter(nil) +} + +func TestAdaptiveShedder(t *testing.T) { + shedder := NewAdaptiveShedder(WithWindow(bucketDuration), WithBuckets(buckets), WithCpuThreshold(100)) + var wg sync.WaitGroup + var drop int64 + proba := mathx.NewProba() + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 30; i++ { + promise, err := shedder.Allow() + if err != nil { + atomic.AddInt64(&drop, 1) + } else { + count := rand.Intn(5) + time.Sleep(time.Millisecond * time.Duration(count)) + if proba.TrueOnProba(0.01) { + promise.Fail() + } else { + promise.Pass() + } + } + } + }() + } + wg.Wait() +} + +func TestAdaptiveShedderMaxPass(t *testing.T) { + passCounter := newRollingWindow() + for i := 1; i <= 10; i++ { + passCounter.Add(float64(i * 100)) + time.Sleep(bucketDuration) + } + shedder := &adaptiveShedder{ + passCounter: passCounter, + droppedRecently: syncx.NewAtomicBool(), + } + assert.Equal(t, int64(1000), shedder.maxPass()) + + // default max pass is equal to 1. + passCounter = newRollingWindow() + shedder = &adaptiveShedder{ + passCounter: passCounter, + droppedRecently: syncx.NewAtomicBool(), + } + assert.Equal(t, int64(1), shedder.maxPass()) +} + +func TestAdaptiveShedderMinRt(t *testing.T) { + rtCounter := newRollingWindow() + for i := 0; i < 10; i++ { + if i > 0 { + time.Sleep(bucketDuration) + } + for j := i*10 + 1; j <= i*10+10; j++ { + rtCounter.Add(float64(j)) + } + } + shedder := &adaptiveShedder{ + rtCounter: rtCounter, + } + assert.Equal(t, float64(6), shedder.minRt()) + + // default max min rt is equal to maxFloat64. + rtCounter = newRollingWindow() + shedder = &adaptiveShedder{ + rtCounter: rtCounter, + droppedRecently: syncx.NewAtomicBool(), + } + assert.Equal(t, defaultMinRt, shedder.minRt()) +} + +func TestAdaptiveShedderMaxFlight(t *testing.T) { + passCounter := newRollingWindow() + rtCounter := newRollingWindow() + for i := 0; i < 10; i++ { + if i > 0 { + time.Sleep(bucketDuration) + } + passCounter.Add(float64((i + 1) * 100)) + for j := i*10 + 1; j <= i*10+10; j++ { + rtCounter.Add(float64(j)) + } + } + shedder := &adaptiveShedder{ + passCounter: passCounter, + rtCounter: rtCounter, + windows: buckets, + droppedRecently: syncx.NewAtomicBool(), + } + assert.Equal(t, int64(54), shedder.maxFlight()) +} + +func TestAdaptiveShedderShouldDrop(t *testing.T) { + logx.Disable() + passCounter := newRollingWindow() + rtCounter := newRollingWindow() + for i := 0; i < 10; i++ { + if i > 0 { + time.Sleep(bucketDuration) + } + passCounter.Add(float64((i + 1) * 100)) + for j := i*10 + 1; j <= i*10+10; j++ { + rtCounter.Add(float64(j)) + } + } + shedder := &adaptiveShedder{ + passCounter: passCounter, + rtCounter: rtCounter, + windows: buckets, + droppedRecently: syncx.NewAtomicBool(), + } + // cpu >= 800, inflight < maxPass + systemOverloadChecker = func(int64) bool { + return true + } + shedder.avgFlying = 50 + assert.False(t, shedder.shouldDrop()) + + // cpu >= 800, inflight > maxPass + shedder.avgFlying = 80 + shedder.flying = 50 + assert.False(t, shedder.shouldDrop()) + + // cpu >= 800, inflight > maxPass + shedder.avgFlying = 80 + shedder.flying = 80 + assert.True(t, shedder.shouldDrop()) + + // cpu < 800, inflight > maxPass + systemOverloadChecker = func(int64) bool { + return false + } + shedder.avgFlying = 80 + assert.False(t, shedder.shouldDrop()) +} + +func BenchmarkAdaptiveShedder_Allow(b *testing.B) { + logx.Disable() + + bench := func(b *testing.B) { + var shedder = NewAdaptiveShedder() + proba := mathx.NewProba() + for i := 0; i < 6000; i++ { + p, err := shedder.Allow() + if err == nil { + time.Sleep(time.Millisecond) + if proba.TrueOnProba(0.01) { + p.Fail() + } else { + p.Pass() + } + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + p, err := shedder.Allow() + if err == nil { + p.Pass() + } + } + } + + systemOverloadChecker = func(int64) bool { + return true + } + b.Run("high load", bench) + systemOverloadChecker = func(int64) bool { + return false + } + b.Run("low load", bench) +} + +func newRollingWindow() *collection.RollingWindow { + return collection.NewRollingWindow(buckets, bucketDuration, collection.IgnoreCurrentBucket()) +} diff --git a/core/load/nopshedder.go b/core/load/nopshedder.go new file mode 100644 index 00000000..9de5e51c --- /dev/null +++ b/core/load/nopshedder.go @@ -0,0 +1,21 @@ +package load + +type nopShedder struct { +} + +func newNopShedder() Shedder { + return nopShedder{} +} + +func (s nopShedder) Allow() (Promise, error) { + return nopPromise{}, nil +} + +type nopPromise struct { +} + +func (p nopPromise) Pass() { +} + +func (p nopPromise) Fail() { +} diff --git a/core/load/nopshedder_test.go b/core/load/nopshedder_test.go new file mode 100644 index 00000000..a731294f --- /dev/null +++ b/core/load/nopshedder_test.go @@ -0,0 +1,21 @@ +package load + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNopShedder(t *testing.T) { + Disable() + shedder := NewAdaptiveShedder() + for i := 0; i < 1000; i++ { + p, err := shedder.Allow() + assert.Nil(t, err) + p.Fail() + } + + p, err := shedder.Allow() + assert.Nil(t, err) + p.Pass() +} diff --git a/core/load/sheddergroup.go b/core/load/sheddergroup.go new file mode 100644 index 00000000..a8b76108 --- /dev/null +++ b/core/load/sheddergroup.go @@ -0,0 +1,36 @@ +package load + +import ( + "io" + + "zero/core/syncx" +) + +type ShedderGroup struct { + options []ShedderOption + manager *syncx.ResourceManager +} + +func NewShedderGroup(opts ...ShedderOption) *ShedderGroup { + return &ShedderGroup{ + options: opts, + manager: syncx.NewResourceManager(), + } +} + +func (g *ShedderGroup) GetShedder(key string) Shedder { + shedder, _ := g.manager.GetResource(key, func() (closer io.Closer, e error) { + return nopCloser{ + Shedder: NewAdaptiveShedder(g.options...), + }, nil + }) + return shedder.(Shedder) +} + +type nopCloser struct { + Shedder +} + +func (c nopCloser) Close() error { + return nil +} diff --git a/core/load/sheddergroup_test.go b/core/load/sheddergroup_test.go new file mode 100644 index 00000000..f58f1588 --- /dev/null +++ b/core/load/sheddergroup_test.go @@ -0,0 +1,15 @@ +package load + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGroup(t *testing.T) { + group := NewShedderGroup() + t.Run("get", func(t *testing.T) { + limiter := group.GetShedder("test") + assert.NotNil(t, limiter) + }) +} diff --git a/core/load/sheddingstat.go b/core/load/sheddingstat.go new file mode 100644 index 00000000..f61e1248 --- /dev/null +++ b/core/load/sheddingstat.go @@ -0,0 +1,68 @@ +package load + +import ( + "sync/atomic" + "time" + + "zero/core/logx" + "zero/core/stat" +) + +type ( + SheddingStat struct { + name string + total int64 + pass int64 + drop int64 + } + + snapshot struct { + Total int64 + Pass int64 + Drop int64 + } +) + +func NewSheddingStat(name string) *SheddingStat { + st := &SheddingStat{ + name: name, + } + go st.run() + return st +} + +func (s *SheddingStat) IncrementTotal() { + atomic.AddInt64(&s.total, 1) +} + +func (s *SheddingStat) IncrementPass() { + atomic.AddInt64(&s.pass, 1) +} + +func (s *SheddingStat) IncrementDrop() { + atomic.AddInt64(&s.drop, 1) +} + +func (s *SheddingStat) reset() snapshot { + return snapshot{ + Total: atomic.SwapInt64(&s.total, 0), + Pass: atomic.SwapInt64(&s.pass, 0), + Drop: atomic.SwapInt64(&s.drop, 0), + } +} + +func (s *SheddingStat) run() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for range ticker.C { + c := stat.CpuUsage() + st := s.reset() + if st.Drop == 0 { + logx.Statf("(%s) shedding_stat [1m], cpu: %d, total: %d, pass: %d, drop: %d", + s.name, c, st.Total, st.Pass, st.Drop) + } else { + logx.Statf("(%s) shedding_stat_drop [1m], cpu: %d, total: %d, pass: %d, drop: %d", + s.name, c, st.Total, st.Pass, st.Drop) + } + } +} diff --git a/core/load/sheddingstat_test.go b/core/load/sheddingstat_test.go new file mode 100644 index 00000000..351eefc0 --- /dev/null +++ b/core/load/sheddingstat_test.go @@ -0,0 +1,24 @@ +package load + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSheddingStat(t *testing.T) { + st := NewSheddingStat("any") + for i := 0; i < 3; i++ { + st.IncrementTotal() + } + for i := 0; i < 5; i++ { + st.IncrementPass() + } + for i := 0; i < 7; i++ { + st.IncrementDrop() + } + result := st.reset() + assert.Equal(t, int64(3), result.Total) + assert.Equal(t, int64(5), result.Pass) + assert.Equal(t, int64(7), result.Drop) +} diff --git a/core/logx/config.go b/core/logx/config.go new file mode 100644 index 00000000..87d02684 --- /dev/null +++ b/core/logx/config.go @@ -0,0 +1,11 @@ +package logx + +type LogConf struct { + ServiceName string `json:",optional"` + Mode string `json:",default=console,options=console|file|volume"` + Path string `json:",default=logs"` + Level string `json:",default=info,options=info|error|severe"` + Compress bool `json:",optional"` + KeepDays int `json:",optional"` + StackCooldownMillis int `json:",default=100"` +} diff --git a/core/logx/customlogger.go b/core/logx/customlogger.go new file mode 100644 index 00000000..ad043b53 --- /dev/null +++ b/core/logx/customlogger.go @@ -0,0 +1,62 @@ +package logx + +import ( + "fmt" + "io" + "time" + + "zero/core/timex" +) + +const customCallerDepth = 3 + +type customLog logEntry + +func WithDuration(d time.Duration) Logger { + return customLog{ + Duration: timex.ReprOfDuration(d), + } +} + +func (l customLog) Error(v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth)) + } +} + +func (l customLog) Errorf(format string, v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth)) + } +} + +func (l customLog) Info(v ...interface{}) { + if shouldLog(InfoLevel) { + l.write(infoLog, levelInfo, fmt.Sprint(v...)) + } +} + +func (l customLog) Infof(format string, v ...interface{}) { + if shouldLog(InfoLevel) { + l.write(infoLog, levelInfo, fmt.Sprintf(format, v...)) + } +} + +func (l customLog) Slow(v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(slowLog, levelSlow, fmt.Sprint(v...)) + } +} + +func (l customLog) Slowf(format string, v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(slowLog, levelSlow, fmt.Sprintf(format, v...)) + } +} + +func (l customLog) write(writer io.Writer, level, content string) { + l.Timestamp = getTimestamp() + l.Level = level + l.Content = content + outputJson(writer, logEntry(l)) +} diff --git a/core/logx/lesslogger.go b/core/logx/lesslogger.go new file mode 100644 index 00000000..15c878a6 --- /dev/null +++ b/core/logx/lesslogger.go @@ -0,0 +1,23 @@ +package logx + +type LessLogger struct { + *limitedExecutor +} + +func NewLessLogger(milliseconds int) *LessLogger { + return &LessLogger{ + limitedExecutor: newLimitedExecutor(milliseconds), + } +} + +func (logger *LessLogger) Error(v ...interface{}) { + logger.logOrDiscard(func() { + Error(v...) + }) +} + +func (logger *LessLogger) Errorf(format string, v ...interface{}) { + logger.logOrDiscard(func() { + Errorf(format, v...) + }) +} diff --git a/core/logx/lesslogger_test.go b/core/logx/lesslogger_test.go new file mode 100644 index 00000000..32f25fdb --- /dev/null +++ b/core/logx/lesslogger_test.go @@ -0,0 +1,31 @@ +package logx + +import ( + "log" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLessLogger_Error(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + l := NewLessLogger(500) + for i := 0; i < 100; i++ { + l.Error("hello") + } + + assert.Equal(t, 1, strings.Count(builder.String(), "\n")) +} + +func TestLessLogger_Errorf(t *testing.T) { + var builder strings.Builder + log.SetOutput(&builder) + l := NewLessLogger(500) + for i := 0; i < 100; i++ { + l.Errorf("hello") + } + + assert.Equal(t, 1, strings.Count(builder.String(), "\n")) +} diff --git a/core/logx/lesswriter.go b/core/logx/lesswriter.go new file mode 100644 index 00000000..4fec9abd --- /dev/null +++ b/core/logx/lesswriter.go @@ -0,0 +1,22 @@ +package logx + +import "io" + +type lessWriter struct { + *limitedExecutor + writer io.Writer +} + +func NewLessWriter(writer io.Writer, milliseconds int) *lessWriter { + return &lessWriter{ + limitedExecutor: newLimitedExecutor(milliseconds), + writer: writer, + } +} + +func (w *lessWriter) Write(p []byte) (n int, err error) { + w.logOrDiscard(func() { + w.writer.Write(p) + }) + return len(p), nil +} diff --git a/core/logx/lesswriter_test.go b/core/logx/lesswriter_test.go new file mode 100644 index 00000000..2b72d011 --- /dev/null +++ b/core/logx/lesswriter_test.go @@ -0,0 +1,19 @@ +package logx + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLessWriter(t *testing.T) { + var builder strings.Builder + w := NewLessWriter(&builder, 500) + for i := 0; i < 100; i++ { + _, err := w.Write([]byte("hello")) + assert.Nil(t, err) + } + + assert.Equal(t, "hello", builder.String()) +} diff --git a/core/logx/limitedexecutor.go b/core/logx/limitedexecutor.go new file mode 100644 index 00000000..df128fb5 --- /dev/null +++ b/core/logx/limitedexecutor.go @@ -0,0 +1,42 @@ +package logx + +import ( + "sync/atomic" + "time" + + "zero/core/syncx" + "zero/core/timex" +) + +type limitedExecutor struct { + threshold time.Duration + lastTime *syncx.AtomicDuration + discarded uint32 +} + +func newLimitedExecutor(milliseconds int) *limitedExecutor { + return &limitedExecutor{ + threshold: time.Duration(milliseconds) * time.Millisecond, + lastTime: syncx.NewAtomicDuration(), + } +} + +func (le *limitedExecutor) logOrDiscard(execute func()) { + if le == nil || le.threshold <= 0 { + execute() + return + } + + now := timex.Now() + if now-le.lastTime.Load() <= le.threshold { + atomic.AddUint32(&le.discarded, 1) + } else { + le.lastTime.Set(now) + discarded := atomic.SwapUint32(&le.discarded, 0) + if discarded > 0 { + Errorf("Discarded %d error messages", discarded) + } + + execute() + } +} diff --git a/core/logx/logs.go b/core/logx/logs.go new file mode 100644 index 00000000..4f038a16 --- /dev/null +++ b/core/logx/logs.go @@ -0,0 +1,481 @@ +package logx + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "path" + "runtime" + "runtime/debug" + "strconv" + "strings" + "sync" + "sync/atomic" + + "zero/core/iox" + "zero/core/lang" + "zero/core/sysx" + "zero/core/timex" +) + +const ( + // InfoLevel logs everything + InfoLevel = iota + // ErrorLevel includes errors, slows, stacks + ErrorLevel + // SevereLevel only log severe messages + SevereLevel +) + +const ( + timeFormat = "2006-01-02T15:04:05.000Z07" + + accessFilename = "access.log" + errorFilename = "error.log" + severeFilename = "severe.log" + slowFilename = "slow.log" + statFilename = "stat.log" + + consoleMode = "console" + volumeMode = "volume" + + levelInfo = "info" + levelError = "error" + levelSevere = "severe" + levelSlow = "slow" + levelStat = "stat" + + backupFileDelimiter = "-" + callerInnerDepth = 5 + flags = 0x0 +) + +var ( + ErrLogPathNotSet = errors.New("log path must be set") + ErrLogNotInitialized = errors.New("log not initialized") + ErrLogServiceNameNotSet = errors.New("log service name must be set") + + writeConsole bool + logLevel uint32 + infoLog io.WriteCloser + errorLog io.WriteCloser + severeLog io.WriteCloser + slowLog io.WriteCloser + statLog io.WriteCloser + stackLog io.Writer + + once sync.Once + initialized uint32 + options logOptions +) + +type ( + logEntry struct { + Timestamp string `json:"@timestamp"` + Level string `json:"level"` + Duration string `json:"duration,omitempty"` + Content string `json:"content"` + } + + logOptions struct { + gzipEnabled bool + logStackCooldownMills int + keepDays int + } + + LogOption func(options *logOptions) + + Logger interface { + Error(...interface{}) + Errorf(string, ...interface{}) + Info(...interface{}) + Infof(string, ...interface{}) + Slow(...interface{}) + Slowf(string, ...interface{}) + } +) + +func MustSetup(c LogConf) { + lang.Must(SetUp(c)) +} + +// SetUp sets up the logx. If already set up, just return nil. +// we allow SetUp to be called multiple times, because for example +// we need to allow different service frameworks to initialize logx respectively. +// the same logic for SetUp +func SetUp(c LogConf) error { + switch c.Mode { + case consoleMode: + setupWithConsole(c) + return nil + case volumeMode: + return setupWithVolume(c) + default: + return setupWithFiles(c) + } +} + +func Close() error { + if writeConsole { + return nil + } + + if atomic.LoadUint32(&initialized) == 0 { + return ErrLogNotInitialized + } + + atomic.StoreUint32(&initialized, 0) + + if infoLog != nil { + if err := infoLog.Close(); err != nil { + return err + } + } + + if errorLog != nil { + if err := errorLog.Close(); err != nil { + return err + } + } + + if severeLog != nil { + if err := severeLog.Close(); err != nil { + return err + } + } + + if slowLog != nil { + if err := slowLog.Close(); err != nil { + return err + } + } + + if statLog != nil { + if err := statLog.Close(); err != nil { + return err + } + } + + return nil +} + +func Disable() { + once.Do(func() { + atomic.StoreUint32(&initialized, 1) + + infoLog = iox.NopCloser(ioutil.Discard) + errorLog = iox.NopCloser(ioutil.Discard) + severeLog = iox.NopCloser(ioutil.Discard) + slowLog = iox.NopCloser(ioutil.Discard) + statLog = iox.NopCloser(ioutil.Discard) + stackLog = ioutil.Discard + }) +} + +func Error(v ...interface{}) { + ErrorCaller(1, v...) +} + +func Errorf(format string, v ...interface{}) { + ErrorCallerf(1, format, v...) +} + +func ErrorCaller(callDepth int, v ...interface{}) { + errorSync(fmt.Sprint(v...), callDepth+callerInnerDepth) +} + +func ErrorCallerf(callDepth int, format string, v ...interface{}) { + errorSync(fmt.Sprintf(format, v...), callDepth+callerInnerDepth) +} + +func ErrorStack(v ...interface{}) { + // there is newline in stack string + stackSync(fmt.Sprint(v...)) +} + +func ErrorStackf(format string, v ...interface{}) { + // there is newline in stack string + stackSync(fmt.Sprintf(format, v...)) +} + +func Info(v ...interface{}) { + infoSync(fmt.Sprint(v...)) +} + +func Infof(format string, v ...interface{}) { + infoSync(fmt.Sprintf(format, v...)) +} + +func SetLevel(level uint32) { + atomic.StoreUint32(&logLevel, level) +} + +func Severe(v ...interface{}) { + severeSync(fmt.Sprint(v...)) +} + +func Severef(format string, v ...interface{}) { + severeSync(fmt.Sprintf(format, v...)) +} + +func Slow(v ...interface{}) { + slowSync(fmt.Sprint(v...)) +} + +func Slowf(format string, v ...interface{}) { + slowSync(fmt.Sprintf(format, v...)) +} + +func Stat(v ...interface{}) { + statSync(fmt.Sprint(v...)) +} + +func Statf(format string, v ...interface{}) { + statSync(fmt.Sprintf(format, v...)) +} + +func WithCooldownMillis(millis int) LogOption { + return func(opts *logOptions) { + opts.logStackCooldownMills = millis + } +} + +func WithKeepDays(days int) LogOption { + return func(opts *logOptions) { + opts.keepDays = days + } +} + +func WithGzip() LogOption { + return func(opts *logOptions) { + opts.gzipEnabled = true + } +} + +func createOutput(path string) (io.WriteCloser, error) { + if len(path) == 0 { + return nil, ErrLogPathNotSet + } + + return NewLogger(path, DefaultRotateRule(path, backupFileDelimiter, options.keepDays, + options.gzipEnabled), options.gzipEnabled) +} + +func errorSync(msg string, callDepth int) { + if shouldLog(ErrorLevel) { + outputError(errorLog, msg, callDepth) + } +} + +func formatWithCaller(msg string, callDepth int) string { + var buf strings.Builder + + caller := getCaller(callDepth) + if len(caller) > 0 { + buf.WriteString(caller) + buf.WriteByte(' ') + } + + buf.WriteString(msg) + + return buf.String() +} + +func getCaller(callDepth int) string { + var buf strings.Builder + + _, file, line, ok := runtime.Caller(callDepth) + if ok { + short := file + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + buf.WriteString(short) + buf.WriteByte(':') + buf.WriteString(strconv.Itoa(line)) + } + + return buf.String() +} + +func getTimestamp() string { + return timex.Time().Format(timeFormat) +} + +func handleOptions(opts []LogOption) { + for _, opt := range opts { + opt(&options) + } +} + +func infoSync(msg string) { + if shouldLog(InfoLevel) { + output(infoLog, levelInfo, msg) + } +} + +func output(writer io.Writer, level, msg string) { + info := logEntry{ + Timestamp: getTimestamp(), + Level: level, + Content: msg, + } + outputJson(writer, info) +} + +func outputError(writer io.Writer, msg string, callDepth int) { + content := formatWithCaller(msg, callDepth) + output(writer, levelError, content) +} + +func outputJson(writer io.Writer, info interface{}) { + if content, err := json.Marshal(info); err != nil { + log.Println(err.Error()) + } else if atomic.LoadUint32(&initialized) == 0 || writer == nil { + log.Println(string(content)) + } else { + writer.Write(append(content, '\n')) + } +} + +func setupLogLevel(c LogConf) { + switch c.Level { + case levelInfo: + SetLevel(InfoLevel) + case levelError: + SetLevel(ErrorLevel) + case levelSevere: + SetLevel(SevereLevel) + } +} + +func setupWithConsole(c LogConf) { + once.Do(func() { + atomic.StoreUint32(&initialized, 1) + writeConsole = true + setupLogLevel(c) + + infoLog = newLogWriter(log.New(os.Stdout, "", flags)) + errorLog = newLogWriter(log.New(os.Stderr, "", flags)) + severeLog = newLogWriter(log.New(os.Stderr, "", flags)) + slowLog = newLogWriter(log.New(os.Stderr, "", flags)) + stackLog = NewLessWriter(errorLog, options.logStackCooldownMills) + statLog = infoLog + }) +} + +func setupWithFiles(c LogConf) error { + var opts []LogOption + var err error + + if len(c.Path) == 0 { + return ErrLogPathNotSet + } + + opts = append(opts, WithCooldownMillis(c.StackCooldownMillis)) + if c.Compress { + opts = append(opts, WithGzip()) + } + if c.KeepDays > 0 { + opts = append(opts, WithKeepDays(c.KeepDays)) + } + + accessFile := path.Join(c.Path, accessFilename) + errorFile := path.Join(c.Path, errorFilename) + severeFile := path.Join(c.Path, severeFilename) + slowFile := path.Join(c.Path, slowFilename) + statFile := path.Join(c.Path, statFilename) + + once.Do(func() { + atomic.StoreUint32(&initialized, 1) + handleOptions(opts) + setupLogLevel(c) + + if infoLog, err = createOutput(accessFile); err != nil { + return + } + + if errorLog, err = createOutput(errorFile); err != nil { + return + } + + if severeLog, err = createOutput(severeFile); err != nil { + return + } + + if slowLog, err = createOutput(slowFile); err != nil { + return + } + + if statLog, err = createOutput(statFile); err != nil { + return + } + + stackLog = NewLessWriter(errorLog, options.logStackCooldownMills) + }) + + return err +} + +func setupWithVolume(c LogConf) error { + if len(c.ServiceName) == 0 { + return ErrLogServiceNameNotSet + } + + c.Path = path.Join(c.Path, c.ServiceName, sysx.Hostname()) + return setupWithFiles(c) +} + +func severeSync(msg string) { + if shouldLog(SevereLevel) { + output(severeLog, levelSevere, fmt.Sprintf("%s\n%s", msg, string(debug.Stack()))) + } +} + +func shouldLog(level uint32) bool { + return atomic.LoadUint32(&logLevel) <= level +} + +func slowSync(msg string) { + if shouldLog(ErrorLevel) { + output(slowLog, levelSlow, msg) + } +} + +func stackSync(msg string) { + if shouldLog(ErrorLevel) { + output(stackLog, levelError, fmt.Sprintf("%s\n%s", msg, string(debug.Stack()))) + } +} + +func statSync(msg string) { + if shouldLog(InfoLevel) { + output(statLog, levelStat, msg) + } +} + +type logWriter struct { + logger *log.Logger +} + +func newLogWriter(logger *log.Logger) logWriter { + return logWriter{ + logger: logger, + } +} + +func (lw logWriter) Close() error { + return nil +} + +func (lw logWriter) Write(data []byte) (int, error) { + lw.logger.Print(string(data)) + return len(data), nil +} diff --git a/core/logx/logs_test.go b/core/logx/logs_test.go new file mode 100644 index 00000000..cb3e6d8a --- /dev/null +++ b/core/logx/logs_test.go @@ -0,0 +1,251 @@ +package logx + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "runtime" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + s = []byte("Sending #11 notification (id: 1451875113812010473) in #1 connection") + pool = make(chan []byte, 1) +) + +type mockWriter struct { + builder strings.Builder +} + +func (mw *mockWriter) Write(data []byte) (int, error) { + return mw.builder.Write(data) +} + +func (mw *mockWriter) Close() error { + return nil +} + +func (mw *mockWriter) Reset() { + mw.builder.Reset() +} + +func (mw *mockWriter) Contains(text string) bool { + return strings.Index(mw.builder.String(), text) > -1 +} + +func TestFileLineFileMode(t *testing.T) { + writer := new(mockWriter) + errorLog = writer + atomic.StoreUint32(&initialized, 1) + file, line := getFileLine() + Error("anything") + assert.True(t, writer.Contains(fmt.Sprintf("%s:%d", file, line+1))) + + writer.Reset() + file, line = getFileLine() + Errorf("anything %s", "format") + assert.True(t, writer.Contains(fmt.Sprintf("%s:%d", file, line+1))) +} + +func TestFileLineConsoleMode(t *testing.T) { + writer := new(mockWriter) + writeConsole = true + errorLog = newLogWriter(log.New(writer, "[ERROR] ", flags)) + atomic.StoreUint32(&initialized, 1) + file, line := getFileLine() + Error("anything") + assert.True(t, writer.Contains(fmt.Sprintf("%s:%d", file, line+1))) + + writer.Reset() + file, line = getFileLine() + Errorf("anything %s", "format") + assert.True(t, writer.Contains(fmt.Sprintf("%s:%d", file, line+1))) +} + +func TestStructedLogInfo(t *testing.T) { + doTestStructedLog(t, levelInfo, func(writer io.WriteCloser) { + infoLog = writer + }, func(v ...interface{}) { + Info(v...) + }) +} + +func TestStructedLogSlow(t *testing.T) { + doTestStructedLog(t, levelSlow, func(writer io.WriteCloser) { + slowLog = writer + }, func(v ...interface{}) { + Slow(v...) + }) +} + +func TestStructedLogWithDuration(t *testing.T) { + const message = "hello there" + writer := new(mockWriter) + infoLog = writer + atomic.StoreUint32(&initialized, 1) + WithDuration(time.Second).Info(message) + var entry logEntry + if err := json.Unmarshal([]byte(writer.builder.String()), &entry); err != nil { + t.Error(err) + } + assert.Equal(t, levelInfo, entry.Level) + assert.Equal(t, message, entry.Content) + assert.Equal(t, "1000.0ms", entry.Duration) +} + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + const message = "hello there" + writer := new(mockWriter) + infoLog = writer + atomic.StoreUint32(&initialized, 1) + Info(message) + assert.Equal(t, 0, writer.builder.Len()) +} + +func TestSetLevelTwiceWithMode(t *testing.T) { + testModes := []string{ + "mode", + "console", + "volumn", + } + for _, mode := range testModes { + testSetLevelTwiceWithMode(t, mode) + } +} + +func TestSetLevelWithDuration(t *testing.T) { + SetLevel(ErrorLevel) + const message = "hello there" + writer := new(mockWriter) + infoLog = writer + atomic.StoreUint32(&initialized, 1) + WithDuration(time.Second).Info(message) + assert.Equal(t, 0, writer.builder.Len()) +} + +func BenchmarkCopyByteSliceAppend(b *testing.B) { + for i := 0; i < b.N; i++ { + var buf []byte + buf = append(buf, getTimestamp()...) + buf = append(buf, ' ') + buf = append(buf, s...) + _ = buf + } +} + +func BenchmarkCopyByteSliceAllocExactly(b *testing.B) { + for i := 0; i < b.N; i++ { + now := []byte(getTimestamp()) + buf := make([]byte, len(now)+1+len(s)) + n := copy(buf, now) + buf[n] = ' ' + copy(buf[n+1:], s) + } +} + +func BenchmarkCopyByteSlice(b *testing.B) { + var buf []byte + for i := 0; i < b.N; i++ { + buf = make([]byte, len(s)) + copy(buf, s) + } + fmt.Fprint(ioutil.Discard, buf) +} + +func BenchmarkCopyOnWriteByteSlice(b *testing.B) { + var buf []byte + for i := 0; i < b.N; i++ { + size := len(s) + buf = s[:size:size] + } + fmt.Fprint(ioutil.Discard, buf) +} + +func BenchmarkCacheByteSlice(b *testing.B) { + for i := 0; i < b.N; i++ { + dup := fetch() + copy(dup, s) + put(dup) + } +} + +func BenchmarkLogs(b *testing.B) { + b.ReportAllocs() + + log.SetOutput(ioutil.Discard) + for i := 0; i < b.N; i++ { + Info(i) + } +} + +func fetch() []byte { + select { + case b := <-pool: + return b + default: + } + return make([]byte, 4096) +} + +func getFileLine() (string, int) { + _, file, line, _ := runtime.Caller(1) + short := file + + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + + return short, line +} + +func put(b []byte) { + select { + case pool <- b: + default: + } +} + +func doTestStructedLog(t *testing.T, level string, setup func(writer io.WriteCloser), + write func(...interface{})) { + const message = "hello there" + writer := new(mockWriter) + setup(writer) + atomic.StoreUint32(&initialized, 1) + write(message) + var entry logEntry + if err := json.Unmarshal([]byte(writer.builder.String()), &entry); err != nil { + t.Error(err) + } + assert.Equal(t, level, entry.Level) + assert.Equal(t, message, entry.Content) +} + +func testSetLevelTwiceWithMode(t *testing.T, mode string) { + SetUp(LogConf{ + Mode: mode, + Level: "error", + Path: "/dev/null", + }) + SetUp(LogConf{ + Mode: mode, + Level: "info", + Path: "/dev/null", + }) + const message = "hello there" + writer := new(mockWriter) + infoLog = writer + atomic.StoreUint32(&initialized, 1) + Info(message) + assert.Equal(t, 0, writer.builder.Len()) +} diff --git a/core/logx/rotatelogger.go b/core/logx/rotatelogger.go new file mode 100644 index 00000000..ccf3f7b9 --- /dev/null +++ b/core/logx/rotatelogger.go @@ -0,0 +1,315 @@ +package logx + +import ( + "compress/gzip" + "errors" + "fmt" + "io" + "log" + "os" + "path" + "path/filepath" + "strings" + "sync" + "time" + + "zero/core/fs" + "zero/core/lang" + "zero/core/timex" +) + +const ( + dateFormat = "2006-01-02" + hoursPerDay = 24 + bufferSize = 100 + defaultDirMode = 0755 + defaultFileMode = 0600 +) + +var ErrLogFileClosed = errors.New("error: log file closed") + +type ( + RotateRule interface { + BackupFileName() string + MarkRotated() + OutdatedFiles() []string + ShallRotate() bool + } + + RotateLogger struct { + filename string + backup string + fp *os.File + channel chan []byte + done chan lang.PlaceholderType + rule RotateRule + compress bool + keepDays int + // can't use threading.RoutineGroup because of cycle import + waitGroup sync.WaitGroup + closeOnce sync.Once + } + + DailyRotateRule struct { + rotatedTime string + filename string + delimiter string + days int + gzip bool + } +) + +func DefaultRotateRule(filename, delimiter string, days int, gzip bool) RotateRule { + return &DailyRotateRule{ + rotatedTime: getNowDate(), + filename: filename, + delimiter: delimiter, + days: days, + gzip: gzip, + } +} + +func (r *DailyRotateRule) BackupFileName() string { + return fmt.Sprintf("%s%s%s", r.filename, r.delimiter, getNowDate()) +} + +func (r *DailyRotateRule) MarkRotated() { + r.rotatedTime = getNowDate() +} + +func (r *DailyRotateRule) OutdatedFiles() []string { + if r.days <= 0 { + return nil + } + + var pattern string + if r.gzip { + pattern = fmt.Sprintf("%s%s*.gz", r.filename, r.delimiter) + } else { + pattern = fmt.Sprintf("%s%s*", r.filename, r.delimiter) + } + + files, err := filepath.Glob(pattern) + if err != nil { + Errorf("failed to delete outdated log files, error: %s", err) + return nil + } + + var buf strings.Builder + boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat) + fmt.Fprintf(&buf, "%s%s%s", r.filename, r.delimiter, boundary) + if r.gzip { + buf.WriteString(".gz") + } + boundaryFile := buf.String() + + var outdates []string + for _, file := range files { + if file < boundaryFile { + outdates = append(outdates, file) + } + } + + return outdates +} + +func (r *DailyRotateRule) ShallRotate() bool { + return len(r.rotatedTime) > 0 && getNowDate() != r.rotatedTime +} + +func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger, error) { + l := &RotateLogger{ + filename: filename, + channel: make(chan []byte, bufferSize), + done: make(chan lang.PlaceholderType), + rule: rule, + compress: compress, + } + if err := l.init(); err != nil { + return nil, err + } + + l.startWorker() + return l, nil +} + +func (l *RotateLogger) Close() error { + var err error + + l.closeOnce.Do(func() { + close(l.done) + l.waitGroup.Wait() + + if err = l.fp.Sync(); err != nil { + return + } + + err = l.fp.Close() + }) + + return err +} + +func (l *RotateLogger) Write(data []byte) (int, error) { + select { + case l.channel <- data: + return len(data), nil + case <-l.done: + log.Println(string(data)) + return 0, ErrLogFileClosed + } +} + +func (l *RotateLogger) getBackupFilename() string { + if len(l.backup) == 0 { + return l.rule.BackupFileName() + } else { + return l.backup + } +} + +func (l *RotateLogger) init() error { + l.backup = l.rule.BackupFileName() + + if _, err := os.Stat(l.filename); err != nil { + basePath := path.Dir(l.filename) + if _, err = os.Stat(basePath); err != nil { + if err = os.MkdirAll(basePath, defaultDirMode); err != nil { + return err + } + } + + if l.fp, err = os.Create(l.filename); err != nil { + return err + } + } else if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil { + return err + } + + fs.CloseOnExec(l.fp) + + return nil +} + +func (l *RotateLogger) maybeCompressFile(file string) { + if l.compress { + defer func() { + if r := recover(); r != nil { + ErrorStack(r) + } + }() + compressLogFile(file) + } +} + +func (l *RotateLogger) maybeDeleteOutdatedFiles() { + files := l.rule.OutdatedFiles() + for _, file := range files { + if err := os.Remove(file); err != nil { + Errorf("failed to remove outdated file: %s", file) + } + } +} + +func (l *RotateLogger) postRotate(file string) { + go func() { + // we cannot use threading.GoSafe here, because of import cycle. + l.maybeCompressFile(file) + l.maybeDeleteOutdatedFiles() + }() +} + +func (l *RotateLogger) rotate() error { + if l.fp != nil { + err := l.fp.Close() + l.fp = nil + if err != nil { + return err + } + } + + _, err := os.Stat(l.filename) + if err == nil && len(l.backup) > 0 { + backupFilename := l.getBackupFilename() + err = os.Rename(l.filename, backupFilename) + if err != nil { + return err + } + + l.postRotate(backupFilename) + } + + l.backup = l.rule.BackupFileName() + if l.fp, err = os.Create(l.filename); err == nil { + fs.CloseOnExec(l.fp) + } + + return err +} + +func (l *RotateLogger) startWorker() { + l.waitGroup.Add(1) + + go func() { + defer l.waitGroup.Done() + + for { + select { + case event := <-l.channel: + l.write(event) + case <-l.done: + return + } + } + }() +} + +func (l *RotateLogger) write(v []byte) { + if l.rule.ShallRotate() { + if err := l.rotate(); err != nil { + log.Println(err) + } else { + l.rule.MarkRotated() + } + } + if l.fp != nil { + l.fp.Write(v) + } +} + +func compressLogFile(file string) { + start := timex.Now() + Infof("compressing log file: %s", file) + if err := gzipFile(file); err != nil { + Errorf("compress error: %s", err) + } else { + Infof("compressed log file: %s, took %s", file, timex.Since(start)) + } +} + +func getNowDate() string { + return time.Now().Format(dateFormat) +} + +func gzipFile(file string) error { + in, err := os.Open(file) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(fmt.Sprintf("%s.gz", file)) + if err != nil { + return err + } + defer out.Close() + + w := gzip.NewWriter(out) + if _, err = io.Copy(w, in); err != nil { + return err + } else if err = w.Close(); err != nil { + return err + } + + return os.Remove(file) +} diff --git a/core/logx/syslog.go b/core/logx/syslog.go new file mode 100644 index 00000000..42ca70ff --- /dev/null +++ b/core/logx/syslog.go @@ -0,0 +1,15 @@ +package logx + +import "log" + +type redirector struct{} + +// CollectSysLog redirects system log into logx info +func CollectSysLog() { + log.SetOutput(new(redirector)) +} + +func (r *redirector) Write(p []byte) (n int, err error) { + Info(string(p)) + return len(p), nil +} diff --git a/core/logx/syslog_test.go b/core/logx/syslog_test.go new file mode 100644 index 00000000..5fbd0c65 --- /dev/null +++ b/core/logx/syslog_test.go @@ -0,0 +1,48 @@ +package logx + +import ( + "encoding/json" + "log" + "strings" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +const testlog = "Stay hungry, stay foolish." + +func TestCollectSysLog(t *testing.T) { + CollectSysLog() + content := getContent(captureOutput(func() { + log.Printf(testlog) + })) + assert.True(t, strings.Contains(content, testlog)) +} + +func TestRedirector(t *testing.T) { + var r redirector + content := getContent(captureOutput(func() { + r.Write([]byte(testlog)) + })) + assert.Equal(t, testlog, content) +} + +func captureOutput(f func()) string { + atomic.StoreUint32(&initialized, 1) + writer := new(mockWriter) + infoLog = writer + + prevLevel := logLevel + logLevel = InfoLevel + f() + logLevel = prevLevel + + return writer.builder.String() +} + +func getContent(jsonStr string) string { + var entry logEntry + json.Unmarshal([]byte(jsonStr), &entry) + return entry.Content +} diff --git a/core/logx/tracelog.go b/core/logx/tracelog.go new file mode 100644 index 00000000..268141b7 --- /dev/null +++ b/core/logx/tracelog.go @@ -0,0 +1,85 @@ +package logx + +import ( + "context" + "fmt" + "io" + + "zero/core/trace/tracespec" +) + +type tracingEntry struct { + logEntry + Trace string `json:"trace,omitempty"` + Span string `json:"span,omitempty"` + ctx context.Context `json:"-"` +} + +func (l tracingEntry) Error(v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth)) + } +} + +func (l tracingEntry) Errorf(format string, v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth)) + } +} + +func (l tracingEntry) Info(v ...interface{}) { + if shouldLog(InfoLevel) { + l.write(infoLog, levelInfo, fmt.Sprint(v...)) + } +} + +func (l tracingEntry) Infof(format string, v ...interface{}) { + if shouldLog(InfoLevel) { + l.write(infoLog, levelInfo, fmt.Sprintf(format, v...)) + } +} + +func (l tracingEntry) Slow(v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(slowLog, levelSlow, fmt.Sprint(v...)) + } +} + +func (l tracingEntry) Slowf(format string, v ...interface{}) { + if shouldLog(ErrorLevel) { + l.write(slowLog, levelSlow, fmt.Sprintf(format, v...)) + } +} + +func (l tracingEntry) write(writer io.Writer, level, content string) { + l.Timestamp = getTimestamp() + l.Level = level + l.Content = content + l.Trace = traceIdFromContext(l.ctx) + l.Span = spanIdFromContext(l.ctx) + outputJson(writer, l) +} + +func WithContext(ctx context.Context) Logger { + return tracingEntry{ + ctx: ctx, + } +} + +func spanIdFromContext(ctx context.Context) string { + t, ok := ctx.Value(tracespec.TracingKey).(tracespec.Trace) + if !ok { + return "" + } + + return t.SpanId() +} + +func traceIdFromContext(ctx context.Context) string { + t, ok := ctx.Value(tracespec.TracingKey).(tracespec.Trace) + if !ok { + return "" + } + + return t.TraceId() +} diff --git a/core/logx/tracelog_test.go b/core/logx/tracelog_test.go new file mode 100644 index 00000000..ca243cf6 --- /dev/null +++ b/core/logx/tracelog_test.go @@ -0,0 +1,50 @@ +package logx + +import ( + "context" + "strings" + "testing" + + "zero/core/trace/tracespec" + + "github.com/stretchr/testify/assert" +) + +const ( + mockTraceId = "mock-trace-id" + mockSpanId = "mock-span-id" +) + +var mock tracespec.Trace = new(mockTrace) + +func TestTraceLog(t *testing.T) { + var buf strings.Builder + ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) + WithContext(ctx).(tracingEntry).write(&buf, levelInfo, testlog) + assert.True(t, strings.Contains(buf.String(), mockTraceId)) + assert.True(t, strings.Contains(buf.String(), mockSpanId)) +} + +type mockTrace struct{} + +func (t mockTrace) TraceId() string { + return mockTraceId +} + +func (t mockTrace) SpanId() string { + return mockSpanId +} + +func (t mockTrace) Finish() { +} + +func (t mockTrace) Fork(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) { + return nil, nil +} + +func (t mockTrace) Follow(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) { + return nil, nil +} + +func (t mockTrace) Visit(fn func(key string, val string) bool) { +} diff --git a/core/mapping/fieldoptions.go b/core/mapping/fieldoptions.go new file mode 100644 index 00000000..075753db --- /dev/null +++ b/core/mapping/fieldoptions.go @@ -0,0 +1,105 @@ +package mapping + +import "fmt" + +const notSymbol = '!' + +type ( + // use context and OptionalDep option to determine the value of Optional + // nothing to do with context.Context + fieldOptionsWithContext struct { + FromString bool + Optional bool + Options []string + Default string + Range *numberRange + } + + fieldOptions struct { + fieldOptionsWithContext + OptionalDep string + } + + numberRange struct { + left float64 + leftInclude bool + right float64 + rightInclude bool + } +) + +func (o *fieldOptionsWithContext) fromString() bool { + return o != nil && o.FromString +} + +func (o *fieldOptionsWithContext) getDefault() (string, bool) { + if o == nil { + return "", false + } else { + return o.Default, len(o.Default) > 0 + } +} + +func (o *fieldOptionsWithContext) optional() bool { + return o != nil && o.Optional +} + +func (o *fieldOptionsWithContext) options() []string { + if o == nil { + return nil + } + + return o.Options +} + +func (o *fieldOptions) optionalDep() string { + if o == nil { + return "" + } else { + return o.OptionalDep + } +} + +func (o *fieldOptions) toOptionsWithContext(key string, m Valuer, fullName string) ( + *fieldOptionsWithContext, error) { + var optional bool + if o.optional() { + dep := o.optionalDep() + if len(dep) == 0 { + optional = true + } else if dep[0] == notSymbol { + dep = dep[1:] + if len(dep) == 0 { + return nil, fmt.Errorf("wrong optional value for %q in %q", key, fullName) + } + + _, baseOn := m.Value(dep) + _, selfOn := m.Value(key) + if baseOn == selfOn { + return nil, fmt.Errorf("set value for either %q or %q in %q", dep, key, fullName) + } else { + optional = baseOn + } + } else { + _, baseOn := m.Value(dep) + _, selfOn := m.Value(key) + if baseOn != selfOn { + return nil, fmt.Errorf("values for %q and %q should be both provided or both not in %q", + dep, key, fullName) + } else { + optional = !baseOn + } + } + } + + if o.fieldOptionsWithContext.Optional == optional { + return &o.fieldOptionsWithContext, nil + } else { + return &fieldOptionsWithContext{ + FromString: o.FromString, + Optional: optional, + Options: o.Options, + Default: o.Default, + }, nil + } +} diff --git a/core/mapping/jsonunmarshaler.go b/core/mapping/jsonunmarshaler.go new file mode 100644 index 00000000..af7ed304 --- /dev/null +++ b/core/mapping/jsonunmarshaler.go @@ -0,0 +1,37 @@ +package mapping + +import ( + "io" + + "zero/core/jsonx" +) + +const jsonTagKey = "json" + +var jsonUnmarshaler = NewUnmarshaler(jsonTagKey) + +func UnmarshalJsonBytes(content []byte, v interface{}) error { + return unmarshalJsonBytes(content, v, jsonUnmarshaler) +} + +func UnmarshalJsonReader(reader io.Reader, v interface{}) error { + return unmarshalJsonReader(reader, v, jsonUnmarshaler) +} + +func unmarshalJsonBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error { + var m map[string]interface{} + if err := jsonx.Unmarshal(content, &m); err != nil { + return err + } + + return unmarshaler.Unmarshal(m, v) +} + +func unmarshalJsonReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error { + var m map[string]interface{} + if err := jsonx.UnmarshalFromReader(reader, &m); err != nil { + return err + } + + return unmarshaler.Unmarshal(m, v) +} diff --git a/core/mapping/jsonunmarshaler_test.go b/core/mapping/jsonunmarshaler_test.go new file mode 100644 index 00000000..fc561fd4 --- /dev/null +++ b/core/mapping/jsonunmarshaler_test.go @@ -0,0 +1,873 @@ +package mapping + +import ( + "bytes" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnmarshalBytes(t *testing.T) { + var c struct { + Name string + } + content := []byte(`{"Name": "liao"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Name) +} + +func TestUnmarshalBytesOptional(t *testing.T) { + var c struct { + Name string + Age int `json:",optional"` + } + content := []byte(`{"Name": "liao"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Name) +} + +func TestUnmarshalBytesOptionalDefault(t *testing.T) { + var c struct { + Name string + Age int `json:",optional,default=1"` + } + content := []byte(`{"Name": "liao"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Name) + assert.Equal(t, 1, c.Age) +} + +func TestUnmarshalBytesDefaultOptional(t *testing.T) { + var c struct { + Name string + Age int `json:",default=1,optional"` + } + content := []byte(`{"Name": "liao"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Name) + assert.Equal(t, 1, c.Age) +} + +func TestUnmarshalBytesDefault(t *testing.T) { + var c struct { + Name string `json:",default=liao"` + } + content := []byte(`{}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Name) +} + +func TestUnmarshalBytesBool(t *testing.T) { + var c struct { + Great bool + } + content := []byte(`{"Great": true}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.True(t, c.Great) +} + +func TestUnmarshalBytesInt(t *testing.T) { + var c struct { + Age int + } + content := []byte(`{"Age": 1}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, c.Age) +} + +func TestUnmarshalBytesUint(t *testing.T) { + var c struct { + Age uint + } + content := []byte(`{"Age": 1}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, uint(1), c.Age) +} + +func TestUnmarshalBytesFloat(t *testing.T) { + var c struct { + Age float32 + } + content := []byte(`{"Age": 1.5}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, float32(1.5), c.Age) +} + +func TestUnmarshalBytesMustInOptional(t *testing.T) { + var c struct { + Inner struct { + There string + Must string + Optional string `json:",optional"` + } `json:",optional"` + } + content := []byte(`{}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesMustInOptionalMissedPart(t *testing.T) { + var c struct { + Inner struct { + There string + Must string + Optional string `json:",optional"` + } `json:",optional"` + } + content := []byte(`{"Inner": {"There": "sure"}}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesMustInOptionalOnlyOptionalFilled(t *testing.T) { + var c struct { + Inner struct { + There string + Must string + Optional string `json:",optional"` + } `json:",optional"` + } + content := []byte(`{"Inner": {"Optional": "sure"}}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesNil(t *testing.T) { + var c struct { + Int int64 `json:"int,optional"` + } + content := []byte(`{"int":null}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, int64(0), c.Int) +} + +func TestUnmarshalBytesNilSlice(t *testing.T) { + var c struct { + Ints []int64 `json:"ints"` + } + content := []byte(`{"ints":[null]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 0, len(c.Ints)) +} + +func TestUnmarshalBytesPartial(t *testing.T) { + var c struct { + Name string + Age float32 + } + content := []byte(`{"Age": 1.5}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesStruct(t *testing.T) { + var c struct { + Inner struct { + Name string + } + } + content := []byte(`{"Inner": {"Name": "liao"}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) +} + +func TestUnmarshalBytesStructOptional(t *testing.T) { + var c struct { + Inner struct { + Name string + Age int `json:",optional"` + } + } + content := []byte(`{"Inner": {"Name": "liao"}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) +} + +func TestUnmarshalBytesStructPtr(t *testing.T) { + var c struct { + Inner *struct { + Name string + } + } + content := []byte(`{"Inner": {"Name": "liao"}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) +} + +func TestUnmarshalBytesStructPtrOptional(t *testing.T) { + var c struct { + Inner *struct { + Name string + Age int `json:",optional"` + } + } + content := []byte(`{"Inner": {"Name": "liao"}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesStructPtrDefault(t *testing.T) { + var c struct { + Inner *struct { + Name string + Age int `json:",default=4"` + } + } + content := []byte(`{"Inner": {"Name": "liao"}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) + assert.Equal(t, 4, c.Inner.Age) +} + +func TestUnmarshalBytesSliceString(t *testing.T) { + var c struct { + Names []string + } + content := []byte(`{"Names": ["liao", "chaoxin"]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + + want := []string{"liao", "chaoxin"} + if !reflect.DeepEqual(c.Names, want) { + t.Fatalf("want %q, got %q", c.Names, want) + } +} + +func TestUnmarshalBytesSliceStringOptional(t *testing.T) { + var c struct { + Names []string + Age []int `json:",optional"` + } + content := []byte(`{"Names": ["liao", "chaoxin"]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + + want := []string{"liao", "chaoxin"} + if !reflect.DeepEqual(c.Names, want) { + t.Fatalf("want %q, got %q", c.Names, want) + } +} + +func TestUnmarshalBytesSliceStruct(t *testing.T) { + var c struct { + People []struct { + Name string + Age int + } + } + content := []byte(`{"People": [{"Name": "liao", "Age": 1}, {"Name": "chaoxin", "Age": 2}]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + + want := []struct { + Name string + Age int + }{ + {"liao", 1}, + {"chaoxin", 2}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %q, got %q", c.People, want) + } +} + +func TestUnmarshalBytesSliceStructOptional(t *testing.T) { + var c struct { + People []struct { + Name string + Age int + Emails []string `json:",optional"` + } + } + content := []byte(`{"People": [{"Name": "liao", "Age": 1}, {"Name": "chaoxin", "Age": 2}]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + + want := []struct { + Name string + Age int + Emails []string `json:",optional"` + }{ + {"liao", 1, nil}, + {"chaoxin", 2, nil}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %q, got %q", c.People, want) + } +} + +func TestUnmarshalBytesSliceStructPtr(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + } + } + content := []byte(`{"People": [{"Name": "liao", "Age": 1}, {"Name": "chaoxin", "Age": 2}]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + + want := []*struct { + Name string + Age int + }{ + {"liao", 1}, + {"chaoxin", 2}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %v, got %v", c.People, want) + } +} + +func TestUnmarshalBytesSliceStructPtrOptional(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + Emails []string `json:",optional"` + } + } + content := []byte(`{"People": [{"Name": "liao", "Age": 1}, {"Name": "chaoxin", "Age": 2}]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + + want := []*struct { + Name string + Age int + Emails []string `json:",optional"` + }{ + {"liao", 1, nil}, + {"chaoxin", 2, nil}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %v, got %v", c.People, want) + } +} + +func TestUnmarshalBytesSliceStructPtrPartial(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + Email string + } + } + content := []byte(`{"People": [{"Name": "liao", "Age": 1}, {"Name": "chaoxin", "Age": 2}]}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesSliceStructPtrDefault(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + Email string `json:",default=chaoxin@liao.com"` + } + } + content := []byte(`{"People": [{"Name": "liao", "Age": 1}, {"Name": "chaoxin", "Age": 2}]}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + + want := []*struct { + Name string + Age int + Email string + }{ + {"liao", 1, "chaoxin@liao.com"}, + {"chaoxin", 2, "chaoxin@liao.com"}, + } + + for i := range c.People { + actual := c.People[i] + expect := want[i] + assert.Equal(t, expect.Age, actual.Age) + assert.Equal(t, expect.Email, actual.Email) + assert.Equal(t, expect.Name, actual.Name) + } +} + +func TestUnmarshalBytesSliceStringPartial(t *testing.T) { + var c struct { + Names []string + Age int + } + content := []byte(`{"Age": 1}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesSliceStructPartial(t *testing.T) { + var c struct { + Group string + People []struct { + Name string + Age int + } + } + content := []byte(`{"Group": "chaoxin"}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesInnerAnonymousPartial(t *testing.T) { + type ( + Deep struct { + A string + B string `json:",optional"` + } + Inner struct { + Deep + InnerV string `json:",optional"` + } + ) + + var c struct { + Value Inner `json:",optional"` + } + content := []byte(`{"Value": {"InnerV": "chaoxin"}}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesStructPartial(t *testing.T) { + var c struct { + Group string + Person struct { + Name string + Age int + } + } + content := []byte(`{"Group": "chaoxin"}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesEmptyMap(t *testing.T) { + var c struct { + Persons map[string]int `json:",optional"` + } + content := []byte(`{"Persons": {}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Empty(t, c.Persons) +} + +func TestUnmarshalBytesMap(t *testing.T) { + var c struct { + Persons map[string]int + } + content := []byte(`{"Persons": {"first": 1, "second": 2}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 2, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"]) + assert.Equal(t, 2, c.Persons["second"]) +} + +func TestUnmarshalBytesMapStruct(t *testing.T) { + var c struct { + Persons map[string]struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": {"Id": 1, "name": "kevin"}}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"].Id) + assert.Equal(t, "kevin", c.Persons["first"].Name) +} + +func TestUnmarshalBytesMapStructPtr(t *testing.T) { + var c struct { + Persons map[string]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": {"Id": 1, "name": "kevin"}}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"].Id) + assert.Equal(t, "kevin", c.Persons["first"].Name) +} + +func TestUnmarshalBytesMapStructMissingPartial(t *testing.T) { + var c struct { + Persons map[string]*struct { + Id int + Name string + } + } + content := []byte(`{"Persons": {"first": {"Id": 1}}}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesMapStructOptional(t *testing.T) { + var c struct { + Persons map[string]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": {"Id": 1}}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"].Id) +} + +func TestUnmarshalBytesMapEmptyStructSlice(t *testing.T) { + var c struct { + Persons map[string][]struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": []}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Empty(t, c.Persons["first"]) +} + +func TestUnmarshalBytesMapStructSlice(t *testing.T) { + var c struct { + Persons map[string][]struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": [{"Id": 1, "name": "kevin"}]}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"][0].Id) + assert.Equal(t, "kevin", c.Persons["first"][0].Name) +} + +func TestUnmarshalBytesMapEmptyStructPtrSlice(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": []}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Empty(t, c.Persons["first"]) +} + +func TestUnmarshalBytesMapStructPtrSlice(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": [{"Id": 1, "name": "kevin"}]}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"][0].Id) + assert.Equal(t, "kevin", c.Persons["first"][0].Name) +} + +func TestUnmarshalBytesMapStructPtrSliceMissingPartial(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string + } + } + content := []byte(`{"Persons": {"first": [{"Id": 1}]}}`) + + assert.NotNil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalBytesMapStructPtrSliceOptional(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`{"Persons": {"first": [{"Id": 1}]}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"][0].Id) +} + +func TestUnmarshalStructOptional(t *testing.T) { + var c struct { + Name string + Etcd struct { + Hosts []string + Key string + } `json:",optional"` + } + content := []byte(`{"Name": "kevin"}`) + + err := UnmarshalJsonBytes(content, &c) + assert.Nil(t, err) + assert.Equal(t, "kevin", c.Name) +} + +func TestUnmarshalStructLowerCase(t *testing.T) { + var c struct { + Name string + Etcd struct { + Key string + } `json:"etcd"` + } + content := []byte(`{"Name": "kevin", "etcd": {"Key": "the key"}}`) + + err := UnmarshalJsonBytes(content, &c) + assert.Nil(t, err) + assert.Equal(t, "kevin", c.Name) + assert.Equal(t, "the key", c.Etcd.Key) +} + +func TestUnmarshalWithStructAllOptionalWithEmpty(t *testing.T) { + var c struct { + Inner struct { + Optional string `json:",optional"` + } + Else string + } + content := []byte(`{"Else": "sure", "Inner": {}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalWithStructAllOptionalPtr(t *testing.T) { + var c struct { + Inner *struct { + Optional string `json:",optional"` + } + Else string + } + content := []byte(`{"Else": "sure", "Inner": {}}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalWithStructOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + In Inner `json:",optional"` + Else string + } + content := []byte(`{"Else": "sure"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "", c.In.Must) +} + +func TestUnmarshalWithStructPtrOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + In *Inner `json:",optional"` + Else string + } + content := []byte(`{"Else": "sure"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Nil(t, c.In) +} + +func TestUnmarshalWithStructAllOptionalAnonymous(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + Inner + Else string + } + content := []byte(`{"Else": "sure"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalWithStructAllOptionalAnonymousPtr(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + *Inner + Else string + } + content := []byte(`{"Else": "sure"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) +} + +func TestUnmarshalWithStructAllOptionalProvoidedAnonymous(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + Inner + Else string + } + content := []byte(`{"Else": "sure", "Optional": "optional"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "optional", c.Optional) +} + +func TestUnmarshalWithStructAllOptionalProvoidedAnonymousPtr(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + *Inner + Else string + } + content := []byte(`{"Else": "sure", "Optional": "optional"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "optional", c.Optional) +} + +func TestUnmarshalWithStructAnonymous(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + Inner + Else string + } + content := []byte(`{"Else": "sure", "Must": "must"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "must", c.Must) +} + +func TestUnmarshalWithStructAnonymousPtr(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + *Inner + Else string + } + content := []byte(`{"Else": "sure", "Must": "must"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "must", c.Must) +} + +func TestUnmarshalWithStructAnonymousOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + Inner `json:",optional"` + Else string + } + content := []byte(`{"Else": "sure"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "", c.Must) +} + +func TestUnmarshalWithStructPtrAnonymousOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + *Inner `json:",optional"` + Else string + } + content := []byte(`{"Else": "sure"}`) + + assert.Nil(t, UnmarshalJsonBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Nil(t, c.Inner) +} + +func TestUnmarshalWithZeroValues(t *testing.T) { + type inner struct { + False bool `json:"no"` + Int int `json:"int"` + String string `json:"string"` + } + content := []byte(`{"no": false, "int": 0, "string": ""}`) + reader := bytes.NewReader(content) + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalJsonReader(reader, &in)) + ast.False(in.False) + ast.Equal(0, in.Int) + ast.Equal("", in.String) +} + +func TestUnmarshalBytesError(t *testing.T) { + payload := `[{"abcd": "cdef"}]` + var v struct { + Any string + } + + err := UnmarshalJsonBytes([]byte(payload), &v) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), payload)) +} + +func TestUnmarshalReaderError(t *testing.T) { + payload := `[{"abcd": "cdef"}]` + reader := strings.NewReader(payload) + var v struct { + Any string + } + + err := UnmarshalJsonReader(reader, &v) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), payload)) +} diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go new file mode 100644 index 00000000..d7f2d7a8 --- /dev/null +++ b/core/mapping/unmarshaler.go @@ -0,0 +1,737 @@ +package mapping + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "sync/atomic" + "time" + + "zero/core/jsonx" + "zero/core/lang" + "zero/core/stringx" +) + +const ( + defaultKeyName = "key" + delimiter = '.' +) + +var ( + errTypeMismatch = errors.New("type mismatch") + errValueNotSettable = errors.New("value is not settable") + keyUnmarshaler = NewUnmarshaler(defaultKeyName) + cacheKeys atomic.Value + cacheKeysLock sync.Mutex + durationType = reflect.TypeOf(time.Duration(0)) + emptyMap = map[string]interface{}{} + emptyValue = reflect.ValueOf(lang.Placeholder) +) + +type ( + Unmarshaler struct { + key string + opts unmarshalOptions + } + + unmarshalOptions struct { + fromString bool + } + + keyCache map[string][]string + UnmarshalOption func(*unmarshalOptions) +) + +func init() { + cacheKeys.Store(make(keyCache)) +} + +func NewUnmarshaler(key string, opts ...UnmarshalOption) *Unmarshaler { + unmarshaler := Unmarshaler{ + key: key, + } + + for _, opt := range opts { + opt(&unmarshaler.opts) + } + + return &unmarshaler +} + +func UnmarshalKey(m map[string]interface{}, v interface{}) error { + return keyUnmarshaler.Unmarshal(m, v) +} + +func (u *Unmarshaler) Unmarshal(m map[string]interface{}, v interface{}) error { + return u.UnmarshalValuer(MapValuer(m), v) +} + +func (u *Unmarshaler) UnmarshalValuer(m Valuer, v interface{}) error { + return u.unmarshalWithFullName(m, v, "") +} + +func (u *Unmarshaler) unmarshalWithFullName(m Valuer, v interface{}, fullName string) error { + rv := reflect.ValueOf(v) + if err := ValidatePtr(&rv); err != nil { + return err + } + + rte := reflect.TypeOf(v).Elem() + rve := rv.Elem() + numFields := rte.NumField() + for i := 0; i < numFields; i++ { + field := rte.Field(i) + if usingDifferentKeys(u.key, field) { + continue + } + + if err := u.processField(field, rve.Field(i), m, fullName); err != nil { + return err + } + } + + return nil +} + +func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value reflect.Value, + m Valuer, fullName string) error { + key, options, err := u.parseOptionsWithContext(field, m, fullName) + if err != nil { + return err + } + + if _, hasValue := getValue(m, key); hasValue { + return fmt.Errorf("fields of %s can't be wrapped inside, because it's anonymous", key) + } + + if options.optional() { + return u.processAnonymousFieldOptional(field, value, key, m, fullName) + } else { + return u.processAnonymousFieldRequired(field, value, m, fullName) + } +} + +func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value, + key string, m Valuer, fullName string) error { + var filled bool + var required int + var requiredFilled int + var indirectValue reflect.Value + fieldType := Deref(field.Type) + + for i := 0; i < fieldType.NumField(); i++ { + subField := fieldType.Field(i) + fieldKey, fieldOpts, err := u.parseOptionsWithContext(subField, m, fullName) + if err != nil { + return err + } + + _, hasValue := getValue(m, fieldKey) + if hasValue { + if !filled { + filled = true + maybeNewValue(field, value) + indirectValue = reflect.Indirect(value) + + } + if err = u.processField(subField, indirectValue.Field(i), m, fullName); err != nil { + return err + } + } + if !fieldOpts.optional() { + required++ + if hasValue { + requiredFilled++ + } + } + } + + if filled && required != requiredFilled { + return fmt.Errorf("%s is not fully set", key) + } + + return nil +} + +func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value, + m Valuer, fullName string) error { + maybeNewValue(field, value) + fieldType := Deref(field.Type) + indirectValue := reflect.Indirect(value) + + for i := 0; i < fieldType.NumField(); i++ { + if err := u.processField(fieldType.Field(i), indirectValue.Field(i), m, fullName); err != nil { + return err + } + } + + return nil +} + +func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value, m Valuer, + fullName string) error { + if usingDifferentKeys(u.key, field) { + return nil + } + + if field.Anonymous { + return u.processAnonymousField(field, value, m, fullName) + } else { + return u.processNamedField(field, value, m, fullName) + } +} + +func (u *Unmarshaler) processFieldNotFromString(field reflect.StructField, value reflect.Value, + mapValue interface{}, opts *fieldOptionsWithContext, fullName string) error { + fieldType := field.Type + derefedFieldType := Deref(fieldType) + typeKind := derefedFieldType.Kind() + valueKind := reflect.TypeOf(mapValue).Kind() + + switch { + case valueKind == reflect.Map && typeKind == reflect.Struct: + return u.processFieldStruct(field, value, mapValue, fullName) + case valueKind == reflect.String && typeKind == reflect.Slice: + return u.fillSliceFromString(fieldType, value, mapValue, fullName) + case valueKind == reflect.String && derefedFieldType == durationType: + return fillDurationValue(fieldType.Kind(), value, mapValue.(string)) + default: + return u.processFieldPrimitive(field, value, mapValue, opts, fullName) + } +} + +func (u *Unmarshaler) processFieldPrimitive(field reflect.StructField, value reflect.Value, + mapValue interface{}, opts *fieldOptionsWithContext, fullName string) error { + fieldType := field.Type + typeKind := Deref(fieldType).Kind() + valueKind := reflect.TypeOf(mapValue).Kind() + + switch { + case typeKind == reflect.Slice && valueKind == reflect.Slice: + return u.fillSlice(fieldType, value, mapValue) + case typeKind == reflect.Map && valueKind == reflect.Map: + return u.fillMap(field, value, mapValue) + default: + switch v := mapValue.(type) { + case json.Number: + return u.processFieldPrimitiveWithJsonNumber(field, value, v, opts, fullName) + default: + if typeKind == valueKind { + if err := validateValueInOptions(opts.options(), mapValue); err != nil { + return err + } + + return fillWithSameType(field, value, mapValue, opts) + } + } + } + + return newTypeMismatchError(fullName) +} + +func (u *Unmarshaler) processFieldPrimitiveWithJsonNumber(field reflect.StructField, value reflect.Value, + v json.Number, opts *fieldOptionsWithContext, fullName string) error { + fieldType := field.Type + fieldKind := fieldType.Kind() + typeKind := Deref(fieldType).Kind() + + if err := validateJsonNumberRange(v, opts); err != nil { + return err + } + + if fieldKind == reflect.Ptr { + value = value.Elem() + } + + switch typeKind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + iValue, err := v.Int64() + if err != nil { + return err + } + + value.SetInt(iValue) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + iValue, err := v.Int64() + if err != nil { + return err + } + + value.SetUint(uint64(iValue)) + case reflect.Float32, reflect.Float64: + fValue, err := v.Float64() + if err != nil { + return err + } + + value.SetFloat(fValue) + default: + return newTypeMismatchError(fullName) + } + + return nil +} + +func (u *Unmarshaler) processFieldStruct(field reflect.StructField, value reflect.Value, + mapValue interface{}, fullName string) error { + convertedValue, ok := mapValue.(map[string]interface{}) + if !ok { + valueKind := reflect.TypeOf(mapValue).Kind() + return fmt.Errorf("error: field: %s, expect map[string]interface{}, actual %v", fullName, valueKind) + } + + return u.processFieldStructWithMap(field, value, MapValuer(convertedValue), fullName) +} + +func (u *Unmarshaler) processFieldStructWithMap(field reflect.StructField, value reflect.Value, + m Valuer, fullName string) error { + if field.Type.Kind() == reflect.Ptr { + baseType := Deref(field.Type) + target := reflect.New(baseType).Elem() + if err := u.unmarshalWithFullName(m, target.Addr().Interface(), fullName); err != nil { + return err + } + + value.Set(target.Addr()) + } else if err := u.unmarshalWithFullName(m, value.Addr().Interface(), fullName); err != nil { + return err + } + + return nil +} + +func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value, + m Valuer, fullName string) error { + key, opts, err := u.parseOptionsWithContext(field, m, fullName) + if err != nil { + return err + } + + fullName = join(fullName, key) + mapValue, hasValue := getValue(m, key) + if hasValue { + return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName) + } else { + return u.processNamedFieldWithoutValue(field, value, opts, fullName) + } +} + +func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, value reflect.Value, + mapValue interface{}, key string, opts *fieldOptionsWithContext, fullName string) error { + if mapValue == nil { + if opts.optional() { + return nil + } else { + return fmt.Errorf("field %s mustn't be nil", key) + } + } + + maybeNewValue(field, value) + + fieldKind := Deref(field.Type).Kind() + switch fieldKind { + case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: + return u.processFieldNotFromString(field, value, mapValue, opts, fullName) + default: + if u.opts.fromString || opts.fromString() { + valueKind := reflect.TypeOf(mapValue).Kind() + if valueKind != reflect.String { + return fmt.Errorf("error: the value in map is not string, but %s", valueKind) + } + + options := opts.options() + if len(options) > 0 { + if !stringx.Contains(options, mapValue.(string)) { + return fmt.Errorf(`error: value "%s" for field "%s" is not defined in opts "%v"`, + mapValue, key, options) + } + } + + return fillPrimitive(field.Type, value, mapValue, opts, fullName) + } + + return u.processFieldNotFromString(field, value, mapValue, opts, fullName) + } +} + +func (u *Unmarshaler) processNamedFieldWithoutValue(field reflect.StructField, value reflect.Value, + opts *fieldOptionsWithContext, fullName string) error { + derefedType := Deref(field.Type) + fieldKind := derefedType.Kind() + if defaultValue, ok := opts.getDefault(); ok { + if field.Type.Kind() == reflect.Ptr { + maybeNewValue(field, value) + value = value.Elem() + } + if derefedType == durationType { + return fillDurationValue(fieldKind, value, defaultValue) + } + return setValue(fieldKind, value, defaultValue) + } + + switch fieldKind { + case reflect.Array, reflect.Map, reflect.Slice: + if !opts.optional() { + return u.processFieldNotFromString(field, value, emptyMap, opts, fullName) + } + case reflect.Struct: + if !opts.optional() { + required, err := structValueRequired(u.key, derefedType) + if err != nil { + return err + } + if required { + return fmt.Errorf("%q is not set", fullName) + } + return u.processFieldNotFromString(field, value, emptyMap, opts, fullName) + } + default: + if !opts.optional() { + return newInitError(fullName) + } + } + + return nil +} + +func (u *Unmarshaler) fillMap(field reflect.StructField, value reflect.Value, mapValue interface{}) error { + if !value.CanSet() { + return errValueNotSettable + } + + fieldKeyType := field.Type.Key() + fieldElemType := field.Type.Elem() + targetValue, err := u.generateMap(fieldKeyType, fieldElemType, mapValue) + if err != nil { + return err + } + + value.Set(targetValue) + return nil +} + +func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, mapValue interface{}) error { + if !value.CanSet() { + return errValueNotSettable + } + + baseType := fieldType.Elem() + baseKind := baseType.Kind() + dereffedBaseType := Deref(baseType) + dereffedBaseKind := dereffedBaseType.Kind() + refValue := reflect.ValueOf(mapValue) + conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap()) + + var valid bool + for i := 0; i < refValue.Len(); i++ { + ithValue := refValue.Index(i).Interface() + if ithValue == nil { + continue + } + + valid = true + switch dereffedBaseKind { + case reflect.Struct: + target := reflect.New(dereffedBaseType) + if err := u.Unmarshal(ithValue.(map[string]interface{}), target.Interface()); err != nil { + return err + } + + if baseKind == reflect.Ptr { + conv.Index(i).Set(target) + } else { + conv.Index(i).Set(target.Elem()) + } + default: + if err := u.fillSliceValue(conv, i, dereffedBaseKind, ithValue); err != nil { + return err + } + } + } + + if valid { + value.Set(conv) + } + + return nil +} + +func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value, + mapValue interface{}, fullName string) error { + var slice []interface{} + if err := jsonx.UnmarshalFromString(mapValue.(string), &slice); err != nil { + return err + } + + baseFieldType := Deref(fieldType.Elem()) + baseFieldKind := baseFieldType.Kind() + conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice)) + + for i := 0; i < len(slice); i++ { + if err := u.fillSliceValue(conv, i, baseFieldKind, slice[i]); err != nil { + return err + } + } + + value.Set(conv) + return nil +} + +func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, baseKind reflect.Kind, value interface{}) error { + switch v := value.(type) { + case json.Number: + return setValue(baseKind, slice.Index(index), v.String()) + default: + // don't need to consider the difference between int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, uint64, because they're handled as json.Number. + if slice.Index(index).Kind() != reflect.TypeOf(value).Kind() { + return errTypeMismatch + } + + slice.Index(index).Set(reflect.ValueOf(value)) + return nil + } +} + +func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue interface{}) (reflect.Value, error) { + mapType := reflect.MapOf(keyType, elemType) + valueType := reflect.TypeOf(mapValue) + if mapType == valueType { + return reflect.ValueOf(mapValue), nil + } + + refValue := reflect.ValueOf(mapValue) + targetValue := reflect.MakeMapWithSize(mapType, refValue.Len()) + fieldElemKind := elemType.Kind() + dereffedElemType := Deref(elemType) + dereffedElemKind := dereffedElemType.Kind() + + for _, key := range refValue.MapKeys() { + keythValue := refValue.MapIndex(key) + keythData := keythValue.Interface() + + switch dereffedElemKind { + case reflect.Slice: + target := reflect.New(dereffedElemType) + if err := u.fillSlice(elemType, target.Elem(), keythData); err != nil { + return emptyValue, err + } + + targetValue.SetMapIndex(key, target.Elem()) + case reflect.Struct: + keythMap, ok := keythData.(map[string]interface{}) + if !ok { + return emptyValue, errTypeMismatch + } + + target := reflect.New(dereffedElemType) + if err := u.Unmarshal(keythMap, target.Interface()); err != nil { + return emptyValue, err + } + + if fieldElemKind == reflect.Ptr { + targetValue.SetMapIndex(key, target) + } else { + targetValue.SetMapIndex(key, target.Elem()) + } + case reflect.Map: + keythMap, ok := keythData.(map[string]interface{}) + if !ok { + return emptyValue, errTypeMismatch + } + + innerValue, err := u.generateMap(elemType.Key(), elemType.Elem(), keythMap) + if err != nil { + return emptyValue, err + } + + targetValue.SetMapIndex(key, innerValue) + default: + switch v := keythData.(type) { + case string: + targetValue.SetMapIndex(key, reflect.ValueOf(v)) + case json.Number: + target := reflect.New(dereffedElemType) + if err := setValue(dereffedElemKind, target.Elem(), v.String()); err != nil { + return emptyValue, err + } + + targetValue.SetMapIndex(key, target.Elem()) + default: + targetValue.SetMapIndex(key, keythValue) + } + } + } + + return targetValue, nil +} + +func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Valuer, fullName string) ( + string, *fieldOptionsWithContext, error) { + key, options, err := parseKeyAndOptions(u.key, field) + if err != nil { + return "", nil, err + } else if options == nil { + return key, nil, nil + } + + optsWithContext, err := options.toOptionsWithContext(key, m, fullName) + if err != nil { + return "", nil, err + } + + return key, optsWithContext, nil +} + +func WithStringValues() UnmarshalOption { + return func(opt *unmarshalOptions) { + opt.fromString = true + } +} + +func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error { + d, err := time.ParseDuration(dur) + if err != nil { + return err + } + + if fieldKind == reflect.Ptr { + value.Elem().Set(reflect.ValueOf(d)) + } else { + value.Set(reflect.ValueOf(d)) + } + + return nil +} + +func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue interface{}, + opts *fieldOptionsWithContext, fullName string) error { + if !value.CanSet() { + return errValueNotSettable + } + + baseType := Deref(fieldType) + if fieldType.Kind() == reflect.Ptr { + target := reflect.New(baseType).Elem() + switch mapValue.(type) { + case string, json.Number: + value.Set(target.Addr()) + value = target + } + } + + switch v := mapValue.(type) { + case string: + return validateAndSetValue(baseType.Kind(), value, v, opts) + case json.Number: + if err := validateJsonNumberRange(v, opts); err != nil { + return err + } + return setValue(baseType.Kind(), value, v.String()) + default: + return newTypeMismatchError(fullName) + } +} + +func fillWithSameType(field reflect.StructField, value reflect.Value, mapValue interface{}, + opts *fieldOptionsWithContext) error { + if !value.CanSet() { + return errValueNotSettable + } + + if err := validateValueRange(mapValue, opts); err != nil { + return err + } + + if field.Type.Kind() == reflect.Ptr { + baseType := Deref(field.Type) + target := reflect.New(baseType).Elem() + target.Set(reflect.ValueOf(mapValue)) + value.Set(target.Addr()) + } else { + value.Set(reflect.ValueOf(mapValue)) + } + + return nil +} + +// getValue gets the value for the specific key, the key can be in the format of parentKey.childKey +func getValue(m Valuer, key string) (interface{}, bool) { + keys := readKeys(key) + return getValueWithChainedKeys(m, keys) +} + +func getValueWithChainedKeys(m Valuer, keys []string) (interface{}, bool) { + if len(keys) == 1 { + v, ok := m.Value(keys[0]) + return v, ok + } else if len(keys) > 1 { + if v, ok := m.Value(keys[0]); ok { + if nextm, ok := v.(map[string]interface{}); ok { + return getValueWithChainedKeys(MapValuer(nextm), keys[1:]) + } + } + } + + return nil, false +} + +func insertKeys(key string, cache []string) { + cacheKeysLock.Lock() + defer cacheKeysLock.Unlock() + + keys := cacheKeys.Load().(keyCache) + // copy the contents into the new map, to guarantee the old map is immutable + newKeys := make(keyCache) + for k, v := range keys { + newKeys[k] = v + } + newKeys[key] = cache + cacheKeys.Store(newKeys) +} + +func join(elem ...string) string { + var builder strings.Builder + + var fillSep bool + for _, e := range elem { + if len(e) == 0 { + continue + } + + if fillSep { + builder.WriteByte(delimiter) + } else { + fillSep = true + } + + builder.WriteString(e) + } + + return builder.String() +} + +func newInitError(name string) error { + return fmt.Errorf("field %s is not set", name) +} + +func newTypeMismatchError(name string) error { + return fmt.Errorf("error: type mismatch for field %s", name) +} + +func readKeys(key string) []string { + cache := cacheKeys.Load().(keyCache) + if keys, ok := cache[key]; ok { + return keys + } + + keys := strings.FieldsFunc(key, func(c rune) bool { + return c == delimiter + }) + insertKeys(key, keys) + + return keys +} diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go new file mode 100644 index 00000000..72ea2840 --- /dev/null +++ b/core/mapping/unmarshaler_test.go @@ -0,0 +1,2469 @@ +package mapping + +import ( + "encoding/json" + "strconv" + "testing" + "time" + + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +// because json.Number doesn't support strconv.ParseUint(...), +// so we only can test to 62 bits. +const maxUintBitsToTest = 62 + +func TestUnmarshalWithoutTagName(t *testing.T) { + type inner struct { + Optional bool `key:",optional"` + } + m := map[string]interface{}{ + "Optional": true, + } + + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.True(t, in.Optional) +} + +func TestUnmarshalBool(t *testing.T) { + type inner struct { + True bool `key:"yes"` + False bool `key:"no"` + TrueFromOne bool `key:"yesone,string"` + FalseFromZero bool `key:"nozero,string"` + TrueFromTrue bool `key:"yestrue,string"` + FalseFromFalse bool `key:"nofalse,string"` + DefaultTrue bool `key:"defaulttrue,default=1"` + Optional bool `key:"optional,optional"` + } + m := map[string]interface{}{ + "yes": true, + "no": false, + "yesone": "1", + "nozero": "0", + "yestrue": "true", + "nofalse": "false", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.True(in.True) + ast.False(in.False) + ast.True(in.TrueFromOne) + ast.False(in.FalseFromZero) + ast.True(in.TrueFromTrue) + ast.False(in.FalseFromFalse) + ast.True(in.DefaultTrue) +} + +func TestUnmarshalDuration(t *testing.T) { + type inner struct { + Duration time.Duration `key:"duration"` + LessDuration time.Duration `key:"less"` + MoreDuration time.Duration `key:"more"` + } + m := map[string]interface{}{ + "duration": "5s", + "less": "100ms", + "more": "24h", + } + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.Equal(t, time.Second*5, in.Duration) + assert.Equal(t, time.Millisecond*100, in.LessDuration) + assert.Equal(t, time.Hour*24, in.MoreDuration) +} + +func TestUnmarshalDurationDefault(t *testing.T) { + type inner struct { + Int int `key:"int"` + Duration time.Duration `key:"duration,default=5s"` + } + m := map[string]interface{}{ + "int": 5, + } + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.Equal(t, 5, in.Int) + assert.Equal(t, time.Second*5, in.Duration) +} + +func TestUnmarshalDurationPtr(t *testing.T) { + type inner struct { + Duration *time.Duration `key:"duration"` + } + m := map[string]interface{}{ + "duration": "5s", + } + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.Equal(t, time.Second*5, *in.Duration) +} + +func TestUnmarshalDurationPtrDefault(t *testing.T) { + type inner struct { + Int int `key:"int"` + Value *int `key:",default=5"` + Duration *time.Duration `key:"duration,default=5s"` + } + m := map[string]interface{}{ + "int": 5, + } + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.Equal(t, 5, in.Int) + assert.Equal(t, 5, *in.Value) + assert.Equal(t, time.Second*5, *in.Duration) +} + +func TestUnmarshalInt(t *testing.T) { + type inner struct { + Int int `key:"int"` + IntFromStr int `key:"intstr,string"` + Int8 int8 `key:"int8"` + Int8FromStr int8 `key:"int8str,string"` + Int16 int16 `key:"int16"` + Int16FromStr int16 `key:"int16str,string"` + Int32 int32 `key:"int32"` + Int32FromStr int32 `key:"int32str,string"` + Int64 int64 `key:"int64"` + Int64FromStr int64 `key:"int64str,string"` + DefaultInt int64 `key:"defaultint,default=11"` + Optional int `key:"optional,optional"` + } + m := map[string]interface{}{ + "int": 1, + "intstr": "2", + "int8": int8(3), + "int8str": "4", + "int16": int16(5), + "int16str": "6", + "int32": int32(7), + "int32str": "8", + "int64": int64(9), + "int64str": "10", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(1, in.Int) + ast.Equal(2, in.IntFromStr) + ast.Equal(int8(3), in.Int8) + ast.Equal(int8(4), in.Int8FromStr) + ast.Equal(int16(5), in.Int16) + ast.Equal(int16(6), in.Int16FromStr) + ast.Equal(int32(7), in.Int32) + ast.Equal(int32(8), in.Int32FromStr) + ast.Equal(int64(9), in.Int64) + ast.Equal(int64(10), in.Int64FromStr) + ast.Equal(int64(11), in.DefaultInt) +} + +func TestUnmarshalIntPtr(t *testing.T) { + type inner struct { + Int *int `key:"int"` + } + m := map[string]interface{}{ + "int": 1, + } + + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.NotNil(t, in.Int) + assert.Equal(t, 1, *in.Int) +} + +func TestUnmarshalIntWithDefault(t *testing.T) { + type inner struct { + Int int `key:"int,default=5"` + } + m := map[string]interface{}{ + "int": 1, + } + + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.Equal(t, 1, in.Int) +} + +func TestUnmarshalUint(t *testing.T) { + type inner struct { + Uint uint `key:"uint"` + UintFromStr uint `key:"uintstr,string"` + Uint8 uint8 `key:"uint8"` + Uint8FromStr uint8 `key:"uint8str,string"` + Uint16 uint16 `key:"uint16"` + Uint16FromStr uint16 `key:"uint16str,string"` + Uint32 uint32 `key:"uint32"` + Uint32FromStr uint32 `key:"uint32str,string"` + Uint64 uint64 `key:"uint64"` + Uint64FromStr uint64 `key:"uint64str,string"` + DefaultUint uint `key:"defaultuint,default=11"` + Optional uint `key:"optional,optional"` + } + m := map[string]interface{}{ + "uint": uint(1), + "uintstr": "2", + "uint8": uint8(3), + "uint8str": "4", + "uint16": uint16(5), + "uint16str": "6", + "uint32": uint32(7), + "uint32str": "8", + "uint64": uint64(9), + "uint64str": "10", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(uint(1), in.Uint) + ast.Equal(uint(2), in.UintFromStr) + ast.Equal(uint8(3), in.Uint8) + ast.Equal(uint8(4), in.Uint8FromStr) + ast.Equal(uint16(5), in.Uint16) + ast.Equal(uint16(6), in.Uint16FromStr) + ast.Equal(uint32(7), in.Uint32) + ast.Equal(uint32(8), in.Uint32FromStr) + ast.Equal(uint64(9), in.Uint64) + ast.Equal(uint64(10), in.Uint64FromStr) + ast.Equal(uint(11), in.DefaultUint) +} + +func TestUnmarshalFloat(t *testing.T) { + type inner struct { + Float32 float32 `key:"float32"` + Float32Str float32 `key:"float32str,string"` + Float64 float64 `key:"float64"` + Float64Str float64 `key:"float64str,string"` + DefaultFloat float32 `key:"defaultfloat,default=5.5"` + Optional float32 `key:",optional"` + } + m := map[string]interface{}{ + "float32": float32(1.5), + "float32str": "2.5", + "float64": float64(3.5), + "float64str": "4.5", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(float32(1.5), in.Float32) + ast.Equal(float32(2.5), in.Float32Str) + ast.Equal(3.5, in.Float64) + ast.Equal(4.5, in.Float64Str) + ast.Equal(float32(5.5), in.DefaultFloat) +} + +func TestUnmarshalInt64Slice(t *testing.T) { + var v struct { + Ages []int64 `key:"ages"` + } + m := map[string]interface{}{ + "ages": []int64{1, 2}, + } + + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &v)) + ast.ElementsMatch([]int64{1, 2}, v.Ages) +} + +func TestUnmarshalIntSlice(t *testing.T) { + var v struct { + Ages []int `key:"ages"` + } + m := map[string]interface{}{ + "ages": []int{1, 2}, + } + + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &v)) + ast.ElementsMatch([]int{1, 2}, v.Ages) +} + +func TestUnmarshalString(t *testing.T) { + type inner struct { + Name string `key:"name"` + NameStr string `key:"namestr,string"` + NotPresent string `key:",optional"` + NotPresentWithTag string `key:"notpresent,optional"` + DefaultString string `key:"defaultstring,default=hello"` + Optional string `key:",optional"` + } + m := map[string]interface{}{ + "name": "kevin", + "namestr": "namewithstring", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("kevin", in.Name) + ast.Equal("namewithstring", in.NameStr) + ast.Empty(in.NotPresent) + ast.Empty(in.NotPresentWithTag) + ast.Equal("hello", in.DefaultString) +} + +func TestUnmarshalStringWithMissing(t *testing.T) { + type inner struct { + Name string `key:"name"` + } + m := map[string]interface{}{} + + var in inner + assert.NotNil(t, UnmarshalKey(m, &in)) +} + +func TestUnmarshalStringSliceFromString(t *testing.T) { + var v struct { + Names []string `key:"names"` + } + m := map[string]interface{}{ + "names": `["first", "second"]`, + } + + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &v)) + ast.Equal(2, len(v.Names)) + ast.Equal("first", v.Names[0]) + ast.Equal("second", v.Names[1]) +} + +func TestUnmarshalIntSliceFromString(t *testing.T) { + var v struct { + Values []int `key:"values"` + } + m := map[string]interface{}{ + "values": `[1, 2]`, + } + + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &v)) + ast.Equal(2, len(v.Values)) + ast.Equal(1, v.Values[0]) + ast.Equal(2, v.Values[1]) +} + +func TestUnmarshalStruct(t *testing.T) { + type address struct { + City string `key:"city"` + ZipCode int `key:"zipcode,string"` + DefaultString string `key:"defaultstring,default=hello"` + Optional string `key:",optional"` + } + type inner struct { + Name string `key:"name"` + Address address `key:"address"` + } + m := map[string]interface{}{ + "name": "kevin", + "address": map[string]interface{}{ + "city": "shanghai", + "zipcode": "200000", + }, + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("kevin", in.Name) + ast.Equal("shanghai", in.Address.City) + ast.Equal(200000, in.Address.ZipCode) + ast.Equal("hello", in.Address.DefaultString) +} + +func TestUnmarshalStructOptionalDepends(t *testing.T) { + type address struct { + City string `key:"city"` + Optional string `key:",optional"` + OptionalDepends string `key:",optional=Optional"` + } + type inner struct { + Name string `key:"name"` + Address address `key:"address"` + } + + tests := []struct { + input map[string]string + pass bool + }{ + { + pass: true, + }, + { + input: map[string]string{ + "OptionalDepends": "b", + }, + pass: false, + }, + { + input: map[string]string{ + "Optional": "a", + }, + pass: false, + }, + { + input: map[string]string{ + "Optional": "a", + "OptionalDepends": "b", + }, + pass: true, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + m := map[string]interface{}{ + "name": "kevin", + "address": map[string]interface{}{ + "city": "shanghai", + }, + } + for k, v := range test.input { + m["address"].(map[string]interface{})[k] = v + } + + var in inner + ast := assert.New(t) + if test.pass { + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("kevin", in.Name) + ast.Equal("shanghai", in.Address.City) + ast.Equal(test.input["Optional"], in.Address.Optional) + ast.Equal(test.input["OptionalDepends"], in.Address.OptionalDepends) + } else { + ast.NotNil(UnmarshalKey(m, &in)) + } + }) + } +} + +func TestUnmarshalStructOptionalDependsNot(t *testing.T) { + type address struct { + City string `key:"city"` + Optional string `key:",optional"` + OptionalDepends string `key:",optional=!Optional"` + } + type inner struct { + Name string `key:"name"` + Address address `key:"address"` + } + + tests := []struct { + input map[string]string + pass bool + }{ + { + input: map[string]string{}, + pass: false, + }, + { + input: map[string]string{ + "Optional": "a", + "OptionalDepends": "b", + }, + pass: false, + }, + { + input: map[string]string{ + "Optional": "a", + }, + pass: true, + }, + { + input: map[string]string{ + "OptionalDepends": "b", + }, + pass: true, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + m := map[string]interface{}{ + "name": "kevin", + "address": map[string]interface{}{ + "city": "shanghai", + }, + } + for k, v := range test.input { + m["address"].(map[string]interface{})[k] = v + } + + var in inner + ast := assert.New(t) + if test.pass { + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("kevin", in.Name) + ast.Equal("shanghai", in.Address.City) + ast.Equal(test.input["Optional"], in.Address.Optional) + ast.Equal(test.input["OptionalDepends"], in.Address.OptionalDepends) + } else { + ast.NotNil(UnmarshalKey(m, &in)) + } + }) + } +} + +func TestUnmarshalStructOptionalDependsNotErrorDetails(t *testing.T) { + type address struct { + Optional string `key:",optional"` + OptionalDepends string `key:",optional=!Optional"` + } + type inner struct { + Name string `key:"name"` + Address address `key:"address"` + } + + m := map[string]interface{}{ + "name": "kevin", + } + + var in inner + err := UnmarshalKey(m, &in) + assert.NotNil(t, err) +} + +func TestUnmarshalStructOptionalDependsNotNested(t *testing.T) { + type address struct { + Optional string `key:",optional"` + OptionalDepends string `key:",optional=!Optional"` + } + type combo struct { + Name string `key:"name,optional"` + Address address `key:"address"` + } + type inner struct { + Name string `key:"name"` + Combo combo `key:"combo"` + } + + m := map[string]interface{}{ + "name": "kevin", + } + + var in inner + err := UnmarshalKey(m, &in) + assert.NotNil(t, err) +} + +func TestUnmarshalStructOptionalNestedDifferentKey(t *testing.T) { + type address struct { + Optional string `dkey:",optional"` + OptionalDepends string `key:",optional"` + } + type combo struct { + Name string `key:"name,optional"` + Address address `key:"address"` + } + type inner struct { + Name string `key:"name"` + Combo combo `key:"combo"` + } + + m := map[string]interface{}{ + "name": "kevin", + } + + var in inner + assert.NotNil(t, UnmarshalKey(m, &in)) +} + +func TestUnmarshalStructOptionalDependsNotEnoughValue(t *testing.T) { + type address struct { + Optional string `key:",optional"` + OptionalDepends string `key:",optional=!"` + } + type inner struct { + Name string `key:"name"` + Address address `key:"address"` + } + + m := map[string]interface{}{ + "name": "kevin", + "address": map[string]interface{}{}, + } + + var in inner + err := UnmarshalKey(m, &in) + assert.NotNil(t, err) +} + +func TestUnmarshalAnonymousStructOptionalDepends(t *testing.T) { + type AnonAddress struct { + City string `key:"city"` + Optional string `key:",optional"` + OptionalDepends string `key:",optional=Optional"` + } + type inner struct { + Name string `key:"name"` + AnonAddress + } + + tests := []struct { + input map[string]string + pass bool + }{ + { + pass: true, + }, + { + input: map[string]string{ + "OptionalDepends": "b", + }, + pass: false, + }, + { + input: map[string]string{ + "Optional": "a", + }, + pass: false, + }, + { + input: map[string]string{ + "Optional": "a", + "OptionalDepends": "b", + }, + pass: true, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + m := map[string]interface{}{ + "name": "kevin", + "city": "shanghai", + } + for k, v := range test.input { + m[k] = v + } + + var in inner + ast := assert.New(t) + if test.pass { + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("kevin", in.Name) + ast.Equal("shanghai", in.City) + ast.Equal(test.input["Optional"], in.Optional) + ast.Equal(test.input["OptionalDepends"], in.OptionalDepends) + } else { + ast.NotNil(UnmarshalKey(m, &in)) + } + }) + } +} + +func TestUnmarshalStructPtr(t *testing.T) { + type address struct { + City string `key:"city"` + ZipCode int `key:"zipcode,string"` + DefaultString string `key:"defaultstring,default=hello"` + Optional string `key:",optional"` + } + type inner struct { + Name string `key:"name"` + Address *address `key:"address"` + } + m := map[string]interface{}{ + "name": "kevin", + "address": map[string]interface{}{ + "city": "shanghai", + "zipcode": "200000", + }, + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("kevin", in.Name) + ast.Equal("shanghai", in.Address.City) + ast.Equal(200000, in.Address.ZipCode) + ast.Equal("hello", in.Address.DefaultString) +} + +func TestUnmarshalWithStringIgnored(t *testing.T) { + type inner struct { + True bool `key:"yes"` + False bool `key:"no"` + Int int `key:"int"` + Int8 int8 `key:"int8"` + Int16 int16 `key:"int16"` + Int32 int32 `key:"int32"` + Int64 int64 `key:"int64"` + Uint uint `key:"uint"` + Uint8 uint8 `key:"uint8"` + Uint16 uint16 `key:"uint16"` + Uint32 uint32 `key:"uint32"` + Uint64 uint64 `key:"uint64"` + Float32 float32 `key:"float32"` + Float64 float64 `key:"float64"` + } + m := map[string]interface{}{ + "yes": "1", + "no": "0", + "int": "1", + "int8": "3", + "int16": "5", + "int32": "7", + "int64": "9", + "uint": "1", + "uint8": "3", + "uint16": "5", + "uint32": "7", + "uint64": "9", + "float32": "1.5", + "float64": "3.5", + } + + var in inner + um := NewUnmarshaler("key", WithStringValues()) + ast := assert.New(t) + ast.Nil(um.Unmarshal(m, &in)) + ast.True(in.True) + ast.False(in.False) + ast.Equal(1, in.Int) + ast.Equal(int8(3), in.Int8) + ast.Equal(int16(5), in.Int16) + ast.Equal(int32(7), in.Int32) + ast.Equal(int64(9), in.Int64) + ast.Equal(uint(1), in.Uint) + ast.Equal(uint8(3), in.Uint8) + ast.Equal(uint16(5), in.Uint16) + ast.Equal(uint32(7), in.Uint32) + ast.Equal(uint64(9), in.Uint64) + ast.Equal(float32(1.5), in.Float32) + ast.Equal(3.5, in.Float64) +} + +func TestUnmarshalJsonNumberInt64(t *testing.T) { + for i := 0; i <= maxUintBitsToTest; i++ { + var intValue int64 = 1 << uint(i) + strValue := strconv.FormatInt(intValue, 10) + var number = json.Number(strValue) + m := map[string]interface{}{ + "Id": number, + } + var v struct { + Id int64 + } + assert.Nil(t, UnmarshalKey(m, &v)) + assert.Equal(t, intValue, v.Id) + } +} + +func TestUnmarshalJsonNumberUint64(t *testing.T) { + for i := 0; i <= maxUintBitsToTest; i++ { + var intValue uint64 = 1 << uint(i) + strValue := strconv.FormatUint(intValue, 10) + var number = json.Number(strValue) + m := map[string]interface{}{ + "Id": number, + } + var v struct { + Id uint64 + } + assert.Nil(t, UnmarshalKey(m, &v)) + assert.Equal(t, intValue, v.Id) + } +} + +func TestUnmarshalJsonNumberUint64Ptr(t *testing.T) { + for i := 0; i <= maxUintBitsToTest; i++ { + var intValue uint64 = 1 << uint(i) + strValue := strconv.FormatUint(intValue, 10) + var number = json.Number(strValue) + m := map[string]interface{}{ + "Id": number, + } + var v struct { + Id *uint64 + } + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &v)) + ast.NotNil(v.Id) + ast.Equal(intValue, *v.Id) + } +} + +func TestUnmarshalMapOfInt(t *testing.T) { + m := map[string]interface{}{ + "Ids": map[string]bool{"first": true}, + } + var v struct { + Ids map[string]bool + } + assert.Nil(t, UnmarshalKey(m, &v)) + assert.True(t, v.Ids["first"]) +} + +func TestUnmarshalMapOfStructError(t *testing.T) { + m := map[string]interface{}{ + "Ids": map[string]interface{}{"first": "second"}, + } + var v struct { + Ids map[string]struct { + Name string + } + } + assert.NotNil(t, UnmarshalKey(m, &v)) +} + +func TestUnmarshalSlice(t *testing.T) { + m := map[string]interface{}{ + "Ids": []interface{}{"first", "second"}, + } + var v struct { + Ids []string + } + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &v)) + ast.Equal(2, len(v.Ids)) + ast.Equal("first", v.Ids[0]) + ast.Equal("second", v.Ids[1]) +} + +func TestUnmarshalSliceOfStruct(t *testing.T) { + m := map[string]interface{}{ + "Ids": []map[string]interface{}{ + { + "First": 1, + "Second": 2, + }, + }, + } + var v struct { + Ids []struct { + First int + Second int + } + } + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &v)) + ast.Equal(1, len(v.Ids)) + ast.Equal(1, v.Ids[0].First) + ast.Equal(2, v.Ids[0].Second) +} + +func TestUnmarshalWithStringOptionsCorrect(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Correct string `key:"correct,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "correct": "2", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("first", in.Value) + ast.Equal("2", in.Correct) +} + +func TestUnmarshalStringOptionsWithStringOptionsNotString(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Correct string `key:"correct,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "correct": 2, + } + + var in inner + unmarshaler := NewUnmarshaler(defaultKeyName, WithStringValues()) + ast := assert.New(t) + ast.NotNil(unmarshaler.Unmarshal(m, &in)) +} + +func TestUnmarshalStringOptionsWithStringOptions(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Correct string `key:"correct,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "correct": "2", + } + + var in inner + unmarshaler := NewUnmarshaler(defaultKeyName, WithStringValues()) + ast := assert.New(t) + ast.Nil(unmarshaler.Unmarshal(m, &in)) + ast.Equal("first", in.Value) + ast.Equal("2", in.Correct) +} + +func TestUnmarshalStringOptionsWithStringOptionsPtr(t *testing.T) { + type inner struct { + Value *string `key:"value,options=first|second"` + Correct *int `key:"correct,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "correct": "2", + } + + var in inner + unmarshaler := NewUnmarshaler(defaultKeyName, WithStringValues()) + ast := assert.New(t) + ast.Nil(unmarshaler.Unmarshal(m, &in)) + ast.True(*in.Value == "first") + ast.True(*in.Correct == 2) +} + +func TestUnmarshalStringOptionsWithStringOptionsIncorrect(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Correct string `key:"correct,options=1|2"` + } + m := map[string]interface{}{ + "value": "third", + "correct": "2", + } + + var in inner + unmarshaler := NewUnmarshaler(defaultKeyName, WithStringValues()) + ast := assert.New(t) + ast.NotNil(unmarshaler.Unmarshal(m, &in)) +} + +func TestUnmarshalWithStringOptionsIncorrect(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Incorrect string `key:"incorrect,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "incorrect": "3", + } + + var in inner + assert.NotNil(t, UnmarshalKey(m, &in)) +} + +func TestUnmarshalWithIntOptionsCorrect(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Number int `key:"number,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "number": 2, + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("first", in.Value) + ast.Equal(2, in.Number) +} + +func TestUnmarshalWithIntOptionsCorrectPtr(t *testing.T) { + type inner struct { + Value *string `key:"value,options=first|second"` + Number *int `key:"number,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "number": 2, + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.True(*in.Value == "first") + ast.True(*in.Number == 2) +} + +func TestUnmarshalWithIntOptionsIncorrect(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Incorrect int `key:"incorrect,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "incorrect": 3, + } + + var in inner + assert.NotNil(t, UnmarshalKey(m, &in)) +} + +func TestUnmarshalWithUintOptionsCorrect(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Number uint `key:"number,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "number": uint(2), + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal("first", in.Value) + ast.Equal(uint(2), in.Number) +} + +func TestUnmarshalWithUintOptionsIncorrect(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second"` + Incorrect uint `key:"incorrect,options=1|2"` + } + m := map[string]interface{}{ + "value": "first", + "incorrect": uint(3), + } + + var in inner + assert.NotNil(t, UnmarshalKey(m, &in)) +} + +func TestUnmarshalWithOptionsAndDefault(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second|third,default=second"` + } + m := map[string]interface{}{} + + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.Equal(t, "second", in.Value) +} + +func TestUnmarshalWithOptionsAndSet(t *testing.T) { + type inner struct { + Value string `key:"value,options=first|second|third,default=second"` + } + m := map[string]interface{}{ + "value": "first", + } + + var in inner + assert.Nil(t, UnmarshalKey(m, &in)) + assert.Equal(t, "first", in.Value) +} + +func TestUnmarshalNestedKey(t *testing.T) { + var c struct { + Id int `json:"Persons.first.Id"` + } + m := map[string]interface{}{ + "Persons": map[string]interface{}{ + "first": map[string]interface{}{ + "Id": 1, + }, + }, + } + + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c)) + assert.Equal(t, 1, c.Id) +} + +func TestUnmarhsalNestedKeyArray(t *testing.T) { + var c struct { + First []struct { + Id int + } `json:"Persons.first"` + } + m := map[string]interface{}{ + "Persons": map[string]interface{}{ + "first": []map[string]interface{}{ + {"Id": 1}, + {"Id": 2}, + }, + }, + } + + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c)) + assert.Equal(t, 2, len(c.First)) + assert.Equal(t, 1, c.First[0].Id) +} + +func TestUnmarshalAnonymousOptionalRequiredProvided(t *testing.T) { + type ( + Foo struct { + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalRequiredMissed(t *testing.T) { + type ( + Foo struct { + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{} + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Value) == 0) +} + +func TestUnmarshalAnonymousOptionalOptionalProvided(t *testing.T) { + type ( + Foo struct { + Value string `json:"v,optional"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalOptionalMissed(t *testing.T) { + type ( + Foo struct { + Value string `json:"v,optional"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{} + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Value) == 0) +} + +func TestUnmarshalAnonymousOptionalRequiredBothProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n"` + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "n": "kevin", + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "kevin", b.Name) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalRequiredOneProvidedOneMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n"` + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalAnonymousOptionalRequiredBothMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n"` + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{} + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.True(t, len(b.Value) == 0) +} + +func TestUnmarshalAnonymousOptionalOneRequiredOneOptionalBothProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "n": "kevin", + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "kevin", b.Name) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalOneRequiredOneOptionalBothMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{} + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.True(t, len(b.Value) == 0) +} + +func TestUnmarshalAnonymousOptionalOneRequiredOneOptionalRequiredProvidedOptionalMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalOneRequiredOneOptionalRequiredMissedOptionalProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "n": "anything", + } + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalAnonymousOptionalBothOptionalBothProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v,optional"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "n": "kevin", + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "kevin", b.Name) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalBothOptionalOneProvidedOneMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v,optional"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalBothOptionalBothMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v,optional"` + } + + Bar struct { + Foo `json:",optional"` + } + ) + m := map[string]interface{}{} + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.True(t, len(b.Value) == 0) +} + +func TestUnmarshalAnonymousRequiredProvided(t *testing.T) { + type ( + Foo struct { + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousRequiredMissed(t *testing.T) { + type ( + Foo struct { + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{} + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalAnonymousOptionalProvided(t *testing.T) { + type ( + Foo struct { + Value string `json:"v,optional"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOptionalMissed(t *testing.T) { + type ( + Foo struct { + Value string `json:"v,optional"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{} + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Value) == 0) +} + +func TestUnmarshalAnonymousRequiredBothProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "n": "kevin", + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "kevin", b.Name) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousRequiredOneProvidedOneMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalAnonymousRequiredBothMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalAnonymousOneRequiredOneOptionalBothProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "n": "kevin", + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "kevin", b.Name) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOneRequiredOneOptionalBothMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{} + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalAnonymousOneRequiredOneOptionalRequiredProvidedOptionalMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousOneRequiredOneOptionalRequiredMissedOptionalProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "n": "anything", + } + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalAnonymousBothOptionalBothProvided(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v,optional"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "n": "kevin", + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "kevin", b.Name) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousBothOptionalOneProvidedOneMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v,optional"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "v": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.Equal(t, "anything", b.Value) +} + +func TestUnmarshalAnonymousBothOptionalBothMissed(t *testing.T) { + type ( + Foo struct { + Name string `json:"n,optional"` + Value string `json:"v,optional"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{} + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.True(t, len(b.Name) == 0) + assert.True(t, len(b.Value) == 0) +} + +func TestUnmarshalAnonymousWrappedToMuch(t *testing.T) { + type ( + Foo struct { + Name string `json:"n"` + Value string `json:"v"` + } + + Bar struct { + Foo + } + ) + m := map[string]interface{}{ + "Foo": map[string]interface{}{ + "n": "name", + "v": "anything", + }, + } + + var b Bar + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &b)) +} + +func TestUnmarshalWrappedObject(t *testing.T) { + type ( + Foo struct { + Value string `json:"v"` + } + + Bar struct { + Inner Foo + } + ) + m := map[string]interface{}{ + "Inner": map[string]interface{}{ + "v": "anything", + }, + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Inner.Value) +} + +func TestUnmarshalWrappedObjectOptional(t *testing.T) { + type ( + Foo struct { + Hosts []string + Key string + } + + Bar struct { + Inner Foo `json:",optional"` + Name string + } + ) + m := map[string]interface{}{ + "Name": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Name) +} + +func TestUnmarshalWrappedObjectOptionalFilled(t *testing.T) { + type ( + Foo struct { + Hosts []string + Key string + } + + Bar struct { + Inner Foo `json:",optional"` + Name string + } + ) + hosts := []string{"1", "2"} + m := map[string]interface{}{ + "Inner": map[string]interface{}{ + "Hosts": hosts, + "Key": "key", + }, + "Name": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.EqualValues(t, hosts, b.Inner.Hosts) + assert.Equal(t, "key", b.Inner.Key) + assert.Equal(t, "anything", b.Name) +} + +func TestUnmarshalWrappedNamedObjectOptional(t *testing.T) { + type ( + Foo struct { + Host string + Key string + } + + Bar struct { + Inner Foo `json:",optional"` + Name string + } + ) + m := map[string]interface{}{ + "Inner": map[string]interface{}{ + "Host": "thehost", + "Key": "thekey", + }, + "Name": "anything", + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "thehost", b.Inner.Host) + assert.Equal(t, "thekey", b.Inner.Key) + assert.Equal(t, "anything", b.Name) +} + +func TestUnmarshalWrappedObjectNamedPtr(t *testing.T) { + type ( + Foo struct { + Value string `json:"v"` + } + + Bar struct { + Inner *Foo `json:"foo,optional"` + } + ) + m := map[string]interface{}{ + "foo": map[string]interface{}{ + "v": "anything", + }, + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Inner.Value) +} + +func TestUnmarshalWrappedObjectPtr(t *testing.T) { + type ( + Foo struct { + Value string `json:"v"` + } + + Bar struct { + Inner *Foo + } + ) + m := map[string]interface{}{ + "Inner": map[string]interface{}{ + "v": "anything", + }, + } + + var b Bar + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &b)) + assert.Equal(t, "anything", b.Inner.Value) +} + +func TestUnmarshalInt2String(t *testing.T) { + type inner struct { + Int string `key:"int"` + } + m := map[string]interface{}{ + "int": 123, + } + + var in inner + assert.NotNil(t, UnmarshalKey(m, &in)) +} + +func TestUnmarshalZeroValues(t *testing.T) { + type inner struct { + False bool `key:"no"` + Int int `key:"int"` + String string `key:"string"` + } + m := map[string]interface{}{ + "no": false, + "int": 0, + "string": "", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.False(in.False) + ast.Equal(0, in.Int) + ast.Equal("", in.String) +} + +func TestUnmarshalUsingDifferentKeys(t *testing.T) { + type inner struct { + False bool `key:"no"` + Int int `key:"int"` + String string `bson:"string"` + } + m := map[string]interface{}{ + "no": false, + "int": 9, + "string": "value", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.False(in.False) + ast.Equal(9, in.Int) + ast.True(len(in.String) == 0) +} + +func TestUnmarshalNumberRangeInt(t *testing.T) { + type inner struct { + Value1 int `key:"value1,range=[1:]"` + Value2 int8 `key:"value2,range=[1:5]"` + Value3 int16 `key:"value3,range=[1:5]"` + Value4 int32 `key:"value4,range=[1:5]"` + Value5 int64 `key:"value5,range=[1:5]"` + Value6 uint `key:"value6,range=[:5]"` + Value8 uint8 `key:"value8,range=[1:5],string"` + Value9 uint16 `key:"value9,range=[1:5],string"` + Value10 uint32 `key:"value10,range=[1:5],string"` + Value11 uint64 `key:"value11,range=[1:5],string"` + } + m := map[string]interface{}{ + "value1": 10, + "value2": int8(1), + "value3": int16(2), + "value4": int32(4), + "value5": int64(5), + "value6": uint(0), + "value8": "1", + "value9": "2", + "value10": "4", + "value11": "5", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(10, in.Value1) + ast.Equal(int8(1), in.Value2) + ast.Equal(int16(2), in.Value3) + ast.Equal(int32(4), in.Value4) + ast.Equal(int64(5), in.Value5) + ast.Equal(uint(0), in.Value6) + ast.Equal(uint8(1), in.Value8) + ast.Equal(uint16(2), in.Value9) + ast.Equal(uint32(4), in.Value10) + ast.Equal(uint64(5), in.Value11) +} + +func TestUnmarshalNumberRangeJsonNumber(t *testing.T) { + type inner struct { + Value3 uint `key:"value3,range=(1:5]"` + Value4 uint8 `key:"value4,range=(1:5]"` + Value5 uint16 `key:"value5,range=(1:5]"` + } + m := map[string]interface{}{ + "value3": json.Number("2"), + "value4": json.Number("4"), + "value5": json.Number("5"), + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(uint(2), in.Value3) + ast.Equal(uint8(4), in.Value4) + ast.Equal(uint16(5), in.Value5) + + type inner1 struct { + Value int `key:"value,range=(1:5]"` + } + m = map[string]interface{}{ + "value": json.Number("a"), + } + + var in1 inner1 + ast.NotNil(UnmarshalKey(m, &in1)) +} + +func TestUnmarshalNumberRangeIntLeftExclude(t *testing.T) { + type inner struct { + Value3 uint `key:"value3,range=(1:5]"` + Value4 uint32 `key:"value4,default=4,range=(1:5]"` + Value5 uint64 `key:"value5,range=(1:5]"` + Value9 int `key:"value9,range=(1:5],string"` + Value10 int `key:"value10,range=(1:5],string"` + Value11 int `key:"value11,range=(1:5],string"` + } + m := map[string]interface{}{ + "value3": uint(2), + "value4": uint32(4), + "value5": uint64(5), + "value9": "2", + "value10": "4", + "value11": "5", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(uint(2), in.Value3) + ast.Equal(uint32(4), in.Value4) + ast.Equal(uint64(5), in.Value5) + ast.Equal(2, in.Value9) + ast.Equal(4, in.Value10) + ast.Equal(5, in.Value11) +} + +func TestUnmarshalNumberRangeIntRightExclude(t *testing.T) { + type inner struct { + Value2 uint `key:"value2,range=[1:5)"` + Value3 uint8 `key:"value3,range=[1:5)"` + Value4 uint16 `key:"value4,range=[1:5)"` + Value8 int `key:"value8,range=[1:5),string"` + Value9 int `key:"value9,range=[1:5),string"` + Value10 int `key:"value10,range=[1:5),string"` + } + m := map[string]interface{}{ + "value2": uint(1), + "value3": uint8(2), + "value4": uint16(4), + "value8": "1", + "value9": "2", + "value10": "4", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(uint(1), in.Value2) + ast.Equal(uint8(2), in.Value3) + ast.Equal(uint16(4), in.Value4) + ast.Equal(1, in.Value8) + ast.Equal(2, in.Value9) + ast.Equal(4, in.Value10) +} + +func TestUnmarshalNumberRangeIntExclude(t *testing.T) { + type inner struct { + Value3 int `key:"value3,range=(1:5)"` + Value4 int `key:"value4,range=(1:5)"` + Value9 int `key:"value9,range=(1:5),string"` + Value10 int `key:"value10,range=(1:5),string"` + } + m := map[string]interface{}{ + "value3": 2, + "value4": 4, + "value9": "2", + "value10": "4", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(2, in.Value3) + ast.Equal(4, in.Value4) + ast.Equal(2, in.Value9) + ast.Equal(4, in.Value10) +} + +func TestUnmarshalNumberRangeIntOutOfRange(t *testing.T) { + type inner1 struct { + Value int64 `key:"value,default=3,range=(1:5)"` + } + + var in1 inner1 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(1), + }, &in1)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(0), + }, &in1)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(5), + }, &in1)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": json.Number("6"), + }, &in1)) + + type inner2 struct { + Value int64 `key:"value,optional,range=[1:5)"` + } + + var in2 inner2 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(0), + }, &in2)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(5), + }, &in2)) + + type inner3 struct { + Value int64 `key:"value,range=(1:5]"` + } + + var in3 inner3 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(1), + }, &in3)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(6), + }, &in3)) + + type inner4 struct { + Value int64 `key:"value,range=[1:5]"` + } + + var in4 inner4 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(0), + }, &in4)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": int64(6), + }, &in4)) +} + +func TestUnmarshalNumberRangeFloat(t *testing.T) { + type inner struct { + Value2 float32 `key:"value2,range=[1:5]"` + Value3 float32 `key:"value3,range=[1:5]"` + Value4 float64 `key:"value4,range=[1:5]"` + Value5 float64 `key:"value5,range=[1:5]"` + Value8 float64 `key:"value8,range=[1:5],string"` + Value9 float64 `key:"value9,range=[1:5],string"` + Value10 float64 `key:"value10,range=[1:5],string"` + Value11 float64 `key:"value11,range=[1:5],string"` + } + m := map[string]interface{}{ + "value2": float32(1), + "value3": float32(2), + "value4": float64(4), + "value5": float64(5), + "value8": "1", + "value9": "2", + "value10": "4", + "value11": "5", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(float32(1), in.Value2) + ast.Equal(float32(2), in.Value3) + ast.Equal(float64(4), in.Value4) + ast.Equal(float64(5), in.Value5) + ast.Equal(float64(1), in.Value8) + ast.Equal(float64(2), in.Value9) + ast.Equal(float64(4), in.Value10) + ast.Equal(float64(5), in.Value11) +} + +func TestUnmarshalNumberRangeFloatLeftExclude(t *testing.T) { + type inner struct { + Value3 float64 `key:"value3,range=(1:5]"` + Value4 float64 `key:"value4,range=(1:5]"` + Value5 float64 `key:"value5,range=(1:5]"` + Value9 float64 `key:"value9,range=(1:5],string"` + Value10 float64 `key:"value10,range=(1:5],string"` + Value11 float64 `key:"value11,range=(1:5],string"` + } + m := map[string]interface{}{ + "value3": float64(2), + "value4": float64(4), + "value5": float64(5), + "value9": "2", + "value10": "4", + "value11": "5", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(float64(2), in.Value3) + ast.Equal(float64(4), in.Value4) + ast.Equal(float64(5), in.Value5) + ast.Equal(float64(2), in.Value9) + ast.Equal(float64(4), in.Value10) + ast.Equal(float64(5), in.Value11) +} + +func TestUnmarshalNumberRangeFloatRightExclude(t *testing.T) { + type inner struct { + Value2 float64 `key:"value2,range=[1:5)"` + Value3 float64 `key:"value3,range=[1:5)"` + Value4 float64 `key:"value4,range=[1:5)"` + Value8 float64 `key:"value8,range=[1:5),string"` + Value9 float64 `key:"value9,range=[1:5),string"` + Value10 float64 `key:"value10,range=[1:5),string"` + } + m := map[string]interface{}{ + "value2": float64(1), + "value3": float64(2), + "value4": float64(4), + "value8": "1", + "value9": "2", + "value10": "4", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(float64(1), in.Value2) + ast.Equal(float64(2), in.Value3) + ast.Equal(float64(4), in.Value4) + ast.Equal(float64(1), in.Value8) + ast.Equal(float64(2), in.Value9) + ast.Equal(float64(4), in.Value10) +} + +func TestUnmarshalNumberRangeFloatExclude(t *testing.T) { + type inner struct { + Value3 float64 `key:"value3,range=(1:5)"` + Value4 float64 `key:"value4,range=(1:5)"` + Value9 float64 `key:"value9,range=(1:5),string"` + Value10 float64 `key:"value10,range=(1:5),string"` + } + m := map[string]interface{}{ + "value3": float64(2), + "value4": float64(4), + "value9": "2", + "value10": "4", + } + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalKey(m, &in)) + ast.Equal(float64(2), in.Value3) + ast.Equal(float64(4), in.Value4) + ast.Equal(float64(2), in.Value9) + ast.Equal(float64(4), in.Value10) +} + +func TestUnmarshalNumberRangeFloatOutOfRange(t *testing.T) { + type inner1 struct { + Value float64 `key:"value,range=(1:5)"` + } + + var in1 inner1 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(1), + }, &in1)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(0), + }, &in1)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(5), + }, &in1)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": json.Number("6"), + }, &in1)) + + type inner2 struct { + Value float64 `key:"value,range=[1:5)"` + } + + var in2 inner2 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(0), + }, &in2)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(5), + }, &in2)) + + type inner3 struct { + Value float64 `key:"value,range=(1:5]"` + } + + var in3 inner3 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(1), + }, &in3)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(6), + }, &in3)) + + type inner4 struct { + Value float64 `key:"value,range=[1:5]"` + } + + var in4 inner4 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(0), + }, &in4)) + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "value": float64(6), + }, &in4)) +} + +func TestUnmarshalRangeError(t *testing.T) { + type inner1 struct { + Value int `key:",range="` + } + var in1 inner1 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in1)) + + type inner2 struct { + Value int `key:",range=["` + } + var in2 inner2 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in2)) + + type inner3 struct { + Value int `key:",range=[:"` + } + var in3 inner3 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in3)) + + type inner4 struct { + Value int `key:",range=[:]"` + } + var in4 inner4 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in4)) + + type inner5 struct { + Value int `key:",range={:]"` + } + var in5 inner5 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in5)) + + type inner6 struct { + Value int `key:",range=[:}"` + } + var in6 inner6 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in6)) + + type inner7 struct { + Value int `key:",range=[]"` + } + var in7 inner7 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in7)) + + type inner8 struct { + Value int `key:",range=[a:]"` + } + var in8 inner8 + assert.NotNil(t, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in8)) + + type inner9 struct { + Value int `key:",range=[:a]"` + } + var in9 inner9 + assert.NotNil(t, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in9)) + + type inner10 struct { + Value int `key:",range"` + } + var in10 inner10 + assert.NotNil(t, UnmarshalKey(map[string]interface{}{ + "Value": 1, + }, &in10)) + + type inner11 struct { + Value int `key:",range=[1,2]"` + } + var in11 inner11 + assert.Equal(t, errNumberRange, UnmarshalKey(map[string]interface{}{ + "Value": "a", + }, &in11)) +} + +func TestUnmarshalNestedMap(t *testing.T) { + var c struct { + Anything map[string]map[string]string `json:"anything"` + } + m := map[string]interface{}{ + "anything": map[string]map[string]interface{}{ + "inner": { + "id": "1", + "name": "any", + }, + }, + } + + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c)) + assert.Equal(t, "1", c.Anything["inner"]["id"]) +} + +func TestUnmarshalNestedMapMismatch(t *testing.T) { + var c struct { + Anything map[string]map[string]map[string]string `json:"anything"` + } + m := map[string]interface{}{ + "anything": map[string]map[string]interface{}{ + "inner": { + "name": "any", + }, + }, + } + + assert.NotNil(t, NewUnmarshaler("json").Unmarshal(m, &c)) +} + +func TestUnmarshalNestedMapSimple(t *testing.T) { + var c struct { + Anything map[string]string `json:"anything"` + } + m := map[string]interface{}{ + "anything": map[string]interface{}{ + "id": "1", + "name": "any", + }, + } + + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c)) + assert.Equal(t, "1", c.Anything["id"]) +} + +func TestUnmarshalNestedMapSimpleTypeMatch(t *testing.T) { + var c struct { + Anything map[string]string `json:"anything"` + } + m := map[string]interface{}{ + "anything": map[string]string{ + "id": "1", + "name": "any", + }, + } + + assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c)) + assert.Equal(t, "1", c.Anything["id"]) +} + +func BenchmarkUnmarshalString(b *testing.B) { + type inner struct { + Value string `key:"value"` + } + m := map[string]interface{}{ + "value": "first", + } + + for i := 0; i < b.N; i++ { + var in inner + if err := UnmarshalKey(m, &in); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkUnmarshalStruct(b *testing.B) { + b.ReportAllocs() + + m := map[string]interface{}{ + "Ids": []map[string]interface{}{ + { + "First": 1, + "Second": 2, + }, + }, + } + + for i := 0; i < b.N; i++ { + var v struct { + Ids []struct { + First int + Second int + } + } + if err := UnmarshalKey(m, &v); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkMapToStruct(b *testing.B) { + data := map[string]interface{}{ + "valid": "1", + "age": "5", + "name": "liao", + } + type anonymous struct { + Valid bool + Age int + Name string + } + + for i := 0; i < b.N; i++ { + var an anonymous + if valid, ok := data["valid"]; ok { + an.Valid = valid == "1" + } + if age, ok := data["age"]; ok { + ages, _ := age.(string) + an.Age, _ = strconv.Atoi(ages) + } + if name, ok := data["name"]; ok { + names, _ := name.(string) + an.Name = names + } + } +} + +func BenchmarkUnmarshal(b *testing.B) { + data := map[string]interface{}{ + "valid": "1", + "age": "5", + "name": "liao", + } + type anonymous struct { + Valid bool `key:"valid,string"` + Age int `key:"age,string"` + Name string `key:"name"` + } + + for i := 0; i < b.N; i++ { + var an anonymous + UnmarshalKey(data, &an) + } +} diff --git a/core/mapping/utils.go b/core/mapping/utils.go new file mode 100644 index 00000000..fb9c67e9 --- /dev/null +++ b/core/mapping/utils.go @@ -0,0 +1,513 @@ +package mapping + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "reflect" + "strconv" + "strings" + "sync" + + "zero/core/stringx" +) + +const ( + defaultOption = "default" + stringOption = "string" + optionalOption = "optional" + optionsOption = "options" + rangeOption = "range" + optionSeparator = "|" + equalToken = "=" +) + +var ( + errUnsupportedType = errors.New("unsupported type on setting field value") + errNumberRange = errors.New("wrong number range setting") + optionsCache = make(map[string]optionsCacheValue) + cacheLock sync.RWMutex + structRequiredCache = make(map[reflect.Type]requiredCacheValue) + structCacheLock sync.RWMutex +) + +type ( + optionsCacheValue struct { + key string + options *fieldOptions + err error + } + + requiredCacheValue struct { + required bool + err error + } +) + +func Deref(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return t +} + +func Repr(v interface{}) string { + if v == nil { + return "" + } + + // if func (v *Type) String() string, we can't use Elem() + switch vt := v.(type) { + case fmt.Stringer: + return vt.String() + } + + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr && !val.IsNil() { + val = val.Elem() + } + + switch vt := val.Interface().(type) { + case bool: + return strconv.FormatBool(vt) + case error: + return vt.Error() + case float32: + return strconv.FormatFloat(float64(vt), 'f', -1, 32) + case float64: + return strconv.FormatFloat(vt, 'f', -1, 64) + case fmt.Stringer: + return vt.String() + case int: + return strconv.Itoa(vt) + case int8: + return strconv.Itoa(int(vt)) + case int16: + return strconv.Itoa(int(vt)) + case int32: + return strconv.Itoa(int(vt)) + case int64: + return strconv.FormatInt(vt, 10) + case string: + return vt + case uint: + return strconv.FormatUint(uint64(vt), 10) + case uint8: + return strconv.FormatUint(uint64(vt), 10) + case uint16: + return strconv.FormatUint(uint64(vt), 10) + case uint32: + return strconv.FormatUint(uint64(vt), 10) + case uint64: + return strconv.FormatUint(vt, 10) + case []byte: + return string(vt) + default: + return fmt.Sprint(val.Interface()) + } +} + +func ValidatePtr(v *reflect.Value) error { + // sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr, + // panic otherwise + if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() { + return fmt.Errorf("not a valid pointer: %v", v) + } + + return nil +} + +func convertType(kind reflect.Kind, str string) (interface{}, error) { + switch kind { + case reflect.Bool: + return str == "1" || strings.ToLower(str) == "true", nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if intValue, err := strconv.ParseInt(str, 10, 64); err != nil { + return 0, fmt.Errorf("the value %q cannot parsed as int", str) + } else { + return intValue, nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if uintValue, err := strconv.ParseUint(str, 10, 64); err != nil { + return 0, fmt.Errorf("the value %q cannot parsed as uint", str) + } else { + return uintValue, nil + } + case reflect.Float32, reflect.Float64: + if floatValue, err := strconv.ParseFloat(str, 64); err != nil { + return 0, fmt.Errorf("the value %q cannot parsed as float", str) + } else { + return floatValue, nil + } + case reflect.String: + return str, nil + default: + return nil, errUnsupportedType + } +} + +func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) { + segments := strings.Split(value, ",") + key := strings.TrimSpace(segments[0]) + options := segments[1:] + + if len(options) > 0 { + var fieldOpts fieldOptions + + for _, segment := range options { + option := strings.TrimSpace(segment) + switch { + case option == stringOption: + fieldOpts.FromString = true + case strings.HasPrefix(option, optionalOption): + segs := strings.Split(option, equalToken) + switch len(segs) { + case 1: + fieldOpts.Optional = true + case 2: + fieldOpts.Optional = true + fieldOpts.OptionalDep = segs[1] + default: + return "", nil, fmt.Errorf("field %s has wrong optional", field.Name) + } + case option == optionalOption: + fieldOpts.Optional = true + case strings.HasPrefix(option, optionsOption): + segs := strings.Split(option, equalToken) + if len(segs) != 2 { + return "", nil, fmt.Errorf("field %s has wrong options", field.Name) + } else { + fieldOpts.Options = strings.Split(segs[1], optionSeparator) + } + case strings.HasPrefix(option, defaultOption): + segs := strings.Split(option, equalToken) + if len(segs) != 2 { + return "", nil, fmt.Errorf("field %s has wrong default option", field.Name) + } else { + fieldOpts.Default = strings.TrimSpace(segs[1]) + } + case strings.HasPrefix(option, rangeOption): + segs := strings.Split(option, equalToken) + if len(segs) != 2 { + return "", nil, fmt.Errorf("field %s has wrong range", field.Name) + } + if nr, err := parseNumberRange(segs[1]); err != nil { + return "", nil, err + } else { + fieldOpts.Range = nr + } + } + } + + return key, &fieldOpts, nil + } + + return key, nil, nil +} + +func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) { + numFields := tp.NumField() + for i := 0; i < numFields; i++ { + childField := tp.Field(i) + if usingDifferentKeys(tag, childField) { + return true, nil + } + + _, opts, err := parseKeyAndOptions(tag, childField) + if err != nil { + return false, err + } + + if opts == nil { + if childField.Type.Kind() != reflect.Struct { + return true, nil + } + + if required, err := implicitValueRequiredStruct(tag, childField.Type); err != nil { + return false, err + } else if required { + return true, nil + } + } else if !opts.Optional && len(opts.Default) == 0 { + return true, nil + } else if len(opts.OptionalDep) > 0 && opts.OptionalDep[0] == notSymbol { + return true, nil + } + } + + return false, nil +} + +func maybeNewValue(field reflect.StructField, value reflect.Value) { + if field.Type.Kind() == reflect.Ptr && value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } +} + +// don't modify returned fieldOptions, it's cached and shared among different calls. +func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fieldOptions, error) { + value := field.Tag.Get(tagName) + if len(value) == 0 { + return field.Name, nil, nil + } + + cacheLock.RLock() + cache, ok := optionsCache[value] + cacheLock.RUnlock() + if ok { + return stringx.TakeOne(cache.key, field.Name), cache.options, cache.err + } + + key, options, err := doParseKeyAndOptions(field, value) + cacheLock.Lock() + optionsCache[value] = optionsCacheValue{ + key: key, + options: options, + err: err, + } + cacheLock.Unlock() + + return stringx.TakeOne(key, field.Name), options, err +} + +// support below notations: +// [:5] (:5] [:5) (:5) +// [1:] [1:) (1:] (1:) +// [1:5] [1:5) (1:5] (1:5) +func parseNumberRange(str string) (*numberRange, error) { + if len(str) == 0 { + return nil, errNumberRange + } + + var leftInclude bool + switch str[0] { + case '[': + leftInclude = true + case '(': + leftInclude = false + default: + return nil, errNumberRange + } + + str = str[1:] + if len(str) == 0 { + return nil, errNumberRange + } + + var rightInclude bool + switch str[len(str)-1] { + case ']': + rightInclude = true + case ')': + rightInclude = false + default: + return nil, errNumberRange + } + + str = str[:len(str)-1] + fields := strings.Split(str, ":") + if len(fields) != 2 { + return nil, errNumberRange + } + + if len(fields[0]) == 0 && len(fields[1]) == 0 { + return nil, errNumberRange + } + + var left float64 + if len(fields[0]) > 0 { + var err error + if left, err = strconv.ParseFloat(fields[0], 64); err != nil { + return nil, err + } + } else { + left = -math.MaxFloat64 + } + + var right float64 + if len(fields[1]) > 0 { + var err error + if right, err = strconv.ParseFloat(fields[1], 64); err != nil { + return nil, err + } + } else { + right = math.MaxFloat64 + } + + return &numberRange{ + left: left, + leftInclude: leftInclude, + right: right, + rightInclude: rightInclude, + }, nil +} + +func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error { + switch kind { + case reflect.Bool: + value.SetBool(v.(bool)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value.SetInt(v.(int64)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + value.SetUint(v.(uint64)) + case reflect.Float32, reflect.Float64: + value.SetFloat(v.(float64)) + case reflect.String: + value.SetString(v.(string)) + default: + return errUnsupportedType + } + + return nil +} + +func setValue(kind reflect.Kind, value reflect.Value, str string) error { + if !value.CanSet() { + return errValueNotSettable + } + + v, err := convertType(kind, str) + if err != nil { + return err + } + + return setMatchedPrimitiveValue(kind, value, v) +} + +func structValueRequired(tag string, tp reflect.Type) (bool, error) { + structCacheLock.RLock() + val, ok := structRequiredCache[tp] + structCacheLock.RUnlock() + if ok { + return val.required, val.err + } + + required, err := implicitValueRequiredStruct(tag, tp) + structCacheLock.Lock() + structRequiredCache[tp] = requiredCacheValue{ + required: required, + err: err, + } + structCacheLock.Unlock() + + return required, err +} + +func toFloat64(v interface{}) (float64, bool) { + switch val := v.(type) { + case int: + return float64(val), true + case int8: + return float64(val), true + case int16: + return float64(val), true + case int32: + return float64(val), true + case int64: + return float64(val), true + case uint: + return float64(val), true + case uint8: + return float64(val), true + case uint16: + return float64(val), true + case uint32: + return float64(val), true + case uint64: + return float64(val), true + case float32: + return float64(val), true + case float64: + return val, true + default: + return 0, false + } +} + +func usingDifferentKeys(key string, field reflect.StructField) bool { + if len(field.Tag) > 0 { + if _, ok := field.Tag.Lookup(key); !ok { + return true + } + } + + return false +} + +func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opts *fieldOptionsWithContext) error { + if !value.CanSet() { + return errValueNotSettable + } + + v, err := convertType(kind, str) + if err != nil { + return err + } + + if err := validateValueRange(v, opts); err != nil { + return err + } + + return setMatchedPrimitiveValue(kind, value, v) +} + +func validateJsonNumberRange(v json.Number, opts *fieldOptionsWithContext) error { + if opts == nil || opts.Range == nil { + return nil + } + + fv, err := v.Float64() + if err != nil { + return err + } + + return validateNumberRange(fv, opts.Range) +} + +func validateNumberRange(fv float64, nr *numberRange) error { + if nr == nil { + return nil + } + + if (nr.leftInclude && fv < nr.left) || (!nr.leftInclude && fv <= nr.left) { + return errNumberRange + } + + if (nr.rightInclude && fv > nr.right) || (!nr.rightInclude && fv >= nr.right) { + return errNumberRange + } + + return nil +} + +func validateValueInOptions(options []string, value interface{}) error { + if len(options) > 0 { + switch v := value.(type) { + case string: + if !stringx.Contains(options, v) { + return fmt.Errorf(`error: value "%s" is not defined in options "%v"`, v, options) + } + default: + if !stringx.Contains(options, Repr(v)) { + return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, value, options) + } + } + } + + return nil +} + +func validateValueRange(mapValue interface{}, opts *fieldOptionsWithContext) error { + if opts == nil || opts.Range == nil { + return nil + } + + fv, ok := toFloat64(mapValue) + if !ok { + return errNumberRange + } + + return validateNumberRange(fv, opts.Range) +} diff --git a/core/mapping/utils_test.go b/core/mapping/utils_test.go new file mode 100644 index 00000000..6e56929c --- /dev/null +++ b/core/mapping/utils_test.go @@ -0,0 +1,295 @@ +package mapping + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +const testTagName = "key" + +type Foo struct { + Str string + StrWithTag string `key:"stringwithtag"` + StrWithTagAndOption string `key:"stringwithtag,string"` +} + +func TestDeferInt(t *testing.T) { + var i = 1 + var s = "hello" + number := struct { + f float64 + }{ + f: 6.4, + } + cases := []struct { + t reflect.Type + expect reflect.Kind + }{ + { + t: reflect.TypeOf(i), + expect: reflect.Int, + }, + { + t: reflect.TypeOf(&i), + expect: reflect.Int, + }, + { + t: reflect.TypeOf(s), + expect: reflect.String, + }, + { + t: reflect.TypeOf(&s), + expect: reflect.String, + }, + { + t: reflect.TypeOf(number.f), + expect: reflect.Float64, + }, + { + t: reflect.TypeOf(&number.f), + expect: reflect.Float64, + }, + } + + for _, each := range cases { + t.Run(each.t.String(), func(t *testing.T) { + assert.Equal(t, each.expect, Deref(each.t).Kind()) + }) + } +} + +func TestParseKeyAndOptionWithoutTag(t *testing.T) { + var foo Foo + rte := reflect.TypeOf(&foo).Elem() + field, _ := rte.FieldByName("Str") + key, options, err := parseKeyAndOptions(testTagName, field) + assert.Nil(t, err) + assert.Equal(t, "Str", key) + assert.Nil(t, options) +} + +func TestParseKeyAndOptionWithTagWithoutOption(t *testing.T) { + var foo Foo + rte := reflect.TypeOf(&foo).Elem() + field, _ := rte.FieldByName("StrWithTag") + key, options, err := parseKeyAndOptions(testTagName, field) + assert.Nil(t, err) + assert.Equal(t, "stringwithtag", key) + assert.Nil(t, options) +} + +func TestParseKeyAndOptionWithTagAndOption(t *testing.T) { + var foo Foo + rte := reflect.TypeOf(&foo).Elem() + field, _ := rte.FieldByName("StrWithTagAndOption") + key, options, err := parseKeyAndOptions(testTagName, field) + assert.Nil(t, err) + assert.Equal(t, "stringwithtag", key) + assert.True(t, options.FromString) +} + +func TestValidatePtrWithNonPtr(t *testing.T) { + var foo string + rve := reflect.ValueOf(foo) + assert.NotNil(t, ValidatePtr(&rve)) +} + +func TestValidatePtrWithPtr(t *testing.T) { + var foo string + rve := reflect.ValueOf(&foo) + assert.Nil(t, ValidatePtr(&rve)) +} + +func TestValidatePtrWithNilPtr(t *testing.T) { + var foo *string + rve := reflect.ValueOf(foo) + assert.NotNil(t, ValidatePtr(&rve)) +} + +func TestValidatePtrWithZeroValue(t *testing.T) { + var s string + e := reflect.Zero(reflect.TypeOf(s)) + assert.NotNil(t, ValidatePtr(&e)) +} + +func TestSetValueNotSettable(t *testing.T) { + var i int + assert.NotNil(t, setValue(reflect.Int, reflect.ValueOf(i), "1")) +} + +func TestParseKeyAndOptionsErrors(t *testing.T) { + type Bar struct { + OptionsValue string `key:",options=a=b"` + DefaultValue string `key:",default=a=b"` + } + + var bar Bar + _, _, err := parseKeyAndOptions("key", reflect.TypeOf(&bar).Elem().Field(0)) + assert.NotNil(t, err) + _, _, err = parseKeyAndOptions("key", reflect.TypeOf(&bar).Elem().Field(1)) + assert.NotNil(t, err) +} + +func TestSetValueFormatErrors(t *testing.T) { + type Bar struct { + IntValue int + UintValue uint + FloatValue float32 + MapValue map[string]interface{} + } + + var bar Bar + tests := []struct { + kind reflect.Kind + target reflect.Value + value string + }{ + { + kind: reflect.Int, + target: reflect.ValueOf(&bar.IntValue).Elem(), + value: "a", + }, + { + kind: reflect.Uint, + target: reflect.ValueOf(&bar.UintValue).Elem(), + value: "a", + }, + { + kind: reflect.Float32, + target: reflect.ValueOf(&bar.FloatValue).Elem(), + value: "a", + }, + { + kind: reflect.Map, + target: reflect.ValueOf(&bar.MapValue).Elem(), + }, + } + + for _, test := range tests { + t.Run(test.kind.String(), func(t *testing.T) { + err := setValue(test.kind, test.target, test.value) + assert.NotEqual(t, errValueNotSettable, err) + assert.NotNil(t, err) + }) + } +} + +func TestRepr(t *testing.T) { + var ( + f32 float32 = 1.1 + f64 = 2.2 + i8 int8 = 1 + i16 int16 = 2 + i32 int32 = 3 + i64 int64 = 4 + u8 uint8 = 5 + u16 uint16 = 6 + u32 uint32 = 7 + u64 uint64 = 8 + ) + tests := []struct { + v interface{} + expect string + }{ + { + nil, + "", + }, + { + mockStringable{}, + "mocked", + }, + { + new(mockStringable), + "mocked", + }, + { + newMockPtr(), + "mockptr", + }, + { + true, + "true", + }, + { + false, + "false", + }, + { + f32, + "1.1", + }, + { + f64, + "2.2", + }, + { + i8, + "1", + }, + { + i16, + "2", + }, + { + i32, + "3", + }, + { + i64, + "4", + }, + { + u8, + "5", + }, + { + u16, + "6", + }, + { + u32, + "7", + }, + { + u64, + "8", + }, + { + []byte(`abcd`), + "abcd", + }, + { + mockOpacity{val: 1}, + "{1}", + }, + } + + for _, test := range tests { + t.Run(test.expect, func(t *testing.T) { + assert.Equal(t, test.expect, Repr(test.v)) + }) + } +} + +type mockStringable struct{} + +func (m mockStringable) String() string { + return "mocked" +} + +type mockPtr struct{} + +func newMockPtr() *mockPtr { + return new(mockPtr) +} + +func (m *mockPtr) String() string { + return "mockptr" +} + +type mockOpacity struct { + val int +} diff --git a/core/mapping/valuer.go b/core/mapping/valuer.go new file mode 100644 index 00000000..07dcdfcb --- /dev/null +++ b/core/mapping/valuer.go @@ -0,0 +1,14 @@ +package mapping + +type ( + Valuer interface { + Value(key string) (interface{}, bool) + } + + MapValuer map[string]interface{} +) + +func (mv MapValuer) Value(key string) (interface{}, bool) { + v, ok := mv[key] + return v, ok +} diff --git a/core/mapping/yamlunmarshaler.go b/core/mapping/yamlunmarshaler.go new file mode 100644 index 00000000..6452d41b --- /dev/null +++ b/core/mapping/yamlunmarshaler.go @@ -0,0 +1,95 @@ +package mapping + +import ( + "encoding/json" + "errors" + "io" + "io/ioutil" + + "gopkg.in/yaml.v2" +) + +// To make .json & .yaml consistent, we just use json as the tag key. +const yamlTagKey = "json" + +var ( + ErrUnsupportedType = errors.New("only map-like configs are suported") + + yamlUnmarshaler = NewUnmarshaler(yamlTagKey) +) + +func UnmarshalYamlBytes(content []byte, v interface{}) error { + return unmarshalYamlBytes(content, v, yamlUnmarshaler) +} + +func UnmarshalYamlReader(reader io.Reader, v interface{}) error { + return unmarshalYamlReader(reader, v, yamlUnmarshaler) +} + +func unmarshalYamlBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error { + var o interface{} + if err := yamlUnmarshal(content, &o); err != nil { + return err + } + + if m, ok := o.(map[string]interface{}); ok { + return unmarshaler.Unmarshal(m, v) + } else { + return ErrUnsupportedType + } +} + +func unmarshalYamlReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error { + content, err := ioutil.ReadAll(reader) + if err != nil { + return err + } + + return unmarshalYamlBytes(content, v, unmarshaler) +} + +// yamlUnmarshal YAML to map[string]interface{} instead of map[interface{}]interface{}. +func yamlUnmarshal(in []byte, out interface{}) error { + var res interface{} + if err := yaml.Unmarshal(in, &res); err != nil { + return err + } + + *out.(*interface{}) = cleanupMapValue(res) + return nil +} + +func cleanupInterfaceMap(in map[interface{}]interface{}) map[string]interface{} { + res := make(map[string]interface{}) + for k, v := range in { + res[Repr(k)] = cleanupMapValue(v) + } + return res +} + +func cleanupInterfaceNumber(in interface{}) json.Number { + return json.Number(Repr(in)) +} + +func cleanupInterfaceSlice(in []interface{}) []interface{} { + res := make([]interface{}, len(in)) + for i, v := range in { + res[i] = cleanupMapValue(v) + } + return res +} + +func cleanupMapValue(v interface{}) interface{} { + switch v := v.(type) { + case []interface{}: + return cleanupInterfaceSlice(v) + case map[interface{}]interface{}: + return cleanupInterfaceMap(v) + case bool, string: + return v + case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64, float32, float64: + return cleanupInterfaceNumber(v) + default: + return Repr(v) + } +} diff --git a/core/mapping/yamlunmarshaler_test.go b/core/mapping/yamlunmarshaler_test.go new file mode 100644 index 00000000..b9197c1d --- /dev/null +++ b/core/mapping/yamlunmarshaler_test.go @@ -0,0 +1,920 @@ +package mapping + +import ( + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnmarshalYamlBytes(t *testing.T) { + var c struct { + Name string + } + content := []byte(`Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Name) +} + +func TestUnmarshalYamlBytesOptional(t *testing.T) { + var c struct { + Name string + Age int `json:",optional"` + } + content := []byte(`Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Name) +} + +func TestUnmarshalYamlBytesOptionalDefault(t *testing.T) { + var c struct { + Name string + Age int `json:",optional,default=1"` + } + content := []byte(`Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Name) + assert.Equal(t, 1, c.Age) +} + +func TestUnmarshalYamlBytesDefaultOptional(t *testing.T) { + var c struct { + Name string + Age int `json:",default=1,optional"` + } + content := []byte(`Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Name) + assert.Equal(t, 1, c.Age) +} + +func TestUnmarshalYamlBytesDefault(t *testing.T) { + var c struct { + Name string `json:",default=liao"` + } + content := []byte(`{}`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Name) +} + +func TestUnmarshalYamlBytesBool(t *testing.T) { + var c struct { + Great bool + } + content := []byte(`Great: true`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.True(t, c.Great) +} + +func TestUnmarshalYamlBytesInt(t *testing.T) { + var c struct { + Age int + } + content := []byte(`Age: 1`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, c.Age) +} + +func TestUnmarshalYamlBytesUint(t *testing.T) { + var c struct { + Age uint + } + content := []byte(`Age: 1`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, uint(1), c.Age) +} + +func TestUnmarshalYamlBytesFloat(t *testing.T) { + var c struct { + Age float32 + } + content := []byte(`Age: 1.5`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, float32(1.5), c.Age) +} + +func TestUnmarshalYamlBytesMustInOptional(t *testing.T) { + var c struct { + Inner struct { + There string + Must string + Optional string `json:",optional"` + } `json:",optional"` + } + content := []byte(`{}`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesMustInOptionalMissedPart(t *testing.T) { + var c struct { + Inner struct { + There string + Must string + Optional string `json:",optional"` + } `json:",optional"` + } + content := []byte(`Inner: + There: sure`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesMustInOptionalOnlyOptionalFilled(t *testing.T) { + var c struct { + Inner struct { + There string + Must string + Optional string `json:",optional"` + } `json:",optional"` + } + content := []byte(`Inner: + Optional: sure`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesPartial(t *testing.T) { + var c struct { + Name string + Age float32 + } + content := []byte(`Age: 1.5`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesStruct(t *testing.T) { + var c struct { + Inner struct { + Name string + } + } + content := []byte(`Inner: + Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) +} + +func TestUnmarshalYamlBytesStructOptional(t *testing.T) { + var c struct { + Inner struct { + Name string + Age int `json:",optional"` + } + } + content := []byte(`Inner: + Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) +} + +func TestUnmarshalYamlBytesStructPtr(t *testing.T) { + var c struct { + Inner *struct { + Name string + } + } + content := []byte(`Inner: + Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) +} + +func TestUnmarshalYamlBytesStructPtrOptional(t *testing.T) { + var c struct { + Inner *struct { + Name string + Age int `json:",optional"` + } + } + content := []byte(`Inner: + Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesStructPtrDefault(t *testing.T) { + var c struct { + Inner *struct { + Name string + Age int `json:",default=4"` + } + } + content := []byte(`Inner: + Name: liao`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "liao", c.Inner.Name) + assert.Equal(t, 4, c.Inner.Age) +} + +func TestUnmarshalYamlBytesSliceString(t *testing.T) { + var c struct { + Names []string + } + content := []byte(`Names: +- liao +- chaoxin`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + + want := []string{"liao", "chaoxin"} + if !reflect.DeepEqual(c.Names, want) { + t.Fatalf("want %q, got %q", c.Names, want) + } +} + +func TestUnmarshalYamlBytesSliceStringOptional(t *testing.T) { + var c struct { + Names []string + Age []int `json:",optional"` + } + content := []byte(`Names: +- liao +- chaoxin`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + + want := []string{"liao", "chaoxin"} + if !reflect.DeepEqual(c.Names, want) { + t.Fatalf("want %q, got %q", c.Names, want) + } +} + +func TestUnmarshalYamlBytesSliceStruct(t *testing.T) { + var c struct { + People []struct { + Name string + Age int + } + } + content := []byte(`People: +- Name: liao + Age: 1 +- Name: chaoxin + Age: 2`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + + want := []struct { + Name string + Age int + }{ + {"liao", 1}, + {"chaoxin", 2}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %q, got %q", c.People, want) + } +} + +func TestUnmarshalYamlBytesSliceStructOptional(t *testing.T) { + var c struct { + People []struct { + Name string + Age int + Emails []string `json:",optional"` + } + } + content := []byte(`People: +- Name: liao + Age: 1 +- Name: chaoxin + Age: 2`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + + want := []struct { + Name string + Age int + Emails []string `json:",optional"` + }{ + {"liao", 1, nil}, + {"chaoxin", 2, nil}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %q, got %q", c.People, want) + } +} + +func TestUnmarshalYamlBytesSliceStructPtr(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + } + } + content := []byte(`People: +- Name: liao + Age: 1 +- Name: chaoxin + Age: 2`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + + want := []*struct { + Name string + Age int + }{ + {"liao", 1}, + {"chaoxin", 2}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %v, got %v", c.People, want) + } +} + +func TestUnmarshalYamlBytesSliceStructPtrOptional(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + Emails []string `json:",optional"` + } + } + content := []byte(`People: +- Name: liao + Age: 1 +- Name: chaoxin + Age: 2`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + + want := []*struct { + Name string + Age int + Emails []string `json:",optional"` + }{ + {"liao", 1, nil}, + {"chaoxin", 2, nil}, + } + if !reflect.DeepEqual(c.People, want) { + t.Fatalf("want %v, got %v", c.People, want) + } +} + +func TestUnmarshalYamlBytesSliceStructPtrPartial(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + Email string + } + } + content := []byte(`People: +- Name: liao + Age: 1 +- Name: chaoxin + Age: 2`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesSliceStructPtrDefault(t *testing.T) { + var c struct { + People []*struct { + Name string + Age int + Email string `json:",default=chaoxin@liao.com"` + } + } + content := []byte(`People: +- Name: liao + Age: 1 +- Name: chaoxin + Age: 2`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + + want := []*struct { + Name string + Age int + Email string + }{ + {"liao", 1, "chaoxin@liao.com"}, + {"chaoxin", 2, "chaoxin@liao.com"}, + } + + for i := range c.People { + actual := c.People[i] + expect := want[i] + assert.Equal(t, expect.Age, actual.Age) + assert.Equal(t, expect.Email, actual.Email) + assert.Equal(t, expect.Name, actual.Name) + } +} + +func TestUnmarshalYamlBytesSliceStringPartial(t *testing.T) { + var c struct { + Names []string + Age int + } + content := []byte(`Age: 1`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesSliceStructPartial(t *testing.T) { + var c struct { + Group string + People []struct { + Name string + Age int + } + } + content := []byte(`Group: chaoxin`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesInnerAnonymousPartial(t *testing.T) { + type ( + Deep struct { + A string + B string `json:",optional"` + } + Inner struct { + Deep + InnerV string `json:",optional"` + } + ) + + var c struct { + Value Inner `json:",optional"` + } + content := []byte(`Value: + InnerV: chaoxin`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesStructPartial(t *testing.T) { + var c struct { + Group string + Person struct { + Name string + Age int + } + } + content := []byte(`Group: chaoxin`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesEmptyMap(t *testing.T) { + var c struct { + Persons map[string]int `json:",optional"` + } + content := []byte(`{}`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Empty(t, c.Persons) +} + +func TestUnmarshalYamlBytesMap(t *testing.T) { + var c struct { + Persons map[string]int + } + content := []byte(`Persons: + first: 1 + second: 2`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 2, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"]) + assert.Equal(t, 2, c.Persons["second"]) +} + +func TestUnmarshalYamlBytesMapStruct(t *testing.T) { + var c struct { + Persons map[string]struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: + Id: 1 + name: kevin`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"].Id) + assert.Equal(t, "kevin", c.Persons["first"].Name) +} + +func TestUnmarshalYamlBytesMapStructPtr(t *testing.T) { + var c struct { + Persons map[string]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: + Id: 1 + name: kevin`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"].Id) + assert.Equal(t, "kevin", c.Persons["first"].Name) +} + +func TestUnmarshalYamlBytesMapStructMissingPartial(t *testing.T) { + var c struct { + Persons map[string]*struct { + Id int + Name string + } + } + content := []byte(`Persons: + first: + Id: 1`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesMapStructOptional(t *testing.T) { + var c struct { + Persons map[string]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: + Id: 1`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"].Id) +} + +func TestUnmarshalYamlBytesMapStructSlice(t *testing.T) { + var c struct { + Persons map[string][]struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: + - Id: 1 + name: kevin`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"][0].Id) + assert.Equal(t, "kevin", c.Persons["first"][0].Name) +} + +func TestUnmarshalYamlBytesMapEmptyStructSlice(t *testing.T) { + var c struct { + Persons map[string][]struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: []`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Empty(t, c.Persons["first"]) +} + +func TestUnmarshalYamlBytesMapStructPtrSlice(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: + - Id: 1 + name: kevin`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"][0].Id) + assert.Equal(t, "kevin", c.Persons["first"][0].Name) +} + +func TestUnmarshalYamlBytesMapEmptyStructPtrSlice(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: []`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Empty(t, c.Persons["first"]) +} + +func TestUnmarshalYamlBytesMapStructPtrSliceMissingPartial(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string + } + } + content := []byte(`Persons: + first: + - Id: 1`) + + assert.NotNil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlBytesMapStructPtrSliceOptional(t *testing.T) { + var c struct { + Persons map[string][]*struct { + Id int + Name string `json:"name,optional"` + } + } + content := []byte(`Persons: + first: + - Id: 1`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, 1, len(c.Persons)) + assert.Equal(t, 1, c.Persons["first"][0].Id) +} + +func TestUnmarshalYamlStructOptional(t *testing.T) { + var c struct { + Name string + Etcd struct { + Hosts []string + Key string + } `json:",optional"` + } + content := []byte(`Name: kevin`) + + err := UnmarshalYamlBytes(content, &c) + assert.Nil(t, err) + assert.Equal(t, "kevin", c.Name) +} + +func TestUnmarshalYamlStructLowerCase(t *testing.T) { + var c struct { + Name string + Etcd struct { + Key string + } `json:"etcd"` + } + content := []byte(`Name: kevin +etcd: + Key: the key`) + + err := UnmarshalYamlBytes(content, &c) + assert.Nil(t, err) + assert.Equal(t, "kevin", c.Name) + assert.Equal(t, "the key", c.Etcd.Key) +} + +func TestUnmarshalYamlWithStructAllOptionalWithEmpty(t *testing.T) { + var c struct { + Inner struct { + Optional string `json:",optional"` + } + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlWithStructAllOptionalPtr(t *testing.T) { + var c struct { + Inner *struct { + Optional string `json:",optional"` + } + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlWithStructOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + In Inner `json:",optional"` + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "", c.In.Must) +} + +func TestUnmarshalYamlWithStructPtrOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + In *Inner `json:",optional"` + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Nil(t, c.In) +} + +func TestUnmarshalYamlWithStructAllOptionalAnonymous(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + Inner + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlWithStructAllOptionalAnonymousPtr(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + *Inner + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) +} + +func TestUnmarshalYamlWithStructAllOptionalProvoidedAnonymous(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + Inner + Else string + } + content := []byte(`Else: sure +Optional: optional`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "optional", c.Optional) +} + +func TestUnmarshalYamlWithStructAllOptionalProvoidedAnonymousPtr(t *testing.T) { + type Inner struct { + Optional string `json:",optional"` + } + + var c struct { + *Inner + Else string + } + content := []byte(`Else: sure +Optional: optional`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "optional", c.Optional) +} + +func TestUnmarshalYamlWithStructAnonymous(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + Inner + Else string + } + content := []byte(`Else: sure +Must: must`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "must", c.Must) +} + +func TestUnmarshalYamlWithStructAnonymousPtr(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + *Inner + Else string + } + content := []byte(`Else: sure +Must: must`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "must", c.Must) +} + +func TestUnmarshalYamlWithStructAnonymousOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + Inner `json:",optional"` + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Equal(t, "", c.Must) +} + +func TestUnmarshalYamlWithStructPtrAnonymousOptional(t *testing.T) { + type Inner struct { + Must string + } + + var c struct { + *Inner `json:",optional"` + Else string + } + content := []byte(`Else: sure`) + + assert.Nil(t, UnmarshalYamlBytes(content, &c)) + assert.Equal(t, "sure", c.Else) + assert.Nil(t, c.Inner) +} + +func TestUnmarshalYamlWithZeroValues(t *testing.T) { + type inner struct { + False bool `json:"negative"` + Int int `json:"int"` + String string `json:"string"` + } + content := []byte(`negative: false +int: 0 +string: ""`) + + var in inner + ast := assert.New(t) + ast.Nil(UnmarshalYamlBytes(content, &in)) + ast.False(in.False) + ast.Equal(0, in.Int) + ast.Equal("", in.String) +} + +func TestUnmarshalYamlBytesError(t *testing.T) { + payload := `abcd: +- cdef` + var v struct { + Any []string `json:"abcd"` + } + + err := UnmarshalYamlBytes([]byte(payload), &v) + assert.Nil(t, err) + assert.Equal(t, 1, len(v.Any)) + assert.Equal(t, "cdef", v.Any[0]) +} + +func TestUnmarshalYamlReaderError(t *testing.T) { + payload := `abcd: cdef` + reader := strings.NewReader(payload) + var v struct { + Any string + } + + err := UnmarshalYamlReader(reader, &v) + assert.NotNil(t, err) +} diff --git a/core/mapreduce/mapreduce.go b/core/mapreduce/mapreduce.go new file mode 100644 index 00000000..c2c6ee95 --- /dev/null +++ b/core/mapreduce/mapreduce.go @@ -0,0 +1,264 @@ +package mapreduce + +import ( + "errors" + "fmt" + "sync" + + "zero/core/errorx" + "zero/core/lang" + "zero/core/syncx" + "zero/core/threading" +) + +const ( + defaultWorkers = 16 + minWorkers = 1 +) + +var ErrCancelWithNil = errors.New("mapreduce cancelled with nil") + +type ( + GenerateFunc func(source chan<- interface{}) + MapFunc func(item interface{}, writer Writer) + VoidMapFunc func(item interface{}) + MapperFunc func(item interface{}, writer Writer, cancel func(error)) + ReducerFunc func(pipe <-chan interface{}, writer Writer, cancel func(error)) + VoidReducerFunc func(pipe <-chan interface{}, cancel func(error)) + Option func(opts *mapReduceOptions) + + mapReduceOptions struct { + workers int + } + + Writer interface { + Write(v interface{}) + } +) + +func Finish(fns ...func() error) error { + if len(fns) == 0 { + return nil + } + + return MapReduceVoid(func(source chan<- interface{}) { + for _, fn := range fns { + source <- fn + } + }, func(item interface{}, writer Writer, cancel func(error)) { + fn := item.(func() error) + if err := fn(); err != nil { + cancel(err) + } + }, func(pipe <-chan interface{}, cancel func(error)) { + drain(pipe) + }, WithWorkers(len(fns))) +} + +func FinishVoid(fns ...func()) { + if len(fns) == 0 { + return + } + + MapVoid(func(source chan<- interface{}) { + for _, fn := range fns { + source <- fn + } + }, func(item interface{}) { + fn := item.(func()) + fn() + }, WithWorkers(len(fns))) +} + +func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} { + options := buildOptions(opts...) + source := buildSource(generate) + collector := make(chan interface{}, options.workers) + done := syncx.NewDoneChan() + + go mapDispatcher(mapper, source, collector, done.Done(), options.workers) + + return collector +} + +func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) { + source := buildSource(generate) + return MapReduceWithSource(source, mapper, reducer, opts...) +} + +func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc, + opts ...Option) (interface{}, error) { + options := buildOptions(opts...) + output := make(chan interface{}) + collector := make(chan interface{}, options.workers) + done := syncx.NewDoneChan() + writer := newGuardedWriter(output, done.Done()) + var retErr errorx.AtomicError + cancel := once(func(err error) { + if err != nil { + retErr.Set(err) + } else { + retErr.Set(ErrCancelWithNil) + } + + drain(source) + done.Close() + close(output) + }) + + go func() { + defer func() { + if r := recover(); r != nil { + cancel(fmt.Errorf("%v", r)) + } + }() + reducer(collector, writer, cancel) + }() + go mapperDispatcher(mapper, source, collector, done.Done(), cancel, options.workers) + + value, ok := <-output + if err := retErr.Load(); err != nil { + return nil, err + } else if ok { + return value, nil + } else { + return nil, nil + } +} + +func MapReduceVoid(generator GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error { + _, err := MapReduce(generator, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) { + reducer(input, cancel) + // We need to write a placeholder to let MapReduce to continue on reducer done, + // otherwise, all goroutines are waiting. The placeholder will be discarded by MapReduce. + writer.Write(lang.Placeholder) + }, opts...) + return err +} + +func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) { + drain(Map(generate, func(item interface{}, writer Writer) { + mapper(item) + }, opts...)) +} + +func WithWorkers(workers int) Option { + return func(opts *mapReduceOptions) { + if workers < minWorkers { + opts.workers = minWorkers + } else { + opts.workers = workers + } + } +} + +func buildOptions(opts ...Option) *mapReduceOptions { + options := newOptions() + for _, opt := range opts { + opt(options) + } + + return options +} + +func buildSource(generate GenerateFunc) chan interface{} { + source := make(chan interface{}) + threading.GoSafe(func() { + defer close(source) + generate(source) + }) + + return source +} + +// drain drains the channel. +func drain(channel <-chan interface{}) { + // drain the channel + for range channel { + } +} + +func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- interface{}, + done <-chan lang.PlaceholderType, workers int) { + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(collector) + }() + + pool := make(chan lang.PlaceholderType, workers) + writer := newGuardedWriter(collector, done) + for { + select { + case <-done: + return + case pool <- lang.Placeholder: + item, ok := <-input + if !ok { + <-pool + return + } + + wg.Add(1) + // better to safely run caller defined method + threading.GoSafe(func() { + defer func() { + wg.Done() + <-pool + }() + + mapper(item, writer) + }) + } + } +} + +func mapDispatcher(mapper MapFunc, input <-chan interface{}, collector chan<- interface{}, + done <-chan lang.PlaceholderType, workers int) { + executeMappers(func(item interface{}, writer Writer) { + mapper(item, writer) + }, input, collector, done, workers) +} + +func mapperDispatcher(mapper MapperFunc, input <-chan interface{}, collector chan<- interface{}, + done <-chan lang.PlaceholderType, cancel func(error), workers int) { + executeMappers(func(item interface{}, writer Writer) { + mapper(item, writer, cancel) + }, input, collector, done, workers) +} + +func newOptions() *mapReduceOptions { + return &mapReduceOptions{ + workers: defaultWorkers, + } +} + +func once(fn func(error)) func(error) { + once := new(sync.Once) + return func(err error) { + once.Do(func() { + fn(err) + }) + } +} + +type guardedWriter struct { + channel chan<- interface{} + done <-chan lang.PlaceholderType +} + +func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderType) guardedWriter { + return guardedWriter{ + channel: channel, + done: done, + } +} + +func (gw guardedWriter) Write(v interface{}) { + select { + case <-gw.done: + return + default: + gw.channel <- v + } +} diff --git a/core/mapreduce/mapreduce_test.go b/core/mapreduce/mapreduce_test.go new file mode 100644 index 00000000..153ff38e --- /dev/null +++ b/core/mapreduce/mapreduce_test.go @@ -0,0 +1,403 @@ +package mapreduce + +import ( + "errors" + "io/ioutil" + "log" + "runtime" + "sync/atomic" + "testing" + "time" + + "zero/core/stringx" + "zero/core/syncx" + + "github.com/stretchr/testify/assert" +) + +var errDummy = errors.New("dummy") + +func init() { + log.SetOutput(ioutil.Discard) +} + +func TestFinish(t *testing.T) { + var total uint32 + err := Finish(func() error { + atomic.AddUint32(&total, 2) + return nil + }, func() error { + atomic.AddUint32(&total, 3) + return nil + }, func() error { + atomic.AddUint32(&total, 5) + return nil + }) + + assert.Equal(t, uint32(10), atomic.LoadUint32(&total)) + assert.Nil(t, err) +} + +func TestFinishNone(t *testing.T) { + assert.Nil(t, Finish()) +} + +func TestFinishVoidNone(t *testing.T) { + FinishVoid() +} + +func TestFinishErr(t *testing.T) { + var total uint32 + err := Finish(func() error { + atomic.AddUint32(&total, 2) + return nil + }, func() error { + atomic.AddUint32(&total, 3) + return errDummy + }, func() error { + atomic.AddUint32(&total, 5) + return nil + }) + + assert.Equal(t, errDummy, err) +} + +func TestFinishVoid(t *testing.T) { + var total uint32 + FinishVoid(func() { + atomic.AddUint32(&total, 2) + }, func() { + atomic.AddUint32(&total, 3) + }, func() { + atomic.AddUint32(&total, 5) + }) + + assert.Equal(t, uint32(10), atomic.LoadUint32(&total)) +} + +func TestMap(t *testing.T) { + tests := []struct { + mapper MapFunc + expect int + }{ + { + mapper: func(item interface{}, writer Writer) { + v := item.(int) + writer.Write(v * v) + }, + expect: 30, + }, + { + mapper: func(item interface{}, writer Writer) { + v := item.(int) + if v%2 == 0 { + return + } + writer.Write(v * v) + }, + expect: 10, + }, + { + mapper: func(item interface{}, writer Writer) { + v := item.(int) + if v%2 == 0 { + panic(v) + } + writer.Write(v * v) + }, + expect: 10, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + channel := Map(func(source chan<- interface{}) { + for i := 1; i < 5; i++ { + source <- i + } + }, test.mapper, WithWorkers(-1)) + + var result int + for v := range channel { + result += v.(int) + } + + assert.Equal(t, test.expect, result) + }) + } +} + +func TestMapReduce(t *testing.T) { + tests := []struct { + mapper MapperFunc + reducer ReducerFunc + expectErr error + expectValue interface{} + }{ + { + expectErr: nil, + expectValue: 30, + }, + { + mapper: func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + if v%3 == 0 { + cancel(errDummy) + } + writer.Write(v * v) + }, + expectErr: errDummy, + }, + { + mapper: func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + if v%3 == 0 { + cancel(nil) + } + writer.Write(v * v) + }, + expectErr: ErrCancelWithNil, + expectValue: nil, + }, + { + reducer: func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + var result int + for item := range pipe { + result += item.(int) + if result > 10 { + cancel(errDummy) + } + } + writer.Write(result) + }, + expectErr: errDummy, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + if test.mapper == nil { + test.mapper = func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + writer.Write(v * v) + } + } + if test.reducer == nil { + test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + var result int + for item := range pipe { + result += item.(int) + } + writer.Write(result) + } + } + value, err := MapReduce(func(source chan<- interface{}) { + for i := 1; i < 5; i++ { + source <- i + } + }, test.mapper, test.reducer, WithWorkers(runtime.NumCPU())) + + assert.Equal(t, test.expectErr, err) + assert.Equal(t, test.expectValue, value) + }) + } +} + +func TestMapReduceVoid(t *testing.T) { + var value uint32 + tests := []struct { + mapper MapperFunc + reducer VoidReducerFunc + expectValue uint32 + expectErr error + }{ + { + expectValue: 30, + expectErr: nil, + }, + { + mapper: func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + if v%3 == 0 { + cancel(errDummy) + } + writer.Write(v * v) + }, + expectErr: errDummy, + }, + { + mapper: func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + if v%3 == 0 { + cancel(nil) + } + writer.Write(v * v) + }, + expectErr: ErrCancelWithNil, + }, + { + reducer: func(pipe <-chan interface{}, cancel func(error)) { + for item := range pipe { + result := atomic.AddUint32(&value, uint32(item.(int))) + if result > 10 { + cancel(errDummy) + } + } + }, + expectErr: errDummy, + }, + } + + for _, test := range tests { + t.Run(stringx.Rand(), func(t *testing.T) { + atomic.StoreUint32(&value, 0) + + if test.mapper == nil { + test.mapper = func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + writer.Write(v * v) + } + } + if test.reducer == nil { + test.reducer = func(pipe <-chan interface{}, cancel func(error)) { + for item := range pipe { + atomic.AddUint32(&value, uint32(item.(int))) + } + } + } + err := MapReduceVoid(func(source chan<- interface{}) { + for i := 1; i < 5; i++ { + source <- i + } + }, test.mapper, test.reducer) + + assert.Equal(t, test.expectErr, err) + if err == nil { + assert.Equal(t, test.expectValue, atomic.LoadUint32(&value)) + } + }) + } +} + +func TestMapReduceVoidWithDelay(t *testing.T) { + var result []int + err := MapReduceVoid(func(source chan<- interface{}) { + source <- 0 + source <- 1 + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + if i == 0 { + time.Sleep(time.Millisecond * 50) + } + writer.Write(i) + }, func(pipe <-chan interface{}, cancel func(error)) { + for item := range pipe { + i := item.(int) + result = append(result, i) + } + }) + assert.Nil(t, err) + assert.Equal(t, 2, len(result)) + assert.Equal(t, 1, result[0]) + assert.Equal(t, 0, result[1]) +} + +func TestMapVoid(t *testing.T) { + const tasks = 1000 + var count uint32 + MapVoid(func(source chan<- interface{}) { + for i := 0; i < tasks; i++ { + source <- i + } + }, func(item interface{}) { + atomic.AddUint32(&count, 1) + }) + + assert.Equal(t, tasks, int(count)) +} + +func TestMapReducePanic(t *testing.T) { + v, err := MapReduce(func(source chan<- interface{}) { + source <- 0 + source <- 1 + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + writer.Write(i) + }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + for range pipe { + panic("panic") + } + }) + assert.Nil(t, v) + assert.NotNil(t, err) + assert.Equal(t, "panic", err.Error()) +} + +func TestMapReduceVoidCancel(t *testing.T) { + var result []int + err := MapReduceVoid(func(source chan<- interface{}) { + source <- 0 + source <- 1 + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + if i == 1 { + cancel(errors.New("anything")) + } + writer.Write(i) + }, func(pipe <-chan interface{}, cancel func(error)) { + for item := range pipe { + i := item.(int) + result = append(result, i) + } + }) + assert.NotNil(t, err) + assert.Equal(t, "anything", err.Error()) +} + +func TestMapReduceVoidCancelWithRemains(t *testing.T) { + var done syncx.AtomicBool + var result []int + err := MapReduceVoid(func(source chan<- interface{}) { + for i := 0; i < defaultWorkers*2; i++ { + source <- i + } + done.Set(true) + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + if i == defaultWorkers/2 { + cancel(errors.New("anything")) + } + writer.Write(i) + }, func(pipe <-chan interface{}, cancel func(error)) { + for item := range pipe { + i := item.(int) + result = append(result, i) + } + }) + assert.NotNil(t, err) + assert.Equal(t, "anything", err.Error()) + assert.True(t, done.True()) +} + +func BenchmarkMapReduce(b *testing.B) { + b.ReportAllocs() + + mapper := func(v interface{}, writer Writer, cancel func(error)) { + writer.Write(v.(int64) * v.(int64)) + } + reducer := func(input <-chan interface{}, writer Writer, cancel func(error)) { + var result int64 + for v := range input { + result += v.(int64) + } + writer.Write(result) + } + + for i := 0; i < b.N; i++ { + MapReduce(func(input chan<- interface{}) { + for j := 0; j < 2; j++ { + input <- int64(j) + } + }, mapper, reducer) + } +} diff --git a/core/mathx/entropy.go b/core/mathx/entropy.go new file mode 100644 index 00000000..0b3c7f22 --- /dev/null +++ b/core/mathx/entropy.go @@ -0,0 +1,14 @@ +package mathx + +import "math" + +func CalcEntropy(m map[interface{}]int, total int) float64 { + var entropy float64 + + for _, v := range m { + proba := float64(v) / float64(total) + entropy -= proba * math.Log2(proba) + } + + return entropy / math.Log2(float64(len(m))) +} diff --git a/core/mathx/entropy_test.go b/core/mathx/entropy_test.go new file mode 100644 index 00000000..87130cff --- /dev/null +++ b/core/mathx/entropy_test.go @@ -0,0 +1,17 @@ +package mathx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCalcEntropy(t *testing.T) { + const total = 1000 + const count = 100 + m := make(map[interface{}]int, total) + for i := 0; i < total; i++ { + m[i] = count + } + assert.True(t, CalcEntropy(m, total*count) > .99) +} diff --git a/core/mathx/int.go b/core/mathx/int.go new file mode 100644 index 00000000..58499618 --- /dev/null +++ b/core/mathx/int.go @@ -0,0 +1,17 @@ +package mathx + +func MaxInt(a, b int) int { + if a > b { + return a + } else { + return b + } +} + +func MinInt(a, b int) int { + if a < b { + return a + } else { + return b + } +} diff --git a/core/mathx/int_test.go b/core/mathx/int_test.go new file mode 100644 index 00000000..04593f64 --- /dev/null +++ b/core/mathx/int_test.go @@ -0,0 +1,71 @@ +package mathx + +import ( + "testing" + + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +func TestMaxInt(t *testing.T) { + cases := []struct { + a int + b int + expect int + }{ + { + a: 0, + b: 1, + expect: 1, + }, + { + a: 0, + b: -1, + expect: 0, + }, + { + a: 1, + b: 1, + expect: 1, + }, + } + + for _, each := range cases { + t.Run(stringx.Rand(), func(t *testing.T) { + actual := MaxInt(each.a, each.b) + assert.Equal(t, each.expect, actual) + }) + } +} + +func TestMinInt(t *testing.T) { + cases := []struct { + a int + b int + expect int + }{ + { + a: 0, + b: 1, + expect: 0, + }, + { + a: 0, + b: -1, + expect: -1, + }, + { + a: 1, + b: 1, + expect: 1, + }, + } + + for _, each := range cases { + t.Run(stringx.Rand(), func(t *testing.T) { + actual := MinInt(each.a, each.b) + assert.Equal(t, each.expect, actual) + }) + } +} diff --git a/core/mathx/proba.go b/core/mathx/proba.go new file mode 100644 index 00000000..9db6f9d9 --- /dev/null +++ b/core/mathx/proba.go @@ -0,0 +1,26 @@ +package mathx + +import ( + "math/rand" + "sync" + "time" +) + +type Proba struct { + // rand.New(...) returns a non thread safe object + r *rand.Rand + lock sync.Mutex +} + +func NewProba() *Proba { + return &Proba{ + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (p *Proba) TrueOnProba(proba float64) (truth bool) { + p.lock.Lock() + truth = p.r.Float64() < proba + p.lock.Unlock() + return +} diff --git a/core/mathx/proba_test.go b/core/mathx/proba_test.go new file mode 100644 index 00000000..e87ed938 --- /dev/null +++ b/core/mathx/proba_test.go @@ -0,0 +1,24 @@ +package mathx + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTrueOnProba(t *testing.T) { + const proba = math.Pi / 10 + const total = 100000 + const epsilon = 0.05 + var count int + p := NewProba() + for i := 0; i < total; i++ { + if p.TrueOnProba(proba) { + count++ + } + } + + ratio := float64(count) / float64(total) + assert.InEpsilon(t, proba, ratio, epsilon) +} diff --git a/core/mathx/unstable.go b/core/mathx/unstable.go new file mode 100644 index 00000000..198568d7 --- /dev/null +++ b/core/mathx/unstable.go @@ -0,0 +1,41 @@ +package mathx + +import ( + "math/rand" + "sync" + "time" +) + +type Unstable struct { + deviation float64 + r *rand.Rand + lock *sync.Mutex +} + +func NewUnstable(deviation float64) Unstable { + if deviation < 0 { + deviation = 0 + } + if deviation > 1 { + deviation = 1 + } + return Unstable{ + deviation: deviation, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + lock: new(sync.Mutex), + } +} + +func (u Unstable) AroundDuration(base time.Duration) time.Duration { + u.lock.Lock() + val := time.Duration((1 + u.deviation - 2*u.deviation*u.r.Float64()) * float64(base)) + u.lock.Unlock() + return val +} + +func (u Unstable) AroundInt(base int64) int64 { + u.lock.Lock() + val := int64((1 + u.deviation - 2*u.deviation*u.r.Float64()) * float64(base)) + u.lock.Unlock() + return val +} diff --git a/core/mathx/unstable_test.go b/core/mathx/unstable_test.go new file mode 100644 index 00000000..e9bb53c0 --- /dev/null +++ b/core/mathx/unstable_test.go @@ -0,0 +1,71 @@ +package mathx + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestUnstable_AroundDuration(t *testing.T) { + unstable := NewUnstable(0.05) + for i := 0; i < 1000; i++ { + val := unstable.AroundDuration(time.Second) + assert.True(t, float64(time.Second)*0.95 <= float64(val)) + assert.True(t, float64(val) <= float64(time.Second)*1.05) + } +} + +func TestUnstable_AroundInt(t *testing.T) { + const target = 10000 + unstable := NewUnstable(0.05) + for i := 0; i < 1000; i++ { + val := unstable.AroundInt(target) + assert.True(t, float64(target)*0.95 <= float64(val)) + assert.True(t, float64(val) <= float64(target)*1.05) + } +} + +func TestUnstable_AroundIntLarge(t *testing.T) { + const target int64 = 10000 + unstable := NewUnstable(5) + for i := 0; i < 1000; i++ { + val := unstable.AroundInt(target) + assert.True(t, 0 <= val) + assert.True(t, val <= 2*target) + } +} + +func TestUnstable_AroundIntNegative(t *testing.T) { + const target int64 = 10000 + unstable := NewUnstable(-0.05) + for i := 0; i < 1000; i++ { + val := unstable.AroundInt(target) + assert.Equal(t, target, val) + } +} + +func TestUnstable_Distribution(t *testing.T) { + const ( + seconds = 10000 + total = 10000 + ) + + m := make(map[int]int) + expiry := NewUnstable(0.05) + for i := 0; i < total; i++ { + val := int(expiry.AroundInt(seconds)) + m[val]++ + } + + _, ok := m[0] + assert.False(t, ok) + + mi := make(map[interface{}]int, len(m)) + for k, v := range m { + mi[k] = v + } + entropy := CalcEntropy(mi, total) + assert.True(t, len(m) > 1) + assert.True(t, entropy > 0.95) +} diff --git a/core/metric/counter.go b/core/metric/counter.go new file mode 100644 index 00000000..33cf5b9c --- /dev/null +++ b/core/metric/counter.go @@ -0,0 +1,55 @@ +package metric + +import ( + "zero/core/proc" + + prom "github.com/prometheus/client_golang/prometheus" +) + +type ( + CounterVecOpts VectorOpts + + CounterVec interface { + Inc(lables ...string) + Add(v float64, labels ...string) + close() bool + } + + promCounterVec struct { + counter *prom.CounterVec + } +) + +func NewCounterVec(cfg *CounterVecOpts) CounterVec { + if cfg == nil { + return nil + } + + vec := prom.NewCounterVec(prom.CounterOpts{ + Namespace: cfg.Namespace, + Subsystem: cfg.Subsystem, + Name: cfg.Name, + Help: cfg.Help, + }, cfg.Labels) + prom.MustRegister(vec) + cv := &promCounterVec{ + counter: vec, + } + proc.AddShutdownListener(func() { + cv.close() + }) + + return cv +} + +func (cv *promCounterVec) Inc(labels ...string) { + cv.counter.WithLabelValues(labels...).Inc() +} + +func (cv *promCounterVec) Add(v float64, lables ...string) { + cv.counter.WithLabelValues(lables...).Add(v) +} + +func (cv *promCounterVec) close() bool { + return prom.Unregister(cv.counter) +} diff --git a/core/metric/counter_test.go b/core/metric/counter_test.go new file mode 100644 index 00000000..d8612fd6 --- /dev/null +++ b/core/metric/counter_test.go @@ -0,0 +1,53 @@ +package metric + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" +) + +func TestNewCounterVec(t *testing.T) { + counterVec := NewCounterVec(&CounterVecOpts{ + Namespace: "http_server", + Subsystem: "requests", + Name: "total", + Help: "rpc client requests error count.", + }) + defer counterVec.close() + counterVecNil := NewCounterVec(nil) + assert.NotNil(t, counterVec) + assert.Nil(t, counterVecNil) +} + +func TestCounterIncr(t *testing.T) { + counterVec := NewCounterVec(&CounterVecOpts{ + Namespace: "http_client", + Subsystem: "call", + Name: "code_total", + Help: "http client requests error count.", + Labels: []string{"path", "code"}, + }) + defer counterVec.close() + cv, _ := counterVec.(*promCounterVec) + cv.Inc("/Users", "500") + cv.Inc("/Users", "500") + r := testutil.ToFloat64(cv.counter) + assert.Equal(t, float64(2), r) +} + +func TestCounterAdd(t *testing.T) { + counterVec := NewCounterVec(&CounterVecOpts{ + Namespace: "rpc_server", + Subsystem: "requests", + Name: "err_total", + Help: "rpc client requests error count.", + Labels: []string{"method", "code"}, + }) + defer counterVec.close() + cv, _ := counterVec.(*promCounterVec) + cv.Add(11, "/Users", "500") + cv.Add(22, "/Users", "500") + r := testutil.ToFloat64(cv.counter) + assert.Equal(t, float64(33), r) +} diff --git a/core/metric/gauge.go b/core/metric/gauge.go new file mode 100644 index 00000000..3951a782 --- /dev/null +++ b/core/metric/gauge.go @@ -0,0 +1,61 @@ +package metric + +import ( + "zero/core/proc" + + prom "github.com/prometheus/client_golang/prometheus" +) + +type ( + GaugeVecOpts VectorOpts + + GuageVec interface { + Set(v float64, labels ...string) + Inc(labels ...string) + Add(v float64, labels ...string) + close() bool + } + + promGuageVec struct { + gauge *prom.GaugeVec + } +) + +func NewGaugeVec(cfg *GaugeVecOpts) GuageVec { + if cfg == nil { + return nil + } + + vec := prom.NewGaugeVec( + prom.GaugeOpts{ + Namespace: cfg.Namespace, + Subsystem: cfg.Subsystem, + Name: cfg.Name, + Help: cfg.Help, + }, cfg.Labels) + prom.MustRegister(vec) + gv := &promGuageVec{ + gauge: vec, + } + proc.AddShutdownListener(func() { + gv.close() + }) + + return gv +} + +func (gv *promGuageVec) Inc(labels ...string) { + gv.gauge.WithLabelValues(labels...).Inc() +} + +func (gv *promGuageVec) Add(v float64, lables ...string) { + gv.gauge.WithLabelValues(lables...).Add(v) +} + +func (gv *promGuageVec) Set(v float64, lables ...string) { + gv.gauge.WithLabelValues(lables...).Set(v) +} + +func (gv *promGuageVec) close() bool { + return prom.Unregister(gv.gauge) +} diff --git a/core/metric/gauge_test.go b/core/metric/gauge_test.go new file mode 100644 index 00000000..3a504c15 --- /dev/null +++ b/core/metric/gauge_test.go @@ -0,0 +1,68 @@ +package metric + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" +) + +func TestNewGaugeVec(t *testing.T) { + gaugeVec := NewGaugeVec(&GaugeVecOpts{ + Namespace: "rpc_server", + Subsystem: "requests", + Name: "duration", + Help: "rpc server requests duration(ms).", + }) + defer gaugeVec.close() + gaugeVecNil := NewGaugeVec(nil) + assert.NotNil(t, gaugeVec) + assert.Nil(t, gaugeVecNil) +} + +func TestGaugeInc(t *testing.T) { + gaugeVec := NewGaugeVec(&GaugeVecOpts{ + Namespace: "rpc_client2", + Subsystem: "requests", + Name: "duration_ms", + Help: "rpc server requests duration(ms).", + Labels: []string{"path"}, + }) + defer gaugeVec.close() + gv, _ := gaugeVec.(*promGuageVec) + gv.Inc("/users") + gv.Inc("/users") + r := testutil.ToFloat64(gv.gauge) + assert.Equal(t, float64(2), r) +} + +func TestGaugeAdd(t *testing.T) { + gaugeVec := NewGaugeVec(&GaugeVecOpts{ + Namespace: "rpc_client", + Subsystem: "request", + Name: "duration_ms", + Help: "rpc server requests duration(ms).", + Labels: []string{"path"}, + }) + defer gaugeVec.close() + gv, _ := gaugeVec.(*promGuageVec) + gv.Add(-10, "/classroom") + gv.Add(30, "/classroom") + r := testutil.ToFloat64(gv.gauge) + assert.Equal(t, float64(20), r) +} + +func TestGaugeSet(t *testing.T) { + gaugeVec := NewGaugeVec(&GaugeVecOpts{ + Namespace: "http_client", + Subsystem: "request", + Name: "duration_ms", + Help: "rpc server requests duration(ms).", + Labels: []string{"path"}, + }) + gaugeVec.close() + gv, _ := gaugeVec.(*promGuageVec) + gv.Set(666, "/users") + r := testutil.ToFloat64(gv.gauge) + assert.Equal(t, float64(666), r) +} diff --git a/core/metric/histogram.go b/core/metric/histogram.go new file mode 100644 index 00000000..5fc0b820 --- /dev/null +++ b/core/metric/histogram.go @@ -0,0 +1,58 @@ +package metric + +import ( + "zero/core/proc" + + prom "github.com/prometheus/client_golang/prometheus" +) + +type ( + HistogramVecOpts struct { + Namespace string + Subsystem string + Name string + Help string + Labels []string + Buckets []float64 + } + + HistogramVec interface { + Observe(v int64, lables ...string) + close() bool + } + + promHistogramVec struct { + histogram *prom.HistogramVec + } +) + +func NewHistogramVec(cfg *HistogramVecOpts) HistogramVec { + if cfg == nil { + return nil + } + + vec := prom.NewHistogramVec(prom.HistogramOpts{ + Namespace: cfg.Namespace, + Subsystem: cfg.Subsystem, + Name: cfg.Name, + Help: cfg.Help, + Buckets: cfg.Buckets, + }, cfg.Labels) + prom.MustRegister(vec) + hv := &promHistogramVec{ + histogram: vec, + } + proc.AddShutdownListener(func() { + hv.close() + }) + + return hv +} + +func (hv *promHistogramVec) Observe(v int64, labels ...string) { + hv.histogram.WithLabelValues(labels...).Observe(float64(v)) +} + +func (hv *promHistogramVec) close() bool { + return prom.Unregister(hv.histogram) +} diff --git a/core/metric/histogram_test.go b/core/metric/histogram_test.go new file mode 100644 index 00000000..e8328649 --- /dev/null +++ b/core/metric/histogram_test.go @@ -0,0 +1,49 @@ +package metric + +import ( + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" +) + +func TestNewHistogramVec(t *testing.T) { + histogramVec := NewHistogramVec(&HistogramVecOpts{ + Name: "duration_ms", + Help: "rpc server requests duration(ms).", + Buckets: []float64{1, 2, 3}, + }) + defer histogramVec.close() + histogramVecNil := NewHistogramVec(nil) + assert.NotNil(t, histogramVec) + assert.Nil(t, histogramVecNil) +} + +func TestHistogramObserve(t *testing.T) { + histogramVec := NewHistogramVec(&HistogramVecOpts{ + Name: "counts", + Help: "rpc server requests duration(ms).", + Buckets: []float64{1, 2, 3}, + Labels: []string{"method"}, + }) + defer histogramVec.close() + hv, _ := histogramVec.(*promHistogramVec) + hv.Observe(2, "/Users") + + metadata := ` + # HELP counts rpc server requests duration(ms). + # TYPE counts histogram +` + val := ` + counts_bucket{method="/Users",le="1"} 0 + counts_bucket{method="/Users",le="2"} 1 + counts_bucket{method="/Users",le="3"} 1 + counts_bucket{method="/Users",le="+Inf"} 1 + counts_sum{method="/Users"} 2 + counts_count{method="/Users"} 1 +` + + err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val)) + assert.Nil(t, err) +} diff --git a/core/metric/metric.go b/core/metric/metric.go new file mode 100644 index 00000000..0577f309 --- /dev/null +++ b/core/metric/metric.go @@ -0,0 +1,10 @@ +package metric + +// VectorOpts general configuration +type VectorOpts struct { + Namespace string + Subsystem string + Name string + Help string + Labels []string +} diff --git a/core/naming/namer.go b/core/naming/namer.go new file mode 100644 index 00000000..a1dd53f7 --- /dev/null +++ b/core/naming/namer.go @@ -0,0 +1,5 @@ +package naming + +type Namer interface { + Name() string +} diff --git a/core/netx/ip.go b/core/netx/ip.go new file mode 100644 index 00000000..172e246e --- /dev/null +++ b/core/netx/ip.go @@ -0,0 +1,39 @@ +package netx + +import "net" + +func InternalIp() string { + infs, err := net.Interfaces() + if err != nil { + return "" + } + + for _, inf := range infs { + if isEthDown(inf.Flags) || isLoopback(inf.Flags) { + continue + } + + addrs, err := inf.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + } + } + + return "" +} + +func isEthDown(f net.Flags) bool { + return f&net.FlagUp != net.FlagUp +} + +func isLoopback(f net.Flags) bool { + return f&net.FlagLoopback == net.FlagLoopback +} diff --git a/core/netx/ip_test.go b/core/netx/ip_test.go new file mode 100644 index 00000000..3d13ea52 --- /dev/null +++ b/core/netx/ip_test.go @@ -0,0 +1,11 @@ +package netx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInternalIp(t *testing.T) { + assert.True(t, len(InternalIp()) > 0) +} diff --git a/core/proc/env.go b/core/proc/env.go new file mode 100644 index 00000000..f33dad04 --- /dev/null +++ b/core/proc/env.go @@ -0,0 +1,43 @@ +package proc + +import ( + "os" + "strconv" + "sync" +) + +var ( + envs = make(map[string]string) + envLock sync.RWMutex +) + +func Env(name string) string { + envLock.RLock() + val, ok := envs[name] + envLock.RUnlock() + + if ok { + return val + } + + val = os.Getenv(name) + envLock.Lock() + envs[name] = val + envLock.Unlock() + + return val +} + +func EnvInt(name string) (int, bool) { + val := Env(name) + if len(val) == 0 { + return 0, false + } + + n, err := strconv.Atoi(val) + if err != nil { + return 0, false + } + + return n, true +} diff --git a/core/proc/env_test.go b/core/proc/env_test.go new file mode 100644 index 00000000..1187ff25 --- /dev/null +++ b/core/proc/env_test.go @@ -0,0 +1,34 @@ +package proc + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEnv(t *testing.T) { + assert.True(t, len(Env("any")) == 0) + envLock.RLock() + val, ok := envs["any"] + envLock.RUnlock() + assert.True(t, len(val) == 0) + assert.True(t, ok) + assert.True(t, len(Env("any")) == 0) +} + +func TestEnvInt(t *testing.T) { + val, ok := EnvInt("any") + assert.Equal(t, 0, val) + assert.False(t, ok) + err := os.Setenv("anyInt", "10") + assert.Nil(t, err) + val, ok = EnvInt("anyInt") + assert.Equal(t, 10, val) + assert.True(t, ok) + err = os.Setenv("anyString", "a") + assert.Nil(t, err) + val, ok = EnvInt("anyString") + assert.Equal(t, 0, val) + assert.False(t, ok) +} diff --git a/core/proc/goroutines+polyfill.go b/core/proc/goroutines+polyfill.go new file mode 100644 index 00000000..b9bb1edc --- /dev/null +++ b/core/proc/goroutines+polyfill.go @@ -0,0 +1,6 @@ +// +build windows + +package proc + +func dumpGoroutines() { +} diff --git a/core/proc/goroutines.go b/core/proc/goroutines.go new file mode 100644 index 00000000..97afe752 --- /dev/null +++ b/core/proc/goroutines.go @@ -0,0 +1,35 @@ +// +build linux darwin + +package proc + +import ( + "fmt" + "os" + "path" + "runtime/pprof" + "syscall" + "time" + + "zero/core/logx" +) + +const ( + goroutineProfile = "goroutine" + debugLevel = 2 +) + +func dumpGoroutines() { + command := path.Base(os.Args[0]) + pid := syscall.Getpid() + dumpFile := path.Join(os.TempDir(), fmt.Sprintf("%s-%d-goroutines-%s.dump", + command, pid, time.Now().Format(timeFormat))) + + logx.Infof("Got dump goroutine signal, printing goroutine profile to %s", dumpFile) + + if f, err := os.Create(dumpFile); err != nil { + logx.Errorf("Failed to dump goroutine profile, error: %v", err) + } else { + defer f.Close() + pprof.Lookup(goroutineProfile).WriteTo(f, debugLevel) + } +} diff --git a/core/proc/process.go b/core/proc/process.go new file mode 100644 index 00000000..15a9eb4c --- /dev/null +++ b/core/proc/process.go @@ -0,0 +1,24 @@ +package proc + +import ( + "os" + "path/filepath" +) + +var ( + procName string + pid int +) + +func init() { + procName = filepath.Base(os.Args[0]) + pid = os.Getpid() +} + +func Pid() int { + return pid +} + +func ProcessName() string { + return procName +} diff --git a/core/proc/process_test.go b/core/proc/process_test.go new file mode 100644 index 00000000..cb1fca2c --- /dev/null +++ b/core/proc/process_test.go @@ -0,0 +1,15 @@ +package proc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProcessName(t *testing.T) { + assert.True(t, len(ProcessName()) > 0) +} + +func TestPid(t *testing.T) { + assert.True(t, Pid() > 0) +} diff --git a/core/proc/profile+polyfill.go b/core/proc/profile+polyfill.go new file mode 100644 index 00000000..6b3c2773 --- /dev/null +++ b/core/proc/profile+polyfill.go @@ -0,0 +1,7 @@ +// +build windows + +package proc + +func StartProfile() Stopper { + return noopStopper +} diff --git a/core/proc/profile.go b/core/proc/profile.go new file mode 100644 index 00000000..05b707b8 --- /dev/null +++ b/core/proc/profile.go @@ -0,0 +1,205 @@ +// +build linux darwin + +package proc + +import ( + "fmt" + "os" + "os/signal" + "path" + "runtime" + "runtime/pprof" + "runtime/trace" + "sync/atomic" + "syscall" + "time" + + "zero/core/logx" +) + +// DefaultMemProfileRate is the default memory profiling rate. +// See also http://golang.org/pkg/runtime/#pkg-variables +const DefaultMemProfileRate = 4096 + +// started is non zero if a profile is running. +var started uint32 + +// Profile represents an active profiling session. +type Profile struct { + // path holds the base path where various profiling files are written. + // If blank, the base path will be generated by ioutil.TempDir. + path string + + // closers holds cleanup functions that run after each profile + closers []func() + + // stopped records if a call to profile.Stop has been made + stopped uint32 +} + +func (p *Profile) close() { + for _, closer := range p.closers { + closer() + } +} + +func (p *Profile) startBlockProfile() { + fn := createDumpFile("block") + f, err := os.Create(fn) + if err != nil { + logx.Errorf("profile: could not create block profile %q: %v", fn, err) + return + } + + runtime.SetBlockProfileRate(1) + logx.Infof("profile: block profiling enabled, %s", fn) + p.closers = append(p.closers, func() { + pprof.Lookup("block").WriteTo(f, 0) + f.Close() + runtime.SetBlockProfileRate(0) + logx.Infof("profile: block profiling disabled, %s", fn) + }) +} + +func (p *Profile) startCpuProfile() { + fn := createDumpFile("cpu") + f, err := os.Create(fn) + if err != nil { + logx.Errorf("profile: could not create cpu profile %q: %v", fn, err) + return + } + + logx.Infof("profile: cpu profiling enabled, %s", fn) + pprof.StartCPUProfile(f) + p.closers = append(p.closers, func() { + pprof.StopCPUProfile() + f.Close() + logx.Infof("profile: cpu profiling disabled, %s", fn) + }) +} + +func (p *Profile) startMemProfile() { + fn := createDumpFile("mem") + f, err := os.Create(fn) + if err != nil { + logx.Errorf("profile: could not create memory profile %q: %v", fn, err) + return + } + + old := runtime.MemProfileRate + runtime.MemProfileRate = DefaultMemProfileRate + logx.Infof("profile: memory profiling enabled (rate %d), %s", runtime.MemProfileRate, fn) + p.closers = append(p.closers, func() { + pprof.Lookup("heap").WriteTo(f, 0) + f.Close() + runtime.MemProfileRate = old + logx.Infof("profile: memory profiling disabled, %s", fn) + }) +} + +func (p *Profile) startMutexProfile() { + fn := createDumpFile("mutex") + f, err := os.Create(fn) + if err != nil { + logx.Errorf("profile: could not create mutex profile %q: %v", fn, err) + return + } + + runtime.SetMutexProfileFraction(1) + logx.Infof("profile: mutex profiling enabled, %s", fn) + p.closers = append(p.closers, func() { + if mp := pprof.Lookup("mutex"); mp != nil { + mp.WriteTo(f, 0) + } + f.Close() + runtime.SetMutexProfileFraction(0) + logx.Infof("profile: mutex profiling disabled, %s", fn) + }) +} + +func (p *Profile) startThreadCreateProfile() { + fn := createDumpFile("threadcreate") + f, err := os.Create(fn) + if err != nil { + logx.Errorf("profile: could not create threadcreate profile %q: %v", fn, err) + return + } + + logx.Infof("profile: threadcreate profiling enabled, %s", fn) + p.closers = append(p.closers, func() { + if mp := pprof.Lookup("threadcreate"); mp != nil { + mp.WriteTo(f, 0) + } + f.Close() + logx.Infof("profile: threadcreate profiling disabled, %s", fn) + }) +} + +func (p *Profile) startTraceProfile() { + fn := createDumpFile("trace") + f, err := os.Create(fn) + if err != nil { + logx.Errorf("profile: could not create trace output file %q: %v", fn, err) + return + } + + if err := trace.Start(f); err != nil { + logx.Errorf("profile: could not start trace: %v", err) + return + } + + logx.Infof("profile: trace enabled, %s", fn) + p.closers = append(p.closers, func() { + trace.Stop() + logx.Infof("profile: trace disabled, %s", fn) + }) +} + +// Stop stops the profile and flushes any unwritten data. +func (p *Profile) Stop() { + if !atomic.CompareAndSwapUint32(&p.stopped, 0, 1) { + // someone has already called close + return + } + p.close() + atomic.StoreUint32(&started, 0) +} + +// Start starts a new profiling session. +// The caller should call the Stop method on the value returned +// to cleanly stop profiling. +func StartProfile() Stopper { + if !atomic.CompareAndSwapUint32(&started, 0, 1) { + logx.Error("profile: Start() already called") + return noopStopper + } + + var prof Profile + prof.startCpuProfile() + prof.startMemProfile() + prof.startMutexProfile() + prof.startBlockProfile() + prof.startTraceProfile() + prof.startThreadCreateProfile() + + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT) + <-c + + logx.Info("profile: caught interrupt, stopping profiles") + prof.Stop() + + signal.Reset() + syscall.Kill(os.Getpid(), syscall.SIGINT) + }() + + return &prof +} + +func createDumpFile(kind string) string { + command := path.Base(os.Args[0]) + pid := syscall.Getpid() + return path.Join(os.TempDir(), fmt.Sprintf("%s-%d-%s-%s.pprof", + command, pid, kind, time.Now().Format(timeFormat))) +} diff --git a/core/proc/shutdown+polyfill.go b/core/proc/shutdown+polyfill.go new file mode 100644 index 00000000..f853dd5d --- /dev/null +++ b/core/proc/shutdown+polyfill.go @@ -0,0 +1,16 @@ +// +build windows + +package proc + +import "time" + +func AddShutdownListener(fn func()) func() { + return nil +} + +func AddWrapUpListener(fn func()) func() { + return nil +} + +func SetTimeoutToForceQuit(duration time.Duration) { +} diff --git a/core/proc/shutdown.go b/core/proc/shutdown.go new file mode 100644 index 00000000..dfad0536 --- /dev/null +++ b/core/proc/shutdown.go @@ -0,0 +1,81 @@ +// +build linux darwin + +package proc + +import ( + "os" + "os/signal" + "sync" + "syscall" + "time" + + "zero/core/logx" +) + +const ( + wrapUpTime = time.Second + // why we use 5500 milliseconds is because most of our queue are blocking mode with 5 seconds + waitTime = 5500 * time.Millisecond +) + +var ( + wrapUpListeners = new(listenerManager) + shutdownListeners = new(listenerManager) + delayTimeBeforeForceQuit = waitTime +) + +func AddShutdownListener(fn func()) (waitForCalled func()) { + return shutdownListeners.addListener(fn) +} + +func AddWrapUpListener(fn func()) (waitForCalled func()) { + return wrapUpListeners.addListener(fn) +} + +func SetTimeoutToForceQuit(duration time.Duration) { + delayTimeBeforeForceQuit = duration +} + +func gracefulStop(signals chan os.Signal) { + signal.Stop(signals) + + logx.Info("Got signal SIGTERM, shutting down...") + wrapUpListeners.notifyListeners() + + time.Sleep(wrapUpTime) + shutdownListeners.notifyListeners() + + time.Sleep(delayTimeBeforeForceQuit - wrapUpTime) + logx.Infof("Still alive after %v, going to force kill the process...", delayTimeBeforeForceQuit) + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) +} + +type listenerManager struct { + lock sync.Mutex + waitGroup sync.WaitGroup + listeners []func() +} + +func (lm *listenerManager) addListener(fn func()) (waitForCalled func()) { + lm.waitGroup.Add(1) + + lm.lock.Lock() + lm.listeners = append(lm.listeners, func() { + defer lm.waitGroup.Done() + fn() + }) + lm.lock.Unlock() + + return func() { + lm.waitGroup.Wait() + } +} + +func (lm *listenerManager) notifyListeners() { + lm.lock.Lock() + defer lm.lock.Unlock() + + for _, listener := range lm.listeners { + listener() + } +} diff --git a/core/proc/signals.go b/core/proc/signals.go new file mode 100644 index 00000000..542ef052 --- /dev/null +++ b/core/proc/signals.go @@ -0,0 +1,42 @@ +// +build linux darwin + +package proc + +import ( + "os" + "os/signal" + "syscall" + + "zero/core/logx" +) + +const timeFormat = "0102150405" + +func init() { + go func() { + var profiler Stopper + + // https://golang.org/pkg/os/signal/#Notify + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGUSR1, syscall.SIGUSR2, syscall.SIGTERM) + + for { + v := <-signals + switch v { + case syscall.SIGUSR1: + dumpGoroutines() + case syscall.SIGUSR2: + if profiler == nil { + profiler = StartProfile() + } else { + profiler.Stop() + profiler = nil + } + case syscall.SIGTERM: + gracefulStop(signals) + default: + logx.Error("Got unregistered signal:", v) + } + } + }() +} diff --git a/core/proc/stopper.go b/core/proc/stopper.go new file mode 100644 index 00000000..1e05f015 --- /dev/null +++ b/core/proc/stopper.go @@ -0,0 +1,14 @@ +package proc + +var noopStopper nilStopper + +type ( + Stopper interface { + Stop() + } + + nilStopper struct{} +) + +func (ns nilStopper) Stop() { +} diff --git a/core/prof/profilecenter.go b/core/prof/profilecenter.go new file mode 100644 index 00000000..6de7cbd8 --- /dev/null +++ b/core/prof/profilecenter.go @@ -0,0 +1,118 @@ +package prof + +import ( + "bytes" + "strconv" + "sync" + "sync/atomic" + "time" + + "zero/core/logx" + "zero/core/threading" + + "github.com/olekukonko/tablewriter" +) + +type ( + Slot struct { + lifecount int64 + lastcount int64 + lifecycle int64 + lastcycle int64 + } + + ProfileCenter struct { + lock sync.RWMutex + slots map[string]*Slot + } +) + +const flushInterval = 5 * time.Minute + +var ( + profileCenter = &ProfileCenter{ + slots: make(map[string]*Slot), + } + once sync.Once +) + +func report(name string, duration time.Duration) { + updated := func() bool { + profileCenter.lock.RLock() + defer profileCenter.lock.RUnlock() + + slot, ok := profileCenter.slots[name] + if ok { + atomic.AddInt64(&slot.lifecount, 1) + atomic.AddInt64(&slot.lastcount, 1) + atomic.AddInt64(&slot.lifecycle, int64(duration)) + atomic.AddInt64(&slot.lastcycle, int64(duration)) + } + return ok + }() + + if !updated { + func() { + profileCenter.lock.Lock() + defer profileCenter.lock.Unlock() + + profileCenter.slots[name] = &Slot{ + lifecount: 1, + lastcount: 1, + lifecycle: int64(duration), + lastcycle: int64(duration), + } + }() + } + + once.Do(flushRepeatly) +} + +func flushRepeatly() { + threading.GoSafe(func() { + for { + time.Sleep(flushInterval) + logx.Stat(generateReport()) + } + }) +} + +func generateReport() string { + var buffer bytes.Buffer + buffer.WriteString("Profiling report\n") + var data [][]string + calcFn := func(total, count int64) string { + if count == 0 { + return "-" + } else { + return (time.Duration(total) / time.Duration(count)).String() + } + } + + func() { + profileCenter.lock.Lock() + defer profileCenter.lock.Unlock() + + for key, slot := range profileCenter.slots { + data = append(data, []string{ + key, + strconv.FormatInt(slot.lifecount, 10), + calcFn(slot.lifecycle, slot.lifecount), + strconv.FormatInt(slot.lastcount, 10), + calcFn(slot.lastcycle, slot.lastcount), + }) + + // reset the data for last cycle + slot.lastcount = 0 + slot.lastcycle = 0 + } + }() + + table := tablewriter.NewWriter(&buffer) + table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"}) + table.SetBorder(false) + table.AppendBulk(data) + table.Render() + + return buffer.String() +} diff --git a/core/prof/profiler.go b/core/prof/profiler.go new file mode 100644 index 00000000..e4396991 --- /dev/null +++ b/core/prof/profiler.go @@ -0,0 +1,58 @@ +package prof + +import "zero/core/utils" + +type ( + ProfilePoint struct { + *utils.ElapsedTimer + } + + Profiler interface { + Start() ProfilePoint + Report(name string, point ProfilePoint) + } + + RealProfiler struct{} + + NullProfiler struct{} +) + +var profiler = newNullProfiler() + +func EnableProfiling() { + profiler = newRealProfiler() +} + +func Start() ProfilePoint { + return profiler.Start() +} + +func Report(name string, point ProfilePoint) { + profiler.Report(name, point) +} + +func newRealProfiler() Profiler { + return &RealProfiler{} +} + +func (rp *RealProfiler) Start() ProfilePoint { + return ProfilePoint{ + ElapsedTimer: utils.NewElapsedTimer(), + } +} + +func (rp *RealProfiler) Report(name string, point ProfilePoint) { + duration := point.Duration() + report(name, duration) +} + +func newNullProfiler() Profiler { + return &NullProfiler{} +} + +func (np *NullProfiler) Start() ProfilePoint { + return ProfilePoint{} +} + +func (np *NullProfiler) Report(string, ProfilePoint) { +} diff --git a/core/prometheus/agent.go b/core/prometheus/agent.go new file mode 100644 index 00000000..6e9ce65d --- /dev/null +++ b/core/prometheus/agent.go @@ -0,0 +1,31 @@ +package prometheus + +import ( + "fmt" + "net/http" + "sync" + + "zero/core/logx" + "zero/core/threading" + + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var once sync.Once + +func StartAgent(c Config) { + once.Do(func() { + if len(c.Host) == 0 { + return + } + + threading.GoSafe(func() { + http.Handle(c.Path, promhttp.Handler()) + addr := fmt.Sprintf("%s:%d", c.Host, c.Port) + logx.Infof("Starting prometheus agent at %s", addr) + if err := http.ListenAndServe(addr, nil); err != nil { + logx.Error(err) + } + }) + }) +} diff --git a/core/prometheus/config.go b/core/prometheus/config.go new file mode 100644 index 00000000..29887bbd --- /dev/null +++ b/core/prometheus/config.go @@ -0,0 +1,7 @@ +package prometheus + +type Config struct { + Host string `json:",optional"` + Port int `json:",default=9101"` + Path string `json:",default=/metrics"` +} diff --git a/core/queue/balancedqueuepusher.go b/core/queue/balancedqueuepusher.go new file mode 100644 index 00000000..0c53c9e9 --- /dev/null +++ b/core/queue/balancedqueuepusher.go @@ -0,0 +1,44 @@ +package queue + +import ( + "errors" + "sync/atomic" + + "zero/core/logx" +) + +var ErrNoAvailablePusher = errors.New("no available pusher") + +type BalancedQueuePusher struct { + name string + pushers []QueuePusher + index uint64 +} + +func NewBalancedQueuePusher(pushers []QueuePusher) QueuePusher { + return &BalancedQueuePusher{ + name: generateName(pushers), + pushers: pushers, + } +} + +func (pusher *BalancedQueuePusher) Name() string { + return pusher.name +} + +func (pusher *BalancedQueuePusher) Push(message string) error { + size := len(pusher.pushers) + + for i := 0; i < size; i++ { + index := atomic.AddUint64(&pusher.index, 1) % uint64(size) + target := pusher.pushers[index] + + if err := target.Push(message); err != nil { + logx.Error(err) + } else { + return nil + } + } + + return ErrNoAvailablePusher +} diff --git a/core/queue/balancedqueuepusher_test.go b/core/queue/balancedqueuepusher_test.go new file mode 100644 index 00000000..89aaa3bb --- /dev/null +++ b/core/queue/balancedqueuepusher_test.go @@ -0,0 +1,43 @@ +package queue + +import ( + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBalancedQueuePusher(t *testing.T) { + const numPushers = 100 + var pushers []QueuePusher + var mockedPushers []*mockedPusher + for i := 0; i < numPushers; i++ { + p := &mockedPusher{ + name: "pusher:" + strconv.Itoa(i), + } + pushers = append(pushers, p) + mockedPushers = append(mockedPushers, p) + } + + pusher := NewBalancedQueuePusher(pushers) + assert.True(t, len(pusher.Name()) > 0) + + for i := 0; i < numPushers*1000; i++ { + assert.Nil(t, pusher.Push("item")) + } + + var counts []int + for _, p := range mockedPushers { + counts = append(counts, p.count) + } + mean := calcMean(counts) + variance := calcVariance(mean, counts) + assert.True(t, variance < 100, fmt.Sprintf("too big variance - %.2f", variance)) +} + +func TestBalancedQueuePusher_NoAvailable(t *testing.T) { + pusher := NewBalancedQueuePusher(nil) + assert.True(t, len(pusher.Name()) == 0) + assert.Equal(t, ErrNoAvailablePusher, pusher.Push("item")) +} diff --git a/core/queue/consumer.go b/core/queue/consumer.go new file mode 100644 index 00000000..8f12d97f --- /dev/null +++ b/core/queue/consumer.go @@ -0,0 +1,10 @@ +package queue + +type ( + Consumer interface { + Consume(string) error + OnEvent(event interface{}) + } + + ConsumerFactory func() (Consumer, error) +) diff --git a/core/queue/messagequeue.go b/core/queue/messagequeue.go new file mode 100644 index 00000000..569ae566 --- /dev/null +++ b/core/queue/messagequeue.go @@ -0,0 +1,6 @@ +package queue + +type MessageQueue interface { + Start() + Stop() +} diff --git a/core/queue/multiqueuepusher.go b/core/queue/multiqueuepusher.go new file mode 100644 index 00000000..f57b98af --- /dev/null +++ b/core/queue/multiqueuepusher.go @@ -0,0 +1,31 @@ +package queue + +import "zero/core/errorx" + +type MultiQueuePusher struct { + name string + pushers []QueuePusher +} + +func NewMultiQueuePusher(pushers []QueuePusher) QueuePusher { + return &MultiQueuePusher{ + name: generateName(pushers), + pushers: pushers, + } +} + +func (pusher *MultiQueuePusher) Name() string { + return pusher.name +} + +func (pusher *MultiQueuePusher) Push(message string) error { + var batchError errorx.BatchError + + for _, each := range pusher.pushers { + if err := each.Push(message); err != nil { + batchError.Add(err) + } + } + + return batchError.Err() +} diff --git a/core/queue/multiqueuepusher_test.go b/core/queue/multiqueuepusher_test.go new file mode 100644 index 00000000..8af5200c --- /dev/null +++ b/core/queue/multiqueuepusher_test.go @@ -0,0 +1,39 @@ +package queue + +import ( + "fmt" + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMultiQueuePusher(t *testing.T) { + const numPushers = 100 + var pushers []QueuePusher + var mockedPushers []*mockedPusher + for i := 0; i < numPushers; i++ { + p := &mockedPusher{ + name: "pusher:" + strconv.Itoa(i), + } + pushers = append(pushers, p) + mockedPushers = append(mockedPushers, p) + } + + pusher := NewMultiQueuePusher(pushers) + assert.True(t, len(pusher.Name()) > 0) + + for i := 0; i < 1000; i++ { + _ = pusher.Push("item") + } + + var counts []int + for _, p := range mockedPushers { + counts = append(counts, p.count) + } + mean := calcMean(counts) + variance := calcVariance(mean, counts) + assert.True(t, math.Abs(mean-1000*(1-failProba)) < 10) + assert.True(t, variance < 100, fmt.Sprintf("too big variance - %.2f", variance)) +} diff --git a/core/queue/producer.go b/core/queue/producer.go new file mode 100644 index 00000000..c0ca935d --- /dev/null +++ b/core/queue/producer.go @@ -0,0 +1,15 @@ +package queue + +type ( + Producer interface { + AddListener(listener ProduceListener) + Produce() (string, bool) + } + + ProduceListener interface { + OnProducerPause() + OnProducerResume() + } + + ProducerFactory func() (Producer, error) +) diff --git a/core/queue/queue.go b/core/queue/queue.go new file mode 100644 index 00000000..40905716 --- /dev/null +++ b/core/queue/queue.go @@ -0,0 +1,239 @@ +package queue + +import ( + "runtime" + "sync" + "sync/atomic" + "time" + + "zero/core/logx" + "zero/core/rescue" + "zero/core/stat" + "zero/core/threading" + "zero/core/timex" +) + +const queueName = "queue" + +type ( + Queue struct { + name string + metrics *stat.Metrics + producerFactory ProducerFactory + producerRoutineGroup *threading.RoutineGroup + consumerFactory ConsumerFactory + consumerRoutineGroup *threading.RoutineGroup + producerCount int + consumerCount int + active int32 + channel chan string + quit chan struct{} + listeners []QueueListener + eventLock sync.Mutex + eventChannels []chan interface{} + } + + QueueListener interface { + OnPause() + OnResume() + } + + QueuePoller interface { + Name() string + Poll() string + } + + QueuePusher interface { + Name() string + Push(string) error + } +) + +func NewQueue(producerFactory ProducerFactory, consumerFactory ConsumerFactory) *Queue { + queue := &Queue{ + metrics: stat.NewMetrics(queueName), + producerFactory: producerFactory, + producerRoutineGroup: threading.NewRoutineGroup(), + consumerFactory: consumerFactory, + consumerRoutineGroup: threading.NewRoutineGroup(), + producerCount: runtime.NumCPU(), + consumerCount: runtime.NumCPU() << 1, + channel: make(chan string), + quit: make(chan struct{}), + } + queue.SetName(queueName) + + return queue +} + +func (queue *Queue) AddListener(listener QueueListener) { + queue.listeners = append(queue.listeners, listener) +} + +func (queue *Queue) Broadcast(message interface{}) { + go func() { + queue.eventLock.Lock() + defer queue.eventLock.Unlock() + + for _, channel := range queue.eventChannels { + channel <- message + } + }() +} + +func (queue *Queue) SetName(name string) { + queue.name = name + queue.metrics.SetName(name) +} + +func (queue *Queue) SetNumConsumer(count int) { + queue.consumerCount = count +} + +func (queue *Queue) SetNumProducer(count int) { + queue.producerCount = count +} + +func (queue *Queue) Start() { + queue.startProducers(queue.producerCount) + queue.startConsumers(queue.consumerCount) + + queue.producerRoutineGroup.Wait() + close(queue.channel) + queue.consumerRoutineGroup.Wait() +} + +func (queue *Queue) Stop() { + close(queue.quit) +} + +func (queue *Queue) consume(eventChan chan interface{}) { + var consumer Consumer + + for { + var err error + if consumer, err = queue.consumerFactory(); err != nil { + logx.Errorf("Error on creating consumer: %v", err) + time.Sleep(time.Second) + } else { + break + } + } + + for { + select { + case message, ok := <-queue.channel: + if ok { + queue.consumeOne(consumer, message) + } else { + logx.Info("Task channel was closed, quitting consumer...") + return + } + case event := <-eventChan: + consumer.OnEvent(event) + } + } +} + +func (queue *Queue) consumeOne(consumer Consumer, message string) { + threading.RunSafe(func() { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + queue.metrics.Add(stat.Task{ + Duration: duration, + }) + logx.WithDuration(duration).Infof("%s", message) + }() + + if err := consumer.Consume(message); err != nil { + logx.Errorf("Error occurred while consuming %v: %v", message, err) + } + }) +} + +func (queue *Queue) pause() { + for _, listener := range queue.listeners { + listener.OnPause() + } +} + +func (queue *Queue) produce() { + var producer Producer + + for { + var err error + if producer, err = queue.producerFactory(); err != nil { + logx.Errorf("Error on creating producer: %v", err) + time.Sleep(time.Second) + } else { + break + } + } + + atomic.AddInt32(&queue.active, 1) + producer.AddListener(routineListener{ + queue: queue, + }) + + for { + select { + case <-queue.quit: + logx.Info("Quitting producer") + return + default: + if v, ok := queue.produceOne(producer); ok { + queue.channel <- v + } + } + } +} + +func (queue *Queue) produceOne(producer Producer) (string, bool) { + // avoid panic quit the producer, just log it and continue + defer rescue.Recover() + + return producer.Produce() +} + +func (queue *Queue) resume() { + for _, listener := range queue.listeners { + listener.OnResume() + } +} + +func (queue *Queue) startConsumers(number int) { + for i := 0; i < number; i++ { + eventChan := make(chan interface{}) + queue.eventLock.Lock() + queue.eventChannels = append(queue.eventChannels, eventChan) + queue.eventLock.Unlock() + queue.consumerRoutineGroup.Run(func() { + queue.consume(eventChan) + }) + } +} + +func (queue *Queue) startProducers(number int) { + for i := 0; i < number; i++ { + queue.producerRoutineGroup.Run(func() { + queue.produce() + }) + } +} + +type routineListener struct { + queue *Queue +} + +func (rl routineListener) OnProducerPause() { + if atomic.AddInt32(&rl.queue.active, -1) <= 0 { + rl.queue.pause() + } +} + +func (rl routineListener) OnProducerResume() { + if atomic.AddInt32(&rl.queue.active, 1) == 1 { + rl.queue.resume() + } +} diff --git a/core/queue/queue_test.go b/core/queue/queue_test.go new file mode 100644 index 00000000..32279094 --- /dev/null +++ b/core/queue/queue_test.go @@ -0,0 +1,94 @@ +package queue + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + consumers = 4 + rounds = 100 +) + +func TestQueue(t *testing.T) { + producer := newMockedProducer(rounds) + consumer := newMockedConsumer() + consumer.wait.Add(consumers) + q := NewQueue(func() (Producer, error) { + return producer, nil + }, func() (Consumer, error) { + return consumer, nil + }) + q.AddListener(new(mockedListener)) + q.SetName("mockqueue") + q.SetNumConsumer(consumers) + q.SetNumProducer(1) + q.pause() + q.resume() + go func() { + producer.wait.Wait() + q.Stop() + }() + q.Start() + assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count)) +} + +type mockedConsumer struct { + count int32 + events int32 + wait sync.WaitGroup +} + +func newMockedConsumer() *mockedConsumer { + return new(mockedConsumer) +} + +func (c *mockedConsumer) Consume(string) error { + atomic.AddInt32(&c.count, 1) + return nil +} + +func (c *mockedConsumer) OnEvent(interface{}) { + if atomic.AddInt32(&c.events, 1) <= consumers { + c.wait.Done() + } +} + +type mockedProducer struct { + total int32 + count int32 + wait sync.WaitGroup +} + +func newMockedProducer(total int32) *mockedProducer { + p := new(mockedProducer) + p.total = total + p.wait.Add(int(total)) + return p +} + +func (p *mockedProducer) AddListener(listener ProduceListener) { +} + +func (p *mockedProducer) Produce() (string, bool) { + if atomic.AddInt32(&p.count, 1) <= p.total { + p.wait.Done() + return "item", true + } else { + time.Sleep(time.Second) + return "", false + } +} + +type mockedListener struct { +} + +func (l *mockedListener) OnPause() { +} + +func (l *mockedListener) OnResume() { +} diff --git a/core/queue/util.go b/core/queue/util.go new file mode 100644 index 00000000..f4377f12 --- /dev/null +++ b/core/queue/util.go @@ -0,0 +1,12 @@ +package queue + +import "strings" + +func generateName(pushers []QueuePusher) string { + names := make([]string, len(pushers)) + for i, pusher := range pushers { + names[i] = pusher.Name() + } + + return strings.Join(names, ",") +} diff --git a/core/queue/util_test.go b/core/queue/util_test.go new file mode 100644 index 00000000..6537866e --- /dev/null +++ b/core/queue/util_test.go @@ -0,0 +1,78 @@ +package queue + +import ( + "errors" + "math" + "testing" + + "zero/core/logx" + "zero/core/mathx" + + "github.com/stretchr/testify/assert" +) + +var ( + proba = mathx.NewProba() + failProba = 0.01 +) + +func init() { + logx.Disable() +} + +func TestGenerateName(t *testing.T) { + pushers := []QueuePusher{ + &mockedPusher{name: "first"}, + &mockedPusher{name: "second"}, + &mockedPusher{name: "third"}, + } + + assert.Equal(t, "first,second,third", generateName(pushers)) +} + +func TestGenerateNameNil(t *testing.T) { + var pushers []QueuePusher + assert.Equal(t, "", generateName(pushers)) +} + +func calcMean(vals []int) float64 { + if len(vals) == 0 { + return 0 + } + + var result float64 + for _, val := range vals { + result += float64(val) + } + return result / float64(len(vals)) +} + +func calcVariance(mean float64, vals []int) float64 { + if len(vals) == 0 { + return 0 + } + + var result float64 + for _, val := range vals { + result += math.Pow(float64(val)-mean, 2) + } + return result / float64(len(vals)) +} + +type mockedPusher struct { + name string + count int +} + +func (p *mockedPusher) Name() string { + return p.name +} + +func (p *mockedPusher) Push(s string) error { + if proba.TrueOnProba(failProba) { + return errors.New("dummy") + } + + p.count++ + return nil +} diff --git a/core/redisqueue/conf.go b/core/redisqueue/conf.go new file mode 100644 index 00000000..4661260c --- /dev/null +++ b/core/redisqueue/conf.go @@ -0,0 +1,19 @@ +package redisqueue + +import ( + "zero/core/queue" + "zero/core/stores/redis" +) + +type RedisKeyConf struct { + redis.RedisConf + Key string `json:",optional"` +} + +func (rkc RedisKeyConf) NewProducer(opts ...ProducerOption) (queue.Producer, error) { + return newProducer(rkc.NewRedis(), rkc.Key, opts...) +} + +func (rkc RedisKeyConf) NewPusher(opts ...PusherOption) queue.QueuePusher { + return NewPusher(rkc.NewRedis(), rkc.Key, opts...) +} diff --git a/core/redisqueue/message.go b/core/redisqueue/message.go new file mode 100644 index 00000000..75bbb9d0 --- /dev/null +++ b/core/redisqueue/message.go @@ -0,0 +1,6 @@ +package redisqueue + +type TimedMessage struct { + Time int64 `json:"time"` + Payload string `json:"payload"` +} diff --git a/core/redisqueue/redisqueue_test.go b/core/redisqueue/redisqueue_test.go new file mode 100644 index 00000000..fbce6a62 --- /dev/null +++ b/core/redisqueue/redisqueue_test.go @@ -0,0 +1,82 @@ +package redisqueue + +import ( + "strconv" + "sync" + "testing" + + "zero/core/logx" + "zero/core/queue" + "zero/core/stores/redis" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +func init() { + logx.Disable() +} + +func TestRedisQueue(t *testing.T) { + const ( + total = 1000 + key = "queue" + ) + r, err := miniredis.Run() + assert.Nil(t, err) + + c := RedisKeyConf{ + RedisConf: redis.RedisConf{ + Host: r.Addr(), + Type: redis.NodeType, + }, + Key: key, + } + + pusher := NewPusher(c.NewRedis(), key, WithTime()) + assert.True(t, len(pusher.Name()) > 0) + for i := 0; i < total; i++ { + err := pusher.Push(strconv.Itoa(i)) + assert.Nil(t, err) + } + + consumer := new(mockedConsumer) + consumer.wait.Add(total) + q := queue.NewQueue(func() (queue.Producer, error) { + return c.NewProducer(TimeSensitive(5)) + }, func() (queue.Consumer, error) { + return consumer, nil + }) + q.SetNumProducer(1) + q.SetNumConsumer(1) + go func() { + q.Start() + }() + consumer.wait.Wait() + q.Stop() + + var expect int + for i := 0; i < total; i++ { + expect ^= i + } + assert.Equal(t, expect, consumer.xor) +} + +type mockedConsumer struct { + wait sync.WaitGroup + xor int +} + +func (c *mockedConsumer) Consume(s string) error { + val, err := strconv.Atoi(s) + if err != nil { + return err + } + + c.xor ^= val + c.wait.Done() + return nil +} + +func (c *mockedConsumer) OnEvent(event interface{}) { +} diff --git a/core/redisqueue/redisqueueproducer.go b/core/redisqueue/redisqueueproducer.go new file mode 100644 index 00000000..6496353f --- /dev/null +++ b/core/redisqueue/redisqueueproducer.go @@ -0,0 +1,166 @@ +package redisqueue + +import ( + "fmt" + "sync" + "time" + + "zero/core/jsonx" + "zero/core/logx" + "zero/core/queue" + "zero/core/stores/redis" +) + +const ( + logIntervalMillis = 1000 + retryRedisInterval = time.Second +) + +type ( + ProducerOption func(p queue.Producer) queue.Producer + + RedisQueueProducer struct { + name string + store *redis.Redis + key string + redisNode redis.ClosableNode + listeners []queue.ProduceListener + } +) + +func NewProducerFactory(store *redis.Redis, key string, opts ...ProducerOption) queue.ProducerFactory { + return func() (queue.Producer, error) { + return newProducer(store, key, opts...) + } +} + +func (p *RedisQueueProducer) AddListener(listener queue.ProduceListener) { + p.listeners = append(p.listeners, listener) +} + +func (p *RedisQueueProducer) Name() string { + return p.name +} + +func (p *RedisQueueProducer) Produce() (string, bool) { + lessLogger := logx.NewLessLogger(logIntervalMillis) + + for { + value, ok, err := p.store.BlpopEx(p.redisNode, p.key) + if err == nil { + return value, ok + } else if err == redis.Nil { + // timed out without elements popped + continue + } else { + lessLogger.Errorf("Error on blpop: %v", err) + p.waitForRedisAvailable() + } + } +} + +func newProducer(store *redis.Redis, key string, opts ...ProducerOption) (queue.Producer, error) { + redisNode, err := redis.CreateBlockingNode(store) + if err != nil { + return nil, err + } + + var producer queue.Producer = &RedisQueueProducer{ + name: fmt.Sprintf("%s/%s/%s", store.Type, store.Addr, key), + store: store, + key: key, + redisNode: redisNode, + } + + for _, opt := range opts { + producer = opt(producer) + } + + return producer, nil +} + +func (p *RedisQueueProducer) resetRedisConnection() error { + if p.redisNode != nil { + p.redisNode.Close() + p.redisNode = nil + } + + redisNode, err := redis.CreateBlockingNode(p.store) + if err != nil { + return err + } + + p.redisNode = redisNode + return nil +} + +func (p *RedisQueueProducer) waitForRedisAvailable() { + var paused bool + var pauseOnce sync.Once + + for { + if err := p.resetRedisConnection(); err != nil { + pauseOnce.Do(func() { + paused = true + for _, listener := range p.listeners { + listener.OnProducerPause() + } + }) + logx.Errorf("Error occurred while connect to redis: %v", err) + time.Sleep(retryRedisInterval) + } else { + break + } + } + + if paused { + for _, listener := range p.listeners { + listener.OnProducerResume() + } + } +} + +func TimeSensitive(seconds int64) ProducerOption { + return func(p queue.Producer) queue.Producer { + if seconds > 0 { + return autoDropQueueProducer{ + seconds: seconds, + producer: p, + } + } + + return p + } +} + +type autoDropQueueProducer struct { + seconds int64 // seconds before to drop + producer queue.Producer +} + +func (p autoDropQueueProducer) AddListener(listener queue.ProduceListener) { + p.producer.AddListener(listener) +} + +func (p autoDropQueueProducer) Produce() (string, bool) { + lessLogger := logx.NewLessLogger(logIntervalMillis) + + for { + content, ok := p.producer.Produce() + if !ok { + return "", false + } + + var timedMsg TimedMessage + if err := jsonx.UnmarshalFromString(content, &timedMsg); err != nil { + lessLogger.Errorf("invalid timedMessage: %s, error: %s", content, err.Error()) + continue + } + + if timedMsg.Time+p.seconds < time.Now().Unix() { + lessLogger.Errorf("expired timedMessage: %s", content) + } + + return timedMsg.Payload, true + } +} diff --git a/core/redisqueue/redisqueuepusher.go b/core/redisqueue/redisqueuepusher.go new file mode 100644 index 00000000..79008af9 --- /dev/null +++ b/core/redisqueue/redisqueuepusher.go @@ -0,0 +1,78 @@ +package redisqueue + +import ( + "fmt" + "time" + + "zero/core/jsonx" + "zero/core/logx" + "zero/core/queue" + "zero/core/stores/redis" +) + +type ( + PusherOption func(p queue.QueuePusher) queue.QueuePusher + + RedisQueuePusher struct { + name string + store *redis.Redis + key string + } +) + +func NewPusher(store *redis.Redis, key string, opts ...PusherOption) queue.QueuePusher { + var pusher queue.QueuePusher = &RedisQueuePusher{ + name: fmt.Sprintf("%s/%s/%s", store.Type, store.Addr, key), + store: store, + key: key, + } + + for _, opt := range opts { + pusher = opt(pusher) + } + + return pusher +} + +func (saver *RedisQueuePusher) Name() string { + return saver.name +} + +func (saver *RedisQueuePusher) Push(message string) error { + _, err := saver.store.Rpush(saver.key, message) + if nil != err { + return err + } + + logx.Infof("<= %s", message) + return nil +} + +func WithTime() PusherOption { + return func(p queue.QueuePusher) queue.QueuePusher { + return timedQueuePusher{ + pusher: p, + } + } +} + +type timedQueuePusher struct { + pusher queue.QueuePusher +} + +func (p timedQueuePusher) Name() string { + return p.pusher.Name() +} + +func (p timedQueuePusher) Push(message string) error { + tm := TimedMessage{ + Time: time.Now().Unix(), + Payload: message, + } + + if content, err := jsonx.Marshal(tm); err != nil { + return err + } else { + return p.pusher.Push(string(content)) + } +} diff --git a/core/rescue/recover.go b/core/rescue/recover.go new file mode 100644 index 00000000..3d63a2a5 --- /dev/null +++ b/core/rescue/recover.go @@ -0,0 +1,13 @@ +package rescue + +import "zero/core/logx" + +func Recover(cleanups ...func()) { + for _, cleanup := range cleanups { + cleanup() + } + + if p := recover(); p != nil { + logx.ErrorStack(p) + } +} diff --git a/core/rescue/recover_test.go b/core/rescue/recover_test.go new file mode 100644 index 00000000..4316501d --- /dev/null +++ b/core/rescue/recover_test.go @@ -0,0 +1,28 @@ +package rescue + +import ( + "sync/atomic" + "testing" + + "zero/core/logx" + + "github.com/stretchr/testify/assert" +) + +func init() { + logx.Disable() +} + +func TestRescue(t *testing.T) { + var count int32 + assert.NotPanics(t, func() { + defer Recover(func() { + atomic.AddInt32(&count, 2) + }, func() { + atomic.AddInt32(&count, 3) + }) + + panic("hello") + }) + assert.Equal(t, int32(5), atomic.LoadInt32(&count)) +} diff --git a/core/rpc/chainclientinterceptors.go b/core/rpc/chainclientinterceptors.go new file mode 100644 index 00000000..d2f04c0e --- /dev/null +++ b/core/rpc/chainclientinterceptors.go @@ -0,0 +1,83 @@ +package rpc + +import ( + "context" + + "google.golang.org/grpc" +) + +func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption { + return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...)) +} + +func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption { + return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...)) +} + +func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { + switch len(interceptors) { + case 0: + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return streamer(ctx, desc, cc, method, opts...) + } + case 1: + return interceptors[0] + default: + last := len(interceptors) - 1 + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, + method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + var chainStreamer grpc.Streamer + var current int + + chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn, + curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) { + if current == last { + return streamer(curCtx, curDesc, curCc, curMethod, curOpts...) + } + + current++ + clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...) + current-- + + return clientStream, err + } + + return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...) + } + } +} + +func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { + switch len(interceptors) { + case 0: + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return invoker(ctx, method, req, reply, cc, opts...) + } + case 1: + return interceptors[0] + default: + last := len(interceptors) - 1 + return func(ctx context.Context, method string, req, reply interface{}, + cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + var chainInvoker grpc.UnaryInvoker + var current int + + chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{}, + curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error { + if current == last { + return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...) + } + + current++ + err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...) + current-- + + return err + } + + return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...) + } + } +} diff --git a/core/rpc/chainserverinterceptors.go b/core/rpc/chainserverinterceptors.go new file mode 100644 index 00000000..9c3eec99 --- /dev/null +++ b/core/rpc/chainserverinterceptors.go @@ -0,0 +1,81 @@ +package rpc + +import ( + "context" + + "google.golang.org/grpc" +) + +func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption { + return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...)) +} + +func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption { + return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...)) +} + +func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { + switch len(interceptors) { + case 0: + return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + return handler(srv, stream) + } + case 1: + return interceptors[0] + default: + last := len(interceptors) - 1 + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + var chainHandler grpc.StreamHandler + var current int + + chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error { + if current == last { + return handler(curSrv, curStream) + } + + current++ + err := interceptors[current](curSrv, curStream, info, chainHandler) + current-- + + return err + } + + return interceptors[0](srv, stream, info, chainHandler) + } + } +} + +func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + switch len(interceptors) { + case 0: + return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + interface{}, error) { + return handler(ctx, req) + } + case 1: + return interceptors[0] + default: + last := len(interceptors) - 1 + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + interface{}, error) { + var chainHandler grpc.UnaryHandler + var current int + + chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) { + if current == last { + return handler(curCtx, curReq) + } + + current++ + resp, err := interceptors[current](curCtx, curReq, info, chainHandler) + current-- + + return resp, err + } + + return interceptors[0](ctx, req, info, chainHandler) + } + } +} diff --git a/core/rpc/client.go b/core/rpc/client.go new file mode 100644 index 00000000..e88386aa --- /dev/null +++ b/core/rpc/client.go @@ -0,0 +1,71 @@ +package rpc + +import ( + "context" + "fmt" + "time" + + "zero/core/rpc/clientinterceptors" + + "google.golang.org/grpc" +) + +const dialTimeout = time.Second * 3 + +type ( + ClientOptions struct { + Timeout time.Duration + DialOptions []grpc.DialOption + } + + ClientOption func(options *ClientOptions) + + Client interface { + Next() (*grpc.ClientConn, bool) + } +) + +func WithDialOption(opt grpc.DialOption) ClientOption { + return func(options *ClientOptions) { + options.DialOptions = append(options.DialOptions, opt) + } +} + +func WithTimeout(timeout time.Duration) ClientOption { + return func(options *ClientOptions) { + options.Timeout = timeout + } +} + +func buildDialOptions(opts ...ClientOption) []grpc.DialOption { + var clientOptions ClientOptions + for _, opt := range opts { + opt(&clientOptions) + } + + options := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithBlock(), + WithUnaryClientInterceptors( + clientinterceptors.BreakerInterceptor, + clientinterceptors.DurationInterceptor, + clientinterceptors.PromMetricInterceptor, + clientinterceptors.TimeoutInterceptor(clientOptions.Timeout), + clientinterceptors.TracingInterceptor, + ), + } + + return append(options, clientOptions.DialOptions...) +} + +func dial(server string, opts ...ClientOption) (*grpc.ClientConn, error) { + options := buildDialOptions(opts...) + timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout) + defer cancel() + conn, err := grpc.DialContext(timeCtx, server, options...) + if err != nil { + return nil, fmt.Errorf("rpc dial: %s, error: %s", server, err.Error()) + } + + return conn, nil +} diff --git a/core/rpc/clientinterceptors/breakerinterceptor.go b/core/rpc/clientinterceptors/breakerinterceptor.go new file mode 100644 index 00000000..43233520 --- /dev/null +++ b/core/rpc/clientinterceptors/breakerinterceptor.go @@ -0,0 +1,29 @@ +package clientinterceptors + +import ( + "context" + "path" + + "zero/core/breaker" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func acceptable(err error) bool { + switch status.Code(err) { + case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss: + return false + default: + return true + } +} + +func BreakerInterceptor(ctx context.Context, method string, req, reply interface{}, + cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + breakerName := path.Join(cc.Target(), method) + return breaker.DoWithAcceptable(breakerName, func() error { + return invoker(ctx, method, req, reply, cc, opts...) + }, acceptable) +} diff --git a/core/rpc/clientinterceptors/breakerinterceptor_test.go b/core/rpc/clientinterceptors/breakerinterceptor_test.go new file mode 100644 index 00000000..b134dca6 --- /dev/null +++ b/core/rpc/clientinterceptors/breakerinterceptor_test.go @@ -0,0 +1,51 @@ +package clientinterceptors + +import ( + "testing" + + "zero/core/breaker" + "zero/core/stat" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func init() { + stat.SetReporter(nil) +} + +type mockError struct { + st *status.Status +} + +func (m mockError) GRPCStatus() *status.Status { + return m.st +} + +func (m mockError) Error() string { + return "mocked error" +} + +func TestBreakerInterceptorNotFound(t *testing.T) { + err := mockError{st: status.New(codes.NotFound, "any")} + for i := 0; i < 1000; i++ { + assert.Equal(t, err, breaker.DoWithAcceptable("call", func() error { + return err + }, acceptable)) + } +} + +func TestBreakerInterceptorDeadlineExceeded(t *testing.T) { + err := mockError{st: status.New(codes.DeadlineExceeded, "any")} + errs := make(map[error]int) + for i := 0; i < 1000; i++ { + e := breaker.DoWithAcceptable("call", func() error { + return err + }, acceptable) + errs[e]++ + } + assert.Equal(t, 2, len(errs)) + assert.True(t, errs[err] > 0) + assert.True(t, errs[breaker.ErrServiceUnavailable] > 0) +} diff --git a/core/rpc/clientinterceptors/durationinterceptor.go b/core/rpc/clientinterceptors/durationinterceptor.go new file mode 100644 index 00000000..4062ca57 --- /dev/null +++ b/core/rpc/clientinterceptors/durationinterceptor.go @@ -0,0 +1,31 @@ +package clientinterceptors + +import ( + "context" + "path" + "time" + + "zero/core/logx" + "zero/core/timex" + + "google.golang.org/grpc" +) + +const slowThreshold = time.Millisecond * 500 + +func DurationInterceptor(ctx context.Context, method string, req, reply interface{}, + cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + serverName := path.Join(cc.Target(), method) + start := timex.Now() + err := invoker(ctx, method, req, reply, cc, opts...) + if err != nil { + logx.WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s", serverName, req, err.Error()) + } else { + elapsed := timex.Since(start) + if elapsed > slowThreshold { + logx.WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v", serverName, req, reply) + } + } + + return err +} diff --git a/core/rpc/clientinterceptors/prommetricinterceptor.go b/core/rpc/clientinterceptors/prommetricinterceptor.go new file mode 100644 index 00000000..43a2c160 --- /dev/null +++ b/core/rpc/clientinterceptors/prommetricinterceptor.go @@ -0,0 +1,43 @@ +package clientinterceptors + +import ( + "context" + "strconv" + "time" + + "zero/core/metric" + "zero/core/timex" + + "google.golang.org/grpc" + "google.golang.org/grpc/status" +) + +const clientNamespace = "rpc_client" + +var ( + metricClientReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{ + Namespace: clientNamespace, + Subsystem: "requests", + Name: "duration_ms", + Help: "rpc client requests duration(ms).", + Labels: []string{"method"}, + Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000}, + }) + + metricClientReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{ + Namespace: clientNamespace, + Subsystem: "requests", + Name: "code_total", + Help: "rpc client requests code count.", + Labels: []string{"method", "code"}, + }) +) + +func PromMetricInterceptor(ctx context.Context, method string, req, reply interface{}, + cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + startTime := timex.Now() + err := invoker(ctx, method, req, reply, cc, opts...) + metricClientReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), method) + metricClientReqCodeTotal.Inc(method, strconv.Itoa(int(status.Code(err)))) + return err +} diff --git a/core/rpc/clientinterceptors/timeoutinterceptor.go b/core/rpc/clientinterceptors/timeoutinterceptor.go new file mode 100644 index 00000000..27ff6dc6 --- /dev/null +++ b/core/rpc/clientinterceptors/timeoutinterceptor.go @@ -0,0 +1,25 @@ +package clientinterceptors + +import ( + "context" + "time" + + "zero/core/contextx" + + "google.golang.org/grpc" +) + +const defaultTimeout = time.Second * 2 + +func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { + if timeout <= 0 { + timeout = defaultTimeout + } + + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx, cancel := contextx.ShrinkDeadline(ctx, timeout) + defer cancel() + return invoker(ctx, method, req, reply, cc, opts...) + } +} diff --git a/core/rpc/clientinterceptors/tracinginterceptor.go b/core/rpc/clientinterceptors/tracinginterceptor.go new file mode 100644 index 00000000..1d41d15a --- /dev/null +++ b/core/rpc/clientinterceptors/tracinginterceptor.go @@ -0,0 +1,25 @@ +package clientinterceptors + +import ( + "context" + + "zero/core/trace" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +func TracingInterceptor(ctx context.Context, method string, req, reply interface{}, + cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx, span := trace.StartClientSpan(ctx, cc.Target(), method) + defer span.Finish() + + var pairs []string + span.Visit(func(key, val string) bool { + pairs = append(pairs, key, val) + return true + }) + ctx = metadata.AppendToOutgoingContext(ctx, pairs...) + + return invoker(ctx, method, req, reply, cc, opts...) +} diff --git a/core/rpc/directclient.go b/core/rpc/directclient.go new file mode 100644 index 00000000..3302a077 --- /dev/null +++ b/core/rpc/directclient.go @@ -0,0 +1,32 @@ +package rpc + +import ( + "google.golang.org/grpc" + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/connectivity" +) + +type DirectClient struct { + conn *grpc.ClientConn +} + +func NewDirectClient(server string, opts ...ClientOption) (*DirectClient, error) { + opts = append(opts, WithDialOption(grpc.WithBalancerName(roundrobin.Name))) + conn, err := dial(server, opts...) + if err != nil { + return nil, err + } + + return &DirectClient{ + conn: conn, + }, nil +} + +func (c *DirectClient) Next() (*grpc.ClientConn, bool) { + state := c.conn.GetState() + if state == connectivity.Ready { + return c.conn, true + } else { + return nil, false + } +} diff --git a/core/rpc/rpclogger.go b/core/rpc/rpclogger.go new file mode 100644 index 00000000..86e2fcab --- /dev/null +++ b/core/rpc/rpclogger.go @@ -0,0 +1,74 @@ +package rpc + +import ( + "sync" + + "zero/core/logx" + + "google.golang.org/grpc/grpclog" +) + +// because grpclog.errorLog is not exported, we need to define our own. +const errorLevel = 2 + +var once sync.Once + +type Logger struct{} + +func InitLogger() { + once.Do(func() { + grpclog.SetLoggerV2(new(Logger)) + }) +} + +func (l *Logger) Error(args ...interface{}) { + logx.Error(args...) +} + +func (l *Logger) Errorf(format string, args ...interface{}) { + logx.Errorf(format, args...) +} + +func (l *Logger) Errorln(args ...interface{}) { + logx.Error(args...) +} + +func (l *Logger) Fatal(args ...interface{}) { + logx.Error(args...) +} + +func (l *Logger) Fatalf(format string, args ...interface{}) { + logx.Errorf(format, args...) +} + +func (l *Logger) Fatalln(args ...interface{}) { + logx.Error(args...) +} + +func (l *Logger) Info(args ...interface{}) { + // ignore builtin grpc info +} + +func (l *Logger) Infoln(args ...interface{}) { + // ignore builtin grpc info +} + +func (l *Logger) Infof(format string, args ...interface{}) { + // ignore builtin grpc info +} + +func (l *Logger) V(v int) bool { + return v >= errorLevel +} + +func (l *Logger) Warning(args ...interface{}) { + // ignore builtin grpc warning +} + +func (l *Logger) Warningln(args ...interface{}) { + // ignore builtin grpc warning +} + +func (l *Logger) Warningf(format string, args ...interface{}) { + // ignore builtin grpc warning +} diff --git a/core/rpc/rpcpubserver.go b/core/rpc/rpcpubserver.go new file mode 100644 index 00000000..e7848bb3 --- /dev/null +++ b/core/rpc/rpcpubserver.go @@ -0,0 +1,29 @@ +package rpc + +import "zero/core/discov" + +func NewRpcPubServer(etcdEndpoints []string, etcdKey, listenOn string, opts ...ServerOption) (Server, error) { + registerEtcd := func() error { + pubClient := discov.NewPublisher(etcdEndpoints, etcdKey, listenOn) + return pubClient.KeepAlive() + } + server := keepAliveServer{ + registerEtcd: registerEtcd, + Server: NewRpcServer(listenOn, opts...), + } + + return server, nil +} + +type keepAliveServer struct { + registerEtcd func() error + Server +} + +func (ags keepAliveServer) Start(fn RegisterFn) error { + if err := ags.registerEtcd(); err != nil { + return err + } + + return ags.Server.Start(fn) +} diff --git a/core/rpc/rpcserver.go b/core/rpc/rpcserver.go new file mode 100644 index 00000000..6305778d --- /dev/null +++ b/core/rpc/rpcserver.go @@ -0,0 +1,82 @@ +package rpc + +import ( + "net" + + "zero/core/proc" + "zero/core/rpc/serverinterceptors" + "zero/core/stat" + + "google.golang.org/grpc" +) + +type ( + ServerOption func(options *rpcServerOptions) + + rpcServerOptions struct { + metrics *stat.Metrics + } + + rpcServer struct { + *baseRpcServer + } +) + +func init() { + InitLogger() +} + +func NewRpcServer(address string, opts ...ServerOption) Server { + var options rpcServerOptions + for _, opt := range opts { + opt(&options) + } + if options.metrics == nil { + options.metrics = stat.NewMetrics(address) + } + + return &rpcServer{ + baseRpcServer: newBaseRpcServer(address, options.metrics), + } +} + +func (s *rpcServer) SetName(name string) { + s.baseRpcServer.SetName(name) +} + +func (s *rpcServer) Start(register RegisterFn) error { + lis, err := net.Listen("tcp", s.address) + if err != nil { + return err + } + + unaryInterceptors := []grpc.UnaryServerInterceptor{ + serverinterceptors.UnaryCrashInterceptor(), + serverinterceptors.UnaryStatInterceptor(s.metrics), + serverinterceptors.UnaryPromMetricInterceptor(), + } + unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...) + streamInterceptors := []grpc.StreamServerInterceptor{ + serverinterceptors.StreamCrashInterceptor, + } + streamInterceptors = append(streamInterceptors, s.streamInterceptors...) + options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...), + WithStreamServerInterceptors(streamInterceptors...)) + server := grpc.NewServer(options...) + register(server) + // we need to make sure all others are wrapped up + // so we do graceful stop at shutdown phase instead of wrap up phase + shutdownCalled := proc.AddShutdownListener(func() { + server.GracefulStop() + }) + err = server.Serve(lis) + shutdownCalled() + + return err +} + +func WithMetrics(metrics *stat.Metrics) ServerOption { + return func(options *rpcServerOptions) { + options.metrics = metrics + } +} diff --git a/core/rpc/rpcsubclient.go b/core/rpc/rpcsubclient.go new file mode 100644 index 00000000..7d0c1436 --- /dev/null +++ b/core/rpc/rpcsubclient.go @@ -0,0 +1,102 @@ +package rpc + +import ( + "time" + + "zero/core/discov" + "zero/core/logx" + "zero/core/threading" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" +) + +const ( + coolOffTime = time.Second * 5 + retryTimes = 3 +) + +type ( + RoundRobinSubClient struct { + *discov.RoundRobinSubClient + } + + ConsistentSubClient struct { + *discov.ConsistentSubClient + } +) + +func NewRoundRobinRpcClient(endpoints []string, key string, opts ...ClientOption) (*RoundRobinSubClient, error) { + subClient, err := discov.NewRoundRobinSubClient(endpoints, key, func(server string) (interface{}, error) { + return dial(server, opts...) + }, func(server string, conn interface{}) error { + return closeConn(conn.(*grpc.ClientConn)) + }, discov.Exclusive()) + if err != nil { + return nil, err + } else { + return &RoundRobinSubClient{subClient}, nil + } +} + +func NewConsistentRpcClient(endpoints []string, key string, opts ...ClientOption) (*ConsistentSubClient, error) { + subClient, err := discov.NewConsistentSubClient(endpoints, key, func(server string) (interface{}, error) { + return dial(server, opts...) + }, func(server string, conn interface{}) error { + return closeConn(conn.(*grpc.ClientConn)) + }) + if err != nil { + return nil, err + } else { + return &ConsistentSubClient{subClient}, nil + } +} + +func (cli *RoundRobinSubClient) Next() (*grpc.ClientConn, bool) { + return next(func() (interface{}, bool) { + return cli.RoundRobinSubClient.Next() + }) +} + +func (cli *ConsistentSubClient) Next(key string) (*grpc.ClientConn, bool) { + return next(func() (interface{}, bool) { + return cli.ConsistentSubClient.Next(key) + }) +} + +func closeConn(conn *grpc.ClientConn) error { + // why to close the conn asynchronously is because maybe another goroutine + // is using the same conn, we can wait the coolOffTime to let the other + // goroutine to finish using the conn. + // after the conn unregistered, the balancer will not assign the conn, + // but maybe the already assigned tasks are still using it. + threading.GoSafe(func() { + time.Sleep(coolOffTime) + if err := conn.Close(); err != nil { + logx.Error(err) + } + }) + + return nil +} + +func next(nextFn func() (interface{}, bool)) (*grpc.ClientConn, bool) { + for i := 0; i < retryTimes; i++ { + v, ok := nextFn() + if !ok { + break + } + + conn, yes := v.(*grpc.ClientConn) + if !yes { + break + } + + switch conn.GetState() { + case connectivity.Ready: + return conn, true + } + } + + return nil, false +} diff --git a/core/rpc/rrclient.go b/core/rpc/rrclient.go new file mode 100644 index 00000000..243c44e5 --- /dev/null +++ b/core/rpc/rrclient.go @@ -0,0 +1,40 @@ +package rpc + +import ( + "math/rand" + "sync" + "time" + + "google.golang.org/grpc" +) + +type RRClient struct { + conns []*grpc.ClientConn + index int + lock sync.Mutex +} + +func NewRRClient(endpoints []string) (*RRClient, error) { + var conns []*grpc.ClientConn + for _, endpoint := range endpoints { + conn, err := dial(endpoint) + if err != nil { + return nil, err + } + + conns = append(conns, conn) + } + + rand.Seed(time.Now().UnixNano()) + return &RRClient{ + conns: conns, + index: rand.Intn(len(conns)), + }, nil +} + +func (c *RRClient) Next() *grpc.ClientConn { + c.lock.Lock() + defer c.lock.Unlock() + c.index = (c.index + 1) % len(c.conns) + return c.conns[c.index] +} diff --git a/core/rpc/server.go b/core/rpc/server.go new file mode 100644 index 00000000..c03b5a30 --- /dev/null +++ b/core/rpc/server.go @@ -0,0 +1,50 @@ +package rpc + +import ( + "zero/core/stat" + + "google.golang.org/grpc" +) + +type ( + RegisterFn func(*grpc.Server) + + Server interface { + AddOptions(options ...grpc.ServerOption) + AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) + AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) + SetName(string) + Start(register RegisterFn) error + } + + baseRpcServer struct { + address string + metrics *stat.Metrics + options []grpc.ServerOption + streamInterceptors []grpc.StreamServerInterceptor + unaryInterceptors []grpc.UnaryServerInterceptor + } +) + +func newBaseRpcServer(address string, metrics *stat.Metrics) *baseRpcServer { + return &baseRpcServer{ + address: address, + metrics: metrics, + } +} + +func (s *baseRpcServer) AddOptions(options ...grpc.ServerOption) { + s.options = append(s.options, options...) +} + +func (s *baseRpcServer) AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) { + s.streamInterceptors = append(s.streamInterceptors, interceptors...) +} + +func (s *baseRpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) { + s.unaryInterceptors = append(s.unaryInterceptors, interceptors...) +} + +func (s *baseRpcServer) SetName(name string) { + s.metrics.SetName(name) +} diff --git a/core/rpc/serverinterceptors/crashinterceptor.go b/core/rpc/serverinterceptors/crashinterceptor.go new file mode 100644 index 00000000..9ec567ed --- /dev/null +++ b/core/rpc/serverinterceptors/crashinterceptor.go @@ -0,0 +1,43 @@ +package serverinterceptors + +import ( + "context" + "runtime/debug" + + "zero/core/logx" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func StreamCrashInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, + handler grpc.StreamHandler) (err error) { + defer handleCrash(func(r interface{}) { + err = toPanicError(r) + }) + + return handler(srv, stream) +} + +func UnaryCrashInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (resp interface{}, err error) { + defer handleCrash(func(r interface{}) { + err = toPanicError(r) + }) + + return handler(ctx, req) + } +} + +func handleCrash(handler func(interface{})) { + if r := recover(); r != nil { + handler(r) + } +} + +func toPanicError(r interface{}) error { + logx.Errorf("%+v %s", r, debug.Stack()) + return status.Errorf(codes.Internal, "panic: %v", r) +} diff --git a/core/rpc/serverinterceptors/prommetricinterceptor.go b/core/rpc/serverinterceptors/prommetricinterceptor.go new file mode 100644 index 00000000..e3c535a5 --- /dev/null +++ b/core/rpc/serverinterceptors/prommetricinterceptor.go @@ -0,0 +1,45 @@ +package serverinterceptors + +import ( + "context" + "strconv" + "time" + + "zero/core/metric" + "zero/core/timex" + + "google.golang.org/grpc" + "google.golang.org/grpc/status" +) + +const serverNamespace = "rpc_server" + +var ( + metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{ + Namespace: serverNamespace, + Subsystem: "requests", + Name: "duration_ms", + Help: "rpc server requests duration(ms).", + Labels: []string{"method"}, + Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000}, + }) + + metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{ + Namespace: serverNamespace, + Subsystem: "requests", + Name: "code_total", + Help: "rpc server requests code count.", + Labels: []string{"method", "code"}, + }) +) + +func UnaryPromMetricInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + startTime := timex.Now() + resp, err := handler(ctx, req) + metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), info.FullMethod) + metricServerReqCodeTotal.Inc(info.FullMethod, strconv.Itoa(int(status.Code(err)))) + return resp, err + } + +} diff --git a/core/rpc/serverinterceptors/sheddinginterceptor.go b/core/rpc/serverinterceptors/sheddinginterceptor.go new file mode 100644 index 00000000..a2c266e4 --- /dev/null +++ b/core/rpc/serverinterceptors/sheddinginterceptor.go @@ -0,0 +1,53 @@ +package serverinterceptors + +import ( + "context" + "sync" + + "zero/core/load" + "zero/core/stat" + + "google.golang.org/grpc" +) + +const serviceType = "rpc" + +var ( + sheddingStat *load.SheddingStat + lock sync.Mutex +) + +func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc.UnaryServerInterceptor { + ensureSheddingStat() + + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (val interface{}, err error) { + sheddingStat.IncrementTotal() + var promise load.Promise + promise, err = shedder.Allow() + if err != nil { + metrics.AddDrop() + sheddingStat.IncrementDrop() + return + } + + defer func() { + if err == context.DeadlineExceeded { + promise.Fail() + } else { + sheddingStat.IncrementPass() + promise.Pass() + } + }() + + return handler(ctx, req) + } +} + +func ensureSheddingStat() { + lock.Lock() + if sheddingStat == nil { + sheddingStat = load.NewSheddingStat(serviceType) + } + lock.Unlock() +} diff --git a/core/rpc/serverinterceptors/statinterceptor.go b/core/rpc/serverinterceptors/statinterceptor.go new file mode 100644 index 00000000..b2a5c83b --- /dev/null +++ b/core/rpc/serverinterceptors/statinterceptor.go @@ -0,0 +1,52 @@ +package serverinterceptors + +import ( + "context" + "encoding/json" + "time" + + "zero/core/logx" + "zero/core/stat" + "zero/core/timex" + + "google.golang.org/grpc" + "google.golang.org/grpc/peer" +) + +const serverSlowThreshold = time.Millisecond * 500 + +func UnaryStatInterceptor(metrics *stat.Metrics) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (resp interface{}, err error) { + defer handleCrash(func(r interface{}) { + err = toPanicError(r) + }) + + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + metrics.Add(stat.Task{ + Duration: duration, + }) + logDuration(ctx, info.FullMethod, req, duration) + }() + + return handler(ctx, req) + } +} + +func logDuration(ctx context.Context, method string, req interface{}, duration time.Duration) { + var addr string + client, ok := peer.FromContext(ctx) + if ok { + addr = client.Addr.String() + } + content, err := json.Marshal(req) + if err != nil { + logx.Errorf("%s - %s", addr, err.Error()) + } else if duration > serverSlowThreshold { + logx.WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s", addr, method, string(content)) + } else { + logx.WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content)) + } +} diff --git a/core/rpc/serverinterceptors/timeoutinterceptor.go b/core/rpc/serverinterceptors/timeoutinterceptor.go new file mode 100644 index 00000000..26029264 --- /dev/null +++ b/core/rpc/serverinterceptors/timeoutinterceptor.go @@ -0,0 +1,19 @@ +package serverinterceptors + +import ( + "context" + "time" + + "zero/core/contextx" + + "google.golang.org/grpc" +) + +func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (resp interface{}, err error) { + ctx, cancel := contextx.ShrinkDeadline(ctx, timeout) + defer cancel() + return handler(ctx, req) + } +} diff --git a/core/rpc/serverinterceptors/tracinginterceptor.go b/core/rpc/serverinterceptors/tracinginterceptor.go new file mode 100644 index 00000000..595d3dcd --- /dev/null +++ b/core/rpc/serverinterceptors/tracinginterceptor.go @@ -0,0 +1,29 @@ +package serverinterceptors + +import ( + "context" + + "zero/core/trace" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +func UnaryTracingInterceptor(serviceName string) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (resp interface{}, err error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return handler(ctx, req) + } + + carrier, err := trace.Extract(trace.GrpcFormat, md) + if err != nil { + return handler(ctx, req) + } + + ctx, span := trace.StartServerSpan(ctx, carrier, serviceName, info.FullMethod) + defer span.Finish() + return handler(ctx, req) + } +} diff --git a/core/search/searchtree.go b/core/search/searchtree.go new file mode 100644 index 00000000..0163b37d --- /dev/null +++ b/core/search/searchtree.go @@ -0,0 +1,199 @@ +package search + +import "errors" + +const ( + colon = ':' + slash = '/' +) + +var ( + ErrDupItem = errors.New("duplicated item") + ErrDupSlash = errors.New("duplicated slash") + ErrEmptyItem = errors.New("empty item") + ErrInvalidState = errors.New("search tree is in an invalid state") + ErrNotFromRoot = errors.New("path should start with /") + + NotFound Result +) + +type ( + innerResult struct { + key string + value string + named bool + found bool + } + + node struct { + item interface{} + children [2]map[string]*node + } + + Tree struct { + root *node + } + + Result struct { + Item interface{} + Params map[string]string + } +) + +func NewTree() *Tree { + return &Tree{ + root: newNode(nil), + } +} + +func (t *Tree) Add(route string, item interface{}) error { + if len(route) == 0 || route[0] != slash { + return ErrNotFromRoot + } + + if item == nil { + return ErrEmptyItem + } + + return add(t.root, route[1:], item) +} + +func (t *Tree) Search(route string) (Result, bool) { + if len(route) == 0 || route[0] != slash { + return NotFound, false + } + + var result Result + ok := t.next(t.root, route[1:], &result) + return result, ok +} + +func (t *Tree) next(n *node, route string, result *Result) bool { + if len(route) == 0 && n.item != nil { + result.Item = n.item + return true + } + + for i := range route { + if route[i] == slash { + token := route[:i] + for _, children := range n.children { + for k, v := range children { + if r := match(k, token); r.found { + if t.next(v, route[i+1:], result) { + if r.named { + addParam(result, r.key, r.value) + } + + return true + } + } + } + } + + return false + } + } + + for _, children := range n.children { + for k, v := range children { + if r := match(k, route); r.found && v.item != nil { + result.Item = v.item + if r.named { + addParam(result, r.key, r.value) + } + + return true + } + } + } + + return false +} + +func (nd *node) getChildren(route string) map[string]*node { + if len(route) > 0 && route[0] == colon { + return nd.children[1] + } else { + return nd.children[0] + } +} + +func add(nd *node, route string, item interface{}) error { + if len(route) == 0 { + if nd.item != nil { + return ErrDupItem + } + + nd.item = item + return nil + } + + if route[0] == slash { + return ErrDupSlash + } + + for i := range route { + if route[i] == slash { + token := route[:i] + children := nd.getChildren(token) + if child, ok := children[token]; ok { + if child != nil { + return add(child, route[i+1:], item) + } else { + return ErrInvalidState + } + } else { + child := newNode(nil) + children[token] = child + return add(child, route[i+1:], item) + } + } + } + + children := nd.getChildren(route) + if child, ok := children[route]; ok { + if child.item != nil { + return ErrDupItem + } + + child.item = item + } else { + children[route] = newNode(item) + } + + return nil +} + +func addParam(result *Result, k, v string) { + if result.Params == nil { + result.Params = make(map[string]string) + } + + result.Params[k] = v +} + +func match(pat, token string) innerResult { + if pat[0] == colon { + return innerResult{ + key: pat[1:], + value: token, + named: true, + found: true, + } + } + + return innerResult{ + found: pat == token, + } +} + +func newNode(item interface{}) *node { + return &node{ + item: item, + children: [2]map[string]*node{ + make(map[string]*node), + make(map[string]*node), + }, + } +} diff --git a/core/search/searchtree_debug.go b/core/search/searchtree_debug.go new file mode 100644 index 00000000..1ef7418c --- /dev/null +++ b/core/search/searchtree_debug.go @@ -0,0 +1,32 @@ +// +build debug + +package search + +import "fmt" + +func (t *Tree) Print() { + if t.root.item == nil { + fmt.Println("/") + } else { + fmt.Printf("/:%#v\n", t.root.item) + } + printNode(t.root, 1) +} + +func printNode(n *node, depth int) { + indent := make([]byte, depth) + for i := 0; i < len(indent); i++ { + indent[i] = '\t' + } + + for _, children := range n.children { + for k, v := range children { + if v.item == nil { + fmt.Printf("%s%s\n", string(indent), k) + } else { + fmt.Printf("%s%s:%#v\n", string(indent), k, v.item) + } + printNode(v, depth+1) + } + } +} diff --git a/core/search/searchtree_test.go b/core/search/searchtree_test.go new file mode 100644 index 00000000..763d3abf --- /dev/null +++ b/core/search/searchtree_test.go @@ -0,0 +1,187 @@ +package search + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockedRoute struct { + route string + value int +} + +func TestSearch(t *testing.T) { + routes := []mockedRoute{ + {"/", 1}, + {"/api", 2}, + {"/img", 3}, + {"/:layer1", 4}, + {"/api/users", 5}, + {"/img/jpgs", 6}, + {"/img/jpgs", 7}, + {"/api/:layer2", 8}, + {"/:layer1/:layer2", 9}, + {"/:layer1/:layer2/users", 10}, + } + + tests := []struct { + query string + expect int + params map[string]string + contains bool + }{ + { + query: "", + contains: false, + }, + { + query: "/", + expect: 1, + contains: true, + }, + { + query: "/wildcard", + expect: 4, + params: map[string]string{ + "layer1": "wildcard", + }, + contains: true, + }, + { + query: "/wildcard/", + expect: 4, + params: map[string]string{ + "layer1": "wildcard", + }, + contains: true, + }, + { + query: "/a/b/c", + contains: false, + }, + { + query: "/a/b", + expect: 9, + params: map[string]string{ + "layer1": "a", + "layer2": "b", + }, + contains: true, + }, + { + query: "/a/b/", + expect: 9, + params: map[string]string{ + "layer1": "a", + "layer2": "b", + }, + contains: true, + }, + { + query: "/a/b/users", + expect: 10, + params: map[string]string{ + "layer1": "a", + "layer2": "b", + }, + contains: true, + }, + } + + for _, test := range tests { + t.Run(test.query, func(t *testing.T) { + tree := NewTree() + for _, r := range routes { + tree.Add(r.route, r.value) + } + result, ok := tree.Search(test.query) + assert.Equal(t, test.contains, ok) + if ok { + actual := result.Item.(int) + assert.EqualValues(t, test.params, result.Params) + assert.Equal(t, test.expect, actual) + } + }) + } +} + +func TestStrictSearch(t *testing.T) { + routes := []mockedRoute{ + {"/api/users", 1}, + {"/api/:layer", 2}, + } + query := "/api/users" + + tree := NewTree() + for _, r := range routes { + tree.Add(r.route, r.value) + } + + for i := 0; i < 1000; i++ { + result, ok := tree.Search(query) + assert.True(t, ok) + assert.Equal(t, 1, result.Item.(int)) + } +} + +func TestStrictSearchSibling(t *testing.T) { + routes := []mockedRoute{ + {"/api/:user/profile/name", 1}, + {"/api/:user/profile", 2}, + {"/api/:user/name", 3}, + {"/api/:layer", 4}, + } + query := "/api/123/name" + + tree := NewTree() + for _, r := range routes { + tree.Add(r.route, r.value) + } + + for i := 0; i < 1000; i++ { + result, ok := tree.Search(query) + assert.True(t, ok) + assert.Equal(t, 3, result.Item.(int)) + } +} + +func TestAddDuplicate(t *testing.T) { + tree := NewTree() + err := tree.Add("/a/b", 1) + assert.Nil(t, err) + err = tree.Add("/a/b", 2) + assert.Equal(t, ErrDupItem, err) + err = tree.Add("/a/b/", 2) + assert.Equal(t, ErrDupItem, err) +} + +func TestPlain(t *testing.T) { + tree := NewTree() + err := tree.Add("/a/b", 1) + assert.Nil(t, err) + err = tree.Add("/a/c", 2) + assert.Nil(t, err) + _, ok := tree.Search("/a/d") + assert.False(t, ok) +} + +func TestSearchWithDoubleSlashes(t *testing.T) { + tree := NewTree() + err := tree.Add("//a", 1) + assert.Error(t, ErrDupSlash, err) +} + +func TestSearchInvalidRoute(t *testing.T) { + tree := NewTree() + err := tree.Add("", 1) + assert.Equal(t, ErrNotFromRoot, err) + err = tree.Add("bad", 1) + assert.Equal(t, ErrNotFromRoot, err) +} + +func TestSearchInvalidItem(t *testing.T) { + tree := NewTree() + err := tree.Add("/", nil) + assert.Equal(t, ErrEmptyItem, err) +} diff --git a/core/service/serviceconf.go b/core/service/serviceconf.go new file mode 100644 index 00000000..b7d5416c --- /dev/null +++ b/core/service/serviceconf.go @@ -0,0 +1,56 @@ +package service + +import ( + "log" + + "zero/core/load" + "zero/core/logx" + "zero/core/prometheus" + "zero/core/stat" +) + +const ( + DevMode = "dev" + TestMode = "test" + PreMode = "pre" + ProMode = "pro" +) + +type ServiceConf struct { + Name string + Log logx.LogConf + Mode string `json:",default=pro,options=dev|test|pre|pro"` + MetricsUrl string `json:",optional"` + Prometheus prometheus.Config `json:",optional"` +} + +func (sc ServiceConf) MustSetUp() { + if err := sc.SetUp(); err != nil { + log.Fatal(err) + } +} + +func (sc ServiceConf) SetUp() error { + if len(sc.Log.ServiceName) == 0 { + sc.Log.ServiceName = sc.Name + } + if err := logx.SetUp(sc.Log); err != nil { + return err + } + + sc.initMode() + prometheus.StartAgent(sc.Prometheus) + if len(sc.MetricsUrl) > 0 { + stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl)) + } + + return nil +} + +func (sc ServiceConf) initMode() { + switch sc.Mode { + case DevMode, TestMode, PreMode: + load.Disable() + stat.SetReporter(nil) + } +} diff --git a/core/service/servicegroup.go b/core/service/servicegroup.go new file mode 100644 index 00000000..11d924e3 --- /dev/null +++ b/core/service/servicegroup.go @@ -0,0 +1,107 @@ +package service + +import ( + "log" + + "zero/core/proc" + "zero/core/syncx" + "zero/core/threading" +) + +type ( + Starter interface { + Start() + } + + Stopper interface { + Stop() + } + + Service interface { + Starter + Stopper + } + + ServiceGroup struct { + services []Service + stopOnce func() + } +) + +func NewServiceGroup() *ServiceGroup { + sg := new(ServiceGroup) + sg.stopOnce = syncx.Once(sg.doStop) + return sg +} + +func (sg *ServiceGroup) Add(service Service) { + sg.services = append(sg.services, service) +} + +// There should not be any logic code after calling this method, because this method is a blocking one. +// Also, quitting this method will close the logx output. +func (sg *ServiceGroup) Start() { + proc.AddShutdownListener(func() { + log.Println("Shutting down...") + sg.stopOnce() + }) + + sg.doStart() +} + +func (sg *ServiceGroup) Stop() { + sg.stopOnce() +} + +func (sg *ServiceGroup) doStart() { + routineGroup := threading.NewRoutineGroup() + + for i := range sg.services { + service := sg.services[i] + routineGroup.RunSafe(func() { + service.Start() + }) + } + + routineGroup.Wait() +} + +func (sg *ServiceGroup) doStop() { + for _, service := range sg.services { + service.Stop() + } +} + +func WithStart(start func()) Service { + return startOnlyService{ + start: start, + } +} + +func WithStarter(start Starter) Service { + return starterOnlyService{ + Starter: start, + } +} + +type ( + stopper struct { + } + + startOnlyService struct { + start func() + stopper + } + + starterOnlyService struct { + Starter + stopper + } +) + +func (s stopper) Stop() { +} + +func (s startOnlyService) Start() { + s.start() +} diff --git a/core/service/servicegroup_test.go b/core/service/servicegroup_test.go new file mode 100644 index 00000000..00defaa1 --- /dev/null +++ b/core/service/servicegroup_test.go @@ -0,0 +1,126 @@ +package service + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + number = 1 + mutex sync.Mutex + done = make(chan struct{}) +) + +type mockedService struct { + quit chan struct{} + multiplier int +} + +func newMockedService(multiplier int) *mockedService { + return &mockedService{ + quit: make(chan struct{}), + multiplier: multiplier, + } +} + +func (s *mockedService) Start() { + mutex.Lock() + number = number * s.multiplier + mutex.Unlock() + done <- struct{}{} + <-s.quit +} + +func (s *mockedService) Stop() { + close(s.quit) +} + +func TestServiceGroup(t *testing.T) { + multipliers := []int{2, 3, 5, 7} + want := 1 + + group := NewServiceGroup() + for _, multiplier := range multipliers { + want *= multiplier + service := newMockedService(multiplier) + group.Add(service) + } + + go group.Start() + + for i := 0; i < len(multipliers); i++ { + <-done + } + + group.Stop() + + mutex.Lock() + defer mutex.Unlock() + assert.Equal(t, want, number) +} + +func TestServiceGroup_WithStart(t *testing.T) { + multipliers := []int{2, 3, 5, 7} + want := 1 + + var wait sync.WaitGroup + var lock sync.Mutex + wait.Add(len(multipliers)) + group := NewServiceGroup() + for _, multiplier := range multipliers { + var mul = multiplier + group.Add(WithStart(func() { + lock.Lock() + want *= mul + lock.Unlock() + wait.Done() + })) + } + + go group.Start() + wait.Wait() + group.Stop() + + lock.Lock() + defer lock.Unlock() + assert.Equal(t, 210, want) +} + +func TestServiceGroup_WithStarter(t *testing.T) { + multipliers := []int{2, 3, 5, 7} + want := 1 + + var wait sync.WaitGroup + var lock sync.Mutex + wait.Add(len(multipliers)) + group := NewServiceGroup() + for _, multiplier := range multipliers { + var mul = multiplier + group.Add(WithStarter(mockedStarter{ + fn: func() { + lock.Lock() + want *= mul + lock.Unlock() + wait.Done() + }, + })) + } + + go group.Start() + wait.Wait() + group.Stop() + + lock.Lock() + defer lock.Unlock() + assert.Equal(t, 210, want) +} + +type mockedStarter struct { + fn func() +} + +func (s mockedStarter) Start() { + s.fn() +} diff --git a/core/stat/alert.go b/core/stat/alert.go new file mode 100644 index 00000000..7b890b27 --- /dev/null +++ b/core/stat/alert.go @@ -0,0 +1,70 @@ +// +build linux + +package stat + +import ( + "flag" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "zero/core/executors" + "zero/core/proc" + "zero/core/sysx" + "zero/core/timex" + "zero/core/utils" +) + +const ( + clusterNameKey = "CLUSTER_NAME" + testEnv = "test.v" + timeFormat = "2006-01-02 15:04:05" +) + +var ( + reporter = utils.Report + lock sync.RWMutex + lessExecutor = executors.NewLessExecutor(time.Minute * 5) + dropped int32 + clusterName = proc.Env(clusterNameKey) +) + +func init() { + if flag.Lookup(testEnv) != nil { + SetReporter(nil) + } +} + +func Report(msg string) { + lock.RLock() + fn := reporter + lock.RUnlock() + + if fn != nil { + reported := lessExecutor.DoOrDiscard(func() { + var builder strings.Builder + fmt.Fprintf(&builder, "%s\n", timex.Time().Format(timeFormat)) + if len(clusterName) > 0 { + fmt.Fprintf(&builder, "cluster: %s\n", clusterName) + } + fmt.Fprintf(&builder, "host: %s\n", sysx.Hostname()) + dp := atomic.SwapInt32(&dropped, 0) + if dp > 0 { + fmt.Fprintf(&builder, "dropped: %d\n", dp) + } + builder.WriteString(strings.TrimSpace(msg)) + fn(builder.String()) + }) + if !reported { + atomic.AddInt32(&dropped, 1) + } + } +} + +func SetReporter(fn func(string)) { + lock.Lock() + defer lock.Unlock() + reporter = fn +} diff --git a/core/stat/alert_polyfill.go b/core/stat/alert_polyfill.go new file mode 100644 index 00000000..cef947bc --- /dev/null +++ b/core/stat/alert_polyfill.go @@ -0,0 +1,9 @@ +// +build !linux + +package stat + +func Report(string) { +} + +func SetReporter(func(string)) { +} diff --git a/core/stat/internal/cgroup_linux.go b/core/stat/internal/cgroup_linux.go new file mode 100644 index 00000000..13fb7251 --- /dev/null +++ b/core/stat/internal/cgroup_linux.go @@ -0,0 +1,168 @@ +package internal + +import ( + "fmt" + "os" + "path" + "strconv" + "strings" + + "zero/core/iox" + "zero/core/lang" +) + +const cgroupDir = "/sys/fs/cgroup" + +type cgroup struct { + cgroups map[string]string +} + +func (c *cgroup) acctUsageAllCpus() (uint64, error) { + data, err := iox.ReadText(path.Join(c.cgroups["cpuacct"], "cpuacct.usage")) + if err != nil { + return 0, err + } + + return parseUint(string(data)) +} + +func (c *cgroup) acctUsagePerCpu() ([]uint64, error) { + data, err := iox.ReadText(path.Join(c.cgroups["cpuacct"], "cpuacct.usage_percpu")) + if err != nil { + return nil, err + } + + var usage []uint64 + for _, v := range strings.Fields(string(data)) { + u, err := parseUint(v) + if err != nil { + return nil, err + } + + usage = append(usage, u) + } + + return usage, nil +} + +func (c *cgroup) cpuQuotaUs() (int64, error) { + data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_quota_us")) + if err != nil { + return 0, err + } + + return strconv.ParseInt(string(data), 10, 64) +} + +func (c *cgroup) cpuPeriodUs() (uint64, error) { + data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_period_us")) + if err != nil { + return 0, err + } + + return parseUint(string(data)) +} + +func (c *cgroup) cpus() ([]uint64, error) { + data, err := iox.ReadText(path.Join(c.cgroups["cpuset"], "cpuset.cpus")) + if err != nil { + return nil, err + } + + return parseUints(string(data)) +} + +func currentCgroup() (*cgroup, error) { + cgroupFile := fmt.Sprintf("/proc/%d/cgroup", os.Getpid()) + lines, err := iox.ReadTextLines(cgroupFile, iox.WithoutBlank()) + if err != nil { + return nil, err + } + + cgroups := make(map[string]string) + for _, line := range lines { + cols := strings.Split(line, ":") + if len(cols) != 3 { + return nil, fmt.Errorf("invalid cgroup line: %s", line) + } + + subsys := cols[1] + // only read cpu staff + if !strings.HasPrefix(subsys, "cpu") { + continue + } + + cgroups[subsys] = path.Join(cgroupDir, subsys) + if strings.Contains(subsys, ",") { + for _, k := range strings.Split(subsys, ",") { + cgroups[k] = path.Join(cgroupDir, k) + } + } + } + + return &cgroup{ + cgroups: cgroups, + }, nil +} + +func parseUint(s string) (uint64, error) { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + if err.(*strconv.NumError).Err == strconv.ErrRange { + return 0, nil + } else { + return 0, fmt.Errorf("cgroup: bad int format: %s", s) + } + } else { + if v < 0 { + return 0, nil + } else { + return uint64(v), nil + } + } +} + +func parseUints(val string) ([]uint64, error) { + if val == "" { + return nil, nil + } + + ints := make(map[uint64]lang.PlaceholderType) + cols := strings.Split(val, ",") + for _, r := range cols { + if strings.Contains(r, "-") { + fields := strings.SplitN(r, "-", 2) + min, err := parseUint(fields[0]) + if err != nil { + return nil, fmt.Errorf("cgroup: bad int list format: %s", val) + } + + max, err := parseUint(fields[1]) + if err != nil { + return nil, fmt.Errorf("cgroup: bad int list format: %s", val) + } + + if max < min { + return nil, fmt.Errorf("cgroup: bad int list format: %s", val) + } + + for i := min; i <= max; i++ { + ints[i] = lang.Placeholder + } + } else { + v, err := parseUint(r) + if err != nil { + return nil, err + } + + ints[v] = lang.Placeholder + } + } + + var sets []uint64 + for k := range ints { + sets = append(sets, k) + } + + return sets, nil +} diff --git a/core/stat/internal/cpu_linux.go b/core/stat/internal/cpu_linux.go new file mode 100644 index 00000000..b8437bfc --- /dev/null +++ b/core/stat/internal/cpu_linux.go @@ -0,0 +1,148 @@ +package internal + +import ( + "errors" + "fmt" + "strings" + "time" + + "zero/core/iox" + "zero/core/lang" +) + +const ( + cpuTicks = 100 + cpuFields = 8 +) + +var ( + preSystem uint64 + preTotal uint64 + quota float64 + cores uint64 +) + +func init() { + cpus, err := perCpuUsage() + lang.Must(err) + cores = uint64(len(cpus)) + + sets, err := cpuSets() + lang.Must(err) + quota = float64(len(sets)) + cq, err := cpuQuota() + if err == nil { + if cq != -1 { + period, err := cpuPeriod() + lang.Must(err) + + limit := float64(cq) / float64(period) + if limit < quota { + quota = limit + } + } + } + + preSystem, err = systemCpuUsage() + lang.Must(err) + + preTotal, err = totalCpuUsage() + lang.Must(err) +} + +func RefreshCpu() uint64 { + total, err := totalCpuUsage() + if err != nil { + return 0 + } + system, err := systemCpuUsage() + if err != nil { + return 0 + } + + var usage uint64 + cpuDelta := total - preTotal + systemDelta := system - preSystem + if cpuDelta > 0 && systemDelta > 0 { + usage = uint64(float64(cpuDelta*cores*1e3) / (float64(systemDelta) * quota)) + } + preSystem = system + preTotal = total + + return usage +} + +func cpuQuota() (int64, error) { + cg, err := currentCgroup() + if err != nil { + return 0, err + } + + return cg.cpuQuotaUs() +} + +func cpuPeriod() (uint64, error) { + cg, err := currentCgroup() + if err != nil { + return 0, err + } + + return cg.cpuPeriodUs() +} + +func cpuSets() ([]uint64, error) { + cg, err := currentCgroup() + if err != nil { + return nil, err + } + + return cg.cpus() +} + +func perCpuUsage() ([]uint64, error) { + cg, err := currentCgroup() + if err != nil { + return nil, err + } + + return cg.acctUsagePerCpu() +} + +func systemCpuUsage() (uint64, error) { + lines, err := iox.ReadTextLines("/proc/stat", iox.WithoutBlank()) + if err != nil { + return 0, err + } + + for _, line := range lines { + fields := strings.Fields(line) + if fields[0] == "cpu" { + if len(fields) < cpuFields { + return 0, fmt.Errorf("bad format of cpu stats") + } + + var totalClockTicks uint64 + for _, i := range fields[1:cpuFields] { + v, err := parseUint(i) + if err != nil { + return 0, err + } + + totalClockTicks += v + } + + return (totalClockTicks * uint64(time.Second)) / cpuTicks, nil + } + } + + return 0, errors.New("bad stats format") +} + +func totalCpuUsage() (usage uint64, err error) { + var cg *cgroup + if cg, err = currentCgroup(); err != nil { + return + } + + return cg.acctUsageAllCpus() +} diff --git a/core/stat/internal/cpu_linux_test.go b/core/stat/internal/cpu_linux_test.go new file mode 100644 index 00000000..c126e92c --- /dev/null +++ b/core/stat/internal/cpu_linux_test.go @@ -0,0 +1,9 @@ +package internal + +import "testing" + +func BenchmarkRefreshCpu(b *testing.B) { + for i := 0; i < b.N; i++ { + RefreshCpu() + } +} diff --git a/core/stat/internal/cpu_other.go b/core/stat/internal/cpu_other.go new file mode 100644 index 00000000..ddd6165d --- /dev/null +++ b/core/stat/internal/cpu_other.go @@ -0,0 +1,7 @@ +// +build !linux + +package internal + +func RefreshCpu() uint64 { + return 0 +} diff --git a/core/stat/metrics.go b/core/stat/metrics.go new file mode 100644 index 00000000..176ce4e5 --- /dev/null +++ b/core/stat/metrics.go @@ -0,0 +1,210 @@ +package stat + +import ( + "os" + "sync" + "time" + + "zero/core/executors" + "zero/core/logx" +) + +var ( + LogInterval = time.Minute + + writerLock sync.Mutex + reportWriter Writer = nil +) + +type ( + Writer interface { + Write(report *StatReport) error + } + + StatReport struct { + Name string `json:"name"` + Timestamp int64 `json:"tm"` + Pid int `json:"pid"` + ReqsPerSecond float32 `json:"qps"` + Drops int `json:"drops"` + Average float32 `json:"avg"` + Median float32 `json:"med"` + Top90th float32 `json:"t90"` + Top99th float32 `json:"t99"` + Top99p9th float32 `json:"t99p9"` + } + + Metrics struct { + executor *executors.PeriodicalExecutor + container *metricsContainer + } +) + +func SetReportWriter(writer Writer) { + writerLock.Lock() + reportWriter = writer + writerLock.Unlock() +} + +func NewMetrics(name string) *Metrics { + container := &metricsContainer{ + name: name, + pid: os.Getpid(), + } + + return &Metrics{ + executor: executors.NewPeriodicalExecutor(LogInterval, container), + container: container, + } +} + +func (m *Metrics) Add(task Task) { + m.executor.Add(task) +} + +func (m *Metrics) AddDrop() { + m.executor.Add(Task{ + Drop: true, + }) +} + +func (m *Metrics) SetName(name string) { + m.executor.Sync(func() { + m.container.name = name + }) +} + +type ( + tasksDurationPair struct { + tasks []Task + duration time.Duration + drops int + } + + metricsContainer struct { + name string + pid int + tasks []Task + duration time.Duration + drops int + } +) + +func (c *metricsContainer) AddTask(v interface{}) bool { + if task, ok := v.(Task); ok { + if task.Drop { + c.drops++ + } else { + c.tasks = append(c.tasks, task) + c.duration += task.Duration + } + } + + return false +} + +func (c *metricsContainer) Execute(v interface{}) { + pair := v.(tasksDurationPair) + tasks := pair.tasks + duration := pair.duration + drops := pair.drops + size := len(tasks) + report := &StatReport{ + Name: c.name, + Timestamp: time.Now().Unix(), + Pid: c.pid, + ReqsPerSecond: float32(size) / float32(LogInterval/time.Second), + Drops: drops, + } + + if size > 0 { + report.Average = float32(duration/time.Millisecond) / float32(size) + + fiftyPercent := size >> 1 + if fiftyPercent > 0 { + top50pTasks := topK(tasks, fiftyPercent) + medianTask := top50pTasks[0] + report.Median = float32(medianTask.Duration) / float32(time.Millisecond) + tenPercent := fiftyPercent / 5 + if tenPercent > 0 { + top10pTasks := topK(tasks, tenPercent) + task90th := top10pTasks[0] + report.Top90th = float32(task90th.Duration) / float32(time.Millisecond) + onePercent := tenPercent / 10 + if onePercent > 0 { + top1pTasks := topK(top10pTasks, onePercent) + task99th := top1pTasks[0] + report.Top99th = float32(task99th.Duration) / float32(time.Millisecond) + pointOnePercent := onePercent / 10 + if pointOnePercent > 0 { + topPointOneTasks := topK(top1pTasks, pointOnePercent) + task99Point9th := topPointOneTasks[0] + report.Top99p9th = float32(task99Point9th.Duration) / float32(time.Millisecond) + } else { + report.Top99p9th = getTopDuration(top1pTasks) + } + } else { + mostDuration := getTopDuration(top10pTasks) + report.Top99th = mostDuration + report.Top99p9th = mostDuration + } + } else { + mostDuration := getTopDuration(tasks) + report.Top90th = mostDuration + report.Top99th = mostDuration + report.Top99p9th = mostDuration + } + } else { + mostDuration := getTopDuration(tasks) + report.Median = mostDuration + report.Top90th = mostDuration + report.Top99th = mostDuration + report.Top99p9th = mostDuration + } + } + + log(report) +} + +func (c *metricsContainer) RemoveAll() interface{} { + tasks := c.tasks + duration := c.duration + drops := c.drops + c.tasks = nil + c.duration = 0 + c.drops = 0 + + return tasksDurationPair{ + tasks: tasks, + duration: duration, + drops: drops, + } +} + +func getTopDuration(tasks []Task) float32 { + top := topK(tasks, 1) + if len(top) < 1 { + return 0 + } else { + return float32(top[0].Duration) / float32(time.Millisecond) + } +} + +func log(report *StatReport) { + writeReport(report) + logx.Statf("(%s) - qps: %.1f/s, drops: %d, avg time: %.1fms, med: %.1fms, "+ + "90th: %.1fms, 99th: %.1fms, 99.9th: %.1fms", + report.Name, report.ReqsPerSecond, report.Drops, report.Average, report.Median, + report.Top90th, report.Top99th, report.Top99p9th) +} + +func writeReport(report *StatReport) { + writerLock.Lock() + defer writerLock.Unlock() + + if reportWriter != nil { + if err := reportWriter.Write(report); err != nil { + logx.Error(err) + } + } +} diff --git a/core/stat/remotewriter.go b/core/stat/remotewriter.go new file mode 100644 index 00000000..b6119575 --- /dev/null +++ b/core/stat/remotewriter.go @@ -0,0 +1,48 @@ +package stat + +import ( + "bytes" + "encoding/json" + "errors" + "net/http" + "time" + + "zero/core/logx" +) + +const httpTimeout = time.Second * 5 + +var ErrWriteFailed = errors.New("submit failed") + +type RemoteWriter struct { + endpoint string +} + +func NewRemoteWriter(endpoint string) Writer { + return &RemoteWriter{ + endpoint: endpoint, + } +} + +func (rw *RemoteWriter) Write(report *StatReport) error { + bs, err := json.Marshal(report) + if err != nil { + return err + } + + client := &http.Client{ + Timeout: httpTimeout, + } + resp, err := client.Post(rw.endpoint, "application/json", bytes.NewBuffer(bs)) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logx.Errorf("write report failed, code: %d, reason: %s", resp.StatusCode, resp.Status) + return ErrWriteFailed + } + + return nil +} diff --git a/core/stat/task.go b/core/stat/task.go new file mode 100644 index 00000000..41c1ee22 --- /dev/null +++ b/core/stat/task.go @@ -0,0 +1,9 @@ +package stat + +import "time" + +type Task struct { + Drop bool + Duration time.Duration + Description string +} diff --git a/core/stat/topk.go b/core/stat/topk.go new file mode 100644 index 00000000..6457960d --- /dev/null +++ b/core/stat/topk.go @@ -0,0 +1,45 @@ +package stat + +import "container/heap" + +type taskHeap []Task + +func (h *taskHeap) Len() int { + return len(*h) +} + +func (h *taskHeap) Less(i, j int) bool { + return (*h)[i].Duration < (*h)[j].Duration +} + +func (h *taskHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +func (h *taskHeap) Push(x interface{}) { + *h = append(*h, x.(Task)) +} + +func (h *taskHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +func topK(all []Task, k int) []Task { + h := new(taskHeap) + heap.Init(h) + + for _, each := range all { + if h.Len() < k { + heap.Push(h, each) + } else if (*h)[0].Duration < each.Duration { + heap.Pop(h) + heap.Push(h, each) + } + } + + return *h +} diff --git a/core/stat/topk_test.go b/core/stat/topk_test.go new file mode 100644 index 00000000..06d2e140 --- /dev/null +++ b/core/stat/topk_test.go @@ -0,0 +1,62 @@ +package stat + +import ( + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + numSamples = 10000 + topNum = 100 +) + +var samples []Task + +func init() { + for i := 0; i < numSamples; i++ { + task := Task{ + Duration: time.Duration(rand.Int63()), + } + samples = append(samples, task) + } +} + +func TestTopK(t *testing.T) { + tasks := []Task{ + {false, 1, "a"}, + {false, 4, "a"}, + {false, 2, "a"}, + {false, 5, "a"}, + {false, 9, "a"}, + {false, 10, "a"}, + {false, 12, "a"}, + {false, 3, "a"}, + {false, 6, "a"}, + {false, 11, "a"}, + {false, 8, "a"}, + } + + result := topK(tasks, 3) + if len(result) != 3 { + t.Fail() + } + + set := make(map[time.Duration]struct{}) + for _, each := range result { + set[each.Duration] = struct{}{} + } + + for _, v := range []time.Duration{10, 11, 12} { + _, ok := set[v] + assert.True(t, ok) + } +} + +func BenchmarkTopkHeap(b *testing.B) { + for i := 0; i < b.N; i++ { + topK(samples, topNum) + } +} diff --git a/core/stat/usage.go b/core/stat/usage.go new file mode 100644 index 00000000..06b83653 --- /dev/null +++ b/core/stat/usage.go @@ -0,0 +1,60 @@ +package stat + +import ( + "runtime" + "sync/atomic" + "time" + + "zero/core/logx" + "zero/core/stat/internal" + "zero/core/threading" +) + +const ( + // 250ms and 0.95 as beta will count the average cpu load for past 5 seconds + cpuRefreshInterval = time.Millisecond * 250 + allRefreshInterval = time.Minute + // moving average beta hyperparameter + beta = 0.95 +) + +var cpuUsage int64 + +func init() { + go func() { + cpuTicker := time.NewTicker(cpuRefreshInterval) + defer cpuTicker.Stop() + allTicker := time.NewTicker(allRefreshInterval) + defer allTicker.Stop() + + for { + select { + case <-cpuTicker.C: + threading.RunSafe(func() { + curUsage := internal.RefreshCpu() + prevUsage := atomic.LoadInt64(&cpuUsage) + // cpu = cpuᵗ⁻¹ * beta + cpuᵗ * (1 - beta) + usage := int64(float64(prevUsage)*beta + float64(curUsage)*(1-beta)) + atomic.StoreInt64(&cpuUsage, usage) + }) + case <-allTicker.C: + printUsage() + } + } + }() +} + +func CpuUsage() int64 { + return atomic.LoadInt64(&cpuUsage) +} + +func bToMb(b uint64) float32 { + return float32(b) / 1024 / 1024 +} + +func printUsage() { + var m runtime.MemStats + runtime.ReadMemStats(&m) + logx.Statf("CPU: %dm, MEMORY: Alloc=%.1fMi, TotalAlloc=%.1fMi, Sys=%.1fMi, NumGC=%d", + CpuUsage(), bToMb(m.Alloc), bToMb(m.TotalAlloc), bToMb(m.Sys), m.NumGC) +} diff --git a/core/stores/cache/cacheconf.go b/core/stores/cache/cacheconf.go new file mode 100644 index 00000000..7c706897 --- /dev/null +++ b/core/stores/cache/cacheconf.go @@ -0,0 +1,5 @@ +package cache + +import "zero/core/stores/internal" + +type CacheConf = internal.ClusterConf diff --git a/core/stores/cache/cacheopt.go b/core/stores/cache/cacheopt.go new file mode 100644 index 00000000..ef92a243 --- /dev/null +++ b/core/stores/cache/cacheopt.go @@ -0,0 +1,21 @@ +package cache + +import ( + "time" + + "zero/core/stores/internal" +) + +type Option = internal.Option + +func WithExpiry(expiry time.Duration) Option { + return func(o *internal.Options) { + o.Expiry = expiry + } +} + +func WithNotFoundExpiry(expiry time.Duration) Option { + return func(o *internal.Options) { + o.NotFoundExpiry = expiry + } +} diff --git a/core/stores/clickhouse/clickhouse.go b/core/stores/clickhouse/clickhouse.go new file mode 100644 index 00000000..56d6d808 --- /dev/null +++ b/core/stores/clickhouse/clickhouse.go @@ -0,0 +1,13 @@ +package clickhouse + +import ( + "zero/core/stores/sqlx" + + _ "github.com/kshvakov/clickhouse" +) + +const clickHouseDriverName = "clickhouse" + +func New(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn { + return sqlx.NewSqlConn(clickHouseDriverName, datasource, opts...) +} diff --git a/core/stores/internal/cache.go b/core/stores/internal/cache.go new file mode 100644 index 00000000..5a7bcdec --- /dev/null +++ b/core/stores/internal/cache.go @@ -0,0 +1,129 @@ +package internal + +import ( + "fmt" + "log" + "time" + + "zero/core/errorx" + "zero/core/hash" + "zero/core/syncx" +) + +type ( + Cache interface { + DelCache(keys ...string) error + GetCache(key string, v interface{}) error + SetCache(key string, v interface{}) error + SetCacheWithExpire(key string, v interface{}, expire time.Duration) error + Take(v interface{}, key string, query func(v interface{}) error) error + TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error + } + + cacheCluster struct { + dispatcher *hash.ConsistentHash + errNotFound error + } +) + +func NewCache(c ClusterConf, barrier syncx.SharedCalls, st *CacheStat, errNotFound error, + opts ...Option) Cache { + if len(c) == 0 || TotalWeights(c) <= 0 { + log.Fatal("no cache nodes") + } + + if len(c) == 1 { + return NewCacheNode(c[0].NewRedis(), barrier, st, errNotFound, opts...) + } + + dispatcher := hash.NewConsistentHash() + for _, node := range c { + cn := NewCacheNode(node.NewRedis(), barrier, st, errNotFound, opts...) + dispatcher.AddWithWeight(cn, node.Weight) + } + + return cacheCluster{ + dispatcher: dispatcher, + errNotFound: errNotFound, + } +} + +func (cc cacheCluster) DelCache(keys ...string) error { + switch len(keys) { + case 0: + return nil + case 1: + key := keys[0] + c, ok := cc.dispatcher.Get(key) + if !ok { + return cc.errNotFound + } + + return c.(Cache).DelCache(key) + default: + var be errorx.BatchError + nodes := make(map[interface{}][]string) + for _, key := range keys { + c, ok := cc.dispatcher.Get(key) + if !ok { + be.Add(fmt.Errorf("key %q not found", key)) + continue + } + + nodes[c] = append(nodes[c], key) + } + for c, ks := range nodes { + if err := c.(Cache).DelCache(ks...); err != nil { + be.Add(err) + } + } + + return be.Err() + } +} + +func (cc cacheCluster) GetCache(key string, v interface{}) error { + c, ok := cc.dispatcher.Get(key) + if !ok { + return cc.errNotFound + } + + return c.(Cache).GetCache(key, v) +} + +func (cc cacheCluster) SetCache(key string, v interface{}) error { + c, ok := cc.dispatcher.Get(key) + if !ok { + return cc.errNotFound + } + + return c.(Cache).SetCache(key, v) +} + +func (cc cacheCluster) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error { + c, ok := cc.dispatcher.Get(key) + if !ok { + return cc.errNotFound + } + + return c.(Cache).SetCacheWithExpire(key, v, expire) +} + +func (cc cacheCluster) Take(v interface{}, key string, query func(v interface{}) error) error { + c, ok := cc.dispatcher.Get(key) + if !ok { + return cc.errNotFound + } + + return c.(Cache).Take(v, key, query) +} + +func (cc cacheCluster) TakeWithExpire(v interface{}, key string, + query func(v interface{}, expire time.Duration) error) error { + c, ok := cc.dispatcher.Get(key) + if !ok { + return cc.errNotFound + } + + return c.(Cache).TakeWithExpire(v, key, query) +} diff --git a/core/stores/internal/cache_test.go b/core/stores/internal/cache_test.go new file mode 100644 index 00000000..e7b0a9b0 --- /dev/null +++ b/core/stores/internal/cache_test.go @@ -0,0 +1,201 @@ +package internal + +import ( + "encoding/json" + "fmt" + "math" + "strconv" + "testing" + "time" + + "zero/core/errorx" + "zero/core/hash" + "zero/core/stores/redis" + "zero/core/syncx" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +type mockedNode struct { + vals map[string][]byte + errNotFound error +} + +func (mc *mockedNode) DelCache(keys ...string) error { + var be errorx.BatchError + for _, key := range keys { + if _, ok := mc.vals[key]; !ok { + be.Add(mc.errNotFound) + } else { + delete(mc.vals, key) + } + } + return be.Err() +} + +func (mc *mockedNode) GetCache(key string, v interface{}) error { + bs, ok := mc.vals[key] + if ok { + return json.Unmarshal(bs, v) + } + + return mc.errNotFound +} + +func (mc *mockedNode) SetCache(key string, v interface{}) error { + data, err := json.Marshal(v) + if err != nil { + return err + } + + mc.vals[key] = data + return nil +} + +func (mc *mockedNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error { + return mc.SetCache(key, v) +} + +func (mc *mockedNode) Take(v interface{}, key string, query func(v interface{}) error) error { + if _, ok := mc.vals[key]; ok { + return mc.GetCache(key, v) + } + + if err := query(v); err != nil { + return err + } + + return mc.SetCache(key, v) +} + +func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error { + return mc.Take(v, key, func(v interface{}) error { + return query(v, 0) + }) +} + +func TestCache_SetDel(t *testing.T) { + const total = 1000 + r1 := miniredis.NewMiniRedis() + assert.Nil(t, r1.Start()) + defer r1.Close() + r2 := miniredis.NewMiniRedis() + assert.Nil(t, r2.Start()) + defer r2.Close() + conf := ClusterConf{ + { + RedisConf: redis.RedisConf{ + Host: r1.Addr(), + Type: redis.NodeType, + }, + Weight: 100, + }, + { + RedisConf: redis.RedisConf{ + Host: r2.Addr(), + Type: redis.NodeType, + }, + Weight: 100, + }, + } + c := NewCache(conf, syncx.NewSharedCalls(), NewCacheStat("mock"), errPlaceholder) + for i := 0; i < total; i++ { + if i%2 == 0 { + assert.Nil(t, c.SetCache(fmt.Sprintf("key/%d", i), i)) + } else { + assert.Nil(t, c.SetCacheWithExpire(fmt.Sprintf("key/%d", i), i, 0)) + } + } + for i := 0; i < total; i++ { + var v int + assert.Nil(t, c.GetCache(fmt.Sprintf("key/%d", i), &v)) + assert.Equal(t, i, v) + } + for i := 0; i < total; i++ { + assert.Nil(t, c.DelCache(fmt.Sprintf("key/%d", i))) + } + for i := 0; i < total; i++ { + var v int + assert.Equal(t, errPlaceholder, c.GetCache(fmt.Sprintf("key/%d", i), &v)) + assert.Equal(t, 0, v) + } +} + +func TestCache_Balance(t *testing.T) { + const ( + numNodes = 100 + total = 10000 + ) + dispatcher := hash.NewConsistentHash() + maps := make([]map[string][]byte, numNodes) + for i := 0; i < numNodes; i++ { + maps[i] = map[string][]byte{ + strconv.Itoa(i): []byte(strconv.Itoa(i)), + } + } + for i := 0; i < numNodes; i++ { + dispatcher.AddWithWeight(&mockedNode{ + vals: maps[i], + errNotFound: errPlaceholder, + }, 100) + } + + c := cacheCluster{ + dispatcher: dispatcher, + errNotFound: errPlaceholder, + } + for i := 0; i < total; i++ { + assert.Nil(t, c.SetCache(strconv.Itoa(i), i)) + } + + counts := make(map[int]int) + for i, m := range maps { + counts[i] = len(m) + } + entropy := calcEntropy(counts, total) + assert.True(t, len(counts) > 1) + assert.True(t, entropy > .95, fmt.Sprintf("entropy should be greater than 0.95, but got %.2f", entropy)) + + for i := 0; i < total; i++ { + var v int + assert.Nil(t, c.GetCache(strconv.Itoa(i), &v)) + assert.Equal(t, i, v) + } + + for i := 0; i < total/10; i++ { + assert.Nil(t, c.DelCache(strconv.Itoa(i*10), strconv.Itoa(i*10+1), strconv.Itoa(i*10+2))) + assert.Nil(t, c.DelCache(strconv.Itoa(i*10+9))) + } + + var count int + for i := 0; i < total/10; i++ { + var val int + if i%2 == 0 { + assert.Nil(t, c.Take(&val, strconv.Itoa(i*10), func(v interface{}) error { + *v.(*int) = i + count++ + return nil + })) + } else { + assert.Nil(t, c.TakeWithExpire(&val, strconv.Itoa(i*10), func(v interface{}, expire time.Duration) error { + *v.(*int) = i + count++ + return nil + })) + } + assert.Equal(t, i, val) + } + assert.Equal(t, total/10, count) +} + +func calcEntropy(m map[int]int, total int) float64 { + var entropy float64 + + for _, v := range m { + proba := float64(v) / float64(total) + entropy -= proba * math.Log2(proba) + } + + return entropy / math.Log2(float64(len(m))) +} diff --git a/core/stores/internal/cachenode.go b/core/stores/internal/cachenode.go new file mode 100644 index 00000000..e74a750f --- /dev/null +++ b/core/stores/internal/cachenode.go @@ -0,0 +1,208 @@ +package internal + +import ( + "encoding/json" + "errors" + "fmt" + "math/rand" + "sync" + "time" + + "zero/core/logx" + "zero/core/mathx" + "zero/core/stat" + "zero/core/stores/redis" + "zero/core/syncx" +) + +const ( + notFoundPlaceholder = "*" + // make the expiry unstable to avoid lots of cached items expire at the same time + // make the unstable expiry to be [0.95, 1.05] * seconds + expiryDeviation = 0.05 +) + +// indicates there is no such value associate with the key +var errPlaceholder = errors.New("placeholder") + +type cacheNode struct { + rds *redis.Redis + expiry time.Duration + notFoundExpiry time.Duration + barrier syncx.SharedCalls + r *rand.Rand + lock *sync.Mutex + unstableExpiry mathx.Unstable + stat *CacheStat + errNotFound error +} + +func NewCacheNode(rds *redis.Redis, barrier syncx.SharedCalls, st *CacheStat, + errNotFound error, opts ...Option) Cache { + o := newOptions(opts...) + return cacheNode{ + rds: rds, + expiry: o.Expiry, + notFoundExpiry: o.NotFoundExpiry, + barrier: barrier, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + lock: new(sync.Mutex), + unstableExpiry: mathx.NewUnstable(expiryDeviation), + stat: st, + errNotFound: errNotFound, + } +} + +func (c cacheNode) DelCache(keys ...string) error { + if len(keys) == 0 { + return nil + } + + if _, err := c.rds.Del(keys...); err != nil { + logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err) + c.asyncRetryDelCache(keys...) + } + + return nil +} + +func (c cacheNode) GetCache(key string, v interface{}) error { + if err := c.doGetCache(key, v); err == errPlaceholder { + return c.errNotFound + } else { + return err + } +} + +func (c cacheNode) SetCache(key string, v interface{}) error { + return c.SetCacheWithExpire(key, v, c.aroundDuration(c.expiry)) +} + +func (c cacheNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error { + data, err := json.Marshal(v) + if err != nil { + return err + } + + return c.rds.Setex(key, string(data), int(expire.Seconds())) +} + +func (c cacheNode) String() string { + return c.rds.Addr +} + +func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error { + return c.doTake(v, key, query, func(v interface{}) error { + return c.SetCache(key, v) + }) +} + +func (c cacheNode) TakeWithExpire(v interface{}, key string, + query func(v interface{}, expire time.Duration) error) error { + expire := c.aroundDuration(c.expiry) + return c.doTake(v, key, func(v interface{}) error { + return query(v, expire) + }, func(v interface{}) error { + return c.SetCacheWithExpire(key, v, expire) + }) +} + +func (c cacheNode) aroundDuration(duration time.Duration) time.Duration { + return c.unstableExpiry.AroundDuration(duration) +} + +func (c cacheNode) asyncRetryDelCache(keys ...string) { + AddCleanTask(func() error { + _, err := c.rds.Del(keys...) + return err + }, keys...) +} + +func (c cacheNode) doGetCache(key string, v interface{}) error { + c.stat.IncrementTotal() + data, err := c.rds.Get(key) + if err != nil { + c.stat.IncrementMiss() + return err + } + + if len(data) == 0 { + c.stat.IncrementMiss() + return c.errNotFound + } + + c.stat.IncrementHit() + if data == notFoundPlaceholder { + return errPlaceholder + } + + return c.processCache(key, data, v) +} + +func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) error, + cacheVal func(v interface{}) error) error { + val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) { + if err := c.doGetCache(key, v); err != nil { + if err == errPlaceholder { + return nil, c.errNotFound + } else if err != c.errNotFound { + // why we just return the error instead of query from db, + // because we don't allow the disaster pass to the dbs. + // fail fast, in case we bring down the dbs. + return nil, err + } + + if err = query(v); err == c.errNotFound { + if err = c.setCacheWithNotFound(key); err != nil { + logx.Error(err) + } + + return nil, c.errNotFound + } else if err != nil { + c.stat.IncrementDbFails() + return nil, err + } + + if err = cacheVal(v); err != nil { + logx.Error(err) + } + } + + return json.Marshal(v) + }) + if err != nil { + return err + } + if fresh { + return nil + } else { + // got the result from previous ongoing query + c.stat.IncrementTotal() + c.stat.IncrementHit() + } + + return json.Unmarshal(val.([]byte), v) +} + +func (c cacheNode) processCache(key string, data string, v interface{}) error { + err := json.Unmarshal([]byte(data), v) + if err == nil { + return nil + } + + report := fmt.Sprintf("unmarshal cache, node: %s, key: %s, value: %s, error: %v", + c.rds.Addr, key, data, err) + logx.Error(report) + stat.Report(report) + if _, e := c.rds.Del(key); e != nil { + logx.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v", + c.rds.Addr, key, data, e) + } + + // returns errNotFound to reload the value by the given queryFn + return c.errNotFound +} + +func (c cacheNode) setCacheWithNotFound(key string) error { + return c.rds.Setex(key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds())) +} diff --git a/core/stores/internal/cachenode_test.go b/core/stores/internal/cachenode_test.go new file mode 100644 index 00000000..1a8c6405 --- /dev/null +++ b/core/stores/internal/cachenode_test.go @@ -0,0 +1,66 @@ +package internal + +import ( + "errors" + "math/rand" + "sync" + "testing" + "time" + + "zero/core/logx" + "zero/core/mathx" + "zero/core/stat" + "zero/core/stores/redis" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +func init() { + logx.Disable() + stat.SetReporter(nil) +} + +func TestCacheNode_DelCache(t *testing.T) { + s, err := miniredis.Run() + assert.Nil(t, err) + defer s.Close() + + cn := cacheNode{ + rds: redis.NewRedis(s.Addr(), redis.NodeType), + r: rand.New(rand.NewSource(time.Now().UnixNano())), + lock: new(sync.Mutex), + unstableExpiry: mathx.NewUnstable(expiryDeviation), + stat: NewCacheStat("any"), + errNotFound: errors.New("any"), + } + assert.Nil(t, cn.DelCache()) + assert.Nil(t, cn.DelCache([]string{}...)) + assert.Nil(t, cn.DelCache(make([]string, 0)...)) + cn.SetCache("first", "one") + assert.Nil(t, cn.DelCache("first")) + cn.SetCache("first", "one") + cn.SetCache("second", "two") + assert.Nil(t, cn.DelCache("first", "second")) +} + +func TestCacheNode_InvalidCache(t *testing.T) { + s, err := miniredis.Run() + assert.Nil(t, err) + defer s.Close() + + cn := cacheNode{ + rds: redis.NewRedis(s.Addr(), redis.NodeType), + r: rand.New(rand.NewSource(time.Now().UnixNano())), + lock: new(sync.Mutex), + unstableExpiry: mathx.NewUnstable(expiryDeviation), + stat: NewCacheStat("any"), + errNotFound: errors.New("any"), + } + s.Set("any", "value") + var str string + assert.NotNil(t, cn.GetCache("any", &str)) + assert.Equal(t, "", str) + _, err = s.Get("any") + assert.Equal(t, miniredis.ErrKeyNotFound, err) +} diff --git a/core/stores/internal/cacheopt.go b/core/stores/internal/cacheopt.go new file mode 100644 index 00000000..908c0bc0 --- /dev/null +++ b/core/stores/internal/cacheopt.go @@ -0,0 +1,33 @@ +package internal + +import "time" + +const ( + defaultExpiry = time.Hour * 24 * 7 + defaultNotFoundExpiry = time.Minute +) + +type ( + Options struct { + Expiry time.Duration + NotFoundExpiry time.Duration + } + + Option func(o *Options) +) + +func newOptions(opts ...Option) Options { + var o Options + for _, opt := range opts { + opt(&o) + } + + if o.Expiry <= 0 { + o.Expiry = defaultExpiry + } + if o.NotFoundExpiry <= 0 { + o.NotFoundExpiry = defaultNotFoundExpiry + } + + return o +} diff --git a/core/stores/internal/cachestat.go b/core/stores/internal/cachestat.go new file mode 100644 index 00000000..238ded6f --- /dev/null +++ b/core/stores/internal/cachestat.go @@ -0,0 +1,67 @@ +package internal + +import ( + "sync/atomic" + "time" + + "zero/core/logx" +) + +const statInterval = time.Minute + +type CacheStat struct { + name string + // export the fields to let the unit tests working, + // reside in internal package, doesn't matter. + Total uint64 + Hit uint64 + Miss uint64 + DbFails uint64 +} + +func NewCacheStat(name string) *CacheStat { + ret := &CacheStat{ + name: name, + } + go ret.statLoop() + + return ret +} + +func (cs *CacheStat) IncrementTotal() { + atomic.AddUint64(&cs.Total, 1) +} + +func (cs *CacheStat) IncrementHit() { + atomic.AddUint64(&cs.Hit, 1) +} + +func (cs *CacheStat) IncrementMiss() { + atomic.AddUint64(&cs.Miss, 1) +} + +func (cs *CacheStat) IncrementDbFails() { + atomic.AddUint64(&cs.DbFails, 1) +} + +func (cs *CacheStat) statLoop() { + ticker := time.NewTicker(statInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + total := atomic.SwapUint64(&cs.Total, 0) + if total == 0 { + continue + } + + hit := atomic.SwapUint64(&cs.Hit, 0) + percent := 100 * float32(hit) / float32(total) + miss := atomic.SwapUint64(&cs.Miss, 0) + dbf := atomic.SwapUint64(&cs.DbFails, 0) + logx.Statf("dbcache(%s) - qpm: %d, hit_ratio: %.1f%%, hit: %d, miss: %d, db_fails: %d", + cs.name, total, percent, hit, miss, dbf) + } + } +} diff --git a/core/stores/internal/cleaner.go b/core/stores/internal/cleaner.go new file mode 100644 index 00000000..1dfc1fe1 --- /dev/null +++ b/core/stores/internal/cleaner.go @@ -0,0 +1,85 @@ +package internal + +import ( + "fmt" + "time" + + "zero/core/collection" + "zero/core/lang" + "zero/core/logx" + "zero/core/proc" + "zero/core/stat" + "zero/core/stringx" + "zero/core/threading" +) + +const ( + timingWheelSlots = 300 + cleanWorkers = 5 + taskKeyLen = 8 +) + +var ( + timingWheel *collection.TimingWheel + taskRunner = threading.NewTaskRunner(cleanWorkers) +) + +type delayTask struct { + delay time.Duration + task func() error + keys []string +} + +func init() { + var err error + timingWheel, err = collection.NewTimingWheel(time.Second, timingWheelSlots, clean) + lang.Must(err) + + proc.AddShutdownListener(func() { + timingWheel.Drain(clean) + }) +} + +func AddCleanTask(task func() error, keys ...string) { + timingWheel.SetTimer(stringx.Randn(taskKeyLen), delayTask{ + delay: time.Second, + task: task, + keys: keys, + }, time.Second) +} + +func clean(key, value interface{}) { + taskRunner.Schedule(func() { + dt := value.(delayTask) + err := dt.task() + if err == nil { + return + } + + next, ok := nextDelay(dt.delay) + if ok { + dt.delay = next + timingWheel.SetTimer(key, dt, next) + } else { + msg := fmt.Sprintf("retried but failed to clear cache with keys: %q, error: %v", + formatKeys(dt.keys), err) + logx.Error(msg) + stat.Report(msg) + } + }) +} + +func nextDelay(delay time.Duration) (time.Duration, bool) { + switch delay { + case time.Second: + return time.Second * 5, true + case time.Second * 5: + return time.Minute, true + case time.Minute: + return time.Minute * 5, true + case time.Minute * 5: + return time.Hour, true + default: + return 0, false + } +} diff --git a/core/stores/internal/config.go b/core/stores/internal/config.go new file mode 100644 index 00000000..63353244 --- /dev/null +++ b/core/stores/internal/config.go @@ -0,0 +1,12 @@ +package internal + +import "zero/core/stores/redis" + +type ( + ClusterConf []NodeConf + + NodeConf struct { + redis.RedisConf + Weight int `json:",default=100"` + } +) diff --git a/core/stores/internal/util.go b/core/stores/internal/util.go new file mode 100644 index 00000000..bf77e81a --- /dev/null +++ b/core/stores/internal/util.go @@ -0,0 +1,22 @@ +package internal + +import "strings" + +const keySeparator = "," + +func TotalWeights(c []NodeConf) int { + var weights int + + for _, node := range c { + if node.Weight < 0 { + node.Weight = 0 + } + weights += node.Weight + } + + return weights +} + +func formatKeys(keys []string) string { + return strings.Join(keys, keySeparator) +} diff --git a/core/stores/kv/config.go b/core/stores/kv/config.go new file mode 100644 index 00000000..38361574 --- /dev/null +++ b/core/stores/kv/config.go @@ -0,0 +1,5 @@ +package kv + +import "zero/core/stores/internal" + +type KvConf = internal.ClusterConf diff --git a/core/stores/kv/store.go b/core/stores/kv/store.go new file mode 100644 index 00000000..02967b1e --- /dev/null +++ b/core/stores/kv/store.go @@ -0,0 +1,653 @@ +package kv + +import ( + "errors" + "log" + + "zero/core/errorx" + "zero/core/hash" + "zero/core/stores/internal" + "zero/core/stores/redis" +) + +var ErrNoRedisNode = errors.New("no redis node") + +type ( + Store interface { + Del(keys ...string) (int, error) + Eval(script string, key string, args ...interface{}) (interface{}, error) + Exists(key string) (bool, error) + Expire(key string, seconds int) error + Expireat(key string, expireTime int64) error + Get(key string) (string, error) + Hdel(key, field string) (bool, error) + Hexists(key, field string) (bool, error) + Hget(key, field string) (string, error) + Hgetall(key string) (map[string]string, error) + Hincrby(key, field string, increment int) (int, error) + Hkeys(key string) ([]string, error) + Hlen(key string) (int, error) + Hmget(key string, fields ...string) ([]string, error) + Hset(key, field, value string) error + Hsetnx(key, field, value string) (bool, error) + Hmset(key string, fieldsAndValues map[string]string) error + Hvals(key string) ([]string, error) + Incr(key string) (int64, error) + Incrby(key string, increment int64) (int64, error) + Llen(key string) (int, error) + Lpop(key string) (string, error) + Lpush(key string, values ...interface{}) (int, error) + Lrange(key string, start int, stop int) ([]string, error) + Lrem(key string, count int, value string) (int, error) + Persist(key string) (bool, error) + Pfadd(key string, values ...interface{}) (bool, error) + Pfcount(key string) (int64, error) + Rpush(key string, values ...interface{}) (int, error) + Sadd(key string, values ...interface{}) (int, error) + Scard(key string) (int64, error) + Set(key string, value string) error + Setex(key, value string, seconds int) error + Setnx(key, value string) (bool, error) + SetnxEx(key, value string, seconds int) (bool, error) + Sismember(key string, value interface{}) (bool, error) + Smembers(key string) ([]string, error) + Spop(key string) (string, error) + Srandmember(key string, count int) ([]string, error) + Srem(key string, values ...interface{}) (int, error) + Sscan(key string, cursor uint64, match string, count int64) (keys []string, cur uint64, err error) + Ttl(key string) (int, error) + Zadd(key string, score int64, value string) (bool, error) + Zadds(key string, ps ...redis.Pair) (int64, error) + Zcard(key string) (int, error) + Zcount(key string, start, stop int64) (int, error) + Zincrby(key string, increment int64, field string) (int64, error) + Zrange(key string, start, stop int64) ([]string, error) + ZrangeWithScores(key string, start, stop int64) ([]redis.Pair, error) + ZrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) + ZrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error) + Zrank(key, field string) (int64, error) + Zrem(key string, values ...interface{}) (int, error) + Zremrangebyrank(key string, start, stop int64) (int, error) + Zremrangebyscore(key string, start, stop int64) (int, error) + Zrevrange(key string, start, stop int64) ([]string, error) + ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) + ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error) + Zscore(key string, value string) (int64, error) + } + + clusterStore struct { + dispatcher *hash.ConsistentHash + } +) + +func NewStore(c KvConf) Store { + if len(c) == 0 || internal.TotalWeights(c) <= 0 { + log.Fatal("no cache nodes") + } + + // even if only one node, we chose to use consistent hash, + // because Store and redis.Redis has different methods. + dispatcher := hash.NewConsistentHash() + for _, node := range c { + cn := node.NewRedis() + dispatcher.AddWithWeight(cn, node.Weight) + } + + return clusterStore{ + dispatcher: dispatcher, + } +} + +func (cs clusterStore) Del(keys ...string) (int, error) { + var val int + var be errorx.BatchError + + for _, key := range keys { + node, e := cs.getRedis(key) + if e != nil { + be.Add(e) + continue + } + + if v, e := node.Del(key); e != nil { + be.Add(e) + } else { + val += v + } + } + + return val, be.Err() +} + +func (cs clusterStore) Eval(script string, key string, args ...interface{}) (interface{}, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Eval(script, []string{key}, args...) +} + +func (cs clusterStore) Exists(key string) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Exists(key) +} + +func (cs clusterStore) Expire(key string, seconds int) error { + node, err := cs.getRedis(key) + if err != nil { + return err + } + + return node.Expire(key, seconds) +} + +func (cs clusterStore) Expireat(key string, expireTime int64) error { + node, err := cs.getRedis(key) + if err != nil { + return err + } + + return node.Expireat(key, expireTime) +} + +func (cs clusterStore) Get(key string) (string, error) { + node, err := cs.getRedis(key) + if err != nil { + return "", err + } + + return node.Get(key) +} + +func (cs clusterStore) Hdel(key, field string) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Hdel(key, field) +} + +func (cs clusterStore) Hexists(key, field string) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Hexists(key, field) +} + +func (cs clusterStore) Hget(key, field string) (string, error) { + node, err := cs.getRedis(key) + if err != nil { + return "", err + } + + return node.Hget(key, field) +} + +func (cs clusterStore) Hgetall(key string) (map[string]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Hgetall(key) +} + +func (cs clusterStore) Hincrby(key, field string, increment int) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Hincrby(key, field, increment) +} + +func (cs clusterStore) Hkeys(key string) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Hkeys(key) +} + +func (cs clusterStore) Hlen(key string) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Hlen(key) +} + +func (cs clusterStore) Hmget(key string, fields ...string) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Hmget(key, fields...) +} + +func (cs clusterStore) Hset(key, field, value string) error { + node, err := cs.getRedis(key) + if err != nil { + return err + } + + return node.Hset(key, field, value) +} + +func (cs clusterStore) Hsetnx(key, field, value string) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Hsetnx(key, field, value) +} + +func (cs clusterStore) Hmset(key string, fieldsAndValues map[string]string) error { + node, err := cs.getRedis(key) + if err != nil { + return err + } + + return node.Hmset(key, fieldsAndValues) +} + +func (cs clusterStore) Hvals(key string) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Hvals(key) +} + +func (cs clusterStore) Incr(key string) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Incr(key) +} + +func (cs clusterStore) Incrby(key string, increment int64) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Incrby(key, increment) +} + +func (cs clusterStore) Llen(key string) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Llen(key) +} + +func (cs clusterStore) Lpop(key string) (string, error) { + node, err := cs.getRedis(key) + if err != nil { + return "", err + } + + return node.Lpop(key) +} + +func (cs clusterStore) Lpush(key string, values ...interface{}) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Lpush(key, values...) +} + +func (cs clusterStore) Lrange(key string, start int, stop int) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Lrange(key, start, stop) +} + +func (cs clusterStore) Lrem(key string, count int, value string) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Lrem(key, count, value) +} + +func (cs clusterStore) Persist(key string) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Persist(key) +} + +func (cs clusterStore) Pfadd(key string, values ...interface{}) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Pfadd(key, values...) +} + +func (cs clusterStore) Pfcount(key string) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Pfcount(key) +} + +func (cs clusterStore) Rpush(key string, values ...interface{}) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Rpush(key, values...) +} + +func (cs clusterStore) Sadd(key string, values ...interface{}) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Sadd(key, values...) +} + +func (cs clusterStore) Scard(key string) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Scard(key) +} + +func (cs clusterStore) Set(key string, value string) error { + node, err := cs.getRedis(key) + if err != nil { + return err + } + + return node.Set(key, value) +} + +func (cs clusterStore) Setex(key, value string, seconds int) error { + node, err := cs.getRedis(key) + if err != nil { + return err + } + + return node.Setex(key, value, seconds) +} + +func (cs clusterStore) Setnx(key, value string) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Setnx(key, value) +} + +func (cs clusterStore) SetnxEx(key, value string, seconds int) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.SetnxEx(key, value, seconds) +} + +func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Sismember(key, value) +} + +func (cs clusterStore) Smembers(key string) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Smembers(key) +} + +func (cs clusterStore) Spop(key string) (string, error) { + node, err := cs.getRedis(key) + if err != nil { + return "", err + } + + return node.Spop(key) +} + +func (cs clusterStore) Srandmember(key string, count int) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Srandmember(key, count) +} + +func (cs clusterStore) Srem(key string, values ...interface{}) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Srem(key, values...) +} + +func (cs clusterStore) Sscan(key string, cursor uint64, match string, count int64) ( + keys []string, cur uint64, err error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, 0, err + } + + return node.Sscan(key, cursor, match, count) +} + +func (cs clusterStore) Ttl(key string) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Ttl(key) +} + +func (cs clusterStore) Zadd(key string, score int64, value string) (bool, error) { + node, err := cs.getRedis(key) + if err != nil { + return false, err + } + + return node.Zadd(key, score, value) +} + +func (cs clusterStore) Zadds(key string, ps ...redis.Pair) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zadds(key, ps...) +} + +func (cs clusterStore) Zcard(key string) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zcard(key) +} + +func (cs clusterStore) Zcount(key string, start, stop int64) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zcount(key, start, stop) +} + +func (cs clusterStore) Zincrby(key string, increment int64, field string) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zincrby(key, increment, field) +} + +func (cs clusterStore) Zrank(key, field string) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zrank(key, field) +} + +func (cs clusterStore) Zrange(key string, start, stop int64) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Zrange(key, start, stop) +} + +func (cs clusterStore) ZrangeWithScores(key string, start, stop int64) ([]redis.Pair, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.ZrangeWithScores(key, start, stop) +} + +func (cs clusterStore) ZrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.ZrangebyscoreWithScores(key, start, stop) +} + +func (cs clusterStore) ZrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ( + []redis.Pair, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.ZrangebyscoreWithScoresAndLimit(key, start, stop, page, size) +} + +func (cs clusterStore) Zrem(key string, values ...interface{}) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zrem(key, values...) +} + +func (cs clusterStore) Zremrangebyrank(key string, start, stop int64) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zremrangebyrank(key, start, stop) +} + +func (cs clusterStore) Zremrangebyscore(key string, start, stop int64) (int, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zremrangebyscore(key, start, stop) +} + +func (cs clusterStore) Zrevrange(key string, start, stop int64) ([]string, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.Zrevrange(key, start, stop) +} + +func (cs clusterStore) ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.ZrevrangebyscoreWithScores(key, start, stop) +} + +func (cs clusterStore) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ( + []redis.Pair, error) { + node, err := cs.getRedis(key) + if err != nil { + return nil, err + } + + return node.ZrevrangebyscoreWithScoresAndLimit(key, start, stop, page, size) +} + +func (cs clusterStore) Zscore(key string, value string) (int64, error) { + node, err := cs.getRedis(key) + if err != nil { + return 0, err + } + + return node.Zscore(key, value) +} + +func (cs clusterStore) getRedis(key string) (*redis.Redis, error) { + if val, ok := cs.dispatcher.Get(key); !ok { + return nil, ErrNoRedisNode + } else { + return val.(*redis.Redis), nil + } +} diff --git a/core/stores/kv/store_test.go b/core/stores/kv/store_test.go new file mode 100644 index 00000000..f16c6316 --- /dev/null +++ b/core/stores/kv/store_test.go @@ -0,0 +1,498 @@ +package kv + +import ( + "testing" + "time" + + "zero/core/stores/internal" + "zero/core/stores/redis" + "zero/core/stringx" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +var s1, _ = miniredis.Run() +var s2, _ = miniredis.Run() + +func TestRedis_Exists(t *testing.T) { + runOnCluster(t, func(client Store) { + ok, err := client.Exists("a") + assert.Nil(t, err) + assert.False(t, ok) + assert.Nil(t, client.Set("a", "b")) + ok, err = client.Exists("a") + assert.Nil(t, err) + assert.True(t, ok) + }) +} + +func TestRedis_Eval(t *testing.T) { + runOnCluster(t, func(client Store) { + _, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, "notexist") + assert.Equal(t, redis.Nil, err) + err = client.Set("key1", "value1") + assert.Nil(t, err) + _, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, "key1") + assert.Equal(t, redis.Nil, err) + val, err := client.Eval(`return redis.call("EXISTS", KEYS[1])`, "key1") + assert.Nil(t, err) + assert.Equal(t, int64(1), val) + }) +} + +func TestRedis_Hgetall(t *testing.T) { + runOnCluster(t, func(client Store) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hgetall("a") + assert.Nil(t, err) + assert.EqualValues(t, map[string]string{ + "aa": "aaa", + "bb": "bbb", + }, vals) + }) +} + +func TestRedis_Hvals(t *testing.T) { + runOnCluster(t, func(client Store) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hvals("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals) + }) +} + +func TestRedis_Hsetnx(t *testing.T) { + runOnCluster(t, func(client Store) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + ok, err := client.Hsetnx("a", "bb", "ccc") + assert.Nil(t, err) + assert.False(t, ok) + ok, err = client.Hsetnx("a", "dd", "ddd") + assert.Nil(t, err) + assert.True(t, ok) + vals, err := client.Hvals("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"aaa", "bbb", "ddd"}, vals) + }) +} + +func TestRedis_HdelHlen(t *testing.T) { + runOnCluster(t, func(client Store) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + num, err := client.Hlen("a") + assert.Nil(t, err) + assert.Equal(t, 2, num) + val, err := client.Hdel("a", "aa") + assert.Nil(t, err) + assert.True(t, val) + vals, err := client.Hvals("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"bbb"}, vals) + }) +} + +func TestRedis_HIncrBy(t *testing.T) { + runOnCluster(t, func(client Store) { + val, err := client.Hincrby("key", "field", 2) + assert.Nil(t, err) + assert.Equal(t, 2, val) + val, err = client.Hincrby("key", "field", 3) + assert.Nil(t, err) + assert.Equal(t, 5, val) + }) +} + +func TestRedis_Hkeys(t *testing.T) { + runOnCluster(t, func(client Store) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hkeys("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"aa", "bb"}, vals) + }) +} + +func TestRedis_Hmget(t *testing.T) { + runOnCluster(t, func(client Store) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hmget("a", "aa", "bb") + assert.Nil(t, err) + assert.EqualValues(t, []string{"aaa", "bbb"}, vals) + vals, err = client.Hmget("a", "aa", "no", "bb") + assert.Nil(t, err) + assert.EqualValues(t, []string{"aaa", "", "bbb"}, vals) + }) +} + +func TestRedis_Hmset(t *testing.T) { + runOnCluster(t, func(client Store) { + assert.Nil(t, client.Hmset("a", map[string]string{ + "aa": "aaa", + "bb": "bbb", + })) + vals, err := client.Hmget("a", "aa", "bb") + assert.Nil(t, err) + assert.EqualValues(t, []string{"aaa", "bbb"}, vals) + }) +} + +func TestRedis_Incr(t *testing.T) { + runOnCluster(t, func(client Store) { + val, err := client.Incr("a") + assert.Nil(t, err) + assert.Equal(t, int64(1), val) + val, err = client.Incr("a") + assert.Nil(t, err) + assert.Equal(t, int64(2), val) + }) +} + +func TestRedis_IncrBy(t *testing.T) { + runOnCluster(t, func(client Store) { + val, err := client.Incrby("a", 2) + assert.Nil(t, err) + assert.Equal(t, int64(2), val) + val, err = client.Incrby("a", 3) + assert.Nil(t, err) + assert.Equal(t, int64(5), val) + }) +} + +func TestRedis_List(t *testing.T) { + runOnCluster(t, func(client Store) { + val, err := client.Lpush("key", "value1", "value2") + assert.Nil(t, err) + assert.Equal(t, 2, val) + val, err = client.Rpush("key", "value3", "value4") + assert.Nil(t, err) + assert.Equal(t, 4, val) + val, err = client.Llen("key") + assert.Nil(t, err) + assert.Equal(t, 4, val) + vals, err := client.Lrange("key", 0, 10) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals) + v, err := client.Lpop("key") + assert.Nil(t, err) + assert.Equal(t, "value2", v) + val, err = client.Lpush("key", "value1", "value2") + assert.Nil(t, err) + assert.Equal(t, 5, val) + val, err = client.Rpush("key", "value3", "value3") + assert.Nil(t, err) + assert.Equal(t, 7, val) + n, err := client.Lrem("key", 2, "value1") + assert.Nil(t, err) + assert.Equal(t, 2, n) + vals, err = client.Lrange("key", 0, 10) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals) + n, err = client.Lrem("key", -2, "value3") + assert.Nil(t, err) + assert.Equal(t, 2, n) + vals, err = client.Lrange("key", 0, 10) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals) + }) +} + +func TestRedis_Persist(t *testing.T) { + runOnCluster(t, func(client Store) { + ok, err := client.Persist("key") + assert.Nil(t, err) + assert.False(t, ok) + err = client.Set("key", "value") + assert.Nil(t, err) + ok, err = client.Persist("key") + assert.Nil(t, err) + assert.False(t, ok) + err = client.Expire("key", 5) + ok, err = client.Persist("key") + assert.Nil(t, err) + assert.True(t, ok) + err = client.Expireat("key", time.Now().Unix()+5) + ok, err = client.Persist("key") + assert.Nil(t, err) + assert.True(t, ok) + }) +} + +func TestRedis_Sscan(t *testing.T) { + runOnCluster(t, func(client Store) { + key := "list" + var list []string + for i := 0; i < 1550; i++ { + list = append(list, stringx.Randn(i)) + } + lens, err := client.Sadd(key, list) + assert.Nil(t, err) + assert.Equal(t, lens, 1550) + + var cursor uint64 = 0 + sum := 0 + for { + keys, next, err := client.Sscan(key, cursor, "", 100) + assert.Nil(t, err) + sum += len(keys) + if next == 0 { + break + } + cursor = next + } + + assert.Equal(t, sum, 1550) + _, err = client.Del(key) + assert.Nil(t, err) + }) +} + +func TestRedis_Set(t *testing.T) { + runOnCluster(t, func(client Store) { + num, err := client.Sadd("key", 1, 2, 3, 4) + assert.Nil(t, err) + assert.Equal(t, 4, num) + val, err := client.Scard("key") + assert.Nil(t, err) + assert.Equal(t, int64(4), val) + ok, err := client.Sismember("key", 2) + assert.Nil(t, err) + assert.True(t, ok) + num, err = client.Srem("key", 3, 4) + assert.Nil(t, err) + assert.Equal(t, 2, num) + vals, err := client.Smembers("key") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"1", "2"}, vals) + members, err := client.Srandmember("key", 1) + assert.Nil(t, err) + assert.Len(t, members, 1) + assert.Contains(t, []string{"1", "2"}, members[0]) + member, err := client.Spop("key") + assert.Nil(t, err) + assert.Contains(t, []string{"1", "2"}, member) + vals, err = client.Smembers("key") + assert.Nil(t, err) + assert.NotContains(t, vals, member) + num, err = client.Sadd("key1", 1, 2, 3, 4) + assert.Nil(t, err) + assert.Equal(t, 4, num) + num, err = client.Sadd("key2", 2, 3, 4, 5) + assert.Nil(t, err) + assert.Equal(t, 4, num) + }) +} + +func TestRedis_SetGetDel(t *testing.T) { + runOnCluster(t, func(client Store) { + err := client.Set("hello", "world") + assert.Nil(t, err) + val, err := client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "world", val) + ret, err := client.Del("hello") + assert.Nil(t, err) + assert.Equal(t, 1, ret) + }) +} + +func TestRedis_SetExNx(t *testing.T) { + runOnCluster(t, func(client Store) { + err := client.Setex("hello", "world", 5) + assert.Nil(t, err) + ok, err := client.Setnx("hello", "newworld") + assert.Nil(t, err) + assert.False(t, ok) + ok, err = client.Setnx("newhello", "newworld") + assert.Nil(t, err) + assert.True(t, ok) + val, err := client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "world", val) + val, err = client.Get("newhello") + assert.Nil(t, err) + assert.Equal(t, "newworld", val) + ttl, err := client.Ttl("hello") + assert.Nil(t, err) + assert.True(t, ttl > 0) + ok, err = client.SetnxEx("newhello", "newworld", 5) + assert.Nil(t, err) + assert.False(t, ok) + num, err := client.Del("newhello") + assert.Nil(t, err) + assert.Equal(t, 1, num) + ok, err = client.SetnxEx("newhello", "newworld", 5) + assert.Nil(t, err) + assert.True(t, ok) + val, err = client.Get("newhello") + assert.Nil(t, err) + assert.Equal(t, "newworld", val) + }) +} + +func TestRedis_SetGetDelHashField(t *testing.T) { + runOnCluster(t, func(client Store) { + err := client.Hset("key", "field", "value") + assert.Nil(t, err) + val, err := client.Hget("key", "field") + assert.Nil(t, err) + assert.Equal(t, "value", val) + ok, err := client.Hexists("key", "field") + assert.Nil(t, err) + assert.True(t, ok) + ret, err := client.Hdel("key", "field") + assert.Nil(t, err) + assert.True(t, ret) + ok, err = client.Hexists("key", "field") + assert.Nil(t, err) + assert.False(t, ok) + }) +} + +func TestRedis_SortedSet(t *testing.T) { + runOnCluster(t, func(client Store) { + ok, err := client.Zadd("key", 1, "value1") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 2, "value1") + assert.Nil(t, err) + assert.False(t, ok) + val, err := client.Zscore("key", "value1") + assert.Nil(t, err) + assert.Equal(t, int64(2), val) + val, err = client.Zincrby("key", 3, "value1") + assert.Nil(t, err) + assert.Equal(t, int64(5), val) + val, err = client.Zscore("key", "value1") + assert.Nil(t, err) + assert.Equal(t, int64(5), val) + ok, err = client.Zadd("key", 6, "value2") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 7, "value3") + assert.Nil(t, err) + assert.True(t, ok) + rank, err := client.Zrank("key", "value2") + assert.Nil(t, err) + assert.Equal(t, int64(1), rank) + rank, err = client.Zrank("key", "value4") + assert.Equal(t, redis.Nil, err) + num, err := client.Zrem("key", "value2", "value3") + assert.Nil(t, err) + assert.Equal(t, 2, num) + ok, err = client.Zadd("key", 6, "value2") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 7, "value3") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 8, "value4") + assert.Nil(t, err) + assert.True(t, ok) + num, err = client.Zremrangebyscore("key", 6, 7) + assert.Nil(t, err) + assert.Equal(t, 2, num) + ok, err = client.Zadd("key", 6, "value2") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 7, "value3") + assert.Nil(t, err) + assert.True(t, ok) + num, err = client.Zcount("key", 6, 7) + assert.Nil(t, err) + assert.Equal(t, 2, num) + num, err = client.Zremrangebyrank("key", 1, 2) + assert.Nil(t, err) + assert.Equal(t, 2, num) + card, err := client.Zcard("key") + assert.Nil(t, err) + assert.Equal(t, 2, card) + vals, err := client.Zrange("key", 0, -1) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value1", "value4"}, vals) + vals, err = client.Zrevrange("key", 0, -1) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value4", "value1"}, vals) + pairs, err := client.ZrangeWithScores("key", 0, -1) + assert.Nil(t, err) + assert.EqualValues(t, []redis.Pair{ + { + Key: "value1", + Score: 5, + }, + { + Key: "value4", + Score: 8, + }, + }, pairs) + pairs, err = client.ZrangebyscoreWithScores("key", 5, 8) + assert.Nil(t, err) + assert.EqualValues(t, []redis.Pair{ + { + Key: "value1", + Score: 5, + }, + { + Key: "value4", + Score: 8, + }, + }, pairs) + pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1) + assert.Nil(t, err) + assert.EqualValues(t, []redis.Pair{ + { + Key: "value4", + Score: 8, + }, + }, pairs) + pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8) + assert.Nil(t, err) + assert.EqualValues(t, []redis.Pair{ + { + Key: "value4", + Score: 8, + }, + { + Key: "value1", + Score: 5, + }, + }, pairs) + pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1) + assert.Nil(t, err) + assert.EqualValues(t, []redis.Pair{ + { + Key: "value1", + Score: 5, + }, + }, pairs) + }) +} + +func runOnCluster(t *testing.T, fn func(cluster Store)) { + s1.FlushAll() + s2.FlushAll() + + store := NewStore([]internal.NodeConf{ + { + RedisConf: redis.RedisConf{ + Host: s1.Addr(), + Type: redis.NodeType, + }, + Weight: 100, + }, + { + RedisConf: redis.RedisConf{ + Host: s2.Addr(), + Type: redis.NodeType, + }, + Weight: 100, + }, + }) + + fn(store) +} diff --git a/core/stores/mongo/bulkinserter.go b/core/stores/mongo/bulkinserter.go new file mode 100644 index 00000000..eb4c3a03 --- /dev/null +++ b/core/stores/mongo/bulkinserter.go @@ -0,0 +1,87 @@ +package mongo + +import ( + "time" + + "zero/core/executors" + "zero/core/logx" + + "github.com/globalsign/mgo" +) + +const ( + flushInterval = time.Second + maxBulkRows = 1000 +) + +type ( + ResultHandler func(*mgo.BulkResult, error) + + BulkInserter struct { + executor *executors.PeriodicalExecutor + inserter *dbInserter + } +) + +func NewBulkInserter(session *mgo.Session, dbName string, collectionNamer func() string) *BulkInserter { + inserter := &dbInserter{ + session: session, + dbName: dbName, + collectionNamer: collectionNamer, + } + + return &BulkInserter{ + executor: executors.NewPeriodicalExecutor(flushInterval, inserter), + inserter: inserter, + } +} + +func (bi *BulkInserter) Flush() { + bi.executor.Flush() +} + +func (bi *BulkInserter) Insert(doc interface{}) { + bi.executor.Add(doc) +} + +func (bi *BulkInserter) SetResultHandler(handler ResultHandler) { + bi.executor.Sync(func() { + bi.inserter.resultHandler = handler + }) +} + +type dbInserter struct { + session *mgo.Session + dbName string + collectionNamer func() string + documents []interface{} + resultHandler ResultHandler +} + +func (in *dbInserter) AddTask(doc interface{}) bool { + in.documents = append(in.documents, doc) + return len(in.documents) >= maxBulkRows +} + +func (in *dbInserter) Execute(objs interface{}) { + docs := objs.([]interface{}) + if len(docs) == 0 { + return + } + + bulk := in.session.DB(in.dbName).C(in.collectionNamer()).Bulk() + bulk.Insert(docs...) + bulk.Unordered() + result, err := bulk.Run() + if in.resultHandler != nil { + in.resultHandler(result, err) + } else if err != nil { + logx.Error(err) + } +} + +func (in *dbInserter) RemoveAll() interface{} { + documents := in.documents + in.documents = nil + return documents +} diff --git a/core/stores/mongo/collection.go b/core/stores/mongo/collection.go new file mode 100644 index 00000000..ba33e52a --- /dev/null +++ b/core/stores/mongo/collection.go @@ -0,0 +1,238 @@ +package mongo + +import ( + "encoding/json" + "time" + + "zero/core/breaker" + "zero/core/logx" + "zero/core/timex" + + "github.com/globalsign/mgo" +) + +const slowThreshold = time.Millisecond * 500 + +var ErrNotFound = mgo.ErrNotFound + +type ( + Collection interface { + Find(query interface{}) Query + FindId(id interface{}) Query + Insert(docs ...interface{}) error + Pipe(pipeline interface{}) Pipe + Remove(selector interface{}) error + RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) + RemoveId(id interface{}) error + Update(selector, update interface{}) error + UpdateId(id, update interface{}) error + Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) + } + + decoratedCollection struct { + *mgo.Collection + brk breaker.Breaker + } + + keepablePromise struct { + promise breaker.Promise + log func(error) + } +) + +func newCollection(collection *mgo.Collection) Collection { + return &decoratedCollection{ + Collection: collection, + brk: breaker.NewBreaker(), + } +} + +func (c *decoratedCollection) Find(query interface{}) Query { + promise, err := c.brk.Allow() + if err != nil { + return rejectedQuery{} + } + + startTime := timex.Now() + return promisedQuery{ + Query: c.Collection.Find(query), + promise: keepablePromise{ + promise: promise, + log: func(err error) { + duration := timex.Since(startTime) + c.logDuration("find", duration, err, query) + }, + }, + } +} + +func (c *decoratedCollection) FindId(id interface{}) Query { + promise, err := c.brk.Allow() + if err != nil { + return rejectedQuery{} + } + + startTime := timex.Now() + return promisedQuery{ + Query: c.Collection.FindId(id), + promise: keepablePromise{ + promise: promise, + log: func(err error) { + duration := timex.Since(startTime) + c.logDuration("findId", duration, err, id) + }, + }, + } +} + +func (c *decoratedCollection) Insert(docs ...interface{}) (err error) { + return c.brk.DoWithAcceptable(func() error { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + c.logDuration("insert", duration, err, docs...) + }() + + return c.Collection.Insert(docs...) + }, acceptable) +} + +func (c *decoratedCollection) Pipe(pipeline interface{}) Pipe { + promise, err := c.brk.Allow() + if err != nil { + return rejectedPipe{} + } + + startTime := timex.Now() + return promisedPipe{ + Pipe: c.Collection.Pipe(pipeline), + promise: keepablePromise{ + promise: promise, + log: func(err error) { + duration := timex.Since(startTime) + c.logDuration("pipe", duration, err, pipeline) + }, + }, + } +} + +func (c *decoratedCollection) Remove(selector interface{}) (err error) { + return c.brk.DoWithAcceptable(func() error { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + c.logDuration("remove", duration, err, selector) + }() + + return c.Collection.Remove(selector) + }, acceptable) +} + +func (c *decoratedCollection) RemoveAll(selector interface{}) (info *mgo.ChangeInfo, err error) { + err = c.brk.DoWithAcceptable(func() error { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + c.logDuration("removeAll", duration, err, selector) + }() + + info, err = c.Collection.RemoveAll(selector) + return err + }, acceptable) + + return +} + +func (c *decoratedCollection) RemoveId(id interface{}) (err error) { + return c.brk.DoWithAcceptable(func() error { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + c.logDuration("removeId", duration, err, id) + }() + + return c.Collection.RemoveId(id) + }, acceptable) +} + +func (c *decoratedCollection) Update(selector, update interface{}) (err error) { + return c.brk.DoWithAcceptable(func() error { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + c.logDuration("update", duration, err, selector, update) + }() + + return c.Collection.Update(selector, update) + }, acceptable) +} + +func (c *decoratedCollection) UpdateId(id, update interface{}) (err error) { + return c.brk.DoWithAcceptable(func() error { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + c.logDuration("updateId", duration, err, id, update) + }() + + return c.Collection.UpdateId(id, update) + }, acceptable) +} + +func (c *decoratedCollection) Upsert(selector, update interface{}) (info *mgo.ChangeInfo, err error) { + err = c.brk.DoWithAcceptable(func() error { + startTime := timex.Now() + defer func() { + duration := timex.Since(startTime) + c.logDuration("upsert", duration, err, selector, update) + }() + + info, err = c.Collection.Upsert(selector, update) + return err + }, acceptable) + + return +} + +func (c *decoratedCollection) logDuration(method string, duration time.Duration, err error, docs ...interface{}) { + content, e := json.Marshal(docs) + if e != nil { + logx.Error(err) + } else if err != nil { + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s", + c.FullName, method, err.Error(), string(content)) + } else { + logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s", + c.FullName, method, err.Error(), string(content)) + } + } else { + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s", + c.FullName, method, string(content)) + } else { + logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.FullName, method, string(content)) + } + } +} + +func (p keepablePromise) accept(err error) error { + p.promise.Accept() + p.log(err) + return err +} + +func (p keepablePromise) keep(err error) error { + if acceptable(err) { + p.promise.Accept() + } else { + p.promise.Reject(err.Error()) + } + + p.log(err) + return err +} + +func acceptable(err error) bool { + return err == nil || err == mgo.ErrNotFound +} diff --git a/core/stores/mongo/collection_test.go b/core/stores/mongo/collection_test.go new file mode 100644 index 00000000..65ccaa08 --- /dev/null +++ b/core/stores/mongo/collection_test.go @@ -0,0 +1,71 @@ +package mongo + +import ( + "errors" + "testing" + + "github.com/globalsign/mgo" + "github.com/stretchr/testify/assert" + + "zero/core/stringx" +) + +func TestKeepPromise_accept(t *testing.T) { + p := new(mockPromise) + kp := keepablePromise{ + promise: p, + log: func(error) {}, + } + assert.Nil(t, kp.accept(nil)) + assert.Equal(t, mgo.ErrNotFound, kp.accept(mgo.ErrNotFound)) +} + +func TestKeepPromise_keep(t *testing.T) { + tests := []struct { + err error + accepted bool + reason string + }{ + { + err: nil, + accepted: true, + reason: "", + }, + { + err: mgo.ErrNotFound, + accepted: true, + reason: "", + }, + { + err: errors.New("any"), + accepted: false, + reason: "any", + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + p := new(mockPromise) + kp := keepablePromise{ + promise: p, + log: func(error) {}, + } + assert.Equal(t, test.err, kp.keep(test.err)) + assert.Equal(t, test.accepted, p.accepted) + assert.Equal(t, test.reason, p.reason) + }) + } +} + +type mockPromise struct { + accepted bool + reason string +} + +func (p *mockPromise) Accept() { + p.accepted = true +} + +func (p *mockPromise) Reject(reason string) { + p.reason = reason +} diff --git a/core/stores/mongo/iter.go b/core/stores/mongo/iter.go new file mode 100644 index 00000000..31c38ab8 --- /dev/null +++ b/core/stores/mongo/iter.go @@ -0,0 +1,96 @@ +//go:generate mockgen -package mongo -destination iter_mock.go -source iter.go Iter +package mongo + +import ( + "zero/core/breaker" + + "github.com/globalsign/mgo/bson" +) + +type ( + Iter interface { + All(result interface{}) error + Close() error + Done() bool + Err() error + For(result interface{}, f func() error) error + Next(result interface{}) bool + State() (int64, []bson.Raw) + Timeout() bool + } + + ClosableIter struct { + Iter + Cleanup func() + } + + promisedIter struct { + Iter + promise keepablePromise + } + + rejectedIter struct{} +) + +func (i promisedIter) All(result interface{}) error { + return i.promise.keep(i.Iter.All(result)) +} + +func (i promisedIter) Close() error { + return i.promise.keep(i.Iter.Close()) +} + +func (i promisedIter) Err() error { + return i.Iter.Err() +} + +func (i promisedIter) For(result interface{}, f func() error) error { + var ferr error + err := i.Iter.For(result, func() error { + ferr = f() + return ferr + }) + if ferr == err { + return i.promise.accept(err) + } + + return i.promise.keep(err) +} + +func (it *ClosableIter) Close() error { + err := it.Iter.Close() + it.Cleanup() + return err +} + +func (i rejectedIter) All(result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (i rejectedIter) Close() error { + return breaker.ErrServiceUnavailable +} + +func (i rejectedIter) Done() bool { + return false +} + +func (i rejectedIter) Err() error { + return breaker.ErrServiceUnavailable +} + +func (i rejectedIter) For(result interface{}, f func() error) error { + return breaker.ErrServiceUnavailable +} + +func (i rejectedIter) Next(result interface{}) bool { + return false +} + +func (i rejectedIter) State() (int64, []bson.Raw) { + return 0, nil +} + +func (i rejectedIter) Timeout() bool { + return false +} diff --git a/core/stores/mongo/iter_mock.go b/core/stores/mongo/iter_mock.go new file mode 100644 index 00000000..dbc851ce --- /dev/null +++ b/core/stores/mongo/iter_mock.go @@ -0,0 +1,147 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: iter.go + +// Package mongo is a generated GoMock package. +package mongo + +import ( + bson "github.com/globalsign/mgo/bson" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockIter is a mock of Iter interface +type MockIter struct { + ctrl *gomock.Controller + recorder *MockIterMockRecorder +} + +// MockIterMockRecorder is the mock recorder for MockIter +type MockIterMockRecorder struct { + mock *MockIter +} + +// NewMockIter creates a new mock instance +func NewMockIter(ctrl *gomock.Controller) *MockIter { + mock := &MockIter{ctrl: ctrl} + mock.recorder = &MockIterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockIter) EXPECT() *MockIterMockRecorder { + return m.recorder +} + +// All mocks base method +func (m *MockIter) All(result interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "All", result) + ret0, _ := ret[0].(error) + return ret0 +} + +// All indicates an expected call of All +func (mr *MockIterMockRecorder) All(result interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "All", reflect.TypeOf((*MockIter)(nil).All), result) +} + +// Close mocks base method +func (m *MockIter) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockIterMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockIter)(nil).Close)) +} + +// Done mocks base method +func (m *MockIter) Done() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Done") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Done indicates an expected call of Done +func (mr *MockIterMockRecorder) Done() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockIter)(nil).Done)) +} + +// Err mocks base method +func (m *MockIter) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err +func (mr *MockIterMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockIter)(nil).Err)) +} + +// For mocks base method +func (m *MockIter) For(result interface{}, f func() error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "For", result, f) + ret0, _ := ret[0].(error) + return ret0 +} + +// For indicates an expected call of For +func (mr *MockIterMockRecorder) For(result, f interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "For", reflect.TypeOf((*MockIter)(nil).For), result, f) +} + +// Next mocks base method +func (m *MockIter) Next(result interface{}) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next", result) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next +func (mr *MockIterMockRecorder) Next(result interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIter)(nil).Next), result) +} + +// State mocks base method +func (m *MockIter) State() (int64, []bson.Raw) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "State") + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].([]bson.Raw) + return ret0, ret1 +} + +// State indicates an expected call of State +func (mr *MockIterMockRecorder) State() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockIter)(nil).State)) +} + +// Timeout mocks base method +func (m *MockIter) Timeout() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Timeout") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Timeout indicates an expected call of Timeout +func (mr *MockIterMockRecorder) Timeout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockIter)(nil).Timeout)) +} diff --git a/core/stores/mongo/iter_test.go b/core/stores/mongo/iter_test.go new file mode 100644 index 00000000..70deb867 --- /dev/null +++ b/core/stores/mongo/iter_test.go @@ -0,0 +1,265 @@ +package mongo + +import ( + "errors" + "testing" + + "zero/core/breaker" + "zero/core/stringx" + "zero/core/syncx" + + "github.com/globalsign/mgo" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestClosableIter_Close(t *testing.T) { + errs := []error{ + nil, + mgo.ErrNotFound, + } + + for _, err := range errs { + t.Run(stringx.RandId(), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cleaned := syncx.NewAtomicBool() + iter := NewMockIter(ctrl) + iter.EXPECT().Close().Return(err) + ci := ClosableIter{ + Iter: iter, + Cleanup: func() { + cleaned.Set(true) + }, + } + assert.Equal(t, err, ci.Close()) + assert.True(t, cleaned.True()) + }) + } +} + +func TestPromisedIter_AllAndClose(t *testing.T) { + tests := []struct { + err error + accepted bool + reason string + }{ + { + err: nil, + accepted: true, + reason: "", + }, + { + err: mgo.ErrNotFound, + accepted: true, + reason: "", + }, + { + err: errors.New("any"), + accepted: false, + reason: "any", + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().All(gomock.Any()).Return(test.err) + promise := new(mockPromise) + pi := promisedIter{ + Iter: iter, + promise: keepablePromise{ + promise: promise, + log: func(error) {}, + }, + } + assert.Equal(t, test.err, pi.All(nil)) + assert.Equal(t, test.accepted, promise.accepted) + assert.Equal(t, test.reason, promise.reason) + }) + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().Close().Return(test.err) + promise := new(mockPromise) + pi := promisedIter{ + Iter: iter, + promise: keepablePromise{ + promise: promise, + log: func(error) {}, + }, + } + assert.Equal(t, test.err, pi.Close()) + assert.Equal(t, test.accepted, promise.accepted) + assert.Equal(t, test.reason, promise.reason) + }) + } +} + +func TestPromisedIter_Err(t *testing.T) { + errs := []error{ + nil, + mgo.ErrNotFound, + } + + for _, err := range errs { + t.Run(stringx.RandId(), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().Err().Return(err) + promise := new(mockPromise) + pi := promisedIter{ + Iter: iter, + promise: keepablePromise{ + promise: promise, + log: func(error) {}, + }, + } + assert.Equal(t, err, pi.Err()) + }) + } +} + +func TestPromisedIter_For(t *testing.T) { + tests := []struct { + err error + accepted bool + reason string + }{ + { + err: nil, + accepted: true, + reason: "", + }, + { + err: mgo.ErrNotFound, + accepted: true, + reason: "", + }, + { + err: errors.New("any"), + accepted: false, + reason: "any", + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().For(gomock.Any(), gomock.Any()).Return(test.err) + promise := new(mockPromise) + pi := promisedIter{ + Iter: iter, + promise: keepablePromise{ + promise: promise, + log: func(error) {}, + }, + } + assert.Equal(t, test.err, pi.For(nil, nil)) + assert.Equal(t, test.accepted, promise.accepted) + assert.Equal(t, test.reason, promise.reason) + }) + } +} + +func TestRejectedIter_All(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).All(nil)) +} + +func TestRejectedIter_Close(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Close()) +} + +func TestRejectedIter_Done(t *testing.T) { + assert.False(t, new(rejectedIter).Done()) +} + +func TestRejectedIter_Err(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Err()) +} + +func TestRejectedIter_For(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).For(nil, nil)) +} + +func TestRejectedIter_Next(t *testing.T) { + assert.False(t, new(rejectedIter).Next(nil)) +} + +func TestRejectedIter_State(t *testing.T) { + n, raw := new(rejectedIter).State() + assert.Equal(t, int64(0), n) + assert.Nil(t, raw) +} + +func TestRejectedIter_Timeout(t *testing.T) { + assert.False(t, new(rejectedIter).Timeout()) +} + +func TestIter_Done(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().Done().Return(true) + ci := ClosableIter{ + Iter: iter, + Cleanup: nil, + } + assert.True(t, ci.Done()) +} + +func TestIter_Next(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().Next(gomock.Any()).Return(true) + ci := ClosableIter{ + Iter: iter, + Cleanup: nil, + } + assert.True(t, ci.Next(nil)) +} + +func TestIter_State(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().State().Return(int64(1), nil) + ci := ClosableIter{ + Iter: iter, + Cleanup: nil, + } + n, raw := ci.State() + assert.Equal(t, int64(1), n) + assert.Nil(t, raw) +} + +func TestIter_Timeout(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + iter := NewMockIter(ctrl) + iter.EXPECT().Timeout().Return(true) + ci := ClosableIter{ + Iter: iter, + Cleanup: nil, + } + assert.True(t, ci.Timeout()) +} diff --git a/core/stores/mongo/model.go b/core/stores/mongo/model.go new file mode 100644 index 00000000..9088896d --- /dev/null +++ b/core/stores/mongo/model.go @@ -0,0 +1,164 @@ +package mongo + +import ( + "log" + "time" + + "github.com/globalsign/mgo" +) + +type ( + options struct { + timeout time.Duration + } + + Option func(opts *options) + + Model struct { + session *concurrentSession + db *mgo.Database + collection string + opts []Option + } +) + +func MustNewModel(url, database, collection string, opts ...Option) *Model { + model, err := NewModel(url, database, collection, opts...) + if err != nil { + log.Fatal(err) + } + + return model +} + +func NewModel(url, database, collection string, opts ...Option) (*Model, error) { + session, err := getConcurrentSession(url) + if err != nil { + return nil, err + } + + return &Model{ + session: session, + db: session.DB(database), + collection: collection, + opts: opts, + }, nil +} + +func (mm *Model) Find(query interface{}) (Query, error) { + return mm.query(func(c Collection) Query { + return c.Find(query) + }) +} + +func (mm *Model) FindId(id interface{}) (Query, error) { + return mm.query(func(c Collection) Query { + return c.FindId(id) + }) +} + +func (mm *Model) GetCollection(session *mgo.Session) Collection { + return newCollection(mm.db.C(mm.collection).With(session)) +} + +func (mm *Model) Insert(docs ...interface{}) error { + return mm.execute(func(c Collection) error { + return c.Insert(docs...) + }) +} + +func (mm *Model) Pipe(pipeline interface{}) (Pipe, error) { + return mm.pipe(func(c Collection) Pipe { + return c.Pipe(pipeline) + }) +} + +func (mm *Model) PutSession(session *mgo.Session) { + mm.session.putSession(session) +} + +func (mm *Model) Remove(selector interface{}) error { + return mm.execute(func(c Collection) error { + return c.Remove(selector) + }) +} + +func (mm *Model) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) { + return mm.change(func(c Collection) (*mgo.ChangeInfo, error) { + return c.RemoveAll(selector) + }) +} + +func (mm *Model) RemoveId(id interface{}) error { + return mm.execute(func(c Collection) error { + return c.RemoveId(id) + }) +} + +func (mm *Model) TakeSession() (*mgo.Session, error) { + return mm.session.takeSession(mm.opts...) +} + +func (mm *Model) Update(selector, update interface{}) error { + return mm.execute(func(c Collection) error { + return c.Update(selector, update) + }) +} + +func (mm *Model) UpdateId(id, update interface{}) error { + return mm.execute(func(c Collection) error { + return c.UpdateId(id, update) + }) +} + +func (mm *Model) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) { + return mm.change(func(c Collection) (*mgo.ChangeInfo, error) { + return c.Upsert(selector, update) + }) +} + +func (mm *Model) change(fn func(c Collection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) { + session, err := mm.TakeSession() + if err != nil { + return nil, err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)) +} + +func (mm *Model) execute(fn func(c Collection) error) error { + session, err := mm.TakeSession() + if err != nil { + return err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)) +} + +func (mm *Model) pipe(fn func(c Collection) Pipe) (Pipe, error) { + session, err := mm.TakeSession() + if err != nil { + return nil, err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)), nil +} + +func (mm *Model) query(fn func(c Collection) Query) (Query, error) { + session, err := mm.TakeSession() + if err != nil { + return nil, err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)), nil +} + +func WithTimeout(timeout time.Duration) Option { + return func(opts *options) { + opts.timeout = timeout + } +} diff --git a/core/stores/mongo/pipe.go b/core/stores/mongo/pipe.go new file mode 100644 index 00000000..c54b16ab --- /dev/null +++ b/core/stores/mongo/pipe.go @@ -0,0 +1,100 @@ +package mongo + +import ( + "time" + + "zero/core/breaker" + + "github.com/globalsign/mgo" +) + +type ( + Pipe interface { + All(result interface{}) error + AllowDiskUse() Pipe + Batch(n int) Pipe + Collation(collation *mgo.Collation) Pipe + Explain(result interface{}) error + Iter() Iter + One(result interface{}) error + SetMaxTime(d time.Duration) Pipe + } + + promisedPipe struct { + *mgo.Pipe + promise keepablePromise + } + + rejectedPipe struct{} +) + +func (p promisedPipe) All(result interface{}) error { + return p.promise.keep(p.Pipe.All(result)) +} + +func (p promisedPipe) AllowDiskUse() Pipe { + p.Pipe.AllowDiskUse() + return p +} + +func (p promisedPipe) Batch(n int) Pipe { + p.Pipe.Batch(n) + return p +} + +func (p promisedPipe) Collation(collation *mgo.Collation) Pipe { + p.Pipe.Collation(collation) + return p +} + +func (p promisedPipe) Explain(result interface{}) error { + return p.promise.keep(p.Pipe.Explain(result)) +} + +func (p promisedPipe) Iter() Iter { + return promisedIter{ + Iter: p.Pipe.Iter(), + promise: p.promise, + } +} + +func (p promisedPipe) One(result interface{}) error { + return p.promise.keep(p.Pipe.One(result)) +} + +func (p promisedPipe) SetMaxTime(d time.Duration) Pipe { + p.Pipe.SetMaxTime(d) + return p +} + +func (p rejectedPipe) All(result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (p rejectedPipe) AllowDiskUse() Pipe { + return p +} + +func (p rejectedPipe) Batch(n int) Pipe { + return p +} + +func (p rejectedPipe) Collation(collation *mgo.Collation) Pipe { + return p +} + +func (p rejectedPipe) Explain(result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (p rejectedPipe) Iter() Iter { + return rejectedIter{} +} + +func (p rejectedPipe) One(result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (p rejectedPipe) SetMaxTime(d time.Duration) Pipe { + return p +} diff --git a/core/stores/mongo/pipe_test.go b/core/stores/mongo/pipe_test.go new file mode 100644 index 00000000..640ab8c7 --- /dev/null +++ b/core/stores/mongo/pipe_test.go @@ -0,0 +1,45 @@ +package mongo + +import ( + "testing" + + "zero/core/breaker" + + "github.com/stretchr/testify/assert" +) + +func TestRejectedPipe_All(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).All(nil)) +} + +func TestRejectedPipe_AllowDiskUse(t *testing.T) { + var p rejectedPipe + assert.Equal(t, p, p.AllowDiskUse()) +} + +func TestRejectedPipe_Batch(t *testing.T) { + var p rejectedPipe + assert.Equal(t, p, p.Batch(1)) +} + +func TestRejectedPipe_Collation(t *testing.T) { + var p rejectedPipe + assert.Equal(t, p, p.Collation(nil)) +} + +func TestRejectedPipe_Explain(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).Explain(nil)) +} + +func TestRejectedPipe_Iter(t *testing.T) { + assert.EqualValues(t, rejectedIter{}, new(rejectedPipe).Iter()) +} + +func TestRejectedPipe_One(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).One(nil)) +} + +func TestRejectedPipe_SetMaxTime(t *testing.T) { + var p rejectedPipe + assert.Equal(t, p, p.SetMaxTime(0)) +} diff --git a/core/stores/mongo/query.go b/core/stores/mongo/query.go new file mode 100644 index 00000000..d3e23457 --- /dev/null +++ b/core/stores/mongo/query.go @@ -0,0 +1,285 @@ +package mongo + +import ( + "time" + + "zero/core/breaker" + + "github.com/globalsign/mgo" +) + +type ( + Query interface { + All(result interface{}) error + Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) + Batch(n int) Query + Collation(collation *mgo.Collation) Query + Comment(comment string) Query + Count() (int, error) + Distinct(key string, result interface{}) error + Explain(result interface{}) error + For(result interface{}, f func() error) error + Hint(indexKey ...string) Query + Iter() Iter + Limit(n int) Query + LogReplay() Query + MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) + One(result interface{}) error + Prefetch(p float64) Query + Select(selector interface{}) Query + SetMaxScan(n int) Query + SetMaxTime(d time.Duration) Query + Skip(n int) Query + Snapshot() Query + Sort(fields ...string) Query + Tail(timeout time.Duration) Iter + } + + promisedQuery struct { + *mgo.Query + promise keepablePromise + } + + rejectedQuery struct{} +) + +func (q promisedQuery) All(result interface{}) error { + return q.promise.keep(q.Query.All(result)) +} + +func (q promisedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) { + info, err := q.Query.Apply(change, result) + return info, q.promise.keep(err) +} + +func (q promisedQuery) Batch(n int) Query { + return promisedQuery{ + Query: q.Query.Batch(n), + promise: q.promise, + } +} + +func (q promisedQuery) Collation(collation *mgo.Collation) Query { + return promisedQuery{ + Query: q.Query.Collation(collation), + promise: q.promise, + } +} + +func (q promisedQuery) Comment(comment string) Query { + return promisedQuery{ + Query: q.Query.Comment(comment), + promise: q.promise, + } +} + +func (q promisedQuery) Count() (int, error) { + v, err := q.Query.Count() + return v, q.promise.keep(err) +} + +func (q promisedQuery) Distinct(key string, result interface{}) error { + return q.promise.keep(q.Query.Distinct(key, result)) +} + +func (q promisedQuery) Explain(result interface{}) error { + return q.promise.keep(q.Query.Explain(result)) +} + +func (q promisedQuery) For(result interface{}, f func() error) error { + var ferr error + err := q.Query.For(result, func() error { + ferr = f() + return ferr + }) + if ferr == err { + return q.promise.accept(err) + } + + return q.promise.keep(err) +} + +func (q promisedQuery) Hint(indexKey ...string) Query { + return promisedQuery{ + Query: q.Query.Hint(indexKey...), + promise: q.promise, + } +} + +func (q promisedQuery) Iter() Iter { + return promisedIter{ + Iter: q.Query.Iter(), + promise: q.promise, + } +} + +func (q promisedQuery) Limit(n int) Query { + return promisedQuery{ + Query: q.Query.Limit(n), + promise: q.promise, + } +} + +func (q promisedQuery) LogReplay() Query { + return promisedQuery{ + Query: q.Query.LogReplay(), + promise: q.promise, + } +} + +func (q promisedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) { + info, err := q.Query.MapReduce(job, result) + return info, q.promise.keep(err) +} + +func (q promisedQuery) One(result interface{}) error { + return q.promise.keep(q.Query.One(result)) +} + +func (q promisedQuery) Prefetch(p float64) Query { + return promisedQuery{ + Query: q.Query.Prefetch(p), + promise: q.promise, + } +} + +func (q promisedQuery) Select(selector interface{}) Query { + return promisedQuery{ + Query: q.Query.Select(selector), + promise: q.promise, + } +} + +func (q promisedQuery) SetMaxScan(n int) Query { + return promisedQuery{ + Query: q.Query.SetMaxScan(n), + promise: q.promise, + } +} + +func (q promisedQuery) SetMaxTime(d time.Duration) Query { + return promisedQuery{ + Query: q.Query.SetMaxTime(d), + promise: q.promise, + } +} + +func (q promisedQuery) Skip(n int) Query { + return promisedQuery{ + Query: q.Query.Skip(n), + promise: q.promise, + } +} + +func (q promisedQuery) Snapshot() Query { + return promisedQuery{ + Query: q.Query.Snapshot(), + promise: q.promise, + } +} + +func (q promisedQuery) Sort(fields ...string) Query { + return promisedQuery{ + Query: q.Query.Sort(fields...), + promise: q.promise, + } +} + +func (q promisedQuery) Tail(timeout time.Duration) Iter { + return promisedIter{ + Iter: q.Query.Tail(timeout), + promise: q.promise, + } +} + +func (q rejectedQuery) All(result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) { + return nil, breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) Batch(n int) Query { + return q +} + +func (q rejectedQuery) Collation(collation *mgo.Collation) Query { + return q +} + +func (q rejectedQuery) Comment(comment string) Query { + return q +} + +func (q rejectedQuery) Count() (int, error) { + return 0, breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) Distinct(key string, result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) Explain(result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) For(result interface{}, f func() error) error { + return breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) Hint(indexKey ...string) Query { + return q +} + +func (q rejectedQuery) Iter() Iter { + return rejectedIter{} +} + +func (q rejectedQuery) Limit(n int) Query { + return q +} + +func (q rejectedQuery) LogReplay() Query { + return q +} + +func (q rejectedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) { + return nil, breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) One(result interface{}) error { + return breaker.ErrServiceUnavailable +} + +func (q rejectedQuery) Prefetch(p float64) Query { + return q +} + +func (q rejectedQuery) Select(selector interface{}) Query { + return q +} + +func (q rejectedQuery) SetMaxScan(n int) Query { + return q +} + +func (q rejectedQuery) SetMaxTime(d time.Duration) Query { + return q +} + +func (q rejectedQuery) Skip(n int) Query { + return q +} + +func (q rejectedQuery) Snapshot() Query { + return q +} + +func (q rejectedQuery) Sort(fields ...string) Query { + return q +} + +func (q rejectedQuery) Tail(timeout time.Duration) Iter { + return rejectedIter{} +} diff --git a/core/stores/mongo/query_test.go b/core/stores/mongo/query_test.go new file mode 100644 index 00000000..9b05b258 --- /dev/null +++ b/core/stores/mongo/query_test.go @@ -0,0 +1,121 @@ +package mongo + +import ( + "testing" + + "zero/core/breaker" + + "github.com/globalsign/mgo" + "github.com/stretchr/testify/assert" +) + +func Test_rejectedQuery_All(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).All(nil)) +} + +func Test_rejectedQuery_Apply(t *testing.T) { + info, err := new(rejectedQuery).Apply(mgo.Change{}, nil) + assert.Equal(t, breaker.ErrServiceUnavailable, err) + assert.Nil(t, info) +} + +func Test_rejectedQuery_Batch(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Batch(1)) +} + +func Test_rejectedQuery_Collation(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Collation(nil)) +} + +func Test_rejectedQuery_Comment(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Comment("")) +} + +func Test_rejectedQuery_Count(t *testing.T) { + n, err := new(rejectedQuery).Count() + assert.Equal(t, breaker.ErrServiceUnavailable, err) + assert.Equal(t, 0, n) +} + +func Test_rejectedQuery_Distinct(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Distinct("", nil)) +} + +func Test_rejectedQuery_Explain(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Explain(nil)) +} + +func Test_rejectedQuery_For(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).For(nil, nil)) +} + +func Test_rejectedQuery_Hint(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Hint()) +} + +func Test_rejectedQuery_Iter(t *testing.T) { + assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Iter()) +} + +func Test_rejectedQuery_Limit(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Limit(1)) +} + +func Test_rejectedQuery_LogReplay(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.LogReplay()) +} + +func Test_rejectedQuery_MapReduce(t *testing.T) { + info, err := new(rejectedQuery).MapReduce(nil, nil) + assert.Equal(t, breaker.ErrServiceUnavailable, err) + assert.Nil(t, info) +} + +func Test_rejectedQuery_One(t *testing.T) { + assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).One(nil)) +} + +func Test_rejectedQuery_Prefetch(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Prefetch(1)) +} + +func Test_rejectedQuery_Select(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Select(nil)) +} + +func Test_rejectedQuery_SetMaxScan(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.SetMaxScan(0)) +} + +func Test_rejectedQuery_SetMaxTime(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.SetMaxTime(0)) +} + +func Test_rejectedQuery_Skip(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Skip(0)) +} + +func Test_rejectedQuery_Snapshot(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Snapshot()) +} + +func Test_rejectedQuery_Sort(t *testing.T) { + var q rejectedQuery + assert.Equal(t, q, q.Sort()) +} + +func Test_rejectedQuery_Tail(t *testing.T) { + assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Tail(0)) +} diff --git a/core/stores/mongo/sessionmanager.go b/core/stores/mongo/sessionmanager.go new file mode 100644 index 00000000..819c16fc --- /dev/null +++ b/core/stores/mongo/sessionmanager.go @@ -0,0 +1,73 @@ +package mongo + +import ( + "io" + "time" + + "zero/core/logx" + "zero/core/syncx" + + "github.com/globalsign/mgo" +) + +const ( + defaultConcurrency = 50 + defaultTimeout = time.Second +) + +var sessionManager = syncx.NewResourceManager() + +type concurrentSession struct { + *mgo.Session + limit syncx.TimeoutLimit +} + +func (cs *concurrentSession) Close() error { + cs.Session.Close() + return nil +} + +func getConcurrentSession(url string) (*concurrentSession, error) { + val, err := sessionManager.GetResource(url, func() (io.Closer, error) { + mgoSession, err := mgo.Dial(url) + if err != nil { + return nil, err + } + + concurrentSess := &concurrentSession{ + Session: mgoSession, + limit: syncx.NewTimeoutLimit(defaultConcurrency), + } + + return concurrentSess, nil + }) + if err != nil { + return nil, err + } + + return val.(*concurrentSession), nil +} + +func (cs *concurrentSession) putSession(session *mgo.Session) { + if err := cs.limit.Return(); err != nil { + logx.Error(err) + } + + // anyway, we need to close the session + session.Close() +} + +func (cs *concurrentSession) takeSession(opts ...Option) (*mgo.Session, error) { + o := &options{ + timeout: defaultTimeout, + } + for _, opt := range opts { + opt(o) + } + + if err := cs.limit.Borrow(o.timeout); err != nil { + return nil, err + } else { + return cs.Copy(), nil + } +} diff --git a/core/stores/mongo/util.go b/core/stores/mongo/util.go new file mode 100644 index 00000000..0db2c035 --- /dev/null +++ b/core/stores/mongo/util.go @@ -0,0 +1,9 @@ +package mongo + +import "strings" + +const mongoAddrSep = "," + +func FormatAddr(hosts []string) string { + return strings.Join(hosts, mongoAddrSep) +} diff --git a/core/stores/mongo/utils_test.go b/core/stores/mongo/utils_test.go new file mode 100644 index 00000000..b9095701 --- /dev/null +++ b/core/stores/mongo/utils_test.go @@ -0,0 +1,35 @@ +package mongo + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormatAddrs(t *testing.T) { + tests := []struct { + addrs []string + expect string + }{ + { + addrs: []string{"a", "b"}, + expect: "a,b", + }, + { + addrs: []string{"a", "b", "c"}, + expect: "a,b,c", + }, + { + addrs: []string{}, + expect: "", + }, + { + addrs: nil, + expect: "", + }, + } + + for _, test := range tests { + assert.Equal(t, test.expect, FormatAddr(test.addrs)) + } +} diff --git a/core/stores/mongoc/cachedcollection.go b/core/stores/mongoc/cachedcollection.go new file mode 100644 index 00000000..e4faf77e --- /dev/null +++ b/core/stores/mongoc/cachedcollection.go @@ -0,0 +1,171 @@ +package mongoc + +import ( + "zero/core/stores/internal" + "zero/core/stores/mongo" + "zero/core/syncx" + + "github.com/globalsign/mgo" +) + +var ( + ErrNotFound = mgo.ErrNotFound + + // can't use one SharedCalls per conn, because multiple conns may share the same cache key. + sharedCalls = syncx.NewSharedCalls() + stats = internal.NewCacheStat("mongoc") +) + +type ( + QueryOption func(query mongo.Query) mongo.Query + + cachedCollection struct { + collection mongo.Collection + cache internal.Cache + } +) + +func newCollection(collection mongo.Collection, c internal.Cache) *cachedCollection { + return &cachedCollection{ + collection: collection, + cache: c, + } +} + +func (c *cachedCollection) Count(query interface{}) (int, error) { + return c.collection.Find(query).Count() +} + +func (c *cachedCollection) DelCache(keys ...string) error { + return c.cache.DelCache(keys...) +} + +func (c *cachedCollection) GetCache(key string, v interface{}) error { + return c.cache.GetCache(key, v) +} + +func (c *cachedCollection) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error { + q := c.collection.Find(query) + for _, opt := range opts { + q = opt(q) + } + return q.All(v) +} + +func (c *cachedCollection) FindOne(v interface{}, key string, query interface{}) error { + return c.cache.Take(v, key, func(v interface{}) error { + q := c.collection.Find(query) + return q.One(v) + }) +} + +func (c *cachedCollection) FindOneNoCache(v interface{}, query interface{}) error { + q := c.collection.Find(query) + return q.One(v) +} + +func (c *cachedCollection) FindOneId(v interface{}, key string, id interface{}) error { + return c.cache.Take(v, key, func(v interface{}) error { + q := c.collection.FindId(id) + return q.One(v) + }) +} + +func (c *cachedCollection) FindOneIdNoCache(v interface{}, id interface{}) error { + q := c.collection.FindId(id) + return q.One(v) +} + +func (c *cachedCollection) Insert(docs ...interface{}) error { + return c.collection.Insert(docs...) +} + +func (c *cachedCollection) Pipe(pipeline interface{}) mongo.Pipe { + return c.collection.Pipe(pipeline) +} + +func (c *cachedCollection) Remove(selector interface{}, keys ...string) error { + if err := c.RemoveNoCache(selector); err != nil { + return err + } + + return c.DelCache(keys...) +} + +func (c *cachedCollection) RemoveNoCache(selector interface{}) error { + return c.collection.Remove(selector) +} + +func (c *cachedCollection) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) { + info, err := c.RemoveAllNoCache(selector) + if err != nil { + return nil, err + } + + if err := c.DelCache(keys...); err != nil { + return nil, err + } + + return info, nil +} + +func (c *cachedCollection) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) { + return c.collection.RemoveAll(selector) +} + +func (c *cachedCollection) RemoveId(id interface{}, keys ...string) error { + if err := c.RemoveIdNoCache(id); err != nil { + return err + } + + return c.DelCache(keys...) +} + +func (c *cachedCollection) RemoveIdNoCache(id interface{}) error { + return c.collection.RemoveId(id) +} + +func (c *cachedCollection) SetCache(key string, v interface{}) error { + return c.cache.SetCache(key, v) +} + +func (c *cachedCollection) Update(selector, update interface{}, keys ...string) error { + if err := c.UpdateNoCache(selector, update); err != nil { + return err + } + + return c.DelCache(keys...) +} + +func (c *cachedCollection) UpdateNoCache(selector, update interface{}) error { + return c.collection.Update(selector, update) +} + +func (c *cachedCollection) UpdateId(id, update interface{}, keys ...string) error { + if err := c.UpdateIdNoCache(id, update); err != nil { + return err + } + + return c.DelCache(keys...) +} + +func (c *cachedCollection) UpdateIdNoCache(id, update interface{}) error { + return c.collection.UpdateId(id, update) +} + +func (c *cachedCollection) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) { + info, err := c.UpsertNoCache(selector, update) + if err != nil { + return nil, err + } + + if err := c.DelCache(keys...); err != nil { + return nil, err + } + + return info, nil +} + +func (c *cachedCollection) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) { + return c.collection.Upsert(selector, update) +} diff --git a/core/stores/mongoc/cachedcollection_test.go b/core/stores/mongoc/cachedcollection_test.go new file mode 100644 index 00000000..11c6f682 --- /dev/null +++ b/core/stores/mongoc/cachedcollection_test.go @@ -0,0 +1,300 @@ +package mongoc + +import ( + "errors" + "io/ioutil" + "log" + "os" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "zero/core/stat" + "zero/core/stores/internal" + "zero/core/stores/mongo" + "zero/core/stores/redis" + + "github.com/alicebob/miniredis" + "github.com/globalsign/mgo" + "github.com/globalsign/mgo/bson" + "github.com/stretchr/testify/assert" +) + +func init() { + stat.SetReporter(nil) +} + +func TestStat(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) + c := newCollection(dummyConn{}, cach) + + for i := 0; i < 10; i++ { + var str string + if err = c.cache.Take(&str, "name", func(v interface{}) error { + *v.(*string) = "zero" + return nil + }); err != nil { + t.Error(err) + } + } + + assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit)) +} + +func TestStatCacheFails(t *testing.T) { + resetStats() + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + r := redis.NewRedis("localhost:59999", redis.NodeType) + cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) + c := newCollection(dummyConn{}, cach) + + for i := 0; i < 20; i++ { + var str string + err := c.FindOne(&str, "name", bson.M{}) + assert.NotNil(t, err) + } + + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit)) + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss)) + assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails)) +} + +func TestStatDbFails(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) + c := newCollection(dummyConn{}, cach) + + for i := 0; i < 20; i++ { + var str string + err := c.cache.Take(&str, "name", func(v interface{}) error { + return errors.New("db failed") + }) + assert.NotNil(t, err) + } + + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit)) + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails)) +} + +func TestStatFromMemory(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) + c := newCollection(dummyConn{}, cach) + + var all sync.WaitGroup + var wait sync.WaitGroup + all.Add(10) + wait.Add(4) + go func() { + var str string + if err := c.cache.Take(&str, "name", func(v interface{}) error { + *v.(*string) = "zero" + return nil + }); err != nil { + t.Error(err) + } + wait.Wait() + runtime.Gosched() + all.Done() + }() + + for i := 0; i < 4; i++ { + go func() { + var str string + wait.Done() + if err := c.cache.Take(&str, "name", func(v interface{}) error { + *v.(*string) = "zero" + return nil + }); err != nil { + t.Error(err) + } + all.Done() + }() + } + for i := 0; i < 5; i++ { + go func() { + var str string + if err := c.cache.Take(&str, "name", func(v interface{}) error { + *v.(*string) = "zero" + return nil + }); err != nil { + t.Error(err) + } + all.Done() + }() + } + all.Wait() + + assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit)) +} + +func resetStats() { + atomic.StoreUint64(&stats.Total, 0) + atomic.StoreUint64(&stats.Hit, 0) + atomic.StoreUint64(&stats.Miss, 0) + atomic.StoreUint64(&stats.DbFails, 0) +} + +type dummyConn struct { +} + +func (c dummyConn) Find(query interface{}) mongo.Query { + return dummyQuery{} +} + +func (c dummyConn) FindId(id interface{}) mongo.Query { + return dummyQuery{} +} + +func (c dummyConn) Insert(docs ...interface{}) error { + return nil +} + +func (c dummyConn) Remove(selector interface{}) error { + return nil +} + +func (dummyConn) Pipe(pipeline interface{}) mongo.Pipe { + return nil +} + +func (c dummyConn) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) { + return nil, nil +} + +func (c dummyConn) RemoveId(id interface{}) error { + return nil +} + +func (c dummyConn) Update(selector, update interface{}) error { + return nil +} + +func (c dummyConn) UpdateId(id, update interface{}) error { + return nil +} +func (c dummyConn) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) { + return nil, nil +} + +type dummyQuery struct { +} + +func (d dummyQuery) All(result interface{}) error { + return nil +} + +func (d dummyQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) { + return nil, nil +} + +func (d dummyQuery) Count() (int, error) { + return 0, nil +} + +func (d dummyQuery) Distinct(key string, result interface{}) error { + return nil +} + +func (d dummyQuery) Explain(result interface{}) error { + return nil +} + +func (d dummyQuery) For(result interface{}, f func() error) error { + return nil +} + +func (d dummyQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) { + return nil, nil +} + +func (d dummyQuery) One(result interface{}) error { + return nil +} + +func (d dummyQuery) Batch(n int) mongo.Query { + return d +} + +func (d dummyQuery) Collation(collation *mgo.Collation) mongo.Query { + return d +} + +func (d dummyQuery) Comment(comment string) mongo.Query { + return d +} + +func (d dummyQuery) Hint(indexKey ...string) mongo.Query { + return d +} + +func (d dummyQuery) Iter() mongo.Iter { + return &mgo.Iter{} +} + +func (d dummyQuery) Limit(n int) mongo.Query { + return d +} + +func (d dummyQuery) LogReplay() mongo.Query { + return d +} + +func (d dummyQuery) Prefetch(p float64) mongo.Query { + return d +} + +func (d dummyQuery) Select(selector interface{}) mongo.Query { + return d +} + +func (d dummyQuery) SetMaxScan(n int) mongo.Query { + return d +} + +func (d dummyQuery) SetMaxTime(duration time.Duration) mongo.Query { + return d +} + +func (d dummyQuery) Skip(n int) mongo.Query { + return d +} + +func (d dummyQuery) Snapshot() mongo.Query { + return d +} + +func (d dummyQuery) Sort(fields ...string) mongo.Query { + return d +} + +func (d dummyQuery) Tail(timeout time.Duration) mongo.Iter { + return &mgo.Iter{} +} diff --git a/core/stores/mongoc/cachedmodel.go b/core/stores/mongoc/cachedmodel.go new file mode 100644 index 00000000..8b880ea7 --- /dev/null +++ b/core/stores/mongoc/cachedmodel.go @@ -0,0 +1,243 @@ +package mongoc + +import ( + "log" + + "zero/core/stores/cache" + "zero/core/stores/internal" + "zero/core/stores/mongo" + "zero/core/stores/redis" + + "github.com/globalsign/mgo" +) + +type Model struct { + *mongo.Model + cache internal.Cache + generateCollection func(*mgo.Session) *cachedCollection +} + +func MustNewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) *Model { + model, err := NewNodeModel(url, database, collection, rds, opts...) + if err != nil { + log.Fatal(err) + } + + return model +} + +func MustNewModel(url, database, collection string, c cache.CacheConf, opts ...cache.Option) *Model { + model, err := NewModel(url, database, collection, c, opts...) + if err != nil { + log.Fatal(err) + } + + return model +} + +func NewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) { + c := internal.NewCacheNode(rds, sharedCalls, stats, mgo.ErrNotFound, opts...) + return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection { + return newCollection(collection, c) + }) +} + +func NewModel(url, database, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) { + c := internal.NewCache(conf, sharedCalls, stats, mgo.ErrNotFound, opts...) + return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection { + return newCollection(collection, c) + }) +} + +func (mm *Model) Count(query interface{}) (int, error) { + return mm.executeInt(func(c *cachedCollection) (int, error) { + return c.Count(query) + }) +} + +func (mm *Model) DelCache(keys ...string) error { + return mm.cache.DelCache(keys...) +} + +func (mm *Model) GetCache(key string, v interface{}) error { + return mm.cache.GetCache(key, v) +} + +func (mm *Model) GetCollection(session *mgo.Session) *cachedCollection { + return mm.generateCollection(session) +} + +func (mm *Model) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error { + return mm.execute(func(c *cachedCollection) error { + return c.FindAllNoCache(v, query, opts...) + }) +} + +func (mm *Model) FindOne(v interface{}, key string, query interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.FindOne(v, key, query) + }) +} + +func (mm *Model) FindOneNoCache(v interface{}, query interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.FindOneNoCache(v, query) + }) +} + +func (mm *Model) FindOneId(v interface{}, key string, id interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.FindOneId(v, key, id) + }) +} + +func (mm *Model) FindOneIdNoCache(v interface{}, id interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.FindOneIdNoCache(v, id) + }) +} + +func (mm *Model) Insert(docs ...interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.Insert(docs...) + }) +} + +func (mm *Model) Pipe(pipeline interface{}) (mongo.Pipe, error) { + return mm.pipe(func(c *cachedCollection) mongo.Pipe { + return c.Pipe(pipeline) + }) +} + +func (mm *Model) Remove(selector interface{}, keys ...string) error { + return mm.execute(func(c *cachedCollection) error { + return c.Remove(selector, keys...) + }) +} + +func (mm *Model) RemoveNoCache(selector interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.RemoveNoCache(selector) + }) +} + +func (mm *Model) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) { + return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) { + return c.RemoveAll(selector, keys...) + }) +} + +func (mm *Model) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) { + return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) { + return c.RemoveAllNoCache(selector) + }) +} + +func (mm *Model) RemoveId(id interface{}, keys ...string) error { + return mm.execute(func(c *cachedCollection) error { + return c.RemoveId(id, keys...) + }) +} + +func (mm *Model) RemoveIdNoCache(id interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.RemoveIdNoCache(id) + }) +} + +func (mm *Model) SetCache(key string, v interface{}) error { + return mm.cache.SetCache(key, v) +} + +func (mm *Model) Update(selector, update interface{}, keys ...string) error { + return mm.execute(func(c *cachedCollection) error { + return c.Update(selector, update, keys...) + }) +} + +func (mm *Model) UpdateNoCache(selector, update interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.UpdateNoCache(selector, update) + }) +} + +func (mm *Model) UpdateId(id, update interface{}, keys ...string) error { + return mm.execute(func(c *cachedCollection) error { + return c.UpdateId(id, update, keys...) + }) +} + +func (mm *Model) UpdateIdNoCache(id, update interface{}) error { + return mm.execute(func(c *cachedCollection) error { + return c.UpdateIdNoCache(id, update) + }) +} + +func (mm *Model) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) { + return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) { + return c.Upsert(selector, update, keys...) + }) +} + +func (mm *Model) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) { + return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) { + return c.UpsertNoCache(selector, update) + }) +} + +func (mm *Model) change(fn func(c *cachedCollection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) { + session, err := mm.TakeSession() + if err != nil { + return nil, err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)) +} + +func (mm *Model) execute(fn func(c *cachedCollection) error) error { + session, err := mm.TakeSession() + if err != nil { + return err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)) +} + +func (mm *Model) executeInt(fn func(c *cachedCollection) (int, error)) (int, error) { + session, err := mm.TakeSession() + if err != nil { + return 0, err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)) +} + +func (mm *Model) pipe(fn func(c *cachedCollection) mongo.Pipe) (mongo.Pipe, error) { + session, err := mm.TakeSession() + if err != nil { + return nil, err + } + defer mm.PutSession(session) + + return fn(mm.GetCollection(session)), nil +} + +func createModel(url, database, collection string, c internal.Cache, + create func(mongo.Collection) *cachedCollection) (*Model, error) { + model, err := mongo.NewModel(url, database, collection) + if err != nil { + return nil, err + } + + return &Model{ + Model: model, + cache: c, + generateCollection: func(session *mgo.Session) *cachedCollection { + collection := model.GetCollection(session) + return create(collection) + }, + }, nil +} diff --git a/core/stores/postgres/postgresql.go b/core/stores/postgres/postgresql.go new file mode 100644 index 00000000..820862b8 --- /dev/null +++ b/core/stores/postgres/postgresql.go @@ -0,0 +1,13 @@ +package postgres + +import ( + "zero/core/stores/sqlx" + + _ "github.com/lib/pq" +) + +const postgreDriverName = "postgres" + +func NewPostgre(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn { + return sqlx.NewSqlConn(postgreDriverName, datasource, opts...) +} diff --git a/core/stores/redis/conf.go b/core/stores/redis/conf.go new file mode 100644 index 00000000..f2fc21c5 --- /dev/null +++ b/core/stores/redis/conf.go @@ -0,0 +1,50 @@ +package redis + +import "errors" + +var ( + ErrEmptyHost = errors.New("empty redis host") + ErrEmptyType = errors.New("empty redis type") + ErrEmptyKey = errors.New("empty redis key") +) + +type ( + RedisConf struct { + Host string + Type string `json:",default=node,options=node|cluster"` + Pass string `json:",optional"` + } + + RedisKeyConf struct { + RedisConf + Key string `json:",optional"` + } +) + +func (rc RedisConf) NewRedis() *Redis { + return NewRedis(rc.Host, rc.Type, rc.Pass) +} + +func (rc RedisConf) Validate() error { + if len(rc.Host) == 0 { + return ErrEmptyHost + } + + if len(rc.Type) == 0 { + return ErrEmptyType + } + + return nil +} + +func (rkc RedisKeyConf) Validate() error { + if err := rkc.RedisConf.Validate(); err != nil { + return err + } + + if len(rkc.Key) == 0 { + return ErrEmptyKey + } + + return nil +} diff --git a/core/stores/redis/process.go b/core/stores/redis/process.go new file mode 100644 index 00000000..141257c9 --- /dev/null +++ b/core/stores/redis/process.go @@ -0,0 +1,33 @@ +package redis + +import ( + "strings" + + "zero/core/logx" + "zero/core/mapping" + "zero/core/timex" + + red "github.com/go-redis/redis" +) + +func process(proc func(red.Cmder) error) func(red.Cmder) error { + return func(cmd red.Cmder) error { + start := timex.Now() + + defer func() { + duration := timex.Since(start) + if duration > slowThreshold { + var buf strings.Builder + for i, arg := range cmd.Args() { + if i > 0 { + buf.WriteByte(' ') + } + buf.WriteString(mapping.Repr(arg)) + } + logx.WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String()) + } + }() + + return proc(cmd) + } +} diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go new file mode 100644 index 00000000..6c052794 --- /dev/null +++ b/core/stores/redis/redis.go @@ -0,0 +1,1339 @@ +package redis + +import ( + "errors" + "fmt" + "strconv" + "time" + + "zero/core/breaker" + "zero/core/mapping" + + red "github.com/go-redis/redis" +) + +const ( + ClusterType = "cluster" + NodeType = "node" + Nil = red.Nil + + blockingQueryTimeout = 5 * time.Second + readWriteTimeout = 2 * time.Second + + slowThreshold = time.Millisecond * 100 +) + +var ErrNilNode = errors.New("nil redis node") + +type ( + Pair struct { + Key string + Score int64 + } + + // thread-safe + Redis struct { + Addr string + Type string + Pass string + brk breaker.Breaker + } + + RedisNode interface { + red.Cmdable + } + + Pipeliner = red.Pipeliner + + // Z represents sorted set member. + Z = red.Z + + IntCmd = red.IntCmd + FloatCmd = red.FloatCmd +) + +func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis { + var pass string + for _, v := range redisPass { + pass = v + } + + return &Redis{ + Addr: redisAddr, + Type: redisType, + Pass: pass, + brk: breaker.NewBreaker(), + } +} + +// Use passed in redis connection to execute blocking queries +// Doesn't benefit from pooling redis connections of blocking queries +func (s *Redis) Blpop(redisNode RedisNode, key string) (string, error) { + if redisNode == nil { + return "", ErrNilNode + } + + vals, err := redisNode.BLPop(blockingQueryTimeout, key).Result() + if err != nil { + return "", err + } + + if len(vals) < 2 { + return "", fmt.Errorf("no value on key: %s", key) + } else { + return vals[1], nil + } +} + +func (s *Redis) BlpopEx(redisNode RedisNode, key string) (string, bool, error) { + if redisNode == nil { + return "", false, ErrNilNode + } + + vals, err := redisNode.BLPop(blockingQueryTimeout, key).Result() + if err != nil { + return "", false, err + } + + if len(vals) < 2 { + return "", false, fmt.Errorf("no value on key: %s", key) + } else { + return vals[1], true, nil + } +} + +func (s *Redis) Del(keys ...string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.Del(keys...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Eval(script string, keys []string, args ...interface{}) (val interface{}, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.Eval(script, keys, args...).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Exists(key string) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.Exists(key).Result(); err != nil { + return err + } else { + val = v == 1 + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Expire(key string, seconds int) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + return conn.Expire(key, time.Duration(seconds)*time.Second).Err() + }, acceptable) +} + +func (s *Redis) Expireat(key string, expireTime int64) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + return conn.ExpireAt(key, time.Unix(expireTime, 0)).Err() + }, acceptable) +} + +func (s *Redis) Get(key string) (val string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if val, err = conn.Get(key).Result(); err == red.Nil { + return nil + } else if err != nil { + return err + } else { + return nil + } + }, acceptable) + + return +} + +func (s *Redis) GetBit(key string, offset int64) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.GetBit(key, offset).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Hdel(key, field string) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.HDel(key, field).Result(); err != nil { + return err + } else { + val = v == 1 + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Hexists(key, field string) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.HExists(key, field).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Hget(key, field string) (val string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.HGet(key, field).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Hgetall(key string) (val map[string]string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.HGetAll(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Hincrby(key, field string, increment int) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.HIncrBy(key, field, int64(increment)).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Hkeys(key string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.HKeys(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Hlen(key string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.HLen(key).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Hmget(key string, fields ...string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.HMGet(key, fields...).Result(); err != nil { + return err + } else { + val = toStrings(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Hset(key, field, value string) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + return conn.HSet(key, field, value).Err() + }, acceptable) +} + +func (s *Redis) Hsetnx(key, field, value string) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.HSetNX(key, field, value).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Hmset(key string, fieldsAndValues map[string]string) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + vals := make(map[string]interface{}, len(fieldsAndValues)) + for k, v := range fieldsAndValues { + vals[k] = v + } + + return conn.HMSet(key, vals).Err() + }, acceptable) +} + +func (s *Redis) Hvals(key string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.HVals(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Incr(key string) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.Incr(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Incrby(key string, increment int64) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.IncrBy(key, int64(increment)).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Keys(pattern string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.Keys(pattern).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Llen(key string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.LLen(key).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Lpop(key string) (val string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.LPop(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Lpush(key string, values ...interface{}) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.LPush(key, values...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Lrange(key string, start int, stop int) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.LRange(key, int64(start), int64(stop)).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Lrem(key string, count int, value string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.LRem(key, int64(count), value).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Mget(keys ...string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.MGet(keys...).Result(); err != nil { + return err + } else { + val = toStrings(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Persist(key string) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.Persist(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Pfadd(key string, values ...interface{}) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.PFAdd(key, values...).Result(); err != nil { + return err + } else { + val = v == 1 + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Pfcount(key string) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.PFCount(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Pfmerge(dest string, keys ...string) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + _, err = conn.PFMerge(dest, keys...).Result() + return err + }, acceptable) +} + +func (s *Redis) Ping() (val bool) { + // ignore error, error means false + _ = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + val = false + return nil + } + + if v, err := conn.Ping().Result(); err != nil { + val = false + return nil + } else { + val = v == "PONG" + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Pipelined(fn func(Pipeliner) error) (err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + _, err = conn.Pipelined(fn) + return err + + }, acceptable) + + return +} + +func (s *Redis) Rpush(key string, values ...interface{}) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.RPush(key, values...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Sadd(key string, values ...interface{}) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.SAdd(key, values...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Scan(cursor uint64, match string, count int64) (keys []string, cur uint64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + keys, cur, err = conn.Scan(cursor, match, count).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) SetBit(key string, offset int64, value int) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + _, err = conn.SetBit(key, offset, value).Result() + return err + }, acceptable) +} + +func (s *Redis) Sscan(key string, cursor uint64, match string, count int64) (keys []string, cur uint64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + keys, cur, err = conn.SScan(key, cursor, match, count).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Scard(key string) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SCard(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Set(key string, value string) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + return conn.Set(key, value, 0).Err() + }, acceptable) +} + +func (s *Redis) Setex(key, value string, seconds int) error { + return s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + return conn.Set(key, value, time.Duration(seconds)*time.Second).Err() + }, acceptable) +} + +func (s *Redis) Setnx(key, value string) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SetNX(key, value, 0).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) SetnxEx(key, value string, seconds int) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SetNX(key, value, time.Duration(seconds)*time.Second).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Sismember(key string, value interface{}) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + val, err = conn.SIsMember(key, value).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.SRem(key, values...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Smembers(key string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SMembers(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Spop(key string) (val string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SPop(key).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Srandmember(key string, count int) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SRandMemberN(key, int64(count)).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Sunion(keys ...string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SUnion(keys...).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Sunionstore(destination string, keys ...string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.SUnionStore(destination, keys...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Sdiff(keys ...string) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.SDiff(keys...).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Sdiffstore(destination string, keys ...string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.SDiffStore(destination, keys...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Ttl(key string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if duration, err := conn.TTL(key).Result(); err != nil { + return err + } else { + val = int(duration / time.Second) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zadd(key string, score int64, value string) (val bool, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZAdd(key, red.Z{ + Score: float64(score), + Member: value, + }).Result(); err != nil { + return err + } else { + val = v == 1 + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zadds(key string, ps ...Pair) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + var zs []red.Z + for _, p := range ps { + z := red.Z{Score: float64(p.Score), Member: p.Key} + zs = append(zs, z) + } + + if v, err := conn.ZAdd(key, zs...).Result(); err != nil { + return err + } else { + val = v + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zcard(key string) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZCard(key).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zcount(key string, start, stop int64) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZCount(key, strconv.FormatInt(start, 10), + strconv.FormatInt(stop, 10)).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zincrby(key string, increment int64, field string) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZIncrBy(key, float64(increment), field).Result(); err != nil { + return err + } else { + val = int64(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zscore(key string, value string) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZScore(key, value).Result(); err != nil { + return err + } else { + val = int64(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zrank(key, field string) (val int64, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.ZRank(key, field).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) Zrem(key string, values ...interface{}) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRem(key, values...).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zremrangebyscore(key string, start, stop int64) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRemRangeByScore(key, strconv.FormatInt(start, 10), + strconv.FormatInt(stop, 10)).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zremrangebyrank(key string, start, stop int64) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRemRangeByRank(key, start, stop).Result(); err != nil { + return err + } else { + val = int(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zrange(key string, start, stop int64) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.ZRange(key, start, stop).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) ZrangeWithScores(key string, start, stop int64) (val []Pair, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRangeWithScores(key, start, stop).Result(); err != nil { + return err + } else { + val = toPairs(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) ZRevRangeWithScores(key string, start, stop int64) (val []Pair, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRevRangeWithScores(key, start, stop).Result(); err != nil { + return err + } else { + val = toPairs(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) ZrangebyscoreWithScores(key string, start, stop int64) (val []Pair, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRangeByScoreWithScores(key, red.ZRangeBy{ + Min: strconv.FormatInt(start, 10), + Max: strconv.FormatInt(stop, 10), + }).Result(); err != nil { + return err + } else { + val = toPairs(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) ZrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ( + val []Pair, err error) { + err = s.brk.DoWithAcceptable(func() error { + if size <= 0 { + return nil + } + + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRangeByScoreWithScores(key, red.ZRangeBy{ + Min: strconv.FormatInt(start, 10), + Max: strconv.FormatInt(stop, 10), + Offset: int64(page * size), + Count: int64(size), + }).Result(); err != nil { + return err + } else { + val = toPairs(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) Zrevrange(key string, start, stop int64) (val []string, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = conn.ZRevRange(key, start, stop).Result() + return err + }, acceptable) + + return +} + +func (s *Redis) ZrevrangebyscoreWithScores(key string, start, stop int64) (val []Pair, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRevRangeByScoreWithScores(key, red.ZRangeBy{ + Min: strconv.FormatInt(start, 10), + Max: strconv.FormatInt(stop, 10), + }).Result(); err != nil { + return err + } else { + val = toPairs(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ( + val []Pair, err error) { + err = s.brk.DoWithAcceptable(func() error { + if size <= 0 { + return nil + } + + conn, err := getRedis(s) + if err != nil { + return err + } + + if v, err := conn.ZRevRangeByScoreWithScores(key, red.ZRangeBy{ + Min: strconv.FormatInt(start, 10), + Max: strconv.FormatInt(stop, 10), + Offset: int64(page * size), + Count: int64(size), + }).Result(); err != nil { + return err + } else { + val = toPairs(v) + return nil + } + }, acceptable) + + return +} + +func (s *Redis) String() string { + return s.Addr +} + +func (s *Redis) scriptLoad(script string) (string, error) { + conn, err := getRedis(s) + if err != nil { + return "", err + } + + return conn.ScriptLoad(script).Result() +} + +func acceptable(err error) bool { + return err == nil || err == red.Nil +} + +func getRedis(r *Redis) (RedisNode, error) { + switch r.Type { + case ClusterType: + return getCluster(r.Addr, r.Pass) + case NodeType: + return getClient(r.Addr, r.Pass) + default: + return nil, fmt.Errorf("redis type '%s' is not supported", r.Type) + } +} + +func toPairs(vals []red.Z) []Pair { + pairs := make([]Pair, len(vals)) + for i, val := range vals { + switch member := val.Member.(type) { + case string: + pairs[i] = Pair{ + Key: member, + Score: int64(val.Score), + } + default: + pairs[i] = Pair{ + Key: mapping.Repr(val.Member), + Score: int64(val.Score), + } + } + } + return pairs +} + +func toStrings(vals []interface{}) []string { + ret := make([]string, len(vals)) + for i, val := range vals { + if val == nil { + ret[i] = "" + } else { + switch val := val.(type) { + case string: + ret[i] = val + default: + ret[i] = mapping.Repr(val) + } + } + } + return ret +} diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go new file mode 100644 index 00000000..e1a3a5aa --- /dev/null +++ b/core/stores/redis/redis_test.go @@ -0,0 +1,580 @@ +package redis + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +func TestRedis_Exists(t *testing.T) { + runOnRedis(t, func(client *Redis) { + ok, err := client.Exists("a") + assert.Nil(t, err) + assert.False(t, ok) + assert.Nil(t, client.Set("a", "b")) + ok, err = client.Exists("a") + assert.Nil(t, err) + assert.True(t, ok) + }) +} + +func TestRedis_Eval(t *testing.T) { + runOnRedis(t, func(client *Redis) { + _, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"}) + assert.Equal(t, Nil, err) + err = client.Set("key1", "value1") + assert.Nil(t, err) + _, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"key1"}) + assert.Equal(t, Nil, err) + val, err := client.Eval(`return redis.call("EXISTS", KEYS[1])`, []string{"key1"}) + assert.Nil(t, err) + assert.Equal(t, int64(1), val) + }) +} + +func TestRedis_Hgetall(t *testing.T) { + runOnRedis(t, func(client *Redis) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hgetall("a") + assert.Nil(t, err) + assert.EqualValues(t, map[string]string{ + "aa": "aaa", + "bb": "bbb", + }, vals) + }) +} + +func TestRedis_Hvals(t *testing.T) { + runOnRedis(t, func(client *Redis) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hvals("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals) + }) +} + +func TestRedis_Hsetnx(t *testing.T) { + runOnRedis(t, func(client *Redis) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + ok, err := client.Hsetnx("a", "bb", "ccc") + assert.Nil(t, err) + assert.False(t, ok) + ok, err = client.Hsetnx("a", "dd", "ddd") + assert.Nil(t, err) + assert.True(t, ok) + vals, err := client.Hvals("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"aaa", "bbb", "ddd"}, vals) + }) +} + +func TestRedis_HdelHlen(t *testing.T) { + runOnRedis(t, func(client *Redis) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + num, err := client.Hlen("a") + assert.Nil(t, err) + assert.Equal(t, 2, num) + val, err := client.Hdel("a", "aa") + assert.Nil(t, err) + assert.True(t, val) + vals, err := client.Hvals("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"bbb"}, vals) + }) +} + +func TestRedis_HIncrBy(t *testing.T) { + runOnRedis(t, func(client *Redis) { + val, err := client.Hincrby("key", "field", 2) + assert.Nil(t, err) + assert.Equal(t, 2, val) + val, err = client.Hincrby("key", "field", 3) + assert.Nil(t, err) + assert.Equal(t, 5, val) + }) +} + +func TestRedis_Hkeys(t *testing.T) { + runOnRedis(t, func(client *Redis) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hkeys("a") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"aa", "bb"}, vals) + }) +} + +func TestRedis_Hmget(t *testing.T) { + runOnRedis(t, func(client *Redis) { + assert.Nil(t, client.Hset("a", "aa", "aaa")) + assert.Nil(t, client.Hset("a", "bb", "bbb")) + vals, err := client.Hmget("a", "aa", "bb") + assert.Nil(t, err) + assert.EqualValues(t, []string{"aaa", "bbb"}, vals) + vals, err = client.Hmget("a", "aa", "no", "bb") + assert.Nil(t, err) + assert.EqualValues(t, []string{"aaa", "", "bbb"}, vals) + }) +} + +func TestRedis_Hmset(t *testing.T) { + runOnRedis(t, func(client *Redis) { + assert.Nil(t, client.Hmset("a", map[string]string{ + "aa": "aaa", + "bb": "bbb", + })) + vals, err := client.Hmget("a", "aa", "bb") + assert.Nil(t, err) + assert.EqualValues(t, []string{"aaa", "bbb"}, vals) + }) +} + +func TestRedis_Incr(t *testing.T) { + runOnRedis(t, func(client *Redis) { + val, err := client.Incr("a") + assert.Nil(t, err) + assert.Equal(t, int64(1), val) + val, err = client.Incr("a") + assert.Nil(t, err) + assert.Equal(t, int64(2), val) + }) +} + +func TestRedis_IncrBy(t *testing.T) { + runOnRedis(t, func(client *Redis) { + val, err := client.Incrby("a", 2) + assert.Nil(t, err) + assert.Equal(t, int64(2), val) + val, err = client.Incrby("a", 3) + assert.Nil(t, err) + assert.Equal(t, int64(5), val) + }) +} + +func TestRedis_Keys(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.Set("key1", "value1") + assert.Nil(t, err) + err = client.Set("key2", "value2") + assert.Nil(t, err) + keys, err := client.Keys("*") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"key1", "key2"}, keys) + }) +} + +func TestRedis_List(t *testing.T) { + runOnRedis(t, func(client *Redis) { + val, err := client.Lpush("key", "value1", "value2") + assert.Nil(t, err) + assert.Equal(t, 2, val) + val, err = client.Rpush("key", "value3", "value4") + assert.Nil(t, err) + assert.Equal(t, 4, val) + val, err = client.Llen("key") + assert.Nil(t, err) + assert.Equal(t, 4, val) + vals, err := client.Lrange("key", 0, 10) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals) + v, err := client.Lpop("key") + assert.Nil(t, err) + assert.Equal(t, "value2", v) + val, err = client.Lpush("key", "value1", "value2") + assert.Nil(t, err) + assert.Equal(t, 5, val) + val, err = client.Rpush("key", "value3", "value3") + assert.Nil(t, err) + assert.Equal(t, 7, val) + n, err := client.Lrem("key", 2, "value1") + assert.Nil(t, err) + assert.Equal(t, 2, n) + vals, err = client.Lrange("key", 0, 10) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals) + n, err = client.Lrem("key", -2, "value3") + assert.Nil(t, err) + assert.Equal(t, 2, n) + vals, err = client.Lrange("key", 0, 10) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals) + }) +} + +func TestRedis_Mget(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.Set("key1", "value1") + assert.Nil(t, err) + err = client.Set("key2", "value2") + assert.Nil(t, err) + vals, err := client.Mget("key1", "key0", "key2", "key3") + assert.Nil(t, err) + assert.EqualValues(t, []string{"value1", "", "value2", ""}, vals) + }) +} + +func TestRedis_SetBit(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.SetBit("key", 1, 1) + assert.Nil(t, err) + }) +} + +func TestRedis_GetBit(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.SetBit("key", 2, 1) + assert.Nil(t, err) + val, err := client.GetBit("key", 2) + assert.Nil(t, err) + assert.Equal(t, 1, val) + }) +} + +func TestRedis_Persist(t *testing.T) { + runOnRedis(t, func(client *Redis) { + ok, err := client.Persist("key") + assert.Nil(t, err) + assert.False(t, ok) + err = client.Set("key", "value") + assert.Nil(t, err) + ok, err = client.Persist("key") + assert.Nil(t, err) + assert.False(t, ok) + err = client.Expire("key", 5) + ok, err = client.Persist("key") + assert.Nil(t, err) + assert.True(t, ok) + err = client.Expireat("key", time.Now().Unix()+5) + ok, err = client.Persist("key") + assert.Nil(t, err) + assert.True(t, ok) + }) +} + +func TestRedis_Ping(t *testing.T) { + runOnRedis(t, func(client *Redis) { + ok := client.Ping() + assert.True(t, ok) + }) +} + +func TestRedis_Scan(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.Set("key1", "value1") + assert.Nil(t, err) + err = client.Set("key2", "value2") + assert.Nil(t, err) + keys, _, err := client.Scan(0, "*", 100) + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"key1", "key2"}, keys) + }) +} + +func TestRedis_Sscan(t *testing.T) { + runOnRedis(t, func(client *Redis) { + key := "list" + var list []string + for i := 0; i < 1550; i++ { + list = append(list, randomStr(i)) + } + lens, err := client.Sadd(key, list) + assert.Nil(t, err) + assert.Equal(t, lens, 1550) + + var cursor uint64 = 0 + sum := 0 + for { + keys, next, err := client.Sscan(key, cursor, "", 100) + assert.Nil(t, err) + sum += len(keys) + if next == 0 { + break + } + cursor = next + } + + assert.Equal(t, sum, 1550) + _, err = client.Del(key) + assert.Nil(t, err) + }) +} + +func TestRedis_Set(t *testing.T) { + runOnRedis(t, func(client *Redis) { + num, err := client.Sadd("key", 1, 2, 3, 4) + assert.Nil(t, err) + assert.Equal(t, 4, num) + val, err := client.Scard("key") + assert.Nil(t, err) + assert.Equal(t, int64(4), val) + ok, err := client.Sismember("key", 2) + assert.Nil(t, err) + assert.True(t, ok) + num, err = client.Srem("key", 3, 4) + assert.Nil(t, err) + assert.Equal(t, 2, num) + vals, err := client.Smembers("key") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"1", "2"}, vals) + members, err := client.Srandmember("key", 1) + assert.Nil(t, err) + assert.Len(t, members, 1) + assert.Contains(t, []string{"1", "2"}, members[0]) + member, err := client.Spop("key") + assert.Nil(t, err) + assert.Contains(t, []string{"1", "2"}, member) + vals, err = client.Smembers("key") + assert.Nil(t, err) + assert.NotContains(t, vals, member) + num, err = client.Sadd("key1", 1, 2, 3, 4) + assert.Nil(t, err) + assert.Equal(t, 4, num) + num, err = client.Sadd("key2", 2, 3, 4, 5) + assert.Nil(t, err) + assert.Equal(t, 4, num) + vals, err = client.Sunion("key1", "key2") + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, vals) + num, err = client.Sunionstore("key3", "key1", "key2") + assert.Nil(t, err) + assert.Equal(t, 5, num) + vals, err = client.Sdiff("key1", "key2") + assert.Nil(t, err) + assert.EqualValues(t, []string{"1"}, vals) + num, err = client.Sdiffstore("key4", "key1", "key2") + assert.Nil(t, err) + assert.Equal(t, 1, num) + }) +} + +func TestRedis_SetGetDel(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.Set("hello", "world") + assert.Nil(t, err) + val, err := client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "world", val) + ret, err := client.Del("hello") + assert.Nil(t, err) + assert.Equal(t, 1, ret) + }) +} + +func TestRedis_SetExNx(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.Setex("hello", "world", 5) + assert.Nil(t, err) + ok, err := client.Setnx("hello", "newworld") + assert.Nil(t, err) + assert.False(t, ok) + ok, err = client.Setnx("newhello", "newworld") + assert.Nil(t, err) + assert.True(t, ok) + val, err := client.Get("hello") + assert.Nil(t, err) + assert.Equal(t, "world", val) + val, err = client.Get("newhello") + assert.Nil(t, err) + assert.Equal(t, "newworld", val) + ttl, err := client.Ttl("hello") + assert.Nil(t, err) + assert.True(t, ttl > 0) + ok, err = client.SetnxEx("newhello", "newworld", 5) + assert.Nil(t, err) + assert.False(t, ok) + num, err := client.Del("newhello") + assert.Nil(t, err) + assert.Equal(t, 1, num) + ok, err = client.SetnxEx("newhello", "newworld", 5) + assert.Nil(t, err) + assert.True(t, ok) + val, err = client.Get("newhello") + assert.Nil(t, err) + assert.Equal(t, "newworld", val) + }) +} + +func TestRedis_SetGetDelHashField(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.Hset("key", "field", "value") + assert.Nil(t, err) + val, err := client.Hget("key", "field") + assert.Nil(t, err) + assert.Equal(t, "value", val) + ok, err := client.Hexists("key", "field") + assert.Nil(t, err) + assert.True(t, ok) + ret, err := client.Hdel("key", "field") + assert.Nil(t, err) + assert.True(t, ret) + ok, err = client.Hexists("key", "field") + assert.Nil(t, err) + assert.False(t, ok) + }) +} + +func TestRedis_SortedSet(t *testing.T) { + runOnRedis(t, func(client *Redis) { + ok, err := client.Zadd("key", 1, "value1") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 2, "value1") + assert.Nil(t, err) + assert.False(t, ok) + val, err := client.Zscore("key", "value1") + assert.Nil(t, err) + assert.Equal(t, int64(2), val) + val, err = client.Zincrby("key", 3, "value1") + assert.Nil(t, err) + assert.Equal(t, int64(5), val) + val, err = client.Zscore("key", "value1") + assert.Nil(t, err) + assert.Equal(t, int64(5), val) + ok, err = client.Zadd("key", 6, "value2") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 7, "value3") + assert.Nil(t, err) + assert.True(t, ok) + rank, err := client.Zrank("key", "value2") + assert.Nil(t, err) + assert.Equal(t, int64(1), rank) + rank, err = client.Zrank("key", "value4") + assert.Equal(t, Nil, err) + num, err := client.Zrem("key", "value2", "value3") + assert.Nil(t, err) + assert.Equal(t, 2, num) + ok, err = client.Zadd("key", 6, "value2") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 7, "value3") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 8, "value4") + assert.Nil(t, err) + assert.True(t, ok) + num, err = client.Zremrangebyscore("key", 6, 7) + assert.Nil(t, err) + assert.Equal(t, 2, num) + ok, err = client.Zadd("key", 6, "value2") + assert.Nil(t, err) + assert.True(t, ok) + ok, err = client.Zadd("key", 7, "value3") + assert.Nil(t, err) + assert.True(t, ok) + num, err = client.Zcount("key", 6, 7) + assert.Nil(t, err) + assert.Equal(t, 2, num) + num, err = client.Zremrangebyrank("key", 1, 2) + assert.Nil(t, err) + assert.Equal(t, 2, num) + card, err := client.Zcard("key") + assert.Nil(t, err) + assert.Equal(t, 2, card) + vals, err := client.Zrange("key", 0, -1) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value1", "value4"}, vals) + vals, err = client.Zrevrange("key", 0, -1) + assert.Nil(t, err) + assert.EqualValues(t, []string{"value4", "value1"}, vals) + pairs, err := client.ZrangeWithScores("key", 0, -1) + assert.Nil(t, err) + assert.EqualValues(t, []Pair{ + { + Key: "value1", + Score: 5, + }, + { + Key: "value4", + Score: 8, + }, + }, pairs) + pairs, err = client.ZrangebyscoreWithScores("key", 5, 8) + assert.Nil(t, err) + assert.EqualValues(t, []Pair{ + { + Key: "value1", + Score: 5, + }, + { + Key: "value4", + Score: 8, + }, + }, pairs) + pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1) + assert.Nil(t, err) + assert.EqualValues(t, []Pair{ + { + Key: "value4", + Score: 8, + }, + }, pairs) + pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8) + assert.Nil(t, err) + assert.EqualValues(t, []Pair{ + { + Key: "value4", + Score: 8, + }, + { + Key: "value1", + Score: 5, + }, + }, pairs) + pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1) + assert.Nil(t, err) + assert.EqualValues(t, []Pair{ + { + Key: "value1", + Score: 5, + }, + }, pairs) + }) +} + +func TestRedis_Pipelined(t *testing.T) { + runOnRedis(t, func(client *Redis) { + err := client.Pipelined( + func(pipe Pipeliner) error { + pipe.Incr("pipelined_counter") + pipe.Expire("pipelined_counter", time.Hour) + pipe.ZAdd("zadd", Z{Score: 12, Member: "zadd"}) + return nil + }, + ) + assert.Nil(t, err) + ttl, err := client.Ttl("pipelined_counter") + assert.Nil(t, err) + assert.Equal(t, 3600, ttl) + value, err := client.Get("pipelined_counter") + assert.Nil(t, err) + assert.Equal(t, "1", value) + score, err := client.Zscore("zadd", "zadd") + assert.Equal(t, int64(12), score) + }) +} + +func runOnRedis(t *testing.T, fn func(client *Redis)) { + s, err := miniredis.Run() + assert.Nil(t, err) + defer func() { + client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) { + return nil, errors.New("should already exist") + }) + if err != nil { + t.Error(err) + } + + client.Close() + }() + + fn(NewRedis(s.Addr(), NodeType)) +} diff --git a/core/stores/redis/redisblockingnode.go b/core/stores/redis/redisblockingnode.go new file mode 100644 index 00000000..3f377bb4 --- /dev/null +++ b/core/stores/redis/redisblockingnode.go @@ -0,0 +1,66 @@ +package redis + +import ( + "fmt" + + "zero/core/logx" + + red "github.com/go-redis/redis" +) + +type ClosableNode interface { + RedisNode + Close() +} + +func CreateBlockingNode(r *Redis) (ClosableNode, error) { + timeout := readWriteTimeout + blockingQueryTimeout + + switch r.Type { + case NodeType: + client := red.NewClient(&red.Options{ + Addr: r.Addr, + Password: r.Pass, + DB: defaultDatabase, + MaxRetries: maxRetries, + PoolSize: 1, + MinIdleConns: 1, + ReadTimeout: timeout, + }) + return &clientBridge{client}, nil + case ClusterType: + client := red.NewClusterClient(&red.ClusterOptions{ + Addrs: []string{r.Addr}, + Password: r.Pass, + MaxRetries: maxRetries, + PoolSize: 1, + MinIdleConns: 1, + ReadTimeout: timeout, + }) + return &clusterBridge{client}, nil + default: + return nil, fmt.Errorf("unknown redis type: %s", r.Type) + } +} + +type ( + clientBridge struct { + *red.Client + } + + clusterBridge struct { + *red.ClusterClient + } +) + +func (bridge *clientBridge) Close() { + if err := bridge.Client.Close(); err != nil { + logx.Errorf("Error occurred on close redis client: %s", err) + } +} + +func (bridge *clusterBridge) Close() { + if err := bridge.ClusterClient.Close(); err != nil { + logx.Errorf("Error occurred on close redis cluster: %s", err) + } +} diff --git a/core/stores/redis/redisclientmanager.go b/core/stores/redis/redisclientmanager.go new file mode 100644 index 00000000..7a0cff37 --- /dev/null +++ b/core/stores/redis/redisclientmanager.go @@ -0,0 +1,36 @@ +package redis + +import ( + "io" + + "zero/core/syncx" + + red "github.com/go-redis/redis" +) + +const ( + defaultDatabase = 0 + maxRetries = 3 + idleConns = 8 +) + +var clientManager = syncx.NewResourceManager() + +func getClient(server, pass string) (*red.Client, error) { + val, err := clientManager.GetResource(server, func() (io.Closer, error) { + store := red.NewClient(&red.Options{ + Addr: server, + Password: pass, + DB: defaultDatabase, + MaxRetries: maxRetries, + MinIdleConns: idleConns, + }) + store.WrapProcess(process) + return store, nil + }) + if err != nil { + return nil, err + } + + return val.(*red.Client), nil +} diff --git a/core/stores/redis/redisclustermanager.go b/core/stores/redis/redisclustermanager.go new file mode 100644 index 00000000..f11f1ab7 --- /dev/null +++ b/core/stores/redis/redisclustermanager.go @@ -0,0 +1,30 @@ +package redis + +import ( + "io" + + "zero/core/syncx" + + red "github.com/go-redis/redis" +) + +var clusterManager = syncx.NewResourceManager() + +func getCluster(server, pass string) (*red.ClusterClient, error) { + val, err := clusterManager.GetResource(server, func() (io.Closer, error) { + store := red.NewClusterClient(&red.ClusterOptions{ + Addrs: []string{server}, + Password: pass, + MaxRetries: maxRetries, + MinIdleConns: idleConns, + }) + store.WrapProcess(process) + + return store, nil + }) + if err != nil { + return nil, err + } + + return val.(*red.ClusterClient), nil +} diff --git a/core/stores/redis/redislock.go b/core/stores/redis/redislock.go new file mode 100644 index 00000000..3d2c718e --- /dev/null +++ b/core/stores/redis/redislock.go @@ -0,0 +1,96 @@ +package redis + +import ( + "math/rand" + "strconv" + "sync/atomic" + "time" + + "zero/core/logx" + + red "github.com/go-redis/redis" +) + +const ( + letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + lockCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2]) + return "OK" +else + return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2]) +end` + delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then + return redis.call("DEL", KEYS[1]) +else + return 0 +end` + randomLen = 16 + tolerance = 500 // milliseconds + millisPerSecond = 1000 +) + +type RedisLock struct { + store *Redis + seconds uint32 + key string + id string +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func NewRedisLock(store *Redis, key string) *RedisLock { + return &RedisLock{ + store: store, + key: key, + id: randomStr(randomLen), + } +} + +func (rl *RedisLock) Acquire() (bool, error) { + seconds := atomic.LoadUint32(&rl.seconds) + resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{ + rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance)}) + if err == red.Nil { + return false, nil + } else if err != nil { + logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error()) + return false, err + } else if resp == nil { + return false, nil + } + + reply, ok := resp.(string) + if ok && reply == "OK" { + return true, nil + } else { + logx.Errorf("Unknown reply when acquiring lock for %s: %v", rl.key, resp) + return false, nil + } +} + +func (rl *RedisLock) Release() (bool, error) { + resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id}) + if err != nil { + return false, err + } + + if reply, ok := resp.(int64); !ok { + return false, nil + } else { + return reply == 1, nil + } +} + +func (rl *RedisLock) SetExpire(seconds int) { + atomic.StoreUint32(&rl.seconds, uint32(seconds)) +} + +func randomStr(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} diff --git a/core/stores/redis/redislock_test.go b/core/stores/redis/redislock_test.go new file mode 100644 index 00000000..fd36d588 --- /dev/null +++ b/core/stores/redis/redislock_test.go @@ -0,0 +1,34 @@ +package redis + +import ( + "testing" + + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +func TestRedisLock(t *testing.T) { + runOnRedis(t, func(client *Redis) { + key := stringx.Rand() + firstLock := NewRedisLock(client, key) + firstLock.SetExpire(5) + firstAcquire, err := firstLock.Acquire() + assert.Nil(t, err) + assert.True(t, firstAcquire) + + secondLock := NewRedisLock(client, key) + secondLock.SetExpire(5) + againAcquire, err := secondLock.Acquire() + assert.Nil(t, err) + assert.False(t, againAcquire) + + release, err := firstLock.Release() + assert.Nil(t, err) + assert.True(t, release) + + endAcquire, err := secondLock.Acquire() + assert.Nil(t, err) + assert.True(t, endAcquire) + }) +} diff --git a/core/stores/redis/scriptcache.go b/core/stores/redis/scriptcache.go new file mode 100644 index 00000000..2e38f272 --- /dev/null +++ b/core/stores/redis/scriptcache.go @@ -0,0 +1,48 @@ +package redis + +import ( + "sync" + "sync/atomic" +) + +var ( + once sync.Once + lock sync.Mutex + instance *ScriptCache +) + +type ( + Map map[string]string + + ScriptCache struct { + atomic.Value + } +) + +func GetScriptCache() *ScriptCache { + once.Do(func() { + instance = &ScriptCache{} + instance.Store(make(Map)) + }) + + return instance +} + +func (sc *ScriptCache) GetSha(script string) (string, bool) { + cache := sc.Load().(Map) + ret, ok := cache[script] + return ret, ok +} + +func (sc *ScriptCache) SetSha(script, sha string) { + lock.Lock() + defer lock.Unlock() + + cache := sc.Load().(Map) + newCache := make(Map) + for k, v := range cache { + newCache[k] = v + } + newCache[script] = sha + sc.Store(newCache) +} diff --git a/core/stores/sqlc/cachedsql.go b/core/stores/sqlc/cachedsql.go new file mode 100644 index 00000000..da647d7d --- /dev/null +++ b/core/stores/sqlc/cachedsql.go @@ -0,0 +1,122 @@ +package sqlc + +import ( + "database/sql" + "time" + + "zero/core/stores/cache" + "zero/core/stores/internal" + "zero/core/stores/redis" + "zero/core/stores/sqlx" + "zero/core/syncx" +) + +// see doc/sql-cache.md +const cacheSafeGapBetweenIndexAndPrimary = time.Second * 5 + +var ( + ErrNotFound = sqlx.ErrNotFound + + // can't use one SharedCalls per conn, because multiple conns may share the same cache key. + exclusiveCalls = syncx.NewSharedCalls() + stats = internal.NewCacheStat("sqlc") +) + +type ( + ExecFn func(conn sqlx.SqlConn) (sql.Result, error) + IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error) + PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error + QueryFn func(conn sqlx.SqlConn, v interface{}) error + + CachedConn struct { + db sqlx.SqlConn + cache internal.Cache + } +) + +func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn { + return CachedConn{ + db: db, + cache: internal.NewCacheNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...), + } +} + +func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn { + return CachedConn{ + db: db, + cache: internal.NewCache(c, exclusiveCalls, stats, sql.ErrNoRows, opts...), + } +} + +func (cc CachedConn) DelCache(keys ...string) error { + return cc.cache.DelCache(keys...) +} + +func (cc CachedConn) GetCache(key string, v interface{}) error { + return cc.cache.GetCache(key, v) +} + +func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) { + res, err := exec(cc.db) + if err != nil { + return nil, err + } + + if err := cc.DelCache(keys...); err != nil { + return nil, err + } + + return res, nil +} + +func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) { + return cc.db.Exec(q, args...) +} + +func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error { + return cc.cache.Take(v, key, func(v interface{}) error { + return query(cc.db, v) + }) +} + +func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string, + indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error { + var primaryKey interface{} + var found bool + if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) { + primaryKey, err = indexQuery(cc.db, v) + if err != nil { + return + } + + found = true + return cc.cache.SetCacheWithExpire(keyer(primaryKey), v, expire+cacheSafeGapBetweenIndexAndPrimary) + }); err != nil { + return err + } + + if found { + return nil + } + + return cc.cache.Take(v, keyer(primaryKey), func(v interface{}) error { + return primaryQuery(cc.db, v, primaryKey) + }) +} + +func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error { + return cc.db.QueryRow(v, q, args...) +} + +// QueryRowsNoCache doesn't use cache, because it might cause consistency problem. +func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error { + return cc.db.QueryRows(v, q, args...) +} + +func (cc CachedConn) SetCache(key string, v interface{}) error { + return cc.cache.SetCache(key, v) +} + +func (cc CachedConn) Transact(fn func(sqlx.Session) error) error { + return cc.db.Transact(fn) +} diff --git a/core/stores/sqlc/cachedsql_test.go b/core/stores/sqlc/cachedsql_test.go new file mode 100644 index 00000000..2665d90d --- /dev/null +++ b/core/stores/sqlc/cachedsql_test.go @@ -0,0 +1,508 @@ +package sqlc + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "zero/core/logx" + "zero/core/stat" + "zero/core/stores/cache" + "zero/core/stores/redis" + "zero/core/stores/sqlx" + + "github.com/alicebob/miniredis" + "github.com/stretchr/testify/assert" +) + +func init() { + logx.Disable() + stat.SetReporter(nil) +} + +func TestCachedConn_GetCache(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) + var value string + err = c.GetCache("any", &value) + assert.Equal(t, ErrNotFound, err) + s.Set("any", `"value"`) + err = c.GetCache("any", &value) + assert.Nil(t, err) + assert.Equal(t, "value", value) +} + +func TestStat(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) + + for i := 0; i < 10; i++ { + var str string + err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error { + *v.(*string) = "zero" + return nil + }) + if err != nil { + t.Error(err) + } + } + + assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit)) +} + +func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) + + var str string + err = c.QueryRowIndex(&str, "index", func(s interface{}) string { + return fmt.Sprintf("%s/1234", s) + }, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) { + *v.(*string) = "zero" + return "primary", nil + }, func(conn sqlx.SqlConn, v, pri interface{}) error { + assert.Equal(t, "primary", pri) + *v.(*string) = "xin" + return nil + }) + assert.Nil(t, err) + assert.Equal(t, "zero", str) + val, err := r.Get("index") + assert.Nil(t, err) + assert.Equal(t, `"primary"`, val) + val, err = r.Get("primary/1234") + assert.Nil(t, err) + assert.Equal(t, `"zero"`, val) +} + +func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), + cache.WithNotFoundExpiry(time.Second)) + + var str string + r.Set("index", `"primary"`) + err = c.QueryRowIndex(&str, "index", func(s interface{}) string { + return fmt.Sprintf("%s/1234", s) + }, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) { + assert.Fail(t, "should not go here") + return "primary", nil + }, func(conn sqlx.SqlConn, v, primary interface{}) error { + *v.(*string) = "xin" + assert.Equal(t, "primary", primary) + return nil + }) + assert.Nil(t, err) + assert.Equal(t, "xin", str) + val, err := r.Get("index") + assert.Nil(t, err) + assert.Equal(t, `"primary"`, val) + val, err = r.Get("primary/1234") + assert.Nil(t, err) + assert.Equal(t, `"xin"`, val) +} + +func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) { + caches := map[string]string{ + "index": "primary", + "primary/1234": "xin", + } + + for k, v := range caches { + t.Run(k+"/"+v, func(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), + cache.WithNotFoundExpiry(time.Second)) + + var str string + r.Set(k, v) + err = c.QueryRowIndex(&str, "index", func(s interface{}) string { + return fmt.Sprintf("%s/1234", s) + }, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) { + *v.(*string) = "xin" + return "primary", nil + }, func(conn sqlx.SqlConn, v, primary interface{}) error { + *v.(*string) = "xin" + assert.Equal(t, "primary", primary) + return nil + }) + assert.Nil(t, err) + assert.Equal(t, "xin", str) + val, err := r.Get("index") + assert.Nil(t, err) + assert.Equal(t, `"primary"`, val) + val, err = r.Get("primary/1234") + assert.Nil(t, err) + assert.Equal(t, `"xin"`, val) + }) + } +} + +func TestStatCacheFails(t *testing.T) { + resetStats() + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + r := redis.NewRedis("localhost:59999", redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) + + for i := 0; i < 20; i++ { + var str string + err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error { + return errors.New("db failed") + }) + assert.NotNil(t, err) + } + + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit)) + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss)) + assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails)) +} + +func TestStatDbFails(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) + + for i := 0; i < 20; i++ { + var str string + err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error { + return errors.New("db failed") + }) + assert.NotNil(t, err) + } + + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit)) + assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails)) +} + +func TestStatFromMemory(t *testing.T) { + resetStats() + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) + + var all sync.WaitGroup + var wait sync.WaitGroup + all.Add(10) + wait.Add(4) + go func() { + var str string + err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error { + *v.(*string) = "zero" + return nil + }) + if err != nil { + t.Error(err) + } + wait.Wait() + runtime.Gosched() + all.Done() + }() + + for i := 0; i < 4; i++ { + go func() { + var str string + wait.Done() + err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error { + *v.(*string) = "zero" + return nil + }) + if err != nil { + t.Error(err) + } + all.Done() + }() + } + for i := 0; i < 5; i++ { + go func() { + var str string + err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error { + *v.(*string) = "zero" + return nil + }) + if err != nil { + t.Error(err) + } + all.Done() + }() + } + all.Wait() + + assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total)) + assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit)) +} + +func TestCachedConnQueryRow(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + const ( + key = "user" + value = "any" + ) + var conn trackedConn + var user string + var ran bool + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30)) + err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error { + ran = true + user = value + return nil + }) + assert.Nil(t, err) + actualValue, err := s.Get(key) + assert.Nil(t, err) + var actual string + assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual)) + assert.Equal(t, value, actual) + assert.Equal(t, value, user) + assert.True(t, ran) +} + +func TestCachedConnQueryRowFromCache(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + const ( + key = "user" + value = "any" + ) + var conn trackedConn + var user string + var ran bool + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30)) + assert.Nil(t, c.SetCache(key, value)) + err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error { + ran = true + user = value + return nil + }) + assert.Nil(t, err) + actualValue, err := s.Get(key) + assert.Nil(t, err) + var actual string + assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual)) + assert.Equal(t, value, actual) + assert.Equal(t, value, user) + assert.False(t, ran) +} + +func TestQueryRowNotFound(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + const key = "user" + var conn trackedConn + var user string + var ran int + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30)) + for i := 0; i < 20; i++ { + err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error { + ran++ + return sql.ErrNoRows + }) + assert.Exactly(t, sqlx.ErrNotFound, err) + } + assert.Equal(t, 1, ran) +} + +func TestCachedConnExec(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + var conn trackedConn + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10)) + _, err = c.ExecNoCache("delete from user_table where id='kevin'") + assert.Nil(t, err) + assert.True(t, conn.execValue) +} + +func TestCachedConnExecDropCache(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + const ( + key = "user" + value = "any" + ) + var conn trackedConn + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30)) + assert.Nil(t, c.SetCache(key, value)) + _, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) { + return conn.Exec("delete from user_table where id='kevin'") + }, key) + assert.Nil(t, err) + assert.True(t, conn.execValue) + _, err = s.Get(key) + assert.Exactly(t, miniredis.ErrKeyNotFound, err) +} + +func TestCachedConnExecDropCacheFailed(t *testing.T) { + const key = "user" + var conn trackedConn + r := redis.NewRedis("anyredis:8888", redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10)) + _, err := c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) { + return conn.Exec("delete from user_table where id='kevin'") + }, key) + // async background clean, retry logic + assert.Nil(t, err) +} + +func TestCachedConnQueryRows(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + var conn trackedConn + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10)) + var users []string + err = c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'") + assert.Nil(t, err) + assert.True(t, conn.queryRowsValue) +} + +func TestCachedConnTransact(t *testing.T) { + s, err := miniredis.Run() + if err != nil { + t.Error(err) + } + + var conn trackedConn + r := redis.NewRedis(s.Addr(), redis.NodeType) + c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10)) + err = c.Transact(func(session sqlx.Session) error { + return nil + }) + assert.Nil(t, err) + assert.True(t, conn.transactValue) +} + +func resetStats() { + atomic.StoreUint64(&stats.Total, 0) + atomic.StoreUint64(&stats.Hit, 0) + atomic.StoreUint64(&stats.Miss, 0) + atomic.StoreUint64(&stats.DbFails, 0) +} + +type dummySqlConn struct { +} + +func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) { + return nil, nil +} + +func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) { + return nil, nil +} + +func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error { + return nil +} + +func (d dummySqlConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error { + return nil +} + +func (d dummySqlConn) QueryRows(v interface{}, query string, args ...interface{}) error { + return nil +} + +func (d dummySqlConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error { + return nil +} + +func (d dummySqlConn) Transact(func(session sqlx.Session) error) error { + return nil +} + +type trackedConn struct { + dummySqlConn + execValue bool + queryRowsValue bool + transactValue bool +} + +func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) { + c.execValue = true + return c.dummySqlConn.Exec(query, args...) +} + +func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error { + c.queryRowsValue = true + return c.dummySqlConn.QueryRows(v, query, args...) +} + +func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error { + c.transactValue = true + return c.dummySqlConn.Transact(fn) +} diff --git a/core/stores/sqlx/bulkinserter.go b/core/stores/sqlx/bulkinserter.go new file mode 100644 index 00000000..88ece15e --- /dev/null +++ b/core/stores/sqlx/bulkinserter.go @@ -0,0 +1,187 @@ +package sqlx + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "zero/core/executors" + "zero/core/logx" + "zero/core/stringx" +) + +const ( + flushInterval = time.Second + maxBulkRows = 1000 + valuesKeyword = "values" +) + +var emptyBulkStmt bulkStmt + +type ( + ResultHandler func(sql.Result, error) + + BulkInserter struct { + executor *executors.PeriodicalExecutor + inserter *dbInserter + stmt bulkStmt + } + + bulkStmt struct { + prefix string + valueFormat string + suffix string + } +) + +func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) { + bkStmt, err := parseInsertStmt(stmt) + if err != nil { + return nil, err + } + + inserter := &dbInserter{ + sqlConn: sqlConn, + stmt: bkStmt, + } + + return &BulkInserter{ + executor: executors.NewPeriodicalExecutor(flushInterval, inserter), + inserter: inserter, + stmt: bkStmt, + }, nil +} + +func (bi *BulkInserter) Flush() { + bi.executor.Flush() +} + +func (bi *BulkInserter) Insert(args ...interface{}) error { + value, err := format(bi.stmt.valueFormat, args...) + if err != nil { + return err + } + + bi.executor.Add(value) + + return nil +} + +func (bi *BulkInserter) SetResultHandler(handler ResultHandler) { + bi.executor.Sync(func() { + bi.inserter.resultHandler = handler + }) +} + +func (bi *BulkInserter) UpdateOrDelete(fn func()) { + bi.executor.Flush() + fn() +} + +func (bi *BulkInserter) UpdateStmt(stmt string) error { + bkStmt, err := parseInsertStmt(stmt) + if err != nil { + return err + } + + bi.executor.Flush() + bi.executor.Sync(func() { + bi.inserter.stmt = bkStmt + }) + + return nil +} + +type dbInserter struct { + sqlConn SqlConn + stmt bulkStmt + values []string + resultHandler ResultHandler +} + +func (in *dbInserter) AddTask(task interface{}) bool { + in.values = append(in.values, task.(string)) + return len(in.values) >= maxBulkRows +} + +func (in *dbInserter) Execute(bulk interface{}) { + values := bulk.([]string) + if len(values) == 0 { + return + } + + stmtWithoutValues := in.stmt.prefix + valuesStr := strings.Join(values, ", ") + stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ") + if len(in.stmt.suffix) > 0 { + stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ") + } + result, err := in.sqlConn.Exec(stmt) + if in.resultHandler != nil { + in.resultHandler(result, err) + } else if err != nil { + logx.Errorf("sql: %s, error: %s", stmt, err) + } +} + +func (in *dbInserter) RemoveAll() interface{} { + values := in.values + in.values = nil + return values +} + +func parseInsertStmt(stmt string) (bulkStmt, error) { + lower := strings.ToLower(stmt) + pos := strings.Index(lower, valuesKeyword) + if pos <= 0 { + return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt) + } + + var columns int + right := strings.LastIndexByte(lower[:pos], ')') + if right > 0 { + left := strings.LastIndexByte(lower[:right], '(') + if left > 0 { + values := lower[left+1 : right] + values = stringx.Filter(values, func(r rune) bool { + return r == ' ' || r == '\t' || r == '\r' || r == '\n' + }) + fields := strings.FieldsFunc(values, func(r rune) bool { + return r == ',' + }) + columns = len(fields) + } + } + + var variables int + var valueFormat string + var suffix string + left := strings.IndexByte(lower[pos:], '(') + if left > 0 { + right = strings.IndexByte(lower[pos+left:], ')') + if right > 0 { + values := lower[pos+left : pos+left+right] + for _, x := range values { + if x == '?' { + variables++ + } + } + valueFormat = stmt[pos+left : pos+left+right+1] + suffix = strings.TrimSpace(stmt[pos+left+right+1:]) + } + } + + if variables == 0 { + return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt) + } + if columns > 0 && columns != variables { + return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt) + } + + return bulkStmt{ + prefix: stmt[:pos+len(valuesKeyword)], + valueFormat: valueFormat, + suffix: suffix, + }, nil +} diff --git a/core/stores/sqlx/bulkinserter_test.go b/core/stores/sqlx/bulkinserter_test.go new file mode 100644 index 00000000..f0c2c287 --- /dev/null +++ b/core/stores/sqlx/bulkinserter_test.go @@ -0,0 +1,98 @@ +package sqlx + +import ( + "database/sql" + "strconv" + "testing" + + "zero/core/logx" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +type mockedConn struct { + query string + args []interface{} +} + +func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) { + c.query = query + c.args = args + return nil, nil +} + +func (c *mockedConn) Prepare(query string) (StmtSession, error) { + panic("should not called") +} + +func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error { + panic("should not called") +} + +func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error { + panic("should not called") +} + +func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error { + panic("should not called") +} + +func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error { + panic("should not called") +} + +func (c *mockedConn) Transact(func(session Session) error) error { + panic("should not called") +} + +func TestBulkInserter(t *testing.T) { + runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var conn mockedConn + inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`) + assert.Nil(t, err) + for i := 0; i < 5; i++ { + assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i)) + } + inserter.Flush() + assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+ + `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+ + `('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`, + conn.query) + assert.Nil(t, conn.args) + }) +} + +func TestBulkInserterSuffix(t *testing.T) { + runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var conn mockedConn + inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+ + `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`) + assert.Nil(t, err) + for i := 0; i < 5; i++ { + assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i)) + } + inserter.Flush() + assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+ + `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+ + `('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`, + conn.query) + assert.Nil(t, conn.args) + }) +} + +func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { + logx.Disable() + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + fn(db, mock) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/core/stores/sqlx/mysql.go b/core/stores/sqlx/mysql.go new file mode 100644 index 00000000..c621415e --- /dev/null +++ b/core/stores/sqlx/mysql.go @@ -0,0 +1,37 @@ +package sqlx + +import "github.com/go-sql-driver/mysql" + +const ( + mysqlDriverName = "mysql" + duplicateEntryCode uint16 = 1062 +) + +func NewMysql(datasource string, opts ...SqlOption) SqlConn { + opts = append(opts, withMysqlAcceptable()) + return NewSqlConn(mysqlDriverName, datasource, opts...) +} + +func mysqlAcceptable(err error) bool { + if err == nil { + return true + } + + myerr, ok := err.(*mysql.MySQLError) + if !ok { + return false + } + + switch myerr.Number { + case duplicateEntryCode: + return true + default: + return false + } +} + +func withMysqlAcceptable() SqlOption { + return func(conn *commonSqlConn) { + conn.accept = mysqlAcceptable + } +} diff --git a/core/stores/sqlx/mysql_test.go b/core/stores/sqlx/mysql_test.go new file mode 100644 index 00000000..5efe98c0 --- /dev/null +++ b/core/stores/sqlx/mysql_test.go @@ -0,0 +1,56 @@ +package sqlx + +import ( + "testing" + + "zero/core/breaker" + "zero/core/logx" + "zero/core/stat" + + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" +) + +func init() { + stat.SetReporter(nil) +} + +func TestBreakerOnDuplicateEntry(t *testing.T) { + logx.Disable() + + err := tryOnDuplicateEntryError(t, mysqlAcceptable) + assert.Equal(t, duplicateEntryCode, err.(*mysql.MySQLError).Number) +} + +func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) { + logx.Disable() + + var found bool + for i := 0; i < 100; i++ { + if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable { + found = true + } + } + assert.True(t, found) +} + +func tryOnDuplicateEntryError(t *testing.T, accept func(error) bool) error { + logx.Disable() + + conn := commonSqlConn{ + brk: breaker.NewBreaker(), + accept: accept, + } + for i := 0; i < 1000; i++ { + assert.NotNil(t, conn.brk.DoWithAcceptable(func() error { + return &mysql.MySQLError{ + Number: duplicateEntryCode, + } + }, conn.acceptable)) + } + return conn.brk.DoWithAcceptable(func() error { + return &mysql.MySQLError{ + Number: duplicateEntryCode, + } + }, conn.acceptable) +} diff --git a/core/stores/sqlx/orm.go b/core/stores/sqlx/orm.go new file mode 100644 index 00000000..83a3a4bd --- /dev/null +++ b/core/stores/sqlx/orm.go @@ -0,0 +1,254 @@ +package sqlx + +import ( + "errors" + "reflect" + "strings" + + "zero/core/mapping" +) + +const tagName = "db" + +var ( + ErrNotMatchDestination = errors.New("not matching destination to scan") + ErrNotReadableValue = errors.New("value not addressable or interfaceable") + ErrNotSettable = errors.New("passed in variable is not settable") + ErrUnsupportedValueType = errors.New("unsupported unmarshal type") +) + +type rowsScanner interface { + Columns() ([]string, error) + Err() error + Next() bool + Scan(v ...interface{}) error +} + +func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) { + rt := mapping.Deref(v.Type()) + size := rt.NumField() + result := make(map[string]interface{}, size) + + for i := 0; i < size; i++ { + key := parseTagName(rt.Field(i)) + if len(key) == 0 { + return nil, nil + } + + valueField := reflect.Indirect(v).Field(i) + switch valueField.Kind() { + case reflect.Ptr: + if !valueField.CanInterface() { + return nil, ErrNotReadableValue + } + if valueField.IsNil() { + baseValueType := mapping.Deref(valueField.Type()) + valueField.Set(reflect.New(baseValueType)) + } + result[key] = valueField.Interface() + default: + if !valueField.CanAddr() || !valueField.Addr().CanInterface() { + return nil, ErrNotReadableValue + } + result[key] = valueField.Addr().Interface() + } + } + + return result, nil +} + +func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) { + fields := unwrapFields(v) + if strict && len(columns) < len(fields) { + return nil, ErrNotMatchDestination + } + + taggedMap, err := getTaggedFieldValueMap(v) + if err != nil { + return nil, err + } + + values := make([]interface{}, len(columns)) + if len(taggedMap) == 0 { + for i := 0; i < len(values); i++ { + valueField := fields[i] + switch valueField.Kind() { + case reflect.Ptr: + if !valueField.CanInterface() { + return nil, ErrNotReadableValue + } + if valueField.IsNil() { + baseValueType := mapping.Deref(valueField.Type()) + valueField.Set(reflect.New(baseValueType)) + } + values[i] = valueField.Interface() + default: + if !valueField.CanAddr() || !valueField.Addr().CanInterface() { + return nil, ErrNotReadableValue + } + values[i] = valueField.Addr().Interface() + } + } + } else { + for i, column := range columns { + if tagged, ok := taggedMap[column]; ok { + values[i] = tagged + } else { + var anonymous interface{} + values[i] = &anonymous + } + } + } + + return values, nil +} + +func parseTagName(field reflect.StructField) string { + key := field.Tag.Get(tagName) + if len(key) == 0 { + return "" + } else { + options := strings.Split(key, ",") + return options[0] + } +} + +func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error { + if !scanner.Next() { + if err := scanner.Err(); err != nil { + return err + } + return ErrNotFound + } + + rv := reflect.ValueOf(v) + if err := mapping.ValidatePtr(&rv); err != nil { + return err + } + + rte := reflect.TypeOf(v).Elem() + rve := rv.Elem() + switch rte.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, + reflect.String: + if rve.CanSet() { + return scanner.Scan(v) + } else { + return ErrNotSettable + } + case reflect.Struct: + columns, err := scanner.Columns() + if err != nil { + return err + } + if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil { + return err + } else { + return scanner.Scan(values...) + } + default: + return ErrUnsupportedValueType + } +} + +func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error { + rv := reflect.ValueOf(v) + if err := mapping.ValidatePtr(&rv); err != nil { + return err + } + + rt := reflect.TypeOf(v) + rte := rt.Elem() + rve := rv.Elem() + switch rte.Kind() { + case reflect.Slice: + if rve.CanSet() { + ptr := rte.Elem().Kind() == reflect.Ptr + appendFn := func(item reflect.Value) { + if ptr { + rve.Set(reflect.Append(rve, item)) + } else { + rve.Set(reflect.Append(rve, reflect.Indirect(item))) + } + } + fillFn := func(value interface{}) error { + if rve.CanSet() { + if err := scanner.Scan(value); err != nil { + return err + } else { + appendFn(reflect.ValueOf(value)) + return nil + } + } + return ErrNotSettable + } + + base := mapping.Deref(rte.Elem()) + switch base.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, + reflect.String: + for scanner.Next() { + value := reflect.New(base) + if err := fillFn(value.Interface()); err != nil { + return err + } + } + case reflect.Struct: + columns, err := scanner.Columns() + if err != nil { + return err + } + + for scanner.Next() { + value := reflect.New(base) + if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil { + return err + } else { + if err := scanner.Scan(values...); err != nil { + return err + } else { + appendFn(value) + } + } + } + default: + return ErrUnsupportedValueType + } + + return nil + } else { + return ErrNotSettable + } + default: + return ErrUnsupportedValueType + } +} + +func unwrapFields(v reflect.Value) []reflect.Value { + var fields []reflect.Value + indirect := reflect.Indirect(v) + + for i := 0; i < indirect.NumField(); i++ { + child := indirect.Field(i) + if child.Kind() == reflect.Ptr && child.IsNil() { + baseValueType := mapping.Deref(child.Type()) + child.Set(reflect.New(baseValueType)) + } + + child = reflect.Indirect(child) + childType := indirect.Type().Field(i) + if child.Kind() == reflect.Struct && childType.Anonymous { + fields = append(fields, unwrapFields(child)...) + } else { + fields = append(fields, child) + } + } + + return fields +} diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go new file mode 100644 index 00000000..e35b729c --- /dev/null +++ b/core/stores/sqlx/orm_test.go @@ -0,0 +1,973 @@ +package sqlx + +import ( + "database/sql" + "testing" + + "zero/core/logx" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +func TestUnmarshalRowBool(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value bool + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.True(t, value) + }) +} + +func TestUnmarshalRowInt(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value int + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, 2, value) + }) +} + +func TestUnmarshalRowInt8(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value int8 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, int8(3), value) + }) +} + +func TestUnmarshalRowInt16(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value int16 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.Equal(t, int16(4), value) + }) +} + +func TestUnmarshalRowInt32(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value int32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.Equal(t, int32(5), value) + }) +} + +func TestUnmarshalRowInt64(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value int64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, int64(6), value) + }) +} + +func TestUnmarshalRowUint(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value uint + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, uint(2), value) + }) +} + +func TestUnmarshalRowUint8(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value uint8 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, uint8(3), value) + }) +} + +func TestUnmarshalRowUint16(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value uint16 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, uint16(4), value) + }) +} + +func TestUnmarshalRowUint32(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value uint32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, uint32(5), value) + }) +} + +func TestUnmarshalRowUint64(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value uint64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, uint16(6), value) + }) +} + +func TestUnmarshalRowFloat32(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value float32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, float32(7), value) + }) +} + +func TestUnmarshalRowFloat64(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value float64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, float64(8), value) + }) +} + +func TestUnmarshalRowString(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + const expect = "hello" + rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect) + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value string + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowStruct(t *testing.T) { + var value = new(struct { + Name string + Age int + }) + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select name, age from users where user=?", "anyone")) + assert.Equal(t, "liao", value.Name) + assert.Equal(t, 5, value.Age) + }) +} + +func TestUnmarshalRowStructWithTags(t *testing.T) { + var value = new(struct { + Age int `db:"age"` + Name string `db:"name"` + }) + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select name, age from users where user=?", "anyone")) + assert.Equal(t, "liao", value.Name) + assert.Equal(t, 5, value.Age) + }) +} + +func TestUnmarshalRowsBool(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []bool{true, false} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []bool + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []int{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []int + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt8(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []int8{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []int8 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt16(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []int16{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []int16 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt32(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []int32{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []int32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt64(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []int64{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []int64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []uint{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []uint + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint8(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []uint8{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []uint8 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint16(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []uint16{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []uint16 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint32(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []uint32{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []uint32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint64(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []uint64{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []uint64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsFloat32(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []float32{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []float32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsFloat64(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []float64{2, 3} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []float64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsString(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []string{"hello", "world"} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []string + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsBoolPtr(t *testing.T) { + yes := true + no := false + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*bool{&yes, &no} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*bool + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsIntPtr(t *testing.T) { + two := 2 + three := 3 + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*int{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*int + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt8Ptr(t *testing.T) { + two := int8(2) + three := int8(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*int8{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*int8 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt16Ptr(t *testing.T) { + two := int16(2) + three := int16(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*int16{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*int16 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt32Ptr(t *testing.T) { + two := int32(2) + three := int32(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*int32{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*int32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsInt64Ptr(t *testing.T) { + two := int64(2) + three := int64(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*int64{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*int64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUintPtr(t *testing.T) { + two := uint(2) + three := uint(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*uint{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*uint + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint8Ptr(t *testing.T) { + two := uint8(2) + three := uint8(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*uint8{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*uint8 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint16Ptr(t *testing.T) { + two := uint16(2) + three := uint16(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*uint16{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*uint16 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint32Ptr(t *testing.T) { + two := uint32(2) + three := uint32(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*uint32{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*uint32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsUint64Ptr(t *testing.T) { + two := uint64(2) + three := uint64(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*uint64{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*uint64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsFloat32Ptr(t *testing.T) { + two := float32(2) + three := float32(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*float32{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*float32 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsFloat64Ptr(t *testing.T) { + two := float64(2) + three := float64(3) + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*float64{&two, &three} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*float64 + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsStringPtr(t *testing.T) { + hello := "hello" + world := "world" + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + var expect = []*string{&hello, &world} + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value []*string + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select value from users where user=?", "anyone")) + assert.EqualValues(t, expect, value) + }) +} + +func TestUnmarshalRowsStruct(t *testing.T) { + var expect = []struct { + Name string + Age int64 + }{ + { + Name: "first", + Age: 2, + }, + { + Name: "second", + Age: 3, + }, + } + var value []struct { + Name string + Age int64 + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + } + }) +} + +func TestUnmarshalRowsStructWithNullStringType(t *testing.T) { + var expect = []struct { + Name string + NullString sql.NullString + }{ + { + Name: "first", + NullString: sql.NullString{ + String: "firstnullstring", + Valid: true, + }, + }, + { + Name: "second", + NullString: sql.NullString{ + String: "", + Valid: false, + }, + }, + } + var value []struct { + Name string `db:"name"` + NullString sql.NullString `db:"value"` + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "value"}).AddRow( + "first", "firstnullstring").AddRow("second", nil) + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.NullString.String, value[i].NullString.String) + assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid) + } + }) +} + +func TestUnmarshalRowsStructWithTags(t *testing.T) { + var expect = []struct { + Name string + Age int64 + }{ + { + Name: "first", + Age: 2, + }, + { + Name: "second", + Age: 3, + }, + } + var value []struct { + Age int64 `db:"age"` + Name string `db:"name"` + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + } + }) +} + +func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) { + type Embed struct { + Value int64 `db:"value"` + } + + var expect = []struct { + Name string + Age int64 + Value int64 + }{ + { + Name: "first", + Age: 2, + Value: 3, + }, + { + Name: "second", + Age: 3, + Value: 4, + }, + } + var value []struct { + Name string `db:"name"` + Age int64 `db:"age"` + Embed + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age, value from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + assert.Equal(t, each.Value, value[i].Value) + } + }) +} + +func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T) { + type Embed struct { + Value int64 `db:"value"` + } + + var expect = []struct { + Name string + Age int64 + Value int64 + }{ + { + Name: "first", + Age: 2, + Value: 3, + }, + { + Name: "second", + Age: 3, + Value: 4, + }, + } + var value []struct { + Name string `db:"name"` + Age int64 `db:"age"` + *Embed + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age, value from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + assert.Equal(t, each.Value, value[i].Value) + } + }) +} + +func TestUnmarshalRowsStructPtr(t *testing.T) { + var expect = []*struct { + Name string + Age int64 + }{ + { + Name: "first", + Age: 2, + }, + { + Name: "second", + Age: 3, + }, + } + var value []*struct { + Name string + Age int64 + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + } + }) +} + +func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) { + var expect = []*struct { + Name string + Age int64 + }{ + { + Name: "first", + Age: 2, + }, + { + Name: "second", + Age: 3, + }, + } + var value []*struct { + Age int64 `db:"age"` + Name string `db:"name"` + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + } + }) +} + +func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) { + var expect = []*struct { + Name string + Age int64 + }{ + { + Name: "first", + Age: 2, + }, + { + Name: "second", + Age: 3, + }, + } + var value []*struct { + Age *int64 `db:"age"` + Name string `db:"name"` + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, *value[i].Age) + } + }) +} + +func TestCommonSqlConn_QueryRowOptional(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var r struct { + User string `db:"user"` + Age int `db:"age"` + } + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(&r, rows, false) + }, "select age from users where user=?", "anyone")) + assert.Empty(t, r.User) + assert.Equal(t, 5, r.Age) + }) +} + +func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { + logx.Disable() + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + fn(db, mock) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go new file mode 100644 index 00000000..26b256cd --- /dev/null +++ b/core/stores/sqlx/sqlconn.go @@ -0,0 +1,204 @@ +package sqlx + +import ( + "database/sql" + + "zero/core/breaker" +) + +var ErrNotFound = sql.ErrNoRows + +type ( + // Session stands for raw connections or transaction sessions + Session interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (StmtSession, error) + QueryRow(v interface{}, query string, args ...interface{}) error + QueryRowPartial(v interface{}, query string, args ...interface{}) error + QueryRows(v interface{}, query string, args ...interface{}) error + QueryRowsPartial(v interface{}, query string, args ...interface{}) error + } + + // SqlConn only stands for raw connections, so Transact method can be called. + SqlConn interface { + Session + Transact(func(session Session) error) error + } + + SqlOption func(*commonSqlConn) + + StmtSession interface { + Close() error + Exec(args ...interface{}) (sql.Result, error) + QueryRow(v interface{}, args ...interface{}) error + QueryRowPartial(v interface{}, args ...interface{}) error + QueryRows(v interface{}, args ...interface{}) error + QueryRowsPartial(v interface{}, args ...interface{}) error + } + + // thread-safe + // Because CORBA doesn't support PREPARE, so we need to combine the + // query arguments into one string and do underlying query without arguments + commonSqlConn struct { + driverName string + datasource string + beginTx beginnable + brk breaker.Breaker + accept func(error) bool + } + + sessionConn interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + } + + statement struct { + stmt *sql.Stmt + } + + stmtConn interface { + Exec(args ...interface{}) (sql.Result, error) + Query(args ...interface{}) (*sql.Rows, error) + } +) + +func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { + conn := &commonSqlConn{ + driverName: driverName, + datasource: datasource, + beginTx: begin, + brk: breaker.NewBreaker(), + } + for _, opt := range opts { + opt(conn) + } + + return conn +} + +func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) { + err = db.brk.DoWithAcceptable(func() error { + var conn *sql.DB + conn, err = getSqlConn(db.driverName, db.datasource) + if err != nil { + logInstanceError(db.datasource, err) + return err + } + + result, err = exec(conn, q, args...) + return err + }, db.acceptable) + + return +} + +func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { + err = db.brk.DoWithAcceptable(func() error { + var conn *sql.DB + conn, err = getSqlConn(db.driverName, db.datasource) + if err != nil { + logInstanceError(db.datasource, err) + return err + } + + if st, err := conn.Prepare(query); err != nil { + return err + } else { + stmt = statement{ + stmt: st, + } + return nil + } + }, db.acceptable) + + return +} + +func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error { + return db.queryRows(func(rows *sql.Rows) error { + return unmarshalRow(v, rows, true) + }, q, args...) +} + +func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error { + return db.queryRows(func(rows *sql.Rows) error { + return unmarshalRow(v, rows, false) + }, q, args...) +} + +func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error { + return db.queryRows(func(rows *sql.Rows) error { + return unmarshalRows(v, rows, true) + }, q, args...) +} + +func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { + return db.queryRows(func(rows *sql.Rows) error { + return unmarshalRows(v, rows, false) + }, q, args...) +} + +func (db *commonSqlConn) Transact(fn func(Session) error) error { + return db.brk.DoWithAcceptable(func() error { + return transact(db, db.beginTx, fn) + }, db.acceptable) +} + +func (db *commonSqlConn) acceptable(err error) bool { + ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone + if db.accept == nil { + return ok + } else { + return ok || db.accept(err) + } +} + +func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error { + var qerr error + return db.brk.DoWithAcceptable(func() error { + conn, err := getSqlConn(db.driverName, db.datasource) + if err != nil { + logInstanceError(db.datasource, err) + return err + } + + return query(conn, func(rows *sql.Rows) error { + qerr = scanner(rows) + return qerr + }, q, args...) + }, func(err error) bool { + return qerr == err || db.acceptable(err) + }) +} + +func (s statement) Close() error { + return s.stmt.Close() +} + +func (s statement) Exec(args ...interface{}) (sql.Result, error) { + return execStmt(s.stmt, args...) +} + +func (s statement) QueryRow(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, true) + }, args...) +} + +func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, false) + }, args...) +} + +func (s statement) QueryRows(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, true) + }, args...) +} + +func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, false) + }, args...) +} diff --git a/core/stores/sqlx/sqlmanager.go b/core/stores/sqlx/sqlmanager.go new file mode 100644 index 00000000..fa18e0f8 --- /dev/null +++ b/core/stores/sqlx/sqlmanager.go @@ -0,0 +1,74 @@ +package sqlx + +import ( + "database/sql" + "io" + "sync" + "time" + + "zero/core/syncx" +) + +const ( + maxIdleConns = 64 + maxOpenConns = 64 + maxLifetime = time.Minute +) + +var connManager = syncx.NewResourceManager() + +type pingedDB struct { + *sql.DB + once sync.Once +} + +func getCachedSqlConn(driverName, server string) (*pingedDB, error) { + val, err := connManager.GetResource(server, func() (io.Closer, error) { + conn, err := newDBConnection(driverName, server) + if err != nil { + return nil, err + } + + return &pingedDB{ + DB: conn, + }, nil + }) + if err != nil { + return nil, err + } + + return val.(*pingedDB), nil +} + +func getSqlConn(driverName, server string) (*sql.DB, error) { + pdb, err := getCachedSqlConn(driverName, server) + if err != nil { + return nil, err + } + + pdb.once.Do(func() { + err = pdb.Ping() + }) + if err != nil { + return nil, err + } + + return pdb.DB, nil +} + +func newDBConnection(driverName, datasource string) (*sql.DB, error) { + conn, err := sql.Open(driverName, datasource) + if err != nil { + return nil, err + } + + // we need to do this until the issue https://github.com/golang/go/issues/9851 get fixed + // discussed here https://github.com/go-sql-driver/mysql/issues/257 + // if the discussed SetMaxIdleTimeout methods added, we'll change this behavior + // 8 means we can't have more than 8 goroutines to concurrently access the same database. + conn.SetMaxIdleConns(maxIdleConns) + conn.SetMaxOpenConns(maxOpenConns) + conn.SetConnMaxLifetime(maxLifetime) + + return conn, nil +} diff --git a/core/stores/sqlx/stmt.go b/core/stores/sqlx/stmt.go new file mode 100644 index 00000000..ed3dc530 --- /dev/null +++ b/core/stores/sqlx/stmt.go @@ -0,0 +1,92 @@ +package sqlx + +import ( + "database/sql" + "fmt" + "time" + + "zero/core/logx" + "zero/core/timex" +) + +const slowThreshold = time.Millisecond * 500 + +func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { + stmt, err := format(q, args...) + if err != nil { + return nil, err + } + + startTime := timex.Now() + result, err := conn.Exec(q, args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql exec: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + } + + return result, err +} + +func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) { + stmt := fmt.Sprint(args...) + startTime := timex.Now() + result, err := conn.Exec(args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql execStmt: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + } + + return result, err +} + +func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { + stmt, err := format(q, args...) + if err != nil { + return err + } + + startTime := timex.Now() + rows, err := conn.Query(q, args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql query: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + return err + } + defer rows.Close() + + return scanner(rows) +} + +func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error { + stmt := fmt.Sprint(args...) + startTime := timex.Now() + rows, err := conn.Query(args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + return err + } + defer rows.Close() + + return scanner(rows) +} diff --git a/core/stores/sqlx/tx.go b/core/stores/sqlx/tx.go new file mode 100644 index 00000000..4fe231d0 --- /dev/null +++ b/core/stores/sqlx/tx.go @@ -0,0 +1,103 @@ +package sqlx + +import ( + "database/sql" + "fmt" +) + +type ( + beginnable func(*sql.DB) (trans, error) + + trans interface { + Session + Commit() error + Rollback() error + } + + txSession struct { + *sql.Tx + } +) + +func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) { + return exec(t.Tx, q, args...) +} + +func (t txSession) Prepare(q string) (StmtSession, error) { + if stmt, err := t.Tx.Prepare(q); err != nil { + return nil, err + } else { + return statement{ + stmt: stmt, + }, nil + } +} + +func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error { + return query(t.Tx, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, true) + }, q, args...) +} + +func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error { + return query(t.Tx, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, false) + }, q, args...) +} + +func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error { + return query(t.Tx, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, true) + }, q, args...) +} + +func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { + return query(t.Tx, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, false) + }, q, args...) +} + +func begin(db *sql.DB) (trans, error) { + if tx, err := db.Begin(); err != nil { + return nil, err + } else { + return txSession{ + Tx: tx, + }, nil + } +} + +func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) { + conn, err := getSqlConn(db.driverName, db.datasource) + if err != nil { + logInstanceError(db.datasource, err) + return err + } + + return transactOnConn(conn, b, fn) +} + +func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) { + var tx trans + tx, err = b(conn) + if err != nil { + return + } + defer func() { + if p := recover(); p != nil { + if e := tx.Rollback(); e != nil { + err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e) + } else { + err = fmt.Errorf("recoveer from %#v", p) + } + } else if err != nil { + if e := tx.Rollback(); e != nil { + err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e) + } + } else { + err = tx.Commit() + } + }() + + return fn(tx) +} diff --git a/core/stores/sqlx/tx_test.go b/core/stores/sqlx/tx_test.go new file mode 100644 index 00000000..72ac5f17 --- /dev/null +++ b/core/stores/sqlx/tx_test.go @@ -0,0 +1,76 @@ +package sqlx + +import ( + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + mockCommit = 1 + mockRollback = 2 +) + +type mockTx struct { + status int +} + +func (mt *mockTx) Commit() error { + mt.status |= mockCommit + return nil +} + +func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) { + return nil, nil +} + +func (mt *mockTx) Prepare(query string) (StmtSession, error) { + return nil, nil +} + +func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error { + return nil +} + +func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error { + return nil +} + +func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error { + return nil +} + +func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { + return nil +} + +func (mt *mockTx) Rollback() error { + mt.status |= mockRollback + return nil +} + +func beginMock(mock *mockTx) beginnable { + return func(*sql.DB) (trans, error) { + return mock, nil + } +} + +func TestTransactCommit(t *testing.T) { + mock := &mockTx{} + err := transactOnConn(nil, beginMock(mock), func(Session) error { + return nil + }) + assert.Equal(t, mockCommit, mock.status) + assert.Nil(t, err) +} + +func TestTransactRollback(t *testing.T) { + mock := &mockTx{} + err := transactOnConn(nil, beginMock(mock), func(Session) error { + return errors.New("rollback") + }) + assert.Equal(t, mockRollback, mock.status) + assert.NotNil(t, err) +} diff --git a/core/stores/sqlx/utils.go b/core/stores/sqlx/utils.go new file mode 100644 index 00000000..8ec78539 --- /dev/null +++ b/core/stores/sqlx/utils.go @@ -0,0 +1,101 @@ +package sqlx + +import ( + "fmt" + "strings" + + "zero/core/logx" + "zero/core/mapping" +) + +func desensitize(datasource string) string { + // remove account + pos := strings.LastIndex(datasource, "@") + if 0 <= pos && pos+1 < len(datasource) { + datasource = datasource[pos+1:] + } + + return datasource +} + +func escape(input string) string { + var b strings.Builder + + for _, ch := range input { + switch ch { + case '\x00': + b.WriteString(`\x00`) + case '\r': + b.WriteString(`\r`) + case '\n': + b.WriteString(`\n`) + case '\\': + b.WriteString(`\\`) + case '\'': + b.WriteString(`\'`) + case '"': + b.WriteString(`\"`) + case '\x1a': + b.WriteString(`\x1a`) + default: + b.WriteRune(ch) + } + } + + return b.String() +} + +func format(query string, args ...interface{}) (string, error) { + numArgs := len(args) + if numArgs == 0 { + return query, nil + } + + var b strings.Builder + argIndex := 0 + + for _, ch := range query { + if ch == '?' { + if argIndex >= numArgs { + return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex) + } + + arg := args[argIndex] + argIndex++ + + switch v := arg.(type) { + case bool: + if v { + b.WriteByte('1') + } else { + b.WriteByte('0') + } + case string: + b.WriteByte('\'') + b.WriteString(escape(v)) + b.WriteByte('\'') + default: + b.WriteString(mapping.Repr(v)) + } + } else { + b.WriteRune(ch) + } + } + + if argIndex < numArgs { + return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex) + } + + return b.String(), nil +} + +func logInstanceError(datasource string, err error) { + datasource = desensitize(datasource) + logx.Errorf("Error on getting sql instance of %s: %v", datasource, err) +} + +func logSqlError(stmt string, err error) { + if err != nil && err != ErrNotFound { + logx.Errorf("stmt: %s, error: %s", stmt, err.Error()) + } +} diff --git a/core/stores/sqlx/utils_test.go b/core/stores/sqlx/utils_test.go new file mode 100644 index 00000000..cb1d8619 --- /dev/null +++ b/core/stores/sqlx/utils_test.go @@ -0,0 +1,30 @@ +package sqlx + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEscape(t *testing.T) { + s := "a\x00\n\r\\'\"\x1ab" + + out := escape(s) + + assert.Equal(t, `a\x00\n\r\\\'\"\x1ab`, out) +} + +func TestDesensitize(t *testing.T) { + datasource := "user:pass@tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" + datasource = desensitize(datasource) + assert.False(t, strings.Contains(datasource, "user")) + assert.False(t, strings.Contains(datasource, "pass")) + assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)")) +} + +func TestDesensitize_WithoutAccount(t *testing.T) { + datasource := "tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" + datasource = desensitize(datasource) + assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)")) +} diff --git a/core/stringx/node.go b/core/stringx/node.go new file mode 100644 index 00000000..39d7a6d2 --- /dev/null +++ b/core/stringx/node.go @@ -0,0 +1,32 @@ +package stringx + +type node struct { + children map[rune]*node + end bool +} + +func (n *node) add(word string) { + chars := []rune(word) + if len(chars) == 0 { + return + } + + nd := n + for _, char := range chars { + if nd.children == nil { + child := new(node) + nd.children = map[rune]*node{ + char: child, + } + nd = child + } else if child, ok := nd.children[char]; ok { + nd = child + } else { + child := new(node) + nd.children[char] = child + nd = child + } + } + + nd.end = true +} diff --git a/core/stringx/random.go b/core/stringx/random.go new file mode 100644 index 00000000..b00a9089 --- /dev/null +++ b/core/stringx/random.go @@ -0,0 +1,79 @@ +package stringx + +import ( + crand "crypto/rand" + "fmt" + "math/rand" + "sync" + "time" +) + +const ( + letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + letterIdxBits = 6 // 6 bits to represent a letter index + idLen = 8 + defaultRandLen = 8 + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + if idx := int(cache & letterIdxMask); idx < len(letterBytes) { + b[i] = letterBytes[idx] + i-- + } + cache >>= letterIdxBits + remain-- + } + + return string(b) +} + +func Seed(seed int64) { + src.Seed(seed) +} diff --git a/core/stringx/random_test.go b/core/stringx/random_test.go new file mode 100644 index 00000000..bfc14b9c --- /dev/null +++ b/core/stringx/random_test.go @@ -0,0 +1,23 @@ +package stringx + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRand(t *testing.T) { + Seed(time.Now().UnixNano()) + assert.True(t, len(Rand()) > 0) + assert.True(t, len(RandId()) > 0) + + const size = 10 + assert.True(t, len(Randn(size)) == size) +} + +func BenchmarkRandString(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Randn(10) + } +} diff --git a/core/stringx/replacer.go b/core/stringx/replacer.go new file mode 100644 index 00000000..181933e4 --- /dev/null +++ b/core/stringx/replacer.go @@ -0,0 +1,77 @@ +package stringx + +import "strings" + +type ( + Replacer interface { + Replace(text string) string + } + + replacer struct { + node + mapping map[string]string + } +) + +func NewReplacer(mapping map[string]string) Replacer { + var rep = &replacer{ + mapping: mapping, + } + for k := range mapping { + rep.add(k) + } + + return rep +} + +func (r *replacer) Replace(text string) string { + var builder strings.Builder + var chars = []rune(text) + var size = len(chars) + var start = -1 + + for i := 0; i < size; i++ { + child, ok := r.children[chars[i]] + if !ok { + builder.WriteRune(chars[i]) + continue + } + + if start < 0 { + start = i + } + var end = -1 + if child.end { + end = i + 1 + } + + var j = i + 1 + for ; j < size; j++ { + grandchild, ok := child.children[chars[j]] + if !ok { + break + } + + child = grandchild + if child.end { + end = j + 1 + i = j + } + } + + if end > 0 { + i = j - 1 + builder.WriteString(r.mapping[string(chars[start:end])]) + } else { + if j < size { + end = j + 1 + } else { + end = size + } + builder.WriteRune(chars[i]) + } + start = -1 + } + + return builder.String() +} diff --git a/core/stringx/replacer_test.go b/core/stringx/replacer_test.go new file mode 100644 index 00000000..c739b2a7 --- /dev/null +++ b/core/stringx/replacer_test.go @@ -0,0 +1,44 @@ +package stringx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReplacer_Replace(t *testing.T) { + var mapping = map[string]string{ + "一二三四": "1234", + "二三": "23", + "二": "2", + } + assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五")) +} + +func TestReplacer_ReplaceSingleChar(t *testing.T) { + var mapping = map[string]string{ + "二": "2", + } + assert.Equal(t, "零一2三四五", NewReplacer(mapping).Replace("零一二三四五")) +} + +func TestReplacer_ReplaceExceedRange(t *testing.T) { + var mapping = map[string]string{ + "二三四五六": "23456", + } + assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五")) +} + +func TestReplacer_ReplacePartialMatch(t *testing.T) { + var mapping = map[string]string{ + "二三四七": "2347", + } + assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五")) +} + +func TestReplacer_ReplaceMultiMatches(t *testing.T) { + var mapping = map[string]string{ + "二三": "23", + } + assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五")) +} diff --git a/core/stringx/strings.go b/core/stringx/strings.go new file mode 100644 index 00000000..154eb38f --- /dev/null +++ b/core/stringx/strings.go @@ -0,0 +1,131 @@ +package stringx + +import ( + "errors" + + "zero/core/lang" +) + +var ( + ErrInvalidStartPosition = errors.New("start position is invalid") + ErrInvalidStopPosition = errors.New("stop position is invalid") +) + +func Contains(list []string, str string) bool { + for _, each := range list { + if each == str { + return true + } + } + + return false +} + +func Filter(s string, filter func(r rune) bool) string { + var n int + chars := []rune(s) + for i, x := range chars { + if n < i { + chars[n] = x + } + if !filter(x) { + n++ + } + } + + return string(chars[:n]) +} + +func HasEmpty(args ...string) bool { + for _, arg := range args { + if len(arg) == 0 { + return true + } + } + + return false +} + +func NotEmpty(args ...string) bool { + return !HasEmpty(args...) +} + +func Remove(strings []string, strs ...string) []string { + out := append([]string(nil), strings...) + + for _, str := range strs { + var n int + for _, v := range out { + if v != str { + out[n] = v + n++ + } + } + out = out[:n] + } + + return out +} + +func Reverse(s string) string { + runes := []rune(s) + + for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 { + runes[from], runes[to] = runes[to], runes[from] + } + + return string(runes) +} + +// Substr returns runes between start and stop [start, stop) regardless of the chars are ascii or utf8 +func Substr(str string, start int, stop int) (string, error) { + rs := []rune(str) + length := len(rs) + + if start < 0 || start > length { + return "", ErrInvalidStartPosition + } + + if stop < 0 || stop > length { + return "", ErrInvalidStopPosition + } + + return string(rs[start:stop]), nil +} + +func TakeOne(valid, or string) string { + if len(valid) > 0 { + return valid + } else { + return or + } +} + +func TakeWithPriority(fns ...func() string) string { + for _, fn := range fns { + val := fn() + if len(val) > 0 { + return val + } + } + + return "" +} + +func Union(first, second []string) []string { + set := make(map[string]lang.PlaceholderType) + + for _, each := range first { + set[each] = lang.Placeholder + } + for _, each := range second { + set[each] = lang.Placeholder + } + + merged := make([]string, 0, len(set)) + for k := range set { + merged = append(merged, k) + } + + return merged +} diff --git a/core/stringx/strings_test.go b/core/stringx/strings_test.go new file mode 100644 index 00000000..184fceb0 --- /dev/null +++ b/core/stringx/strings_test.go @@ -0,0 +1,336 @@ +package stringx + +import ( + "path" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNotEmpty(t *testing.T) { + cases := []struct { + args []string + expect bool + }{ + { + args: []string{"a", "b", "c"}, + expect: true, + }, + { + args: []string{"a", "", "c"}, + expect: false, + }, + { + args: []string{"a"}, + expect: true, + }, + { + args: []string{""}, + expect: false, + }, + { + args: []string{}, + expect: true, + }, + } + + for _, each := range cases { + t.Run(path.Join(each.args...), func(t *testing.T) { + assert.Equal(t, each.expect, NotEmpty(each.args...)) + }) + } +} + +func TestContainsString(t *testing.T) { + cases := []struct { + slice []string + value string + expect bool + }{ + {[]string{"1"}, "1", true}, + {[]string{"1"}, "2", false}, + {[]string{"1", "2"}, "1", true}, + {[]string{"1", "2"}, "3", false}, + {nil, "3", false}, + {nil, "", false}, + } + + for _, each := range cases { + t.Run(path.Join(each.slice...), func(t *testing.T) { + actual := Contains(each.slice, each.value) + assert.Equal(t, each.expect, actual) + }) + } +} + +func TestFilter(t *testing.T) { + cases := []struct { + input string + ignores []rune + expect string + }{ + {``, nil, ``}, + {`abcd`, nil, `abcd`}, + {`ab,cd,ef`, []rune{','}, `abcdef`}, + {`ab, cd,ef`, []rune{',', ' '}, `abcdef`}, + {`ab, cd, ef`, []rune{',', ' '}, `abcdef`}, + {`ab, cd, ef, `, []rune{',', ' '}, `abcdef`}, + } + + for _, each := range cases { + t.Run(each.input, func(t *testing.T) { + actual := Filter(each.input, func(r rune) bool { + for _, x := range each.ignores { + if x == r { + return true + } + } + return false + }) + assert.Equal(t, each.expect, actual) + }) + } +} + +func TestRemove(t *testing.T) { + cases := []struct { + input []string + remove []string + expect []string + }{ + { + input: []string{"a", "b", "a", "c"}, + remove: []string{"a", "b"}, + expect: []string{"c"}, + }, + { + input: []string{"b", "c"}, + remove: []string{"a"}, + expect: []string{"b", "c"}, + }, + { + input: []string{"b", "a", "c"}, + remove: []string{"a"}, + expect: []string{"b", "c"}, + }, + { + input: []string{}, + remove: []string{"a"}, + expect: []string{}, + }, + } + + for _, each := range cases { + t.Run(path.Join(each.input...), func(t *testing.T) { + assert.ElementsMatch(t, each.expect, Remove(each.input, each.remove...)) + }) + } +} + +func TestReverse(t *testing.T) { + cases := []struct { + input string + expect string + }{ + { + input: "abcd", + expect: "dcba", + }, + { + input: "", + expect: "", + }, + { + input: "我爱中国", + expect: "国中爱我", + }, + } + + for _, each := range cases { + t.Run(each.input, func(t *testing.T) { + assert.Equal(t, each.expect, Reverse(each.input)) + }) + } +} + +func TestSubstr(t *testing.T) { + cases := []struct { + input string + start int + stop int + err error + expect string + }{ + { + input: "abcdefg", + start: 1, + stop: 4, + expect: "bcd", + }, + { + input: "我爱中国3000遍,even more", + start: 1, + stop: 9, + expect: "爱中国3000遍", + }, + { + input: "abcdefg", + start: -1, + stop: 4, + err: ErrInvalidStartPosition, + expect: "", + }, + { + input: "abcdefg", + start: 100, + stop: 4, + err: ErrInvalidStartPosition, + expect: "", + }, + { + input: "abcdefg", + start: 1, + stop: -1, + err: ErrInvalidStopPosition, + expect: "", + }, + { + input: "abcdefg", + start: 1, + stop: 100, + err: ErrInvalidStopPosition, + expect: "", + }, + } + + for _, each := range cases { + t.Run(each.input, func(t *testing.T) { + val, err := Substr(each.input, each.start, each.stop) + assert.Equal(t, each.err, err) + if err == nil { + assert.Equal(t, each.expect, val) + } + }) + } +} + +func TestTakeOne(t *testing.T) { + cases := []struct { + valid string + or string + expect string + }{ + {"", "", ""}, + {"", "1", "1"}, + {"1", "", "1"}, + {"1", "2", "1"}, + } + + for _, each := range cases { + t.Run(each.valid, func(t *testing.T) { + actual := TakeOne(each.valid, each.or) + assert.Equal(t, each.expect, actual) + }) + } +} + +func TestTakeWithPriority(t *testing.T) { + tests := []struct { + fns []func() string + expect string + }{ + { + fns: []func() string{ + func() string { + return "first" + }, + func() string { + return "second" + }, + func() string { + return "third" + }, + }, + expect: "first", + }, + { + fns: []func() string{ + func() string { + return "" + }, + func() string { + return "second" + }, + func() string { + return "third" + }, + }, + expect: "second", + }, + { + fns: []func() string{ + func() string { + return "" + }, + func() string { + return "" + }, + func() string { + return "third" + }, + }, + expect: "third", + }, + { + fns: []func() string{ + func() string { + return "" + }, + func() string { + return "" + }, + func() string { + return "" + }, + }, + expect: "", + }, + } + + for _, test := range tests { + t.Run(RandId(), func(t *testing.T) { + val := TakeWithPriority(test.fns...) + assert.Equal(t, test.expect, val) + }) + } +} + +func TestUnion(t *testing.T) { + first := []string{ + "one", + "two", + "three", + } + second := []string{ + "zero", + "two", + "three", + "four", + } + union := Union(first, second) + contains := func(v string) bool { + for _, each := range union { + if v == each { + return true + } + } + + return false + } + assert.Equal(t, 5, len(union)) + assert.True(t, contains("zero")) + assert.True(t, contains("one")) + assert.True(t, contains("two")) + assert.True(t, contains("three")) + assert.True(t, contains("four")) +} diff --git a/core/stringx/trie.go b/core/stringx/trie.go new file mode 100644 index 00000000..ddc2872e --- /dev/null +++ b/core/stringx/trie.go @@ -0,0 +1,119 @@ +package stringx + +import "zero/core/lang" + +type ( + Trie interface { + Filter(text string) (string, []string, bool) + FindKeywords(text string) []string + } + + trieNode struct { + node + } + + scope struct { + start int + stop int + } +) + +func NewTrie(words []string) Trie { + n := new(trieNode) + for _, word := range words { + n.add(word) + } + + return n +} + +func (n *trieNode) Filter(text string) (sentence string, keywords []string, found bool) { + chars := []rune(text) + if len(chars) == 0 { + return text, nil, false + } + + scopes := n.findKeywordScopes(chars) + keywords = n.collectKeywords(chars, scopes) + + for _, match := range scopes { + // we don't care about overlaps, not bringing a performance improvement + n.replaceWithAsterisk(chars, match.start, match.stop) + } + + return string(chars), keywords, len(keywords) > 0 +} + +func (n *trieNode) FindKeywords(text string) []string { + chars := []rune(text) + if len(chars) == 0 { + return nil + } + + scopes := n.findKeywordScopes(chars) + return n.collectKeywords(chars, scopes) +} + +func (n *trieNode) collectKeywords(chars []rune, scopes []scope) []string { + set := make(map[string]lang.PlaceholderType) + for _, v := range scopes { + set[string(chars[v.start:v.stop])] = lang.Placeholder + } + + var i int + keywords := make([]string, len(set)) + for k := range set { + keywords[i] = k + i++ + } + + return keywords +} + +func (n *trieNode) findKeywordScopes(chars []rune) []scope { + var scopes []scope + size := len(chars) + start := -1 + + for i := 0; i < size; i++ { + child, ok := n.children[chars[i]] + if !ok { + continue + } + + if start < 0 { + start = i + } + if child.end { + scopes = append(scopes, scope{ + start: start, + stop: i + 1, + }) + } + + for j := i + 1; j < size; j++ { + grandchild, ok := child.children[chars[j]] + if !ok { + break + } + + child = grandchild + if child.end { + scopes = append(scopes, scope{ + start: start, + stop: j + 1, + }) + } + } + + start = -1 + } + + return scopes +} + +func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) { + for i := start; i < stop; i++ { + chars[i] = '*' + } +} diff --git a/core/stringx/trie_test.go b/core/stringx/trie_test.go new file mode 100644 index 00000000..ca13e914 --- /dev/null +++ b/core/stringx/trie_test.go @@ -0,0 +1,164 @@ +package stringx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTrie(t *testing.T) { + tests := []struct { + input string + output string + keywords []string + found bool + }{ + { + input: "日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演", + output: "日本****兼电视、电影演员。*****女优是xx出道, ******们最精彩的表演是******表演", + keywords: []string{ + "AV演员", + "苍井空", + "AV", + "日本AV女优", + "AV演员色情", + }, + found: true, + }, + { + input: "完全和谐的文本完全和谐的文本", + output: "完全和谐的文本完全和谐的文本", + keywords: nil, + found: false, + }, + { + input: "就一个字不对", + output: "就*个字不对", + keywords: []string{ + "一", + }, + found: true, + }, + { + input: "就一对, AV", + output: "就*对, **", + keywords: []string{ + "一", + "AV", + }, + found: true, + }, + { + input: "就一不对, AV", + output: "就**对, **", + keywords: []string{ + "一", + "一不", + "AV", + }, + found: true, + }, + { + input: "就对, AV", + output: "就对, **", + keywords: []string{ + "AV", + }, + found: true, + }, + { + input: "就对, 一不", + output: "就对, **", + keywords: []string{ + "一", + "一不", + }, + found: true, + }, + { + input: "", + output: "", + keywords: nil, + found: false, + }, + } + + trie := NewTrie([]string{ + "", // no hurts for empty keywords + "一", + "一不", + "AV", + "AV演员", + "苍井空", + "AV演员色情", + "日本AV女优", + }) + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + output, keywords, ok := trie.Filter(test.input) + assert.Equal(t, test.found, ok) + assert.Equal(t, test.output, output) + assert.ElementsMatch(t, test.keywords, keywords) + keywords = trie.FindKeywords(test.input) + assert.ElementsMatch(t, test.keywords, keywords) + }) + } +} + +func TestTrieSingleWord(t *testing.T) { + trie := NewTrie([]string{ + "闹", + }) + output, keywords, ok := trie.Filter("今晚真热闹") + assert.ElementsMatch(t, []string{"闹"}, keywords) + assert.True(t, ok) + assert.Equal(t, "今晚真热*", output) +} + +func TestTrieOverlap(t *testing.T) { + trie := NewTrie([]string{ + "一二三四五", + "二三四五六七八", + }) + output, keywords, ok := trie.Filter("零一二三四五六七八九十") + assert.ElementsMatch(t, []string{ + "一二三四五", + "二三四五六七八", + }, keywords) + assert.True(t, ok) + assert.Equal(t, "零********九十", output) +} + +func TestTrieNested(t *testing.T) { + trie := NewTrie([]string{ + "一二三", + "一二三四五", + "一二三四五六七八", + }) + output, keywords, ok := trie.Filter("零一二三四五六七八九十") + assert.ElementsMatch(t, []string{ + "一二三", + "一二三四五", + "一二三四五六七八", + }, keywords) + assert.True(t, ok) + assert.Equal(t, "零********九十", output) +} + +func BenchmarkTrie(b *testing.B) { + b.ReportAllocs() + + trie := NewTrie([]string{ + "A", + "AV", + "AV演员", + "苍井空", + "AV演员色情", + "日本AV女优", + }) + + for i := 0; i < b.N; i++ { + trie.Filter("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演") + } +} diff --git a/core/syncx/atomicbool.go b/core/syncx/atomicbool.go new file mode 100644 index 00000000..28710df5 --- /dev/null +++ b/core/syncx/atomicbool.go @@ -0,0 +1,38 @@ +package syncx + +import "sync/atomic" + +type AtomicBool uint32 + +func NewAtomicBool() *AtomicBool { + return new(AtomicBool) +} + +func ForAtomicBool(val bool) *AtomicBool { + b := NewAtomicBool() + b.Set(val) + return b +} + +func (b *AtomicBool) CompareAndSwap(old, val bool) bool { + var ov, nv uint32 + if old { + ov = 1 + } + if val { + nv = 1 + } + return atomic.CompareAndSwapUint32((*uint32)(b), ov, nv) +} + +func (b *AtomicBool) Set(v bool) { + if v { + atomic.StoreUint32((*uint32)(b), 1) + } else { + atomic.StoreUint32((*uint32)(b), 0) + } +} + +func (b *AtomicBool) True() bool { + return atomic.LoadUint32((*uint32)(b)) == 1 +} diff --git a/core/syncx/atomicbool_test.go b/core/syncx/atomicbool_test.go new file mode 100644 index 00000000..f1f8557e --- /dev/null +++ b/core/syncx/atomicbool_test.go @@ -0,0 +1,27 @@ +package syncx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicBool(t *testing.T) { + val := ForAtomicBool(true) + assert.True(t, val.True()) + val.Set(false) + assert.False(t, val.True()) + val.Set(true) + assert.True(t, val.True()) + val.Set(false) + assert.False(t, val.True()) + ok := val.CompareAndSwap(false, true) + assert.True(t, ok) + assert.True(t, val.True()) + ok = val.CompareAndSwap(true, false) + assert.True(t, ok) + assert.False(t, val.True()) + ok = val.CompareAndSwap(true, false) + assert.False(t, ok) + assert.False(t, val.True()) +} diff --git a/core/syncx/atomicduration.go b/core/syncx/atomicduration.go new file mode 100644 index 00000000..83c1ed4b --- /dev/null +++ b/core/syncx/atomicduration.go @@ -0,0 +1,30 @@ +package syncx + +import ( + "sync/atomic" + "time" +) + +type AtomicDuration int64 + +func NewAtomicDuration() *AtomicDuration { + return new(AtomicDuration) +} + +func ForAtomicDuration(val time.Duration) *AtomicDuration { + d := NewAtomicDuration() + d.Set(val) + return d +} + +func (d *AtomicDuration) CompareAndSwap(old, val time.Duration) bool { + return atomic.CompareAndSwapInt64((*int64)(d), int64(old), int64(val)) +} + +func (d *AtomicDuration) Load() time.Duration { + return time.Duration(atomic.LoadInt64((*int64)(d))) +} + +func (d *AtomicDuration) Set(val time.Duration) { + atomic.StoreInt64((*int64)(d), int64(val)) +} diff --git a/core/syncx/atomicduration_test.go b/core/syncx/atomicduration_test.go new file mode 100644 index 00000000..8165e136 --- /dev/null +++ b/core/syncx/atomicduration_test.go @@ -0,0 +1,19 @@ +package syncx + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicDuration(t *testing.T) { + d := ForAtomicDuration(time.Duration(100)) + assert.Equal(t, time.Duration(100), d.Load()) + d.Set(time.Duration(200)) + assert.Equal(t, time.Duration(200), d.Load()) + assert.True(t, d.CompareAndSwap(time.Duration(200), time.Duration(300))) + assert.Equal(t, time.Duration(300), d.Load()) + assert.False(t, d.CompareAndSwap(time.Duration(200), time.Duration(400))) + assert.Equal(t, time.Duration(300), d.Load()) +} diff --git a/core/syncx/atomicfloat64.go b/core/syncx/atomicfloat64.go new file mode 100644 index 00000000..35dd41cf --- /dev/null +++ b/core/syncx/atomicfloat64.go @@ -0,0 +1,40 @@ +package syncx + +import ( + "math" + "sync/atomic" +) + +type AtomicFloat64 uint64 + +func NewAtomicFloat64() *AtomicFloat64 { + return new(AtomicFloat64) +} + +func ForAtomicFloat64(val float64) *AtomicFloat64 { + f := NewAtomicFloat64() + f.Set(val) + return f +} + +func (f *AtomicFloat64) Add(val float64) float64 { + for { + old := f.Load() + nv := old + val + if f.CompareAndSwap(old, nv) { + return nv + } + } +} + +func (f *AtomicFloat64) CompareAndSwap(old, val float64) bool { + return atomic.CompareAndSwapUint64((*uint64)(f), math.Float64bits(old), math.Float64bits(val)) +} + +func (f *AtomicFloat64) Load() float64 { + return math.Float64frombits(atomic.LoadUint64((*uint64)(f))) +} + +func (f *AtomicFloat64) Set(val float64) { + atomic.StoreUint64((*uint64)(f), math.Float64bits(val)) +} diff --git a/core/syncx/atomicfloat64_test.go b/core/syncx/atomicfloat64_test.go new file mode 100644 index 00000000..c3c5fa19 --- /dev/null +++ b/core/syncx/atomicfloat64_test.go @@ -0,0 +1,24 @@ +package syncx + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicFloat64(t *testing.T) { + f := ForAtomicFloat64(100) + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + for i := 0; i < 100; i++ { + f.Add(1) + } + wg.Done() + }() + } + wg.Wait() + assert.Equal(t, float64(600), f.Load()) +} diff --git a/core/syncx/barrier.go b/core/syncx/barrier.go new file mode 100644 index 00000000..9ee2d5ff --- /dev/null +++ b/core/syncx/barrier.go @@ -0,0 +1,13 @@ +package syncx + +import "sync" + +type Barrier struct { + lock sync.Mutex +} + +func (b *Barrier) Guard(fn func()) { + b.lock.Lock() + defer b.lock.Unlock() + fn() +} diff --git a/core/syncx/barrier_test.go b/core/syncx/barrier_test.go new file mode 100644 index 00000000..c763f1b3 --- /dev/null +++ b/core/syncx/barrier_test.go @@ -0,0 +1,31 @@ +package syncx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBarrier_Guard(t *testing.T) { + const total = 10000 + var barrier Barrier + var count int + for i := 0; i < total; i++ { + barrier.Guard(func() { + count++ + }) + } + assert.Equal(t, total, count) +} + +func TestBarrierPtr_Guard(t *testing.T) { + const total = 10000 + barrier := new(Barrier) + var count int + for i := 0; i < total; i++ { + barrier.Guard(func() { + count++ + }) + } + assert.Equal(t, total, count) +} diff --git a/core/syncx/cond.go b/core/syncx/cond.go new file mode 100644 index 00000000..890e6f84 --- /dev/null +++ b/core/syncx/cond.go @@ -0,0 +1,47 @@ +package syncx + +import ( + "time" + + "zero/core/lang" + "zero/core/timex" +) + +type Cond struct { + signal chan lang.PlaceholderType +} + +func NewCond() *Cond { + return &Cond{ + signal: make(chan lang.PlaceholderType), + } +} + +// WaitWithTimeout wait for signal return remain wait time or timed out +func (cond *Cond) WaitWithTimeout(timeout time.Duration) (time.Duration, bool) { + timer := time.NewTimer(timeout) + defer timer.Stop() + + begin := timex.Now() + select { + case <-cond.signal: + elapsed := timex.Since(begin) + remainTimeout := timeout - elapsed + return remainTimeout, true + case <-timer.C: + return 0, false + } +} + +// Wait for signal +func (cond *Cond) Wait() { + <-cond.signal +} + +// Signal wakes one goroutine waiting on c, if there is any. +func (cond *Cond) Signal() { + select { + case cond.signal <- lang.Placeholder: + default: + } +} diff --git a/core/syncx/cond_test.go b/core/syncx/cond_test.go new file mode 100644 index 00000000..04ee90d1 --- /dev/null +++ b/core/syncx/cond_test.go @@ -0,0 +1,73 @@ +package syncx + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTimeoutCondWait(t *testing.T) { + var wait sync.WaitGroup + cond := NewCond() + wait.Add(2) + go func() { + cond.Wait() + wait.Done() + }() + time.Sleep(time.Duration(50) * time.Millisecond) + go func() { + cond.Signal() + wait.Done() + }() + wait.Wait() +} + +func TestTimeoutCondWaitTimeout(t *testing.T) { + var wait sync.WaitGroup + cond := NewCond() + wait.Add(1) + go func() { + cond.WaitWithTimeout(time.Duration(500) * time.Millisecond) + wait.Done() + }() + wait.Wait() +} + +func TestTimeoutCondWaitTimeoutRemain(t *testing.T) { + var wait sync.WaitGroup + cond := NewCond() + wait.Add(2) + ch := make(chan time.Duration, 1) + defer close(ch) + timeout := time.Duration(2000) * time.Millisecond + go func() { + remainTimeout, _ := cond.WaitWithTimeout(timeout) + ch <- remainTimeout + wait.Done() + }() + sleep(200) + go func() { + cond.Signal() + wait.Done() + }() + wait.Wait() + remainTimeout := <-ch + assert.True(t, remainTimeout < timeout, "expect remainTimeout %v < %v", remainTimeout, timeout) + assert.True(t, remainTimeout >= time.Duration(200)*time.Millisecond, + "expect remainTimeout %v >= 200 millisecond", remainTimeout) +} + +func TestSignalNoWait(t *testing.T) { + cond := NewCond() + cond.Signal() +} + +func sleep(millisecond int) { + time.Sleep(time.Duration(millisecond) * time.Millisecond) +} + +func currentTimeMillis() int64 { + return time.Now().UnixNano() / int64(time.Millisecond) +} diff --git a/core/syncx/donechan.go b/core/syncx/donechan.go new file mode 100644 index 00000000..d49fccea --- /dev/null +++ b/core/syncx/donechan.go @@ -0,0 +1,28 @@ +package syncx + +import ( + "sync" + + "zero/core/lang" +) + +type DoneChan struct { + done chan lang.PlaceholderType + once sync.Once +} + +func NewDoneChan() *DoneChan { + return &DoneChan{ + done: make(chan lang.PlaceholderType), + } +} + +func (dc *DoneChan) Close() { + dc.once.Do(func() { + close(dc.done) + }) +} + +func (dc *DoneChan) Done() chan lang.PlaceholderType { + return dc.done +} diff --git a/core/syncx/donechan_test.go b/core/syncx/donechan_test.go new file mode 100644 index 00000000..3f7520b9 --- /dev/null +++ b/core/syncx/donechan_test.go @@ -0,0 +1,33 @@ +package syncx + +import ( + "sync" + "testing" +) + +func TestDoneChanClose(t *testing.T) { + doneChan := NewDoneChan() + + for i := 0; i < 5; i++ { + doneChan.Close() + } +} + +func TestDoneChanDone(t *testing.T) { + var waitGroup sync.WaitGroup + doneChan := NewDoneChan() + + waitGroup.Add(1) + go func() { + select { + case <-doneChan.Done(): + waitGroup.Done() + } + }() + + for i := 0; i < 5; i++ { + doneChan.Close() + } + + waitGroup.Wait() +} diff --git a/core/syncx/immutableresource.go b/core/syncx/immutableresource.go new file mode 100644 index 00000000..f3475cee --- /dev/null +++ b/core/syncx/immutableresource.go @@ -0,0 +1,77 @@ +package syncx + +import ( + "sync" + "time" + + "zero/core/timex" +) + +const defaultRefreshInterval = time.Second + +type ( + ImmutableResourceOption func(resource *ImmutableResource) + + ImmutableResource struct { + fetch func() (interface{}, error) + resource interface{} + err error + lock sync.RWMutex + refreshInterval time.Duration + lastTime *AtomicDuration + } +) + +func NewImmutableResource(fn func() (interface{}, error), opts ...ImmutableResourceOption) *ImmutableResource { + // cannot use executors.LessExecutor because of cycle imports + ir := ImmutableResource{ + fetch: fn, + refreshInterval: defaultRefreshInterval, + lastTime: NewAtomicDuration(), + } + for _, opt := range opts { + opt(&ir) + } + return &ir +} + +func (ir *ImmutableResource) Get() (interface{}, error) { + ir.lock.RLock() + resource := ir.resource + ir.lock.RUnlock() + if resource != nil { + return resource, nil + } + + ir.maybeRefresh(func() { + res, err := ir.fetch() + ir.lock.Lock() + if err != nil { + ir.err = err + } else { + ir.resource, ir.err = res, nil + } + ir.lock.Unlock() + }) + + ir.lock.RLock() + resource, err := ir.resource, ir.err + ir.lock.RUnlock() + return resource, err +} + +func (ir *ImmutableResource) maybeRefresh(execute func()) { + now := timex.Now() + lastTime := ir.lastTime.Load() + if lastTime == 0 || lastTime+ir.refreshInterval < now { + ir.lastTime.Set(now) + execute() + } +} + +// Set interval to 0 to enforce refresh every time if not succeeded. default is time.Second. +func WithRefreshIntervalOnFailure(interval time.Duration) ImmutableResourceOption { + return func(resource *ImmutableResource) { + resource.refreshInterval = interval + } +} diff --git a/core/syncx/immutableresource_test.go b/core/syncx/immutableresource_test.go new file mode 100644 index 00000000..e76ecb48 --- /dev/null +++ b/core/syncx/immutableresource_test.go @@ -0,0 +1,78 @@ +package syncx + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestImmutableResource(t *testing.T) { + var count int + r := NewImmutableResource(func() (interface{}, error) { + count++ + return "hello", nil + }) + + res, err := r.Get() + assert.Equal(t, "hello", res) + assert.Equal(t, 1, count) + assert.Nil(t, err) + + // again + res, err = r.Get() + assert.Equal(t, "hello", res) + assert.Equal(t, 1, count) + assert.Nil(t, err) +} + +func TestImmutableResourceError(t *testing.T) { + var count int + r := NewImmutableResource(func() (interface{}, error) { + count++ + return nil, errors.New("any") + }) + + res, err := r.Get() + assert.Nil(t, res) + assert.NotNil(t, err) + assert.Equal(t, "any", err.Error()) + assert.Equal(t, 1, count) + + // again + res, err = r.Get() + assert.Nil(t, res) + assert.NotNil(t, err) + assert.Equal(t, "any", err.Error()) + assert.Equal(t, 1, count) + + r.refreshInterval = 0 + time.Sleep(time.Millisecond) + res, err = r.Get() + assert.Nil(t, res) + assert.NotNil(t, err) + assert.Equal(t, "any", err.Error()) + assert.Equal(t, 2, count) +} + +func TestImmutableResourceErrorRefreshAlways(t *testing.T) { + var count int + r := NewImmutableResource(func() (interface{}, error) { + count++ + return nil, errors.New("any") + }, WithRefreshIntervalOnFailure(0)) + + res, err := r.Get() + assert.Nil(t, res) + assert.NotNil(t, err) + assert.Equal(t, "any", err.Error()) + assert.Equal(t, 1, count) + + // again + res, err = r.Get() + assert.Nil(t, res) + assert.NotNil(t, err) + assert.Equal(t, "any", err.Error()) + assert.Equal(t, 2, count) +} diff --git a/core/syncx/limit.go b/core/syncx/limit.go new file mode 100644 index 00000000..4a99126b --- /dev/null +++ b/core/syncx/limit.go @@ -0,0 +1,42 @@ +package syncx + +import ( + "errors" + + "zero/core/lang" +) + +var ErrReturn = errors.New("discarding limited token, resource pool is full, someone returned multiple times") + +type Limit struct { + pool chan lang.PlaceholderType +} + +func NewLimit(n int) Limit { + return Limit{ + pool: make(chan lang.PlaceholderType, n), + } +} + +func (l Limit) Borrow() { + l.pool <- lang.Placeholder +} + +// Return returns the borrowed resource, returns error only if returned more than borrowed. +func (l Limit) Return() error { + select { + case <-l.pool: + return nil + default: + return ErrReturn + } +} + +func (l Limit) TryBorrow() bool { + select { + case l.pool <- lang.Placeholder: + return true + default: + return false + } +} diff --git a/core/syncx/limit_test.go b/core/syncx/limit_test.go new file mode 100644 index 00000000..338e23c5 --- /dev/null +++ b/core/syncx/limit_test.go @@ -0,0 +1,17 @@ +package syncx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLimit(t *testing.T) { + limit := NewLimit(2) + limit.Borrow() + assert.True(t, limit.TryBorrow()) + assert.False(t, limit.TryBorrow()) + assert.Nil(t, limit.Return()) + assert.Nil(t, limit.Return()) + assert.Equal(t, ErrReturn, limit.Return()) +} diff --git a/core/syncx/lockedcalls.go b/core/syncx/lockedcalls.go new file mode 100644 index 00000000..e2dd0fac --- /dev/null +++ b/core/syncx/lockedcalls.go @@ -0,0 +1,56 @@ +package syncx + +import "sync" + +type ( + // LockedCalls makes sure the calls with the same key to be called sequentially. + // For example, A called F, before it's done, B called F, then B's call would not blocked, + // after A's call finished, B's call got executed. + // The calls with the same key are independent, not sharing the returned values. + // A ------->calls F with key and executes<------->returns + // B ------------------>calls F with key<--------->executes<---->returns + LockedCalls interface { + Do(key string, fn func() (interface{}, error)) (interface{}, error) + } + + lockedGroup struct { + mu sync.Mutex + m map[string]*sync.WaitGroup + } +) + +func NewLockedCalls() LockedCalls { + return &lockedGroup{ + m: make(map[string]*sync.WaitGroup), + } +} + +func (lg *lockedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { +begin: + lg.mu.Lock() + if wg, ok := lg.m[key]; ok { + lg.mu.Unlock() + wg.Wait() + goto begin + } + + return lg.makeCall(key, fn) +} + +func (lg *lockedGroup) makeCall(key string, fn func() (interface{}, error)) (interface{}, error) { + var wg sync.WaitGroup + wg.Add(1) + lg.m[key] = &wg + lg.mu.Unlock() + + defer func() { + // delete key first, done later. can't reverse the order, because if reverse, + // another Do call might wg.Wait() without get notified with wg.Done() + lg.mu.Lock() + delete(lg.m, key) + lg.mu.Unlock() + wg.Done() + }() + + return fn() +} diff --git a/core/syncx/lockedcalls_test.go b/core/syncx/lockedcalls_test.go new file mode 100644 index 00000000..8e812807 --- /dev/null +++ b/core/syncx/lockedcalls_test.go @@ -0,0 +1,82 @@ +package syncx + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" +) + +func TestLockedCallDo(t *testing.T) { + g := NewLockedCalls() + v, err := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestLockedCallDoErr(t *testing.T) { + g := NewLockedCalls() + someErr := errors.New("some error") + v, err := g.Do("key", func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr", err) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestLockedCallDoDupSuppress(t *testing.T) { + g := NewLockedCalls() + c := make(chan string) + var calls int + fn := func() (interface{}, error) { + calls++ + ret := calls + <-c + calls-- + return ret, nil + } + + const n = 10 + var results []int + var lock sync.Mutex + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + v, err := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + + lock.Lock() + results = append(results, v.(int)) + lock.Unlock() + wg.Done() + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + for i := 0; i < n; i++ { + c <- "bar" + } + wg.Wait() + + lock.Lock() + defer lock.Unlock() + + for _, item := range results { + if item != 1 { + t.Errorf("number of calls = %d; want 1", item) + } + } +} diff --git a/core/syncx/managedresource.go b/core/syncx/managedresource.go new file mode 100644 index 00000000..f231d671 --- /dev/null +++ b/core/syncx/managedresource.go @@ -0,0 +1,44 @@ +package syncx + +import "sync" + +type ManagedResource struct { + resource interface{} + lock sync.RWMutex + generate func() interface{} + equals func(a, b interface{}) bool +} + +func NewManagedResource(generate func() interface{}, equals func(a, b interface{}) bool) *ManagedResource { + return &ManagedResource{ + generate: generate, + equals: equals, + } +} + +func (mr *ManagedResource) MarkBroken(resource interface{}) { + mr.lock.Lock() + defer mr.lock.Unlock() + + if mr.equals(mr.resource, resource) { + mr.resource = nil + } +} + +func (mr *ManagedResource) Take() interface{} { + mr.lock.RLock() + resource := mr.resource + mr.lock.RUnlock() + + if resource != nil { + return resource + } + + mr.lock.Lock() + defer mr.lock.Unlock() + // maybe another Take() call already generated the resource. + if mr.resource == nil { + mr.resource = mr.generate() + } + return mr.resource +} diff --git a/core/syncx/managedresource_test.go b/core/syncx/managedresource_test.go new file mode 100644 index 00000000..9f8b15df --- /dev/null +++ b/core/syncx/managedresource_test.go @@ -0,0 +1,22 @@ +package syncx + +import ( + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestManagedResource(t *testing.T) { + var count int32 + resource := NewManagedResource(func() interface{} { + return atomic.AddInt32(&count, 1) + }, func(a, b interface{}) bool { + return a == b + }) + + assert.Equal(t, resource.Take(), resource.Take()) + old := resource.Take() + resource.MarkBroken(old) + assert.NotEqual(t, old, resource.Take()) +} diff --git a/core/syncx/once.go b/core/syncx/once.go new file mode 100644 index 00000000..1a51f9bd --- /dev/null +++ b/core/syncx/once.go @@ -0,0 +1,10 @@ +package syncx + +import "sync" + +func Once(fn func()) func() { + once := new(sync.Once) + return func() { + once.Do(fn) + } +} diff --git a/core/syncx/once_test.go b/core/syncx/once_test.go new file mode 100644 index 00000000..b89d701a --- /dev/null +++ b/core/syncx/once_test.go @@ -0,0 +1,20 @@ +package syncx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOnce(t *testing.T) { + var v int + add := Once(func() { + v++ + }) + + for i := 0; i < 5; i++ { + add() + } + + assert.Equal(t, 1, v) +} diff --git a/core/syncx/onceguard.go b/core/syncx/onceguard.go new file mode 100644 index 00000000..a4f232eb --- /dev/null +++ b/core/syncx/onceguard.go @@ -0,0 +1,15 @@ +package syncx + +import "sync/atomic" + +type OnceGuard struct { + done uint32 +} + +func (og *OnceGuard) Taken() bool { + return atomic.LoadUint32(&og.done) == 1 +} + +func (og *OnceGuard) Take() bool { + return atomic.CompareAndSwapUint32(&og.done, 0, 1) +} diff --git a/core/syncx/onceguard_test.go b/core/syncx/onceguard_test.go new file mode 100644 index 00000000..dac7aa36 --- /dev/null +++ b/core/syncx/onceguard_test.go @@ -0,0 +1,17 @@ +package syncx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOnceGuard(t *testing.T) { + var guard OnceGuard + + assert.False(t, guard.Taken()) + assert.True(t, guard.Take()) + assert.True(t, guard.Taken()) + assert.False(t, guard.Take()) + assert.True(t, guard.Taken()) +} diff --git a/core/syncx/pool.go b/core/syncx/pool.go new file mode 100644 index 00000000..c1414e09 --- /dev/null +++ b/core/syncx/pool.go @@ -0,0 +1,98 @@ +package syncx + +import ( + "sync" + "time" + + "zero/core/timex" +) + +type ( + PoolOption func(*Pool) + + node struct { + item interface{} + next *node + lastUsed time.Duration + } + + Pool struct { + limit int + created int + maxAge time.Duration + lock sync.Locker + cond *sync.Cond + head *node + create func() interface{} + destroy func(interface{}) + } +) + +func NewPool(n int, create func() interface{}, destroy func(interface{}), opts ...PoolOption) *Pool { + if n <= 0 { + panic("pool size can't be negative or zero") + } + + lock := new(sync.Mutex) + pool := &Pool{ + limit: n, + lock: lock, + cond: sync.NewCond(lock), + create: create, + destroy: destroy, + } + + for _, opt := range opts { + opt(pool) + } + + return pool +} + +func (p *Pool) Get() interface{} { + p.lock.Lock() + defer p.lock.Unlock() + + for { + if p.head != nil { + head := p.head + p.head = head.next + if p.maxAge > 0 && head.lastUsed+p.maxAge < timex.Now() { + p.created-- + p.destroy(head.item) + continue + } else { + return head.item + } + } + + if p.created < p.limit { + p.created++ + return p.create() + } + + p.cond.Wait() + } +} + +func (p *Pool) Put(x interface{}) { + if x == nil { + return + } + + p.lock.Lock() + defer p.lock.Unlock() + + p.head = &node{ + item: x, + next: p.head, + lastUsed: timex.Now(), + } + p.cond.Signal() +} + +func WithMaxAge(duration time.Duration) PoolOption { + return func(pool *Pool) { + pool.maxAge = duration + } +} diff --git a/core/syncx/pool_test.go b/core/syncx/pool_test.go new file mode 100644 index 00000000..cd9fdc63 --- /dev/null +++ b/core/syncx/pool_test.go @@ -0,0 +1,111 @@ +package syncx + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "zero/core/lang" + + "github.com/stretchr/testify/assert" +) + +const limit = 10 + +func TestPoolGet(t *testing.T) { + stack := NewPool(limit, create, destroy) + ch := make(chan lang.PlaceholderType) + + for i := 0; i < limit; i++ { + go func() { + v := stack.Get() + if v.(int) != 1 { + t.Fatal("unmatch value") + } + ch <- lang.Placeholder + }() + + select { + case <-ch: + case <-time.After(time.Second): + t.Fail() + } + } +} + +func TestPoolPopTooMany(t *testing.T) { + stack := NewPool(limit, create, destroy) + ch := make(chan lang.PlaceholderType, 1) + + for i := 0; i < limit; i++ { + var wait sync.WaitGroup + wait.Add(1) + go func() { + stack.Get() + ch <- lang.Placeholder + wait.Done() + }() + + wait.Wait() + select { + case <-ch: + default: + t.Fail() + } + } + + var waitGroup, pushWait sync.WaitGroup + waitGroup.Add(1) + pushWait.Add(1) + go func() { + pushWait.Done() + stack.Get() + waitGroup.Done() + }() + + pushWait.Wait() + stack.Put(1) + waitGroup.Wait() +} + +func TestPoolPopFirst(t *testing.T) { + var value int32 + stack := NewPool(limit, func() interface{} { + return atomic.AddInt32(&value, 1) + }, destroy) + + for i := 0; i < 100; i++ { + v := stack.Get().(int32) + assert.Equal(t, 1, int(v)) + stack.Put(v) + } +} + +func TestPoolWithMaxAge(t *testing.T) { + var value int32 + stack := NewPool(limit, func() interface{} { + return atomic.AddInt32(&value, 1) + }, destroy, WithMaxAge(time.Millisecond)) + + v1 := stack.Get().(int32) + // put nil should not matter + stack.Put(nil) + stack.Put(v1) + time.Sleep(time.Millisecond * 10) + v2 := stack.Get().(int32) + assert.NotEqual(t, v1, v2) +} + +func TestNewPoolPanics(t *testing.T) { + assert.Panics(t, func() { + NewPool(0, create, destroy) + }) +} + +func create() interface{} { + return 1 +} + +func destroy(_ interface{}) { +} diff --git a/core/syncx/refresource.go b/core/syncx/refresource.go new file mode 100644 index 00000000..f87fc61f --- /dev/null +++ b/core/syncx/refresource.go @@ -0,0 +1,48 @@ +package syncx + +import ( + "errors" + "sync" +) + +var ErrUseOfCleaned = errors.New("using a cleaned resource") + +type RefResource struct { + lock sync.Mutex + ref int32 + cleaned bool + clean func() +} + +func NewRefResource(clean func()) *RefResource { + return &RefResource{ + clean: clean, + } +} + +func (r *RefResource) Use() error { + r.lock.Lock() + defer r.lock.Unlock() + + if r.cleaned { + return ErrUseOfCleaned + } + + r.ref++ + return nil +} + +func (r *RefResource) Clean() { + r.lock.Lock() + defer r.lock.Unlock() + + if r.cleaned { + return + } + + r.ref-- + if r.ref == 0 { + r.cleaned = true + r.clean() + } +} diff --git a/core/syncx/refresource_test.go b/core/syncx/refresource_test.go new file mode 100644 index 00000000..1cc16882 --- /dev/null +++ b/core/syncx/refresource_test.go @@ -0,0 +1,27 @@ +package syncx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRefCleaner(t *testing.T) { + var count int + clean := func() { + count += 1 + } + + cleaner := NewRefResource(clean) + err := cleaner.Use() + assert.Nil(t, err) + err = cleaner.Use() + assert.Nil(t, err) + cleaner.Clean() + cleaner.Clean() + assert.Equal(t, 1, count) + cleaner.Clean() + cleaner.Clean() + assert.Equal(t, 1, count) + assert.Equal(t, ErrUseOfCleaned, cleaner.Use()) +} diff --git a/core/syncx/resourcemanager.go b/core/syncx/resourcemanager.go new file mode 100644 index 00000000..22b49504 --- /dev/null +++ b/core/syncx/resourcemanager.go @@ -0,0 +1,62 @@ +package syncx + +import ( + "io" + "sync" + + "zero/core/errorx" +) + +type ResourceManager struct { + resources map[string]io.Closer + sharedCalls SharedCalls + lock sync.RWMutex +} + +func NewResourceManager() *ResourceManager { + return &ResourceManager{ + resources: make(map[string]io.Closer), + sharedCalls: NewSharedCalls(), + } +} + +func (manager *ResourceManager) Close() error { + manager.lock.Lock() + defer manager.lock.Unlock() + + var be errorx.BatchError + for _, resource := range manager.resources { + if err := resource.Close(); err != nil { + be.Add(err) + } + } + + return be.Err() +} + +func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (io.Closer, error) { + val, err := manager.sharedCalls.Do(key, func() (interface{}, error) { + manager.lock.RLock() + resource, ok := manager.resources[key] + manager.lock.RUnlock() + if ok { + return resource, nil + } + + resource, err := create() + if err != nil { + return nil, err + } + + manager.lock.Lock() + manager.resources[key] = resource + manager.lock.Unlock() + + return resource, nil + }) + if err != nil { + return nil, err + } + + return val.(io.Closer), nil +} diff --git a/core/syncx/resourcemanager_test.go b/core/syncx/resourcemanager_test.go new file mode 100644 index 00000000..2debc84c --- /dev/null +++ b/core/syncx/resourcemanager_test.go @@ -0,0 +1,46 @@ +package syncx + +import ( + "errors" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +type dummyResource struct { + age int +} + +func (dr *dummyResource) Close() error { + return errors.New("close") +} + +func TestResourceManager_GetResource(t *testing.T) { + manager := NewResourceManager() + defer manager.Close() + + var age int + for i := 0; i < 10; i++ { + val, err := manager.GetResource("key", func() (io.Closer, error) { + age++ + return &dummyResource{ + age: age, + }, nil + }) + assert.Nil(t, err) + assert.Equal(t, 1, val.(*dummyResource).age) + } +} + +func TestResourceManager_GetResourceError(t *testing.T) { + manager := NewResourceManager() + defer manager.Close() + + for i := 0; i < 10; i++ { + _, err := manager.GetResource("key", func() (io.Closer, error) { + return nil, errors.New("fail") + }) + assert.NotNil(t, err) + } +} diff --git a/core/syncx/sharedcalls.go b/core/syncx/sharedcalls.go new file mode 100644 index 00000000..18344eb6 --- /dev/null +++ b/core/syncx/sharedcalls.go @@ -0,0 +1,76 @@ +package syncx + +import "sync" + +type ( + // SharedCalls lets the concurrent calls with the same key to share the call result. + // For example, A called F, before it's done, B called F. Then B would not execute F, + // and shared the result returned by F which called by A. + // The calls with the same key are dependent, concurrent calls share the returned values. + // A ------->calls F with key<------------------->returns val + // B --------------------->calls F with key------>returns val + SharedCalls interface { + Do(key string, fn func() (interface{}, error)) (interface{}, error) + DoEx(key string, fn func() (interface{}, error)) (interface{}, bool, error) + } + + call struct { + wg sync.WaitGroup + val interface{} + err error + } + + sharedGroup struct { + calls map[string]*call + lock sync.Mutex + } +) + +func NewSharedCalls() SharedCalls { + return &sharedGroup{ + calls: make(map[string]*call), + } +} + +func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { + g.lock.Lock() + if c, ok := g.calls[key]; ok { + g.lock.Unlock() + c.wg.Wait() + return c.val, c.err + } + + c := g.makeCall(key, fn) + return c.val, c.err +} + +func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) { + g.lock.Lock() + if c, ok := g.calls[key]; ok { + g.lock.Unlock() + c.wg.Wait() + return c.val, false, c.err + } + + c := g.makeCall(key, fn) + return c.val, true, c.err +} + +func (g *sharedGroup) makeCall(key string, fn func() (interface{}, error)) *call { + c := new(call) + c.wg.Add(1) + g.calls[key] = c + g.lock.Unlock() + + defer func() { + // delete key first, done later. can't reverse the order, because if reverse, + // another Do call might wg.Wait() without get notified with wg.Done() + g.lock.Lock() + delete(g.calls, key) + g.lock.Unlock() + c.wg.Done() + }() + + c.val, c.err = fn() + return c +} diff --git a/core/syncx/sharedcalls_test.go b/core/syncx/sharedcalls_test.go new file mode 100644 index 00000000..d3c2a2a6 --- /dev/null +++ b/core/syncx/sharedcalls_test.go @@ -0,0 +1,108 @@ +package syncx + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestExclusiveCallDo(t *testing.T) { + g := NewSharedCalls() + v, err := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestExclusiveCallDoErr(t *testing.T) { + g := NewSharedCalls() + someErr := errors.New("some error") + v, err := g.Do("key", func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr", err) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestExclusiveCallDoDupSuppress(t *testing.T) { + g := NewSharedCalls() + c := make(chan string) + var calls int32 + fn := func() (interface{}, error) { + atomic.AddInt32(&calls, 1) + return <-c, nil + } + + const n = 10 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + v, err := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if v.(string) != "bar" { + t.Errorf("got %q; want %q", v, "bar") + } + wg.Done() + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + c <- "bar" + wg.Wait() + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("number of calls = %d; want 1", got) + } +} + +func TestExclusiveCallDoExDupSuppress(t *testing.T) { + g := NewSharedCalls() + c := make(chan string) + var calls int32 + fn := func() (interface{}, error) { + atomic.AddInt32(&calls, 1) + return <-c, nil + } + + const n = 10 + var wg sync.WaitGroup + var freshes int32 + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + v, fresh, err := g.DoEx("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if fresh { + atomic.AddInt32(&freshes, 1) + } + if v.(string) != "bar" { + t.Errorf("got %q; want %q", v, "bar") + } + wg.Done() + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + c <- "bar" + wg.Wait() + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("number of calls = %d; want 1", got) + } + if got := atomic.LoadInt32(&freshes); got != 1 { + t.Errorf("freshes = %d; want 1", got) + } +} diff --git a/core/syncx/spinlock.go b/core/syncx/spinlock.go new file mode 100644 index 00000000..c4966d2e --- /dev/null +++ b/core/syncx/spinlock.go @@ -0,0 +1,24 @@ +package syncx + +import ( + "runtime" + "sync/atomic" +) + +type SpinLock struct { + lock uint32 +} + +func (sl *SpinLock) Lock() { + for !sl.TryLock() { + runtime.Gosched() + } +} + +func (sl *SpinLock) TryLock() bool { + return atomic.CompareAndSwapUint32(&sl.lock, 0, 1) +} + +func (sl *SpinLock) Unlock() { + atomic.StoreUint32(&sl.lock, 0) +} diff --git a/core/syncx/spinlock_test.go b/core/syncx/spinlock_test.go new file mode 100644 index 00000000..5ca99d84 --- /dev/null +++ b/core/syncx/spinlock_test.go @@ -0,0 +1,41 @@ +package syncx + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTryLock(t *testing.T) { + var lock SpinLock + assert.True(t, lock.TryLock()) + assert.False(t, lock.TryLock()) + lock.Unlock() + assert.True(t, lock.TryLock()) +} + +func TestSpinLock(t *testing.T) { + var lock SpinLock + lock.Lock() + assert.False(t, lock.TryLock()) + lock.Unlock() + assert.True(t, lock.TryLock()) +} + +func TestSpinLockRace(t *testing.T) { + var lock SpinLock + lock.Lock() + var wait sync.WaitGroup + wait.Add(1) + go func() { + lock.Lock() + lock.Unlock() + wait.Done() + }() + time.Sleep(time.Millisecond * 100) + lock.Unlock() + wait.Wait() + assert.True(t, lock.TryLock()) +} diff --git a/core/syncx/timeoutlimit.go b/core/syncx/timeoutlimit.go new file mode 100644 index 00000000..802c9629 --- /dev/null +++ b/core/syncx/timeoutlimit.go @@ -0,0 +1,51 @@ +package syncx + +import ( + "errors" + "time" +) + +var ErrTimeout = errors.New("borrow timeout") + +type TimeoutLimit struct { + limit Limit + cond *Cond +} + +func NewTimeoutLimit(n int) TimeoutLimit { + return TimeoutLimit{ + limit: NewLimit(n), + cond: NewCond(), + } +} + +func (l TimeoutLimit) Borrow(timeout time.Duration) error { + if l.TryBorrow() { + return nil + } + + var ok bool + for { + timeout, ok = l.cond.WaitWithTimeout(timeout) + if ok && l.TryBorrow() { + return nil + } + + if timeout <= 0 { + return ErrTimeout + } + } +} + +func (l TimeoutLimit) Return() error { + if err := l.limit.Return(); err != nil { + return err + } + + l.cond.Signal() + return nil +} + +func (l TimeoutLimit) TryBorrow() bool { + return l.limit.TryBorrow() +} diff --git a/core/syncx/timeoutlimit_test.go b/core/syncx/timeoutlimit_test.go new file mode 100644 index 00000000..a1e54704 --- /dev/null +++ b/core/syncx/timeoutlimit_test.go @@ -0,0 +1,33 @@ +package syncx + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTimeoutLimit(t *testing.T) { + limit := NewTimeoutLimit(2) + assert.Nil(t, limit.Borrow(time.Millisecond*200)) + assert.Nil(t, limit.Borrow(time.Millisecond*200)) + var wait1, wait2, wait3 sync.WaitGroup + wait1.Add(1) + wait2.Add(1) + wait3.Add(1) + go func() { + wait1.Wait() + wait2.Done() + assert.Nil(t, limit.Return()) + wait3.Done() + }() + wait1.Done() + wait2.Wait() + assert.Nil(t, limit.Borrow(time.Second)) + wait3.Wait() + assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100)) + assert.Nil(t, limit.Return()) + assert.Nil(t, limit.Return()) + assert.Equal(t, ErrReturn, limit.Return()) +} diff --git a/core/sysx/automaxprocs.go b/core/sysx/automaxprocs.go new file mode 100644 index 00000000..820822ab --- /dev/null +++ b/core/sysx/automaxprocs.go @@ -0,0 +1,8 @@ +package sysx + +import "go.uber.org/automaxprocs/maxprocs" + +// Automatically set GOMAXPROCS to match Linux container CPU quota. +func init() { + maxprocs.Set(maxprocs.Logger(nil)) +} diff --git a/core/sysx/host.go b/core/sysx/host.go new file mode 100644 index 00000000..cc46902a --- /dev/null +++ b/core/sysx/host.go @@ -0,0 +1,19 @@ +package sysx + +import ( + "os" + + "zero/core/lang" +) + +var hostname string + +func init() { + var err error + hostname, err = os.Hostname() + lang.Must(err) +} + +func Hostname() string { + return hostname +} diff --git a/core/sysx/host_test.go b/core/sysx/host_test.go new file mode 100644 index 00000000..ae426b82 --- /dev/null +++ b/core/sysx/host_test.go @@ -0,0 +1,11 @@ +package sysx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHostname(t *testing.T) { + assert.True(t, len(Hostname()) > 0) +} diff --git a/core/threading/routinegroup.go b/core/threading/routinegroup.go new file mode 100644 index 00000000..07ca61f9 --- /dev/null +++ b/core/threading/routinegroup.go @@ -0,0 +1,37 @@ +package threading + +import "sync" + +type RoutineGroup struct { + waitGroup sync.WaitGroup +} + +func NewRoutineGroup() *RoutineGroup { + return new(RoutineGroup) +} + +// Don't reference the variables from outside, +// because outside variables can be changed by other goroutines +func (g *RoutineGroup) Run(fn func()) { + g.waitGroup.Add(1) + + go func() { + defer g.waitGroup.Done() + fn() + }() +} + +// Don't reference the variables from outside, +// because outside variables can be changed by other goroutines +func (g *RoutineGroup) RunSafe(fn func()) { + g.waitGroup.Add(1) + + GoSafe(func() { + defer g.waitGroup.Done() + fn() + }) +} + +func (g *RoutineGroup) Wait() { + g.waitGroup.Wait() +} diff --git a/core/threading/routinegroup_test.go b/core/threading/routinegroup_test.go new file mode 100644 index 00000000..454b3623 --- /dev/null +++ b/core/threading/routinegroup_test.go @@ -0,0 +1,45 @@ +package threading + +import ( + "io/ioutil" + "log" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRoutineGroupRun(t *testing.T) { + var count int32 + group := NewRoutineGroup() + for i := 0; i < 3; i++ { + group.Run(func() { + atomic.AddInt32(&count, 1) + }) + } + + group.Wait() + + assert.Equal(t, int32(3), count) +} + +func TestRoutingGroupRunSafe(t *testing.T) { + log.SetOutput(ioutil.Discard) + + var count int32 + group := NewRoutineGroup() + var once sync.Once + for i := 0; i < 3; i++ { + group.RunSafe(func() { + once.Do(func() { + panic("") + }) + atomic.AddInt32(&count, 1) + }) + } + + group.Wait() + + assert.Equal(t, int32(2), count) +} diff --git a/core/threading/routines.go b/core/threading/routines.go new file mode 100644 index 00000000..e41a5204 --- /dev/null +++ b/core/threading/routines.go @@ -0,0 +1,31 @@ +package threading + +import ( + "bytes" + "runtime" + "strconv" + + "zero/core/rescue" +) + +func GoSafe(fn func()) { + go RunSafe(fn) +} + +// Only for debug, never use it in production +func RoutineId() uint64 { + b := make([]byte, 64) + b = b[:runtime.Stack(b, false)] + b = bytes.TrimPrefix(b, []byte("goroutine ")) + b = b[:bytes.IndexByte(b, ' ')] + // if error, just return 0 + n, _ := strconv.ParseUint(string(b), 10, 64) + + return n +} + +func RunSafe(fn func()) { + defer rescue.Recover() + + fn() +} diff --git a/core/threading/routines_test.go b/core/threading/routines_test.go new file mode 100644 index 00000000..55f76a2e --- /dev/null +++ b/core/threading/routines_test.go @@ -0,0 +1,37 @@ +package threading + +import ( + "io/ioutil" + "log" + "testing" + + "zero/core/lang" + + "github.com/stretchr/testify/assert" +) + +func TestRoutineId(t *testing.T) { + assert.True(t, RoutineId() > 0) +} + +func TestRunSafe(t *testing.T) { + log.SetOutput(ioutil.Discard) + + i := 0 + + defer func() { + assert.Equal(t, 1, i) + }() + + ch := make(chan lang.PlaceholderType) + go RunSafe(func() { + defer func() { + ch <- lang.Placeholder + }() + + panic("panic") + }) + + <-ch + i++ +} diff --git a/core/threading/taskrunner.go b/core/threading/taskrunner.go new file mode 100644 index 00000000..e119db22 --- /dev/null +++ b/core/threading/taskrunner.go @@ -0,0 +1,28 @@ +package threading + +import ( + "zero/core/lang" + "zero/core/rescue" +) + +type TaskRunner struct { + limitChan chan lang.PlaceholderType +} + +func NewTaskRunner(concurrency int) *TaskRunner { + return &TaskRunner{ + limitChan: make(chan lang.PlaceholderType, concurrency), + } +} + +func (rp *TaskRunner) Schedule(task func()) { + rp.limitChan <- lang.Placeholder + + go func() { + defer rescue.Recover(func() { + <-rp.limitChan + }) + + task() + }() +} diff --git a/core/threading/taskrunner_test.go b/core/threading/taskrunner_test.go new file mode 100644 index 00000000..81cefc82 --- /dev/null +++ b/core/threading/taskrunner_test.go @@ -0,0 +1,37 @@ +package threading + +import ( + "runtime" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRoutinePool(t *testing.T) { + times := 100 + pool := NewTaskRunner(runtime.NumCPU()) + + var counter int32 + var waitGroup sync.WaitGroup + for i := 0; i < times; i++ { + waitGroup.Add(1) + pool.Schedule(func() { + atomic.AddInt32(&counter, 1) + waitGroup.Done() + }) + } + + waitGroup.Wait() + + assert.Equal(t, times, int(counter)) +} + +func BenchmarkRoutinePool(b *testing.B) { + queue := NewTaskRunner(runtime.NumCPU()) + for i := 0; i < b.N; i++ { + queue.Schedule(func() { + }) + } +} diff --git a/core/threading/workergroup.go b/core/threading/workergroup.go new file mode 100644 index 00000000..a5100820 --- /dev/null +++ b/core/threading/workergroup.go @@ -0,0 +1,21 @@ +package threading + +type WorkerGroup struct { + job func() + workers int +} + +func NewWorkerGroup(job func(), workers int) WorkerGroup { + return WorkerGroup{ + job: job, + workers: workers, + } +} + +func (wg WorkerGroup) Start() { + group := NewRoutineGroup() + for i := 0; i < wg.workers; i++ { + group.RunSafe(wg.job) + } + group.Wait() +} diff --git a/core/threading/workergroup_test.go b/core/threading/workergroup_test.go new file mode 100644 index 00000000..a2e978ec --- /dev/null +++ b/core/threading/workergroup_test.go @@ -0,0 +1,28 @@ +package threading + +import ( + "fmt" + "runtime" + "sync" + "testing" + + "zero/core/lang" + + "github.com/stretchr/testify/assert" +) + +func TestWorkerGroup(t *testing.T) { + m := make(map[string]lang.PlaceholderType) + var lock sync.Mutex + var wg sync.WaitGroup + wg.Add(runtime.NumCPU()) + group := NewWorkerGroup(func() { + lock.Lock() + m[fmt.Sprint(RoutineId())] = lang.Placeholder + lock.Unlock() + wg.Done() + }, runtime.NumCPU()) + go group.Start() + wg.Wait() + assert.Equal(t, runtime.NumCPU(), len(m)) +} diff --git a/core/timex/relativetime.go b/core/timex/relativetime.go new file mode 100644 index 00000000..b11aeb4b --- /dev/null +++ b/core/timex/relativetime.go @@ -0,0 +1,18 @@ +package timex + +import "time" + +// Use the long enough past time as start time, in case timex.Now() - lastTime equals 0. +var initTime = time.Now().AddDate(-1, -1, -1) + +func Now() time.Duration { + return time.Since(initTime) +} + +func Since(d time.Duration) time.Duration { + return time.Since(initTime) - d +} + +func Time() time.Time { + return initTime.Add(Now()) +} diff --git a/core/timex/relativetime_test.go b/core/timex/relativetime_test.go new file mode 100644 index 00000000..fceb387b --- /dev/null +++ b/core/timex/relativetime_test.go @@ -0,0 +1,25 @@ +package timex + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRelativeTime(t *testing.T) { + time.Sleep(time.Millisecond) + now := Now() + assert.True(t, now > 0) + time.Sleep(time.Millisecond) + assert.True(t, Since(now) > 0) +} + +func TestRelativeTime_Time(t *testing.T) { + diff := Time().Sub(time.Now()) + if diff > 0 { + assert.True(t, diff < time.Second) + } else { + assert.True(t, -diff < time.Second) + } +} diff --git a/core/timex/repr.go b/core/timex/repr.go new file mode 100644 index 00000000..519423ed --- /dev/null +++ b/core/timex/repr.go @@ -0,0 +1,10 @@ +package timex + +import ( + "fmt" + "time" +) + +func ReprOfDuration(duration time.Duration) string { + return fmt.Sprintf("%.1fms", float32(duration)/float32(time.Millisecond)) +} diff --git a/core/timex/repr_test.go b/core/timex/repr_test.go new file mode 100644 index 00000000..f7874187 --- /dev/null +++ b/core/timex/repr_test.go @@ -0,0 +1,14 @@ +package timex + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestReprOfDuration(t *testing.T) { + assert.Equal(t, "1000.0ms", ReprOfDuration(time.Second)) + assert.Equal(t, "1111.6ms", ReprOfDuration( + time.Second+time.Millisecond*111+time.Microsecond*555)) +} diff --git a/core/timex/ticker.go b/core/timex/ticker.go new file mode 100644 index 00000000..2b7458e4 --- /dev/null +++ b/core/timex/ticker.go @@ -0,0 +1,73 @@ +package timex + +import ( + "errors" + "time" + + "zero/core/lang" +) + +type ( + Ticker interface { + Chan() <-chan time.Time + Stop() + } + + FakeTicker interface { + Ticker + Done() + Tick() + Wait(d time.Duration) error + } + + fakeTicker struct { + c chan time.Time + done chan lang.PlaceholderType + } + + realTicker struct { + *time.Ticker + } +) + +func NewTicker(d time.Duration) Ticker { + return &realTicker{ + Ticker: time.NewTicker(d), + } +} + +func (rt *realTicker) Chan() <-chan time.Time { + return rt.C +} + +func NewFakeTicker() FakeTicker { + return &fakeTicker{ + c: make(chan time.Time, 1), + done: make(chan lang.PlaceholderType, 1), + } +} + +func (ft *fakeTicker) Chan() <-chan time.Time { + return ft.c +} + +func (ft *fakeTicker) Done() { + ft.done <- lang.Placeholder +} + +func (ft *fakeTicker) Stop() { + close(ft.c) +} + +func (ft *fakeTicker) Tick() { + ft.c <- Time() +} + +func (ft *fakeTicker) Wait(d time.Duration) error { + select { + case <-time.After(d): + return errors.New("timeout") + case <-ft.done: + return nil + } +} diff --git a/core/timex/ticker_test.go b/core/timex/ticker_test.go new file mode 100644 index 00000000..5f10413e --- /dev/null +++ b/core/timex/ticker_test.go @@ -0,0 +1,46 @@ +package timex + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRealTickerDoTick(t *testing.T) { + ticker := NewTicker(time.Millisecond * 10) + defer ticker.Stop() + var count int + for range ticker.Chan() { + count++ + if count > 5 { + break + } + } +} + +func TestFakeTicker(t *testing.T) { + const total = 5 + ticker := NewFakeTicker() + defer ticker.Stop() + + var count int32 + go func() { + for { + select { + case <-ticker.Chan(): + if atomic.AddInt32(&count, 1) == total { + ticker.Done() + } + } + } + }() + + for i := 0; i < 5; i++ { + ticker.Tick() + } + + assert.Nil(t, ticker.Wait(time.Second)) + assert.Equal(t, int32(total), atomic.LoadInt32(&count)) +} diff --git a/core/trace/carrier.go b/core/trace/carrier.go new file mode 100644 index 00000000..0970656b --- /dev/null +++ b/core/trace/carrier.go @@ -0,0 +1,41 @@ +package trace + +import ( + "errors" + "net/http" + "strings" +) + +var ErrInvalidCarrier = errors.New("invalid carrier") + +type ( + Carrier interface { + Get(key string) string + Set(key, value string) + } + + httpCarrier http.Header + // grpc metadata takes keys as case insensitive + grpcCarrier map[string][]string +) + +func (h httpCarrier) Get(key string) string { + return http.Header(h).Get(key) +} + +func (h httpCarrier) Set(key, val string) { + http.Header(h).Set(key, val) +} + +func (g grpcCarrier) Get(key string) string { + if vals, ok := g[strings.ToLower(key)]; ok && len(vals) > 0 { + return vals[0] + } else { + return "" + } +} + +func (g grpcCarrier) Set(key, val string) { + key = strings.ToLower(key) + g[key] = append(g[key], val) +} diff --git a/core/trace/carrier_test.go b/core/trace/carrier_test.go new file mode 100644 index 00000000..ce7a29ab --- /dev/null +++ b/core/trace/carrier_test.go @@ -0,0 +1,59 @@ +package trace + +import ( + "net/http" + "net/http/httptest" + "testing" + + "zero/core/stringx" + + "github.com/stretchr/testify/assert" +) + +func TestHttpCarrier(t *testing.T) { + tests := []map[string]string{ + {}, + { + "first": "a", + "second": "b", + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + carrier := httpCarrier(req.Header) + for k, v := range test { + carrier.Set(k, v) + } + for k, v := range test { + assert.Equal(t, v, carrier.Get(k)) + } + assert.Equal(t, "", carrier.Get("none")) + }) + } +} + +func TestGrpcCarrier(t *testing.T) { + tests := []map[string]string{ + {}, + { + "first": "a", + "second": "b", + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + m := make(map[string][]string) + carrier := grpcCarrier(m) + for k, v := range test { + carrier.Set(k, v) + } + for k, v := range test { + assert.Equal(t, v, carrier.Get(k)) + } + assert.Equal(t, "", carrier.Get("none")) + }) + } +} diff --git a/core/trace/constants.go b/core/trace/constants.go new file mode 100644 index 00000000..e840de6b --- /dev/null +++ b/core/trace/constants.go @@ -0,0 +1,6 @@ +package trace + +const ( + traceIdKey = "X-Trace-ID" + spanIdKey = "X-Span-ID" +) diff --git a/core/trace/noop.go b/core/trace/noop.go new file mode 100644 index 00000000..fd048e3b --- /dev/null +++ b/core/trace/noop.go @@ -0,0 +1,33 @@ +package trace + +import ( + "context" + + "zero/core/trace/tracespec" +) + +var emptyNoopSpan = noopSpan{} + +type noopSpan struct{} + +func (s noopSpan) Finish() { +} + +func (s noopSpan) Follow(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) { + return ctx, emptyNoopSpan +} + +func (s noopSpan) Fork(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) { + return ctx, emptyNoopSpan +} + +func (s noopSpan) SpanId() string { + return "" +} + +func (s noopSpan) TraceId() string { + return "" +} + +func (s noopSpan) Visit(fn func(key, val string) bool) { +} diff --git a/core/trace/noop_test.go b/core/trace/noop_test.go new file mode 100644 index 00000000..10bde427 --- /dev/null +++ b/core/trace/noop_test.go @@ -0,0 +1,32 @@ +package trace + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNoopSpan_Fork(t *testing.T) { + ctx, span := emptyNoopSpan.Fork(context.Background(), "", "") + assert.Equal(t, emptyNoopSpan, span) + assert.Equal(t, context.Background(), ctx) +} + +func TestNoopSpan_Follow(t *testing.T) { + ctx, span := emptyNoopSpan.Follow(context.Background(), "", "") + assert.Equal(t, emptyNoopSpan, span) + assert.Equal(t, context.Background(), ctx) +} + +func TestNoopSpan(t *testing.T) { + emptyNoopSpan.Visit(func(key, val string) bool { + assert.Fail(t, "should not go here") + return true + }) + + ctx, span := emptyNoopSpan.Follow(context.Background(), "", "") + assert.Equal(t, context.Background(), ctx) + assert.Equal(t, "", span.TraceId()) + assert.Equal(t, "", span.SpanId()) +} diff --git a/core/trace/propagator.go b/core/trace/propagator.go new file mode 100644 index 00000000..39b18479 --- /dev/null +++ b/core/trace/propagator.go @@ -0,0 +1,85 @@ +package trace + +import ( + "net/http" + + "google.golang.org/grpc/metadata" +) + +const ( + HttpFormat = iota + GrpcFormat +) + +var ( + emptyHttpPropagator httpPropagator + emptyGrpcPropagator grpcPropagator +) + +type ( + Propagator interface { + Extract(carrier interface{}) (Carrier, error) + Inject(carrier interface{}) (Carrier, error) + } + + httpPropagator struct{} + grpcPropagator struct{} +) + +func (h httpPropagator) Extract(carrier interface{}) (Carrier, error) { + if c, ok := carrier.(http.Header); !ok { + return nil, ErrInvalidCarrier + } else { + return httpCarrier(c), nil + } +} + +func (h httpPropagator) Inject(carrier interface{}) (Carrier, error) { + if c, ok := carrier.(http.Header); ok { + return httpCarrier(c), nil + } else { + return nil, ErrInvalidCarrier + } +} + +func (g grpcPropagator) Extract(carrier interface{}) (Carrier, error) { + if c, ok := carrier.(metadata.MD); ok { + return grpcCarrier(c), nil + } else { + return nil, ErrInvalidCarrier + } +} + +func (g grpcPropagator) Inject(carrier interface{}) (Carrier, error) { + if c, ok := carrier.(metadata.MD); ok { + return grpcCarrier(c), nil + } else { + return nil, ErrInvalidCarrier + } +} + +func Extract(format, carrier interface{}) (Carrier, error) { + switch v := format.(type) { + case int: + if v == HttpFormat { + return emptyHttpPropagator.Extract(carrier) + } else if v == GrpcFormat { + return emptyGrpcPropagator.Extract(carrier) + } + } + + return nil, ErrInvalidCarrier +} + +func Inject(format, carrier interface{}) (Carrier, error) { + switch v := format.(type) { + case int: + if v == HttpFormat { + return emptyHttpPropagator.Inject(carrier) + } else if v == GrpcFormat { + return emptyGrpcPropagator.Inject(carrier) + } + } + + return nil, ErrInvalidCarrier +} diff --git a/core/trace/propagator_test.go b/core/trace/propagator_test.go new file mode 100644 index 00000000..a814e1b4 --- /dev/null +++ b/core/trace/propagator_test.go @@ -0,0 +1,68 @@ +package trace + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" +) + +func TestHttpPropagator_Extract(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(traceIdKey, "trace") + req.Header.Set(spanIdKey, "span") + carrier, err := Extract(HttpFormat, req.Header) + assert.Nil(t, err) + assert.Equal(t, "trace", carrier.Get(traceIdKey)) + assert.Equal(t, "span", carrier.Get(spanIdKey)) + + carrier, err = Extract(HttpFormat, req) + assert.Equal(t, ErrInvalidCarrier, err) +} + +func TestHttpPropagator_Inject(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(traceIdKey, "trace") + req.Header.Set(spanIdKey, "span") + carrier, err := Inject(HttpFormat, req.Header) + assert.Nil(t, err) + assert.Equal(t, "trace", carrier.Get(traceIdKey)) + assert.Equal(t, "span", carrier.Get(spanIdKey)) + + carrier, err = Inject(HttpFormat, req) + assert.Equal(t, ErrInvalidCarrier, err) +} + +func TestGrpcPropagator_Extract(t *testing.T) { + md := metadata.New(map[string]string{ + traceIdKey: "trace", + spanIdKey: "span", + }) + carrier, err := Extract(GrpcFormat, md) + assert.Nil(t, err) + assert.Equal(t, "trace", carrier.Get(traceIdKey)) + assert.Equal(t, "span", carrier.Get(spanIdKey)) + + carrier, err = Extract(GrpcFormat, 1) + assert.Equal(t, ErrInvalidCarrier, err) + carrier, err = Extract(nil, 1) + assert.Equal(t, ErrInvalidCarrier, err) +} + +func TestGrpcPropagator_Inject(t *testing.T) { + md := metadata.New(map[string]string{ + traceIdKey: "trace", + spanIdKey: "span", + }) + carrier, err := Inject(GrpcFormat, md) + assert.Nil(t, err) + assert.Equal(t, "trace", carrier.Get(traceIdKey)) + assert.Equal(t, "span", carrier.Get(spanIdKey)) + + carrier, err = Inject(GrpcFormat, 1) + assert.Equal(t, ErrInvalidCarrier, err) + carrier, err = Inject(nil, 1) + assert.Equal(t, ErrInvalidCarrier, err) +} diff --git a/core/trace/span.go b/core/trace/span.go new file mode 100644 index 00000000..83bfc0b4 --- /dev/null +++ b/core/trace/span.go @@ -0,0 +1,144 @@ +package trace + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "zero/core/stringx" + "zero/core/timex" + "zero/core/trace/tracespec" +) + +const ( + initSpanId = "0" + clientFlag = "client" + serverFlag = "server" + spanSepRune = '.' + timeFormat = "2006-01-02 15:04:05.000" +) + +var spanSep = string([]byte{spanSepRune}) + +type Span struct { + ctx spanContext + serviceName string + operationName string + startTime time.Time + flag string + children int +} + +func newServerSpan(carrier Carrier, serviceName, operationName string) tracespec.Trace { + traceId := stringx.TakeWithPriority(func() string { + if carrier != nil { + return carrier.Get(traceIdKey) + } + return "" + }, func() string { + return stringx.RandId() + }) + spanId := stringx.TakeWithPriority(func() string { + if carrier != nil { + return carrier.Get(spanIdKey) + } + return "" + }, func() string { + return initSpanId + }) + + return &Span{ + ctx: spanContext{ + traceId: traceId, + spanId: spanId, + }, + serviceName: serviceName, + operationName: operationName, + startTime: timex.Time(), + flag: serverFlag, + } +} + +func (s *Span) Finish() { +} + +func (s *Span) Follow(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) { + span := &Span{ + ctx: spanContext{ + traceId: s.ctx.traceId, + spanId: s.followSpanId(), + }, + serviceName: serviceName, + operationName: operationName, + startTime: timex.Time(), + flag: s.flag, + } + return context.WithValue(ctx, tracespec.TracingKey, span), span +} + +func (s *Span) Fork(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) { + span := &Span{ + ctx: spanContext{ + traceId: s.ctx.traceId, + spanId: s.forkSpanId(), + }, + serviceName: serviceName, + operationName: operationName, + startTime: timex.Time(), + flag: clientFlag, + } + return context.WithValue(ctx, tracespec.TracingKey, span), span +} + +func (s *Span) SpanId() string { + return s.ctx.SpanId() +} + +func (s *Span) TraceId() string { + return s.ctx.TraceId() +} + +func (s *Span) Visit(fn func(key, val string) bool) { + s.ctx.Visit(fn) +} + +func (s *Span) forkSpanId() string { + s.children++ + return fmt.Sprintf("%s.%d", s.ctx.spanId, s.children) +} + +func (s *Span) followSpanId() string { + fields := strings.FieldsFunc(s.ctx.spanId, func(r rune) bool { + return r == spanSepRune + }) + if len(fields) == 0 { + return s.ctx.spanId + } + + last := fields[len(fields)-1] + val, err := strconv.Atoi(last) + if err != nil { + return s.ctx.spanId + } + + last = strconv.Itoa(val + 1) + fields[len(fields)-1] = last + + return strings.Join(fields, spanSep) +} + +func StartClientSpan(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) { + if span, ok := ctx.Value(tracespec.TracingKey).(*Span); ok { + return span.Fork(ctx, serviceName, operationName) + } + + return ctx, emptyNoopSpan +} + +func StartServerSpan(ctx context.Context, carrier Carrier, serviceName, operationName string) ( + context.Context, tracespec.Trace) { + span := newServerSpan(carrier, serviceName, operationName) + return context.WithValue(ctx, tracespec.TracingKey, span), span +} diff --git a/core/trace/span_test.go b/core/trace/span_test.go new file mode 100644 index 00000000..f86de421 --- /dev/null +++ b/core/trace/span_test.go @@ -0,0 +1,140 @@ +package trace + +import ( + "context" + "testing" + + "zero/core/stringx" + "zero/core/trace/tracespec" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" +) + +func TestClientSpan(t *testing.T) { + span := newServerSpan(nil, "service", "operation") + ctx := context.WithValue(context.Background(), tracespec.TracingKey, span) + ctx, span = StartClientSpan(ctx, "entrance", "operation") + defer span.Finish() + assert.Equal(t, span, ctx.Value(tracespec.TracingKey)) + + const serviceName = "authorization" + const operationName = "verification" + ctx, childSpan := span.Fork(ctx, serviceName, operationName) + defer childSpan.Finish() + + assert.Equal(t, childSpan, ctx.Value(tracespec.TracingKey)) + assert.Equal(t, getSpan(span).TraceId(), getSpan(childSpan).TraceId()) + assert.Equal(t, "0.1.1", getSpan(childSpan).SpanId()) + assert.Equal(t, serviceName, childSpan.(*Span).serviceName) + assert.Equal(t, operationName, childSpan.(*Span).operationName) + assert.Equal(t, clientFlag, childSpan.(*Span).flag) +} + +func TestClientSpan_WithoutTrace(t *testing.T) { + ctx, span := StartClientSpan(context.Background(), "entrance", "operation") + defer span.Finish() + assert.Equal(t, emptyNoopSpan, span) + assert.Equal(t, context.Background(), ctx) +} + +func TestServerSpan(t *testing.T) { + ctx, span := StartServerSpan(context.Background(), nil, "service", "operation") + defer span.Finish() + assert.Equal(t, span, ctx.Value(tracespec.TracingKey)) + + const serviceName = "authorization" + const operationName = "verification" + ctx, childSpan := span.Fork(ctx, serviceName, operationName) + defer childSpan.Finish() + + assert.Equal(t, childSpan, ctx.Value(tracespec.TracingKey)) + assert.Equal(t, getSpan(span).TraceId(), getSpan(childSpan).TraceId()) + assert.Equal(t, "0.1", getSpan(childSpan).SpanId()) + assert.Equal(t, serviceName, childSpan.(*Span).serviceName) + assert.Equal(t, operationName, childSpan.(*Span).operationName) + assert.Equal(t, clientFlag, childSpan.(*Span).flag) +} + +func TestServerSpan_WithCarrier(t *testing.T) { + md := metadata.New(map[string]string{ + traceIdKey: "a", + spanIdKey: "0.1", + }) + ctx, span := StartServerSpan(context.Background(), grpcCarrier(md), "service", "operation") + defer span.Finish() + assert.Equal(t, span, ctx.Value(tracespec.TracingKey)) + + const serviceName = "authorization" + const operationName = "verification" + ctx, childSpan := span.Fork(ctx, serviceName, operationName) + defer childSpan.Finish() + + assert.Equal(t, childSpan, ctx.Value(tracespec.TracingKey)) + assert.Equal(t, getSpan(span).TraceId(), getSpan(childSpan).TraceId()) + assert.Equal(t, "0.1.1", getSpan(childSpan).SpanId()) + assert.Equal(t, serviceName, childSpan.(*Span).serviceName) + assert.Equal(t, operationName, childSpan.(*Span).operationName) + assert.Equal(t, clientFlag, childSpan.(*Span).flag) +} + +func TestSpan_Follow(t *testing.T) { + tests := []struct { + span string + expectSpan string + }{ + { + "0.1", + "0.2", + }, + { + "0", + "1", + }, + { + "a", + "a", + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + md := metadata.New(map[string]string{ + traceIdKey: "a", + spanIdKey: test.span, + }) + ctx, span := StartServerSpan(context.Background(), grpcCarrier(md), + "service", "operation") + defer span.Finish() + assert.Equal(t, span, ctx.Value(tracespec.TracingKey)) + + const serviceName = "authorization" + const operationName = "verification" + ctx, childSpan := span.Follow(ctx, serviceName, operationName) + defer childSpan.Finish() + + assert.Equal(t, childSpan, ctx.Value(tracespec.TracingKey)) + assert.Equal(t, getSpan(span).TraceId(), getSpan(childSpan).TraceId()) + assert.Equal(t, test.expectSpan, getSpan(childSpan).SpanId()) + assert.Equal(t, serviceName, childSpan.(*Span).serviceName) + assert.Equal(t, operationName, childSpan.(*Span).operationName) + assert.Equal(t, span.(*Span).flag, childSpan.(*Span).flag) + }) + } +} + +func TestSpan_Visit(t *testing.T) { + var run bool + span := newServerSpan(nil, "service", "operation") + span.Visit(func(key, val string) bool { + assert.True(t, len(key) > 0) + assert.True(t, len(val) > 0) + run = true + return true + }) + assert.True(t, run) +} + +func getSpan(span tracespec.Trace) tracespec.Trace { + return span.(*Span) +} diff --git a/core/trace/spancontext.go b/core/trace/spancontext.go new file mode 100644 index 00000000..dd78d971 --- /dev/null +++ b/core/trace/spancontext.go @@ -0,0 +1,19 @@ +package trace + +type spanContext struct { + traceId string + spanId string +} + +func (sc spanContext) TraceId() string { + return sc.traceId +} + +func (sc spanContext) SpanId() string { + return sc.spanId +} + +func (sc spanContext) Visit(fn func(key, val string) bool) { + fn(traceIdKey, sc.traceId) + fn(spanIdKey, sc.spanId) +} diff --git a/core/trace/tracespec/spancontext.go b/core/trace/tracespec/spancontext.go new file mode 100644 index 00000000..4c140837 --- /dev/null +++ b/core/trace/tracespec/spancontext.go @@ -0,0 +1,7 @@ +package tracespec + +type SpanContext interface { + TraceId() string + SpanId() string + Visit(fn func(key, val string) bool) +} diff --git a/core/trace/tracespec/trace.go b/core/trace/tracespec/trace.go new file mode 100644 index 00000000..74db6ff7 --- /dev/null +++ b/core/trace/tracespec/trace.go @@ -0,0 +1,10 @@ +package tracespec + +import "context" + +type Trace interface { + SpanContext + Finish() + Fork(ctx context.Context, serviceName, operationName string) (context.Context, Trace) + Follow(ctx context.Context, serviceName, operationName string) (context.Context, Trace) +} diff --git a/core/trace/tracespec/vars.go b/core/trace/tracespec/vars.go new file mode 100644 index 00000000..35a2ee18 --- /dev/null +++ b/core/trace/tracespec/vars.go @@ -0,0 +1,3 @@ +package tracespec + +const TracingKey = "X-Trace" diff --git a/core/utils/report.go b/core/utils/report.go new file mode 100644 index 00000000..e3a11c39 --- /dev/null +++ b/core/utils/report.go @@ -0,0 +1,5 @@ +package utils + +func Report(content string) { + // TODO: implement the report method +} diff --git a/core/utils/times.go b/core/utils/times.go new file mode 100644 index 00000000..b799f511 --- /dev/null +++ b/core/utils/times.go @@ -0,0 +1,38 @@ +package utils + +import ( + "fmt" + "time" + + "zero/core/timex" +) + +type ElapsedTimer struct { + start time.Duration +} + +func NewElapsedTimer() *ElapsedTimer { + return &ElapsedTimer{ + start: timex.Now(), + } +} + +func (et *ElapsedTimer) Duration() time.Duration { + return timex.Since(et.start) +} + +func (et *ElapsedTimer) Elapsed() string { + return timex.Since(et.start).String() +} + +func (et *ElapsedTimer) ElapsedMs() string { + return fmt.Sprintf("%.1fms", float32(timex.Since(et.start))/float32(time.Millisecond)) +} + +func CurrentMicros() int64 { + return time.Now().UnixNano() / int64(time.Microsecond) +} + +func CurrentMillis() int64 { + return time.Now().UnixNano() / int64(time.Millisecond) +} diff --git a/core/utils/times_test.go b/core/utils/times_test.go new file mode 100644 index 00000000..2fa8f0f2 --- /dev/null +++ b/core/utils/times_test.go @@ -0,0 +1,40 @@ +package utils + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const sleepInterval = time.Millisecond * 10 + +func TestElapsedTimer_Duration(t *testing.T) { + timer := NewElapsedTimer() + time.Sleep(sleepInterval) + assert.True(t, timer.Duration() >= sleepInterval) +} + +func TestElapsedTimer_Elapsed(t *testing.T) { + timer := NewElapsedTimer() + time.Sleep(sleepInterval) + duration, err := time.ParseDuration(timer.Elapsed()) + assert.Nil(t, err) + assert.True(t, duration >= sleepInterval) +} + +func TestElapsedTimer_ElapsedMs(t *testing.T) { + timer := NewElapsedTimer() + time.Sleep(sleepInterval) + duration, err := time.ParseDuration(timer.ElapsedMs()) + assert.Nil(t, err) + assert.True(t, duration >= sleepInterval) +} + +func TestCurrent(t *testing.T) { + currentMillis := CurrentMillis() + currentMicros := CurrentMicros() + assert.True(t, currentMillis > 0) + assert.True(t, currentMicros > 0) + assert.True(t, currentMillis*1000 <= currentMicros) +} diff --git a/core/utils/uuid.go b/core/utils/uuid.go new file mode 100644 index 00000000..237235db --- /dev/null +++ b/core/utils/uuid.go @@ -0,0 +1,7 @@ +package utils + +import "github.com/google/uuid" + +func NewUuid() string { + return uuid.New().String() +} diff --git a/core/utils/uuid_test.go b/core/utils/uuid_test.go new file mode 100644 index 00000000..27af5c17 --- /dev/null +++ b/core/utils/uuid_test.go @@ -0,0 +1,11 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUuid(t *testing.T) { + assert.Equal(t, 36, len(NewUuid())) +} diff --git a/core/utils/version.go b/core/utils/version.go new file mode 100644 index 00000000..bb6b4b77 --- /dev/null +++ b/core/utils/version.go @@ -0,0 +1,37 @@ +package utils + +import ( + "strconv" + "strings" +) + +// returns -1 if the first version is lower than the second, 0 if they are equal, and 1 if the second is lower. +func CompareVersions(a, b string) int { + as := strings.Split(a, ".") + bs := strings.Split(b, ".") + var loop int + if len(as) > len(bs) { + loop = len(as) + } else { + loop = len(bs) + } + + for i := 0; i < loop; i++ { + var x, y string + if len(as) > i { + x = as[i] + } + if len(bs) > i { + y = bs[i] + } + xi, _ := strconv.Atoi(x) + yi, _ := strconv.Atoi(y) + if xi > yi { + return 1 + } else if xi < yi { + return -1 + } + } + + return 0 +} diff --git a/core/utils/version_test.go b/core/utils/version_test.go new file mode 100644 index 00000000..b751ced8 --- /dev/null +++ b/core/utils/version_test.go @@ -0,0 +1,30 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCompareVersions(t *testing.T) { + cases := []struct { + ver1 string + ver2 string + out int + }{ + {"1", "1.0.1", -1}, + {"1.0.1", "1.0.2", -1}, + {"1.0.3", "1.1", -1}, + {"1.1", "1.1.1", -1}, + {"1.3.2", "1.2", 1}, + {"1.1.1", "1.1.1", 0}, + {"1.1.0", "1.1", 0}, + } + + for _, each := range cases { + t.Run(each.ver1, func(t *testing.T) { + actual := CompareVersions(each.ver1, each.ver2) + assert.Equal(t, each.out, actual) + }) + } +} diff --git a/doc/breaker.md b/doc/breaker.md new file mode 100644 index 00000000..52eece74 --- /dev/null +++ b/doc/breaker.md @@ -0,0 +1,4 @@ +# 熔断机制设计 + +## 设计目的 +* 依赖的服务出现大规模故障时,调用方应该尽可能少调用,降低故障服务的压力,使之尽快恢复服务 \ No newline at end of file diff --git a/doc/images/shedding_flying.jpg b/doc/images/shedding_flying.jpg new file mode 100644 index 00000000..2ca38834 Binary files /dev/null and b/doc/images/shedding_flying.jpg differ diff --git a/doc/kubernetes_setup.md b/doc/kubernetes_setup.md new file mode 100644 index 00000000..a413dd53 --- /dev/null +++ b/doc/kubernetes_setup.md @@ -0,0 +1,142 @@ +# kubernetes集群搭建(centos7) + +* 修改每台主机的hostname,如果需要的话 + * `hostname ` + * 修改/etc/hostname + +* 选择一台机器安装ansible,为了便于从一台机器上操作所有机器 + * 安装zsh & oh-my-zsh,为了更方便的使用命令行(可选) + + ``` + yum install -y zsh + yum install -y git + sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" + ``` + + * `yum install -y ansible` + * 解决错误 `RequestsDependencyWarning: urllib3 (1.22) or chardet (2.2.1) doesn't match a supported version` + + ``` + pip uninstall -y urllib3 + pip uninstall -y chardet + pip install requests + ``` + + * 禁用command_warnings,在/etc/ansible/ansible.cfg里将`command_warnings = False`前面的#去掉 + * 将所有机器的内网ip按照分组增加到/etc/ansible/hosts,如下: + + ``` + [master] + 172.20.102.[208:210] + + [node] + 172.20.102.[211:212] + ``` + + * 用root账号通过ssh-keygen生成内网无需密码root登录其它服务器,使用默认选项 + * 用ssh-copy-id将生成的id_rsa.pub传送到所有主机的authorized_hosts里,包括本机,如: + + `ssh-copy-id root@172.20.102.208` + * 验证ansible是否可以登录所有服务器,如下: + + ``` + [root@172 ~]# ansible all -m ping -u root + 172.20.102.208 | SUCCESS => { + "changed": false, + "ping": "pong" + } + ... + ``` + +* 更新所有服务器 + + `ansible all -u root -m shell -a "yum update -y"` + +* 所有服务器上安装docker + + ``` + ansible all -u root -m shell -a "yum remove docker docker-client docker-client-latest docker-core docker-latest docker-latest-logrotate docker-logrotate docker-selinux docker-engine-selinux docker-engine" + ansible all -u root -m shell -a "yum install -y yum-utils" + ansible all -u root -m shell -a "yum install -y device-mapper-persistent-data" + ansible all -u root -m shell -a "yum install -y lvm2" + ansible all -u root -m shell -a "yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo" + ansible all -u root -m shell -a "yum install -y docker-ce" + ansible all -u root -m shell -a "systemctl enable docker" + ansible all -u root -m shell -a "systemctl start docker" + ``` + +* 每台机器上添加阿里云的kubernetes repo + + ``` + # cat k8srepo.yaml + cat < /etc/yum.repos.d/kubernetes.repo + [kubernetes] + name=Kubernetes + baseurl=http://mirrors.aliyun.com/kubernetes/yum/repos/kubernetes-el7-x86_64 + enabled=1 + gpgcheck=0 + repo_gpgcheck=0 + gpgkey=http://mirrors.aliyun.com/kubernetes/yum/doc/yum-key.gpg + http://mirrors.aliyun.com/kubernetes/yum/doc/rpm-package-key.gpg + EOF + + # ansible-playbook k8srepo.yaml + ``` + +* 安装kubelet, kubeadm, kubectl, ipvsadm + + `ansible all -u root -m shell -a "yum install -y kubelet kubeadm kubectl ipvsadm"` + +* 禁用所有服务器上的swap + + `ansible all -u root -m shell -a "swapoff -a"` + +* 允许所有服务器进行转发,因为k8s的NodePort需要在所有服务器之间进行转发 + + `ansible all -u root -m shell -a "iptables -P FORWARD ACCEPT"` + +* 由于k8s.gcr.io不能访问,需要从本机科学上网docker pull如下几个image + + ``` + k8s.gcr.io/kube-proxy-amd64:v1.11.1 + k8s.gcr.io/kube-controller-manager-amd64:v1.11.1 + k8s.gcr.io/kube-scheduler-amd64:v1.11.1 + k8s.gcr.io/kube-apiserver-amd64:v1.11.1 + k8s.gcr.io/coredns:1.1.3 + k8s.gcr.io/etcd-amd64:3.2.18 + k8s.gcr.io/pause:3.1 + ``` + + 通过命令一次完成拉取 + + `while IFS= read -r line; do docker pull $line; done` + + 然后上传到一台服务器 + + `while IFS= read -r line; do docker save | pv | ssh @ "docker load"; done` + + 同步到所有k8s服务器,其中$i是为了匹配所有内网ip + + `while IFS= read -r line; do for ((i=208;i<213;i++)); do docker save $line | ssh root@172.20.102.$i "docker load"; done; done` + +* 在一台master服务器上初始化集群 + + `kubeadm init --api-advertise-addresses <本机内网ip> --kubernetes-version=v1.11.1` + + 注意最后的`kubeadm join`一行,用来在其它服务器加入集群(稍后用) + + 初始化配置,master上执行如下命令 + + ``` + mkdir -p $HOME/.kube + sudo cp -i /etc/kubernetes/admin.conf $HOME/.kube/config + sudo chown $(id -u):$(id -g) $HOME/.kube/config + ``` + + 添加calico网络 + + `kubectl apply -f https://docs.projectcalico.org/v3.1/getting-started/kubernetes/installation/hosted/kubeadm/1.7/calico.yaml` + +* 从所有其它服务器执行master上获得的kubeadm join那行命令,里面包含了加入的token + +* 执行`kubectl get nodes`验证集群是否成功 diff --git a/doc/loadshedding.md b/doc/loadshedding.md new file mode 100644 index 00000000..b35239c9 --- /dev/null +++ b/doc/loadshedding.md @@ -0,0 +1,46 @@ +# 服务自适应降载保护设计 + +## 设计目的 +* 保证系统不被过量请求拖垮 +* 在保证系统稳定的前提下,尽可能提供更高的吞吐量 + +## 设计考虑因素 +* 如何衡量系统负载 + * 是否处于虚机或容器内,需要读取cgroup相关负载 + * 用1000m表示100%CPU,推荐使用800m表示系统高负载 +* 尽可能小的Overhead,不显著增加RT +* 不考虑服务本身所依赖的DB或者缓存系统问题,这类问题通过熔断机制来解决 + +## 机制设计 +* 计算CPU负载时使用滑动平均来降低CPU负载抖动带来的不稳定,关于滑动平均见参考资料 + * 滑动平均就是取之前连续N次值的近似平均,N取值可以通过超参beta来决定 + * 当CPU负载大于指定值时触发降载保护机制 +* 时间窗口机制,用滑动窗口机制来记录之前时间窗口内的QPS和RT(response time) + * 滑动窗口使用5秒钟50个桶的方式,每个桶保存100ms时间内的请求,循环利用,最新的覆盖最老的 + * 计算maxQPS和minRT时需要过滤掉最新的时间没有用完的桶,防止此桶内只有极少数请求,并且RT处于低概率的极小值,所以计算maxQPS和minRT时按照上面的50个桶的参数只会算49个 +* 满足以下所有条件则拒绝该请求 + 1. 当前CPU负载超过预设阈值,或者上次拒绝时间到现在不超过1秒(冷却期)。冷却期是为了不能让负载刚下来就马上增加压力导致立马又上去的来回抖动 + 2. `averageFlying > max(1, QPS*minRT/1e3)` + * averageFlying = MovingAverage(flying) + * 在算MovingAverage(flying)的时候,超参beta默认取值为0.9,表示计算前十次的平均flying值 + * 取flying值的时候,有三种做法: + 1. 请求增加后更新一次averageFlying,见图中橙色曲线 + 2. 请求结束后更新一次averageFlying,见图中绿色曲线 + 3. 请求增加后更新一次averageFlying,请求结束后更新一次averageFlying + + 我们使用的是第二种,这样可以更好的防止抖动,如图: + ![flying策略对比](images/shedding_flying.jpg) + * QPS = maxPass * bucketsPerSecond + * maxPass表示每个有效桶里的成功的requests + * bucketsPerSecond表示每秒有多少个桶 + * 1e3表示1000毫秒,minRT单位也是毫秒,QPS*minRT/1e3得到的就是平均每个时间点有多少并发请求 + +## 降载的使用 +* 已经在ngin和rpcx框架里增加了可选激活配置 + * CpuThreshold,如果把值设置为大于0的值,则激活该服务的自动降载机制 +* 如果请求被drop,那么错误日志里会有`dropreq`关键字 + +## 参考资料 +* [滑动平均](https://www.cnblogs.com/wuliytTaotao/p/9479958.html) +* [Sentinel自适应限流](https://github.com/alibaba/Sentinel/wiki/%E7%B3%BB%E7%BB%9F%E8%87%AA%E9%80%82%E5%BA%94%E9%99%90%E6%B5%81) +* [Kratos自适应限流保护](https://github.com/bilibili/kratos/blob/master/doc/wiki-cn/ratelimit.md) \ No newline at end of file diff --git a/doc/mapreduce.md b/doc/mapreduce.md new file mode 100644 index 00000000..4ce44150 --- /dev/null +++ b/doc/mapreduce.md @@ -0,0 +1,29 @@ +# mapreduce用法 + +## Map + +> channel是Map的返回值 + +由于Map是个并发操作,如果不用range或drain的方式,那么在使用返回值的时候,可能Map里面的代码还在读写这个返回值,可能导致数据不全或者`concurrent read write错误` + +* 如果需要收集Map生成的结果,那么使用如下方式 + + ``` + for v := range channel { + // v is with type interface{} + } + ``` + +* 如果不需要收集结果,那么就需要显式的调用mapreduce.Drain,如 + + ``` + mapreduce.Drain(channel) + ``` + +## MapReduce + +* mapper和reducer方法里可以调用cancel,调用了cancel之后返回值会是`nil, false` +* mapper里面如果有item不写入writer,那么这个item就不会被reduce收集 +* mapper里面如果有处理item时panic,那么这个item也不会被reduce收集 +* reduce是单线程,所有mapper出来的结果在这里串行处理 +* reduce里面不写writer,或者panic,会导致返回`nil, false` \ No newline at end of file diff --git a/doc/periodicalexecutor.md b/doc/periodicalexecutor.md new file mode 100644 index 00000000..bec04c35 --- /dev/null +++ b/doc/periodicalexecutor.md @@ -0,0 +1,15 @@ +# PeriodicalExecutor设计 + +# 添加任务 + +* 当前没有未执行的任务 + * 添加并启动定时器 +* 已有未执行的任务 + * 添加并检查是否到达最大缓存数 + * 如到,执行所有缓存任务 + * 未到,只添加 + +# 定时器到期 + +* 清除并执行所有缓存任务 +* 再等待N个定时周期,如果等待过程中一直没有新任务,则退出 \ No newline at end of file diff --git a/doc/rpc.md b/doc/rpc.md new file mode 100644 index 00000000..6076b393 --- /dev/null +++ b/doc/rpc.md @@ -0,0 +1,13 @@ +# rpc设计规范 + +* 目录结构 + * service/remote目录下按照服务所属模块存放,比如用户的profile接口,目录如下: + + `service/remote/user/profile.proto` + + * 生成的profile.pb.go也放在该目录下,并且profile.proto文件里要加上`package user;` + +* 错误处理 + * 需要使用status.Error(code, desc)来定义返回的错误 + * code是codes.Code类型,尽可能使用grpc已经定义好的code + * codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss错误会被自动熔断 \ No newline at end of file diff --git a/doc/sql-cache.md b/doc/sql-cache.md new file mode 100644 index 00000000..ca8a9cd7 --- /dev/null +++ b/doc/sql-cache.md @@ -0,0 +1,24 @@ +# DB缓存机制 + +## QueryRowIndex + +* 没有查询条件到Primary映射的缓存 + * 通过查询条件到DB去查询行记录,然后 + * **把Primary到行记录的缓存写到redis里** + * **把查询条件到Primary的映射保存到redis里**,*框架的Take方法自动做了* + * 可能的过期顺序 + * 查询条件到Primary的映射缓存未过期 + * Primary到行记录的缓存未过期 + * 直接返回缓存行记录 + * Primary到行记录的缓存已过期 + * 通过Primary到DB获取行记录,并写入缓存 + * 此时存在的问题是,查询条件到Primary的缓存可能已经快要过期了,短时间内的查询又会触发一次数据库查询 + * 要避免这个问题,可以让**上面粗体部分**第一个过期时间略长于第二个,比如5秒 + * 查询条件到Primary的映射缓存已过期,不管Primary到行记录的缓存是否过期 + * 查询条件到Primary的映射会被重新获取,获取过程中会自动写入新的Primary到行记录的缓存,这样两种缓存的过期时间都是刚刚设置 +* 有查询条件到Primary映射的缓存 + * 没有Primary到行记录的缓存 + * 通过Primary到DB查询行记录,并写入缓存 + * 有Primary到行记录的缓存 + * 直接返回缓存结果 + diff --git a/dq/config.go b/dq/config.go new file mode 100644 index 00000000..2eabbd83 --- /dev/null +++ b/dq/config.go @@ -0,0 +1,15 @@ +package dq + +import "zero/core/stores/redis" + +type ( + Beanstalk struct { + Endpoint string + Tube string + } + + DqConf struct { + Beanstalks []Beanstalk + Redis redis.RedisConf + } +) diff --git a/dq/connection.go b/dq/connection.go new file mode 100644 index 00000000..76bae6b4 --- /dev/null +++ b/dq/connection.go @@ -0,0 +1,65 @@ +package dq + +import ( + "sync" + + "github.com/beanstalkd/beanstalk" +) + +type connection struct { + lock sync.RWMutex + endpoint string + tube string + conn *beanstalk.Conn +} + +func newConnection(endpint, tube string) *connection { + return &connection{ + endpoint: endpint, + tube: tube, + } +} + +func (c *connection) Close() error { + c.lock.Lock() + conn := c.conn + c.conn = nil + defer c.lock.Unlock() + + if conn != nil { + return conn.Close() + } + + return nil +} + +func (c *connection) get() (*beanstalk.Conn, error) { + c.lock.RLock() + conn := c.conn + c.lock.RUnlock() + if conn != nil { + return conn, nil + } + + c.lock.Lock() + defer c.lock.Unlock() + + var err error + c.conn, err = beanstalk.Dial("tcp", c.endpoint) + if err != nil { + return nil, err + } + + c.conn.Tube.Name = c.tube + return c.conn, err +} + +func (c *connection) reset() { + c.lock.Lock() + defer c.lock.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} diff --git a/dq/consumer.go b/dq/consumer.go new file mode 100644 index 00000000..861a3916 --- /dev/null +++ b/dq/consumer.go @@ -0,0 +1,100 @@ +package dq + +import ( + "strconv" + "time" + + "zero/core/hash" + "zero/core/logx" + "zero/core/service" + "zero/core/stores/redis" +) + +const ( + expiration = 3600 // seconds + guardValue = "1" + tolerance = time.Minute * 30 +) + +var maxCheckBytes = getMaxTimeLen() + +type ( + Consume func(body []byte) + + Consumer interface { + Consume(consume Consume) + } + + consumerCluster struct { + nodes []*consumerNode + red *redis.Redis + } +) + +func NewConsumer(c DqConf) Consumer { + var nodes []*consumerNode + for _, node := range c.Beanstalks { + nodes = append(nodes, newConsumerNode(node.Endpoint, node.Tube)) + } + return &consumerCluster{ + nodes: nodes, + red: c.Redis.NewRedis(), + } +} + +func (c *consumerCluster) Consume(consume Consume) { + guardedConsume := func(body []byte) { + key := hash.Md5Hex(body) + body, ok := c.unwrap(body) + if !ok { + logx.Errorf("discarded: %q", string(body)) + return + } + + ok, err := c.red.SetnxEx(key, guardValue, expiration) + if err != nil { + logx.Error(err) + } else if ok { + consume(body) + } + } + + group := service.NewServiceGroup() + for _, node := range c.nodes { + group.Add(consumeService{ + c: node, + consume: guardedConsume, + }) + } + group.Start() +} + +func (c *consumerCluster) unwrap(body []byte) ([]byte, bool) { + var pos = -1 + for i := 0; i < maxCheckBytes; i++ { + if body[i] == timeSep { + pos = i + break + } + } + if pos < 0 { + return nil, false + } + + val, err := strconv.ParseInt(string(body[:pos]), 10, 64) + if err != nil { + logx.Error(err) + return nil, false + } + + t := time.Unix(0, val) + if t.Add(tolerance).Before(time.Now()) { + return nil, false + } + + return body[pos+1:], true +} + +func getMaxTimeLen() int { + return len(strconv.FormatInt(time.Now().UnixNano(), 10)) + 2 +} diff --git a/dq/consumernode.go b/dq/consumernode.go new file mode 100644 index 00000000..ec90f738 --- /dev/null +++ b/dq/consumernode.go @@ -0,0 +1,95 @@ +package dq + +import ( + "time" + + "zero/core/logx" + "zero/core/syncx" + + "github.com/beanstalkd/beanstalk" +) + +type ( + consumerNode struct { + conn *connection + tube string + on *syncx.AtomicBool + } + + consumeService struct { + c *consumerNode + consume Consume + } +) + +func newConsumerNode(endpoint, tube string) *consumerNode { + return &consumerNode{ + conn: newConnection(endpoint, tube), + tube: tube, + on: syncx.ForAtomicBool(true), + } +} + +func (c *consumerNode) dispose() { + c.on.Set(false) +} + +func (c *consumerNode) consumeEvents(consume Consume) { + for c.on.True() { + conn, err := c.conn.get() + if err != nil { + logx.Error(err) + time.Sleep(time.Second) + continue + } + + // because getting conn takes at most one second, reserve tasks at most 5 seconds, + // if don't check on/off here, the conn might not be closed due to + // graceful shutdon waits at most 5.5 seconds. + if !c.on.True() { + break + } + + conn.Tube.Name = c.tube + conn.TubeSet.Name[c.tube] = true + id, body, err := conn.Reserve(reserveTimeout) + if err == nil { + conn.Delete(id) + consume(body) + continue + } + + // the error can only be beanstalk.NameError or beanstalk.ConnError + switch cerr := err.(type) { + case beanstalk.ConnError: + switch cerr.Err { + case beanstalk.ErrTimeout: + // timeout error on timeout, just continue the loop + case beanstalk.ErrBadChar, beanstalk.ErrBadFormat, beanstalk.ErrBuried, beanstalk.ErrDeadline, + beanstalk.ErrDraining, beanstalk.ErrEmpty, beanstalk.ErrInternal, beanstalk.ErrJobTooBig, + beanstalk.ErrNoCRLF, beanstalk.ErrNotFound, beanstalk.ErrNotIgnored, beanstalk.ErrTooLong: + // won't reset + logx.Error(err) + default: + // beanstalk.ErrOOM, beanstalk.ErrUnknown and other errors + logx.Error(err) + c.conn.reset() + time.Sleep(time.Second) + } + default: + logx.Error(err) + } + } + + if err := c.conn.Close(); err != nil { + logx.Error(err) + } +} + +func (cs consumeService) Start() { + cs.c.consumeEvents(cs.consume) +} + +func (cs consumeService) Stop() { + cs.c.dispose() +} diff --git a/dq/producer.go b/dq/producer.go new file mode 100644 index 00000000..3b1efd23 --- /dev/null +++ b/dq/producer.go @@ -0,0 +1,156 @@ +package dq + +import ( + "bytes" + "log" + "math/rand" + "strconv" + "strings" + "time" + + "zero/core/errorx" + "zero/core/fx" + "zero/core/logx" +) + +const ( + replicaNodes = 3 + minWrittenNodes = 2 +) + +type ( + Producer interface { + At(body []byte, at time.Time) (string, error) + Close() error + Delay(body []byte, delay time.Duration) (string, error) + Revoke(ids string) error + } + + producerCluster struct { + nodes []Producer + } +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func NewProducer(beanstalks []Beanstalk) Producer { + if len(beanstalks) < minWrittenNodes { + log.Fatalf("nodes must be equal or greater than %d", minWrittenNodes) + } + + var nodes []Producer + for _, node := range beanstalks { + nodes = append(nodes, NewProducerNode(node.Endpoint, node.Tube)) + } + return &producerCluster{nodes: nodes} +} + +func (p *producerCluster) At(body []byte, at time.Time) (string, error) { + return p.insert(func(node Producer) (string, error) { + return node.At(p.wrap(body, at), at) + }) +} + +func (p *producerCluster) Close() error { + var be errorx.BatchError + for _, node := range p.nodes { + if err := node.Close(); err != nil { + be.Add(err) + } + } + return be.Err() +} + +func (p *producerCluster) Delay(body []byte, delay time.Duration) (string, error) { + return p.insert(func(node Producer) (string, error) { + return node.Delay(p.wrap(body, time.Now().Add(delay)), delay) + }) +} + +func (p *producerCluster) Revoke(ids string) error { + var be errorx.BatchError + + fx.From(func(source chan<- interface{}) { + for _, node := range p.nodes { + source <- node + } + }).Map(func(item interface{}) interface{} { + node := item.(Producer) + return node.Revoke(ids) + }).ForEach(func(item interface{}) { + if item != nil { + be.Add(item.(error)) + } + }) + + return be.Err() +} + +func (p *producerCluster) cloneNodes() []Producer { + return append([]Producer(nil), p.nodes...) +} + +func (p *producerCluster) getWriteNodes() []Producer { + if len(p.nodes) <= replicaNodes { + return p.nodes + } + + nodes := p.cloneNodes() + rand.Shuffle(len(nodes), func(i, j int) { + nodes[i], nodes[j] = nodes[j], nodes[i] + }) + return nodes[:replicaNodes] +} + +func (p *producerCluster) insert(fn func(node Producer) (string, error)) (string, error) { + type idErr struct { + id string + err error + } + var ret []idErr + fx.From(func(source chan<- interface{}) { + for _, node := range p.getWriteNodes() { + source <- node + } + }).Map(func(item interface{}) interface{} { + node := item.(Producer) + id, err := fn(node) + return idErr{ + id: id, + err: err, + } + }).ForEach(func(item interface{}) { + ret = append(ret, item.(idErr)) + }) + + var ids []string + var be errorx.BatchError + for _, val := range ret { + if val.err != nil { + be.Add(val.err) + } else { + ids = append(ids, val.id) + } + } + + jointId := strings.Join(ids, idSep) + if len(ids) >= minWrittenNodes { + return jointId, nil + } + + if err := p.Revoke(jointId); err != nil { + logx.Error(err) + } + + return "", be.Err() +} + +func (p *producerCluster) wrap(body []byte, at time.Time) []byte { + var builder bytes.Buffer + builder.WriteString(strconv.FormatInt(at.UnixNano(), 10)) + builder.WriteByte(timeSep) + builder.Write(body) + return builder.Bytes() +} diff --git a/dq/producernode.go b/dq/producernode.go new file mode 100644 index 00000000..d098c1d3 --- /dev/null +++ b/dq/producernode.go @@ -0,0 +1,98 @@ +package dq + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/beanstalkd/beanstalk" +) + +var ErrTimeBeforeNow = errors.New("can't schedule task to past time") + +type producerNode struct { + endpoint string + tube string + conn *connection +} + +func NewProducerNode(endpoint, tube string) Producer { + return &producerNode{ + endpoint: endpoint, + tube: tube, + conn: newConnection(endpoint, tube), + } +} + +func (p *producerNode) At(body []byte, at time.Time) (string, error) { + now := time.Now() + if at.Before(now) { + return "", ErrTimeBeforeNow + } + + duration := at.Sub(now) + return p.Delay(body, duration) +} + +func (p *producerNode) Close() error { + return p.conn.Close() +} + +func (p *producerNode) Delay(body []byte, delay time.Duration) (string, error) { + conn, err := p.conn.get() + if err != nil { + return "", err + } + + id, err := conn.Put(body, PriNormal, delay, defaultTimeToRun) + if err == nil { + return fmt.Sprintf("%s/%s/%d", p.endpoint, p.tube, id), nil + } + + // the error can only be beanstalk.NameError or beanstalk.ConnError + // just return when the error is beanstalk.NameError, don't reset + switch cerr := err.(type) { + case beanstalk.ConnError: + switch cerr.Err { + case beanstalk.ErrBadChar, beanstalk.ErrBadFormat, beanstalk.ErrBuried, beanstalk.ErrDeadline, + beanstalk.ErrDraining, beanstalk.ErrEmpty, beanstalk.ErrInternal, beanstalk.ErrJobTooBig, + beanstalk.ErrNoCRLF, beanstalk.ErrNotFound, beanstalk.ErrNotIgnored, beanstalk.ErrTooLong: + // won't reset + default: + // beanstalk.ErrOOM, beanstalk.ErrTimeout, beanstalk.ErrUnknown and other errors + p.conn.reset() + } + } + + return "", err +} + +func (p *producerNode) Revoke(jointId string) error { + ids := strings.Split(jointId, idSep) + for _, id := range ids { + fields := strings.Split(id, "/") + if len(fields) < 3 { + continue + } + if fields[0] != p.endpoint || fields[1] != p.tube { + continue + } + + conn, err := p.conn.get() + if err != nil { + return err + } + + n, err := strconv.ParseUint(fields[2], 10, 64) + if err != nil { + return err + } + + return conn.Delete(n) + } + + // if not in this beanstalk, ignore + return nil +} diff --git a/dq/vars.go b/dq/vars.go new file mode 100644 index 00000000..667f26dd --- /dev/null +++ b/dq/vars.go @@ -0,0 +1,15 @@ +package dq + +import "time" + +const ( + PriHigh = 1 + PriNormal = 2 + PriLow = 3 + + defaultTimeToRun = time.Second * 5 + reserveTimeout = time.Second * 5 + + idSep = "," + timeSep = '/' +) diff --git a/example/beanstalk/consumer/consumer.go b/example/beanstalk/consumer/consumer.go new file mode 100644 index 00000000..13321953 --- /dev/null +++ b/example/beanstalk/consumer/consumer.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + + "zero/core/stores/redis" + "zero/dq" +) + +func main() { + consumer := dq.NewConsumer(dq.DqConf{ + Beanstalks: []dq.Beanstalk{ + { + Endpoint: "localhost:11300", + Tube: "tube", + }, + { + Endpoint: "localhost:11301", + Tube: "tube", + }, + { + Endpoint: "localhost:11302", + Tube: "tube", + }, + { + Endpoint: "localhost:11303", + Tube: "tube", + }, + { + Endpoint: "localhost:11304", + Tube: "tube", + }, + }, + Redis: redis.RedisConf{ + Host: "localhost:6379", + Type: redis.NodeType, + }, + }) + consumer.Consume(func(body []byte) { + fmt.Println(string(body)) + }) +} diff --git a/example/beanstalk/producer/producer.go b/example/beanstalk/producer/producer.go new file mode 100644 index 00000000..7d614526 --- /dev/null +++ b/example/beanstalk/producer/producer.go @@ -0,0 +1,40 @@ +package main + +import ( + "fmt" + "strconv" + "time" + + "zero/dq" +) + +func main() { + producer := dq.NewProducer([]dq.Beanstalk{ + { + Endpoint: "localhost:11300", + Tube: "tube", + }, + { + Endpoint: "localhost:11301", + Tube: "tube", + }, + { + Endpoint: "localhost:11302", + Tube: "tube", + }, + { + Endpoint: "localhost:11303", + Tube: "tube", + }, + { + Endpoint: "localhost:11304", + Tube: "tube", + }, + }) + for i := 0; i < 5; i++ { + _, err := producer.At([]byte(strconv.Itoa(i)), time.Now().Add(time.Second*10)) + if err != nil { + fmt.Println(err) + } + } +} diff --git a/example/bloom/bloom.go b/example/bloom/bloom.go new file mode 100644 index 00000000..b76288f0 --- /dev/null +++ b/example/bloom/bloom.go @@ -0,0 +1,18 @@ +package main + +import ( + "fmt" + + "zero/core/bloom" + "zero/core/stores/redis" +) + +func main() { + store := redis.NewRedis("localhost:6379", "node") + filter := bloom.New(store, "testbloom", 64) + filter.Add([]byte("kevin")) + filter.Add([]byte("wan")) + fmt.Println(filter.Exists([]byte("kevin"))) + fmt.Println(filter.Exists([]byte("wan"))) + fmt.Println(filter.Exists([]byte("nothing"))) +} diff --git a/example/breaker/main.go b/example/breaker/main.go new file mode 100644 index 00000000..178b3da4 --- /dev/null +++ b/example/breaker/main.go @@ -0,0 +1,139 @@ +package main + +import ( + "fmt" + "math/rand" + "os" + "sync" + "sync/atomic" + "time" + + "zero/core/breaker" + "zero/core/lang" + + "gopkg.in/cheggaaa/pb.v1" +) + +const ( + duration = time.Minute * 5 + breakRange = 20 + workRange = 50 + requestInterval = time.Millisecond + // multiply to make it visible in plot + stateFator = float64(time.Second/requestInterval) / 2 +) + +type ( + server struct { + state int32 + } + + metric struct { + calls int64 + } +) + +func (m *metric) addCall() { + atomic.AddInt64(&m.calls, 1) +} + +func (m *metric) reset() int64 { + return atomic.SwapInt64(&m.calls, 0) +} + +func newServer() *server { + return &server{} +} + +func (s *server) serve(m *metric) bool { + m.addCall() + return atomic.LoadInt32(&s.state) == 1 +} + +func (s *server) start() { + go func() { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + var state int32 + for { + var v int32 + if state == 0 { + v = r.Int31n(breakRange) + } else { + v = r.Int31n(workRange) + } + time.Sleep(time.Second * time.Duration(v+1)) + state ^= 1 + atomic.StoreInt32(&s.state, state) + } + }() +} + +func runBreaker(s *server, br breaker.Breaker, duration time.Duration, m *metric) { + ticker := time.NewTicker(requestInterval) + defer ticker.Stop() + done := make(chan lang.PlaceholderType) + + go func() { + time.Sleep(duration) + close(done) + }() + + for { + select { + case <-ticker.C: + _ = br.Do(func() error { + if s.serve(m) { + return nil + } else { + return breaker.ErrServiceUnavailable + } + }) + case <-done: + return + } + } +} + +func main() { + srv := newServer() + srv.start() + + gb := breaker.NewBreaker() + fp, err := os.Create("result.csv") + lang.Must(err) + defer fp.Close() + fmt.Fprintln(fp, "seconds,state,googleCalls,netflixCalls") + + var gm, nm metric + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + var seconds int + for range ticker.C { + seconds++ + gcalls := gm.reset() + ncalls := nm.reset() + fmt.Fprintf(fp, "%d,%.2f,%d,%d\n", + seconds, float64(atomic.LoadInt32(&srv.state))*stateFator, gcalls, ncalls) + } + }() + + var waitGroup sync.WaitGroup + waitGroup.Add(1) + go func() { + runBreaker(srv, gb, duration, &gm) + waitGroup.Done() + }() + + go func() { + bar := pb.New(int(duration / time.Second)).Start() + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for range ticker.C { + bar.Increment() + } + bar.Finish() + }() + + waitGroup.Wait() +} diff --git a/example/breaker/plot.py b/example/breaker/plot.py new file mode 100644 index 00000000..62c9911a --- /dev/null +++ b/example/breaker/plot.py @@ -0,0 +1,15 @@ +import click +import pandas as pd +import matplotlib.pyplot as plt + + +@click.command() +@click.option("--csv", default="result.csv") +def main(csv): + df = pd.read_csv(csv, index_col="seconds") + df.plot() + plt.show() + + +if __name__ == "__main__": + main() diff --git a/example/clickhouse/ch.go b/example/clickhouse/ch.go new file mode 100644 index 00000000..fac3490d --- /dev/null +++ b/example/clickhouse/ch.go @@ -0,0 +1,65 @@ +package main + +import ( + "log" + "time" + + "zero/core/stores/clickhouse" + "zero/core/stores/sqlx" +) + +func main() { + conn := clickhouse.New("tcp://127.0.0.1:9000") + _, err := conn.Exec(` + CREATE TABLE IF NOT EXISTS example ( + country_code FixedString(2), + os_id UInt8, + browser_id UInt8, + categories Array(Int16), + action_day Date, + action_time DateTime + ) engine=Memory + `) + if err != nil { + log.Fatal(err) + } + + conn.Transact(func(session sqlx.Session) error { + stmt, err := session.Prepare("INSERT INTO example (country_code, os_id, browser_id, categories, action_day, action_time) VALUES (?, ?, ?, ?, ?, ?)") + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + for i := 0; i < 10; i++ { + _, err := stmt.Exec("RU", 10+i, 100+i, []int16{1, 2, 3}, time.Now(), time.Now()) + if err != nil { + log.Fatal(err) + } + } + + return nil + }) + + var items []struct { + CountryCode string `db:"country_code"` + OsID uint8 `db:"os_id"` + BrowserID uint8 `db:"browser_id"` + Categories []int16 `db:"categories"` + ActionTime time.Time `db:"action_time"` + } + + err = conn.QueryRows(&items, "SELECT country_code, os_id, browser_id, categories, action_time FROM example") + if err != nil { + log.Fatal(err) + } + + for _, item := range items { + log.Printf("country: %s, os: %d, browser: %d, categories: %v, action_time: %s", + item.CountryCode, item.OsID, item.BrowserID, item.Categories, item.ActionTime) + } + + if _, err := conn.Exec("DROP TABLE example"); err != nil { + log.Fatal(err) + } +} diff --git a/example/config/loadfromyaml/date.yml b/example/config/loadfromyaml/date.yml new file mode 100644 index 00000000..a28fb2fd --- /dev/null +++ b/example/config/loadfromyaml/date.yml @@ -0,0 +1,2 @@ +#date: "2019-06-20 00:00:00" +date: "2019-06-19T16:00:00Z" \ No newline at end of file diff --git a/example/config/loadfromyaml/main.go b/example/config/loadfromyaml/main.go new file mode 100644 index 00000000..c61b43e9 --- /dev/null +++ b/example/config/loadfromyaml/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "time" + + "zero/core/conf" + "zero/core/logx" +) + +type TimeHolder struct { + Date time.Time `json:"date"` +} + +func main() { + th := &TimeHolder{} + err := conf.LoadConfig("./date.yml", th) + if err != nil { + logx.Error(err) + } + logx.Infof("%+v", th) +} diff --git a/example/etcd/demo/Dockerfile b/example/etcd/demo/Dockerfile new file mode 100644 index 00000000..d4675b85 --- /dev/null +++ b/example/etcd/demo/Dockerfile @@ -0,0 +1,27 @@ +FROM golang:alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux +ENV GOPROXY https://goproxy.cn,direct + +RUN apk add upx + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/etcdmon example/etcd/demo/etcdmon.go +RUN upx -q /app/etcdmon + + +FROM alpine + +RUN apk update --no-cache +RUN apk add --no-cache ca-certificates +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/etcdmon /app/etcdmon + +CMD ["./etcdmon"] diff --git a/example/etcd/demo/Makefile b/example/etcd/demo/Makefile new file mode 100644 index 00000000..06b86926 --- /dev/null +++ b/example/etcd/demo/Makefile @@ -0,0 +1,13 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + #docker pull alpine + #docker pull golang:alpine + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/etcdmon:$(version) . -f example/etcd/demo/Dockerfile + #docker image prune --filter label=stage=gobuilder -f + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/etcdmon:$(version) + +deploy: push + kubectl -n xx-xiaoheiban set image deployment/etcdmon-deployment etcdmon=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcdmon:$(version) diff --git a/example/etcd/demo/etcdmon.go b/example/etcd/demo/etcdmon.go new file mode 100644 index 00000000..b623ab84 --- /dev/null +++ b/example/etcd/demo/etcdmon.go @@ -0,0 +1,169 @@ +package main + +import ( + "context" + "fmt" + "strings" + "time" + + "zero/core/discov" + "zero/core/logx" + "zero/core/proc" + "zero/core/syncx" + + "go.etcd.io/etcd/clientv3" +) + +var ( + endpoints []string + keys = []string{ + "user.rpc", + "classroom.rpc", + } + vals = make(map[string]map[string]string) + barrier syncx.Barrier +) + +type listener struct { + key string +} + +func init() { + cluster := proc.Env("ETCD_CLUSTER") + if len(cluster) > 0 { + endpoints = strings.Split(cluster, ",") + } else { + endpoints = []string{"localhost:2379"} + } +} + +func (l listener) OnAdd(key, val string) { + fmt.Printf("add, key: %s, val: %s\n", key, val) + barrier.Guard(func() { + if m, ok := vals[l.key]; ok { + m[key] = val + } else { + vals[l.key] = map[string]string{key: val} + } + }) +} + +func (l listener) OnDelete(key string) { + fmt.Printf("del, key: %s\n", key) + barrier.Guard(func() { + if m, ok := vals[l.key]; ok { + delete(m, key) + } + }) +} + +func load(cli *clientv3.Client, key string) (map[string]string, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err := cli.Get(ctx, key, clientv3.WithPrefix()) + cancel() + if err != nil { + return nil, err + } + + ret := make(map[string]string) + for _, ev := range resp.Kvs { + ret[string(ev.Key)] = string(ev.Value) + } + + return ret, nil +} + +func loadAll(cli *clientv3.Client) (map[string]map[string]string, error) { + ret := make(map[string]map[string]string) + for _, key := range keys { + m, err := load(cli, key) + if err != nil { + return nil, err + } + + ret[key] = m + } + + return ret, nil +} + +func compare(a, b map[string]map[string]string) bool { + if len(a) != len(b) { + return false + } + + for k := range a { + av := a[k] + bv := b[k] + if len(av) != len(bv) { + return false + } + + for kk := range av { + if av[kk] != bv[kk] { + return false + } + } + } + + return true +} + +func serializeMap(m map[string]map[string]string, prefix string) string { + var builder strings.Builder + for k, v := range m { + fmt.Fprintf(&builder, "%s%s:\n", prefix, k) + for kk, vv := range v { + fmt.Fprintf(&builder, "%s\t%s: %s\n", prefix, kk, vv) + } + } + return builder.String() +} + +func main() { + registry := discov.NewFacade(endpoints) + for _, key := range keys { + registry.Monitor(key, listener{key: key}) + } + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + expect, err := loadAll(registry.Client().(*clientv3.Client)) + if err != nil { + fmt.Println("[ETCD-test] can't load current keys") + continue + } + + check := func() bool { + var match bool + barrier.Guard(func() { + match = compare(expect, vals) + }) + if match { + logx.Info("match") + } + return match + } + if check() { + continue + } + + time.AfterFunc(time.Second*5, func() { + if check() { + return + } + + var builder strings.Builder + builder.WriteString(fmt.Sprintf("expect:\n%s\n", serializeMap(expect, "\t"))) + barrier.Guard(func() { + builder.WriteString(fmt.Sprintf("actual:\n%s\n", serializeMap(vals, "\t"))) + }) + fmt.Println(builder.String()) + }) + } + } +} diff --git a/example/etcd/demo/etcdmon.yaml b/example/etcd/demo/etcdmon.yaml new file mode 100644 index 00000000..0aef9940 --- /dev/null +++ b/example/etcd/demo/etcdmon.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Pod +metadata: + name: etcdmon + namespace: discov +spec: + containers: + - name: etcdmon + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcdmon:v200620093045 + imagePullPolicy: Always + env: + - name: ETCD_CLUSTER + value: etcd.discov:2379 + imagePullSecrets: + - name: aliyun diff --git a/example/etcd/discov-namespace.yaml b/example/etcd/discov-namespace.yaml new file mode 100644 index 00000000..e428b332 --- /dev/null +++ b/example/etcd/discov-namespace.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: discovery diff --git a/example/etcd/etcd.yaml b/example/etcd/etcd.yaml new file mode 100644 index 00000000..4277e6bc --- /dev/null +++ b/example/etcd/etcd.yaml @@ -0,0 +1,378 @@ +apiVersion: v1 +kind: Service +metadata: + name: discov + namespace: discovery +spec: + ports: + - name: discov-port + port: 2379 + protocol: TCP + targetPort: 2379 + selector: + app: discov + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: discov + discov_node: discov0 + name: discov0 + namespace: discovery +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - discov0 + - --initial-advertise-peer-urls + - http://discov0:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://discov0:2379 + - --initial-cluster + - discov0=http://discov0:2380,discov1=http://discov1:2380,discov2=http://discov2:2380,discov3=http://discov3:2380,discov4=http://discov4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: discov0 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - discov + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + discov_node: discov0 + name: discov0 + namespace: discovery +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + discov_node: discov0 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: discov + discov_node: discov1 + name: discov1 + namespace: discovery +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - discov1 + - --initial-advertise-peer-urls + - http://discov1:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://discov1:2379 + - --initial-cluster + - discov0=http://discov0:2380,discov1=http://discov1:2380,discov2=http://discov2:2380,discov3=http://discov3:2380,discov4=http://discov4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: discov1 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - discov + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + discov_node: discov1 + name: discov1 + namespace: discovery +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + discov_node: discov1 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: discov + discov_node: discov2 + name: discov2 + namespace: discovery +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - discov2 + - --initial-advertise-peer-urls + - http://discov2:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://discov2:2379 + - --initial-cluster + - discov0=http://discov0:2380,discov1=http://discov1:2380,discov2=http://discov2:2380,discov3=http://discov3:2380,discov4=http://discov4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: discov2 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - discov + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + discov_node: discov2 + name: discov2 + namespace: discovery +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + discov_node: discov2 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: discov + discov_node: discov3 + name: discov3 + namespace: discovery +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - discov3 + - --initial-advertise-peer-urls + - http://discov3:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://discov3:2379 + - --initial-cluster + - discov0=http://discov0:2380,discov1=http://discov1:2380,discov2=http://discov2:2380,discov3=http://discov3:2380,discov4=http://discov4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: discov3 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - discov + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + discov_node: discov3 + name: discov3 + namespace: discovery +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + discov_node: discov3 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: discov + discov_node: discov4 + name: discov4 + namespace: discovery +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - discov4 + - --initial-advertise-peer-urls + - http://discov4:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://discov4:2379 + - --initial-cluster + - discov0=http://discov0:2380,discov1=http://discov1:2380,discov2=http://discov2:2380,discov3=http://discov3:2380,discov4=http://discov4:2380 + - --initial-cluster-state + - new + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/etcd:latest + name: discov4 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + imagePullSecrets: + - name: aliyun + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: app + operator: In + values: + - discov + topologyKey: "kubernetes.io/hostname" + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + discov_node: discov4 + name: discov4 + namespace: discovery +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + discov_node: discov4 diff --git a/example/etcd/pub/Dockerfile b/example/etcd/pub/Dockerfile new file mode 100644 index 00000000..f290fe17 --- /dev/null +++ b/example/etcd/pub/Dockerfile @@ -0,0 +1,22 @@ +FROM golang:1.13-alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux +ENV GOPROXY https://goproxy.cn,direct + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/pub example/etcd/pub/pub.go + + +FROM alpine + +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/pub /app/pub + +CMD ["./pub"] diff --git a/example/etcd/pub/Makefile b/example/etcd/pub/Makefile new file mode 100644 index 00000000..7bfa4f87 --- /dev/null +++ b/example/etcd/pub/Makefile @@ -0,0 +1,11 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/pub:$(version) . -f example/etcd/pub/Dockerfile + docker image prune --filter label=stage=gobuilder -f + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/pub:$(version) + +deploy: push + kubectl -n adhoc set image deployment/pub-deployment pub=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/pub:$(version) diff --git a/example/etcd/pub/pub.go b/example/etcd/pub/pub.go new file mode 100644 index 00000000..870e4b9e --- /dev/null +++ b/example/etcd/pub/pub.go @@ -0,0 +1,27 @@ +package main + +import ( + "flag" + "fmt" + "log" + "time" + + "zero/core/discov" +) + +var value = flag.String("v", "value", "the value") + +func main() { + flag.Parse() + + client := discov.NewPublisher([]string{"etcd.discovery:2379"}, "028F2C35852D", *value) + if err := client.KeepAlive(); err != nil { + log.Fatal(err) + } + defer client.Stop() + + for { + time.Sleep(time.Second) + fmt.Println(*value) + } +} diff --git a/example/etcd/pub/pub.yaml b/example/etcd/pub/pub.yaml new file mode 100644 index 00000000..02de5c48 --- /dev/null +++ b/example/etcd/pub/pub.yaml @@ -0,0 +1,26 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: pub-deployment + namespace: adhoc + labels: + app: pub +spec: + replicas: 1 + selector: + matchLabels: + app: pub + template: + metadata: + labels: + app: pub + spec: + containers: + - name: pub + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/pub:v200213172101 + command: + - /app/pub + - -v + - ccc + imagePullSecrets: + - name: aliyun diff --git a/example/etcd/sub/Dockerfile b/example/etcd/sub/Dockerfile new file mode 100644 index 00000000..9e81330f --- /dev/null +++ b/example/etcd/sub/Dockerfile @@ -0,0 +1,22 @@ +FROM golang:1.13-alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux +ENV GOPROXY https://goproxy.cn,direct + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/sub example/etcd/sub/sub.go + + +FROM alpine + +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/sub /app/sub + +CMD ["./sub"] diff --git a/example/etcd/sub/Makefile b/example/etcd/sub/Makefile new file mode 100644 index 00000000..a16f5015 --- /dev/null +++ b/example/etcd/sub/Makefile @@ -0,0 +1,11 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/sub:$(version) . -f example/etcd/sub/Dockerfile + docker image prune --filter label=stage=gobuilder -f + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/sub:$(version) + +deploy: push + kubectl -n adhoc set image deployment/sub-deployment sub=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/sub:$(version) diff --git a/example/etcd/sub/sub.go b/example/etcd/sub/sub.go new file mode 100644 index 00000000..13cbcd72 --- /dev/null +++ b/example/etcd/sub/sub.go @@ -0,0 +1,21 @@ +package main + +import ( + "fmt" + "time" + + "zero/core/discov" +) + +func main() { + sub := discov.NewSubscriber([]string{"etcd.discovery:2379"}, "028F2C35852D", discov.Exclusive()) + ticker := time.NewTicker(time.Second * 3) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + fmt.Println("values:", sub.Values()) + } + } +} diff --git a/example/etcd/sub/sub.yaml b/example/etcd/sub/sub.yaml new file mode 100644 index 00000000..ae5af089 --- /dev/null +++ b/example/etcd/sub/sub.yaml @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: sub + name: sub + namespace: adhoc +spec: + containers: + - command: + - /app/sub + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/sub:v200213220509 + name: sub + imagePullSecrets: + - name: aliyun + restartPolicy: Always diff --git a/example/filex/pread.go b/example/filex/pread.go new file mode 100644 index 00000000..1377a8d0 --- /dev/null +++ b/example/filex/pread.go @@ -0,0 +1,140 @@ +package main + +import ( + "bufio" + "errors" + "flag" + "fmt" + "log" + "os" + "runtime" + "strconv" + "strings" + "time" + + "zero/core/filex" + "zero/core/fx" + "zero/core/logx" + + "gopkg.in/cheggaaa/pb.v1" +) + +var ( + file = flag.String("f", "", "the input file") + concurrent = flag.Int("c", runtime.NumCPU(), "concurrent goroutines") + wordVecDic TXDictionary +) + +type ( + Vector []float64 + + TXDictionary struct { + EmbeddingCount int64 + Dim int64 + Dict map[string]Vector + } + + pair struct { + key string + vec Vector + } +) + +func FastLoad(filename string) error { + if filename == "" { + return errors.New("no available dictionary") + } + + now := time.Now() + defer func() { + logx.Infof("article2vec init dictionary end used %v", time.Since(now)) + }() + + dicFile, err := os.Open(filename) + if err != nil { + return err + } + defer dicFile.Close() + + header, err := filex.FirstLine(filename) + if err != nil { + return err + } + + total := strings.Split(header, " ") + wordVecDic.EmbeddingCount, err = strconv.ParseInt(total[0], 10, 64) + if err != nil { + return err + } + + wordVecDic.Dim, err = strconv.ParseInt(total[1], 10, 64) + if err != nil { + return err + } + + wordVecDic.Dict = make(map[string]Vector, wordVecDic.EmbeddingCount) + + ranges, err := filex.SplitLineChunks(filename, *concurrent) + if err != nil { + return err + } + + info, err := os.Stat(filename) + if err != nil { + return err + } + + bar := pb.New64(info.Size()).SetUnits(pb.U_BYTES).Start() + fx.From(func(source chan<- interface{}) { + for _, each := range ranges { + source <- each + } + }).Walk(func(item interface{}, pipe chan<- interface{}) { + offsetRange := item.(filex.OffsetRange) + scanner := bufio.NewScanner(filex.NewRangeReader(dicFile, offsetRange.Start, offsetRange.Stop)) + scanner.Buffer([]byte{}, 1<<20) + reader := filex.NewProgressScanner(scanner, bar) + if offsetRange.Start == 0 { + // skip header + reader.Scan() + } + for reader.Scan() { + text := reader.Text() + elements := strings.Split(text, " ") + vec := make(Vector, wordVecDic.Dim) + for i, ele := range elements { + if i == 0 { + continue + } + + v, err := strconv.ParseFloat(ele, 64) + if err != nil { + return + } + + vec[i-1] = v + } + pipe <- pair{ + key: elements[0], + vec: vec, + } + } + }).ForEach(func(item interface{}) { + p := item.(pair) + wordVecDic.Dict[p.key] = p.vec + }) + + return nil +} + +func main() { + flag.Parse() + + start := time.Now() + if err := FastLoad(*file); err != nil { + log.Fatal(err) + } + + fmt.Println(len(wordVecDic.Dict)) + fmt.Println(time.Since(start)) +} diff --git a/example/fx/fx_test.go b/example/fx/fx_test.go new file mode 100644 index 00000000..62a115db --- /dev/null +++ b/example/fx/fx_test.go @@ -0,0 +1,25 @@ +package main + +import ( + "testing" + + "zero/core/fx" +) + +func BenchmarkFx(b *testing.B) { + type Mixed struct { + Name string + Age int + Gender int + } + for i := 0; i < b.N; i++ { + var mx Mixed + fx.Parallel(func() { + mx.Name = "hello" + }, func() { + mx.Age = 20 + }, func() { + mx.Gender = 1 + }) + } +} diff --git a/example/fx/square.go b/example/fx/square.go new file mode 100644 index 00000000..640371e8 --- /dev/null +++ b/example/fx/square.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + + "zero/core/fx" +) + +func main() { + result, err := fx.From(func(source chan<- interface{}) { + for i := 0; i < 10; i++ { + source <- i + source <- i + } + }).Map(func(item interface{}) interface{} { + i := item.(int) + return i * i + }).Filter(func(item interface{}) bool { + i := item.(int) + return i%2 == 0 + }).Distinct(func(item interface{}) interface{} { + return item + }).Reduce(func(pipe <-chan interface{}) (interface{}, error) { + var result int + for item := range pipe { + i := item.(int) + result += i + } + return result, nil + }) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(result) + } +} diff --git a/example/graceful/dns/api/Dockerfile b/example/graceful/dns/api/Dockerfile new file mode 100644 index 00000000..02f82ca1 --- /dev/null +++ b/example/graceful/dns/api/Dockerfile @@ -0,0 +1,26 @@ +FROM golang:1.13 AS builder + +ENV CGO_ENABLED 0 +ENV GOOS linux + +RUN apt-get update +RUN apt-get install -y apt-utils upx + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/graceful example/graceful/dns/api/graceful.go +RUN upx /app/graceful + + +FROM alpine + +RUN apk update --no-cache +RUN apk add --no-cache ca-certificates +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/graceful /app/graceful +COPY example/graceful/dns/api/etc/graceful-api.json /app/etc/config.json + +CMD ["./graceful", "-f", "etc/config.json"] diff --git a/example/graceful/dns/api/Makefile b/example/graceful/dns/api/Makefile new file mode 100644 index 00000000..1e36f6e0 --- /dev/null +++ b/example/graceful/dns/api/Makefile @@ -0,0 +1,11 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + docker pull alpine + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/graceful:$(version) . -f example/graceful/dns/api/Dockerfile + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/graceful:$(version) + +deploy: push + kubectl -n kevin set image deployment/graceful-deployment graceful=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/graceful:$(version) diff --git a/example/graceful/dns/api/config/config.go b/example/graceful/dns/api/config/config.go new file mode 100644 index 00000000..6ae5af47 --- /dev/null +++ b/example/graceful/dns/api/config/config.go @@ -0,0 +1,11 @@ +package config + +import ( + "zero/ngin" + "zero/rpcx" +) + +type Config struct { + ngin.NgConf + Rpc rpcx.RpcClientConf +} diff --git a/example/graceful/dns/api/etc/graceful-api.json b/example/graceful/dns/api/etc/graceful-api.json new file mode 100644 index 00000000..ea8d1437 --- /dev/null +++ b/example/graceful/dns/api/etc/graceful-api.json @@ -0,0 +1,9 @@ +{ + "Name": "graceful-api", + "Host": "0.0.0.0", + "Port": 8888, + "MaxConns": 1000000, + "Rpc": { + "Server": "dns:///gracefulrpc:3456" + } +} \ No newline at end of file diff --git a/example/graceful/dns/api/graceful.api b/example/graceful/dns/api/graceful.api new file mode 100644 index 00000000..be4323d3 --- /dev/null +++ b/example/graceful/dns/api/graceful.api @@ -0,0 +1,11 @@ +type Response { + Host string `json:"host"` + Time int64 `json:"time"` +} + +service graceful-api { + @server( + handler: GracefulHandler + ) + get /api/graceful() returns(Response) +} \ No newline at end of file diff --git a/example/graceful/dns/api/graceful.go b/example/graceful/dns/api/graceful.go new file mode 100644 index 00000000..a9c0efea --- /dev/null +++ b/example/graceful/dns/api/graceful.go @@ -0,0 +1,32 @@ +package main + +import ( + "flag" + + "zero/core/conf" + "zero/example/graceful/dns/api/config" + "zero/example/graceful/dns/api/handler" + "zero/example/graceful/dns/api/svc" + "zero/ngin" + "zero/rpcx" +) + +var configFile = flag.String("f", "etc/graceful-api.json", "the config file") + +func main() { + flag.Parse() + + var c config.Config + conf.MustLoad(*configFile, &c) + + client := rpcx.MustNewClient(c.Rpc) + ctx := &svc.ServiceContext{ + Client: client, + } + + engine := ngin.MustNewEngine(c.NgConf) + defer engine.Stop() + + handler.RegisterHandlers(engine, ctx) + engine.Start() +} diff --git a/example/graceful/dns/api/graceful.yaml b/example/graceful/dns/api/graceful.yaml new file mode 100644 index 00000000..15c7e795 --- /dev/null +++ b/example/graceful/dns/api/graceful.yaml @@ -0,0 +1,42 @@ +apiVersion: v1 +kind: Service +metadata: + name: graceful + namespace: kevin +spec: + selector: + app: graceful + type: ClusterIP + ports: + - name: graceful-port + port: 3333 + targetPort: 8888 + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: graceful-deployment + namespace: kevin + labels: + app: graceful +spec: + replicas: 3 + selector: + matchLabels: + app: graceful + template: + metadata: + labels: + app: graceful + spec: + containers: + - name: graceful + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/graceful:v191022133857 + imagePullPolicy: Always + ports: + - containerPort: 8888 + imagePullSecrets: + - name: aliyun + diff --git a/example/graceful/dns/api/handler/gracefulhandler.go b/example/graceful/dns/api/handler/gracefulhandler.go new file mode 100644 index 00000000..43268d68 --- /dev/null +++ b/example/graceful/dns/api/handler/gracefulhandler.go @@ -0,0 +1,49 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + "os" + "time" + + "zero/core/executors" + "zero/core/httpx" + "zero/core/logx" + "zero/example/graceful/dns/api/svc" + "zero/example/graceful/dns/api/types" + "zero/example/graceful/dns/rpc/graceful" +) + +func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc { + logger := executors.NewLessExecutor(time.Second) + return func(w http.ResponseWriter, r *http.Request) { + var resp types.Response + + conn, ok := ctx.Client.Next() + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + host, err := os.Hostname() + if err != nil { + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + return + } + + client := graceful.NewGraceServiceClient(conn) + rp, err := client.Grace(context.Background(), &graceful.Request{From: host}) + if err != nil { + logx.Error(err) + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + return + } + + resp.Host = rp.Host + logger.DoOrDiscard(func() { + fmt.Printf("%s from host: %s\n", time.Now().Format("15:04:05"), rp.Host) + }) + httpx.OkJson(w, resp) + } +} diff --git a/example/graceful/dns/api/handler/routes.go b/example/graceful/dns/api/handler/routes.go new file mode 100644 index 00000000..6d2dca1c --- /dev/null +++ b/example/graceful/dns/api/handler/routes.go @@ -0,0 +1,19 @@ +// DO NOT EDIT, generated by goctl +package handler + +import ( + "net/http" + + "zero/example/graceful/dns/api/svc" + "zero/ngin" +) + +func RegisterHandlers(engine *ngin.Engine, ctx *svc.ServiceContext) { + engine.AddRoutes([]ngin.Route{ + { + Method: http.MethodGet, + Path: "/api/graceful", + Handler: gracefulHandler(ctx), + }, + }) +} diff --git a/example/graceful/dns/api/svc/servicecontext.go b/example/graceful/dns/api/svc/servicecontext.go new file mode 100644 index 00000000..6efedfd0 --- /dev/null +++ b/example/graceful/dns/api/svc/servicecontext.go @@ -0,0 +1,7 @@ +package svc + +import "zero/rpcx" + +type ServiceContext struct { + Client *rpcx.RpcClient +} diff --git a/example/graceful/dns/api/types/types.go b/example/graceful/dns/api/types/types.go new file mode 100644 index 00000000..5143ec78 --- /dev/null +++ b/example/graceful/dns/api/types/types.go @@ -0,0 +1,7 @@ +// DO NOT EDIT, generated by goctl +package types + +type Response struct { + Host string `json:"host"` + Time int64 `json:"time"` +} diff --git a/example/graceful/dns/rpc/Dockerfile b/example/graceful/dns/rpc/Dockerfile new file mode 100644 index 00000000..9ad2e511 --- /dev/null +++ b/example/graceful/dns/rpc/Dockerfile @@ -0,0 +1,22 @@ +FROM golang:1.13 AS builder + +ENV CGO_ENABLED 0 +ENV GOOS linux + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/gracefulrpc example/graceful/dns/rpc/gracefulrpc.go + + +FROM alpine + +RUN apk update --no-cache +RUN apk add --no-cache ca-certificates +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/gracefulrpc /app/gracefulrpc +COPY example/graceful/dns/rpc/etc/config.json /app/etc/config.json + +CMD ["./gracefulrpc", "-f", "etc/config.json"] diff --git a/example/graceful/dns/rpc/Makefile b/example/graceful/dns/rpc/Makefile new file mode 100644 index 00000000..dd3dd344 --- /dev/null +++ b/example/graceful/dns/rpc/Makefile @@ -0,0 +1,11 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + docker pull alpine + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:$(version) . -f example/graceful/dns/rpc/Dockerfile + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:$(version) + +deploy: push + kubectl -n kevin set image deployment/gracefulrpc-deployment gracefulrpc=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:$(version) diff --git a/example/graceful/dns/rpc/etc/config.json b/example/graceful/dns/rpc/etc/config.json new file mode 100644 index 00000000..fb520d97 --- /dev/null +++ b/example/graceful/dns/rpc/etc/config.json @@ -0,0 +1,4 @@ +{ + "Name": "rpc.grace", + "ListenOn": "0.0.0.0:3456" +} diff --git a/example/graceful/dns/rpc/graceful.proto b/example/graceful/dns/rpc/graceful.proto new file mode 100644 index 00000000..1c6e465f --- /dev/null +++ b/example/graceful/dns/rpc/graceful.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package graceful; + +message Request { + string from = 1; +} + +message Response { + string host = 2; +} + +service GraceService { + rpc grace(Request) returns(Response); +} \ No newline at end of file diff --git a/example/graceful/dns/rpc/graceful/graceful.pb.go b/example/graceful/dns/rpc/graceful/graceful.pb.go new file mode 100644 index 00000000..611ecb0b --- /dev/null +++ b/example/graceful/dns/rpc/graceful/graceful.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go. +// source: graceful.proto +// DO NOT EDIT! + +/* +Package graceful is a generated protocol buffer package. + +It is generated from these files: + graceful.proto + +It has these top-level messages: + Request + Response +*/ +package graceful + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Request struct { + From string `protobuf:"bytes,1,opt,name=from" json:"from,omitempty"` +} + +func (m *Request) Reset() { *m = Request{} } +func (m *Request) String() string { return proto.CompactTextString(m) } +func (*Request) ProtoMessage() {} +func (*Request) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *Request) GetFrom() string { + if m != nil { + return m.From + } + return "" +} + +type Response struct { + Host string `protobuf:"bytes,2,opt,name=host" json:"host,omitempty"` +} + +func (m *Response) Reset() { *m = Response{} } +func (m *Response) String() string { return proto.CompactTextString(m) } +func (*Response) ProtoMessage() {} +func (*Response) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *Response) GetHost() string { + if m != nil { + return m.Host + } + return "" +} + +func init() { + proto.RegisterType((*Request)(nil), "graceful.Request") + proto.RegisterType((*Response)(nil), "graceful.Response") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for GraceService service + +type GraceServiceClient interface { + Grace(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) +} + +type graceServiceClient struct { + cc *grpc.ClientConn +} + +func NewGraceServiceClient(cc *grpc.ClientConn) GraceServiceClient { + return &graceServiceClient{cc} +} + +func (c *graceServiceClient) Grace(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) { + out := new(Response) + err := grpc.Invoke(ctx, "/graceful.GraceService/grace", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for GraceService service + +type GraceServiceServer interface { + Grace(context.Context, *Request) (*Response, error) +} + +func RegisterGraceServiceServer(s *grpc.Server, srv GraceServiceServer) { + s.RegisterService(&_GraceService_serviceDesc, srv) +} + +func _GraceService_Grace_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GraceServiceServer).Grace(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/graceful.GraceService/Grace", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GraceServiceServer).Grace(ctx, req.(*Request)) + } + return interceptor(ctx, in, info, handler) +} + +var _GraceService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "graceful.GraceService", + HandlerType: (*GraceServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "grace", + Handler: _GraceService_Grace_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "graceful.proto", +} + +func init() { proto.RegisterFile("graceful.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 134 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x2f, 0x4a, 0x4c, + 0x4e, 0x4d, 0x2b, 0xcd, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0xf1, 0x95, 0x64, + 0xb9, 0xd8, 0x83, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x84, 0xb8, 0x58, 0xd2, 0x8a, 0xf2, + 0x73, 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x25, 0x39, 0x2e, 0x8e, 0xa0, 0xd4, + 0xe2, 0x82, 0xfc, 0xbc, 0xe2, 0x54, 0x90, 0x7c, 0x46, 0x7e, 0x71, 0x89, 0x04, 0x13, 0x44, 0x1e, + 0xc4, 0x36, 0xb2, 0xe3, 0xe2, 0x71, 0x07, 0x19, 0x15, 0x9c, 0x5a, 0x54, 0x96, 0x99, 0x9c, 0x2a, + 0xa4, 0xc7, 0xc5, 0x0a, 0x36, 0x5a, 0x48, 0x50, 0x0f, 0x6e, 0x25, 0xd4, 0x7c, 0x29, 0x21, 0x64, + 0x21, 0x88, 0x99, 0x49, 0x6c, 0x60, 0xf7, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x81, 0x87, + 0xc8, 0xc1, 0xa1, 0x00, 0x00, 0x00, +} diff --git a/example/graceful/dns/rpc/gracefulrpc.go b/example/graceful/dns/rpc/gracefulrpc.go new file mode 100644 index 00000000..d1515127 --- /dev/null +++ b/example/graceful/dns/rpc/gracefulrpc.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "time" + + "zero/core/conf" + "zero/example/graceful/dns/rpc/graceful" + "zero/rpcx" + + "google.golang.org/grpc" +) + +var configFile = flag.String("f", "etc/config.json", "the config file") + +type GracefulServer struct{} + +func NewGracefulServer() *GracefulServer { + return &GracefulServer{} +} + +func (gs *GracefulServer) Grace(ctx context.Context, req *graceful.Request) (*graceful.Response, error) { + fmt.Println("=>", req) + + time.Sleep(time.Millisecond * 10) + + hostname, err := os.Hostname() + if err != nil { + return nil, err + } + + return &graceful.Response{ + Host: hostname, + }, nil +} + +func main() { + flag.Parse() + + var c rpcx.RpcServerConf + conf.MustLoad(*configFile, &c) + + server := rpcx.MustNewServer(c, func(grpcServer *grpc.Server) { + graceful.RegisterGraceServiceServer(grpcServer, NewGracefulServer()) + }) + defer server.Stop() + server.Start() +} diff --git a/example/graceful/dns/rpc/gracefulrpc.yaml b/example/graceful/dns/rpc/gracefulrpc.yaml new file mode 100644 index 00000000..61ea3595 --- /dev/null +++ b/example/graceful/dns/rpc/gracefulrpc.yaml @@ -0,0 +1,46 @@ +apiVersion: v1 +kind: Service +metadata: + name: gracefulrpc + namespace: kevin +spec: + selector: + app: gracefulrpc + type: ClusterIP + clusterIP: None + ports: + - name: gracefulrpc-port + port: 3456 + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: gracefulrpc-deployment + namespace: kevin + labels: + app: gracefulrpc +spec: + replicas: 3 + selector: + matchLabels: + app: gracefulrpc + template: + metadata: + labels: + app: gracefulrpc + spec: + containers: + - name: gracefulrpc + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:v191022143425 + imagePullPolicy: Always + ports: + - containerPort: 3456 + env: + - name: POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + imagePullSecrets: + - name: aliyun diff --git a/example/graceful/etcd/api/Dockerfile b/example/graceful/etcd/api/Dockerfile new file mode 100644 index 00000000..e096b462 --- /dev/null +++ b/example/graceful/etcd/api/Dockerfile @@ -0,0 +1,28 @@ +FROM golang:alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux + +RUN apk update +RUN apk add upx + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/graceful example/graceful/etcd/api/graceful.go +RUN upx /app/graceful + + +FROM alpine + +RUN apk update --no-cache +RUN apk add --no-cache ca-certificates +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/graceful /app/graceful +COPY example/graceful/etcd/api/etc/graceful-api.json /app/etc/config.json + +CMD ["./graceful", "-f", "etc/config.json"] diff --git a/example/graceful/etcd/api/Makefile b/example/graceful/etcd/api/Makefile new file mode 100644 index 00000000..6db78177 --- /dev/null +++ b/example/graceful/etcd/api/Makefile @@ -0,0 +1,13 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + docker pull alpine + docker pull golang:alpine + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/graceful:$(version) . -f example/graceful/etcd/api/Dockerfile + docker image prune --filter label=stage=gobuilder -f + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/graceful:$(version) + +deploy: push + kubectl -n kevin set image deployment/graceful-deployment graceful=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/graceful:$(version) diff --git a/example/graceful/etcd/api/config/config.go b/example/graceful/etcd/api/config/config.go new file mode 100644 index 00000000..6ae5af47 --- /dev/null +++ b/example/graceful/etcd/api/config/config.go @@ -0,0 +1,11 @@ +package config + +import ( + "zero/ngin" + "zero/rpcx" +) + +type Config struct { + ngin.NgConf + Rpc rpcx.RpcClientConf +} diff --git a/example/graceful/etcd/api/etc/graceful-api.json b/example/graceful/etcd/api/etc/graceful-api.json new file mode 100644 index 00000000..c46bd216 --- /dev/null +++ b/example/graceful/etcd/api/etc/graceful-api.json @@ -0,0 +1,12 @@ +{ + "Name": "graceful-api", + "Host": "0.0.0.0", + "Port": 8888, + "MaxConns": 1000000, + "Rpc": { + "Etcd": { + "Hosts": ["etcd.discov:2379"], + "Key": "rpcx" + } + } +} \ No newline at end of file diff --git a/example/graceful/etcd/api/graceful.api b/example/graceful/etcd/api/graceful.api new file mode 100644 index 00000000..be4323d3 --- /dev/null +++ b/example/graceful/etcd/api/graceful.api @@ -0,0 +1,11 @@ +type Response { + Host string `json:"host"` + Time int64 `json:"time"` +} + +service graceful-api { + @server( + handler: GracefulHandler + ) + get /api/graceful() returns(Response) +} \ No newline at end of file diff --git a/example/graceful/etcd/api/graceful.go b/example/graceful/etcd/api/graceful.go new file mode 100644 index 00000000..539ec141 --- /dev/null +++ b/example/graceful/etcd/api/graceful.go @@ -0,0 +1,32 @@ +package main + +import ( + "flag" + + "zero/core/conf" + "zero/example/graceful/etcd/api/config" + "zero/example/graceful/etcd/api/handler" + "zero/example/graceful/etcd/api/svc" + "zero/ngin" + "zero/rpcx" +) + +var configFile = flag.String("f", "etc/graceful-api.json", "the config file") + +func main() { + flag.Parse() + + var c config.Config + conf.MustLoad(*configFile, &c) + + client := rpcx.MustNewClient(c.Rpc) + ctx := &svc.ServiceContext{ + Client: client, + } + + engine := ngin.MustNewEngine(c.NgConf) + defer engine.Stop() + + handler.RegisterHandlers(engine, ctx) + engine.Start() +} diff --git a/example/graceful/etcd/api/graceful.yaml b/example/graceful/etcd/api/graceful.yaml new file mode 100644 index 00000000..cc6925e8 --- /dev/null +++ b/example/graceful/etcd/api/graceful.yaml @@ -0,0 +1,42 @@ +apiVersion: v1 +kind: Service +metadata: + name: graceful + namespace: kevin +spec: + selector: + app: graceful + type: ClusterIP + ports: + - name: graceful-port + port: 3333 + targetPort: 8888 + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: graceful-deployment + namespace: kevin + labels: + app: graceful +spec: + replicas: 3 + selector: + matchLabels: + app: graceful + template: + metadata: + labels: + app: graceful + spec: + containers: + - name: graceful + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/graceful:v191031145905 + imagePullPolicy: Always + ports: + - containerPort: 8888 + imagePullSecrets: + - name: aliyun + diff --git a/example/graceful/etcd/api/handler/gracefulhandler.go b/example/graceful/etcd/api/handler/gracefulhandler.go new file mode 100644 index 00000000..a875016b --- /dev/null +++ b/example/graceful/etcd/api/handler/gracefulhandler.go @@ -0,0 +1,49 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + "os" + "time" + + "zero/core/executors" + "zero/core/httpx" + "zero/core/logx" + "zero/example/graceful/etcd/api/svc" + "zero/example/graceful/etcd/api/types" + "zero/example/graceful/etcd/rpc/graceful" +) + +func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc { + logger := executors.NewLessExecutor(time.Second) + return func(w http.ResponseWriter, r *http.Request) { + var resp types.Response + + conn, ok := ctx.Client.Next() + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + host, err := os.Hostname() + if err != nil { + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + return + } + + client := graceful.NewGraceServiceClient(conn) + rp, err := client.Grace(context.Background(), &graceful.Request{From: host}) + if err != nil { + logx.Error(err) + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + return + } + + resp.Host = rp.Host + logger.DoOrDiscard(func() { + fmt.Printf("%s from host: %s\n", time.Now().Format("15:04:05"), rp.Host) + }) + httpx.OkJson(w, resp) + } +} diff --git a/example/graceful/etcd/api/handler/routes.go b/example/graceful/etcd/api/handler/routes.go new file mode 100644 index 00000000..a519075b --- /dev/null +++ b/example/graceful/etcd/api/handler/routes.go @@ -0,0 +1,19 @@ +// DO NOT EDIT, generated by goctl +package handler + +import ( + "net/http" + + "zero/example/graceful/etcd/api/svc" + "zero/ngin" +) + +func RegisterHandlers(engine *ngin.Engine, ctx *svc.ServiceContext) { + engine.AddRoutes([]ngin.Route{ + { + Method: http.MethodGet, + Path: "/api/graceful", + Handler: gracefulHandler(ctx), + }, + }) +} diff --git a/example/graceful/etcd/api/svc/servicecontext.go b/example/graceful/etcd/api/svc/servicecontext.go new file mode 100644 index 00000000..6efedfd0 --- /dev/null +++ b/example/graceful/etcd/api/svc/servicecontext.go @@ -0,0 +1,7 @@ +package svc + +import "zero/rpcx" + +type ServiceContext struct { + Client *rpcx.RpcClient +} diff --git a/example/graceful/etcd/api/types/types.go b/example/graceful/etcd/api/types/types.go new file mode 100644 index 00000000..5143ec78 --- /dev/null +++ b/example/graceful/etcd/api/types/types.go @@ -0,0 +1,7 @@ +// DO NOT EDIT, generated by goctl +package types + +type Response struct { + Host string `json:"host"` + Time int64 `json:"time"` +} diff --git a/example/graceful/etcd/discov/discov-namespace.yaml b/example/graceful/etcd/discov/discov-namespace.yaml new file mode 100644 index 00000000..16b397bc --- /dev/null +++ b/example/graceful/etcd/discov/discov-namespace.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: discov diff --git a/example/graceful/etcd/discov/etcd.yaml b/example/graceful/etcd/discov/etcd.yaml new file mode 100644 index 00000000..d3b2878d --- /dev/null +++ b/example/graceful/etcd/discov/etcd.yaml @@ -0,0 +1,319 @@ +apiVersion: v1 +kind: Service +metadata: + name: etcd + namespace: discov +spec: + ports: + - name: etcd-port + port: 2379 + protocol: TCP + targetPort: 2379 + selector: + app: etcd + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd0 + name: etcd0 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd0 + - --initial-advertise-peer-urls + - http://etcd0:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd0:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: quay.io/coreos/etcd:latest + name: etcd0 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd0 + name: etcd0 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd0 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd1 + name: etcd1 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd1 + - --initial-advertise-peer-urls + - http://etcd1:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd1:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: quay.io/coreos/etcd:latest + name: etcd1 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd1 + name: etcd1 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd1 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd2 + name: etcd2 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd2 + - --initial-advertise-peer-urls + - http://etcd2:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd2:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: quay.io/coreos/etcd:latest + name: etcd2 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd2 + name: etcd2 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd2 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd3 + name: etcd3 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd3 + - --initial-advertise-peer-urls + - http://etcd3:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd3:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: quay.io/coreos/etcd:latest + name: etcd3 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd3 + name: etcd3 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd3 + +--- + +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + etcd_node: etcd4 + name: etcd4 + namespace: discov +spec: + containers: + - command: + - /usr/local/bin/etcd + - --name + - etcd4 + - --initial-advertise-peer-urls + - http://etcd4:2380 + - --listen-peer-urls + - http://0.0.0.0:2380 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://etcd4:2379 + - --initial-cluster + - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 + - --initial-cluster-state + - new + image: quay.io/coreos/etcd:latest + name: etcd4 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always + +--- + +apiVersion: v1 +kind: Service +metadata: + labels: + etcd_node: etcd4 + name: etcd4 + namespace: discov +spec: + ports: + - name: client + port: 2379 + protocol: TCP + targetPort: 2379 + - name: server + port: 2380 + protocol: TCP + targetPort: 2380 + selector: + etcd_node: etcd4 + diff --git a/example/graceful/etcd/rpc/Dockerfile b/example/graceful/etcd/rpc/Dockerfile new file mode 100644 index 00000000..9d27948c --- /dev/null +++ b/example/graceful/etcd/rpc/Dockerfile @@ -0,0 +1,24 @@ +FROM golang:alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/gracefulrpc example/graceful/etcd/rpc/gracefulrpc.go + + +FROM alpine + +RUN apk update --no-cache +RUN apk add --no-cache ca-certificates +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/gracefulrpc /app/gracefulrpc +COPY example/graceful/etcd/rpc/etc/graceful-rpc.json /app/etc/config.json + +CMD ["./gracefulrpc", "-f", "etc/config.json"] diff --git a/example/graceful/etcd/rpc/Makefile b/example/graceful/etcd/rpc/Makefile new file mode 100644 index 00000000..c15956c7 --- /dev/null +++ b/example/graceful/etcd/rpc/Makefile @@ -0,0 +1,13 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + docker pull alpine + docker pull golang:alpine + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:$(version) . -f example/graceful/etcd/rpc/Dockerfile + docker image prune --filter label=stage=gobuilder -f + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:$(version) + +deploy: push + kubectl -n kevin set image deployment/gracefulrpc-deployment gracefulrpc=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:$(version) diff --git a/example/graceful/etcd/rpc/etc/graceful-rpc.json b/example/graceful/etcd/rpc/etc/graceful-rpc.json new file mode 100644 index 00000000..98614e62 --- /dev/null +++ b/example/graceful/etcd/rpc/etc/graceful-rpc.json @@ -0,0 +1,8 @@ +{ + "Name": "rpc.grace", + "ListenOn": "0.0.0.0:3456", + "Etcd": { + "Hosts": ["etcd.discov:2379"], + "Key": "rpcx" + } +} diff --git a/example/graceful/etcd/rpc/graceful.proto b/example/graceful/etcd/rpc/graceful.proto new file mode 100644 index 00000000..1c6e465f --- /dev/null +++ b/example/graceful/etcd/rpc/graceful.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package graceful; + +message Request { + string from = 1; +} + +message Response { + string host = 2; +} + +service GraceService { + rpc grace(Request) returns(Response); +} \ No newline at end of file diff --git a/example/graceful/etcd/rpc/graceful/graceful.pb.go b/example/graceful/etcd/rpc/graceful/graceful.pb.go new file mode 100644 index 00000000..611ecb0b --- /dev/null +++ b/example/graceful/etcd/rpc/graceful/graceful.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go. +// source: graceful.proto +// DO NOT EDIT! + +/* +Package graceful is a generated protocol buffer package. + +It is generated from these files: + graceful.proto + +It has these top-level messages: + Request + Response +*/ +package graceful + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Request struct { + From string `protobuf:"bytes,1,opt,name=from" json:"from,omitempty"` +} + +func (m *Request) Reset() { *m = Request{} } +func (m *Request) String() string { return proto.CompactTextString(m) } +func (*Request) ProtoMessage() {} +func (*Request) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *Request) GetFrom() string { + if m != nil { + return m.From + } + return "" +} + +type Response struct { + Host string `protobuf:"bytes,2,opt,name=host" json:"host,omitempty"` +} + +func (m *Response) Reset() { *m = Response{} } +func (m *Response) String() string { return proto.CompactTextString(m) } +func (*Response) ProtoMessage() {} +func (*Response) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *Response) GetHost() string { + if m != nil { + return m.Host + } + return "" +} + +func init() { + proto.RegisterType((*Request)(nil), "graceful.Request") + proto.RegisterType((*Response)(nil), "graceful.Response") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for GraceService service + +type GraceServiceClient interface { + Grace(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) +} + +type graceServiceClient struct { + cc *grpc.ClientConn +} + +func NewGraceServiceClient(cc *grpc.ClientConn) GraceServiceClient { + return &graceServiceClient{cc} +} + +func (c *graceServiceClient) Grace(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) { + out := new(Response) + err := grpc.Invoke(ctx, "/graceful.GraceService/grace", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for GraceService service + +type GraceServiceServer interface { + Grace(context.Context, *Request) (*Response, error) +} + +func RegisterGraceServiceServer(s *grpc.Server, srv GraceServiceServer) { + s.RegisterService(&_GraceService_serviceDesc, srv) +} + +func _GraceService_Grace_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GraceServiceServer).Grace(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/graceful.GraceService/Grace", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GraceServiceServer).Grace(ctx, req.(*Request)) + } + return interceptor(ctx, in, info, handler) +} + +var _GraceService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "graceful.GraceService", + HandlerType: (*GraceServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "grace", + Handler: _GraceService_Grace_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "graceful.proto", +} + +func init() { proto.RegisterFile("graceful.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 134 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x2f, 0x4a, 0x4c, + 0x4e, 0x4d, 0x2b, 0xcd, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0xf1, 0x95, 0x64, + 0xb9, 0xd8, 0x83, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x84, 0xb8, 0x58, 0xd2, 0x8a, 0xf2, + 0x73, 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x25, 0x39, 0x2e, 0x8e, 0xa0, 0xd4, + 0xe2, 0x82, 0xfc, 0xbc, 0xe2, 0x54, 0x90, 0x7c, 0x46, 0x7e, 0x71, 0x89, 0x04, 0x13, 0x44, 0x1e, + 0xc4, 0x36, 0xb2, 0xe3, 0xe2, 0x71, 0x07, 0x19, 0x15, 0x9c, 0x5a, 0x54, 0x96, 0x99, 0x9c, 0x2a, + 0xa4, 0xc7, 0xc5, 0x0a, 0x36, 0x5a, 0x48, 0x50, 0x0f, 0x6e, 0x25, 0xd4, 0x7c, 0x29, 0x21, 0x64, + 0x21, 0x88, 0x99, 0x49, 0x6c, 0x60, 0xf7, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x81, 0x87, + 0xc8, 0xc1, 0xa1, 0x00, 0x00, 0x00, +} diff --git a/example/graceful/etcd/rpc/gracefulrpc-env.yaml b/example/graceful/etcd/rpc/gracefulrpc-env.yaml new file mode 100644 index 00000000..9de5207f --- /dev/null +++ b/example/graceful/etcd/rpc/gracefulrpc-env.yaml @@ -0,0 +1,30 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: gracefulrpc-deployment + namespace: kevin + labels: + app: gracefulrpc +spec: + replicas: 9 + selector: + matchLabels: + app: gracefulrpc + template: + metadata: + labels: + app: gracefulrpc + spec: + containers: + - name: gracefulrpc + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:v191031144304 + imagePullPolicy: Always + ports: + - containerPort: 3456 + env: + - name: POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + imagePullSecrets: + - name: aliyun diff --git a/example/graceful/etcd/rpc/gracefulrpc-headless.yaml b/example/graceful/etcd/rpc/gracefulrpc-headless.yaml new file mode 100644 index 00000000..b635ccea --- /dev/null +++ b/example/graceful/etcd/rpc/gracefulrpc-headless.yaml @@ -0,0 +1,41 @@ + apiVersion: v1 + kind: Service + metadata: + name: gracefulrpc + namespace: kevin + spec: + selector: + app: gracefulrpc + type: ClusterIP + clusterIP: None + ports: + - name: gracefulrpc-port + port: 3456 + + --- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: gracefulrpc-deployment + namespace: kevin + labels: + app: gracefulrpc +spec: + replicas: 9 + selector: + matchLabels: + app: gracefulrpc + template: + metadata: + labels: + app: gracefulrpc + spec: + containers: + - name: gracefulrpc + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:v191031144304 + imagePullPolicy: Always + ports: + - containerPort: 3456 + imagePullSecrets: + - name: aliyun diff --git a/example/graceful/etcd/rpc/gracefulrpc.go b/example/graceful/etcd/rpc/gracefulrpc.go new file mode 100644 index 00000000..330809cf --- /dev/null +++ b/example/graceful/etcd/rpc/gracefulrpc.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "time" + + "zero/core/conf" + "zero/example/graceful/etcd/rpc/graceful" + "zero/rpcx" + + "google.golang.org/grpc" +) + +var configFile = flag.String("f", "etc/config.json", "the config file") + +type GracefulServer struct{} + +func NewGracefulServer() *GracefulServer { + return &GracefulServer{} +} + +func (gs *GracefulServer) Grace(ctx context.Context, req *graceful.Request) (*graceful.Response, error) { + fmt.Println("=>", req) + + time.Sleep(time.Millisecond * 10) + + hostname, err := os.Hostname() + if err != nil { + return nil, err + } + + return &graceful.Response{ + Host: hostname, + }, nil +} + +func main() { + flag.Parse() + + var c rpcx.RpcServerConf + conf.MustLoad(*configFile, &c) + + server := rpcx.MustNewServer(c, func(grpcServer *grpc.Server) { + graceful.RegisterGraceServiceServer(grpcServer, NewGracefulServer()) + }) + defer server.Stop() + server.Start() +} diff --git a/example/graceful/etcd/rpc/gracefulrpc.yaml b/example/graceful/etcd/rpc/gracefulrpc.yaml new file mode 100644 index 00000000..862dab1a --- /dev/null +++ b/example/graceful/etcd/rpc/gracefulrpc.yaml @@ -0,0 +1,25 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: gracefulrpc-deployment + namespace: kevin + labels: + app: gracefulrpc +spec: + replicas: 9 + selector: + matchLabels: + app: gracefulrpc + template: + metadata: + labels: + app: gracefulrpc + spec: + containers: + - name: gracefulrpc + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/gracefulrpc:v191031144304 + imagePullPolicy: Always + ports: + - containerPort: 3456 + imagePullSecrets: + - name: aliyun diff --git a/example/http/breaker/client/client.go b/example/http/breaker/client/client.go new file mode 100644 index 00000000..75c70703 --- /dev/null +++ b/example/http/breaker/client/client.go @@ -0,0 +1,170 @@ +package main + +import ( + "flag" + "fmt" + "net/http" + "os" + "sync" + "time" + + "zero/core/lang" + "zero/core/threading" + + "gopkg.in/cheggaaa/pb.v1" +) + +var ( + freq = flag.Int("freq", 100, "frequence") + duration = flag.String("duration", "10s", "duration") +) + +type ( + counting struct { + ok int + fail int + reject int + errs int + unknown int + } + + metric struct { + counting + lock sync.Mutex + } +) + +func (m *metric) addOk() { + m.lock.Lock() + m.ok++ + m.lock.Unlock() +} + +func (m *metric) addFail() { + m.lock.Lock() + m.ok++ + m.lock.Unlock() +} + +func (m *metric) addReject() { + m.lock.Lock() + m.ok++ + m.lock.Unlock() +} + +func (m *metric) addErrs() { + m.lock.Lock() + m.errs++ + m.lock.Unlock() +} + +func (m *metric) addUnknown() { + m.lock.Lock() + m.unknown++ + m.lock.Unlock() +} + +func (m *metric) reset() counting { + m.lock.Lock() + result := counting{ + ok: m.ok, + fail: m.fail, + reject: m.reject, + errs: m.errs, + unknown: m.unknown, + } + + m.ok = 0 + m.fail = 0 + m.reject = 0 + m.errs = 0 + m.unknown = 0 + m.lock.Unlock() + + return result +} + +func runRequests(url string, frequence int, metrics *metric, done <-chan lang.PlaceholderType) { + ticker := time.NewTicker(time.Second / time.Duration(frequence)) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + go func() { + resp, err := http.Get(url) + if err != nil { + metrics.addErrs() + return + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + metrics.addOk() + case http.StatusInternalServerError: + metrics.addFail() + case http.StatusServiceUnavailable: + metrics.addReject() + default: + metrics.addUnknown() + } + }() + case <-done: + return + } + } +} + +func main() { + flag.Parse() + + fp, err := os.Create("result.csv") + lang.Must(err) + defer fp.Close() + fmt.Fprintln(fp, "seconds,goodOk,goodFail,goodReject,goodErrs,goodUnknowns,goodDropRatio,"+ + "heavyOk,heavyFail,heavyReject,heavyErrs,heavyUnknowns,heavyDropRatio") + + var gm, hm metric + dur, err := time.ParseDuration(*duration) + lang.Must(err) + done := make(chan lang.PlaceholderType) + group := threading.NewRoutineGroup() + group.RunSafe(func() { + runRequests("http://localhost:8080/heavy", *freq, &hm, done) + }) + group.RunSafe(func() { + runRequests("http://localhost:8080/good", *freq, &gm, done) + }) + + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + var seconds int + for range ticker.C { + seconds++ + g := gm.reset() + h := hm.reset() + fmt.Fprintf(fp, "%d,%d,%d,%d,%d,%d,%.1f,%d,%d,%d,%d,%d,%.1f\n", + seconds, g.ok, g.fail, g.reject, g.errs, g.unknown, + float32(g.reject)/float32(g.ok+g.fail+g.reject+g.unknown), + h.ok, h.fail, h.reject, h.errs, h.unknown, + float32(h.reject)/float32(h.ok+h.fail+h.reject+h.unknown)) + } + }() + + go func() { + bar := pb.New(int(dur / time.Second)).Start() + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for range ticker.C { + bar.Increment() + } + bar.Finish() + }() + + <-time.After(dur) + close(done) + group.Wait() + time.Sleep(time.Millisecond * 900) +} diff --git a/example/http/breaker/good.sh b/example/http/breaker/good.sh new file mode 100644 index 00000000..4f93a398 --- /dev/null +++ b/example/http/breaker/good.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +hey -z 60s http://localhost:8080/good diff --git a/example/http/breaker/heavy.sh b/example/http/breaker/heavy.sh new file mode 100644 index 00000000..92b34a37 --- /dev/null +++ b/example/http/breaker/heavy.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +hey -z 60s http://localhost:8080/heavy diff --git a/example/http/breaker/server.go b/example/http/breaker/server.go new file mode 100644 index 00000000..dee0df39 --- /dev/null +++ b/example/http/breaker/server.go @@ -0,0 +1,59 @@ +package main + +import ( + "net/http" + "runtime" + "time" + + "zero/core/logx" + "zero/core/service" + "zero/core/stat" + "zero/core/syncx" + "zero/ngin" +) + +func main() { + logx.Disable() + stat.SetReporter(nil) + server := ngin.MustNewEngine(ngin.NgConf{ + ServiceConf: service.ServiceConf{ + Name: "breaker", + Log: logx.LogConf{ + Mode: "console", + }, + }, + Host: "0.0.0.0", + Port: 8080, + MaxConns: 1000, + Timeout: 3000, + }) + latch := syncx.NewLimit(10) + server.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/heavy", + Handler: func(w http.ResponseWriter, r *http.Request) { + if latch.TryBorrow() { + defer latch.Return() + runtime.LockOSThread() + defer runtime.UnlockOSThread() + begin := time.Now() + for { + if time.Now().Sub(begin) > time.Millisecond*50 { + break + } + } + } else { + w.WriteHeader(http.StatusInternalServerError) + } + }, + }) + server.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/good", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, + }) + defer server.Stop() + server.Start() +} diff --git a/example/http/breaker/start.sh b/example/http/breaker/start.sh new file mode 100644 index 00000000..d9874ea8 --- /dev/null +++ b/example/http/breaker/start.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +GOOS=linux go build -ldflags="-s -w" server.go +docker run --rm -it --cpus=1 -p 8080:8080 -v `pwd`:/app -w /app alpine /app/server +rm -f server diff --git a/example/http/crypt/crypt.go b/example/http/crypt/crypt.go new file mode 100644 index 00000000..6bb47f5b --- /dev/null +++ b/example/http/crypt/crypt.go @@ -0,0 +1,56 @@ +package main + +import ( + "fmt" + "log" + + "zero/core/codec" +) + +const ( + pubKey = `-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE +eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH +miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR +my47YlhspwszKdRP+wIDAQAB +-----END PUBLIC KEY-----` + body = "hello" +) + +var key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D") + +func main() { + encrypter, err := codec.NewRsaEncrypter([]byte(pubKey)) + if err != nil { + log.Fatal(err) + } + + decrypter, err := codec.NewRsaDecrypter("private.pem") + if err != nil { + log.Fatal(err) + } + + output, err := encrypter.Encrypt([]byte(body)) + if err != nil { + log.Fatal(err) + } + + actual, err := decrypter.Decrypt(output) + if err != nil { + log.Fatal(err) + } + + fmt.Println(actual) + + out, err := codec.EcbEncrypt(key, []byte(body)) + if err != nil { + log.Fatal(err) + } + + ret, err := codec.EcbDecrypt(key, out) + if err != nil { + log.Fatal(err) + } + + fmt.Println(string(ret)) +} diff --git a/example/http/demo/main.go b/example/http/demo/main.go new file mode 100644 index 00000000..93fb986a --- /dev/null +++ b/example/http/demo/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "flag" + "net/http" + + "zero/core/httpx" + "zero/core/logx" + "zero/core/service" + "zero/ngin" +) + +var ( + port = flag.Int("port", 3333, "the port to listen") + timeout = flag.Int64("timeout", 0, "timeout of milliseconds") +) + +type Request struct { + User string `form:"user,options=a|b"` +} + +func first(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Middleware", "first") + next(w, r) + } +} + +func second(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Middleware", "second") + next(w, r) + } +} + +func handle(w http.ResponseWriter, r *http.Request) { + var req Request + err := httpx.Parse(r, &req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + httpx.OkJson(w, "helllo, "+req.User) +} + +func main() { + flag.Parse() + + engine := ngin.MustNewEngine(ngin.NgConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Mode: "console", + }, + }, + Port: *port, + Timeout: *timeout, + MaxConns: 500, + }) + defer engine.Stop() + + engine.Use(first) + engine.Use(second) + engine.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/", + Handler: handle, + }) + engine.Start() +} diff --git a/example/http/post/main.go b/example/http/post/main.go new file mode 100644 index 00000000..4a8fc175 --- /dev/null +++ b/example/http/post/main.go @@ -0,0 +1,65 @@ +package main + +import ( + "flag" + "fmt" + "net/http" + + "zero/core/httpx" + "zero/core/logx" + "zero/core/service" + "zero/ngin" +) + +var ( + port = flag.Int("port", 3333, "the port to listen") + timeout = flag.Int64("timeout", 0, "timeout of milliseconds") +) + +type Request struct { + User string `json:"user"` +} + +func handleGet(w http.ResponseWriter, r *http.Request) { +} + +func handlePost(w http.ResponseWriter, r *http.Request) { + var req Request + err := httpx.Parse(r, &req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + httpx.OkJson(w, fmt.Sprintf("Content-Length: %d, UserLen: %d", r.ContentLength, len(req.User))) +} + +func main() { + flag.Parse() + + engine := ngin.MustNewEngine(ngin.NgConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Mode: "console", + }, + }, + Port: *port, + Timeout: *timeout, + MaxConns: 500, + MaxBytes: 50, + CpuThreshold: 500, + }) + defer engine.Stop() + + engine.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/", + Handler: handleGet, + }) + engine.AddRoute(ngin.Route{ + Method: http.MethodPost, + Path: "/", + Handler: handlePost, + }) + engine.Start() +} diff --git a/example/http/shedding/Dockerfile b/example/http/shedding/Dockerfile new file mode 100644 index 00000000..b92e6244 --- /dev/null +++ b/example/http/shedding/Dockerfile @@ -0,0 +1,11 @@ +FROM alpine + +RUN apk update --no-cache +RUN apk add --no-cache ca-certificates +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY main /app/main + +CMD ["./main"] diff --git a/example/http/shedding/main.go b/example/http/shedding/main.go new file mode 100644 index 00000000..887ae854 --- /dev/null +++ b/example/http/shedding/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "flag" + "math" + "net/http" + "time" + + "zero/core/httpx" + "zero/core/logx" + "zero/core/service" + "zero/ngin" +) + +var ( + port = flag.Int("port", 3333, "the port to listen") + timeout = flag.Int64("timeout", 1000, "timeout of milliseconds") + cpu = flag.Int64("cpu", 500, "cpu threshold") +) + +type Request struct { + User string `form:"user,optional"` +} + +func handle(w http.ResponseWriter, r *http.Request) { + var req Request + err := httpx.Parse(r, &req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var result float64 + for i := 0; i < 30000; i++ { + result += math.Sqrt(float64(i)) + } + time.Sleep(time.Millisecond * 5) + httpx.OkJson(w, result) +} + +func main() { + flag.Parse() + + logx.Disable() + engine := ngin.MustNewEngine(ngin.NgConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Mode: "console", + }, + }, + Port: *port, + Timeout: *timeout, + CpuThreshold: *cpu, + }) + defer engine.Stop() + + engine.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/", + Handler: handle, + }) + engine.Start() +} diff --git a/example/http/signature/client/client.go b/example/http/signature/client/client.go new file mode 100644 index 00000000..11c84c2e --- /dev/null +++ b/example/http/signature/client/client.go @@ -0,0 +1,113 @@ +package main + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/sha256" + "encoding/base64" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "strconv" + "strings" + "time" + + "zero/core/codec" +) + +const pubKey = `-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE +eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH +miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR +my47YlhspwszKdRP+wIDAQAB +-----END PUBLIC KEY-----` + +var ( + crypt = flag.Bool("crypt", false, "encrypt body or not") + key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D") +) + +func fingerprint(key string) string { + h := md5.New() + io.WriteString(h, key) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func hs256(key []byte, body string) string { + h := hmac.New(sha256.New, key) + io.WriteString(h, body) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func main() { + flag.Parse() + + var err error + body := "hello world!" + if *crypt { + bodyBytes, err := codec.EcbEncrypt(key, []byte(body)) + if err != nil { + log.Fatal(err) + } + body = base64.StdEncoding.EncodeToString(bodyBytes) + } + + r, err := http.NewRequest(http.MethodPost, "http://localhost:3333/a/b?c=first&d=second", strings.NewReader(body)) + if err != nil { + log.Fatal(err) + } + + timestamp := time.Now().Unix() + sha := sha256.New() + sha.Write([]byte(body)) + bodySign := fmt.Sprintf("%x", sha.Sum(nil)) + contentOfSign := strings.Join([]string{ + strconv.FormatInt(timestamp, 10), + http.MethodPost, + r.URL.Path, + r.URL.RawQuery, + bodySign, + }, "\n") + sign := hs256(key, contentOfSign) + var mode string + if *crypt { + mode = "1" + } else { + mode = "0" + } + content := strings.Join([]string{ + "version=v1", + "type=" + mode, + fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)), + "time=" + strconv.FormatInt(timestamp, 10), + }, "; ") + + encrypter, err := codec.NewRsaEncrypter([]byte(pubKey)) + if err != nil { + log.Fatal(err) + } + + output, err := encrypter.Encrypt([]byte(content)) + if err != nil { + log.Fatal(err) + } + + encryptedContent := base64.StdEncoding.EncodeToString(output) + r.Header.Set("X-Content-Security", strings.Join([]string{ + fmt.Sprintf("key=%s", fingerprint(pubKey)), + "secret=" + encryptedContent, + "signature=" + sign, + }, "; ")) + client := &http.Client{} + resp, err := client.Do(r) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + fmt.Println(resp.Status) + io.Copy(os.Stdout, resp.Body) +} diff --git a/example/http/signature/server/server.go b/example/http/signature/server/server.go new file mode 100644 index 00000000..5833ac5b --- /dev/null +++ b/example/http/signature/server/server.go @@ -0,0 +1,59 @@ +package main + +import ( + "flag" + "io" + "net/http" + + "zero/core/httpx" + "zero/core/logx" + "zero/core/service" + "zero/ngin" +) + +var keyPem = flag.String("prikey", "private.pem", "the private key file") + +type Request struct { + User string `form:"user,optional"` +} + +func handle(w http.ResponseWriter, r *http.Request) { + var req Request + err := httpx.Parse(r, &req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + io.Copy(w, r.Body) +} + +func main() { + flag.Parse() + + engine := ngin.MustNewEngine(ngin.NgConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Path: "logs", + }, + }, + Port: 3333, + Signature: ngin.SignatureConf{ + Strict: true, + PrivateKeys: []ngin.PrivateKeyConf{ + { + Fingerprint: "bvw8YlnSqb+PoMf3MBbLdQ==", + KeyFile: *keyPem, + }, + }, + }, + }) + defer engine.Stop() + + engine.AddRoute(ngin.Route{ + Method: http.MethodPost, + Path: "/a/b", + Handler: handle, + }) + engine.Start() +} diff --git a/example/jobqueue/jobqueue.go b/example/jobqueue/jobqueue.go new file mode 100644 index 00000000..71b6a3d3 --- /dev/null +++ b/example/jobqueue/jobqueue.go @@ -0,0 +1,11 @@ +package main + +import "zero/core/threading" + +func main() { + q := threading.NewTaskRunner(5) + q.Schedule(func() { + panic("hello") + }) + select {} +} diff --git a/example/json/acceptance/main.go b/example/json/acceptance/main.go new file mode 100644 index 00000000..ffe2f290 --- /dev/null +++ b/example/json/acceptance/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + + jsonx "github.com/segmentio/encoding/json" +) + +type A struct { + AA string `json:"aa,omitempty"` +} + +type B struct { + *A + BB string `json:"bb,omitempty"` +} + +func main() { + var b B + b.BB = "b" + b.A = new(A) + b.A.AA = "" + + fmt.Println("github.com/segmentio/encoding/json") + data, err := jsonx.Marshal(b) + if err != nil { + log.Fatal(err) + } + fmt.Println(string(data)) + fmt.Println() + + fmt.Println("encoding/json") + data, err = json.Marshal(b) + if err != nil { + log.Fatal(err) + } + + fmt.Println(string(data)) +} diff --git a/example/json/bench_test.go b/example/json/bench_test.go new file mode 100644 index 00000000..e6fd2b0e --- /dev/null +++ b/example/json/bench_test.go @@ -0,0 +1,74 @@ +package testjson + +import ( + "encoding/json" + "testing" + + jsoniter "github.com/json-iterator/go" + segment "github.com/segmentio/encoding/json" +) + +const input = `{"@timestamp":"2020-02-12T14:02:10.849Z","@metadata":{"beat":"filebeat","type":"doc","version":"6.1.1","topic":"k8slog"},"index":"k8slog","offset":908739,"stream":"stdout","topic":"k8slog","k8s_container_name":"shield-rpc","k8s_pod_namespace":"xx-xiaoheiban","stage":"gray","prospector":{"type":"log"},"k8s_node_name":"cn-hangzhou.i-bp15w8irul9hmm3l9mxz","beat":{"name":"log-pilot-7s6qf","hostname":"log-pilot-7s6qf","version":"6.1.1"},"source":"/host/var/lib/docker/containers/4e6dca76f3e38fb8b39631e9bb3a19f9150cc82b1dab84f71d4622a08db20bfb/4e6dca76f3e38fb8b39631e9bb3a19f9150cc82b1dab84f71d4622a08db20bfb-json.log","level":"info","duration":"39.425µs","content":"172.25.5.167:49976 - /remoteshield.Filter/Filter - {\"sentence\":\"王XX2月12日作业\"}","k8s_pod":"shield-rpc-57c9dc6797-55skf","docker_container":"k8s_shield-rpc_shield-rpc-57c9dc6797-55skf_xx-xiaoheiban_a8341ba0-30ee-11ea-8ac4-00163e0fb3ef_0"}` + +func BenchmarkStdJsonMarshal(b *testing.B) { + m := make(map[string]interface{}) + if err := json.Unmarshal([]byte(input), &m); err != nil { + b.FailNow() + } + for i := 0; i < b.N; i++ { + if _, err := json.Marshal(m); err != nil { + b.FailNow() + } + } +} + +func BenchmarkJsonIteratorMarshal(b *testing.B) { + m := make(map[string]interface{}) + if err := jsoniter.Unmarshal([]byte(input), &m); err != nil { + b.FailNow() + } + for i := 0; i < b.N; i++ { + if _, err := jsoniter.Marshal(m); err != nil { + b.FailNow() + } + } +} + +func BenchmarkSegmentioMarshal(b *testing.B) { + m := make(map[string]interface{}) + if err := segment.Unmarshal([]byte(input), &m); err != nil { + b.FailNow() + } + for i := 0; i < b.N; i++ { + if _, err := jsoniter.Marshal(m); err != nil { + b.FailNow() + } + } +} + +func BenchmarkStdJsonUnmarshal(b *testing.B) { + for i := 0; i < b.N; i++ { + m := make(map[string]interface{}) + if err := json.Unmarshal([]byte(input), &m); err != nil { + b.FailNow() + } + } +} + +func BenchmarkJsonIteratorUnmarshal(b *testing.B) { + for i := 0; i < b.N; i++ { + m := make(map[string]interface{}) + if err := jsoniter.Unmarshal([]byte(input), &m); err != nil { + b.FailNow() + } + } +} + +func BenchmarkSegmentioUnmarshal(b *testing.B) { + for i := 0; i < b.N; i++ { + m := make(map[string]interface{}) + if err := segment.Unmarshal([]byte(input), &m); err != nil { + b.FailNow() + } + } +} diff --git a/example/json/testmarshal_test.go b/example/json/testmarshal_test.go new file mode 100644 index 00000000..f3199aac --- /dev/null +++ b/example/json/testmarshal_test.go @@ -0,0 +1,31 @@ +package testjson + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMarshal(t *testing.T) { + type A struct { + A string `json:"a"` + AA string `json:"aa"` + } + type B struct { + A // can't be A A, or A `json...` + B string `json:"b"` + } + type C struct { + A `json:"a"` + C string `json:"c"` + } + a := A{A: "a", AA: "aa"} + b := B{A: a, B: "b"} + c := C{A: a, C: "c"} + + bstr, _ := json.Marshal(b) + cstr, _ := json.Marshal(c) + assert.Equal(t, `{"a":"a","aa":"aa","b":"b"}`, string(bstr)) + assert.Equal(t, `{"a":{"a":"a","aa":"aa"},"c":"c"}`, string(cstr)) +} diff --git a/example/jwt/user/user.go b/example/jwt/user/user.go new file mode 100644 index 00000000..d36caac0 --- /dev/null +++ b/example/jwt/user/user.go @@ -0,0 +1,257 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + "time" + + "zero/core/conf" + "zero/core/httpx" + "zero/ngin" + + "github.com/dgrijalva/jwt-go" + "github.com/dgrijalva/jwt-go/request" +) + +const jwtUserField = "user" + +type ( + Config struct { + ngin.NgConf + AccessSecret string + AccessExpire int64 `json:",default=1209600"` // 2 weeks + RefreshSecret string + RefreshExpire int64 `json:",default=2419200"` // 4 weeks + RefreshAfter int64 `json:",default=604800"` // 1 week + } + + TokenOptions struct { + AccessSecret string + AccessExpire int64 + RefreshSecret string + RefreshExpire int64 + RefreshAfter int64 + Fields map[string]interface{} + } + + Tokens struct { + // Access token to access the apis + AccessToken string `json:"access_token"` + // Access token expire time, generated like: time.Now().Add(time.Day*14).Unix() + AccessExpire int64 `json:"access_expire"` + // Refresh token, use this to refresh the token + RefreshToken string `json:"refresh_token"` + // Refresh token expire time, generated like: time.Now().Add(time.Month).Unix() + RefreshExpire int64 `json:"refresh_expire"` + // Recommended time to refresh the access token + RefreshAfter int64 `json:"refresh_after"` + } + + UserCredentials struct { + Username string `json:"username"` + Password string `json:"password"` + } + + User struct { + ID int `json:"id"` + Name string `json:"name"` + Username string `json:"username"` + Password string `json:"password"` + } + + Response struct { + Data string `json:"data"` + } + + Token struct { + Token string `json:"token"` + } + + AuthRequest struct { + User string `json:"u"` + } +) + +func main() { + var c Config + conf.MustLoad("user.json", &c) + + engine, err := ngin.NewEngine(c.NgConf) + if err != nil { + log.Fatal(err) + } + defer engine.Stop() + + engine.AddRoute(ngin.Route{ + Method: http.MethodPost, + Path: "/login", + Handler: LoginHandler(c), + }) + engine.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/resource", + Handler: ProtectedHandler, + }, ngin.WithJwt(c.AccessSecret)) + engine.AddRoute(ngin.Route{ + Method: http.MethodPost, + Path: "/refresh", + Handler: RefreshHandler(c), + }, ngin.WithJwt(c.RefreshSecret)) + + fmt.Println("Now listening...") + engine.Start() +} + +func RefreshHandler(c Config) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var authReq AuthRequest + + if err := httpx.Parse(r, &authReq); err != nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Println(err) + return + } + + token, err := request.ParseFromRequest(r, request.AuthorizationHeaderExtractor, + func(token *jwt.Token) (interface{}, error) { + return []byte(c.RefreshSecret), nil + }) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + fmt.Println("Unauthorized access to this resource") + return + } + + if !token.Valid { + w.WriteHeader(http.StatusUnauthorized) + fmt.Println("Token is not valid") + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + w.WriteHeader(http.StatusBadRequest) + fmt.Println("not a valid jwt.MapClaims") + return + } + + user, ok := claims[jwtUserField] + if !ok { + w.WriteHeader(http.StatusBadRequest) + fmt.Println("no user info in fresh token") + return + } + + userStr, ok := user.(string) + if !ok || authReq.User != userStr { + w.WriteHeader(http.StatusBadRequest) + fmt.Println("user info not match in query and fresh token") + return + } + + respond(w, c, userStr) + } +} + +func ProtectedHandler(w http.ResponseWriter, r *http.Request) { + response := Response{"Gained access to protected resource"} + JsonResponse(response, w) +} + +func LoginHandler(c Config) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var user UserCredentials + + if err := httpx.Parse(r, &user); err != nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, "Error in request") + return + } + + if strings.ToLower(user.Username) != "someone" { + if user.Password != "p@ssword" { + w.WriteHeader(http.StatusForbidden) + fmt.Println("Error logging in") + fmt.Fprint(w, "Invalid credentials") + return + } + } + + respond(w, c, user.Username) + } +} + +func JsonResponse(response interface{}, w http.ResponseWriter) { + content, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + w.Write(content) +} + +type () + +func buildTokens(opt TokenOptions) (Tokens, error) { + var tokens Tokens + + accessToken, err := genToken(opt.AccessSecret, opt.Fields, opt.AccessExpire) + if err != nil { + return tokens, err + } + + refreshToken, err := genToken(opt.RefreshSecret, opt.Fields, opt.RefreshExpire) + if err != nil { + return tokens, err + } + + now := time.Now().Unix() + tokens.AccessToken = accessToken + tokens.AccessExpire = now + opt.AccessExpire + tokens.RefreshAfter = now + opt.RefreshAfter + tokens.RefreshToken = refreshToken + tokens.RefreshExpire = now + opt.RefreshExpire + + return tokens, nil +} + +func genToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) { + now := time.Now().Unix() + claims := make(jwt.MapClaims) + claims["exp"] = now + seconds + claims["iat"] = now + for k, v := range payloads { + claims[k] = v + } + + token := jwt.New(jwt.SigningMethodHS256) + token.Claims = claims + + return token.SignedString([]byte(secretKey)) +} + +func respond(w http.ResponseWriter, c Config, user string) { + tokens, err := buildTokens(TokenOptions{ + AccessSecret: c.AccessSecret, + AccessExpire: c.AccessExpire, + RefreshSecret: c.RefreshSecret, + RefreshExpire: c.RefreshExpire, + RefreshAfter: c.RefreshAfter, + Fields: map[string]interface{}{ + jwtUserField: user, + }, + }) + if err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Println(err) + return + } + + httpx.OkJson(w, tokens) +} diff --git a/example/jwt/user/user.json b/example/jwt/user/user.json new file mode 100644 index 00000000..81df6d81 --- /dev/null +++ b/example/jwt/user/user.json @@ -0,0 +1,10 @@ +{ + "Name": "example.user", + "Host": "localhost", + "Port": 8080, + "AccessSecret": "B63F477D-BBA3-4E52-96D3-C0034C27694A", + "AccessExpire": 1800, + "RefreshSecret": "14F17379-EB8F-411B-8F12-6929002DCA76", + "RefreshExpire": 3600, + "RefreshAfter": 600 +} diff --git a/example/kmq/consumer/config.json b/example/kmq/consumer/config.json new file mode 100644 index 00000000..e5d36c2c --- /dev/null +++ b/example/kmq/consumer/config.json @@ -0,0 +1,12 @@ +{ + "Name": "kmq", + "Brokers": [ + "172.16.56.64:19092", + "172.16.56.65:19092", + "172.16.56.66:19092" + ], + "Group": "adhoc", + "Topic": "kevin", + "Offset": "first", + "NumProducers": 1 +} \ No newline at end of file diff --git a/example/kmq/consumer/queue.go b/example/kmq/consumer/queue.go new file mode 100644 index 00000000..fed25418 --- /dev/null +++ b/example/kmq/consumer/queue.go @@ -0,0 +1,20 @@ +package main + +import ( + "fmt" + + "zero/core/conf" + "zero/kq" +) + +func main() { + var c kq.KqConf + conf.MustLoad("config.json", &c) + + q := kq.MustNewQueue(c, kq.WithHandle(func(k, v string) error { + fmt.Printf("=> %s\n", v) + return nil + })) + defer q.Stop() + q.Start() +} diff --git a/example/kmq/producer/produce.go b/example/kmq/producer/produce.go new file mode 100644 index 00000000..006883b2 --- /dev/null +++ b/example/kmq/producer/produce.go @@ -0,0 +1,51 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "math/rand" + "strconv" + "time" + + "zero/core/cmdline" + "zero/kq" +) + +type message struct { + Key string `json:"key"` + Value string `json:"value"` + Payload string `json:"message"` +} + +func main() { + pusher := kq.NewPusher([]string{ + "172.16.56.64:19092", + "172.16.56.65:19092", + "172.16.56.66:19092", + }, "kevin") + + ticker := time.NewTicker(time.Millisecond) + for round := 0; round < 3; round++ { + select { + case <-ticker.C: + count := rand.Intn(100) + m := message{ + Key: strconv.FormatInt(time.Now().UnixNano(), 10), + Value: fmt.Sprintf("%d,%d", round, count), + Payload: fmt.Sprintf("%d,%d", round, count), + } + body, err := json.Marshal(m) + if err != nil { + log.Fatal(err) + } + + fmt.Println(string(body)) + if err := pusher.Push(string(body)); err != nil { + log.Fatal(err) + } + } + } + + cmdline.EnterToContinue() +} diff --git a/example/limit/period/periodlimit.go b/example/limit/period/periodlimit.go new file mode 100644 index 00000000..69a2af7f --- /dev/null +++ b/example/limit/period/periodlimit.go @@ -0,0 +1,66 @@ +package main + +import ( + "flag" + "fmt" + "log" + "runtime" + "strconv" + "sync" + "sync/atomic" + "time" + + "zero/core/limit" + "zero/core/stores/redis" +) + +const seconds = 5 + +var ( + rdx = flag.String("redis", "localhost:6379", "the redis, default localhost:6379") + rdxType = flag.String("redisType", "node", "the redis type, default node") + rdxPass = flag.String("redisPass", "", "the redis password") + rdxKey = flag.String("redisKey", "rate", "the redis key, default rate") + threads = flag.Int("threads", runtime.NumCPU(), "the concurrent threads, default to cores") +) + +func main() { + flag.Parse() + + store := redis.NewRedis(*rdx, *rdxType, *rdxPass) + fmt.Println(store.Ping()) + lmt := limit.NewPeriodLimit(seconds, 5, store, *rdxKey) + timer := time.NewTimer(time.Second * seconds) + quit := make(chan struct{}) + defer timer.Stop() + go func() { + <-timer.C + close(quit) + }() + + var allowed, denied int32 + var wait sync.WaitGroup + for i := 0; i < *threads; i++ { + wait.Add(1) + go func() { + for { + select { + case <-quit: + wait.Done() + return + default: + if v, err := lmt.Take(strconv.FormatInt(int64(i), 10)); err == nil && v == limit.Allowed { + atomic.AddInt32(&allowed, 1) + } else if err != nil { + log.Fatal(err) + } else { + atomic.AddInt32(&denied, 1) + } + } + } + }() + } + + wait.Wait() + fmt.Printf("allowed: %d, denied: %d, qps: %d\n", allowed, denied, (allowed+denied)/seconds) +} diff --git a/example/limit/token/tokenlimit.go b/example/limit/token/tokenlimit.go new file mode 100644 index 00000000..a1059fdd --- /dev/null +++ b/example/limit/token/tokenlimit.go @@ -0,0 +1,66 @@ +package main + +import ( + "flag" + "fmt" + "runtime" + "sync" + "sync/atomic" + "time" + + "zero/core/limit" + "zero/core/stores/redis" +) + +const ( + burst = 100 + rate = 100 + seconds = 5 +) + +var ( + rdx = flag.String("redis", "localhost:6379", "the redis, default localhost:6379") + rdxType = flag.String("redisType", "node", "the redis type, default node") + rdxKey = flag.String("redisKey", "rate", "the redis key, default rate") + rdxPass = flag.String("redisPass", "", "the redis password") + threads = flag.Int("threads", runtime.NumCPU(), "the concurrent threads, default to cores") +) + +func main() { + flag.Parse() + + store := redis.NewRedis(*rdx, *rdxType, *rdxPass) + fmt.Println(store.Ping()) + limit := limit.NewTokenLimiter(rate, burst, store, *rdxKey) + timer := time.NewTimer(time.Second * seconds) + quit := make(chan struct{}) + defer timer.Stop() + go func() { + <-timer.C + close(quit) + }() + + var allowed, denied int32 + var wait sync.WaitGroup + for i := 0; i < *threads; i++ { + wait.Add(1) + go func() { + for { + select { + case <-quit: + wait.Done() + return + default: + if limit.Allow() { + atomic.AddInt32(&allowed, 1) + } else { + atomic.AddInt32(&denied, 1) + } + } + } + }() + } + + wait.Wait() + fmt.Printf("allowed: %d, denied: %d, qps: %d\n", allowed, denied, (allowed+denied)/seconds) +} diff --git a/example/load/main.go b/example/load/main.go new file mode 100644 index 00000000..445f5ffd --- /dev/null +++ b/example/load/main.go @@ -0,0 +1,149 @@ +package main + +import ( + "flag" + "fmt" + "io" + "math" + "math/rand" + "os" + "sync" + "sync/atomic" + "time" + + "zero/core/collection" + "zero/core/executors" + "zero/core/lang" + "zero/core/syncx" + + "gopkg.in/cheggaaa/pb.v1" +) + +const ( + beta = 0.9 + total = 400 + interval = time.Second + factor = 5 +) + +var ( + seconds = flag.Int("d", 400, "duration to go") + flying uint64 + avgFlyingAggressive float64 + aggressiveLock syncx.SpinLock + avgFlyingLazy float64 + lazyLock syncx.SpinLock + avgFlyingBoth float64 + bothLock syncx.SpinLock + lessWriter *executors.LessExecutor + passCounter = collection.NewRollingWindow(50, time.Millisecond*100) + rtCounter = collection.NewRollingWindow(50, time.Millisecond*100) + index int32 +) + +func main() { + flag.Parse() + + // only log 100 records + lessWriter = executors.NewLessExecutor(interval * total / 100) + + fp, err := os.Create("result.csv") + lang.Must(err) + defer fp.Close() + fmt.Fprintln(fp, "second,maxFlight,flying,agressiveAvgFlying,lazyAvgFlying,bothAvgFlying") + + ticker := time.NewTicker(interval) + defer ticker.Stop() + bar := pb.New(*seconds * 2).Start() + var waitGroup sync.WaitGroup + batchRequests := func(i int) { + <-ticker.C + requests := (i + 1) * factor + func() { + it := time.NewTicker(interval / time.Duration(requests)) + defer it.Stop() + for j := 0; j < requests; j++ { + <-it.C + waitGroup.Add(1) + go func() { + issueRequest(fp, atomic.AddInt32(&index, 1)) + waitGroup.Done() + }() + } + bar.Increment() + }() + } + for i := 0; i < *seconds; i++ { + batchRequests(i) + } + for i := *seconds; i > 0; i-- { + batchRequests(i) + } + bar.Finish() + waitGroup.Wait() +} + +func issueRequest(writer io.Writer, idx int32) { + v := atomic.AddUint64(&flying, 1) + aggressiveLock.Lock() + af := avgFlyingAggressive*beta + float64(v)*(1-beta) + avgFlyingAggressive = af + aggressiveLock.Unlock() + bothLock.Lock() + bf := avgFlyingBoth*beta + float64(v)*(1-beta) + avgFlyingBoth = bf + bothLock.Unlock() + duration := time.Millisecond * time.Duration(rand.Int63n(10)+1) + job(duration) + passCounter.Add(1) + rtCounter.Add(float64(duration) / float64(time.Millisecond)) + v1 := atomic.AddUint64(&flying, ^uint64(0)) + lazyLock.Lock() + lf := avgFlyingLazy*beta + float64(v1)*(1-beta) + avgFlyingLazy = lf + lazyLock.Unlock() + bothLock.Lock() + bf = avgFlyingBoth*beta + float64(v1)*(1-beta) + avgFlyingBoth = bf + bothLock.Unlock() + lessWriter.DoOrDiscard(func() { + fmt.Fprintf(writer, "%d,%d,%d,%.2f,%.2f,%.2f\n", idx, maxFlight(), v, af, lf, bf) + }) +} + +func job(duration time.Duration) { + time.Sleep(duration) +} + +func maxFlight() int64 { + return int64(math.Max(1, float64(maxPass()*10)*(minRt()/1e3))) +} + +func maxPass() int64 { + var result float64 = 1 + + passCounter.Reduce(func(b *collection.Bucket) { + if b.Sum > result { + result = b.Sum + } + }) + + return int64(result) +} + +func minRt() float64 { + var result float64 = 1000 + + rtCounter.Reduce(func(b *collection.Bucket) { + if b.Count <= 0 { + return + } + + avg := math.Round(b.Sum / float64(b.Count)) + if avg < result { + result = avg + } + }) + + return result +} diff --git a/example/load/plot.py b/example/load/plot.py new file mode 100644 index 00000000..c46a4d1c --- /dev/null +++ b/example/load/plot.py @@ -0,0 +1,14 @@ +import click +import pandas as pd +import matplotlib.pyplot as plt + +@click.command() +@click.option("--csv", default="result.csv") +def main(csv): + df = pd.read_csv(csv, index_col="second") + df.drop(["agressiveAvgFlying", "bothAvgFlying"], axis=1, inplace=True) + df.plot() + plt.show() + +if __name__ == "__main__": + main() diff --git a/example/load/simulate/client/main.go b/example/load/simulate/client/main.go new file mode 100644 index 00000000..025e5c6e --- /dev/null +++ b/example/load/simulate/client/main.go @@ -0,0 +1,95 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "io" + "net/http" + "os" + "sync/atomic" + "time" + + "zero/core/fx" + "zero/core/lang" +) + +var ( + errServiceUnavailable = errors.New("service unavailable") + total int64 + pass int64 + fail int64 + drop int64 + seconds int64 = 1 +) + +func main() { + flag.Parse() + + fp, err := os.Create("result.csv") + lang.Must(err) + defer fp.Close() + fmt.Fprintln(fp, "seconds,total,pass,fail,drop") + + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for range ticker.C { + reset(fp) + } + }() + + for i := 0; ; i++ { + it := time.NewTicker(time.Second / time.Duration(atomic.LoadInt64(&seconds))) + func() { + for j := 0; j < int(seconds); j++ { + <-it.C + go issueRequest() + } + }() + it.Stop() + + cur := atomic.AddInt64(&seconds, 1) + fmt.Println(cur) + } +} + +func issueRequest() { + atomic.AddInt64(&total, 1) + err := fx.DoWithTimeout(func() error { + return job() + }, time.Second) + switch err { + case nil: + atomic.AddInt64(&pass, 1) + case errServiceUnavailable: + atomic.AddInt64(&drop, 1) + default: + atomic.AddInt64(&fail, 1) + } +} + +func job() error { + resp, err := http.Get("http://localhost:3333/") + if err != nil { + return err + } + + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + return nil + default: + return errServiceUnavailable + } +} + +func reset(writer io.Writer) { + fmt.Fprintf(writer, "%d,%d,%d,%d,%d\n", + atomic.LoadInt64(&seconds), + atomic.SwapInt64(&total, 0), + atomic.SwapInt64(&pass, 0), + atomic.SwapInt64(&fail, 0), + atomic.SwapInt64(&drop, 0), + ) +} diff --git a/example/load/simulate/client/plot.py b/example/load/simulate/client/plot.py new file mode 100644 index 00000000..8b90e942 --- /dev/null +++ b/example/load/simulate/client/plot.py @@ -0,0 +1,13 @@ +import click +import pandas as pd +import matplotlib.pyplot as plt + +@click.command() +@click.option("--csv", default="result.csv") +def main(csv): + df = pd.read_csv(csv, index_col="seconds") + df.plot() + plt.show() + +if __name__ == "__main__": + main() diff --git a/example/load/simulate/cpu/Dockerfile b/example/load/simulate/cpu/Dockerfile new file mode 100644 index 00000000..a03fb8f7 --- /dev/null +++ b/example/load/simulate/cpu/Dockerfile @@ -0,0 +1,26 @@ +FROM golang:alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux +ENV GOPROXY https://goproxy.cn,direct + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/main example/load/simulate/cpu/main.go + + +FROM alpine + +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +RUN apk add git +RUN go get github.com/vikyd/go-cpu-load + +RUN mkdir /app +COPY --from=builder /app/main /app/main + +WORKDIR /app +CMD ["/app/main"] diff --git a/example/load/simulate/cpu/Makefile b/example/load/simulate/cpu/Makefile new file mode 100644 index 00000000..3df40e86 --- /dev/null +++ b/example/load/simulate/cpu/Makefile @@ -0,0 +1,13 @@ +version := v1 + +build: + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/shedding:$(version) . -f example/load/simulate/cpu/Dockerfile + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/shedding:$(version) + +deploy: push + kubectl apply -f shedding.yaml + +clean: + kubectl delete -f shedding.yaml diff --git a/example/load/simulate/cpu/cpu-accuracy.md b/example/load/simulate/cpu/cpu-accuracy.md new file mode 100644 index 00000000..0a7ac96b --- /dev/null +++ b/example/load/simulate/cpu/cpu-accuracy.md @@ -0,0 +1,28 @@ +# cpu监控准确度测试 + +1. 启动测试pod + + `make deploy` + +2. 通过`kubectl get po -n adhoc`确认`sheeding` pod已经成功运行,通过如下命令进入pod + + `kubectl exec -it -n adhoc shedding -- sh` + +3. 启动负载 + + `/app # go-cpu-load -p 50 -c 1` + + 默认`go-cpu-load`是对每个core加上负载的,所以测试里指定了`1000m`,等同于1 core,我们指定`-c 1`让测试更具有可读性 + + `-p`可以多换几个值测试 + +4. 验证测试准确性 + + `kubectl logs -f -n adhoc shedding` + + 可以看到日志中的`CPU`报告,`1000m`表示`100%`,如果看到`500m`则表示`50%`,每分钟输出一次 + + `watch -n 5 kubectl top pod -n adhoc` + + 可以看到`kubectl`报告的`CPU`使用率,两者进行对比,即可知道是否准确 + diff --git a/example/load/simulate/cpu/main.go b/example/load/simulate/cpu/main.go new file mode 100644 index 00000000..f2a7f4f3 --- /dev/null +++ b/example/load/simulate/cpu/main.go @@ -0,0 +1,7 @@ +package main + +import _ "zero/core/stat" + +func main() { + select {} +} diff --git a/example/load/simulate/cpu/shedding.yaml b/example/load/simulate/cpu/shedding.yaml new file mode 100644 index 00000000..8089e2f5 --- /dev/null +++ b/example/load/simulate/cpu/shedding.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Pod +metadata: + name: shedding + namespace: adhoc +spec: + containers: + - name: shedding + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/shedding:v1 + imagePullPolicy: Always + resources: + requests: + cpu: 1000m + limits: + cpu: 1000m + imagePullSecrets: + - name: aliyun diff --git a/example/load/simulate/server/server.go b/example/load/simulate/server/server.go new file mode 100644 index 00000000..0a2c4776 --- /dev/null +++ b/example/load/simulate/server/server.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" + "net/http" + "runtime" + "time" + + "zero/core/fx" + "zero/core/logx" + "zero/core/service" + "zero/core/stat" + "zero/ngin" +) + +const duration = time.Millisecond + +func main() { + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for range ticker.C { + fmt.Printf("cpu: %d\n", stat.CpuUsage()) + } + }() + + logx.Disable() + engine := ngin.MustNewEngine(ngin.NgConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Mode: "console", + }, + }, + Host: "0.0.0.0", + Port: 3333, + CpuThreshold: 800, + }) + defer engine.Stop() + engine.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/", + Handler: func(w http.ResponseWriter, r *http.Request) { + if err := fx.DoWithTimeout(func() error { + job(duration) + return nil + }, time.Millisecond*100); err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + } + }, + }) + engine.Start() +} + +func job(duration time.Duration) { + done := make(chan int) + + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for { + select { + case <-done: + return + default: + } + } + }() + } + + time.Sleep(duration) + close(done) +} diff --git a/example/logging/logging.go b/example/logging/logging.go new file mode 100644 index 00000000..380d7275 --- /dev/null +++ b/example/logging/logging.go @@ -0,0 +1,30 @@ +package main + +import ( + "time" + + "zero/core/logx" +) + +func foo() { + logx.WithDuration(time.Second).Error("world") +} + +func main() { + c := logx.LogConf{ + Mode: "console", + Path: "logs", + } + logx.MustSetup(c) + defer logx.Close() + logx.Info("info") + logx.Error("error") + logx.ErrorStack("hello") + logx.Errorf("%s and %s", "hello", "world") + logx.Severef("%s severe %s", "hello", "world") + logx.Slowf("%s slow %s", "hello", "world") + logx.Statf("%s stat %s", "hello", "world") + logx.WithDuration(time.Minute + time.Second).Info("hello") + logx.WithDuration(time.Minute + time.Second).Error("hello") + foo() +} diff --git a/example/logging/redirector/main.go b/example/logging/redirector/main.go new file mode 100644 index 00000000..57074011 --- /dev/null +++ b/example/logging/redirector/main.go @@ -0,0 +1,20 @@ +package main + +import ( + "fmt" + "time" + + "zero/core/logx" +) + +func main() { + logx.MustSetup(logx.LogConf{ + Mode: "console", + }) + logx.CollectSysLog() + + line := "asdkg" + logx.Info(line) + fmt.Print(line) + time.Sleep(time.Second) +} diff --git a/example/mapreduce/countfunc/countfunc.go b/example/mapreduce/countfunc/countfunc.go new file mode 100644 index 00000000..beeff375 --- /dev/null +++ b/example/mapreduce/countfunc/countfunc.go @@ -0,0 +1,128 @@ +package main + +import ( + "bufio" + "errors" + "flag" + "fmt" + "io" + "log" + "os" + "path" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "zero/core/mapreduce" + + "github.com/google/gops/agent" +) + +var ( + dir = flag.String("d", "", "dir to enumerate") + stopOnFile = flag.String("s", "", "stop when got file") + maxFiles = flag.Int("m", 0, "at most files to process") + mode = flag.String("mode", "", "simulate mode, can be return|panic") + count uint32 +) + +func enumerateLines(filename string) chan string { + output := make(chan string) + go func() { + file, err := os.Open(filename) + if err != nil { + return + } + defer file.Close() + + reader := bufio.NewReader(file) + for { + line, err := reader.ReadString('\n') + if err == io.EOF { + break + } + + if !strings.HasPrefix(line, "#") { + output <- line + } + } + close(output) + }() + return output +} + +func mapper(filename interface{}, writer mapreduce.Writer, cancel func(error)) { + if len(*stopOnFile) > 0 && path.Base(filename.(string)) == *stopOnFile { + fmt.Printf("Stop on file: %s\n", *stopOnFile) + cancel(errors.New("stop on file")) + return + } + + var result int + for line := range enumerateLines(filename.(string)) { + if strings.HasPrefix(strings.TrimSpace(line), "func") { + result++ + } + } + + switch *mode { + case "return": + if atomic.AddUint32(&count, 1)%10 == 0 { + return + } + case "panic": + if atomic.AddUint32(&count, 1)%10 == 0 { + panic("wow") + } + } + + writer.Write(result) +} + +func reducer(input <-chan interface{}, writer mapreduce.Writer, cancel func(error)) { + var result int + + for count := range input { + v := count.(int) + if *maxFiles > 0 && result >= *maxFiles { + fmt.Printf("Reached max files: %d\n", *maxFiles) + cancel(errors.New("max files reached")) + return + } + result += v + } + + writer.Write(result) +} + +func main() { + if err := agent.Listen(agent.Options{}); err != nil { + log.Fatal(err) + } + + flag.Parse() + + if len(*dir) == 0 { + flag.Usage() + } + + fmt.Println("Processing, please wait...") + + start := time.Now() + result, err := mapreduce.MapReduce(func(source chan<- interface{}) { + filepath.Walk(*dir, func(fpath string, f os.FileInfo, err error) error { + if !f.IsDir() && path.Ext(fpath) == ".go" { + source <- fpath + } + return nil + }) + }, mapper, reducer) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(result) + fmt.Println("Elapsed:", time.Since(start)) + fmt.Println("Done") + } +} diff --git a/example/mapreduce/finishvoid/finishvoid.go b/example/mapreduce/finishvoid/finishvoid.go new file mode 100644 index 00000000..854b2c01 --- /dev/null +++ b/example/mapreduce/finishvoid/finishvoid.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "time" + + "zero/core/mapreduce" + "zero/core/timex" +) + +func main() { + start := timex.Now() + + mapreduce.FinishVoid(func() { + time.Sleep(time.Second) + }, func() { + time.Sleep(time.Second * 5) + }, func() { + time.Sleep(time.Second * 10) + }, func() { + time.Sleep(time.Second * 6) + }, func() { + if err := mapreduce.Finish(func() error { + time.Sleep(time.Second) + return nil + }, func() error { + time.Sleep(time.Second * 10) + return nil + }); err != nil { + fmt.Println(err) + } + }) + + fmt.Println(timex.Since(start)) +} diff --git a/example/mapreduce/flatmap/flatmap.go b/example/mapreduce/flatmap/flatmap.go new file mode 100644 index 00000000..afed5bf5 --- /dev/null +++ b/example/mapreduce/flatmap/flatmap.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + + "zero/core/mapreduce" +) + +var ( + persons = []string{"john", "mary", "alice", "bob"} + friends = map[string][]string{ + "john": {"harry", "hermione", "ron"}, + "mary": {"sam", "frodo"}, + "alice": {}, + "bob": {"jamie", "tyrion", "cersei"}, + } +) + +func main() { + var allFriends []string + for v := range mapreduce.Map(func(source chan<- interface{}) { + for _, each := range persons { + source <- each + } + }, func(item interface{}, writer mapreduce.Writer) { + writer.Write(friends[item.(string)]) + }, mapreduce.WithWorkers(100)) { + allFriends = append(allFriends, v.([]string)...) + } + fmt.Println(allFriends) +} diff --git a/example/mapreduce/goroutineleak/leak.go b/example/mapreduce/goroutineleak/leak.go new file mode 100644 index 00000000..0a6e43dc --- /dev/null +++ b/example/mapreduce/goroutineleak/leak.go @@ -0,0 +1,69 @@ +package main + +import ( + "errors" + "fmt" + "os" + "runtime" + "runtime/pprof" + "time" + + "zero/core/lang" + "zero/core/logx" + "zero/core/mapreduce" + "zero/core/proc" +) + +func dumpGoroutines() { + dumpFile := "goroutines.dump" + logx.Infof("Got dump goroutine signal, printing goroutine profile to %s", dumpFile) + + if f, err := os.Create(dumpFile); err != nil { + logx.Errorf("Failed to dump goroutine profile, error: %v", err) + } else { + defer f.Close() + pprof.Lookup("goroutine").WriteTo(f, 2) + } +} + +func main() { + profiler := proc.StartProfile() + defer profiler.Stop() + + done := make(chan lang.PlaceholderType) + go func() { + for { + time.Sleep(time.Second) + fmt.Println(runtime.NumGoroutine()) + } + }() + go func() { + time.Sleep(time.Minute) + dumpGoroutines() + close(done) + }() + for { + select { + case <-done: + return + default: + mapreduce.MapReduce(func(source chan<- interface{}) { + for i := 0; i < 100; i++ { + source <- i + } + }, func(item interface{}, writer mapreduce.Writer, cancel func(error)) { + if item.(int) == 40 { + cancel(errors.New("any")) + return + } + writer.Write(item) + }, func(pipe <-chan interface{}, writer mapreduce.Writer, cancel func(error)) { + list := make([]int, 0) + for p := range pipe { + list = append(list, p.(int)) + } + writer.Write(list) + }) + } + } +} diff --git a/example/mapreduce/irregular/irregular.go b/example/mapreduce/irregular/irregular.go new file mode 100644 index 00000000..3cafa6e1 --- /dev/null +++ b/example/mapreduce/irregular/irregular.go @@ -0,0 +1,28 @@ +package main + +import ( + "fmt" + "time" + + "zero/core/mapreduce" +) + +func main() { + mapreduce.MapReduceVoid(func(source chan<- interface{}) { + for i := 0; i < 10; i++ { + source <- i + } + }, func(item interface{}, writer mapreduce.Writer, cancel func(error)) { + i := item.(int) + if i == 0 { + time.Sleep(10 * time.Second) + } else { + time.Sleep(5 * time.Second) + } + writer.Write(i) + }, func(pipe <-chan interface{}, cancel func(error)) { + for i := range pipe { + fmt.Println(i) + } + }) +} diff --git a/example/mongo/time.go b/example/mongo/time.go new file mode 100644 index 00000000..84390886 --- /dev/null +++ b/example/mongo/time.go @@ -0,0 +1,37 @@ +package main + +import ( + "fmt" + "log" + "time" + + "zero/core/stores/mongo" + + "github.com/globalsign/mgo" + "github.com/globalsign/mgo/bson" +) + +type Roster struct { + Id bson.ObjectId `bson:"_id"` + CreateTime time.Time `bson:"createTime"` + Classroom mgo.DBRef `bson:"classroom"` + Member mgo.DBRef `bson:"member"` + DisplayName string `bson:"displayName"` +} + +func main() { + model := mongo.MustNewModel("localhost:27017", "blackboard", "roster") + for i := 0; i < 1000; i++ { + session, err := model.TakeSession() + if err != nil { + log.Fatal(err) + } + + var roster Roster + filter := bson.M{"_id": bson.ObjectIdHex("587353380cf2d7273d183f9e")} + fmt.Println(model.GetCollection(session).Find(filter).One(&roster)) + model.PutSession(session) + } + + time.Sleep(time.Hour) +} diff --git a/example/periodicalexecutor/pe.go b/example/periodicalexecutor/pe.go new file mode 100644 index 00000000..b4146cdb --- /dev/null +++ b/example/periodicalexecutor/pe.go @@ -0,0 +1,17 @@ +package main + +import ( + "time" + + "zero/core/executors" +) + +func main() { + exeutor := executors.NewBulkExecutor(func(items []interface{}) { + println(len(items)) + }, executors.WithBulkTasks(10)) + for { + exeutor.Add(1) + time.Sleep(time.Millisecond * 90) + } +} diff --git a/example/pool/pool.go b/example/pool/pool.go new file mode 100644 index 00000000..85b76823 --- /dev/null +++ b/example/pool/pool.go @@ -0,0 +1,53 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "sync" + "sync/atomic" + "time" + + "zero/core/lang" + "zero/core/syncx" +) + +func main() { + var count int32 + var consumed int32 + pool := syncx.NewPool(80, func() interface{} { + fmt.Printf("+ %d\n", atomic.AddInt32(&count, 1)) + return 1 + }, func(interface{}) { + fmt.Printf("- %d\n", atomic.AddInt32(&count, -1)) + }, syncx.WithMaxAge(time.Second)) + + var waitGroup sync.WaitGroup + quit := make(chan lang.PlaceholderType) + waitGroup.Add(100) + for i := 0; i < 100; i++ { + go func() { + defer func() { + waitGroup.Done() + fmt.Println("routine quit") + }() + + for { + select { + case <-quit: + return + default: + x := pool.Get().(int) + atomic.AddInt32(&consumed, 1) + pool.Put(x) + } + } + }() + } + + bufio.NewReader(os.Stdin).ReadLine() + close(quit) + fmt.Println("quitted") + waitGroup.Wait() + fmt.Printf("consumed %d\n", atomic.LoadInt32(&consumed)) +} diff --git a/example/queue/poll/poller.go b/example/queue/poll/poller.go new file mode 100644 index 00000000..44474cd8 --- /dev/null +++ b/example/queue/poll/poller.go @@ -0,0 +1,86 @@ +package main + +import ( + "flag" + "fmt" + "log" + "sync" + "time" + + "zero/core/discov" + "zero/core/lang" + "zero/core/logx" + "zero/core/service" + "zero/core/stores/redis" + "zero/rq" +) + +var ( + redisHost = flag.String("redis", "localhost:6379", "") + redisType = flag.String("type", "node", "") + redisKey = flag.String("key", "queue", "") + producers = flag.Int("producers", 1, "") + dropBefore = flag.Int64("drop", 0, "messages before seconds to drop") +) + +type Consumer struct { + lock sync.Mutex + resources map[string]interface{} +} + +func NewConsumer() *Consumer { + return &Consumer{ + resources: make(map[string]interface{}), + } +} + +func (c *Consumer) Consume(msg string) error { + fmt.Println("=>", msg) + c.lock.Lock() + defer c.lock.Unlock() + + c.resources[msg] = lang.Placeholder + + return nil +} + +func (c *Consumer) OnEvent(event interface{}) { + fmt.Printf("event: %+v\n", event) +} + +func main() { + flag.Parse() + + consumer := NewConsumer() + q, err := rq.NewMessageQueue(rq.RmqConf{ + ServiceConf: service.ServiceConf{ + Name: "queue", + Log: logx.LogConf{ + Path: "logs", + KeepDays: 3, + Compress: true, + }, + }, + Redis: redis.RedisKeyConf{ + RedisConf: redis.RedisConf{ + Host: *redisHost, + Type: *redisType, + }, + Key: *redisKey, + }, + Etcd: discov.EtcdConf{ + Hosts: []string{ + "localhost:2379", + }, + Key: "queue", + }, + DropBefore: *dropBefore, + NumProducers: *producers, + }, rq.WithHandler(consumer), rq.WithRenewId(time.Now().UnixNano())) + if err != nil { + log.Fatal(err) + } + defer q.Stop() + + q.Start() +} diff --git a/example/queue/push/pusher.go b/example/queue/push/pusher.go new file mode 100644 index 00000000..da224271 --- /dev/null +++ b/example/queue/push/pusher.go @@ -0,0 +1,31 @@ +package main + +import ( + "log" + "strconv" + "time" + + "zero/core/discov" + "zero/rq" + + "github.com/google/gops/agent" +) + +func main() { + if err := agent.Listen(agent.Options{}); err != nil { + log.Fatal(err) + } + + pusher, err := rq.NewPusher([]string{"localhost:2379"}, "queue", rq.WithConsistentStrategy( + func(msg string) (string, string, error) { + return msg, msg, nil + }, discov.BalanceWithId()), rq.WithServerSensitive()) + if err != nil { + log.Fatal(err) + } + + for i := 0; ; i++ { + pusher.Push(strconv.Itoa(i)) + time.Sleep(time.Second) + } +} diff --git a/example/redis/cluster.go b/example/redis/cluster.go new file mode 100644 index 00000000..5cf76859 --- /dev/null +++ b/example/redis/cluster.go @@ -0,0 +1,62 @@ +package main + +import ( + "flag" + "log" + + "zero/core/logx" + "zero/core/queue" + "zero/core/service" + "zero/core/stores/redis" + "zero/rq" +) + +var ( + host = flag.String("s", "10.24.232.63:7002", "server address") + mode = flag.String("m", "queue", "cluster test mode") +) + +type bridgeHandler struct { + pusher queue.QueuePusher +} + +func newBridgeHandler() rq.ConsumeHandler { + return bridgeHandler{} +} + +func (h bridgeHandler) Consume(str string) error { + logx.Info("=>", str) + return nil +} + +func main() { + flag.Parse() + + if *mode == "queue" { + mq, err := rq.NewMessageQueue(rq.RmqConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Path: "logs", + }, + }, + Redis: redis.RedisKeyConf{ + RedisConf: redis.RedisConf{ + Host: *host, + Type: "cluster", + }, + Key: "notexist", + }, + NumProducers: 1, + }, rq.WithHandler(newBridgeHandler())) + if err != nil { + log.Fatal(err) + } + defer mq.Stop() + + mq.Start() + } else { + rds := redis.NewRedis(*host, "cluster") + rds.Llen("notexist") + select {} + } +} diff --git a/example/rpc/client/direct/Dockerfile b/example/rpc/client/direct/Dockerfile new file mode 100644 index 00000000..f1380bd9 --- /dev/null +++ b/example/rpc/client/direct/Dockerfile @@ -0,0 +1,22 @@ +FROM golang:1.13-alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux +ENV GOPROXY https://goproxy.cn,direct + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/unarydirect example/rpc/client/direct/client.go + + +FROM alpine + +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/unarydirect /app/unarydirect + +CMD ["./unarydirect"] diff --git a/example/rpc/client/direct/Makefile b/example/rpc/client/direct/Makefile new file mode 100644 index 00000000..2c8abca6 --- /dev/null +++ b/example/rpc/client/direct/Makefile @@ -0,0 +1,10 @@ +version := v$(shell /bin/date "+%y%m%d%H%M%S") + +build: + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/unarydirect:$(version) . -f example/rpc/client/direct/Dockerfile + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/unarydirect:$(version) + +deploy: push + kubectl -n adhoc set image deployment/unarydirect-deployment unarydirect=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/unarydirect:$(version) diff --git a/example/rpc/client/direct/client.go b/example/rpc/client/direct/client.go new file mode 100644 index 00000000..e056ee7b --- /dev/null +++ b/example/rpc/client/direct/client.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "flag" + "fmt" + "time" + + "zero/core/discov" + "zero/example/rpc/remote/unary" + "zero/rpcx" +) + +const timeFormat = "15:04:05" + +func main() { + flag.Parse() + + client := rpcx.MustNewClient(rpcx.RpcClientConf{ + Etcd: discov.EtcdConf{ + Hosts: []string{"localhost:2379"}, + Key: "rpcx", + }, + }) + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + conn, ok := client.Next() + if !ok { + time.Sleep(time.Second) + break + } + + greet := unary.NewGreeterClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err := greet.Greet(ctx, &unary.Request{ + Name: "kevin", + }) + if err != nil { + fmt.Printf("%s X %s\n", time.Now().Format(timeFormat), err.Error()) + } else { + fmt.Printf("%s => %s\n", time.Now().Format(timeFormat), resp.Greet) + } + cancel() + } + } +} diff --git a/example/rpc/client/direct/unarydirect.yaml b/example/rpc/client/direct/unarydirect.yaml new file mode 100644 index 00000000..57358f73 --- /dev/null +++ b/example/rpc/client/direct/unarydirect.yaml @@ -0,0 +1,23 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: unarydirect-deployment + namespace: adhoc + labels: + app: unarydirect +spec: + replicas: 1 + selector: + matchLabels: + app: unarydirect + template: + metadata: + labels: + app: unarydirect + spec: + containers: + - name: unarydirect + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/unarydirect:v1 + imagePullPolicy: Always + imagePullSecrets: + - name: aliyun diff --git a/example/rpc/client/stream/client.go b/example/rpc/client/stream/client.go new file mode 100644 index 00000000..d303ccaf --- /dev/null +++ b/example/rpc/client/stream/client.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + + "zero/core/discov" + "zero/example/rpc/remote/stream" + "zero/rpcx" +) + +const name = "kevin" + +var key = flag.String("key", "rpcx", "the key on etcd") + +func main() { + flag.Parse() + + client, err := rpcx.NewClientNoAuth(discov.EtcdConf{ + Hosts: []string{"localhost:2379"}, + Key: *key, + }) + if err != nil { + log.Fatal(err) + } + + conn, ok := client.Next() + if !ok { + log.Fatal("no server") + } + + greet := stream.NewStreamGreeterClient(conn) + stm, err := greet.Greet(context.Background()) + if err != nil { + log.Fatal(err) + } + + go func() { + for { + resp, err := stm.Recv() + if err != nil { + log.Fatal(err) + } + + fmt.Println("=>", resp.Greet) + } + }() + + for i := 0; i < 3; i++ { + fmt.Println("<=", name) + if err = stm.Send(&stream.StreamReq{ + Name: name, + }); err != nil { + log.Fatal(err) + } + } +} diff --git a/example/rpc/client/unary/client.go b/example/rpc/client/unary/client.go new file mode 100644 index 00000000..d7873419 --- /dev/null +++ b/example/rpc/client/unary/client.go @@ -0,0 +1,44 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "time" + + "zero/core/conf" + "zero/example/rpc/remote/unary" + "zero/rpcx" +) + +var configFile = flag.String("f", "config.json", "the config file") + +func main() { + flag.Parse() + + var c rpcx.RpcClientConf + conf.MustLoad(*configFile, &c) + client := rpcx.MustNewClient(c) + ticker := time.NewTicker(time.Millisecond * 500) + defer ticker.Stop() + for { + select { + case <-ticker.C: + conn, ok := client.Next() + if !ok { + log.Fatal("no server") + } + + greet := unary.NewGreeterClient(conn) + resp, err := greet.Greet(context.Background(), &unary.Request{ + Name: "kevin", + }) + if err != nil { + fmt.Println("X", err.Error()) + } else { + fmt.Println("=>", resp.Greet) + } + } + } +} diff --git a/example/rpc/client/unary/config.json b/example/rpc/client/unary/config.json new file mode 100644 index 00000000..19364bfc --- /dev/null +++ b/example/rpc/client/unary/config.json @@ -0,0 +1,5 @@ +{ + "Server": "localhost:3457", + "App": "adhoc", + "Token": "E0459CF7-EA85-4E0C-BB48-C81448811511" +} diff --git a/example/rpc/client/unary/config_etcd.json b/example/rpc/client/unary/config_etcd.json new file mode 100644 index 00000000..31b18e0e --- /dev/null +++ b/example/rpc/client/unary/config_etcd.json @@ -0,0 +1,8 @@ +{ + "Hosts": [ + "127.0.0.1:2379" + ], + "Key": "sms", + "App": "adhoc", + "Token": "E0459CF7-EA85-4E0C-BB48-C81448811511" +} diff --git a/example/rpc/proxy/Dockerfile b/example/rpc/proxy/Dockerfile new file mode 100644 index 00000000..a43ae4a2 --- /dev/null +++ b/example/rpc/proxy/Dockerfile @@ -0,0 +1,16 @@ +FROM golang:1.11 AS builder + +ENV CGO_ENABLED 0 +ENV GOOS linux + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/unaryproxy example/rpc/proxy/proxy.go + + +FROM alpine + +WORKDIR /app +COPY --from=builder /app/unaryproxy /app/unaryproxy + +CMD ["./unaryproxy"] diff --git a/example/rpc/proxy/proxy.go b/example/rpc/proxy/proxy.go new file mode 100644 index 00000000..e95c2d18 --- /dev/null +++ b/example/rpc/proxy/proxy.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "flag" + + "zero/core/logx" + "zero/core/service" + "zero/example/rpc/remote/unary" + "zero/rpcx" + + "google.golang.org/grpc" +) + +var ( + listen = flag.String("listen", "0.0.0.0:3456", "the address to listen on") + server = flag.String("server", "dns:///unaryserver:3456", "the backend service") +) + +type GreetServer struct { + *rpcx.RpcProxy +} + +func (s *GreetServer) Greet(ctx context.Context, req *unary.Request) (*unary.Response, error) { + conn, err := s.TakeConn(ctx) + if err != nil { + return nil, err + } + + remote := unary.NewGreeterClient(conn) + return remote.Greet(ctx, req) +} + +func main() { + flag.Parse() + + proxy := rpcx.MustNewServer(rpcx.RpcServerConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Mode: "console", + }, + }, + ListenOn: *listen, + }, func(grpcServer *grpc.Server) { + unary.RegisterGreeterServer(grpcServer, &GreetServer{ + RpcProxy: rpcx.NewRpcProxy(*server), + }) + }) + proxy.Start() +} diff --git a/example/rpc/proxy/unaryproxy.yaml b/example/rpc/proxy/unaryproxy.yaml new file mode 100644 index 00000000..d9526cdc --- /dev/null +++ b/example/rpc/proxy/unaryproxy.yaml @@ -0,0 +1,46 @@ +apiVersion: v1 +kind: Service +metadata: + name: unaryproxy + namespace: kevin +spec: + selector: + app: unaryproxy + ports: + - name: unaryproxy-port + port: 3456 + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: unaryproxy-deployment + namespace: kevin + labels: + app: unaryproxy +spec: + replicas: 3 + selector: + matchLabels: + app: unaryproxy + template: + metadata: + labels: + app: unaryproxy + spec: + containers: + - name: unaryproxy + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/unaryproxy:v1 + imagePullPolicy: Always + ports: + - containerPort: 3456 + volumeMounts: + - name: timezone + mountPath: /etc/localtime + imagePullSecrets: + - name: aliyun + volumes: + - name: timezone + hostPath: + path: /usr/share/zoneinfo/Asia/Shanghai diff --git a/example/rpc/remote/stream/greet.pb.go b/example/rpc/remote/stream/greet.pb.go new file mode 100644 index 00000000..9df9855d --- /dev/null +++ b/example/rpc/remote/stream/greet.pb.go @@ -0,0 +1,190 @@ +// Code generated by protoc-gen-go. +// source: greet.proto +// DO NOT EDIT! + +/* +Package stream is a generated protocol buffer package. + +It is generated from these files: + greet.proto + +It has these top-level messages: + StreamReq + StreamResp +*/ +package stream + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type StreamReq struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` +} + +func (m *StreamReq) Reset() { *m = StreamReq{} } +func (m *StreamReq) String() string { return proto.CompactTextString(m) } +func (*StreamReq) ProtoMessage() {} +func (*StreamReq) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *StreamReq) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type StreamResp struct { + Greet string `protobuf:"bytes,1,opt,name=greet" json:"greet,omitempty"` +} + +func (m *StreamResp) Reset() { *m = StreamResp{} } +func (m *StreamResp) String() string { return proto.CompactTextString(m) } +func (*StreamResp) ProtoMessage() {} +func (*StreamResp) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *StreamResp) GetGreet() string { + if m != nil { + return m.Greet + } + return "" +} + +func init() { + proto.RegisterType((*StreamReq)(nil), "stream.StreamReq") + proto.RegisterType((*StreamResp)(nil), "stream.StreamResp") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for StreamGreeter service + +type StreamGreeterClient interface { + Greet(ctx context.Context, opts ...grpc.CallOption) (StreamGreeter_GreetClient, error) +} + +type streamGreeterClient struct { + cc *grpc.ClientConn +} + +func NewStreamGreeterClient(cc *grpc.ClientConn) StreamGreeterClient { + return &streamGreeterClient{cc} +} + +func (c *streamGreeterClient) Greet(ctx context.Context, opts ...grpc.CallOption) (StreamGreeter_GreetClient, error) { + stream, err := grpc.NewClientStream(ctx, &_StreamGreeter_serviceDesc.Streams[0], c.cc, "/stream.StreamGreeter/greet", opts...) + if err != nil { + return nil, err + } + x := &streamGreeterGreetClient{stream} + return x, nil +} + +type StreamGreeter_GreetClient interface { + Send(*StreamReq) error + Recv() (*StreamResp, error) + grpc.ClientStream +} + +type streamGreeterGreetClient struct { + grpc.ClientStream +} + +func (x *streamGreeterGreetClient) Send(m *StreamReq) error { + return x.ClientStream.SendMsg(m) +} + +func (x *streamGreeterGreetClient) Recv() (*StreamResp, error) { + m := new(StreamResp) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for StreamGreeter service + +type StreamGreeterServer interface { + Greet(StreamGreeter_GreetServer) error +} + +func RegisterStreamGreeterServer(s *grpc.Server, srv StreamGreeterServer) { + s.RegisterService(&_StreamGreeter_serviceDesc, srv) +} + +func _StreamGreeter_Greet_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(StreamGreeterServer).Greet(&streamGreeterGreetServer{stream}) +} + +type StreamGreeter_GreetServer interface { + Send(*StreamResp) error + Recv() (*StreamReq, error) + grpc.ServerStream +} + +type streamGreeterGreetServer struct { + grpc.ServerStream +} + +func (x *streamGreeterGreetServer) Send(m *StreamResp) error { + return x.ServerStream.SendMsg(m) +} + +func (x *streamGreeterGreetServer) Recv() (*StreamReq, error) { + m := new(StreamReq) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _StreamGreeter_serviceDesc = grpc.ServiceDesc{ + ServiceName: "stream.StreamGreeter", + HandlerType: (*StreamGreeterServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "greet", + Handler: _StreamGreeter_Greet_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "greet.proto", +} + +func init() { proto.RegisterFile("greet.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 128 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0x2f, 0x4a, 0x4d, + 0x2d, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x2e, 0x29, 0x4a, 0x4d, 0xcc, 0x55, + 0x92, 0xe7, 0xe2, 0x0c, 0x06, 0xb3, 0x82, 0x52, 0x0b, 0x85, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, + 0x53, 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x25, 0x25, 0x2e, 0x2e, 0x98, 0x82, + 0xe2, 0x02, 0x21, 0x11, 0x2e, 0x56, 0xb0, 0x29, 0x50, 0x25, 0x10, 0x8e, 0x91, 0x33, 0x17, 0x2f, + 0x44, 0x8d, 0x3b, 0x88, 0x9b, 0x5a, 0x24, 0x64, 0x04, 0x55, 0x26, 0x24, 0xa8, 0x07, 0xb1, 0x47, + 0x0f, 0x6e, 0x89, 0x94, 0x10, 0xba, 0x50, 0x71, 0x81, 0x06, 0xa3, 0x01, 0x63, 0x12, 0x1b, 0xd8, + 0x61, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0xbc, 0x34, 0x15, 0xe8, 0xa7, 0x00, 0x00, 0x00, +} diff --git a/example/rpc/remote/stream/greet.proto b/example/rpc/remote/stream/greet.proto new file mode 100644 index 00000000..fd7c1611 --- /dev/null +++ b/example/rpc/remote/stream/greet.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package stream; + +message StreamReq { + string name = 1; +} + +message StreamResp { + string greet = 1; +} + +service StreamGreeter { + rpc greet(stream StreamReq) returns (stream StreamResp); +} \ No newline at end of file diff --git a/example/rpc/remote/unary/greet.pb.go b/example/rpc/remote/unary/greet.pb.go new file mode 100644 index 00000000..76376af8 --- /dev/null +++ b/example/rpc/remote/unary/greet.pb.go @@ -0,0 +1,158 @@ +// Code generated by protoc-gen-go. +// source: greet.proto +// DO NOT EDIT! + +/* +Package unary is a generated protocol buffer package. + +It is generated from these files: + greet.proto + +It has these top-level messages: + Request + Response +*/ +package unary + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Request struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` +} + +func (m *Request) Reset() { *m = Request{} } +func (m *Request) String() string { return proto.CompactTextString(m) } +func (*Request) ProtoMessage() {} +func (*Request) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *Request) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type Response struct { + Greet string `protobuf:"bytes,1,opt,name=greet" json:"greet,omitempty"` +} + +func (m *Response) Reset() { *m = Response{} } +func (m *Response) String() string { return proto.CompactTextString(m) } +func (*Response) ProtoMessage() {} +func (*Response) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *Response) GetGreet() string { + if m != nil { + return m.Greet + } + return "" +} + +func init() { + proto.RegisterType((*Request)(nil), "unary.Request") + proto.RegisterType((*Response)(nil), "unary.Response") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Greeter service + +type GreeterClient interface { + Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) +} + +type greeterClient struct { + cc *grpc.ClientConn +} + +func NewGreeterClient(cc *grpc.ClientConn) GreeterClient { + return &greeterClient{cc} +} + +func (c *greeterClient) Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) { + out := new(Response) + err := grpc.Invoke(ctx, "/unary.Greeter/greet", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Greeter service + +type GreeterServer interface { + Greet(context.Context, *Request) (*Response, error) +} + +func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) { + s.RegisterService(&_Greeter_serviceDesc, srv) +} + +func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GreeterServer).Greet(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/unary.Greeter/Greet", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GreeterServer).Greet(ctx, req.(*Request)) + } + return interceptor(ctx, in, info, handler) +} + +var _Greeter_serviceDesc = grpc.ServiceDesc{ + ServiceName: "unary.Greeter", + HandlerType: (*GreeterServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "greet", + Handler: _Greeter_Greet_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "greet.proto", +} + +func init() { proto.RegisterFile("greet.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 126 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0x2f, 0x4a, 0x4d, + 0x2d, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2d, 0xcd, 0x4b, 0x2c, 0xaa, 0x54, 0x92, + 0xe5, 0x62, 0x0f, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0x12, 0xe2, 0x62, 0xc9, 0x4b, 0xcc, + 0x4d, 0x95, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x02, 0xb3, 0x95, 0x14, 0xb8, 0x38, 0x82, 0x52, + 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x85, 0x44, 0xb8, 0x58, 0xc1, 0x06, 0x40, 0x15, 0x40, 0x38, + 0x46, 0xc6, 0x5c, 0xec, 0xee, 0x20, 0x46, 0x6a, 0x91, 0x90, 0x06, 0x54, 0x81, 0x10, 0x9f, 0x1e, + 0xd8, 0x70, 0x3d, 0xa8, 0xc9, 0x52, 0xfc, 0x70, 0x3e, 0xc4, 0xa8, 0x24, 0x36, 0xb0, 0x1b, 0x8c, + 0x01, 0x01, 0x00, 0x00, 0xff, 0xff, 0xb8, 0x6d, 0x30, 0xb0, 0x92, 0x00, 0x00, 0x00, +} diff --git a/example/rpc/remote/unary/greet.proto b/example/rpc/remote/unary/greet.proto new file mode 100644 index 00000000..ae642244 --- /dev/null +++ b/example/rpc/remote/unary/greet.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package unary; + +message Request { + string name = 1; +} + +message Response { + string greet = 1; +} + +service Greeter { + rpc greet(Request) returns (Response); +} + + diff --git a/example/rpc/server/stream/etc/config.json b/example/rpc/server/stream/etc/config.json new file mode 100644 index 00000000..20c3688e --- /dev/null +++ b/example/rpc/server/stream/etc/config.json @@ -0,0 +1,17 @@ +{ + "Name":"test", + "MetricsUrl": "http://localhost:2222/add", + "ListenOn": "localhost:3456", + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "Key": "rpcx" + }, + "Redis": { + "Host": "localhost:6379", + "Type": "node", + "Key": "apps" + }, + "Auth": false +} diff --git a/example/rpc/server/stream/server.go b/example/rpc/server/stream/server.go new file mode 100644 index 00000000..1c49c891 --- /dev/null +++ b/example/rpc/server/stream/server.go @@ -0,0 +1,50 @@ +package main + +import ( + "fmt" + "io" + + "zero/core/conf" + "zero/example/rpc/remote/stream" + "zero/rpcx" + + "google.golang.org/grpc" +) + +type StreamGreetServer int + +func (gs StreamGreetServer) Greet(s stream.StreamGreeter_GreetServer) error { + ctx := s.Context() + for { + select { + case <-ctx.Done(): + fmt.Println("cancelled by client") + return ctx.Err() + default: + req, err := s.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + + fmt.Println("=>", req.Name) + greet := "hello, " + req.Name + fmt.Println("<=", greet) + s.Send(&stream.StreamResp{ + Greet: greet, + }) + } + } +} + +func main() { + var c rpcx.RpcServerConf + conf.MustLoad("etc/config.json", &c) + + server := rpcx.MustNewServer(c, func(grpcServer *grpc.Server) { + stream.RegisterStreamGreeterServer(grpcServer, StreamGreetServer(0)) + }) + server.Start() +} diff --git a/example/rpc/server/unary/Dockerfile b/example/rpc/server/unary/Dockerfile new file mode 100644 index 00000000..ad4e0e30 --- /dev/null +++ b/example/rpc/server/unary/Dockerfile @@ -0,0 +1,23 @@ +FROM golang:1.13-alpine AS builder + +LABEL stage=gobuilder + +ENV CGO_ENABLED 0 +ENV GOOS linux +ENV GOPROXY https://goproxy.cn,direct + +WORKDIR $GOPATH/src/zero +COPY . . +RUN go build -ldflags="-s -w" -o /app/unaryserver example/rpc/server/unary/server.go + + +FROM alpine + +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai + +WORKDIR /app +COPY --from=builder /app/unaryserver /app/unaryserver +COPY example/rpc/server/unary/etc/k8s.json /app/ + +CMD ["./unaryserver", "-f", "k8s.json"] diff --git a/example/rpc/server/unary/Makefile b/example/rpc/server/unary/Makefile new file mode 100644 index 00000000..1ed2aea0 --- /dev/null +++ b/example/rpc/server/unary/Makefile @@ -0,0 +1,11 @@ +version := v1 + +build: + cd $(GOPATH)/src/zero && docker build -t registry.cn-hangzhou.aliyuncs.com/xapp/unaryserver:$(version) . -f example/rpc/server/unary/Dockerfile + docker image prune --filter label=stage=gobuilder -f + +push: build + docker push registry.cn-hangzhou.aliyuncs.com/xapp/unaryserver:$(version) + +deploy: push + kubectl -n adhoc set image deployment/unaryserver-deployment unaryserver=registry-vpc.cn-hangzhou.aliyuncs.com/xapp/unaryserver:$(version) diff --git a/example/rpc/server/unary/etc/config.json b/example/rpc/server/unary/etc/config.json new file mode 100644 index 00000000..7a01b7fb --- /dev/null +++ b/example/rpc/server/unary/etc/config.json @@ -0,0 +1,13 @@ +{ + "Name": "rpc.unary", + "Log": { + "Mode": "volume" + }, + "ListenOn": "localhost:3456", + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "Key": "rpcx" + } +} diff --git a/example/rpc/server/unary/etc/config1.json b/example/rpc/server/unary/etc/config1.json new file mode 100644 index 00000000..6af30d6b --- /dev/null +++ b/example/rpc/server/unary/etc/config1.json @@ -0,0 +1,17 @@ +{ + "Name": "rpc.unary", + "MetricsUrl": "http://localhost:2222/add", + "ListenOn": "localhost:3457", + "Auth": false, + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "Key": "rpcx" + }, + "Redis": { + "Host": "localhost:6379", + "Type": "node", + "Key": "apps" + } +} diff --git a/example/rpc/server/unary/etc/k8s.json b/example/rpc/server/unary/etc/k8s.json new file mode 100644 index 00000000..1fb61444 --- /dev/null +++ b/example/rpc/server/unary/etc/k8s.json @@ -0,0 +1,12 @@ +{ + "Name": "rpc.unary", + "ListenOn": "0.0.0.0:3456", + "Auth": false, + "Etcd": { + "Hosts": [ + "etcd.discov:2379" + ], + "Key": "rpcx" + }, + "Timeout": 500 +} diff --git a/example/rpc/server/unary/server.go b/example/rpc/server/unary/server.go new file mode 100644 index 00000000..145658e3 --- /dev/null +++ b/example/rpc/server/unary/server.go @@ -0,0 +1,55 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "sync" + "time" + + "zero/core/conf" + "zero/example/rpc/remote/unary" + "zero/rpcx" + + "google.golang.org/grpc" +) + +var configFile = flag.String("f", "etc/config.json", "the config file") + +type GreetServer struct { + lock sync.Mutex + alive bool + downTime time.Time +} + +func NewGreetServer() *GreetServer { + return &GreetServer{ + alive: true, + } +} + +func (gs *GreetServer) Greet(ctx context.Context, req *unary.Request) (*unary.Response, error) { + fmt.Println("=>", req) + + hostname, err := os.Hostname() + if err != nil { + return nil, err + } + + return &unary.Response{ + Greet: "hello from " + hostname, + }, nil +} + +func main() { + flag.Parse() + + var c rpcx.RpcServerConf + conf.MustLoad(*configFile, &c) + + server := rpcx.MustNewServer(c, func(grpcServer *grpc.Server) { + unary.RegisterGreeterServer(grpcServer, NewGreetServer()) + }) + server.Start() +} diff --git a/example/rpc/server/unary/unaryserver.yaml b/example/rpc/server/unary/unaryserver.yaml new file mode 100644 index 00000000..c5716b4b --- /dev/null +++ b/example/rpc/server/unary/unaryserver.yaml @@ -0,0 +1,25 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: unaryserver-deployment + namespace: adhoc + labels: + app: unaryserver +spec: + replicas: 3 + selector: + matchLabels: + app: unaryserver + template: + metadata: + labels: + app: unaryserver + spec: + containers: + - name: unaryserver + image: registry-vpc.cn-hangzhou.aliyuncs.com/xapp/unaryserver:v1 + imagePullPolicy: Always + ports: + - containerPort: 3456 + imagePullSecrets: + - name: aliyun diff --git a/example/signal/main.go b/example/signal/main.go new file mode 100644 index 00000000..473c81b9 --- /dev/null +++ b/example/signal/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "path" + "runtime/pprof" + "syscall" + "time" + + "zero/core/cmdline" + "zero/core/logx" +) + +const ( + goroutineProfile = "goroutine" + debugLevel = 2 + timeFormat = "0102150405" +) + +func init() { + go func() { + // https://golang.org/pkg/os/signal/#Notify + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGUSR1, syscall.SIGTERM) + + for { + v := <-signals + switch v { + case syscall.SIGUSR1: + dumpGoroutines() + case syscall.SIGTERM: + gracefulStop(signals) + default: + logx.Error("Got unregistered signal:", v) + } + } + }() +} + +func dumpGoroutines() { + command := path.Base(os.Args[0]) + pid := syscall.Getpid() + dumpFile := path.Join(os.TempDir(), fmt.Sprintf("%s-%d-goroutines-%s.dump", + command, pid, time.Now().Format(timeFormat))) + + logx.Infof("Got dump goroutine signal, printing goroutine profile to %s", dumpFile) + + if f, err := os.Create(dumpFile); err != nil { + logx.Errorf("Failed to dump goroutine profile, error: %v", err) + } else { + defer f.Close() + pprof.Lookup(goroutineProfile).WriteTo(f, debugLevel) + } +} + +func gracefulStop(signals chan os.Signal) { + signal.Stop(signals) + + logx.Info("Got signal SIGTERM, shutting down...") + + time.Sleep(time.Second * 5) + logx.Infof("Still alive after %v, going to force kill the process...", time.Second*5) + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) +} + +func main() { + cmdline.EnterToContinue() +} diff --git a/example/siphash/sharding.go b/example/siphash/sharding.go new file mode 100644 index 00000000..2e69e321 --- /dev/null +++ b/example/siphash/sharding.go @@ -0,0 +1,8 @@ +package sharding + +import "github.com/dchest/siphash" + +func sharding(token string) uint64 { + sum := siphash.Hash(0, 0, []byte(token)) + return sum % 3 +} diff --git a/example/siphash/sharding_test.go b/example/siphash/sharding_test.go new file mode 100644 index 00000000..1816992c --- /dev/null +++ b/example/siphash/sharding_test.go @@ -0,0 +1,52 @@ +package sharding + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSiphash64(t *testing.T) { + users := [][]string{ + { + "5a4b7347200a6e0c185d6101", + "5b74c444acdd315c509b78fe", + "5c03e009a496130c2d9bc970", + "5c6ab5a74867f267d560dd9f", + "5b80a2b28be129507d176284", + "5a4b7347200a6e0c185d6101", + "5b74c444acdd315c509b78fe", + "5c03e009a496130c2d9bc970", + "5c6ab5a74867f267d560dd9f", + "5b80a2b28be129507d176284", + "5b8d157aacdd313508a892f2", + "5bf942b4a496130c2d9b7378", + "5c7fc28cd065f17f9edd3698", + "5bf40bd22c64fc5ea63a5174", + }, + { + "5b839929acdd31271f03ded5", + "5bc9e28e2c64fc1a69a28e36", + "5b935d96a49613677b90b589", + "5b97acb2a49613677b910f47", + "5c902f3aff5be73689b4b522", + }, + { + "5cdbee881a722f0001b9ce99", + "", + "5caca58f53add40001c20aaa", + "5beee68520c25041544e353a", + "5b0b957d0179b05769cbecde", + "5bbf45940ab7b7589aa1025f", + "5ac63009200a6e79cadf5175", + "5c94ed250ab7b7386c294662", + "5b9f8ccb2c64fc5832e47d3f", + }, + } + + for shard, ids := range users { + for _, id := range ids { + assert.Equal(t, uint64(shard), sharding(id)) + } + } +} diff --git a/example/sqlc/user.go b/example/sqlc/user.go new file mode 100644 index 00000000..84f892fd --- /dev/null +++ b/example/sqlc/user.go @@ -0,0 +1,142 @@ +package main + +import ( + "database/sql" + "fmt" + + "zero/core/stores/cache" + "zero/core/stores/sqlc" + "zero/core/stores/sqlx" + "zero/kq" +) + +var ( + userRows = "id, mobile, name, sex" + + cacheUserMobilePrefix = "cache#user#mobile#" + cacheUserIdPrefix = "cache#user#id#" + + ErrNotFound = sqlc.ErrNotFound +) + +type ( + User struct { + Id int64 `db:"id" json:"id,omitempty"` + Mobile string `db:"mobile" json:"mobile,omitempty"` + Name string `db:"name" json:"name,omitempty"` + Sex int `db:"sex" json:"sex,omitempty"` + } + + UserModel struct { + sqlc.CachedConn + // sqlx.SqlConn + table string + + // kafka use kq not kmq + push *kq.Pusher + } +) + +func NewUserModel(db sqlx.SqlConn, c cache.CacheConf, table string, pusher *kq.Pusher) *UserModel { + return &UserModel{ + CachedConn: sqlc.NewConn(db, c), + table: table, + push: pusher, + } +} + +func (um *UserModel) FindOne(id int64) (*User, error) { + key := fmt.Sprintf("%s%d", cacheUserIdPrefix, id) + var user User + err := um.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error { + query := fmt.Sprintf("SELECT %s FROM user WHERE id=?", userRows) + return conn.QueryRow(v, query, id) + }) + switch err { + case nil: + return &user, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (um *UserModel) FindByMobile(mobile string) (*User, error) { + var user User + key := fmt.Sprintf("%s%s", cacheUserMobilePrefix, mobile) + err := um.QueryRowIndex(&user, key, func(primary interface{}) string { + return fmt.Sprintf("%s%d", cacheUserIdPrefix, primary.(int64)) + }, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) { + query := fmt.Sprintf("SELECT %s FROM user WHERE mobile=?", userRows) + if err := conn.QueryRow(&user, query, mobile); err != nil { + return nil, err + } + return user.Id, nil + }, func(conn sqlx.SqlConn, v interface{}, primary interface{}) error { + return conn.QueryRow(v, "SELECT * FROM user WHERE id=?", primary) + }) + switch err { + case nil: + return &user, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +// Count for no cache +func (um *UserModel) Count() (int64, error) { + var count int64 + err := um.QueryRowNoCache(&count, "SELECT count(1) FROM user") + if err != nil { + return 0, err + } + return count, nil +} + +// Query rows +func (um *UserModel) FindByName(name string) ([]*User, error) { + var users []*User + query := fmt.Sprintf("SELECT %s FROM user WHERE name=?", userRows) + err := um.QueryRowsNoCache(&userRows, query, name) + if err != nil { + return nil, err + } + return users, nil +} + +func (um *UserModel) UpdateSexById(sex int, id int64) error { + key := fmt.Sprintf("%s%d", cacheUserIdPrefix, id) + _, err := um.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { + query := fmt.Sprintf("UPDATE user SET sex=? WHERE id=?") + return conn.Exec(query, sex, id) + }, key) + return err +} + +func (um *UserModel) UpdateMobileById(mobile string, id int64) error { + idKey := fmt.Sprintf("%s%d", cacheUserIdPrefix, id) + mobileKey := fmt.Sprintf("%s%s", cacheUserMobilePrefix, mobile) + _, err := um.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { + query := fmt.Sprintf("UPDATE user SET mobile=? WHERE id=?") + return conn.Exec(query, mobile, id) + }, idKey, mobileKey) + return err +} + +func (um *UserModel) Update(u *User) error { + oldUser, err := um.FindOne(u.Id) + if err != nil { + return err + } + + idKey := fmt.Sprintf("%s%d", cacheUserIdPrefix, oldUser.Id) + mobileKey := fmt.Sprintf("%s%s", cacheUserMobilePrefix, oldUser.Mobile) + _, err = um.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { + query := fmt.Sprintf("UPDATE user SET mobile=?, name=?, sex=? WHERE id=?") + return conn.Exec(query, u.Mobile, u.Name, u.Sex, u.Id) + }, idKey, mobileKey) + return err +} diff --git a/example/stat/Dockerfile b/example/stat/Dockerfile new file mode 100644 index 00000000..2aefe97f --- /dev/null +++ b/example/stat/Dockerfile @@ -0,0 +1,8 @@ +FROM alpine + +RUN mkdir /app +COPY cpu /app/ + +WORKDIR /app + +CMD ["/app/cpu"] diff --git a/example/stat/main.go b/example/stat/main.go new file mode 100644 index 00000000..07fa3297 --- /dev/null +++ b/example/stat/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "runtime" + "time" + + "zero/core/stat" +) + +func main() { + fmt.Println(runtime.NumCPU()) + for i := 0; i < runtime.NumCPU()+10; i++ { + go func() { + for { + select { + default: + time.Sleep(time.Microsecond) + } + } + }() + } + + ticker := time.NewTicker(time.Second * 5) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + percent := stat.CpuUsage() + fmt.Println("cpu:", percent) + } + } +} diff --git a/example/timingwheel/leak/leak.go b/example/timingwheel/leak/leak.go new file mode 100644 index 00000000..5c6b8c92 --- /dev/null +++ b/example/timingwheel/leak/leak.go @@ -0,0 +1,60 @@ +package main + +import ( + "flag" + "fmt" + "log" + "runtime" + "strconv" + "time" + + "zero/core/cmdline" + "zero/core/collection" + "zero/core/proc" +) + +const numItems = 1000000 + +var round = flag.Int("r", 3, "rounds to go") + +func main() { + defer proc.StartProfile().Stop() + + flag.Parse() + + fmt.Println(getMemUsage()) + for i := 0; i < *round; i++ { + do() + } + cmdline.EnterToContinue() +} + +func do() { + tw, err := collection.NewTimingWheel(time.Second, 100, execute) + if err != nil { + log.Fatal(err) + } + + for i := 0; i < numItems; i++ { + key := strconv.Itoa(i) + tw.SetTimer(key, key, time.Second*5) + } + + fmt.Println(getMemUsage()) +} + +func execute(k, v interface{}) { +} + +func getMemUsage() string { + runtime.GC() + + var m runtime.MemStats + runtime.ReadMemStats(&m) + // For more info, see: https://golang.org/pkg/runtime/#MemStats + return fmt.Sprintf("Heap Alloc = %dMiB", toMiB(m.HeapAlloc)) +} + +func toMiB(b uint64) uint64 { + return b / 1024 / 1024 +} diff --git a/example/timingwheel/main.go b/example/timingwheel/main.go new file mode 100644 index 00000000..79814a83 --- /dev/null +++ b/example/timingwheel/main.go @@ -0,0 +1,78 @@ +package main + +import ( + "flag" + "fmt" + "log" + "runtime" + "sync/atomic" + "time" + + "zero/core/collection" +) + +const interval = time.Minute + +var traditional = flag.Bool("traditional", false, "enable traditional mode") + +func main() { + flag.Parse() + + go func() { + ticker := time.NewTicker(time.Second * 5) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + fmt.Printf("goroutines: %d\n", runtime.NumGoroutine()) + } + } + }() + + if *traditional { + traditionalMode() + } else { + timingWheelMode() + } +} + +func timingWheelMode() { + var count uint64 + tw, err := collection.NewTimingWheel(time.Second, 600, func(key, value interface{}) { + job(&count) + }) + if err != nil { + log.Fatal(err) + } + + defer tw.Stop() + for i := 0; ; i++ { + tw.SetTimer(i, i, interval) + time.Sleep(time.Millisecond) + } +} + +func traditionalMode() { + var count uint64 + for { + go func() { + timer := time.NewTimer(interval) + defer timer.Stop() + + select { + case <-timer.C: + job(&count) + } + }() + + time.Sleep(time.Millisecond) + } +} + +func job(count *uint64) { + v := atomic.AddUint64(count, 1) + if v%1000 == 0 { + fmt.Println(v) + } +} diff --git a/example/tracing/edge/config.json b/example/tracing/edge/config.json new file mode 100644 index 00000000..6d500abd --- /dev/null +++ b/example/tracing/edge/config.json @@ -0,0 +1,3 @@ +{ + "Server": "localhost:3456" +} diff --git a/example/tracing/edge/main.go b/example/tracing/edge/main.go new file mode 100644 index 00000000..20e7b164 --- /dev/null +++ b/example/tracing/edge/main.go @@ -0,0 +1,61 @@ +package main + +import ( + "flag" + "log" + "net/http" + + "zero/core/conf" + "zero/core/httpx" + "zero/core/logx" + "zero/core/service" + "zero/example/tracing/remote/portal" + "zero/ngin" + "zero/rpcx" +) + +var ( + configFile = flag.String("f", "config.json", "the config file") + client *rpcx.RpcClient +) + +func handle(w http.ResponseWriter, r *http.Request) { + conn, ok := client.Next() + if !ok { + log.Fatal("no server") + } + + greet := portal.NewPortalClient(conn) + resp, err := greet.Portal(r.Context(), &portal.PortalRequest{ + Name: "kevin", + }) + if err != nil { + httpx.WriteJson(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } else { + httpx.OkJson(w, resp.Response) + } +} + +func main() { + flag.Parse() + + var c rpcx.RpcClientConf + conf.MustLoad(*configFile, &c) + client = rpcx.MustNewClient(c) + engine := ngin.MustNewEngine(ngin.NgConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + Mode: "console", + }, + }, + Port: 3333, + }) + defer engine.Stop() + + engine.AddRoute(ngin.Route{ + Method: http.MethodGet, + Path: "/", + Handler: handle, + }) + engine.Start() +} diff --git a/example/tracing/portal/etc/config.json b/example/tracing/portal/etc/config.json new file mode 100644 index 00000000..201470e4 --- /dev/null +++ b/example/tracing/portal/etc/config.json @@ -0,0 +1,19 @@ +{ + "Name": "portal.rpc", + "ListenOn": "localhost:3456", + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "Key": "portal" + }, + "UserRpc": { + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "Key": "user" + } + }, + "Timeout": 500 +} diff --git a/example/tracing/portal/server.go b/example/tracing/portal/server.go new file mode 100644 index 00000000..340a7309 --- /dev/null +++ b/example/tracing/portal/server.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "errors" + "flag" + + "zero/core/conf" + "zero/example/tracing/remote/portal" + "zero/example/tracing/remote/user" + "zero/rpcx" + + "google.golang.org/grpc" +) + +var configFile = flag.String("f", "etc/config.json", "the config file") + +type ( + Config struct { + rpcx.RpcServerConf + UserRpc rpcx.RpcClientConf + } + + PortalServer struct { + userRpc *rpcx.RpcClient + } +) + +func NewPortalServer(client *rpcx.RpcClient) *PortalServer { + return &PortalServer{ + userRpc: client, + } +} + +func (gs *PortalServer) Portal(ctx context.Context, req *portal.PortalRequest) (*portal.PortalResponse, error) { + conn, ok := gs.userRpc.Next() + if !ok { + return nil, errors.New("internal error") + } + + greet := user.NewUserClient(conn) + resp, err := greet.GetGrade(ctx, &user.UserRequest{ + Name: req.Name, + }) + if err != nil { + return &portal.PortalResponse{ + Response: err.Error(), + }, nil + } else { + return &portal.PortalResponse{ + Response: resp.Response, + }, nil + } +} + +func main() { + flag.Parse() + + var c Config + conf.MustLoad(*configFile, &c) + + client := rpcx.MustNewClient(c.UserRpc) + server := rpcx.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) { + portal.RegisterPortalServer(grpcServer, NewPortalServer(client)) + }) + server.Start() +} diff --git a/example/tracing/remote/portal/portal.pb.go b/example/tracing/remote/portal/portal.pb.go new file mode 100644 index 00000000..4dd7ba4d --- /dev/null +++ b/example/tracing/remote/portal/portal.pb.go @@ -0,0 +1,158 @@ +// Code generated by protoc-gen-go. +// source: portal.proto +// DO NOT EDIT! + +/* +Package portal is a generated protocol buffer package. + +It is generated from these files: + portal.proto + +It has these top-level messages: + PortalRequest + PortalResponse +*/ +package portal + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type PortalRequest struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` +} + +func (m *PortalRequest) Reset() { *m = PortalRequest{} } +func (m *PortalRequest) String() string { return proto.CompactTextString(m) } +func (*PortalRequest) ProtoMessage() {} +func (*PortalRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *PortalRequest) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type PortalResponse struct { + Response string `protobuf:"bytes,1,opt,name=response" json:"response,omitempty"` +} + +func (m *PortalResponse) Reset() { *m = PortalResponse{} } +func (m *PortalResponse) String() string { return proto.CompactTextString(m) } +func (*PortalResponse) ProtoMessage() {} +func (*PortalResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *PortalResponse) GetResponse() string { + if m != nil { + return m.Response + } + return "" +} + +func init() { + proto.RegisterType((*PortalRequest)(nil), "portal.PortalRequest") + proto.RegisterType((*PortalResponse)(nil), "portal.PortalResponse") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Portal service + +type PortalClient interface { + Portal(ctx context.Context, in *PortalRequest, opts ...grpc.CallOption) (*PortalResponse, error) +} + +type portalClient struct { + cc *grpc.ClientConn +} + +func NewPortalClient(cc *grpc.ClientConn) PortalClient { + return &portalClient{cc} +} + +func (c *portalClient) Portal(ctx context.Context, in *PortalRequest, opts ...grpc.CallOption) (*PortalResponse, error) { + out := new(PortalResponse) + err := grpc.Invoke(ctx, "/portal.Portal/Portal", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Portal service + +type PortalServer interface { + Portal(context.Context, *PortalRequest) (*PortalResponse, error) +} + +func RegisterPortalServer(s *grpc.Server, srv PortalServer) { + s.RegisterService(&_Portal_serviceDesc, srv) +} + +func _Portal_Portal_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PortalRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(PortalServer).Portal(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/portal.Portal/Portal", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(PortalServer).Portal(ctx, req.(*PortalRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Portal_serviceDesc = grpc.ServiceDesc{ + ServiceName: "portal.Portal", + HandlerType: (*PortalServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Portal", + Handler: _Portal_Portal_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "portal.proto", +} + +func init() { proto.RegisterFile("portal.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 122 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x29, 0xc8, 0x2f, 0x2a, + 0x49, 0xcc, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0x94, 0xb9, 0x78, + 0x03, 0xc0, 0xac, 0xa0, 0xd4, 0xc2, 0xd2, 0xd4, 0xe2, 0x12, 0x21, 0x21, 0x2e, 0x96, 0xbc, 0xc4, + 0xdc, 0x54, 0x09, 0x46, 0x05, 0x46, 0x0d, 0xce, 0x20, 0x30, 0x5b, 0x49, 0x87, 0x8b, 0x0f, 0xa6, + 0xa8, 0xb8, 0x20, 0x3f, 0xaf, 0x38, 0x55, 0x48, 0x8a, 0x8b, 0xa3, 0x08, 0xca, 0x86, 0xaa, 0x84, + 0xf3, 0x8d, 0x1c, 0xb9, 0xd8, 0x20, 0xaa, 0x85, 0xcc, 0xe1, 0x2c, 0x51, 0x3d, 0xa8, 0xed, 0x28, + 0x96, 0x49, 0x89, 0xa1, 0x0b, 0x43, 0x8c, 0x48, 0x62, 0x03, 0x3b, 0xd2, 0x18, 0x10, 0x00, 0x00, + 0xff, 0xff, 0xce, 0x05, 0x16, 0xd0, 0xb4, 0x00, 0x00, 0x00, +} diff --git a/example/tracing/remote/portal/portal.proto b/example/tracing/remote/portal/portal.proto new file mode 100644 index 00000000..82cdee3d --- /dev/null +++ b/example/tracing/remote/portal/portal.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package portal; + +message PortalRequest { + string name = 1; +} + +message PortalResponse { + string response = 1; +} + +service Portal { + rpc Portal(PortalRequest) returns (PortalResponse); +} diff --git a/example/tracing/remote/user/user.pb.go b/example/tracing/remote/user/user.pb.go new file mode 100644 index 00000000..40965651 --- /dev/null +++ b/example/tracing/remote/user/user.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go. +// source: user.proto +// DO NOT EDIT! + +/* +Package user is a generated protocol buffer package. + +It is generated from these files: + user.proto + +It has these top-level messages: + UserRequest + UserResponse +*/ +package user + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type UserRequest struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` +} + +func (m *UserRequest) Reset() { *m = UserRequest{} } +func (m *UserRequest) String() string { return proto.CompactTextString(m) } +func (*UserRequest) ProtoMessage() {} +func (*UserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *UserRequest) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type UserResponse struct { + Response string `protobuf:"bytes,1,opt,name=response" json:"response,omitempty"` +} + +func (m *UserResponse) Reset() { *m = UserResponse{} } +func (m *UserResponse) String() string { return proto.CompactTextString(m) } +func (*UserResponse) ProtoMessage() {} +func (*UserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *UserResponse) GetResponse() string { + if m != nil { + return m.Response + } + return "" +} + +func init() { + proto.RegisterType((*UserRequest)(nil), "user.UserRequest") + proto.RegisterType((*UserResponse)(nil), "user.UserResponse") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for User service + +type UserClient interface { + GetGrade(ctx context.Context, in *UserRequest, opts ...grpc.CallOption) (*UserResponse, error) +} + +type userClient struct { + cc *grpc.ClientConn +} + +func NewUserClient(cc *grpc.ClientConn) UserClient { + return &userClient{cc} +} + +func (c *userClient) GetGrade(ctx context.Context, in *UserRequest, opts ...grpc.CallOption) (*UserResponse, error) { + out := new(UserResponse) + err := grpc.Invoke(ctx, "/user.User/GetGrade", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for User service + +type UserServer interface { + GetGrade(context.Context, *UserRequest) (*UserResponse, error) +} + +func RegisterUserServer(s *grpc.Server, srv UserServer) { + s.RegisterService(&_User_serviceDesc, srv) +} + +func _User_GetGrade_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(UserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UserServer).GetGrade(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/user.User/GetGrade", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UserServer).GetGrade(ctx, req.(*UserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _User_serviceDesc = grpc.ServiceDesc{ + ServiceName: "user.User", + HandlerType: (*UserServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetGrade", + Handler: _User_GetGrade_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "user.proto", +} + +func init() { proto.RegisterFile("user.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 131 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x2d, 0x4e, 0x2d, + 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x01, 0xb1, 0x95, 0x14, 0xb9, 0xb8, 0x43, 0x8b, + 0x53, 0x8b, 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, + 0x53, 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x25, 0x2d, 0x2e, 0x1e, 0x88, 0x92, + 0xe2, 0x82, 0xfc, 0xbc, 0xe2, 0x54, 0x21, 0x29, 0x2e, 0x8e, 0x22, 0x28, 0x1b, 0xaa, 0x0e, 0xce, + 0x37, 0xb2, 0xe4, 0x62, 0x01, 0xa9, 0x15, 0x32, 0xe4, 0xe2, 0x70, 0x4f, 0x2d, 0x71, 0x2f, 0x4a, + 0x4c, 0x49, 0x15, 0x12, 0xd4, 0x03, 0xdb, 0x8a, 0x64, 0x8d, 0x94, 0x10, 0xb2, 0x10, 0x44, 0x6b, + 0x12, 0x1b, 0xd8, 0x59, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x99, 0xfd, 0x8f, 0x70, 0xa4, + 0x00, 0x00, 0x00, +} diff --git a/example/tracing/remote/user/user.proto b/example/tracing/remote/user/user.proto new file mode 100644 index 00000000..ae2390c7 --- /dev/null +++ b/example/tracing/remote/user/user.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package user; + +message UserRequest { + string name = 1; +} + +message UserResponse { + string response = 1; +} + +service User { + rpc GetGrade(UserRequest) returns (UserResponse); +} + + diff --git a/example/tracing/user/etc/config.json b/example/tracing/user/etc/config.json new file mode 100644 index 00000000..a1ee98cd --- /dev/null +++ b/example/tracing/user/etc/config.json @@ -0,0 +1,11 @@ +{ + "Name": "user.rpc", + "ListenOn": "localhost:3457", + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "Key": "user" + }, + "Timeout": 500 +} diff --git a/example/tracing/user/server.go b/example/tracing/user/server.go new file mode 100644 index 00000000..f231cab1 --- /dev/null +++ b/example/tracing/user/server.go @@ -0,0 +1,55 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "sync" + "time" + + "zero/core/conf" + "zero/example/tracing/remote/user" + "zero/rpcx" + + "google.golang.org/grpc" +) + +var configFile = flag.String("f", "etc/config.json", "the config file") + +type UserServer struct { + lock sync.Mutex + alive bool + downTime time.Time +} + +func NewUserServer() *UserServer { + return &UserServer{ + alive: true, + } +} + +func (gs *UserServer) GetGrade(ctx context.Context, req *user.UserRequest) (*user.UserResponse, error) { + fmt.Println("=>", req) + + hostname, err := os.Hostname() + if err != nil { + return nil, err + } + + return &user.UserResponse{ + Response: "hello from " + hostname, + }, nil +} + +func main() { + flag.Parse() + + var c rpcx.RpcServerConf + conf.MustLoad(*configFile, &c) + + server := rpcx.MustNewServer(c, func(grpcServer *grpc.Server) { + user.RegisterUserServer(grpcServer, NewUserServer()) + }) + server.Start() +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..ae102e69 --- /dev/null +++ b/go.mod @@ -0,0 +1,84 @@ +module zero + +go 1.14 + +require ( + github.com/DATA-DOG/go-sqlmock v1.4.1 + github.com/DataDog/zstd v1.4.5 // indirect + github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 // indirect + github.com/alicebob/miniredis v2.5.0+incompatible + github.com/beanstalkd/beanstalk v0.0.0-20200229072127-2b7b37f17578 + github.com/beanstalkd/go-beanstalk v0.0.0-20200229072127-2b7b37f17578 // indirect + github.com/coreos/bbolt v1.3.1-coreos.6 // indirect + github.com/coreos/etcd v3.3.18+incompatible + github.com/coreos/go-semver v0.2.0 // indirect + github.com/coreos/go-systemd v0.0.0-20181031085051-9002847aa142 // indirect + github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf // indirect + github.com/dchest/siphash v1.2.1 + github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/fatih/color v1.9.0 // indirect + github.com/fortytw2/leaktest v1.3.0 // indirect + github.com/frankban/quicktest v1.7.2 // indirect + github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 + github.com/go-redis/redis v6.15.7+incompatible + github.com/go-sql-driver/mysql v1.5.0 + github.com/gogo/protobuf v1.3.1 // indirect + github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect + github.com/golang/mock v1.4.3 + github.com/golang/protobuf v1.4.0 + github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf // indirect + github.com/gomodule/redigo v2.0.0+incompatible // indirect + github.com/google/btree v1.0.0 // indirect + github.com/google/gops v0.3.7 + github.com/google/uuid v1.1.1 + github.com/gorilla/websocket v1.4.2 // indirect + github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4 // indirect + github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect + github.com/grpc-ecosystem/grpc-gateway v1.14.3 // indirect + github.com/jonboulle/clockwork v0.1.0 // indirect + github.com/json-iterator/go v1.1.9 + github.com/justinas/alice v1.2.0 + github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect + github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect + github.com/kr/pretty v0.2.0 // indirect + github.com/kshvakov/clickhouse v1.3.11 + github.com/lib/pq v1.3.0 + github.com/mailru/easyjson v0.7.1 // indirect + github.com/mattn/go-colorable v0.1.6 // indirect + github.com/mattn/go-runewidth v0.0.9 // indirect + github.com/olekukonko/tablewriter v0.0.4 + github.com/olivere/elastic v6.2.30+incompatible + github.com/onsi/ginkgo v1.7.0 // indirect + github.com/onsi/gomega v1.5.0 // indirect + github.com/pierrec/lz4 v2.5.1+incompatible // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/prometheus/client_golang v1.5.1 + github.com/segmentio/encoding v0.1.12 + github.com/segmentio/kafka-go v0.3.5 + github.com/soheilhy/cmux v0.1.4 // indirect + github.com/spaolacci/murmur3 v1.1.0 + github.com/stretchr/testify v1.5.1 + github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect + github.com/xdg/stringprep v1.0.1-0.20180714160509-73f8eece6fdc // indirect + github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect + github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb // indirect + go.etcd.io/etcd v3.3.17+incompatible + go.uber.org/automaxprocs v1.3.0 + go.uber.org/multierr v1.4.0 // indirect + go.uber.org/zap v1.12.0 // indirect + golang.org/x/crypto v0.0.0-20200406173513-056763e48d71 // indirect + golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect + golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e + golang.org/x/sys v0.0.0-20200413165638-669c56c373c4 // indirect + golang.org/x/text v0.3.2 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 + golang.org/x/tools v0.0.0-20200410132612-ae9902aceb98 // indirect + google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1 // indirect + google.golang.org/grpc v1.26.0 + gopkg.in/cheggaaa/pb.v1 v1.0.28 + gopkg.in/yaml.v2 v2.2.8 + honnef.co/go/tools v0.0.1-2020.1.4 // indirect + sigs.k8s.io/yaml v1.2.0 // indirect +) + +replace google.golang.org/grpc => google.golang.org/grpc v1.25.1 diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..523eb038 --- /dev/null +++ b/go.sum @@ -0,0 +1,396 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1MRDJM= +github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/DataDog/zstd v1.4.0/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= +github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/StackExchange/wmi v0.0.0-20170410192909-ea383cf3ba6e/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 h1:45bxf7AZMwWcqkLzDAQugVEwedisr5nRJ1r+7LYnv0U= +github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis v2.5.0+incompatible h1:yBHoLpsyjupjz3NL3MhKMVkR41j82Yjf3KFv7ApYzUI= +github.com/alicebob/miniredis v2.5.0+incompatible/go.mod h1:8HZjEj4yU0dwhYHky+DxYx+6BMjkBbe5ONFIF1MXffk= +github.com/antihax/optional v0.0.0-20180407024304-ca021399b1a6/go.mod h1:V8iCPQYkqmusNa815XgQio277wI47sdRh1dUOLdyC6Q= +github.com/beanstalkd/beanstalk v0.0.0-20200229072127-2b7b37f17578 h1:9bjGO11r2d8O5uPJDrA93RR6uPzzIE4pPvjYn+k/ej8= +github.com/beanstalkd/beanstalk v0.0.0-20200229072127-2b7b37f17578/go.mod h1:WFv1+FwOgzmP6by3Dp6MAQpwyGl/JZxR2l1kf14rjFU= +github.com/beanstalkd/go-beanstalk v0.0.0-20200229072127-2b7b37f17578 h1:xdUBa6pQOvMgjhnVhp4gFTKGlpO/wLa5Qw5lBEGRqsU= +github.com/beanstalkd/go-beanstalk v0.0.0-20200229072127-2b7b37f17578/go.mod h1:Q3f6RCbUHp8RHSfBiPUZBojK76rir8Rl+KINuz2/sYs= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bkaradzic/go-lz4 v1.0.0 h1:RXc4wYsyz985CkXXeX04y4VnZFGG8Rd43pRaHsOXAKk= +github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4= +github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= +github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= +github.com/coreos/bbolt v1.3.1-coreos.6 h1:uTXKg9gY70s9jMAKdfljFQcuh4e/BXOM+V+d00KFj3A= +github.com/coreos/bbolt v1.3.1-coreos.6/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= +github.com/coreos/etcd v3.3.18+incompatible h1:Zz1aXgDrFFi1nadh58tA9ktt06cmPTwNNP3dXwIq1lE= +github.com/coreos/etcd v3.3.18+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-semver v0.2.0 h1:3Jm3tLmsgAYcjC+4Up7hJrFBPr+n7rAqYeSw/SZazuY= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd v0.0.0-20181031085051-9002847aa142 h1:3jFq2xL4ZajGK4aZY8jz+DAF0FHjI51BXjjSwCzS1Dk= +github.com/coreos/go-systemd v0.0.0-20181031085051-9002847aa142/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf h1:CAKfRE2YtTUIjjh1bkBtyYFaUT/WmOqsJjgtihT0vMI= +github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4= +github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= +github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk= +github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= +github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 h1:DujepqpGd1hyOd7aW59XpK7Qymp8iy83xq74fLr21is= +github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E= +github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= +github.com/go-redis/redis v6.15.7+incompatible h1:3skhDh95XQMpnqeqNftPkQD9jL9e5e36z/1SUm6dy1U= +github.com/go-redis/redis v6.15.7+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf h1:gFVkHXmVAhEbxZVDln5V9GKrLaluNoFHDbrZwAWZgws= +github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gops v0.3.7 h1:KtVAagOM0FIq+02DiQrKBTnLhYpWBMowaufcj+W1Exw= +github.com/google/gops v0.3.7/go.mod h1:bj0cwMmX1X4XIJFTjR99R5sCxNssNJ8HebFNvoQlmgY= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4 h1:z53tR0945TRRQO/fLEVPI6SMv7ZflF0TEaTAoU7tOzg= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.14.3 h1:OCJlWkOUoTnl0neNGlf4fUm3TmbEtguw7vR+nGtnDjY= +github.com/grpc-ecosystem/grpc-gateway v1.14.3/go.mod h1:6CwZWGDSPRJidgKAtJVvND6soZe6fT7iteq8wDPdhb0= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo= +github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo= +github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA= +github.com/kardianos/osext v0.0.0-20170510131534-ae77be60afb1/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/keybase/go-ps v0.0.0-20161005175911-668c8856d999 h1:2d+FLQbz4xRTi36DO1qYNUwfORax9XcQ0jhbO81Vago= +github.com/keybase/go-ps v0.0.0-20161005175911-668c8856d999/go.mod h1:hY+WOq6m2FpbvyrI93sMaypsttvaIL5nhVR92dTMUcQ= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kshvakov/clickhouse v1.3.11 h1:dtzTJY0fCA+MWkLyuKZaNPkmSwdX4gh8+Klic9NB1Lw= +github.com/kshvakov/clickhouse v1.3.11/go.mod h1:/SVBAcqF3u7rxQ9sTWCZwf8jzzvxiZGeQvtmSF2BBEc= +github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= +github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mailru/easyjson v0.7.1 h1:mdxE1MF9o53iCb2Ghj1VfWvh7ZOwHpnVG/xwXrV90U8= +github.com/mailru/easyjson v0.7.1/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7ldAVICs= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-runewidth v0.0.7 h1:Ei8KR0497xHyKJPAv59M1dkC+rOZCMBJ+t3fZ+twI54= +github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/olekukonko/tablewriter v0.0.4 h1:vHD/YYe1Wolo78koG299f7V/VAS08c6IpCLn+Ejf/w8= +github.com/olekukonko/tablewriter v0.0.4/go.mod h1:zq6QwlOf5SlnkVbMSr5EoBv3636FWnp+qbPhuoO21uA= +github.com/olivere/elastic v6.2.30+incompatible h1:9JdhoNFfUF809qM1S5WLz3CZaxazd/mDty9XXwDRz4Q= +github.com/olivere/elastic v6.2.30+incompatible/go.mod h1:J+q1zQJTgAz9woqsbVRqGeB5G1iqDKVBWLNSYW8yfJ8= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0 h1:izbySO9zDPmjJ8rDjLvkA2zJHIo+HkYXHnf7eN7SSyo= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pierrec/lz4 v2.5.1+incompatible h1:Yq0up0149Hh5Ekhm/91lgkZuD1ZDnXNM26bycpTzYBM= +github.com/pierrec/lz4 v2.5.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.5.1 h1:bdHYieyGlH+6OLEk2YQha8THib30KP0/yD0YH9m6xcA= +github.com/prometheus/client_golang v1.5.1/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.9.1 h1:KOMtN28tlbam3/7ZKEYKHhKoJZYYj3gMH4uc62x7X7U= +github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8 h1:+fpWZdT24pJBiqJdAwYBjPSk+5YmQzYNPYzQsdzLkt8= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/segmentio/encoding v0.1.12 h1:SwIDXReTDnlYqOcLachzJEczAEihST7Mx7nGlAWCJ3Q= +github.com/segmentio/encoding v0.1.12/go.mod h1:RWhr02uzMB9gQC1x+MfYxedtmBibb9cZ6Vv9VxRSSbw= +github.com/segmentio/kafka-go v0.3.5 h1:2JVT1inno7LxEASWj+HflHh5sWGfM0gkRiLAxkXhGG4= +github.com/segmentio/kafka-go v0.3.5/go.mod h1:OT5KXBPbaJJTcvokhWR2KFmm0niEx3mnccTwjmLvSi4= +github.com/shirou/gopsutil v0.0.0-20180427012116-c95755e4bcd7/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U= +github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E= +github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= +github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 h1:lYIiVDtZnyTWlNwiAxLj0bbpTcx1BWCFhXjfsvmPdNc= +github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +github.com/xdg/stringprep v1.0.1-0.20180714160509-73f8eece6fdc h1:vIp1tjhVogU0yBy7w96P027ewvNPeH6gzuNcoc+NReU= +github.com/xdg/stringprep v1.0.1-0.20180714160509-73f8eece6fdc/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6 h1:YdYsPAZ2pC6Tow/nPZOPQ96O3hm/ToAkGsPLzedXERk= +github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox0hTHlnpkcOTuFIDQpZ1IN8rKKhX0= +github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ= +go.etcd.io/etcd v3.3.17+incompatible h1:g8iRku1SID8QAW8cDlV0L/PkZlw63LSiYEHYHoE6j/s= +go.etcd.io/etcd v3.3.17+incompatible/go.mod h1:yaeTdrJi5lOmYerz05bd8+V7KubZs8YSFZfzsF9A6aI= +go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/automaxprocs v1.3.0 h1:II28aZoGdaglS5vVNnspf28lnZpXScxtIozx1lAjdb0= +go.uber.org/automaxprocs v1.3.0/go.mod h1:9CWT6lKIep8U41DDaPiH6eFscnTyjfTANNQNx6LrIcA= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.4.0 h1:f3WCSC2KzAcBXGATIxAB1E2XuCpNU255wNKZ505qi3E= +go.uber.org/multierr v1.4.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.12.0 h1:dySoUQPFBGj6xwjmBzageVL8jGi8uxc6bEmJQjA06bw= +go.uber.org/zap v1.12.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200406173513-056763e48d71 h1:DOmugCavvUtnUD114C1Wh+UgTgQZ4pMLzXxi1pSt+/Y= +golang.org/x/crypto v0.0.0-20200406173513-056763e48d71/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191002035440-2ec189313ef0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20171017063910-8dbc5d05d6ed/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200413165638-669c56c373c4 h1:opSr2sbRXk5X5/givKrrKj9HXxFpW2sdCiP8MJSKLQY= +golang.org/x/sys v0.0.0-20200413165638-669c56c373c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200410132612-ae9902aceb98 h1:ibc1eDGW5ajwA4qzFTj0WHlD9eofMe1gAre+A0a3Vhs= +golang.org/x/tools v0.0.0-20200410132612-ae9902aceb98/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190927181202-20e1ac93f88c/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1 h1:aQktFqmDE2yjveXJlVIfslDFmFnUXSqG0i6KRcJAeMc= +google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zimw= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/cheggaaa/pb.v1 v1.0.28 h1:n1tBJnnK2r7g9OW2btFH91V92STTUevLXYFb8gy9EMk= +gopkg.in/cheggaaa/pb.v1 v1.0.28/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5 h1:ymVxjfMaHvXD8RqPRmzHHsB3VvucivSkIAvJFDI5O3c= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= +honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +rsc.io/goversion v1.0.0 h1:/IhXBiai89TyuerPquiZZ39IQkTfAUbZB2awsyYZ/2c= +rsc.io/goversion v1.0.0/go.mod h1:Eih9y/uIBS3ulggl7KNJ09xGSLcuNaLgmvvqa07sgfo= +rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= +sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q= +sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= diff --git a/kq/config.go b/kq/config.go new file mode 100644 index 00000000..773887b4 --- /dev/null +++ b/kq/config.go @@ -0,0 +1,21 @@ +package kq + +import "zero/core/service" + +const ( + firstOffset = "first" + lastOffset = "last" +) + +type KqConf struct { + service.ServiceConf + Brokers []string + Group string + Topic string + Offset string `json:",options=first|last,default=last"` + NumConns int `json:",default=1"` + NumProducers int `json:",default=8"` + NumConsumers int `json:",default=8"` + MinBytes int `json:",default=10240"` // 10K + MaxBytes int `json:",default=10485760"` // 10M +} diff --git a/kq/pusher.go b/kq/pusher.go new file mode 100644 index 00000000..f95bab0f --- /dev/null +++ b/kq/pusher.go @@ -0,0 +1,101 @@ +package kq + +import ( + "context" + "strconv" + "time" + + "zero/core/executors" + "zero/core/logx" + + "github.com/segmentio/kafka-go" + "github.com/segmentio/kafka-go/snappy" +) + +type ( + PushOption func(options *chunkOptions) + + Pusher struct { + produer *kafka.Writer + topic string + executor *executors.ChunkExecutor + } + + chunkOptions struct { + chunkSize int + flushInterval time.Duration + } +) + +func NewPusher(addrs []string, topic string, opts ...PushOption) *Pusher { + producer := kafka.NewWriter(kafka.WriterConfig{ + Brokers: addrs, + Topic: topic, + Balancer: &kafka.LeastBytes{}, + CompressionCodec: snappy.NewCompressionCodec(), + }) + + pusher := &Pusher{ + produer: producer, + topic: topic, + } + pusher.executor = executors.NewChunkExecutor(func(tasks []interface{}) { + chunk := make([]kafka.Message, len(tasks)) + for i := range tasks { + chunk[i] = tasks[i].(kafka.Message) + } + if err := pusher.produer.WriteMessages(context.Background(), chunk...); err != nil { + logx.Error(err) + } + }, newOptions(opts)...) + + return pusher +} + +func (p *Pusher) Close() error { + return p.produer.Close() +} + +func (p *Pusher) Name() string { + return p.topic +} + +func (p *Pusher) Push(v string) error { + msg := kafka.Message{ + Key: []byte(strconv.FormatInt(time.Now().UnixNano(), 10)), + Value: []byte(v), + } + if p.executor != nil { + return p.executor.Add(msg, len(v)) + } else { + return p.produer.WriteMessages(context.Background(), msg) + } +} + +func WithChunkSize(chunkSize int) PushOption { + return func(options *chunkOptions) { + options.chunkSize = chunkSize + } +} + +func WithFlushInterval(interval time.Duration) PushOption { + return func(options *chunkOptions) { + options.flushInterval = interval + } +} + +func newOptions(opts []PushOption) []executors.ChunkOption { + var options chunkOptions + for _, opt := range opts { + opt(&options) + } + + var chunkOpts []executors.ChunkOption + if options.chunkSize > 0 { + chunkOpts = append(chunkOpts, executors.WithChunkBytes(options.chunkSize)) + } + if options.flushInterval > 0 { + chunkOpts = append(chunkOpts, executors.WithFlushInterval(options.flushInterval)) + } + return chunkOpts +} diff --git a/kq/queue.go b/kq/queue.go new file mode 100644 index 00000000..823ceda4 --- /dev/null +++ b/kq/queue.go @@ -0,0 +1,229 @@ +package kq + +import ( + "context" + "io" + "log" + "time" + + "zero/core/logx" + "zero/core/queue" + "zero/core/service" + "zero/core/stat" + "zero/core/threading" + "zero/core/timex" + + "github.com/segmentio/kafka-go" + _ "github.com/segmentio/kafka-go/gzip" + _ "github.com/segmentio/kafka-go/lz4" + _ "github.com/segmentio/kafka-go/snappy" +) + +const ( + defaultCommitInterval = time.Second + defaultMaxWait = time.Second +) + +type ( + ConsumeHandle func(key, value string) error + + ConsumeHandler interface { + Consume(key, value string) error + } + + queueOptions struct { + commitInterval time.Duration + maxWait time.Duration + metrics *stat.Metrics + } + + QueueOption func(*queueOptions) + + kafkaQueue struct { + c KqConf + consumer *kafka.Reader + handler ConsumeHandler + channel chan kafka.Message + producerRoutines *threading.RoutineGroup + consumerRoutines *threading.RoutineGroup + metrics *stat.Metrics + } + + kafkaQueues struct { + queues []queue.MessageQueue + group *service.ServiceGroup + } +) + +func MustNewQueue(c KqConf, handler ConsumeHandler, opts ...QueueOption) queue.MessageQueue { + q, err := NewQueue(c, handler, opts...) + if err != nil { + log.Fatal(err) + } + + return q +} + +func NewQueue(c KqConf, handler ConsumeHandler, opts ...QueueOption) (queue.MessageQueue, error) { + if err := c.SetUp(); err != nil { + return nil, err + } + + var options queueOptions + for _, opt := range opts { + opt(&options) + } + ensureQueueOptions(c, &options) + + if c.NumConns < 1 { + c.NumConns = 1 + } + q := kafkaQueues{ + group: service.NewServiceGroup(), + } + for i := 0; i < c.NumConns; i++ { + q.queues = append(q.queues, newKafkaQueue(c, handler, options)) + } + + return q, nil +} + +func newKafkaQueue(c KqConf, handler ConsumeHandler, options queueOptions) queue.MessageQueue { + var offset int64 + if c.Offset == firstOffset { + offset = kafka.FirstOffset + } else { + offset = kafka.LastOffset + } + consumer := kafka.NewReader(kafka.ReaderConfig{ + Brokers: c.Brokers, + GroupID: c.Group, + Topic: c.Topic, + StartOffset: offset, + MinBytes: c.MinBytes, // 10KB + MaxBytes: c.MaxBytes, // 10MB + MaxWait: options.maxWait, + CommitInterval: options.commitInterval, + }) + + return &kafkaQueue{ + c: c, + consumer: consumer, + handler: handler, + channel: make(chan kafka.Message), + producerRoutines: threading.NewRoutineGroup(), + consumerRoutines: threading.NewRoutineGroup(), + metrics: options.metrics, + } +} + +func (q *kafkaQueue) Start() { + q.startConsumers() + q.startProducers() + + q.producerRoutines.Wait() + close(q.channel) + q.consumerRoutines.Wait() +} + +func (q *kafkaQueue) Stop() { + q.consumer.Close() + logx.Close() +} + +func (q *kafkaQueue) consumeOne(key, val string) error { + startTime := timex.Now() + err := q.handler.Consume(key, val) + q.metrics.Add(stat.Task{ + Duration: timex.Since(startTime), + }) + return err +} + +func (q *kafkaQueue) startConsumers() { + for i := 0; i < q.c.NumConsumers; i++ { + q.consumerRoutines.Run(func() { + for msg := range q.channel { + if err := q.consumeOne(string(msg.Key), string(msg.Value)); err != nil { + logx.Errorf("Error on consuming: %s, error: %v", string(msg.Value), err) + } + } + }) + } +} + +func (q *kafkaQueue) startProducers() { + for i := 0; i < q.c.NumProducers; i++ { + q.producerRoutines.Run(func() { + for { + msg, err := q.consumer.ReadMessage(context.Background()) + // io.EOF means consumer closed + // io.ErrClosedPipe means committing messages on the consumer, + // kafka will refire the messages on uncommitted messages, ignore + if err == io.EOF || err == io.ErrClosedPipe { + return + } + if err != nil { + logx.Errorf("Error on reading mesage, %q", err.Error()) + continue + } + q.channel <- msg + } + }) + } +} + +func (q kafkaQueues) Start() { + for _, each := range q.queues { + q.group.Add(each) + } + q.group.Start() +} + +func (q kafkaQueues) Stop() { + q.group.Stop() +} + +func WithCommitInterval(interval time.Duration) QueueOption { + return func(options *queueOptions) { + options.commitInterval = interval + } +} + +func WithHandle(handle ConsumeHandle) ConsumeHandler { + return innerConsumeHandler{ + handle: handle, + } +} + +func WithMaxWait(wait time.Duration) QueueOption { + return func(options *queueOptions) { + options.maxWait = wait + } +} + +func WithMetrics(metrics *stat.Metrics) QueueOption { + return func(options *queueOptions) { + options.metrics = metrics + } +} + +type innerConsumeHandler struct { + handle ConsumeHandle +} + +func (ch innerConsumeHandler) Consume(k, v string) error { + return ch.handle(k, v) +} + +func ensureQueueOptions(c KqConf, options *queueOptions) { + if options.commitInterval == 0 { + options.commitInterval = defaultCommitInterval + } + if options.maxWait == 0 { + options.maxWait = defaultMaxWait + } + if options.metrics == nil { + options.metrics = stat.NewMetrics(c.Name) + } +} diff --git a/ngin/config.go b/ngin/config.go new file mode 100644 index 00000000..ebbc1f87 --- /dev/null +++ b/ngin/config.go @@ -0,0 +1,33 @@ +package ngin + +import ( + "time" + + "zero/core/service" +) + +type ( + PrivateKeyConf struct { + Fingerprint string + KeyFile string + } + + SignatureConf struct { + Strict bool `json:",default=false"` + Expiry time.Duration `json:",default=1h"` + PrivateKeys []PrivateKeyConf + } + + NgConf struct { + service.ServiceConf + Host string `json:",default=0.0.0.0"` + Port int + Verbose bool `json:",optional"` + MaxConns int `json:",default=10000"` + MaxBytes int64 `json:",default=1048576,range=[0:8388608]"` + // milliseconds + Timeout int64 `json:",default=3000"` + CpuThreshold int64 `json:",default=900,range=[0:1000]"` + Signature SignatureConf `json:",optional"` + } +) diff --git a/ngin/etc/config.json b/ngin/etc/config.json new file mode 100644 index 00000000..1b684709 --- /dev/null +++ b/ngin/etc/config.json @@ -0,0 +1,13 @@ +{ + "Name": "nging", + "Log": { + "Access": "logs/access.log", + "Error": "logs/error.log", + "Stat": "logs/stat.log" + }, + "Host": "127.0.0.1", + "Port": 1111, + "Timeout": 1000, + "Verbose": 0, + "Develop": 1 +} diff --git a/ngin/ngin.go b/ngin/ngin.go new file mode 100644 index 00000000..efbe5b78 --- /dev/null +++ b/ngin/ngin.go @@ -0,0 +1,170 @@ +package ngin + +import ( + "log" + "net/http" + + "zero/core/httphandler" + "zero/core/httprouter" + "zero/core/logx" +) + +type ( + runOptions struct { + start func(*server) error + } + + RunOption func(*Engine) + + Engine struct { + srv *server + opts runOptions + } +) + +func MustNewEngine(c NgConf, opts ...RunOption) *Engine { + engine, err := NewEngine(c, opts...) + if err != nil { + log.Fatal(err) + } + + return engine +} + +func NewEngine(c NgConf, opts ...RunOption) (*Engine, error) { + if err := c.SetUp(); err != nil { + return nil, err + } + + engine := &Engine{ + srv: newServer(c), + opts: runOptions{ + start: func(srv *server) error { + return srv.Start() + }, + }, + } + + for _, opt := range opts { + opt(engine) + } + + return engine, nil +} + +func (e *Engine) AddRoutes(rs []Route, opts ...RouteOption) { + r := featuredRoutes{ + routes: rs, + } + for _, opt := range opts { + opt(&r) + } + e.srv.AddRoutes(r) +} + +func (e *Engine) AddRoute(r Route, opts ...RouteOption) { + e.AddRoutes([]Route{r}, opts...) +} + +func (e *Engine) Start() { + handleError(e.opts.start(e.srv)) +} + +func (e *Engine) Stop() { + logx.Close() +} + +func (e *Engine) Use(middleware Middleware) { + e.srv.use(middleware) +} + +func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { + return func(handle http.HandlerFunc) http.HandlerFunc { + return handler(handle).ServeHTTP + } +} + +func WithJwt(secret string) RouteOption { + return func(r *featuredRoutes) { + validateSecret(secret) + r.jwt.enabled = true + r.jwt.secret = secret + } +} + +func WithJwtTransition(secret, prevSecret string) RouteOption { + return func(r *featuredRoutes) { + // why not validate prevSecret, because prevSecret is an already used one, + // even it not meet our requirement, we still need to allow the transition. + validateSecret(secret) + r.jwt.enabled = true + r.jwt.secret = secret + r.jwt.prevSecret = prevSecret + } +} + +func WithMiddleware(middleware Middleware, rs ...Route) []Route { + routes := make([]Route, len(rs)) + + for i := range rs { + route := rs[i] + routes[i] = Route{ + Method: route.Method, + Path: route.Path, + Handler: middleware(route.Handler), + } + } + + return routes +} + +func WithPriority() RouteOption { + return func(r *featuredRoutes) { + r.priority = true + } +} + +func WithRouter(router httprouter.Router) RunOption { + return func(engine *Engine) { + engine.opts.start = func(srv *server) error { + return srv.StartWithRouter(router) + } + } +} + +func WithSignature(signature SignatureConf) RouteOption { + return func(r *featuredRoutes) { + r.signature.enabled = true + r.signature.Strict = signature.Strict + r.signature.Expiry = signature.Expiry + r.signature.PrivateKeys = signature.PrivateKeys + } +} + +func WithUnauthorizedCallback(callback httphandler.UnauthorizedCallback) RunOption { + return func(engine *Engine) { + engine.srv.SetUnauthorizedCallback(callback) + } +} + +func WithUnsignedCallback(callback httphandler.UnsignedCallback) RunOption { + return func(engine *Engine) { + engine.srv.SetUnsignedCallback(callback) + } +} + +func handleError(err error) { + // ErrServerClosed means the server is closed manually + if err == nil || err == http.ErrServerClosed { + return + } + + logx.Error(err) + panic(err) +} + +func validateSecret(secret string) { + if len(secret) < 8 { + panic("secret's length can't be less than 8") + } +} diff --git a/ngin/ngin_test.go b/ngin/ngin_test.go new file mode 100644 index 00000000..3834c951 --- /dev/null +++ b/ngin/ngin_test.go @@ -0,0 +1,71 @@ +package ngin + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "zero/core/httprouter" + "zero/core/httpx" + + "github.com/stretchr/testify/assert" +) + +func TestWithMiddleware(t *testing.T) { + m := make(map[string]string) + router := httprouter.NewPatRouter() + handler := func(w http.ResponseWriter, r *http.Request) { + var v struct { + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + } + + err := httpx.Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode)) + assert.Nil(t, err) + } + rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var v struct { + Name string `path:"name"` + Year string `path:"year"` + } + assert.Nil(t, httpx.ParsePath(r, &v)) + m[v.Name] = v.Year + next.ServeHTTP(w, r) + } + }, Route{ + Method: http.MethodGet, + Path: "/first/:name/:year", + Handler: handler, + }, Route{ + Method: http.MethodGet, + Path: "/second/:name/:year", + Handler: handler, + }) + + urls := []string{ + "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000", + "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", + } + for _, route := range rs { + assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) + } + for _, url := range urls { + r, err := http.NewRequest(http.MethodGet, url, nil) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "whatever:200000", rr.Body.String()) + } + + assert.EqualValues(t, map[string]string{ + "kevin": "2017", + "wan": "2020", + }, m) +} diff --git a/ngin/server.go b/ngin/server.go new file mode 100644 index 00000000..6caec1d6 --- /dev/null +++ b/ngin/server.go @@ -0,0 +1,218 @@ +package ngin + +import ( + "errors" + "fmt" + "net/http" + "time" + + "zero/core/codec" + "zero/core/httphandler" + "zero/core/httprouter" + "zero/core/httpserver" + "zero/core/load" + "zero/core/stat" + + "github.com/justinas/alice" +) + +// use 1000m to represent 100% +const topCpuUsage = 1000 + +var ErrSignatureConfig = errors.New("bad config for Signature") + +type ( + Middleware func(next http.HandlerFunc) http.HandlerFunc + + server struct { + conf NgConf + routes []featuredRoutes + unauthorizedCallback httphandler.UnauthorizedCallback + unsignedCallback httphandler.UnsignedCallback + middlewares []Middleware + shedder load.Shedder + priorityShedder load.Shedder + } +) + +func newServer(c NgConf) *server { + srv := &server{ + conf: c, + } + if c.CpuThreshold > 0 { + srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) + srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold( + (c.CpuThreshold + topCpuUsage) >> 1)) + } + + return srv +} + +func (s *server) AddRoutes(r featuredRoutes) { + s.routes = append(s.routes, r) +} + +func (s *server) SetUnauthorizedCallback(callback httphandler.UnauthorizedCallback) { + s.unauthorizedCallback = callback +} + +func (s *server) SetUnsignedCallback(callback httphandler.UnsignedCallback) { + s.unsignedCallback = callback +} + +func (s *server) Start() error { + return s.StartWithRouter(httprouter.NewPatRouter()) +} + +func (s *server) StartWithRouter(router httprouter.Router) error { + if err := s.bindRoutes(router); err != nil { + return err + } + + return httpserver.StartHttp(s.conf.Host, s.conf.Port, router) +} + +func (s *server) appendAuthHandler(fr featuredRoutes, chain alice.Chain, + verifier func(alice.Chain) alice.Chain) alice.Chain { + if fr.jwt.enabled { + if len(fr.jwt.prevSecret) == 0 { + chain = chain.Append(httphandler.Authorize(fr.jwt.secret, + httphandler.WithUnauthorizedCallback(s.unauthorizedCallback))) + } else { + chain = chain.Append(httphandler.Authorize(fr.jwt.secret, + httphandler.WithPrevSecret(fr.jwt.prevSecret), + httphandler.WithUnauthorizedCallback(s.unauthorizedCallback))) + } + } + + return verifier(chain) +} + +func (s *server) bindFeaturedRoutes(router httprouter.Router, fr featuredRoutes, metrics *stat.Metrics) error { + verifier, err := s.signatureVerifier(fr.signature) + if err != nil { + return err + } + + for _, route := range fr.routes { + if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil { + return err + } + } + + return nil +} + +func (s *server) bindRoute(fr featuredRoutes, router httprouter.Router, metrics *stat.Metrics, + route Route, verifier func(chain alice.Chain) alice.Chain) error { + chain := alice.New( + httphandler.TracingHandler, + s.getLogHandler(), + httphandler.MaxConns(s.conf.MaxConns), + httphandler.BreakerHandler(route.Method, route.Path, metrics), + httphandler.SheddingHandler(s.getShedder(fr.priority), metrics), + httphandler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond), + httphandler.RecoverHandler, + httphandler.MetricHandler(metrics), + httphandler.PromMetricHandler(route.Path), + httphandler.MaxBytesHandler(s.conf.MaxBytes), + httphandler.GunzipHandler, + ) + chain = s.appendAuthHandler(fr, chain, verifier) + + for _, middleware := range s.middlewares { + chain = chain.Append(convertMiddleware(middleware)) + } + handle := chain.ThenFunc(route.Handler) + + return router.Handle(route.Method, route.Path, handle) +} + +func (s *server) bindRoutes(router httprouter.Router) error { + metrics := s.createMetrics() + + for _, fr := range s.routes { + if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil { + return err + } + } + + return nil +} + +func (s *server) createMetrics() *stat.Metrics { + var metrics *stat.Metrics + + if len(s.conf.Name) > 0 { + metrics = stat.NewMetrics(s.conf.Name) + } else { + metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port)) + } + + return metrics +} + +func (s *server) getLogHandler() func(http.Handler) http.Handler { + if s.conf.Verbose { + return httphandler.DetailedLogHandler + } else { + return httphandler.LogHandler + } +} + +func (s *server) getShedder(priority bool) load.Shedder { + if priority && s.priorityShedder != nil { + return s.priorityShedder + } + return s.shedder +} + +func (s *server) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { + if !signature.enabled { + return func(chain alice.Chain) alice.Chain { + return chain + }, nil + } + + if len(signature.PrivateKeys) == 0 { + if signature.Strict { + return nil, ErrSignatureConfig + } else { + return func(chain alice.Chain) alice.Chain { + return chain + }, nil + } + } + + decrypters := make(map[string]codec.RsaDecrypter) + for _, key := range signature.PrivateKeys { + fingerprint := key.Fingerprint + file := key.KeyFile + decrypter, err := codec.NewRsaDecrypter(file) + if err != nil { + return nil, err + } + + decrypters[fingerprint] = decrypter + } + + return func(chain alice.Chain) alice.Chain { + if s.unsignedCallback != nil { + return chain.Append(httphandler.ContentSecurityHandler( + decrypters, signature.Expiry, signature.Strict, s.unsignedCallback)) + } else { + return chain.Append(httphandler.ContentSecurityHandler( + decrypters, signature.Expiry, signature.Strict)) + } + }, nil +} + +func (s *server) use(middleware Middleware) { + s.middlewares = append(s.middlewares, middleware) +} + +func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(ware(next.ServeHTTP)) + } +} diff --git a/ngin/types.go b/ngin/types.go new file mode 100644 index 00000000..f237b8ed --- /dev/null +++ b/ngin/types.go @@ -0,0 +1,31 @@ +package ngin + +import "net/http" + +type ( + Route struct { + Method string + Path string + Handler http.HandlerFunc + } + + jwtSetting struct { + enabled bool + secret string + prevSecret string + } + + signatureSetting struct { + SignatureConf + enabled bool + } + + featuredRoutes struct { + priority bool + jwt jwtSetting + signature signatureSetting + routes []Route + } + + RouteOption func(r *featuredRoutes) +) diff --git a/rpcx/auth/auth.go b/rpcx/auth/auth.go new file mode 100644 index 00000000..44060a53 --- /dev/null +++ b/rpcx/auth/auth.go @@ -0,0 +1,74 @@ +package auth + +import ( + "context" + "time" + + "zero/core/collection" + "zero/core/stores/redis" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const defaultExpiration = 5 * time.Minute + +type Authenticator struct { + store *redis.Redis + key string + cache *collection.Cache + strict bool +} + +func NewAuthenticator(store *redis.Redis, key string, strict bool) (*Authenticator, error) { + cache, err := collection.NewCache(defaultExpiration) + if err != nil { + return nil, err + } + + return &Authenticator{ + store: store, + key: key, + cache: cache, + strict: strict, + }, nil +} + +func (a *Authenticator) Authenticate(ctx context.Context) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.Unauthenticated, missingMetadata) + } + + apps, tokens := md[appKey], md[tokenKey] + if len(apps) == 0 || len(tokens) == 0 { + return status.Error(codes.Unauthenticated, missingMetadata) + } + + app, token := apps[0], tokens[0] + if len(app) == 0 || len(token) == 0 { + return status.Error(codes.Unauthenticated, missingMetadata) + } + + return a.validate(app, token) +} + +func (a *Authenticator) validate(app, token string) error { + expect, err := a.cache.Take(app, func() (interface{}, error) { + return a.store.Hget(a.key, app) + }) + if err != nil { + if a.strict { + return status.Error(codes.Internal, err.Error()) + } else { + return nil + } + } + + if token != expect { + return status.Error(codes.Unauthenticated, accessDenied) + } + + return nil +} diff --git a/rpcx/auth/credential.go b/rpcx/auth/credential.go new file mode 100644 index 00000000..5855113f --- /dev/null +++ b/rpcx/auth/credential.go @@ -0,0 +1,47 @@ +package auth + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +type Credential struct { + App string + Token string +} + +func (c *Credential) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { + return map[string]string{ + appKey: c.App, + tokenKey: c.Token, + }, nil +} + +func (c *Credential) RequireTransportSecurity() bool { + return false +} + +func ParseCredential(ctx context.Context) Credential { + var credential Credential + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return credential + } + + apps, tokens := md[appKey], md[tokenKey] + if len(apps) == 0 || len(tokens) == 0 { + return credential + } + + app, token := apps[0], tokens[0] + if len(app) == 0 || len(token) == 0 { + return credential + } + + credential.App = app + credential.Token = token + + return credential +} diff --git a/rpcx/auth/vars.go b/rpcx/auth/vars.go new file mode 100644 index 00000000..6d85c5c3 --- /dev/null +++ b/rpcx/auth/vars.go @@ -0,0 +1,9 @@ +package auth + +const ( + appKey = "app" + tokenKey = "token" + + accessDenied = "access denied" + missingMetadata = "app/token required" +) diff --git a/rpcx/client.go b/rpcx/client.go new file mode 100644 index 00000000..5b1540c6 --- /dev/null +++ b/rpcx/client.go @@ -0,0 +1,69 @@ +package rpcx + +import ( + "log" + "time" + + "zero/core/discov" + "zero/core/rpc" + "zero/rpcx/auth" + + "google.golang.org/grpc" +) + +type RpcClient struct { + client rpc.Client +} + +func MustNewClient(c RpcClientConf, options ...rpc.ClientOption) *RpcClient { + cli, err := NewClient(c, options...) + if err != nil { + log.Fatal(err) + } + + return cli +} + +func NewClient(c RpcClientConf, options ...rpc.ClientOption) (*RpcClient, error) { + var opts []rpc.ClientOption + if c.HasCredential() { + opts = append(opts, rpc.WithDialOption(grpc.WithPerRPCCredentials(&auth.Credential{ + App: c.App, + Token: c.Token, + }))) + } + if c.Timeout > 0 { + opts = append(opts, rpc.WithTimeout(time.Duration(c.Timeout)*time.Millisecond)) + } + opts = append(opts, options...) + + var client rpc.Client + var err error + if len(c.Server) > 0 { + client, err = rpc.NewDirectClient(c.Server, opts...) + } else if err = c.Etcd.Validate(); err == nil { + client, err = rpc.NewRoundRobinRpcClient(c.Etcd.Hosts, c.Etcd.Key, opts...) + } + if err != nil { + return nil, err + } + + return &RpcClient{ + client: client, + }, nil +} + +func NewClientNoAuth(c discov.EtcdConf) (*RpcClient, error) { + client, err := rpc.NewRoundRobinRpcClient(c.Hosts, c.Key) + if err != nil { + return nil, err + } + + return &RpcClient{ + client: client, + }, nil +} + +func (rc *RpcClient) Next() (*grpc.ClientConn, bool) { + return rc.client.Next() +} diff --git a/rpcx/config.go b/rpcx/config.go new file mode 100644 index 00000000..46b91c03 --- /dev/null +++ b/rpcx/config.go @@ -0,0 +1,67 @@ +package rpcx + +import ( + "zero/core/discov" + "zero/core/service" + "zero/core/stores/redis" +) + +type ( + RpcServerConf struct { + service.ServiceConf + ListenOn string + Etcd discov.EtcdConf `json:",optional"` + Auth bool `json:",optional"` + Redis redis.RedisKeyConf `json:",optional"` + StrictControl bool `json:",optional"` + // pending forever is not allowed + // never set it to 0, if zero, the underlying will set to 2s automatically + Timeout int64 `json:",default=2000"` + CpuThreshold int64 `json:",default=900,range=[0:1000]"` + } + + RpcClientConf struct { + Etcd discov.EtcdConf `json:",optional"` + Server string `json:",optional=!Etcd"` + App string `json:",optional"` + Token string `json:",optional"` + Timeout int64 `json:",optional"` + } +) + +func NewDirectClientConf(server, app, token string) RpcClientConf { + return RpcClientConf{ + Server: server, + App: app, + Token: token, + } +} + +func NewEtcdClientConf(hosts []string, key, app, token string) RpcClientConf { + return RpcClientConf{ + Etcd: discov.EtcdConf{ + Hosts: hosts, + Key: key, + }, + App: app, + Token: token, + } +} + +func (sc RpcServerConf) HasEtcd() bool { + return len(sc.Etcd.Hosts) > 0 && len(sc.Etcd.Key) > 0 +} + +func (sc RpcServerConf) Validate() error { + if sc.Auth { + if err := sc.Redis.Validate(); err != nil { + return err + } + } + + return nil +} + +func (cc RpcClientConf) HasCredential() bool { + return len(cc.App) > 0 && len(cc.Token) > 0 +} diff --git a/rpcx/etc/config.json b/rpcx/etc/config.json new file mode 100644 index 00000000..c4707aef --- /dev/null +++ b/rpcx/etc/config.json @@ -0,0 +1,20 @@ +{ + "Log": { + "Access": "logs/access.log", + "Error": "logs/error.log", + "Stat": "logs/stat.log" + }, + "MetricsUrl": "http://localhost:2222/add", + "ListenOn": "localhost:3456", + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "Key": "rpcx" + }, + "Redis": { + "Host": "localhost:6379", + "Type": "node", + "Key": "apps" + } +} diff --git a/rpcx/interceptors/authinterceptor.go b/rpcx/interceptors/authinterceptor.go new file mode 100644 index 00000000..efac6d98 --- /dev/null +++ b/rpcx/interceptors/authinterceptor.go @@ -0,0 +1,31 @@ +package interceptors + +import ( + "context" + + "zero/rpcx/auth" + + "google.golang.org/grpc" +) + +func StreamAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + if err := authenticator.Authenticate(stream.Context()); err != nil { + return err + } + + return handler(srv, stream) + } +} + +func UnaryAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (interface{}, error) { + if err := authenticator.Authenticate(ctx); err != nil { + return nil, err + } + + return handler(ctx, req) + } +} diff --git a/rpcx/proxy.go b/rpcx/proxy.go new file mode 100644 index 00000000..f720d49c --- /dev/null +++ b/rpcx/proxy.go @@ -0,0 +1,66 @@ +package rpcx + +import ( + "context" + "sync" + + "zero/core/rpc" + "zero/core/syncx" + "zero/rpcx/auth" + + "google.golang.org/grpc" +) + +type RpcProxy struct { + backend string + clients map[string]*RpcClient + options []rpc.ClientOption + sharedCalls syncx.SharedCalls + lock sync.Mutex +} + +func NewRpcProxy(backend string, opts ...rpc.ClientOption) *RpcProxy { + return &RpcProxy{ + backend: backend, + clients: make(map[string]*RpcClient), + options: opts, + sharedCalls: syncx.NewSharedCalls(), + } +} + +func (p *RpcProxy) TakeConn(ctx context.Context) (*grpc.ClientConn, error) { + cred := auth.ParseCredential(ctx) + key := cred.App + "/" + cred.Token + val, err := p.sharedCalls.Do(key, func() (interface{}, error) { + p.lock.Lock() + client, ok := p.clients[key] + p.lock.Unlock() + if ok { + return client, nil + } + + client, err := NewClient(RpcClientConf{ + Server: p.backend, + App: cred.App, + Token: cred.Token, + }, p.options...) + if err != nil { + return nil, err + } + + p.lock.Lock() + p.clients[key] = client + p.lock.Unlock() + return client, nil + }) + if err != nil { + return nil, err + } + + conn, ok := val.(*RpcClient).Next() + if !ok { + return nil, grpc.ErrServerStopped + } + + return conn, nil +} diff --git a/rpcx/server.go b/rpcx/server.go new file mode 100644 index 00000000..ffa28df4 --- /dev/null +++ b/rpcx/server.go @@ -0,0 +1,126 @@ +package rpcx + +import ( + "log" + "os" + "strings" + "time" + + "zero/core/load" + "zero/core/logx" + "zero/core/netx" + "zero/core/rpc" + "zero/core/rpc/serverinterceptors" + "zero/core/stat" + "zero/rpcx/auth" + "zero/rpcx/interceptors" +) + +const envPodIp = "POD_IP" + +type RpcServer struct { + server rpc.Server + register rpc.RegisterFn +} + +func MustNewServer(c RpcServerConf, register rpc.RegisterFn) *RpcServer { + server, err := NewServer(c, register) + if err != nil { + log.Fatal(err) + } + + return server +} + +func NewServer(c RpcServerConf, register rpc.RegisterFn) (*RpcServer, error) { + var err error + if err = c.Validate(); err != nil { + return nil, err + } + + var server rpc.Server + metrics := stat.NewMetrics(c.ListenOn) + if c.HasEtcd() { + listenOn := figureOutListenOn(c.ListenOn) + server, err = rpc.NewRpcPubServer(c.Etcd.Hosts, c.Etcd.Key, listenOn, rpc.WithMetrics(metrics)) + if err != nil { + return nil, err + } + } else { + server = rpc.NewRpcServer(c.ListenOn, rpc.WithMetrics(metrics)) + } + + server.SetName(c.Name) + if err = setupInterceptors(server, c, metrics); err != nil { + return nil, err + } + + rpcServer := &RpcServer{ + server: server, + register: register, + } + if err = c.SetUp(); err != nil { + return nil, err + } + + return rpcServer, nil +} + +func (rs *RpcServer) Start() { + if err := rs.server.Start(rs.register); err != nil { + logx.Error(err) + panic(err) + } +} + +func (rs *RpcServer) Stop() { + logx.Close() +} + +func figureOutListenOn(listenOn string) string { + fields := strings.Split(listenOn, ":") + if len(fields) == 0 { + return listenOn + } + + host := fields[0] + if len(host) > 0 && host != "0.0.0.0" { + return listenOn + } + + ip := os.Getenv(envPodIp) + if len(ip) == 0 { + ip = netx.InternalIp() + } + if len(ip) == 0 { + return listenOn + } else { + return strings.Join(append([]string{ip}, fields[1:]...), ":") + } +} + +func setupInterceptors(server rpc.Server, c RpcServerConf, metrics *stat.Metrics) error { + if c.CpuThreshold > 0 { + shedder := load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) + server.AddUnaryInterceptors(serverinterceptors.UnarySheddingInterceptor(shedder, metrics)) + } + + if c.Timeout > 0 { + server.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor( + time.Duration(c.Timeout) * time.Millisecond)) + } + + server.AddUnaryInterceptors(serverinterceptors.UnaryTracingInterceptor(c.Name)) + + if c.Auth { + authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl) + if err != nil { + return err + } + + server.AddStreamInterceptors(interceptors.StreamAuthorizeInterceptor(authenticator)) + server.AddUnaryInterceptors(interceptors.UnaryAuthorizeInterceptor(authenticator)) + } + + return nil +} diff --git a/rq/config.go b/rq/config.go new file mode 100644 index 00000000..0a4223ae --- /dev/null +++ b/rq/config.go @@ -0,0 +1,18 @@ +package rq + +import ( + "zero/core/discov" + "zero/core/service" + "zero/core/stores/redis" +) + +type RmqConf struct { + service.ServiceConf + Redis redis.RedisKeyConf + Etcd discov.EtcdConf `json:",optional"` + NumProducers int `json:",optional"` + NumConsumers int `json:",optional"` + Timeout int64 `json:",optional"` + DropBefore int64 `json:",optional"` + ServerSensitive bool `json:",default=false"` +} diff --git a/rq/constant/const.go b/rq/constant/const.go new file mode 100644 index 00000000..ae4d22e9 --- /dev/null +++ b/rq/constant/const.go @@ -0,0 +1,7 @@ +package constant + +const ( + Delimeter = "/" + ServerSensitivePrefix = '*' + TimedQueueType = "timed" +) diff --git a/rq/etc/config.json b/rq/etc/config.json new file mode 100644 index 00000000..f8951987 --- /dev/null +++ b/rq/etc/config.json @@ -0,0 +1,19 @@ +{ + "Log": { + "Access": "logs/access.log", + "Error": "logs/error.log", + "Stat": "logs/stat.log" + }, + "MetricsUrl": "http://localhost:2222/add", + "Redis": { + "Host": "localhost:6379", + "Type": "node", + "Key": "reqs" + }, + "Etcd": { + "Hosts": [ + "localhost:2379" + ], + "EtcdKey": "rq" + } +} diff --git a/rq/hashchange.go b/rq/hashchange.go new file mode 100644 index 00000000..2f68b5e2 --- /dev/null +++ b/rq/hashchange.go @@ -0,0 +1,39 @@ +package rq + +import ( + "math/rand" + + "zero/core/hash" +) + +type HashChange struct { + id int64 + oldHash *hash.ConsistentHash + newHash *hash.ConsistentHash +} + +func NewHashChange(oldHash, newHash *hash.ConsistentHash) HashChange { + return HashChange{ + id: rand.Int63(), + oldHash: oldHash, + newHash: newHash, + } +} + +func (hc HashChange) GetId() int64 { + return hc.id +} + +func (hc HashChange) ShallEvict(key interface{}) bool { + oldTarget, oldOk := hc.oldHash.Get(key) + if !oldOk { + return false + } + + newTarget, newOk := hc.newHash.Get(key) + if !newOk { + return false + } + + return oldTarget != newTarget +} diff --git a/rq/pusher.go b/rq/pusher.go new file mode 100644 index 00000000..14fe4d79 --- /dev/null +++ b/rq/pusher.go @@ -0,0 +1,446 @@ +package rq + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "zero/core/discov" + "zero/core/errorx" + "zero/core/jsonx" + "zero/core/lang" + "zero/core/logx" + "zero/core/queue" + "zero/core/redisqueue" + "zero/core/stores/redis" + "zero/core/threading" + "zero/rq/constant" + "zero/rq/update" +) + +const ( + retryTimes = 3 + etcdRedisFields = 4 +) + +var ErrPusherTypeError = errors.New("not a QueuePusher instance") + +type ( + KeyFn func(string) (key, payload string, err error) + KeysFn func(string) (ctx context.Context, keys []string, err error) + AssembleFn func(context.Context, []string) (payload string, err error) + PusherOption func(*Pusher) error + + // just push once or do it retryTimes, it's a choice. + // because only when at least a server is alive, and + // pushing to the server failed, we'll return with an error + // if waken up, but the server is going down very quickly, + // we're going to wait again. so it's safe to push once. + pushStrategy interface { + addListener(listener discov.Listener) + push(string) error + } + + batchConsistentStrategy struct { + keysFn KeysFn + assembleFn AssembleFn + subClient *discov.BatchConsistentSubClient + } + + consistentStrategy struct { + keyFn KeyFn + subClient *discov.ConsistentSubClient + } + + roundRobinStrategy struct { + subClient *discov.RoundRobinSubClient + } + + serverListener struct { + updater *update.IncrementalUpdater + } + + Pusher struct { + name string + endpoints []string + key string + failovers sync.Map + strategy pushStrategy + serverSensitive bool + } +) + +func NewPusher(endpoints []string, key string, opts ...PusherOption) (*Pusher, error) { + pusher := &Pusher{ + name: getName(key), + endpoints: endpoints, + key: key, + } + + if len(opts) == 0 { + opts = []PusherOption{WithRoundRobinStrategy()} + } + + for _, opt := range opts { + if err := opt(pusher); err != nil { + return nil, err + } + } + + if pusher.serverSensitive { + listener := new(serverListener) + listener.updater = update.NewIncrementalUpdater(listener.update) + pusher.strategy.addListener(listener) + } + + return pusher, nil +} + +func (pusher *Pusher) Name() string { + return pusher.name +} + +func (pusher *Pusher) Push(message string) error { + return pusher.strategy.push(message) +} + +func (pusher *Pusher) close(server string, conn interface{}) error { + logx.Errorf("dropped redis node: %s", server) + + return pusher.failover(server) +} + +func (pusher *Pusher) dial(server string) (interface{}, error) { + pusher.failovers.Delete(server) + + p, err := newPusher(server) + if err != nil { + return nil, err + } + + logx.Infof("new redis node: %s", server) + + return p, nil +} + +func (pusher *Pusher) failover(server string) error { + pusher.failovers.Store(server, lang.Placeholder) + + rds, key, option, err := newRedisWithKey(server) + if err != nil { + return err + } + + threading.GoSafe(func() { + defer pusher.failovers.Delete(server) + + for { + _, ok := pusher.failovers.Load(server) + if !ok { + logx.Infof("redis queue (%s) revived", server) + return + } + + message, err := rds.Lpop(key) + if err != nil { + logx.Error(err) + return + } + + if len(message) == 0 { + logx.Infof("repush redis queue (%s) done", server) + return + } + + if option == constant.TimedQueueType { + message, err = unwrapTimedMessage(message) + if err != nil { + logx.Errorf("invalid timedMessage: %s, error: %s", message, err.Error()) + return + } + } + + if err = pusher.strategy.push(message); err != nil { + logx.Error(err) + return + } + } + }) + + return nil +} + +func UnmarshalPusher(server string) (queue.QueuePusher, error) { + store, key, option, err := newRedisWithKey(server) + if err != nil { + return nil, err + } + + if option == constant.TimedQueueType { + return redisqueue.NewPusher(store, key, redisqueue.WithTime()), nil + } + + return redisqueue.NewPusher(store, key), nil +} + +func WithBatchConsistentStrategy(keysFn KeysFn, assembleFn AssembleFn, opts ...discov.BalanceOption) PusherOption { + return func(pusher *Pusher) error { + subClient, err := discov.NewBatchConsistentSubClient(pusher.endpoints, pusher.key, pusher.dial, + pusher.close, opts...) + if err != nil { + return err + } + + pusher.strategy = batchConsistentStrategy{ + keysFn: keysFn, + assembleFn: assembleFn, + subClient: subClient, + } + + return nil + } +} + +func WithConsistentStrategy(keyFn KeyFn, opts ...discov.BalanceOption) PusherOption { + return func(pusher *Pusher) error { + subClient, err := discov.NewConsistentSubClient(pusher.endpoints, pusher.key, pusher.dial, pusher.close, opts...) + if err != nil { + return err + } + + pusher.strategy = consistentStrategy{ + keyFn: keyFn, + subClient: subClient, + } + + return nil + } +} + +func WithRoundRobinStrategy() PusherOption { + return func(pusher *Pusher) error { + subClient, err := discov.NewRoundRobinSubClient(pusher.endpoints, pusher.key, pusher.dial, pusher.close) + if err != nil { + return err + } + + pusher.strategy = roundRobinStrategy{ + subClient: subClient, + } + + return nil + } +} + +func WithServerSensitive() PusherOption { + return func(pusher *Pusher) error { + pusher.serverSensitive = true + return nil + } +} + +func (bcs batchConsistentStrategy) addListener(listener discov.Listener) { + bcs.subClient.AddListener(listener) +} + +func (bcs batchConsistentStrategy) balance(keys []string) map[interface{}][]string { + // we need to make sure the servers are available, otherwise wait forever + for { + if mapping, ok := bcs.subClient.Next(keys); ok { + return mapping + } else { + bcs.subClient.WaitForServers() + // make sure we don't flood logs too much in extreme conditions + time.Sleep(time.Second) + } + } +} + +func (bcs batchConsistentStrategy) push(message string) error { + ctx, keys, err := bcs.keysFn(message) + if err != nil { + return err + } + + var batchError errorx.BatchError + mapping := bcs.balance(keys) + for conn, connKeys := range mapping { + payload, err := bcs.assembleFn(ctx, connKeys) + if err != nil { + batchError.Add(err) + continue + } + + for i := 0; i < retryTimes; i++ { + if err = bcs.pushOnce(conn, payload); err != nil { + batchError.Add(err) + } else { + break + } + } + } + + return batchError.Err() +} + +func (bcs batchConsistentStrategy) pushOnce(server interface{}, payload string) error { + pusher, ok := server.(queue.QueuePusher) + if ok { + return pusher.Push(payload) + } else { + return ErrPusherTypeError + } +} + +func (cs consistentStrategy) addListener(listener discov.Listener) { + cs.subClient.AddListener(listener) +} + +func (cs consistentStrategy) push(message string) error { + var batchError errorx.BatchError + + key, payload, err := cs.keyFn(message) + if err != nil { + return err + } + + for i := 0; i < retryTimes; i++ { + if err = cs.pushOnce(key, payload); err != nil { + batchError.Add(err) + } else { + return nil + } + } + + return batchError.Err() +} + +func (cs consistentStrategy) pushOnce(key, payload string) error { + // we need to make sure the servers are available, otherwise wait forever + for { + if server, ok := cs.subClient.Next(key); ok { + pusher, ok := server.(queue.QueuePusher) + if ok { + return pusher.Push(payload) + } else { + return ErrPusherTypeError + } + } else { + cs.subClient.WaitForServers() + // make sure we don't flood logs too much in extreme conditions + time.Sleep(time.Second) + } + } +} + +func (rrs roundRobinStrategy) addListener(listener discov.Listener) { + rrs.subClient.AddListener(listener) +} + +func (rrs roundRobinStrategy) push(message string) error { + var batchError errorx.BatchError + + for i := 0; i < retryTimes; i++ { + if err := rrs.pushOnce(message); err != nil { + batchError.Add(err) + } else { + return nil + } + } + + return batchError.Err() +} + +func (rrs roundRobinStrategy) pushOnce(message string) error { + if server, ok := rrs.subClient.Next(); ok { + pusher, ok := server.(queue.QueuePusher) + if ok { + return pusher.Push(message) + } else { + return ErrPusherTypeError + } + } else { + rrs.subClient.WaitForServers() + return rrs.pushOnce(message) + } +} + +func getName(key string) string { + return fmt.Sprintf("etcd:%s", key) +} + +func newPusher(server string) (queue.QueuePusher, error) { + if rds, key, option, err := newRedisWithKey(server); err != nil { + return nil, err + } else if option == constant.TimedQueueType { + return redisqueue.NewPusher(rds, key, redisqueue.WithTime()), nil + } else { + return redisqueue.NewPusher(rds, key), nil + } +} + +func newRedisWithKey(server string) (rds *redis.Redis, key, option string, err error) { + fields := strings.Split(server, constant.Delimeter) + if len(fields) < etcdRedisFields { + err = fmt.Errorf("wrong redis queue: %s, should be ip:port/type/password/key/[option]", server) + return + } + + addr := fields[0] + tp := fields[1] + pass := fields[2] + key = fields[3] + + if len(fields) > etcdRedisFields { + option = fields[4] + } + + rds = redis.NewRedis(addr, tp, pass) + return +} + +func (sl *serverListener) OnUpdate(keys []string, servers []string, newKey string) { + sl.updater.Update(keys, servers, newKey) +} + +func (sl *serverListener) OnReload() { + sl.updater.Update(nil, nil, "") +} + +func (sl *serverListener) update(change update.ServerChange) { + content, err := change.Marshal() + if err != nil { + logx.Error(err) + } + + if err = broadcast(change.Servers, content); err != nil { + logx.Error(err) + } +} + +func broadcast(servers []string, message string) error { + var be errorx.BatchError + + for _, server := range servers { + q, err := UnmarshalPusher(server) + if err != nil { + be.Add(err) + } else { + q.Push(message) + } + } + + return be.Err() +} + +func unwrapTimedMessage(message string) (string, error) { + var tm redisqueue.TimedMessage + if err := jsonx.UnmarshalFromString(message, &tm); err != nil { + return "", err + } + + return tm.Payload, nil +} diff --git a/rq/queue.go b/rq/queue.go new file mode 100644 index 00000000..e6a093a0 --- /dev/null +++ b/rq/queue.go @@ -0,0 +1,339 @@ +package rq + +import ( + "errors" + "fmt" + "log" + "strings" + "sync" + "time" + + "zero/core/discov" + "zero/core/logx" + "zero/core/queue" + "zero/core/redisqueue" + "zero/core/service" + "zero/core/stores/redis" + "zero/core/stringx" + "zero/core/threading" + "zero/rq/constant" + "zero/rq/update" +) + +const keyLen = 6 + +var ( + ErrTimeout = errors.New("timeout error") + + eventHandlerPlaceholder = dummyEventHandler(0) +) + +type ( + ConsumeHandle func(string) error + + ConsumeHandler interface { + Consume(string) error + } + + EventHandler interface { + OnEvent(event interface{}) + } + + QueueOption func(queue *MessageQueue) + + queueOptions struct { + renewId int64 + } + + MessageQueue struct { + c RmqConf + redisQueue *queue.Queue + consumerFactory queue.ConsumerFactory + options queueOptions + eventLock sync.Mutex + lastEvent string + } +) + +func MustNewMessageQueue(c RmqConf, factory queue.ConsumerFactory, opts ...QueueOption) queue.MessageQueue { + q, err := NewMessageQueue(c, factory, opts...) + if err != nil { + log.Fatal(err) + } + + return q +} + +func NewMessageQueue(c RmqConf, factory queue.ConsumerFactory, opts ...QueueOption) (queue.MessageQueue, error) { + if err := c.SetUp(); err != nil { + return nil, err + } + + q := &MessageQueue{ + c: c, + } + + if len(q.c.Redis.Key) == 0 { + if len(q.c.Name) == 0 { + q.c.Redis.Key = stringx.Randn(keyLen) + } else { + q.c.Redis.Key = fmt.Sprintf("%s-%s", q.c.Name, stringx.Randn(keyLen)) + } + } + if q.c.Timeout > 0 { + factory = wrapWithTimeout(factory, time.Duration(q.c.Timeout)*time.Millisecond) + } + factory = wrapWithServerSensitive(q, factory) + q.consumerFactory = factory + q.redisQueue = q.buildQueue() + + for _, opt := range opts { + opt(q) + } + + return q, nil +} + +func (q *MessageQueue) Start() { + serviceGroup := service.NewServiceGroup() + serviceGroup.Add(q.redisQueue) + q.maybeAppendRenewer(serviceGroup, q.redisQueue) + serviceGroup.Start() +} + +func (q *MessageQueue) Stop() { + logx.Close() +} + +func (q *MessageQueue) buildQueue() *queue.Queue { + inboundStore := redis.NewRedis(q.c.Redis.Host, q.c.Redis.Type, q.c.Redis.Pass) + producerFactory := redisqueue.NewProducerFactory(inboundStore, q.c.Redis.Key, + redisqueue.TimeSensitive(q.c.DropBefore)) + mq := queue.NewQueue(producerFactory, q.consumerFactory) + + if len(q.c.Name) > 0 { + mq.SetName(q.c.Name) + } + if q.c.NumConsumers > 0 { + mq.SetNumConsumer(q.c.NumConsumers) + } + if q.c.NumProducers > 0 { + mq.SetNumProducer(q.c.NumProducers) + } + + return mq +} + +func (q *MessageQueue) compareAndSetEvent(event string) bool { + q.eventLock.Lock() + defer q.eventLock.Unlock() + + if q.lastEvent == event { + return false + } + + q.lastEvent = event + return true +} + +func (q *MessageQueue) maybeAppendRenewer(group *service.ServiceGroup, mq *queue.Queue) { + if len(q.c.Etcd.Hosts) > 0 || len(q.c.Etcd.Key) > 0 { + etcdValue := MarshalQueue(q.c.Redis) + if q.c.DropBefore > 0 { + etcdValue = strings.Join([]string{etcdValue, constant.TimedQueueType}, constant.Delimeter) + } + keepAliver := discov.NewRenewer(q.c.Etcd.Hosts, q.c.Etcd.Key, etcdValue, q.options.renewId) + mq.AddListener(pauseResumeHandler{ + Renewer: keepAliver, + }) + group.Add(keepAliver) + } +} + +func MarshalQueue(rds redis.RedisKeyConf) string { + return strings.Join([]string{ + rds.Host, + rds.Type, + rds.Pass, + rds.Key, + }, constant.Delimeter) +} + +func WithHandle(handle ConsumeHandle) queue.ConsumerFactory { + return WithHandler(innerConsumerHandler{handle}) +} + +func WithHandler(handler ConsumeHandler, eventHandlers ...EventHandler) queue.ConsumerFactory { + return func() (queue.Consumer, error) { + if len(eventHandlers) < 1 { + return eventConsumer{ + consumeHandler: handler, + eventHandler: eventHandlerPlaceholder, + }, nil + } else { + return eventConsumer{ + consumeHandler: handler, + eventHandler: eventHandlers[0], + }, nil + } + } +} + +func WithHandlerFactory(factory func() (ConsumeHandler, error)) queue.ConsumerFactory { + return func() (queue.Consumer, error) { + if handler, err := factory(); err != nil { + return nil, err + } else { + return eventlessHandler{handler}, nil + } + } +} + +func WithRenewId(id int64) QueueOption { + return func(mq *MessageQueue) { + mq.options.renewId = id + } +} + +func wrapWithServerSensitive(mq *MessageQueue, factory queue.ConsumerFactory) queue.ConsumerFactory { + return func() (queue.Consumer, error) { + consumer, err := factory() + if err != nil { + return nil, err + } + + return &serverSensitiveConsumer{ + mq: mq, + consumer: consumer, + }, nil + } +} + +func wrapWithTimeout(factory queue.ConsumerFactory, dt time.Duration) queue.ConsumerFactory { + return func() (queue.Consumer, error) { + consumer, err := factory() + if err != nil { + return nil, err + } + + return &timeoutConsumer{ + consumer: consumer, + dt: dt, + timer: time.NewTimer(dt), + }, nil + } +} + +type innerConsumerHandler struct { + handle ConsumeHandle +} + +func (h innerConsumerHandler) Consume(v string) error { + return h.handle(v) +} + +type serverSensitiveConsumer struct { + mq *MessageQueue + consumer queue.Consumer +} + +func (c *serverSensitiveConsumer) Consume(msg string) error { + if update.IsServerChange(msg) { + change, err := update.UnmarshalServerChange(msg) + if err != nil { + return err + } + + code := change.GetCode() + if !c.mq.compareAndSetEvent(code) { + return nil + } + + oldHash := change.CreatePrevHash() + newHash := change.CreateCurrentHash() + hashChange := NewHashChange(oldHash, newHash) + c.mq.redisQueue.Broadcast(hashChange) + + return nil + } + + return c.consumer.Consume(msg) +} + +func (c *serverSensitiveConsumer) OnEvent(event interface{}) { + c.consumer.OnEvent(event) +} + +type timeoutConsumer struct { + consumer queue.Consumer + dt time.Duration + timer *time.Timer +} + +func (c *timeoutConsumer) Consume(msg string) error { + done := make(chan error) + threading.GoSafe(func() { + if err := c.consumer.Consume(msg); err != nil { + done <- err + } + close(done) + }) + + c.timer.Reset(c.dt) + select { + case err, ok := <-done: + c.timer.Stop() + if ok { + return err + } else { + return nil + } + case <-c.timer.C: + return ErrTimeout + } +} + +func (c *timeoutConsumer) OnEvent(event interface{}) { + c.consumer.OnEvent(event) +} + +type pauseResumeHandler struct { + discov.Renewer +} + +func (pr pauseResumeHandler) OnPause() { + pr.Pause() +} + +func (pr pauseResumeHandler) OnResume() { + pr.Resume() +} + +type eventConsumer struct { + consumeHandler ConsumeHandler + eventHandler EventHandler +} + +func (ec eventConsumer) Consume(msg string) error { + return ec.consumeHandler.Consume(msg) +} + +func (ec eventConsumer) OnEvent(event interface{}) { + ec.eventHandler.OnEvent(event) +} + +type eventlessHandler struct { + handler ConsumeHandler +} + +func (h eventlessHandler) Consume(msg string) error { + return h.handler.Consume(msg) +} + +func (h eventlessHandler) OnEvent(event interface{}) { +} + +type dummyEventHandler int + +func (eh dummyEventHandler) OnEvent(event interface{}) { +} diff --git a/rq/queue_test.go b/rq/queue_test.go new file mode 100644 index 00000000..8722048b --- /dev/null +++ b/rq/queue_test.go @@ -0,0 +1,62 @@ +package rq + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestQueueWithTimeout(t *testing.T) { + consumer, err := wrapWithTimeout(WithHandle(func(string) error { + time.Sleep(time.Minute) + return nil + }), 100)() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, ErrTimeout, consumer.Consume("any")) +} + +func TestQueueWithoutTimeout(t *testing.T) { + consumer, err := wrapWithTimeout(WithHandle(func(string) error { + return nil + }), 3600000)() + if err != nil { + t.Fatal(err) + } + + assert.Nil(t, consumer.Consume("any")) +} + +func BenchmarkQueue(b *testing.B) { + b.ReportAllocs() + + consumer, err := WithHandle(func(string) error { + return nil + })() + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + consumer.Consume(strconv.Itoa(i)) + } +} + +func BenchmarkQueueWithTimeout(b *testing.B) { + b.ReportAllocs() + + consumer, err := wrapWithTimeout(WithHandle(func(string) error { + return nil + }), 1000)() + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + consumer.Consume(strconv.Itoa(i)) + } +} diff --git a/rq/update/incrementalupdater.go b/rq/update/incrementalupdater.go new file mode 100644 index 00000000..d8187095 --- /dev/null +++ b/rq/update/incrementalupdater.go @@ -0,0 +1,179 @@ +package update + +import ( + "sync" + "time" + + "zero/core/hash" + "zero/core/stringx" +) + +const ( + incrementalStep = 5 + stepDuration = time.Second * 3 +) + +type ( + updateEvent struct { + keys []string + newKey string + servers []string + } + + UpdateFunc func(change ServerChange) + + IncrementalUpdater struct { + lock sync.Mutex + started bool + taskChan chan updateEvent + updates ServerChange + updateFn UpdateFunc + pendingEvents []updateEvent + } +) + +func NewIncrementalUpdater(updateFn UpdateFunc) *IncrementalUpdater { + return &IncrementalUpdater{ + taskChan: make(chan updateEvent), + updates: ServerChange{ + Current: Snapshot{ + Keys: make([]string, 0), + WeightedKeys: make([]weightedKey, 0), + }, + Servers: make([]string, 0), + }, + updateFn: updateFn, + } +} + +func (ru *IncrementalUpdater) Update(keys []string, servers []string, newKey string) { + ru.lock.Lock() + defer ru.lock.Unlock() + + if !ru.started { + go ru.run() + ru.started = true + } + + ru.taskChan <- updateEvent{ + keys: keys, + newKey: newKey, + servers: servers, + } +} + +// Return true if incremental update is done +func (ru *IncrementalUpdater) advance() bool { + previous := ru.updates.Current + keys := make([]string, 0) + weightedKeys := make([]weightedKey, 0) + servers := ru.updates.Servers + for _, key := range ru.updates.Current.Keys { + keys = append(keys, key) + } + for _, wkey := range ru.updates.Current.WeightedKeys { + weight := wkey.Weight + incrementalStep + if weight >= hash.TopWeight { + keys = append(keys, wkey.Key) + } else { + weightedKeys = append(weightedKeys, weightedKey{ + Key: wkey.Key, + Weight: weight, + }) + } + } + + for _, event := range ru.pendingEvents { + // ignore reload events + if len(event.newKey) == 0 || len(event.servers) == 0 { + continue + } + + // anyway, add the servers, just to avoid missing notify any server + servers = stringx.Union(servers, event.servers) + if keyExists(keys, weightedKeys, event.newKey) { + continue + } + + weightedKeys = append(weightedKeys, weightedKey{ + Key: event.newKey, + Weight: incrementalStep, + }) + } + + // clear pending events + ru.pendingEvents = ru.pendingEvents[:0] + + change := ServerChange{ + Previous: previous, + Current: Snapshot{ + Keys: keys, + WeightedKeys: weightedKeys, + }, + Servers: servers, + } + ru.updates = change + ru.updateFn(change) + + return len(weightedKeys) == 0 +} + +func (ru *IncrementalUpdater) run() { + defer func() { + ru.lock.Lock() + ru.started = false + ru.lock.Unlock() + }() + + ticker := time.NewTicker(stepDuration) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if ru.advance() { + return + } + case event := <-ru.taskChan: + ru.updateKeys(event) + } + } +} + +func (ru *IncrementalUpdater) updateKeys(event updateEvent) { + isWeightedKey := func(key string) bool { + for _, wkey := range ru.updates.Current.WeightedKeys { + if wkey.Key == key { + return true + } + } + + return false + } + + keys := make([]string, 0, len(event.keys)) + for _, key := range event.keys { + if !isWeightedKey(key) { + keys = append(keys, key) + } + } + + ru.updates.Current.Keys = keys + ru.pendingEvents = append(ru.pendingEvents, event) +} + +func keyExists(keys []string, weightedKeys []weightedKey, key string) bool { + for _, each := range keys { + if key == each { + return true + } + } + + for _, wkey := range weightedKeys { + if wkey.Key == key { + return true + } + } + + return false +} diff --git a/rq/update/serverchange.go b/rq/update/serverchange.go new file mode 100644 index 00000000..44ea40c7 --- /dev/null +++ b/rq/update/serverchange.go @@ -0,0 +1,106 @@ +package update + +import ( + "crypto/md5" + "errors" + "fmt" + "io" + "sort" + + "zero/core/hash" + "zero/core/jsonx" + "zero/rq/constant" +) + +var ErrInvalidServerChange = errors.New("not a server change message") + +type ( + weightedKey struct { + Key string + Weight int + } + + Snapshot struct { + Keys []string + WeightedKeys []weightedKey + } + + ServerChange struct { + Previous Snapshot + Current Snapshot + Servers []string + } +) + +func (s Snapshot) GetCode() string { + keys := append([]string(nil), s.Keys...) + sort.Strings(keys) + weightedKeys := append([]weightedKey(nil), s.WeightedKeys...) + sort.SliceStable(weightedKeys, func(i, j int) bool { + return weightedKeys[i].Key < weightedKeys[j].Key + }) + + digest := md5.New() + for _, key := range keys { + io.WriteString(digest, fmt.Sprintf("%s\n", key)) + } + for _, wkey := range weightedKeys { + io.WriteString(digest, fmt.Sprintf("%s:%d\n", wkey.Key, wkey.Weight)) + } + + return fmt.Sprintf("%x", digest.Sum(nil)) +} + +func (sc ServerChange) CreateCurrentHash() *hash.ConsistentHash { + curHash := hash.NewConsistentHash() + + for _, key := range sc.Current.Keys { + curHash.Add(key) + } + for _, wkey := range sc.Current.WeightedKeys { + curHash.AddWithWeight(wkey.Key, wkey.Weight) + } + + return curHash +} + +func (sc ServerChange) CreatePrevHash() *hash.ConsistentHash { + prevHash := hash.NewConsistentHash() + + for _, key := range sc.Previous.Keys { + prevHash.Add(key) + } + for _, wkey := range sc.Previous.WeightedKeys { + prevHash.AddWithWeight(wkey.Key, wkey.Weight) + } + + return prevHash +} + +func (sc ServerChange) GetCode() string { + return sc.Current.GetCode() +} + +func IsServerChange(message string) bool { + return len(message) > 0 && message[0] == constant.ServerSensitivePrefix +} + +func (sc ServerChange) Marshal() (string, error) { + body, err := jsonx.Marshal(sc) + if err != nil { + return "", err + } + + return string(append([]byte{constant.ServerSensitivePrefix}, body...)), nil +} + +func UnmarshalServerChange(body string) (ServerChange, error) { + if len(body) == 0 { + return ServerChange{}, ErrInvalidServerChange + } + + var change ServerChange + err := jsonx.UnmarshalFromString(body[1:], &change) + + return change, err +} diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 00000000..08c4b4a9 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,13 @@ +sonar.host.url=https://sonar.xiaoheiban.cn/ +sonar.login=admin +sonar.password=blackboard +sonar.projectKey=zero +sonar.projectName=zero +sonar.sources=. +sonar.sources.inclusions=**/**.go +sonar.exclusions=**/*_test.go,**/vendor/** +sonar.tests=. +sonar.test.inclusions=**/*_test.go +sonar.test.exclusions=**/vendor/** +sonar.sourceEncoding=UTF-8 +sonar.go.tests.reportPaths=report.json diff --git a/stash/config/config.go b/stash/config/config.go new file mode 100644 index 00000000..e0a0d507 --- /dev/null +++ b/stash/config/config.go @@ -0,0 +1,41 @@ +package config + +import ( + "time" + + "zero/kq" +) + +type ( + Condition struct { + Key string + Value string + Type string `json:",default=match,options=match|contains"` + Op string `json:",default=and,options=and|or"` + } + + ElasticSearchConf struct { + Hosts []string + DailyIndexPrefix string + TimeZone string `json:",optional"` + MaxChunkBytes int `json:",default=1048576"` + Compress bool `json:",default=false"` + } + + Filter struct { + Action string `json:",options=drop|remove_field"` + Conditions []Condition `json:",optional"` + Fields []string `json:",optional"` + } + + Config struct { + Input struct { + Kafka kq.KqConf + } + Filters []Filter + Output struct { + ElasticSearch ElasticSearchConf + } + GracePeriod time.Duration `json:",default=10s"` + } +) diff --git a/stash/es/index.go b/stash/es/index.go new file mode 100644 index 00000000..ed150f6f --- /dev/null +++ b/stash/es/index.go @@ -0,0 +1,82 @@ +package es + +import ( + "context" + "sync" + "time" + + "zero/core/fx" + "zero/core/logx" + "zero/core/syncx" + + "github.com/olivere/elastic" +) + +const sharedCallsKey = "ensureIndex" + +type ( + IndexFormat func(time.Time) string + IndexFunc func() string + + Index struct { + client *elastic.Client + indexFormat IndexFormat + index string + lock sync.RWMutex + sharedCalls syncx.SharedCalls + } +) + +func NewIndex(client *elastic.Client, indexFormat IndexFormat) *Index { + return &Index{ + client: client, + indexFormat: indexFormat, + sharedCalls: syncx.NewSharedCalls(), + } +} + +func (idx *Index) GetIndex(t time.Time) string { + index := idx.indexFormat(t) + if err := idx.ensureIndex(index); err != nil { + logx.Error(err) + } + return index +} + +func (idx *Index) ensureIndex(index string) error { + idx.lock.RLock() + if index == idx.index { + idx.lock.RUnlock() + return nil + } + idx.lock.RUnlock() + + _, err := idx.sharedCalls.Do(sharedCallsKey, func() (i interface{}, err error) { + idx.lock.Lock() + defer idx.lock.Unlock() + + existsService := elastic.NewIndicesExistsService(idx.client) + existsService.Index([]string{index}) + exist, err := existsService.Do(context.Background()) + if err != nil { + return nil, err + } + if exist { + idx.index = index + return nil, nil + } + + createService := idx.client.CreateIndex(index) + if err := fx.DoWithRetries(func() error { + // is it necessary to check the result? + _, err := createService.Do(context.Background()) + return err + }); err != nil { + return nil, err + } + + idx.index = index + return nil, nil + }) + return err +} diff --git a/stash/es/writer.go b/stash/es/writer.go new file mode 100644 index 00000000..ac4e0fa6 --- /dev/null +++ b/stash/es/writer.go @@ -0,0 +1,65 @@ +package es + +import ( + "context" + "time" + + "zero/core/executors" + "zero/core/logx" + "zero/stash/config" + + "github.com/olivere/elastic" +) + +const docType = "doc" + +type ( + Writer struct { + client *elastic.Client + indexer *Index + inserter *executors.ChunkExecutor + } + + valueWithTime struct { + t time.Time + val string + } +) + +func NewWriter(c config.ElasticSearchConf, indexer *Index) (*Writer, error) { + client, err := elastic.NewClient( + elastic.SetSniff(false), + elastic.SetURL(c.Hosts...), + elastic.SetGzip(c.Compress), + ) + if err != nil { + return nil, err + } + + writer := Writer{ + client: client, + indexer: indexer, + } + writer.inserter = executors.NewChunkExecutor(writer.execute, executors.WithChunkBytes(c.MaxChunkBytes)) + return &writer, nil +} + +func (w *Writer) Write(t time.Time, val string) error { + return w.inserter.Add(valueWithTime{ + t: t, + val: val, + }, len(val)) +} + +func (w *Writer) execute(vals []interface{}) { + var bulk = w.client.Bulk() + for _, val := range vals { + pair := val.(valueWithTime) + req := elastic.NewBulkIndexRequest().Index(w.indexer.GetIndex(pair.t)).Type(docType).Doc(pair.val) + bulk.Add(req) + } + _, err := bulk.Do(context.Background()) + if err != nil { + logx.Error(err) + } +} diff --git a/stash/etc/config.json b/stash/etc/config.json new file mode 100644 index 00000000..35a610d2 --- /dev/null +++ b/stash/etc/config.json @@ -0,0 +1,75 @@ +{ + "Input": { + "Kafka": { + "Name": "easystash", + "Brokers": [ + "172.16.186.156:19092", + "172.16.186.157:19092", + "172.16.186.158:19092", + "172.16.186.159:19092", + "172.16.186.160:19092", + "172.16.186.161:19092" + ], + "Topic": "k8slog", + "Group": "pro", + "NumProducers": 16, + "MetricsUrl": "http://localhost:2222/add" + } + }, + "Filters": [ + { + "Action": "drop", + "Conditions": [ + { + "Key": "k8s_container_name", + "Value": "-rpc", + "Type": "contains" + }, + { + "Key": "level", + "Value": "info", + "Type": "match", + "Op": "and" + } + ] + }, + { + "Action": "remove_field", + "Fields": [ + "message", + "_source", + "_type", + "_score", + "_id", + "@version", + "topic", + "index", + "beat", + "docker_container", + "offset", + "prospector", + "source", + "stream" + ] + } + ], + "Output": { + "ElasticSearch": { + "Hosts": [ + "172.16.141.14:9200", + "172.16.141.15:9200", + "172.16.141.16:9200", + "172.16.141.17:9200", + "172.16.140.195:9200", + "172.16.140.196:9200", + "172.16.140.197:9200", + "172.16.140.198:9200", + "172.16.140.199:9200", + "172.16.140.200:9200", + "172.16.140.201:9200", + "172.16.140.202:9200" + ], + "DailyIndexPrefix": "k8s_pro-" + } + } +} \ No newline at end of file diff --git a/stash/filter/filters.go b/stash/filter/filters.go new file mode 100644 index 00000000..c1a59f30 --- /dev/null +++ b/stash/filter/filters.go @@ -0,0 +1,104 @@ +package filter + +import ( + "strings" + + "zero/stash/config" + + "github.com/globalsign/mgo/bson" +) + +const ( + filterDrop = "drop" + filterRemoveFields = "remove_field" + opAnd = "and" + opOr = "or" + typeContains = "contains" + typeMatch = "match" +) + +type FilterFunc func(map[string]interface{}) map[string]interface{} + +func CreateFilters(c config.Config) []FilterFunc { + var filters []FilterFunc + + for _, f := range c.Filters { + switch f.Action { + case filterDrop: + filters = append(filters, DropFilter(f.Conditions)) + case filterRemoveFields: + filters = append(filters, RemoveFieldFilter(f.Fields)) + } + } + + return filters +} + +func DropFilter(conds []config.Condition) FilterFunc { + return func(m map[string]interface{}) map[string]interface{} { + var qualify bool + for _, cond := range conds { + var qualifyOnce bool + switch cond.Type { + case typeMatch: + qualifyOnce = cond.Value == m[cond.Key] + case typeContains: + if val, ok := m[cond.Key].(string); ok { + qualifyOnce = strings.Contains(val, cond.Value) + } + } + + switch cond.Op { + case opAnd: + if !qualifyOnce { + return m + } else { + qualify = true + } + case opOr: + if qualifyOnce { + qualify = true + } + } + } + + if qualify { + return nil + } else { + return m + } + } +} + +func RemoveFieldFilter(fields []string) FilterFunc { + return func(m map[string]interface{}) map[string]interface{} { + for _, field := range fields { + delete(m, field) + } + return m + } +} + +func AddUriFieldFilter(inField, outFirld string) FilterFunc { + return func(m map[string]interface{}) map[string]interface{} { + if val, ok := m[inField].(string); ok { + var datas []string + idx := strings.Index(val, "?") + if idx < 0 { + datas = strings.Split(val, "/") + } else { + datas = strings.Split(val[:idx], "/") + } + + for i, data := range datas { + if bson.IsObjectIdHex(data) { + datas[i] = "*" + } + } + + m[outFirld] = strings.Join(datas, "/") + } + + return m + } +} diff --git a/stash/handler/handler.go b/stash/handler/handler.go new file mode 100644 index 00000000..c278542f --- /dev/null +++ b/stash/handler/handler.go @@ -0,0 +1,64 @@ +package handler + +import ( + "time" + + "zero/stash/es" + "zero/stash/filter" + + jsoniter "github.com/json-iterator/go" +) + +const ( + timestampFormat = "2006-01-02T15:04:05.000Z" + timestampKey = "@timestamp" +) + +type MessageHandler struct { + writer *es.Writer + filters []filter.FilterFunc +} + +func NewHandler(writer *es.Writer) *MessageHandler { + return &MessageHandler{ + writer: writer, + } +} + +func (mh *MessageHandler) AddFilters(filters ...filter.FilterFunc) { + for _, f := range filters { + mh.filters = append(mh.filters, f) + } +} + +func (mh *MessageHandler) Consume(_, val string) error { + m := make(map[string]interface{}) + if err := jsoniter.Unmarshal([]byte(val), &m); err != nil { + return err + } + + for _, proc := range mh.filters { + if m = proc(m); m == nil { + return nil + } + } + + bs, err := jsoniter.Marshal(m) + if err != nil { + return err + } + + return mh.writer.Write(mh.getTime(m), string(bs)) +} + +func (mh *MessageHandler) getTime(m map[string]interface{}) time.Time { + if ti, ok := m[timestampKey]; ok { + if ts, ok := ti.(string); ok { + if t, err := time.Parse(timestampFormat, ts); err == nil { + return t + } + } + } + + return time.Now() +} diff --git a/stash/stash.go b/stash/stash.go new file mode 100644 index 00000000..b0eddf9c --- /dev/null +++ b/stash/stash.go @@ -0,0 +1,57 @@ +package main + +import ( + "flag" + "time" + + "zero/core/conf" + "zero/core/lang" + "zero/core/proc" + "zero/kq" + "zero/stash/config" + "zero/stash/es" + "zero/stash/filter" + "zero/stash/handler" + + "github.com/olivere/elastic" +) + +const dateFormat = "2006.01.02" + +var configFile = flag.String("f", "etc/config.json", "Specify the config file") + +func main() { + flag.Parse() + + var c config.Config + conf.MustLoad(*configFile, &c) + proc.SetTimeoutToForceQuit(c.GracePeriod) + + client, err := elastic.NewClient( + elastic.SetSniff(false), + elastic.SetURL(c.Output.ElasticSearch.Hosts...), + ) + lang.Must(err) + + indexFormat := c.Output.ElasticSearch.DailyIndexPrefix + dateFormat + var loc *time.Location + if len(c.Output.ElasticSearch.TimeZone) > 0 { + loc, err = time.LoadLocation(c.Output.ElasticSearch.TimeZone) + lang.Must(err) + } else { + loc = time.Local + } + indexer := es.NewIndex(client, func(t time.Time) string { + return t.In(loc).Format(indexFormat) + }) + + filters := filter.CreateFilters(c) + writer, err := es.NewWriter(c.Output.ElasticSearch, indexer) + lang.Must(err) + + handle := handler.NewHandler(writer) + handle.AddFilters(filters...) + handle.AddFilters(filter.AddUriFieldFilter("url", "uri")) + q := kq.MustNewQueue(c.Input.Kafka, handle) + q.Start() +}