diff --git a/core/errorx/check.go b/core/errorx/check.go new file mode 100644 index 00000000..b6452b9a --- /dev/null +++ b/core/errorx/check.go @@ -0,0 +1,14 @@ +package errorx + +import "errors" + +// In checks if the given err is one of errs. +func In(err error, errs ...error) bool { + for _, each := range errs { + if errors.Is(err, each) { + return true + } + } + + return false +} diff --git a/core/errorx/check_test.go b/core/errorx/check_test.go new file mode 100644 index 00000000..0e7b267f --- /dev/null +++ b/core/errorx/check_test.go @@ -0,0 +1,70 @@ +package errorx + +import ( + "errors" + "testing" +) + +func TestIn(t *testing.T) { + err1 := errors.New("error 1") + err2 := errors.New("error 2") + err3 := errors.New("error 3") + + tests := []struct { + name string + err error + errs []error + want bool + }{ + { + name: "Error matches one of the errors in the list", + err: err1, + errs: []error{err1, err2}, + want: true, + }, + { + name: "Error does not match any errors in the list", + err: err3, + errs: []error{err1, err2}, + want: false, + }, + { + name: "Empty error list", + err: err1, + errs: []error{}, + want: false, + }, + { + name: "Nil error with non-nil list", + err: nil, + errs: []error{err1, err2}, + want: false, + }, + { + name: "Non-nil error with nil in list", + err: err1, + errs: []error{nil, err2}, + want: false, + }, + { + name: "Error matches nil error in the list", + err: nil, + errs: []error{nil, err2}, + want: true, + }, + { + name: "Nil error with empty list", + err: nil, + errs: []error{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := In(tt.err, tt.errs...); got != tt.want { + t.Errorf("In() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/limit/tokenlimit.go b/core/limit/tokenlimit.go index 25c7ba0b..5b481dcd 100644 --- a/core/limit/tokenlimit.go +++ b/core/limit/tokenlimit.go @@ -10,6 +10,7 @@ import ( "sync/atomic" "time" + "github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/redis" xrate "golang.org/x/time/rate" @@ -103,7 +104,7 @@ func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) boo if errors.Is(err, redis.Nil) { return false } - if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + if errorx.In(err, context.DeadlineExceeded, context.Canceled) { logx.Errorf("fail to use rate limiter: %s", err) return false } diff --git a/core/stores/mon/collection.go b/core/stores/mon/collection.go index 5deadbef..e6f7e702 100644 --- a/core/stores/mon/collection.go +++ b/core/stores/mon/collection.go @@ -2,10 +2,10 @@ package mon import ( "context" - "errors" "time" "github.com/zeromicro/go-zero/core/breaker" + "github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/timex" "go.mongodb.org/mongo-driver/mongo" mopt "go.mongodb.org/mongo-driver/mongo/options" @@ -527,19 +527,10 @@ func (p keepablePromise) keep(err error) error { } func acceptable(err error) bool { - return err == nil || - errors.Is(err, mongo.ErrNoDocuments) || - errors.Is(err, mongo.ErrNilValue) || - errors.Is(err, mongo.ErrNilDocument) || - errors.Is(err, mongo.ErrNilCursor) || - errors.Is(err, mongo.ErrEmptySlice) || + return err == nil || errorx.In(err, mongo.ErrNoDocuments, mongo.ErrNilValue, + mongo.ErrNilDocument, mongo.ErrNilCursor, mongo.ErrEmptySlice, // session errors - errors.Is(err, session.ErrSessionEnded) || - errors.Is(err, session.ErrNoTransactStarted) || - errors.Is(err, session.ErrTransactInProgress) || - errors.Is(err, session.ErrAbortAfterCommit) || - errors.Is(err, session.ErrAbortTwice) || - errors.Is(err, session.ErrCommitAfterAbort) || - errors.Is(err, session.ErrUnackWCUnsupported) || - errors.Is(err, session.ErrSnapshotTransaction) + session.ErrSessionEnded, session.ErrNoTransactStarted, session.ErrTransactInProgress, + session.ErrAbortAfterCommit, session.ErrAbortTwice, session.ErrCommitAfterAbort, + session.ErrUnackWCUnsupported, session.ErrSnapshotTransaction) } diff --git a/core/stores/mon/trace.go b/core/stores/mon/trace.go index 99bf96b6..8b7818b2 100644 --- a/core/stores/mon/trace.go +++ b/core/stores/mon/trace.go @@ -2,8 +2,8 @@ package mon import ( "context" - "errors" + "github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/trace" "go.mongodb.org/mongo-driver/mongo" "go.opentelemetry.io/otel/attribute" @@ -24,8 +24,7 @@ func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span func endSpan(span oteltrace.Span, err error) { defer span.End() - if err == nil || errors.Is(err, mongo.ErrNoDocuments) || - errors.Is(err, mongo.ErrNilValue) || errors.Is(err, mongo.ErrNilDocument) { + if err == nil || errorx.In(err, mongo.ErrNoDocuments, mongo.ErrNilValue, mongo.ErrNilDocument) { span.SetStatus(codes.Ok, "") return } diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index adce614a..76265c5a 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -2372,7 +2372,7 @@ func withHook(hook red.Hook) Option { } func acceptable(err error) bool { - return err == nil || errors.Is(err, red.Nil) || errors.Is(err, context.Canceled) + return err == nil || errorx.In(err, red.Nil, context.Canceled) } func getRedis(r *Redis) (RedisNode, error) { diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index ff676b45..34a4f386 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/zeromicro/go-zero/core/breaker" + "github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/logx" ) @@ -267,8 +268,7 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex } func (db *commonSqlConn) acceptable(err error) bool { - if err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) || - errors.Is(err, context.Canceled) { + if err == nil || errorx.In(err, sql.ErrNoRows, sql.ErrTxDone, context.Canceled) { return true }