fix: etcd discovery mechanism on grpc with idle manager (#4589)

This commit is contained in:
Kevin Wan 2025-01-22 14:01:18 +08:00 committed by GitHub
parent 33011c7ed1
commit bf883101d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 494 additions and 285 deletions

View File

@ -5,28 +5,29 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"slices"
"sort" "sort"
"strings" "strings"
"sync" "sync"
"time" "time"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/zeromicro/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading" "github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3"
) )
const coolDownDeviation = 0.05
var ( var (
registry = Registry{ registry = Registry{
clusters: make(map[string]*cluster), clusters: make(map[string]*cluster),
} }
connManager = syncx.NewResourceManager() connManager = syncx.NewResourceManager()
errClosed = errors.New("etcd monitor chan has been closed") coolDownUnstable = mathx.NewUnstable(coolDownDeviation)
errClosed = errors.New("etcd monitor chan has been closed")
) )
// A Registry is a registry that manages the etcd client connections. // A Registry is a registry that manages the etcd client connections.
@ -42,44 +43,92 @@ func GetRegistry() *Registry {
// GetConn returns an etcd client connection associated with given endpoints. // GetConn returns an etcd client connection associated with given endpoints.
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) { func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
c, _ := r.getCluster(endpoints) c, _ := r.getOrCreateCluster(endpoints)
return c.getClient() return c.getClient()
} }
// Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener. // Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener.
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener, exactMatch bool) error { func (r *Registry) Monitor(endpoints []string, key string, exactMatch bool, l UpdateListener) error {
c, exists := r.getCluster(endpoints) wkey := watchKey{
key: key,
exactMatch: exactMatch,
}
c, exists := r.getOrCreateCluster(endpoints)
// if exists, the existing values should be updated to the listener. // if exists, the existing values should be updated to the listener.
if exists { if exists {
kvs := c.getCurrent(key) c.lock.Lock()
for _, kv := range kvs { watcher, ok := c.watchers[wkey]
l.OnAdd(kv) if ok {
watcher.listeners = append(watcher.listeners, l)
}
c.lock.Unlock()
if ok {
kvs := c.getCurrent(wkey)
for _, kv := range kvs {
l.OnAdd(kv)
}
return nil
} }
} }
return c.monitor(key, l, exactMatch) return c.monitor(wkey, l)
} }
// Unmonitor cancel monitoring of given endpoints and keys, and remove the listener. func (r *Registry) Unmonitor(endpoints []string, key string, exactMatch bool, l UpdateListener) {
func (r *Registry) Unmonitor(endpoints []string, key string, l UpdateListener) {
c, exists := r.getCluster(endpoints) c, exists := r.getCluster(endpoints)
// if not exists, return.
if !exists { if !exists {
return return
} }
c.unmonitor(key, l) wkey := watchKey{
key: key,
exactMatch: exactMatch,
}
c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[wkey]
if !ok {
return
}
for i, listener := range watcher.listeners {
if listener == l {
watcher.listeners = append(watcher.listeners[:i], watcher.listeners[i+1:]...)
break
}
}
if len(watcher.listeners) == 0 {
if watcher.cancel != nil {
watcher.cancel()
}
delete(c.watchers, wkey)
}
} }
func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) { func (r *Registry) getCluster(endpoints []string) (*cluster, bool) {
clusterKey := getClusterKey(endpoints) clusterKey := getClusterKey(endpoints)
r.lock.RLock() r.lock.RLock()
c, exists = r.clusters[clusterKey] c, ok := r.clusters[clusterKey]
r.lock.RUnlock() r.lock.RUnlock()
return c, ok
}
func (r *Registry) getOrCreateCluster(endpoints []string) (c *cluster, exists bool) {
c, exists = r.getCluster(endpoints)
if !exists { if !exists {
clusterKey := getClusterKey(endpoints)
r.lock.Lock() r.lock.Lock()
defer r.lock.Unlock() defer r.lock.Unlock()
// double-check locking // double-check locking
c, exists = r.clusters[clusterKey] c, exists = r.clusters[clusterKey]
if !exists { if !exists {
@ -91,34 +140,51 @@ func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) {
return return
} }
type cluster struct { type (
endpoints []string watchKey struct {
key string key string
values map[string]map[string]string exactMatch bool
listeners map[string][]UpdateListener }
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType watchValue struct {
lock sync.RWMutex listeners []UpdateListener
exactMatch bool values map[string]string
watchCtx map[string]context.CancelFunc cancel context.CancelFunc
watchFlag map[string]bool }
}
cluster struct {
endpoints []string
key string
watchers map[watchKey]*watchValue
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType
lock sync.RWMutex
}
)
func newCluster(endpoints []string) *cluster { func newCluster(endpoints []string) *cluster {
return &cluster{ return &cluster{
endpoints: endpoints, endpoints: endpoints,
key: getClusterKey(endpoints), key: getClusterKey(endpoints),
values: make(map[string]map[string]string), watchers: make(map[watchKey]*watchValue),
listeners: make(map[string][]UpdateListener),
watchGroup: threading.NewRoutineGroup(), watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType), done: make(chan lang.PlaceholderType),
watchCtx: make(map[string]context.CancelFunc),
watchFlag: make(map[string]bool),
} }
} }
func (c *cluster) context(cli EtcdClient) context.Context { func (c *cluster) addListener(key watchKey, l UpdateListener) {
return contextx.ValueOnlyFrom(cli.Ctx()) c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[key]
if ok {
watcher.listeners = append(watcher.listeners, l)
return
}
val := newWatchValue()
val.listeners = []UpdateListener{l}
c.watchers[key] = val
} }
func (c *cluster) getClient() (EtcdClient, error) { func (c *cluster) getClient() (EtcdClient, error) {
@ -132,12 +198,17 @@ func (c *cluster) getClient() (EtcdClient, error) {
return val.(EtcdClient), nil return val.(EtcdClient), nil
} }
func (c *cluster) getCurrent(key string) []KV { func (c *cluster) getCurrent(key watchKey) []KV {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
watcher, ok := c.watchers[key]
if !ok {
return nil
}
var kvs []KV var kvs []KV
for k, v := range c.values[key] { for k, v := range watcher.values {
kvs = append(kvs, KV{ kvs = append(kvs, KV{
Key: k, Key: k,
Val: v, Val: v,
@ -147,43 +218,23 @@ func (c *cluster) getCurrent(key string) []KV {
return kvs return kvs
} }
func (c *cluster) handleChanges(key string, kvs []KV) { func (c *cluster) handleChanges(key watchKey, kvs []KV) {
var add []KV
var remove []KV
c.lock.Lock() c.lock.Lock()
listeners := append([]UpdateListener(nil), c.listeners[key]...) watcher, ok := c.watchers[key]
vals, ok := c.values[key]
if !ok { if !ok {
add = kvs c.lock.Unlock()
vals = make(map[string]string) return
for _, kv := range kvs {
vals[kv.Key] = kv.Val
}
c.values[key] = vals
} else {
m := make(map[string]string)
for _, kv := range kvs {
m[kv.Key] = kv.Val
}
for k, v := range vals {
if val, ok := m[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
for k, v := range m {
if val, ok := vals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
c.values[key] = m
} }
listeners := append([]UpdateListener(nil), watcher.listeners...)
// watcher.values cannot be nil
vals := watcher.values
newVals := make(map[string]string, len(kvs)+len(vals))
for _, kv := range kvs {
newVals[kv.Key] = kv.Val
}
add, remove := calculateChanges(vals, newVals)
watcher.values = newVals
c.lock.Unlock() c.lock.Unlock()
for _, kv := range add { for _, kv := range add {
@ -198,20 +249,22 @@ func (c *cluster) handleChanges(key string, kvs []KV) {
} }
} }
func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) { func (c *cluster) handleWatchEvents(key watchKey, events []*clientv3.Event) {
c.lock.RLock() c.lock.RLock()
listeners := append([]UpdateListener(nil), c.listeners[key]...) watcher, ok := c.watchers[key]
if !ok {
c.lock.RUnlock()
return
}
listeners := append([]UpdateListener(nil), watcher.listeners...)
c.lock.RUnlock() c.lock.RUnlock()
for _, ev := range events { for _, ev := range events {
switch ev.Type { switch ev.Type {
case clientv3.EventTypePut: case clientv3.EventTypePut:
c.lock.Lock() c.lock.Lock()
if vals, ok := c.values[key]; ok { watcher.values[string(ev.Kv.Key)] = string(ev.Kv.Value)
vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
} else {
c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
}
c.lock.Unlock() c.lock.Unlock()
for _, l := range listeners { for _, l := range listeners {
l.OnAdd(KV{ l.OnAdd(KV{
@ -221,9 +274,7 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
} }
case clientv3.EventTypeDelete: case clientv3.EventTypeDelete:
c.lock.Lock() c.lock.Lock()
if vals, ok := c.values[key]; ok { delete(watcher.values, string(ev.Kv.Key))
delete(vals, string(ev.Kv.Key))
}
c.lock.Unlock() c.lock.Unlock()
for _, l := range listeners { for _, l := range listeners {
l.OnDelete(KV{ l.OnDelete(KV{
@ -237,15 +288,15 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
} }
} }
func (c *cluster) load(cli EtcdClient, key string) int64 { func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
var resp *clientv3.GetResponse var resp *clientv3.GetResponse
for { for {
var err error var err error
ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout) ctx, cancel := context.WithTimeout(cli.Ctx(), RequestTimeout)
if c.exactMatch { if key.exactMatch {
resp, err = cli.Get(ctx, key) resp, err = cli.Get(ctx, key.key)
} else { } else {
resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix()) resp, err = cli.Get(ctx, makeKeyPrefix(key.key), clientv3.WithPrefix())
} }
cancel() cancel()
@ -253,8 +304,8 @@ func (c *cluster) load(cli EtcdClient, key string) int64 {
break break
} }
logx.Errorf("%s, key is %s", err.Error(), key) logx.Errorf("%s, key: %s, exactMatch: %t", err.Error(), key.key, key.exactMatch)
time.Sleep(coolDownInterval) time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
} }
var kvs []KV var kvs []KV
@ -270,33 +321,19 @@ func (c *cluster) load(cli EtcdClient, key string) int64 {
return resp.Header.Revision return resp.Header.Revision
} }
func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error { func (c *cluster) monitor(key watchKey, l UpdateListener) error {
c.lock.Lock() cli, err := c.getClient()
c.listeners[key] = append(c.listeners[key], l) if err != nil {
c.exactMatch = exactMatch return err
c.lock.Unlock()
if !c.watchFlag[key] {
cli, err := c.getClient()
if err != nil {
return err
}
rev := c.load(cli, key)
c.watchGroup.Run(func() {
c.watch(cli, key, rev)
})
} }
return nil c.addListener(key, l)
} rev := c.load(cli, key)
c.watchGroup.Run(func() {
func (c *cluster) unmonitor(key string, l UpdateListener) { c.watch(cli, key, rev)
c.lock.Lock()
defer c.lock.Unlock()
c.listeners[key] = slices.DeleteFunc(c.listeners[key], func(listener UpdateListener) bool {
return l == listener
}) })
return nil
} }
func (c *cluster) newClient() (EtcdClient, error) { func (c *cluster) newClient() (EtcdClient, error) {
@ -312,17 +349,22 @@ func (c *cluster) newClient() (EtcdClient, error) {
func (c *cluster) reload(cli EtcdClient) { func (c *cluster) reload(cli EtcdClient) {
c.lock.Lock() c.lock.Lock()
// cancel the previous watches
close(c.done) close(c.done)
c.watchGroup.Wait() c.watchGroup.Wait()
var keys []watchKey
for wk, wval := range c.watchers {
keys = append(keys, wk)
if wval.cancel != nil {
wval.cancel()
}
}
c.done = make(chan lang.PlaceholderType) c.done = make(chan lang.PlaceholderType)
c.watchGroup = threading.NewRoutineGroup() c.watchGroup = threading.NewRoutineGroup()
var keys []string
for k := range c.listeners {
keys = append(keys, k)
}
c.clearWatch()
c.lock.Unlock() c.lock.Unlock()
// start new watches
for _, key := range keys { for _, key := range keys {
k := key k := key
c.watchGroup.Run(func() { c.watchGroup.Run(func() {
@ -332,10 +374,9 @@ func (c *cluster) reload(cli EtcdClient) {
} }
} }
func (c *cluster) watch(cli EtcdClient, key string, rev int64) { func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) {
ctx := c.addWatch(key, cli)
for { for {
err := c.watchStream(cli, key, rev, ctx) err := c.watchStream(cli, key, rev)
if err == nil { if err == nil {
return return
} }
@ -350,21 +391,8 @@ func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
} }
} }
func (c *cluster) watchStream(cli EtcdClient, key string, rev int64, ctx context.Context) error { func (c *cluster) watchStream(cli EtcdClient, key watchKey, rev int64) error {
var ( ctx, rch := c.setupWatch(cli, key, rev)
rch clientv3.WatchChan
ops []clientv3.OpOption
watchKey = key
)
if !c.exactMatch {
watchKey = makeKeyPrefix(key)
ops = append(ops, clientv3.WithPrefix())
}
if rev != 0 {
ops = append(ops, clientv3.WithRev(rev+1))
}
rch = cli.Watch(clientv3.WithRequireLeader(ctx), watchKey, ops...)
for { for {
select { select {
@ -380,14 +408,46 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64, ctx context
} }
c.handleWatchEvents(key, wresp.Events) c.handleWatchEvents(key, wresp.Events)
case <-c.done:
return nil
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-c.done:
return nil
} }
} }
} }
func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.Context, clientv3.WatchChan) {
var (
rch clientv3.WatchChan
ops []clientv3.OpOption
wkey = key.key
)
if !key.exactMatch {
wkey = makeKeyPrefix(key.key)
ops = append(ops, clientv3.WithPrefix())
}
if rev != 0 {
ops = append(ops, clientv3.WithRev(rev+1))
}
ctx, cancel := context.WithCancel(cli.Ctx())
if watcher, ok := c.watchers[key]; ok {
watcher.cancel = cancel
} else {
val := newWatchValue()
val.cancel = cancel
c.lock.Lock()
c.watchers[key] = val
c.lock.Unlock()
}
rch = cli.Watch(clientv3.WithRequireLeader(ctx), wkey, ops...)
return ctx, rch
}
func (c *cluster) watchConnState(cli EtcdClient) { func (c *cluster) watchConnState(cli EtcdClient) {
watcher := newStateWatcher() watcher := newStateWatcher()
watcher.addListener(func() { watcher.addListener(func() {
@ -396,23 +456,6 @@ func (c *cluster) watchConnState(cli EtcdClient) {
watcher.watch(cli.ActiveConnection()) watcher.watch(cli.ActiveConnection())
} }
func (c *cluster) addWatch(key string, cli EtcdClient) context.Context {
ctx, cancel := context.WithCancel(cli.Ctx())
c.lock.Lock()
c.watchCtx[key] = cancel
c.watchFlag[key] = true
c.lock.Unlock()
return ctx
}
func (c *cluster) clearWatch() {
for _, cancel := range c.watchCtx {
cancel()
}
c.watchCtx = make(map[string]context.CancelFunc)
c.watchFlag = make(map[string]bool)
}
// DialClient dials an etcd cluster with given endpoints. // DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) { func DialClient(endpoints []string) (EtcdClient, error) {
cfg := clientv3.Config{ cfg := clientv3.Config{
@ -433,6 +476,28 @@ func DialClient(endpoints []string) (EtcdClient, error) {
return clientv3.New(cfg) return clientv3.New(cfg)
} }
func calculateChanges(oldVals, newVals map[string]string) (add, remove []KV) {
for k, v := range newVals {
if val, ok := oldVals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
for k, v := range oldVals {
if val, ok := newVals[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
return add, remove
}
func getClusterKey(endpoints []string) string { func getClusterKey(endpoints []string) string {
sort.Strings(endpoints) sort.Strings(endpoints)
return strings.Join(endpoints, endpointsSeparator) return strings.Join(endpoints, endpointsSeparator)
@ -441,3 +506,10 @@ func getClusterKey(endpoints []string) string {
func makeKeyPrefix(key string) string { func makeKeyPrefix(key string) string {
return fmt.Sprintf("%s%c", key, Delimiter) return fmt.Sprintf("%s%c", key, Delimiter)
} }
// NewClient returns a watchValue that make sure values are not nil.
func newWatchValue() *watchValue {
return &watchValue{
values: make(map[string]string),
}
}

View File

@ -13,6 +13,7 @@ import (
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/etcdserverpb" "go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
@ -38,9 +39,9 @@ func setMockClient(cli EtcdClient) func() {
func TestGetCluster(t *testing.T) { func TestGetCluster(t *testing.T) {
AddAccount([]string{"first"}, "foo", "bar") AddAccount([]string{"first"}, "foo", "bar")
c1, _ := GetRegistry().getCluster([]string{"first"}) c1, _ := GetRegistry().getOrCreateCluster([]string{"first"})
c2, _ := GetRegistry().getCluster([]string{"second"}) c2, _ := GetRegistry().getOrCreateCluster([]string{"second"})
c3, _ := GetRegistry().getCluster([]string{"first"}) c3, _ := GetRegistry().getOrCreateCluster([]string{"first"})
assert.Equal(t, c1, c3) assert.Equal(t, c1, c3)
assert.NotEqual(t, c1, c2) assert.NotEqual(t, c1, c2)
} }
@ -50,6 +51,36 @@ func TestGetClusterKey(t *testing.T) {
getClusterKey([]string{"remotehost:5678", "localhost:1234"})) getClusterKey([]string{"remotehost:5678", "localhost:1234"}))
} }
func TestUnmonitor(t *testing.T) {
t.Run("no listener", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "any", false, nil)
})
})
t.Run("no value", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{
"any": {
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
},
},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "another", false, nil)
})
})
}
func TestCluster_HandleChanges(t *testing.T) { func TestCluster_HandleChanges(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
l := NewMockUpdateListener(ctrl) l := NewMockUpdateListener(ctrl)
@ -78,8 +109,14 @@ func TestCluster_HandleChanges(t *testing.T) {
Val: "4", Val: "4",
}) })
c := newCluster([]string{"any"}) c := newCluster([]string{"any"})
c.listeners["any"] = []UpdateListener{l} key := watchKey{
c.handleChanges("any", []KV{ key: "any",
exactMatch: false,
}
c.watchers[key] = &watchValue{
listeners: []UpdateListener{l},
}
c.handleChanges(key, []KV{
{ {
Key: "first", Key: "first",
Val: "1", Val: "1",
@ -92,8 +129,8 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{ assert.EqualValues(t, map[string]string{
"first": "1", "first": "1",
"second": "2", "second": "2",
}, c.values["any"]) }, c.watchers[key].values)
c.handleChanges("any", []KV{ c.handleChanges(key, []KV{
{ {
Key: "third", Key: "third",
Val: "3", Val: "3",
@ -106,7 +143,7 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{ assert.EqualValues(t, map[string]string{
"third": "3", "third": "3",
"fourth": "4", "fourth": "4",
}, c.values["any"]) }, c.watchers[key].values)
} }
func TestCluster_Load(t *testing.T) { func TestCluster_Load(t *testing.T) {
@ -126,9 +163,11 @@ func TestCluster_Load(t *testing.T) {
}, nil) }, nil)
cli.EXPECT().Ctx().Return(context.Background()) cli.EXPECT().Ctx().Return(context.Background())
c := &cluster{ c := &cluster{
values: make(map[string]map[string]string), watchers: make(map[watchKey]*watchValue),
} }
c.load(cli, "any") c.load(cli, watchKey{
key: "any",
})
} }
func TestCluster_Watch(t *testing.T) { func TestCluster_Watch(t *testing.T) {
@ -160,13 +199,16 @@ func TestCluster_Watch(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
c := &cluster{ c := &cluster{
values: make(map[string]map[string]string), watchers: make(map[watchKey]*watchValue),
listeners: make(map[string][]UpdateListener), }
watchCtx: make(map[string]context.CancelFunc), key := watchKey{
watchFlag: make(map[string]bool), key: "any",
} }
listener := NewMockUpdateListener(ctrl) listener := NewMockUpdateListener(ctrl)
c.listeners["any"] = []UpdateListener{listener} c.watchers[key] = &watchValue{
listeners: []UpdateListener{listener},
values: make(map[string]string),
}
listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) { listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) {
assert.Equal(t, "hello", kv.Key) assert.Equal(t, "hello", kv.Key)
assert.Equal(t, "world", kv.Val) assert.Equal(t, "world", kv.Val)
@ -175,7 +217,7 @@ func TestCluster_Watch(t *testing.T) {
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) { listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) {
wg.Done() wg.Done()
}).MaxTimes(1) }).MaxTimes(1)
go c.watch(cli, "any", 0) go c.watch(cli, key, 0)
ch <- clientv3.WatchResponse{ ch <- clientv3.WatchResponse{
Events: []*clientv3.Event{ Events: []*clientv3.Event{
{ {
@ -213,17 +255,111 @@ func TestClusterWatch_RespFailures(t *testing.T) {
ch := make(chan clientv3.WatchResponse) ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := newCluster([]string{}) c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
c.done = make(chan lang.PlaceholderType) c.done = make(chan lang.PlaceholderType)
go func() { go func() {
ch <- resp ch <- resp
close(c.done) close(c.done)
}() }()
c.watch(cli, "any", 0) key := watchKey{
key: "any",
}
c.watch(cli, key, 0)
}) })
} }
} }
func TestCluster_getCurrent(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.Nil(t, c.getCurrent(watchKey{
key: "another",
}))
})
}
func TestCluster_handleWatchEvents(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.NotPanics(t, func() {
c.handleWatchEvents(watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_addListener(t *testing.T) {
t.Run("has listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "any",
}, nil)
})
})
t.Run("no listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_reload(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{},
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
assert.NotPanics(t, func() {
c.reload(cli)
})
}
func TestClusterWatch_CloseChan(t *testing.T) { func TestClusterWatch_CloseChan(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
@ -233,44 +369,17 @@ func TestClusterWatch_CloseChan(t *testing.T) {
ch := make(chan clientv3.WatchResponse) ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := newCluster([]string{}) c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
c.done = make(chan lang.PlaceholderType) c.done = make(chan lang.PlaceholderType)
go func() { go func() {
close(ch) close(ch)
close(c.done) close(c.done)
}() }()
c.watch(cli, "any", 0) c.watch(cli, watchKey{
} key: "any",
}, 0)
func TestClusterWatch_CtxCancel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
ctx, cancelFunc := context.WithCancel(context.Background())
cli.EXPECT().Ctx().Return(ctx).AnyTimes()
c := newCluster([]string{})
c.done = make(chan lang.PlaceholderType)
go func() {
cancelFunc()
close(ch)
}()
c.watch(cli, "any", 0)
}
func TestCluster_ClearWatch(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
c := &cluster{
watchCtx: map[string]context.CancelFunc{"foo": cancel},
watchFlag: map[string]bool{"foo": true},
}
c.clearWatch()
assert.Equal(t, ctx.Err(), context.Canceled)
assert.Equal(t, 0, len(c.watchCtx))
assert.Equal(t, 0, len(c.watchFlag))
} }
func TestValueOnlyContext(t *testing.T) { func TestValueOnlyContext(t *testing.T) {
@ -313,39 +422,59 @@ func TestRegistry_Monitor(t *testing.T) {
GetRegistry().lock.Lock() GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{ GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): { getClusterKey(endpoints): {
listeners: map[string][]UpdateListener{}, watchers: map[watchKey]*watchValue{
values: map[string]map[string]string{ watchKey{
"foo": { key: "foo",
"bar": "baz", exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
}, },
}, },
watchCtx: map[string]context.CancelFunc{},
watchFlag: map[string]bool{},
}, },
} }
GetRegistry().lock.Unlock() GetRegistry().lock.Unlock()
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener), false)) assert.Error(t, GetRegistry().Monitor(endpoints, "foo", false, new(mockListener)))
} }
func TestRegistry_Unmonitor(t *testing.T) { func TestRegistry_Unmonitor(t *testing.T) {
l := new(mockListener) svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
_, cancel := context.WithCancel(context.Background())
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock() GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{ GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): { getClusterKey(endpoints): {
listeners: map[string][]UpdateListener{"foo": {l}}, watchers: map[watchKey]*watchValue{
values: map[string]map[string]string{ watchKey{
"foo": { key: "foo",
"bar": "baz", exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
cancel: cancel,
}, },
}, },
}, },
} }
GetRegistry().lock.Unlock() GetRegistry().lock.Unlock()
l := new(mockListener) l := new(mockListener)
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", l, false)) assert.NoError(t, GetRegistry().Monitor(endpoints, "foo", true, l))
assert.Equal(t, 1, len(GetRegistry().clusters[getClusterKey(endpoints)].listeners["foo"])) watchVals := GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
GetRegistry().Unmonitor(endpoints, "foo", l) key: "foo",
assert.Equal(t, 0, len(GetRegistry().clusters[getClusterKey(endpoints)].listeners["foo"])) exactMatch: true,
}]
assert.Equal(t, 1, len(watchVals.listeners))
GetRegistry().Unmonitor(endpoints, "foo", true, l)
watchVals = GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
key: "foo",
exactMatch: true,
}]
assert.Nil(t, watchVals)
} }
type mockListener struct { type mockListener struct {

View File

@ -10,6 +10,7 @@ type (
} }
// UpdateListener wraps the OnAdd and OnDelete methods. // UpdateListener wraps the OnAdd and OnDelete methods.
// The implementation should be thread-safe and idempotent.
UpdateListener interface { UpdateListener interface {
OnAdd(kv KV) OnAdd(kv KV)
OnDelete(kv KV) OnDelete(kv KV)

View File

@ -4,7 +4,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
@ -17,10 +16,9 @@ type (
// A Subscriber is used to subscribe the given key on an etcd cluster. // A Subscriber is used to subscribe the given key on an etcd cluster.
Subscriber struct { Subscriber struct {
endpoints []string endpoints []string
key string
exclusive bool exclusive bool
exactMatch bool
key string key string
exactMatch bool
items *container items *container
} }
) )
@ -39,7 +37,7 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
} }
sub.items = newContainer(sub.exclusive) sub.items = newContainer(sub.exclusive)
if err := internal.GetRegistry().Monitor(endpoints, key, sub.items, sub.exactMatch); err != nil { if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil {
return nil, err return nil, err
} }
@ -51,16 +49,16 @@ func (s *Subscriber) AddListener(listener func()) {
s.items.addListener(listener) s.items.addListener(listener)
} }
// Close closes the subscriber.
func (s *Subscriber) Close() {
internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.exactMatch, s.items)
}
// Values returns all the subscription values. // Values returns all the subscription values.
func (s *Subscriber) Values() []string { func (s *Subscriber) Values() []string {
return s.items.getValues() return s.items.getValues()
} }
// Close s.
func (s *Subscriber) Close() {
internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.items)
}
// Exclusive means that key value can only be 1:1, // Exclusive means that key value can only be 1:1,
// which means later added value will remove the keys associated with the same value previously. // which means later added value will remove the keys associated with the same value previously.
func Exclusive() SubOption { func Exclusive() SubOption {
@ -92,7 +90,7 @@ func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify boo
type container struct { type container struct {
exclusive bool exclusive bool
values map[string]*collection.Set values map[string][]string
mapping map[string]string mapping map[string]string
snapshot atomic.Value snapshot atomic.Value
dirty *syncx.AtomicBool dirty *syncx.AtomicBool
@ -103,7 +101,7 @@ type container struct {
func newContainer(exclusive bool) *container { func newContainer(exclusive bool) *container {
return &container{ return &container{
exclusive: exclusive, exclusive: exclusive,
values: make(map[string]*collection.Set), values: make(map[string][]string),
mapping: make(map[string]string), mapping: make(map[string]string),
dirty: syncx.ForAtomicBool(true), dirty: syncx.ForAtomicBool(true),
} }
@ -125,21 +123,15 @@ func (c *container) addKv(key, value string) ([]string, bool) {
defer c.lock.Unlock() defer c.lock.Unlock()
c.dirty.Set(true) c.dirty.Set(true)
if c.values[value] == nil { keys := c.values[value]
c.values[value] = collection.NewSet()
}
keys := c.values[value].KeysStr()
previous := append([]string(nil), keys...) previous := append([]string(nil), keys...)
early := len(keys) > 0 early := len(keys) > 0
if c.exclusive && early { if c.exclusive && early {
for _, each := range keys { for _, each := range keys {
c.doRemoveKey(each) c.doRemoveKey(each)
} }
if c.values[value] == nil {
c.values[value] = collection.NewSet()
}
} }
c.values[value].AddStr(key) c.values[value] = append(c.values[value], key)
c.mapping[key] = value c.mapping[key] = value
if early { if early {
@ -162,12 +154,18 @@ func (c *container) doRemoveKey(key string) {
} }
delete(c.mapping, key) delete(c.mapping, key)
if c.values[server] == nil { keys := c.values[server]
return remain := keys[:0]
}
c.values[server].Remove(key)
if c.values[server].Count() == 0 { for _, k := range keys {
if k != key {
remain = append(remain, k)
}
}
if len(remain) > 0 {
c.values[server] = remain
} else {
delete(c.values, server) delete(c.values, server)
} }
} }

View File

@ -214,18 +214,6 @@ func TestSubscriber(t *testing.T) {
assert.Equal(t, int32(1), atomic.LoadInt32(&count)) assert.Equal(t, int32(1), atomic.LoadInt32(&count))
} }
func TestSubscriberClos(t *testing.T) {
l := newContainer(false)
sub := &Subscriber{
endpoints: []string{"localhost:2379"},
key: "foo",
items: l,
}
_ = internal.GetRegistry().Monitor(sub.endpoints, sub.key, l, false)
sub.Close()
assert.Empty(t, sub.items.listeners)
}
func TestWithSubEtcdAccount(t *testing.T) { func TestWithSubEtcdAccount(t *testing.T) {
endpoints := []string{"localhost:2379"} endpoints := []string{"localhost:2379"}
user := stringx.Rand() user := stringx.Rand()
@ -237,3 +225,28 @@ func TestWithSubEtcdAccount(t *testing.T) {
assert.Equal(t, user, account.User) assert.Equal(t, user, account.User)
assert.Equal(t, "bar", account.Pass) assert.Equal(t, "bar", account.Pass)
} }
func TestWithExactMatch(t *testing.T) {
sub := new(Subscriber)
WithExactMatch()(sub)
sub.items = newContainer(sub.exclusive)
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
sub.items.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
func TestSubscriberClose(t *testing.T) {
l := newContainer(false)
sub := &Subscriber{
endpoints: []string{"localhost:12379"},
key: "foo",
items: l,
}
assert.NotPanics(t, func() {
sub.Close()
})
}

View File

@ -38,9 +38,24 @@ func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, _
sub.AddListener(update) sub.AddListener(update)
update() update()
return &nopResolver{cc: cc, closeFunc: func() { sub.Close() }}, nil return &discovResolver{
cc: cc,
sub: sub,
}, nil
} }
func (b *discovBuilder) Scheme() string { func (b *discovBuilder) Scheme() string {
return DiscovScheme return DiscovScheme
} }
type discovResolver struct {
cc resolver.ClientConn
sub *discov.Subscriber
}
func (r *discovResolver) Close() {
r.sub.Close()
}
func (r *discovResolver) ResolveNow(_ resolver.ResolveNowOptions) {
}

View File

@ -28,6 +28,10 @@ type kubeResolver struct {
stopCh chan struct{} stopCh chan struct{}
} }
func (r *kubeResolver) Close() {
close(r.stopCh)
}
func (r *kubeResolver) ResolveNow(_ resolver.ResolveNowOptions) {} func (r *kubeResolver) ResolveNow(_ resolver.ResolveNowOptions) {}
func (r *kubeResolver) start() { func (r *kubeResolver) start() {
@ -36,10 +40,6 @@ func (r *kubeResolver) start() {
}) })
} }
func (r *kubeResolver) Close() {
close(r.stopCh)
}
type kubeBuilder struct{} type kubeBuilder struct{}
func (b *kubeBuilder) Build(target resolver.Target, cc resolver.ClientConn, func (b *kubeBuilder) Build(target resolver.Target, cc resolver.ClientConn,

View File

@ -37,14 +37,10 @@ func register() {
} }
type nopResolver struct { type nopResolver struct {
cc resolver.ClientConn cc resolver.ClientConn
closeFunc func()
} }
func (r *nopResolver) Close() { func (r *nopResolver) Close() {
if r.closeFunc != nil {
r.closeFunc()
}
} }
func (r *nopResolver) ResolveNow(_ resolver.ResolveNowOptions) { func (r *nopResolver) ResolveNow(_ resolver.ResolveNowOptions) {

View File

@ -1,7 +1,6 @@
package internal package internal
import ( import (
"github.com/zeromicro/go-zero/core/discov"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -19,20 +18,6 @@ func TestNopResolver(t *testing.T) {
}) })
} }
func TestNopResolver_Close(t *testing.T) {
var isChanged bool
r := nopResolver{}
r.Close()
assert.False(t, isChanged)
r = nopResolver{
closeFunc: func() {
isChanged = true
},
}
r.Close()
assert.True(t, isChanged)
}
type mockedClientConn struct { type mockedClientConn struct {
state resolver.State state resolver.State
err error err error