diff --git a/core/discov/internal/registry.go b/core/discov/internal/registry.go index e1dce491..bc26a588 100644 --- a/core/discov/internal/registry.go +++ b/core/discov/internal/registry.go @@ -100,6 +100,8 @@ type cluster struct { done chan lang.PlaceholderType lock sync.RWMutex exactMatch bool + watchCtx map[string]context.CancelFunc + watchFlag map[string]bool } func newCluster(endpoints []string) *cluster { @@ -110,6 +112,8 @@ func newCluster(endpoints []string) *cluster { listeners: make(map[string][]UpdateListener), watchGroup: threading.NewRoutineGroup(), done: make(chan lang.PlaceholderType), + watchCtx: make(map[string]context.CancelFunc), + watchFlag: make(map[string]bool), } } @@ -272,15 +276,17 @@ func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error { c.exactMatch = exactMatch c.lock.Unlock() - cli, err := c.getClient() - if err != nil { - return err - } + if !c.watchFlag[key] { + cli, err := c.getClient() + if err != nil { + return err + } - rev := c.load(cli, key) - c.watchGroup.Run(func() { - c.watch(cli, key, rev) - }) + rev := c.load(cli, key) + c.watchGroup.Run(func() { + c.watch(cli, key, rev) + }) + } return nil } @@ -314,6 +320,7 @@ func (c *cluster) reload(cli EtcdClient) { for k := range c.listeners { keys = append(keys, k) } + c.clearWatch() c.lock.Unlock() for _, key := range keys { @@ -326,8 +333,9 @@ func (c *cluster) reload(cli EtcdClient) { } func (c *cluster) watch(cli EtcdClient, key string, rev int64) { + ctx := c.addWatch(key, cli) for { - err := c.watchStream(cli, key, rev) + err := c.watchStream(cli, key, rev, ctx) if err == nil { return } @@ -342,7 +350,7 @@ func (c *cluster) watch(cli EtcdClient, key string, rev int64) { } } -func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { +func (c *cluster) watchStream(cli EtcdClient, key string, rev int64, ctx context.Context) error { var ( rch clientv3.WatchChan ops []clientv3.OpOption @@ -356,7 +364,7 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { ops = append(ops, clientv3.WithRev(rev+1)) } - rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), watchKey, ops...) + rch = cli.Watch(clientv3.WithRequireLeader(ctx), watchKey, ops...) for { select { @@ -374,6 +382,8 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { c.handleWatchEvents(key, wresp.Events) case <-c.done: return nil + case <-ctx.Done(): + return nil } } } @@ -386,6 +396,23 @@ func (c *cluster) watchConnState(cli EtcdClient) { watcher.watch(cli.ActiveConnection()) } +func (c *cluster) addWatch(key string, cli EtcdClient) context.Context { + ctx, cancel := context.WithCancel(cli.Ctx()) + c.lock.Lock() + c.watchCtx[key] = cancel + c.watchFlag[key] = true + c.lock.Unlock() + return ctx +} + +func (c *cluster) clearWatch() { + for _, cancel := range c.watchCtx { + cancel() + } + c.watchCtx = make(map[string]context.CancelFunc) + c.watchFlag = make(map[string]bool) +} + // DialClient dials an etcd cluster with given endpoints. func DialClient(endpoints []string) (EtcdClient, error) { cfg := clientv3.Config{ diff --git a/core/discov/internal/registry_test.go b/core/discov/internal/registry_test.go index 3c78a2b2..e25b9757 100644 --- a/core/discov/internal/registry_test.go +++ b/core/discov/internal/registry_test.go @@ -160,8 +160,10 @@ func TestCluster_Watch(t *testing.T) { var wg sync.WaitGroup wg.Add(1) c := &cluster{ - listeners: make(map[string][]UpdateListener), values: make(map[string]map[string]string), + listeners: make(map[string][]UpdateListener), + watchCtx: make(map[string]context.CancelFunc), + watchFlag: make(map[string]bool), } listener := NewMockUpdateListener(ctrl) c.listeners["any"] = []UpdateListener{listener} @@ -211,7 +213,7 @@ func TestClusterWatch_RespFailures(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := new(cluster) + c := newCluster([]string{}) c.done = make(chan lang.PlaceholderType) go func() { ch <- resp @@ -231,7 +233,7 @@ func TestClusterWatch_CloseChan(t *testing.T) { ch := make(chan clientv3.WatchResponse) cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() - c := new(cluster) + c := newCluster([]string{}) c.done = make(chan lang.PlaceholderType) go func() { close(ch) @@ -240,6 +242,37 @@ func TestClusterWatch_CloseChan(t *testing.T) { c.watch(cli, "any", 0) } +func TestClusterWatch_CtxCancel(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := NewMockEtcdClient(ctrl) + restore := setMockClient(cli) + defer restore() + ch := make(chan clientv3.WatchResponse) + cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() + ctx, cancelFunc := context.WithCancel(context.Background()) + cli.EXPECT().Ctx().Return(ctx).AnyTimes() + c := newCluster([]string{}) + c.done = make(chan lang.PlaceholderType) + go func() { + cancelFunc() + close(ch) + }() + c.watch(cli, "any", 0) +} + +func TestCluster_ClearWatch(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := &cluster{ + watchCtx: map[string]context.CancelFunc{"foo": cancel}, + watchFlag: map[string]bool{"foo": true}, + } + c.clearWatch() + assert.Equal(t, ctx.Err(), context.Canceled) + assert.Equal(t, 0, len(c.watchCtx)) + assert.Equal(t, 0, len(c.watchFlag)) +} + func TestValueOnlyContext(t *testing.T) { ctx := contextx.ValueOnlyFrom(context.Background()) ctx.Done() @@ -286,6 +319,8 @@ func TestRegistry_Monitor(t *testing.T) { "bar": "baz", }, }, + watchCtx: map[string]context.CancelFunc{}, + watchFlag: map[string]bool{}, }, } GetRegistry().lock.Unlock() @@ -293,15 +328,11 @@ func TestRegistry_Monitor(t *testing.T) { } func TestRegistry_Unmonitor(t *testing.T) { - svr, err := mockserver.StartMockServers(1) - assert.NoError(t, err) - svr.StartAt(0) - - endpoints := []string{svr.Servers[0].Address} + l := new(mockListener) GetRegistry().lock.Lock() GetRegistry().clusters = map[string]*cluster{ getClusterKey(endpoints): { - listeners: map[string][]UpdateListener{}, + listeners: map[string][]UpdateListener{"foo": {l}}, values: map[string]map[string]string{ "foo": { "bar": "baz", diff --git a/core/discov/subscriber.go b/core/discov/subscriber.go index cbf1c69d..3aeed274 100644 --- a/core/discov/subscriber.go +++ b/core/discov/subscriber.go @@ -20,6 +20,7 @@ type ( key string exclusive bool exactMatch bool + key string items *container } ) @@ -31,6 +32,7 @@ type ( func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) { sub := &Subscriber{ endpoints: endpoints, + key: key, } for _, opt := range opts { opt(sub) diff --git a/core/discov/subscriber_test.go b/core/discov/subscriber_test.go index 6dce7cec..1f760979 100644 --- a/core/discov/subscriber_test.go +++ b/core/discov/subscriber_test.go @@ -214,6 +214,18 @@ func TestSubscriber(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&count)) } +func TestSubscriberClos(t *testing.T) { + l := newContainer(false) + sub := &Subscriber{ + endpoints: []string{"localhost:2379"}, + key: "foo", + items: l, + } + _ = internal.GetRegistry().Monitor(sub.endpoints, sub.key, l, false) + sub.Close() + assert.Empty(t, sub.items.listeners) +} + func TestWithSubEtcdAccount(t *testing.T) { endpoints := []string{"localhost:2379"} user := stringx.Rand() diff --git a/zrpc/resolver/internal/resolver_test.go b/zrpc/resolver/internal/resolver_test.go index 38799726..99934678 100644 --- a/zrpc/resolver/internal/resolver_test.go +++ b/zrpc/resolver/internal/resolver_test.go @@ -1,6 +1,7 @@ package internal import ( + "github.com/zeromicro/go-zero/core/discov" "testing" "github.com/stretchr/testify/assert"