diff --git a/core/breaker/breakers.go b/core/breaker/breakers.go index 1df42d3e..af00e8d0 100644 --- a/core/breaker/breakers.go +++ b/core/breaker/breakers.go @@ -1,6 +1,9 @@ package breaker -import "sync" +import ( + "context" + "sync" +) var ( lock sync.RWMutex @@ -14,6 +17,13 @@ func Do(name string, req func() error) error { }) } +// DoCtx calls Breaker.DoCtx on the Breaker with given name. +func DoCtx(ctx context.Context, name string, req func() error) error { + return do(name, func(b Breaker) error { + return b.DoCtx(ctx, req) + }) +} + // DoWithAcceptable calls Breaker.DoWithAcceptable on the Breaker with given name. func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error { return do(name, func(b Breaker) error { @@ -21,6 +31,14 @@ func DoWithAcceptable(name string, req func() error, acceptable Acceptable) erro }) } +// DoWithAcceptableCtx calls Breaker.DoWithAcceptableCtx on the Breaker with given name. +func DoWithAcceptableCtx(ctx context.Context, name string, req func() error, + acceptable Acceptable) error { + return do(name, func(b Breaker) error { + return b.DoWithAcceptableCtx(ctx, req, acceptable) + }) +} + // DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name. func DoWithFallback(name string, req func() error, fallback Fallback) error { return do(name, func(b Breaker) error { @@ -28,6 +46,13 @@ func DoWithFallback(name string, req func() error, fallback Fallback) error { }) } +// DoWithFallbackCtx calls Breaker.DoWithFallbackCtx on the Breaker with given name. +func DoWithFallbackCtx(ctx context.Context, name string, req func() error, fallback Fallback) error { + return do(name, func(b Breaker) error { + return b.DoWithFallbackCtx(ctx, req, fallback) + }) +} + // DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name. func DoWithFallbackAcceptable(name string, req func() error, fallback Fallback, acceptable Acceptable) error { @@ -36,6 +61,14 @@ func DoWithFallbackAcceptable(name string, req func() error, fallback Fallback, }) } +// DoWithFallbackAcceptableCtx calls Breaker.DoWithFallbackAcceptableCtx on the Breaker with given name. +func DoWithFallbackAcceptableCtx(ctx context.Context, name string, req func() error, + fallback Fallback, acceptable Acceptable) error { + return do(name, func(b Breaker) error { + return b.DoWithFallbackAcceptableCtx(ctx, req, fallback, acceptable) + }) +} + // GetBreaker returns the Breaker with the given name. func GetBreaker(name string) Breaker { lock.RLock() diff --git a/core/breaker/breakers_test.go b/core/breaker/breakers_test.go index ad1e62fb..af7cb8cc 100644 --- a/core/breaker/breakers_test.go +++ b/core/breaker/breakers_test.go @@ -1,6 +1,7 @@ package breaker import ( + "context" "errors" "fmt" "testing" @@ -22,6 +23,9 @@ func TestBreakersDo(t *testing.T) { assert.Equal(t, errDummy, Do("any", func() error { return errDummy })) + assert.Equal(t, errDummy, DoCtx(context.Background(), "any", func() error { + return errDummy + })) } func TestBreakersDoWithAcceptable(t *testing.T) { @@ -38,6 +42,13 @@ func TestBreakersDoWithAcceptable(t *testing.T) { return nil }) == nil }) + verify(t, func() bool { + return DoWithAcceptableCtx(context.Background(), "anyone", func() error { + return nil + }, func(err error) bool { + return true + }) == nil + }) for i := 0; i < 10000; i++ { err := DoWithAcceptable("another", func() error { @@ -76,6 +87,12 @@ func TestBreakersFallback(t *testing.T) { return nil }) assert.True(t, err == nil || errors.Is(err, errDummy)) + err = DoWithFallbackCtx(context.Background(), "fallback", func() error { + return errDummy + }, func(err error) error { + return nil + }) + assert.True(t, err == nil || errors.Is(err, errDummy)) } verify(t, func() bool { return errors.Is(Do("fallback", func() error { @@ -86,7 +103,7 @@ func TestBreakersFallback(t *testing.T) { func TestBreakersAcceptableFallback(t *testing.T) { errDummy := errors.New("any") - for i := 0; i < 10000; i++ { + for i := 0; i < 5000; i++ { err := DoWithFallbackAcceptable("acceptablefallback", func() error { return errDummy }, func(err error) error { @@ -95,6 +112,14 @@ func TestBreakersAcceptableFallback(t *testing.T) { return err == nil }) assert.True(t, err == nil || errors.Is(err, errDummy)) + err = DoWithFallbackAcceptableCtx(context.Background(), "acceptablefallback", func() error { + return errDummy + }, func(err error) error { + return nil + }, func(err error) bool { + return err == nil + }) + assert.True(t, err == nil || errors.Is(err, errDummy)) } verify(t, func() bool { return errors.Is(Do("acceptablefallback", func() error {