diff --git a/core/discov/internal/registry.go b/core/discov/internal/registry.go index bc26a588..f440adc3 100644 --- a/core/discov/internal/registry.go +++ b/core/discov/internal/registry.go @@ -5,28 +5,29 @@ import ( "errors" "fmt" "io" - "slices" "sort" "strings" "sync" "time" - "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" - clientv3 "go.etcd.io/etcd/client/v3" - - "github.com/zeromicro/go-zero/core/contextx" "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/mathx" "github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/threading" + "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" + clientv3 "go.etcd.io/etcd/client/v3" ) +const coolDownDeviation = 0.05 + var ( registry = Registry{ clusters: make(map[string]*cluster), } - connManager = syncx.NewResourceManager() - errClosed = errors.New("etcd monitor chan has been closed") + connManager = syncx.NewResourceManager() + coolDownUnstable = mathx.NewUnstable(coolDownDeviation) + errClosed = errors.New("etcd monitor chan has been closed") ) // A Registry is a registry that manages the etcd client connections. @@ -42,44 +43,92 @@ func GetRegistry() *Registry { // GetConn returns an etcd client connection associated with given endpoints. func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) { - c, _ := r.getCluster(endpoints) + c, _ := r.getOrCreateCluster(endpoints) return c.getClient() } // Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener. -func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener, exactMatch bool) error { - c, exists := r.getCluster(endpoints) +func (r *Registry) Monitor(endpoints []string, key string, exactMatch bool, l UpdateListener) error { + wkey := watchKey{ + key: key, + exactMatch: exactMatch, + } + + c, exists := r.getOrCreateCluster(endpoints) // if exists, the existing values should be updated to the listener. if exists { - kvs := c.getCurrent(key) - for _, kv := range kvs { - l.OnAdd(kv) + c.lock.Lock() + watcher, ok := c.watchers[wkey] + if ok { + watcher.listeners = append(watcher.listeners, l) + } + c.lock.Unlock() + + if ok { + kvs := c.getCurrent(wkey) + for _, kv := range kvs { + l.OnAdd(kv) + } + + return nil } } - return c.monitor(key, l, exactMatch) + return c.monitor(wkey, l) } -// Unmonitor cancel monitoring of given endpoints and keys, and remove the listener. -func (r *Registry) Unmonitor(endpoints []string, key string, l UpdateListener) { +func (r *Registry) Unmonitor(endpoints []string, key string, exactMatch bool, l UpdateListener) { c, exists := r.getCluster(endpoints) - // if not exists, return. if !exists { return } - c.unmonitor(key, l) + wkey := watchKey{ + key: key, + exactMatch: exactMatch, + } + + c.lock.Lock() + defer c.lock.Unlock() + + watcher, ok := c.watchers[wkey] + if !ok { + return + } + + for i, listener := range watcher.listeners { + if listener == l { + watcher.listeners = append(watcher.listeners[:i], watcher.listeners[i+1:]...) + break + } + } + + if len(watcher.listeners) == 0 { + if watcher.cancel != nil { + watcher.cancel() + } + delete(c.watchers, wkey) + } } -func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) { +func (r *Registry) getCluster(endpoints []string) (*cluster, bool) { clusterKey := getClusterKey(endpoints) + r.lock.RLock() - c, exists = r.clusters[clusterKey] + c, ok := r.clusters[clusterKey] r.lock.RUnlock() + return c, ok +} + +func (r *Registry) getOrCreateCluster(endpoints []string) (c *cluster, exists bool) { + c, exists = r.getCluster(endpoints) if !exists { + clusterKey := getClusterKey(endpoints) + r.lock.Lock() defer r.lock.Unlock() + // double-check locking c, exists = r.clusters[clusterKey] if !exists { @@ -91,34 +140,51 @@ func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) { return } -type cluster struct { - endpoints []string - key string - values map[string]map[string]string - listeners map[string][]UpdateListener - watchGroup *threading.RoutineGroup - done chan lang.PlaceholderType - lock sync.RWMutex - exactMatch bool - watchCtx map[string]context.CancelFunc - watchFlag map[string]bool -} +type ( + watchKey struct { + key string + exactMatch bool + } + + watchValue struct { + listeners []UpdateListener + values map[string]string + cancel context.CancelFunc + } + + cluster struct { + endpoints []string + key string + watchers map[watchKey]*watchValue + watchGroup *threading.RoutineGroup + done chan lang.PlaceholderType + lock sync.RWMutex + } +) func newCluster(endpoints []string) *cluster { return &cluster{ endpoints: endpoints, key: getClusterKey(endpoints), - values: make(map[string]map[string]string), - listeners: make(map[string][]UpdateListener), + watchers: make(map[watchKey]*watchValue), watchGroup: threading.NewRoutineGroup(), done: make(chan lang.PlaceholderType), - watchCtx: make(map[string]context.CancelFunc), - watchFlag: make(map[string]bool), } } -func (c *cluster) context(cli EtcdClient) context.Context { - return contextx.ValueOnlyFrom(cli.Ctx()) +func (c *cluster) addListener(key watchKey, l UpdateListener) { + c.lock.Lock() + defer c.lock.Unlock() + + watcher, ok := c.watchers[key] + if ok { + watcher.listeners = append(watcher.listeners, l) + return + } + + val := newWatchValue() + val.listeners = []UpdateListener{l} + c.watchers[key] = val } func (c *cluster) getClient() (EtcdClient, error) { @@ -132,12 +198,17 @@ func (c *cluster) getClient() (EtcdClient, error) { return val.(EtcdClient), nil } -func (c *cluster) getCurrent(key string) []KV { +func (c *cluster) getCurrent(key watchKey) []KV { c.lock.RLock() defer c.lock.RUnlock() + watcher, ok := c.watchers[key] + if !ok { + return nil + } + var kvs []KV - for k, v := range c.values[key] { + for k, v := range watcher.values { kvs = append(kvs, KV{ Key: k, Val: v, @@ -147,43 +218,23 @@ func (c *cluster) getCurrent(key string) []KV { return kvs } -func (c *cluster) handleChanges(key string, kvs []KV) { - var add []KV - var remove []KV - +func (c *cluster) handleChanges(key watchKey, kvs []KV) { c.lock.Lock() - listeners := append([]UpdateListener(nil), c.listeners[key]...) - vals, ok := c.values[key] + watcher, ok := c.watchers[key] if !ok { - add = kvs - vals = make(map[string]string) - for _, kv := range kvs { - vals[kv.Key] = kv.Val - } - c.values[key] = vals - } else { - m := make(map[string]string) - for _, kv := range kvs { - m[kv.Key] = kv.Val - } - for k, v := range vals { - if val, ok := m[k]; !ok || v != val { - remove = append(remove, KV{ - Key: k, - Val: v, - }) - } - } - for k, v := range m { - if val, ok := vals[k]; !ok || v != val { - add = append(add, KV{ - Key: k, - Val: v, - }) - } - } - c.values[key] = m + c.lock.Unlock() + return } + + listeners := append([]UpdateListener(nil), watcher.listeners...) + // watcher.values cannot be nil + vals := watcher.values + newVals := make(map[string]string, len(kvs)+len(vals)) + for _, kv := range kvs { + newVals[kv.Key] = kv.Val + } + add, remove := calculateChanges(vals, newVals) + watcher.values = newVals c.lock.Unlock() for _, kv := range add { @@ -198,20 +249,22 @@ func (c *cluster) handleChanges(key string, kvs []KV) { } } -func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) { +func (c *cluster) handleWatchEvents(key watchKey, events []*clientv3.Event) { c.lock.RLock() - listeners := append([]UpdateListener(nil), c.listeners[key]...) + watcher, ok := c.watchers[key] + if !ok { + c.lock.RUnlock() + return + } + + listeners := append([]UpdateListener(nil), watcher.listeners...) c.lock.RUnlock() for _, ev := range events { switch ev.Type { case clientv3.EventTypePut: c.lock.Lock() - if vals, ok := c.values[key]; ok { - vals[string(ev.Kv.Key)] = string(ev.Kv.Value) - } else { - c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)} - } + watcher.values[string(ev.Kv.Key)] = string(ev.Kv.Value) c.lock.Unlock() for _, l := range listeners { l.OnAdd(KV{ @@ -221,9 +274,7 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) { } case clientv3.EventTypeDelete: c.lock.Lock() - if vals, ok := c.values[key]; ok { - delete(vals, string(ev.Kv.Key)) - } + delete(watcher.values, string(ev.Kv.Key)) c.lock.Unlock() for _, l := range listeners { l.OnDelete(KV{ @@ -237,15 +288,15 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) { } } -func (c *cluster) load(cli EtcdClient, key string) int64 { +func (c *cluster) load(cli EtcdClient, key watchKey) int64 { var resp *clientv3.GetResponse for { var err error - ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout) - if c.exactMatch { - resp, err = cli.Get(ctx, key) + ctx, cancel := context.WithTimeout(cli.Ctx(), RequestTimeout) + if key.exactMatch { + resp, err = cli.Get(ctx, key.key) } else { - resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix()) + resp, err = cli.Get(ctx, makeKeyPrefix(key.key), clientv3.WithPrefix()) } cancel() @@ -253,8 +304,8 @@ func (c *cluster) load(cli EtcdClient, key string) int64 { break } - logx.Errorf("%s, key is %s", err.Error(), key) - time.Sleep(coolDownInterval) + logx.Errorf("%s, key: %s, exactMatch: %t", err.Error(), key.key, key.exactMatch) + time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval)) } var kvs []KV @@ -270,33 +321,19 @@ func (c *cluster) load(cli EtcdClient, key string) int64 { return resp.Header.Revision } -func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error { - c.lock.Lock() - c.listeners[key] = append(c.listeners[key], l) - c.exactMatch = exactMatch - c.lock.Unlock() - - if !c.watchFlag[key] { - cli, err := c.getClient() - if err != nil { - return err - } - - rev := c.load(cli, key) - c.watchGroup.Run(func() { - c.watch(cli, key, rev) - }) +func (c *cluster) monitor(key watchKey, l UpdateListener) error { + cli, err := c.getClient() + if err != nil { + return err } - return nil -} - -func (c *cluster) unmonitor(key string, l UpdateListener) { - c.lock.Lock() - defer c.lock.Unlock() - c.listeners[key] = slices.DeleteFunc(c.listeners[key], func(listener UpdateListener) bool { - return l == listener + c.addListener(key, l) + rev := c.load(cli, key) + c.watchGroup.Run(func() { + c.watch(cli, key, rev) }) + + return nil } func (c *cluster) newClient() (EtcdClient, error) { @@ -312,17 +349,22 @@ func (c *cluster) newClient() (EtcdClient, error) { func (c *cluster) reload(cli EtcdClient) { c.lock.Lock() + // cancel the previous watches close(c.done) c.watchGroup.Wait() + var keys []watchKey + for wk, wval := range c.watchers { + keys = append(keys, wk) + if wval.cancel != nil { + wval.cancel() + } + } + c.done = make(chan lang.PlaceholderType) c.watchGroup = threading.NewRoutineGroup() - var keys []string - for k := range c.listeners { - keys = append(keys, k) - } - c.clearWatch() c.lock.Unlock() + // start new watches for _, key := range keys { k := key c.watchGroup.Run(func() { @@ -332,10 +374,9 @@ func (c *cluster) reload(cli EtcdClient) { } } -func (c *cluster) watch(cli EtcdClient, key string, rev int64) { - ctx := c.addWatch(key, cli) +func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) { for { - err := c.watchStream(cli, key, rev, ctx) + err := c.watchStream(cli, key, rev) if err == nil { return } @@ -350,21 +391,8 @@ func (c *cluster) watch(cli EtcdClient, key string, rev int64) { } } -func (c *cluster) watchStream(cli EtcdClient, key string, rev int64, ctx context.Context) error { - var ( - rch clientv3.WatchChan - ops []clientv3.OpOption - watchKey = key - ) - if !c.exactMatch { - watchKey = makeKeyPrefix(key) - ops = append(ops, clientv3.WithPrefix()) - } - if rev != 0 { - ops = append(ops, clientv3.WithRev(rev+1)) - } - - rch = cli.Watch(clientv3.WithRequireLeader(ctx), watchKey, ops...) +func (c *cluster) watchStream(cli EtcdClient, key watchKey, rev int64) error { + ctx, rch := c.setupWatch(cli, key, rev) for { select { @@ -380,14 +408,46 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64, ctx context } c.handleWatchEvents(key, wresp.Events) - case <-c.done: - return nil case <-ctx.Done(): return nil + case <-c.done: + return nil } } } +func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.Context, clientv3.WatchChan) { + var ( + rch clientv3.WatchChan + ops []clientv3.OpOption + wkey = key.key + ) + + if !key.exactMatch { + wkey = makeKeyPrefix(key.key) + ops = append(ops, clientv3.WithPrefix()) + } + if rev != 0 { + ops = append(ops, clientv3.WithRev(rev+1)) + } + + ctx, cancel := context.WithCancel(cli.Ctx()) + if watcher, ok := c.watchers[key]; ok { + watcher.cancel = cancel + } else { + val := newWatchValue() + val.cancel = cancel + + c.lock.Lock() + c.watchers[key] = val + c.lock.Unlock() + } + + rch = cli.Watch(clientv3.WithRequireLeader(ctx), wkey, ops...) + + return ctx, rch +} + func (c *cluster) watchConnState(cli EtcdClient) { watcher := newStateWatcher() watcher.addListener(func() { @@ -396,23 +456,6 @@ func (c *cluster) watchConnState(cli EtcdClient) { watcher.watch(cli.ActiveConnection()) } -func (c *cluster) addWatch(key string, cli EtcdClient) context.Context { - ctx, cancel := context.WithCancel(cli.Ctx()) - c.lock.Lock() - c.watchCtx[key] = cancel - c.watchFlag[key] = true - c.lock.Unlock() - return ctx -} - -func (c *cluster) clearWatch() { - for _, cancel := range c.watchCtx { - cancel() - } - c.watchCtx = make(map[string]context.CancelFunc) - c.watchFlag = make(map[string]bool) -} - // DialClient dials an etcd cluster with given endpoints. func DialClient(endpoints []string) (EtcdClient, error) { cfg := clientv3.Config{ @@ -433,6 +476,28 @@ func DialClient(endpoints []string) (EtcdClient, error) { return clientv3.New(cfg) } +func calculateChanges(oldVals, newVals map[string]string) (add, remove []KV) { + for k, v := range newVals { + if val, ok := oldVals[k]; !ok || v != val { + add = append(add, KV{ + Key: k, + Val: v, + }) + } + } + + for k, v := range oldVals { + if val, ok := newVals[k]; !ok || v != val { + remove = append(remove, KV{ + Key: k, + Val: v, + }) + } + } + + return add, remove +} + func getClusterKey(endpoints []string) string { sort.Strings(endpoints) return strings.Join(endpoints, endpointsSeparator) @@ -441,3 +506,10 @@ func getClusterKey(endpoints []string) string { func makeKeyPrefix(key string) string { return fmt.Sprintf("%s%c", key, Delimiter) } + +// NewClient returns a watchValue that make sure values are not nil. +func newWatchValue() *watchValue { + return &watchValue{ + values: make(map[string]string), + } +} diff --git a/core/discov/internal/registry_test.go b/core/discov/internal/registry_test.go index e25b9757..2f933364 100644 --- a/core/discov/internal/registry_test.go +++ b/core/discov/internal/registry_test.go @@ -13,6 +13,7 @@ import ( "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stringx" + "github.com/zeromicro/go-zero/core/threading" "go.etcd.io/etcd/api/v3/etcdserverpb" "go.etcd.io/etcd/api/v3/mvccpb" clientv3 "go.etcd.io/etcd/client/v3" @@ -38,9 +39,9 @@ func setMockClient(cli EtcdClient) func() { func TestGetCluster(t *testing.T) { AddAccount([]string{"first"}, "foo", "bar") - c1, _ := GetRegistry().getCluster([]string{"first"}) - c2, _ := GetRegistry().getCluster([]string{"second"}) - c3, _ := GetRegistry().getCluster([]string{"first"}) + c1, _ := GetRegistry().getOrCreateCluster([]string{"first"}) + c2, _ := GetRegistry().getOrCreateCluster([]string{"second"}) + c3, _ := GetRegistry().getOrCreateCluster([]string{"first"}) assert.Equal(t, c1, c3) assert.NotEqual(t, c1, c2) } @@ -50,6 +51,36 @@ func TestGetClusterKey(t *testing.T) { getClusterKey([]string{"remotehost:5678", "localhost:1234"})) } +func TestUnmonitor(t *testing.T) { + t.Run("no listener", func(t *testing.T) { + reg := &Registry{ + clusters: map[string]*cluster{}, + } + assert.NotPanics(t, func() { + reg.Unmonitor([]string{"any"}, "any", false, nil) + }) + }) + + t.Run("no value", func(t *testing.T) { + reg := &Registry{ + clusters: map[string]*cluster{ + "any": { + watchers: map[watchKey]*watchValue{ + { + key: "any", + }: { + values: map[string]string{}, + }, + }, + }, + }, + } + assert.NotPanics(t, func() { + reg.Unmonitor([]string{"any"}, "another", false, nil) + }) + }) +} + func TestCluster_HandleChanges(t *testing.T) { ctrl := gomock.NewController(t) l := NewMockUpdateListener(ctrl) @@ -78,8 +109,14 @@ func TestCluster_HandleChanges(t *testing.T) { Val: "4", }) c := newCluster([]string{"any"}) - c.listeners["any"] = []UpdateListener{l} - c.handleChanges("any", []KV{ + key := watchKey{ + key: "any", + exactMatch: false, + } + c.watchers[key] = &watchValue{ + listeners: []UpdateListener{l}, + } + c.handleChanges(key, []KV{ { Key: "first", Val: "1", @@ -92,8 +129,8 @@ func TestCluster_HandleChanges(t *testing.T) { assert.EqualValues(t, map[string]string{ "first": "1", "second": "2", - }, c.values["any"]) - c.handleChanges("any", []KV{ + }, c.watchers[key].values) + c.handleChanges(key, []KV{ { Key: "third", Val: "3", @@ -106,7 +143,7 @@ func TestCluster_HandleChanges(t *testing.T) { assert.EqualValues(t, map[string]string{ "third": "3", "fourth": "4", - }, c.values["any"]) + }, c.watchers[key].values) } func TestCluster_Load(t *testing.T) { @@ -126,9 +163,11 @@ func TestCluster_Load(t *testing.T) { }, nil) cli.EXPECT().Ctx().Return(context.Background()) c := &cluster{ - values: make(map[string]map[string]string), + watchers: make(map[watchKey]*watchValue), } - c.load(cli, "any") + c.load(cli, watchKey{ + key: "any", + }) } func TestCluster_Watch(t *testing.T) { @@ -160,13 +199,16 @@ func TestCluster_Watch(t *testing.T) { var wg sync.WaitGroup wg.Add(1) c := &cluster{ - values: make(map[string]map[string]string), - listeners: make(map[string][]UpdateListener), - watchCtx: make(map[string]context.CancelFunc), - watchFlag: make(map[string]bool), + watchers: make(map[watchKey]*watchValue), + } + key := watchKey{ + key: "any", } listener := NewMockUpdateListener(ctrl) - c.listeners["any"] = []UpdateListener{listener} + c.watchers[key] = &watchValue{ + listeners: []UpdateListener{listener}, + values: make(map[string]string), + } listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) { assert.Equal(t, "hello", kv.Key) assert.Equal(t, "world", kv.Val) @@ -175,7 +217,7 @@ func TestCluster_Watch(t *testing.T) { listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) { wg.Done() }).MaxTimes(1) - go c.watch(cli, "any", 0) + go c.watch(cli, key, 0) ch <- clientv3.WatchResponse{ Events: []*clientv3.Event{ { @@ -213,17 +255,111 @@ func TestClusterWatch_RespFailures(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := newCluster([]string{}) + c := &cluster{ + watchers: make(map[watchKey]*watchValue), + } c.done = make(chan lang.PlaceholderType) go func() { ch <- resp close(c.done) }() - c.watch(cli, "any", 0) + key := watchKey{ + key: "any", + } + c.watch(cli, key, 0) }) } } +func TestCluster_getCurrent(t *testing.T) { + t.Run("no value", func(t *testing.T) { + c := &cluster{ + watchers: map[watchKey]*watchValue{ + { + key: "any", + }: { + values: map[string]string{}, + }, + }, + } + assert.Nil(t, c.getCurrent(watchKey{ + key: "another", + })) + }) +} + +func TestCluster_handleWatchEvents(t *testing.T) { + t.Run("no value", func(t *testing.T) { + c := &cluster{ + watchers: map[watchKey]*watchValue{ + { + key: "any", + }: { + values: map[string]string{}, + }, + }, + } + assert.NotPanics(t, func() { + c.handleWatchEvents(watchKey{ + key: "another", + }, nil) + }) + }) +} + +func TestCluster_addListener(t *testing.T) { + t.Run("has listener", func(t *testing.T) { + c := &cluster{ + watchers: map[watchKey]*watchValue{ + { + key: "any", + }: { + listeners: make([]UpdateListener, 0), + }, + }, + } + assert.NotPanics(t, func() { + c.addListener(watchKey{ + key: "any", + }, nil) + }) + }) + + t.Run("no listener", func(t *testing.T) { + c := &cluster{ + watchers: map[watchKey]*watchValue{ + { + key: "any", + }: { + listeners: make([]UpdateListener, 0), + }, + }, + } + assert.NotPanics(t, func() { + c.addListener(watchKey{ + key: "another", + }, nil) + }) + }) +} + +func TestCluster_reload(t *testing.T) { + c := &cluster{ + watchers: map[watchKey]*watchValue{}, + watchGroup: threading.NewRoutineGroup(), + done: make(chan lang.PlaceholderType), + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + assert.NotPanics(t, func() { + c.reload(cli) + }) +} + func TestClusterWatch_CloseChan(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -233,44 +369,17 @@ func TestClusterWatch_CloseChan(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := newCluster([]string{}) + c := &cluster{ + watchers: make(map[watchKey]*watchValue), + } c.done = make(chan lang.PlaceholderType) go func() { close(ch) close(c.done) }() - c.watch(cli, "any", 0) -} - -func TestClusterWatch_CtxCancel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - cli := NewMockEtcdClient(ctrl) - restore := setMockClient(cli) - defer restore() - ch := make(chan clientv3.WatchResponse) - cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() - ctx, cancelFunc := context.WithCancel(context.Background()) - cli.EXPECT().Ctx().Return(ctx).AnyTimes() - c := newCluster([]string{}) - c.done = make(chan lang.PlaceholderType) - go func() { - cancelFunc() - close(ch) - }() - c.watch(cli, "any", 0) -} - -func TestCluster_ClearWatch(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - c := &cluster{ - watchCtx: map[string]context.CancelFunc{"foo": cancel}, - watchFlag: map[string]bool{"foo": true}, - } - c.clearWatch() - assert.Equal(t, ctx.Err(), context.Canceled) - assert.Equal(t, 0, len(c.watchCtx)) - assert.Equal(t, 0, len(c.watchFlag)) + c.watch(cli, watchKey{ + key: "any", + }, 0) } func TestValueOnlyContext(t *testing.T) { @@ -313,39 +422,59 @@ func TestRegistry_Monitor(t *testing.T) { GetRegistry().lock.Lock() GetRegistry().clusters = map[string]*cluster{ getClusterKey(endpoints): { - listeners: map[string][]UpdateListener{}, - values: map[string]map[string]string{ - "foo": { - "bar": "baz", + watchers: map[watchKey]*watchValue{ + watchKey{ + key: "foo", + exactMatch: true, + }: { + values: map[string]string{ + "bar": "baz", + }, }, }, - watchCtx: map[string]context.CancelFunc{}, - watchFlag: map[string]bool{}, }, } GetRegistry().lock.Unlock() - assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener), false)) + assert.Error(t, GetRegistry().Monitor(endpoints, "foo", false, new(mockListener))) } func TestRegistry_Unmonitor(t *testing.T) { - l := new(mockListener) + svr, err := mockserver.StartMockServers(1) + assert.NoError(t, err) + svr.StartAt(0) + + _, cancel := context.WithCancel(context.Background()) + endpoints := []string{svr.Servers[0].Address} GetRegistry().lock.Lock() GetRegistry().clusters = map[string]*cluster{ getClusterKey(endpoints): { - listeners: map[string][]UpdateListener{"foo": {l}}, - values: map[string]map[string]string{ - "foo": { - "bar": "baz", + watchers: map[watchKey]*watchValue{ + watchKey{ + key: "foo", + exactMatch: true, + }: { + values: map[string]string{ + "bar": "baz", + }, + cancel: cancel, }, }, }, } GetRegistry().lock.Unlock() l := new(mockListener) - assert.Error(t, GetRegistry().Monitor(endpoints, "foo", l, false)) - assert.Equal(t, 1, len(GetRegistry().clusters[getClusterKey(endpoints)].listeners["foo"])) - GetRegistry().Unmonitor(endpoints, "foo", l) - assert.Equal(t, 0, len(GetRegistry().clusters[getClusterKey(endpoints)].listeners["foo"])) + assert.NoError(t, GetRegistry().Monitor(endpoints, "foo", true, l)) + watchVals := GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{ + key: "foo", + exactMatch: true, + }] + assert.Equal(t, 1, len(watchVals.listeners)) + GetRegistry().Unmonitor(endpoints, "foo", true, l) + watchVals = GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{ + key: "foo", + exactMatch: true, + }] + assert.Nil(t, watchVals) } type mockListener struct { diff --git a/core/discov/internal/updatelistener.go b/core/discov/internal/updatelistener.go index 535ceda1..5bc288af 100644 --- a/core/discov/internal/updatelistener.go +++ b/core/discov/internal/updatelistener.go @@ -10,6 +10,7 @@ type ( } // UpdateListener wraps the OnAdd and OnDelete methods. + // The implementation should be thread-safe and idempotent. UpdateListener interface { OnAdd(kv KV) OnDelete(kv KV) diff --git a/core/discov/subscriber.go b/core/discov/subscriber.go index 3aeed274..7cb7fd51 100644 --- a/core/discov/subscriber.go +++ b/core/discov/subscriber.go @@ -4,7 +4,6 @@ import ( "sync" "sync/atomic" - "github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/discov/internal" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/syncx" @@ -17,10 +16,9 @@ type ( // A Subscriber is used to subscribe the given key on an etcd cluster. Subscriber struct { endpoints []string - key string exclusive bool - exactMatch bool key string + exactMatch bool items *container } ) @@ -39,7 +37,7 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib } sub.items = newContainer(sub.exclusive) - if err := internal.GetRegistry().Monitor(endpoints, key, sub.items, sub.exactMatch); err != nil { + if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil { return nil, err } @@ -51,16 +49,16 @@ func (s *Subscriber) AddListener(listener func()) { s.items.addListener(listener) } +// Close closes the subscriber. +func (s *Subscriber) Close() { + internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.exactMatch, s.items) +} + // Values returns all the subscription values. func (s *Subscriber) Values() []string { return s.items.getValues() } -// Close s. -func (s *Subscriber) Close() { - internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.items) -} - // Exclusive means that key value can only be 1:1, // which means later added value will remove the keys associated with the same value previously. func Exclusive() SubOption { @@ -92,7 +90,7 @@ func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify boo type container struct { exclusive bool - values map[string]*collection.Set + values map[string][]string mapping map[string]string snapshot atomic.Value dirty *syncx.AtomicBool @@ -103,7 +101,7 @@ type container struct { func newContainer(exclusive bool) *container { return &container{ exclusive: exclusive, - values: make(map[string]*collection.Set), + values: make(map[string][]string), mapping: make(map[string]string), dirty: syncx.ForAtomicBool(true), } @@ -125,21 +123,15 @@ func (c *container) addKv(key, value string) ([]string, bool) { defer c.lock.Unlock() c.dirty.Set(true) - if c.values[value] == nil { - c.values[value] = collection.NewSet() - } - keys := c.values[value].KeysStr() + keys := c.values[value] previous := append([]string(nil), keys...) early := len(keys) > 0 if c.exclusive && early { for _, each := range keys { c.doRemoveKey(each) } - if c.values[value] == nil { - c.values[value] = collection.NewSet() - } } - c.values[value].AddStr(key) + c.values[value] = append(c.values[value], key) c.mapping[key] = value if early { @@ -162,12 +154,18 @@ func (c *container) doRemoveKey(key string) { } delete(c.mapping, key) - if c.values[server] == nil { - return - } - c.values[server].Remove(key) + keys := c.values[server] + remain := keys[:0] - if c.values[server].Count() == 0 { + for _, k := range keys { + if k != key { + remain = append(remain, k) + } + } + + if len(remain) > 0 { + c.values[server] = remain + } else { delete(c.values, server) } } diff --git a/core/discov/subscriber_test.go b/core/discov/subscriber_test.go index 1f760979..b8762afa 100644 --- a/core/discov/subscriber_test.go +++ b/core/discov/subscriber_test.go @@ -214,18 +214,6 @@ func TestSubscriber(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&count)) } -func TestSubscriberClos(t *testing.T) { - l := newContainer(false) - sub := &Subscriber{ - endpoints: []string{"localhost:2379"}, - key: "foo", - items: l, - } - _ = internal.GetRegistry().Monitor(sub.endpoints, sub.key, l, false) - sub.Close() - assert.Empty(t, sub.items.listeners) -} - func TestWithSubEtcdAccount(t *testing.T) { endpoints := []string{"localhost:2379"} user := stringx.Rand() @@ -237,3 +225,28 @@ func TestWithSubEtcdAccount(t *testing.T) { assert.Equal(t, user, account.User) assert.Equal(t, "bar", account.Pass) } + +func TestWithExactMatch(t *testing.T) { + sub := new(Subscriber) + WithExactMatch()(sub) + sub.items = newContainer(sub.exclusive) + var count int32 + sub.AddListener(func() { + atomic.AddInt32(&count, 1) + }) + sub.items.notifyChange() + assert.Empty(t, sub.Values()) + assert.Equal(t, int32(1), atomic.LoadInt32(&count)) +} + +func TestSubscriberClose(t *testing.T) { + l := newContainer(false) + sub := &Subscriber{ + endpoints: []string{"localhost:12379"}, + key: "foo", + items: l, + } + assert.NotPanics(t, func() { + sub.Close() + }) +} diff --git a/zrpc/resolver/internal/discovbuilder.go b/zrpc/resolver/internal/discovbuilder.go index 5a91ee73..1aa5d8f3 100644 --- a/zrpc/resolver/internal/discovbuilder.go +++ b/zrpc/resolver/internal/discovbuilder.go @@ -38,9 +38,24 @@ func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ sub.AddListener(update) update() - return &nopResolver{cc: cc, closeFunc: func() { sub.Close() }}, nil + return &discovResolver{ + cc: cc, + sub: sub, + }, nil } func (b *discovBuilder) Scheme() string { return DiscovScheme } + +type discovResolver struct { + cc resolver.ClientConn + sub *discov.Subscriber +} + +func (r *discovResolver) Close() { + r.sub.Close() +} + +func (r *discovResolver) ResolveNow(_ resolver.ResolveNowOptions) { +} diff --git a/zrpc/resolver/internal/kubebuilder.go b/zrpc/resolver/internal/kubebuilder.go index 5eecec46..c2498865 100644 --- a/zrpc/resolver/internal/kubebuilder.go +++ b/zrpc/resolver/internal/kubebuilder.go @@ -28,6 +28,10 @@ type kubeResolver struct { stopCh chan struct{} } +func (r *kubeResolver) Close() { + close(r.stopCh) +} + func (r *kubeResolver) ResolveNow(_ resolver.ResolveNowOptions) {} func (r *kubeResolver) start() { @@ -36,10 +40,6 @@ func (r *kubeResolver) start() { }) } -func (r *kubeResolver) Close() { - close(r.stopCh) -} - type kubeBuilder struct{} func (b *kubeBuilder) Build(target resolver.Target, cc resolver.ClientConn, diff --git a/zrpc/resolver/internal/resolver.go b/zrpc/resolver/internal/resolver.go index e04d65d8..7868eca8 100644 --- a/zrpc/resolver/internal/resolver.go +++ b/zrpc/resolver/internal/resolver.go @@ -37,14 +37,10 @@ func register() { } type nopResolver struct { - cc resolver.ClientConn - closeFunc func() + cc resolver.ClientConn } func (r *nopResolver) Close() { - if r.closeFunc != nil { - r.closeFunc() - } } func (r *nopResolver) ResolveNow(_ resolver.ResolveNowOptions) { diff --git a/zrpc/resolver/internal/resolver_test.go b/zrpc/resolver/internal/resolver_test.go index 99934678..7dd10ee7 100644 --- a/zrpc/resolver/internal/resolver_test.go +++ b/zrpc/resolver/internal/resolver_test.go @@ -1,7 +1,6 @@ package internal import ( - "github.com/zeromicro/go-zero/core/discov" "testing" "github.com/stretchr/testify/assert" @@ -19,20 +18,6 @@ func TestNopResolver(t *testing.T) { }) } -func TestNopResolver_Close(t *testing.T) { - var isChanged bool - r := nopResolver{} - r.Close() - assert.False(t, isChanged) - r = nopResolver{ - closeFunc: func() { - isChanged = true - }, - } - r.Close() - assert.True(t, isChanged) -} - type mockedClientConn struct { state resolver.State err error