From 815a4f7eed31865efbcf317b51754de218ee33a8 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Thu, 18 Apr 2024 18:00:17 +0800 Subject: [PATCH] feat: support context in breaker methods (#4088) --- core/breaker/breaker.go | 60 +++++++ core/breaker/breaker_test.go | 242 ++++++++++++++++++++++++++++- core/breaker/nopbreaker.go | 23 +++ core/breaker/nopbreaker_test.go | 19 +++ core/stores/mon/collection_test.go | 21 +++ 5 files changed, 361 insertions(+), 4 deletions(-) diff --git a/core/breaker/breaker.go b/core/breaker/breaker.go index a53aca5c..cbcede46 100644 --- a/core/breaker/breaker.go +++ b/core/breaker/breaker.go @@ -1,6 +1,7 @@ package breaker import ( + "context" "errors" "fmt" "strings" @@ -36,12 +37,16 @@ type ( // The caller needs to call promise.Accept() on success, // or call promise.Reject() on failure. Allow() (Promise, error) + // AllowCtx checks if the request is allowed when ctx isn't done. + AllowCtx(ctx context.Context) (Promise, error) // Do runs the given request if the Breaker accepts it. // Do returns an error instantly if the Breaker rejects the request. // If a panic occurs in the request, the Breaker handles it as an error // and causes the same panic again. Do(req func() error) error + // DoCtx runs the given request if the Breaker accepts it when ctx isn't done. + DoCtx(ctx context.Context, req func() error) error // DoWithAcceptable runs the given request if the Breaker accepts it. // DoWithAcceptable returns an error instantly if the Breaker rejects the request. @@ -49,12 +54,16 @@ type ( // and causes the same panic again. // acceptable checks if it's a successful call, even if the error is not nil. DoWithAcceptable(req func() error, acceptable Acceptable) error + // DoWithAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done. + DoWithAcceptableCtx(ctx context.Context, req func() error, acceptable Acceptable) error // DoWithFallback runs the given request if the Breaker accepts it. // DoWithFallback runs the fallback if the Breaker rejects the request. // If a panic occurs in the request, the Breaker handles it as an error // and causes the same panic again. DoWithFallback(req func() error, fallback Fallback) error + // DoWithFallbackCtx runs the given request if the Breaker accepts it when ctx isn't done. + DoWithFallbackCtx(ctx context.Context, req func() error, fallback Fallback) error // DoWithFallbackAcceptable runs the given request if the Breaker accepts it. // DoWithFallbackAcceptable runs the fallback if the Breaker rejects the request. @@ -62,6 +71,9 @@ type ( // and causes the same panic again. // acceptable checks if it's a successful call, even if the error is not nil. DoWithFallbackAcceptable(req func() error, fallback Fallback, acceptable Acceptable) error + // DoWithFallbackAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done. + DoWithFallbackAcceptableCtx(ctx context.Context, req func() error, fallback Fallback, + acceptable Acceptable) error } // Fallback is the func to be called if the request is rejected. @@ -118,23 +130,71 @@ func (cb *circuitBreaker) Allow() (Promise, error) { return cb.throttle.allow() } +func (cb *circuitBreaker) AllowCtx(ctx context.Context) (Promise, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return cb.Allow() + } +} + func (cb *circuitBreaker) Do(req func() error) error { return cb.throttle.doReq(req, nil, defaultAcceptable) } +func (cb *circuitBreaker) DoCtx(ctx context.Context, req func() error) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return cb.Do(req) + } +} + func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error { return cb.throttle.doReq(req, nil, acceptable) } +func (cb *circuitBreaker) DoWithAcceptableCtx(ctx context.Context, req func() error, + acceptable Acceptable) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return cb.DoWithAcceptable(req, acceptable) + } +} + func (cb *circuitBreaker) DoWithFallback(req func() error, fallback Fallback) error { return cb.throttle.doReq(req, fallback, defaultAcceptable) } +func (cb *circuitBreaker) DoWithFallbackCtx(ctx context.Context, req func() error, + fallback Fallback) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return cb.DoWithFallback(req, fallback) + } +} + func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback Fallback, acceptable Acceptable) error { return cb.throttle.doReq(req, fallback, acceptable) } +func (cb *circuitBreaker) DoWithFallbackAcceptableCtx(ctx context.Context, req func() error, + fallback Fallback, acceptable Acceptable) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return cb.DoWithFallbackAcceptable(req, fallback, acceptable) + } +} + func (cb *circuitBreaker) Name() string { return cb.name } diff --git a/core/breaker/breaker_test.go b/core/breaker/breaker_test.go index 70a3480c..cbbfb40e 100644 --- a/core/breaker/breaker_test.go +++ b/core/breaker/breaker_test.go @@ -1,11 +1,13 @@ package breaker import ( + "context" "errors" "fmt" "strconv" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/stat" @@ -16,10 +18,242 @@ func init() { } func TestCircuitBreaker_Allow(t *testing.T) { - b := NewBreaker() - assert.True(t, len(b.Name()) > 0) - _, err := b.Allow() - assert.Nil(t, err) + t.Run("allow", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + _, err := b.Allow() + assert.Nil(t, err) + }) + + t.Run("allow with ctx", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + _, err := b.AllowCtx(context.Background()) + assert.Nil(t, err) + }) + + t.Run("allow with ctx timeout", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + time.Sleep(time.Millisecond) + _, err := b.AllowCtx(ctx) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("allow with ctx cancel", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancel() + _, err := b.AllowCtx(ctx) + assert.ErrorIs(t, err, context.Canceled) + }) +} + +func TestCircuitBreaker_Do(t *testing.T) { + t.Run("do", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.Do(func() error { + return nil + }) + assert.Nil(t, err) + }) + + t.Run("do with ctx", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.DoCtx(context.Background(), func() error { + return nil + }) + assert.Nil(t, err) + }) + + t.Run("do with ctx timeout", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + time.Sleep(time.Millisecond) + err := b.DoCtx(ctx, func() error { + return nil + }) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("do with ctx cancel", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancel() + err := b.DoCtx(ctx, func() error { + return nil + }) + assert.ErrorIs(t, err, context.Canceled) + }) +} + +func TestCircuitBreaker_DoWithAcceptable(t *testing.T) { + t.Run("doWithAcceptable", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.DoWithAcceptable(func() error { + return nil + }, func(err error) bool { + return true + }) + assert.Nil(t, err) + }) + + t.Run("doWithAcceptable with ctx", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.DoWithAcceptableCtx(context.Background(), func() error { + return nil + }, func(err error) bool { + return true + }) + assert.Nil(t, err) + }) + + t.Run("doWithAcceptable with ctx timeout", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + time.Sleep(time.Millisecond) + err := b.DoWithAcceptableCtx(ctx, func() error { + return nil + }, func(err error) bool { + return true + }) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("doWithAcceptable with ctx cancel", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancel() + err := b.DoWithAcceptableCtx(ctx, func() error { + return nil + }, func(err error) bool { + return true + }) + assert.ErrorIs(t, err, context.Canceled) + }) +} + +func TestCircuitBreaker_DoWithFallback(t *testing.T) { + t.Run("doWithFallback", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.DoWithFallback(func() error { + return nil + }, func(err error) error { + return err + }) + assert.Nil(t, err) + }) + + t.Run("doWithFallback with ctx", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.DoWithFallbackCtx(context.Background(), func() error { + return nil + }, func(err error) error { + return err + }) + assert.Nil(t, err) + }) + + t.Run("doWithFallback with ctx timeout", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + time.Sleep(time.Millisecond) + err := b.DoWithFallbackCtx(ctx, func() error { + return nil + }, func(err error) error { + return err + }) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("doWithFallback with ctx cancel", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancel() + err := b.DoWithFallbackCtx(ctx, func() error { + return nil + }, func(err error) error { + return err + }) + assert.ErrorIs(t, err, context.Canceled) + }) +} + +func TestCircuitBreaker_DoWithFallbackAcceptable(t *testing.T) { + t.Run("doWithFallbackAcceptable", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.DoWithFallbackAcceptable(func() error { + return nil + }, func(err error) error { + return err + }, func(err error) bool { + return true + }) + assert.Nil(t, err) + }) + + t.Run("doWithFallbackAcceptable with ctx", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + err := b.DoWithFallbackAcceptableCtx(context.Background(), func() error { + return nil + }, func(err error) error { + return err + }, func(err error) bool { + return true + }) + assert.Nil(t, err) + }) + + t.Run("doWithFallbackAcceptable with ctx timeout", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) + defer cancel() + time.Sleep(time.Millisecond) + err := b.DoWithFallbackAcceptableCtx(ctx, func() error { + return nil + }, func(err error) error { + return err + }, func(err error) bool { + return true + }) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("doWithFallbackAcceptable with ctx cancel", func(t *testing.T) { + b := NewBreaker() + assert.True(t, len(b.Name()) > 0) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancel() + err := b.DoWithFallbackAcceptableCtx(ctx, func() error { + return nil + }, func(err error) error { + return err + }, func(err error) bool { + return true + }) + assert.ErrorIs(t, err, context.Canceled) + }) } func TestLogReason(t *testing.T) { diff --git a/core/breaker/nopbreaker.go b/core/breaker/nopbreaker.go index baa09801..99ef29f1 100644 --- a/core/breaker/nopbreaker.go +++ b/core/breaker/nopbreaker.go @@ -1,5 +1,7 @@ package breaker +import "context" + const nopBreakerName = "nopBreaker" type nopBreaker struct{} @@ -17,22 +19,43 @@ func (b nopBreaker) Allow() (Promise, error) { return nopPromise{}, nil } +func (b nopBreaker) AllowCtx(_ context.Context) (Promise, error) { + return nopPromise{}, nil +} + func (b nopBreaker) Do(req func() error) error { return req() } +func (b nopBreaker) DoCtx(_ context.Context, req func() error) error { + return req() +} + func (b nopBreaker) DoWithAcceptable(req func() error, _ Acceptable) error { return req() } +func (b nopBreaker) DoWithAcceptableCtx(_ context.Context, req func() error, _ Acceptable) error { + return req() +} + func (b nopBreaker) DoWithFallback(req func() error, _ Fallback) error { return req() } +func (b nopBreaker) DoWithFallbackCtx(_ context.Context, req func() error, _ Fallback) error { + return req() +} + func (b nopBreaker) DoWithFallbackAcceptable(req func() error, _ Fallback, _ Acceptable) error { return req() } +func (b nopBreaker) DoWithFallbackAcceptableCtx(_ context.Context, req func() error, + _ Fallback, _ Acceptable) error { + return req() +} + type nopPromise struct{} func (p nopPromise) Accept() { diff --git a/core/breaker/nopbreaker_test.go b/core/breaker/nopbreaker_test.go index ac26428d..8ced1c22 100644 --- a/core/breaker/nopbreaker_test.go +++ b/core/breaker/nopbreaker_test.go @@ -1,6 +1,7 @@ package breaker import ( + "context" "errors" "testing" @@ -12,6 +13,8 @@ func TestNopBreaker(t *testing.T) { assert.Equal(t, nopBreakerName, b.Name()) p, err := b.Allow() assert.Nil(t, err) + p, err = b.AllowCtx(context.Background()) + assert.Nil(t, err) p.Accept() for i := 0; i < 1000; i++ { p, err := b.Allow() @@ -21,18 +24,34 @@ func TestNopBreaker(t *testing.T) { assert.Nil(t, b.Do(func() error { return nil })) + assert.Nil(t, b.DoCtx(context.Background(), func() error { + return nil + })) assert.Nil(t, b.DoWithAcceptable(func() error { return nil }, defaultAcceptable)) + assert.Nil(t, b.DoWithAcceptableCtx(context.Background(), func() error { + return nil + }, defaultAcceptable)) errDummy := errors.New("any") assert.Equal(t, errDummy, b.DoWithFallback(func() error { return errDummy }, func(err error) error { return nil })) + assert.Equal(t, errDummy, b.DoWithFallbackCtx(context.Background(), func() error { + return errDummy + }, func(err error) error { + return nil + })) assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error { return errDummy }, func(err error) error { return nil }, defaultAcceptable)) + assert.Equal(t, errDummy, b.DoWithFallbackAcceptableCtx(context.Background(), func() error { + return errDummy + }, func(err error) error { + return nil + }, defaultAcceptable)) } diff --git a/core/stores/mon/collection_test.go b/core/stores/mon/collection_test.go index 2adc27f0..4a3f3b56 100644 --- a/core/stores/mon/collection_test.go +++ b/core/stores/mon/collection_test.go @@ -595,19 +595,40 @@ func (d *dropBreaker) Allow() (breaker.Promise, error) { return nil, errDummy } +func (d *dropBreaker) AllowCtx(_ context.Context) (breaker.Promise, error) { + return nil, errDummy +} + func (d *dropBreaker) Do(_ func() error) error { return nil } +func (d *dropBreaker) DoCtx(_ context.Context, _ func() error) error { + return nil +} + func (d *dropBreaker) DoWithAcceptable(_ func() error, _ breaker.Acceptable) error { return errDummy } +func (d *dropBreaker) DoWithAcceptableCtx(_ context.Context, _ func() error, _ breaker.Acceptable) error { + return errDummy +} + func (d *dropBreaker) DoWithFallback(_ func() error, _ breaker.Fallback) error { return nil } +func (d *dropBreaker) DoWithFallbackCtx(_ context.Context, _ func() error, _ breaker.Fallback) error { + return nil +} + func (d *dropBreaker) DoWithFallbackAcceptable(_ func() error, _ breaker.Fallback, _ breaker.Acceptable) error { return nil } + +func (d *dropBreaker) DoWithFallbackAcceptableCtx(_ context.Context, _ func() error, + _ breaker.Fallback, _ breaker.Acceptable) error { + return nil +}