diff --git a/core/breaker/breakers_test.go b/core/breaker/breakers_test.go index 74aa5b4c..ad1e62fb 100644 --- a/core/breaker/breakers_test.go +++ b/core/breaker/breakers_test.go @@ -30,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) { assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error { return errDummy }, func(err error) bool { - return err == nil || err == errDummy + return err == nil || errors.Is(err, errDummy) })) } verify(t, func() bool { @@ -45,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) { }, func(err error) bool { return err == nil }) - assert.True(t, err == errDummy || err == ErrServiceUnavailable) + assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable)) } verify(t, func() bool { - return ErrServiceUnavailable == Do("another", func() error { + return errors.Is(Do("another", func() error { return nil - }) + }), ErrServiceUnavailable) }) } @@ -75,12 +75,12 @@ func TestBreakersFallback(t *testing.T) { }, func(err error) error { return nil }) - assert.True(t, err == nil || err == errDummy) + assert.True(t, err == nil || errors.Is(err, errDummy)) } verify(t, func() bool { - return ErrServiceUnavailable == Do("fallback", func() error { + return errors.Is(Do("fallback", func() error { return nil - }) + }), ErrServiceUnavailable) }) } @@ -94,12 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) { }, func(err error) bool { return err == nil }) - assert.True(t, err == nil || err == errDummy) + assert.True(t, err == nil || errors.Is(err, errDummy)) } verify(t, func() bool { - return ErrServiceUnavailable == Do("acceptablefallback", func() error { + return errors.Is(Do("acceptablefallback", func() error { return nil - }) + }), ErrServiceUnavailable) }) } diff --git a/core/search/tree.go b/core/search/tree.go index 960f3f3f..c386660c 100644 --- a/core/search/tree.go +++ b/core/search/tree.go @@ -69,10 +69,10 @@ func (t *Tree) Add(route string, item any) error { } err := add(t.root, route[1:], item) - switch err { - case errDupItem: + switch { + case errors.Is(err, errDupItem): return duplicatedItem(route) - case errDupSlash: + case errors.Is(err, errDupSlash): return duplicatedSlash(route) default: return err diff --git a/core/stores/cache/cachenode.go b/core/stores/cache/cachenode.go index 3c0ccc42..312e0196 100644 --- a/core/stores/cache/cachenode.go +++ b/core/stores/cache/cachenode.go @@ -96,7 +96,7 @@ func (c cacheNode) Get(key string, val any) error { // GetCtx gets the cache with key and fills into v. func (c cacheNode) GetCtx(ctx context.Context, key string, val any) error { err := c.doGetCache(ctx, key, val) - if err == errPlaceholder { + if errors.Is(err, errPlaceholder) { return c.errNotFound } @@ -210,16 +210,16 @@ func (c cacheNode) doTake(ctx context.Context, v any, key string, logger := logx.WithContext(ctx) val, fresh, err := c.barrier.DoEx(key, func() (any, error) { if err := c.doGetCache(ctx, key, v); err != nil { - if err == errPlaceholder { + if errors.Is(err, errPlaceholder) { return nil, c.errNotFound - } else if err != c.errNotFound { + } else if !errors.Is(err, c.errNotFound) { // why we just return the error instead of query from db, // because we don't allow the disaster pass to the dbs. // fail fast, in case we bring down the dbs. return nil, err } - if err = query(v); err == c.errNotFound { + if err = query(v); errors.Is(err, c.errNotFound) { if err = c.setCacheWithNotFound(ctx, key); err != nil { logger.Error(err) } diff --git a/core/stores/mon/collection.go b/core/stores/mon/collection.go index 1c445939..e052dcc7 100644 --- a/core/stores/mon/collection.go +++ b/core/stores/mon/collection.go @@ -3,6 +3,7 @@ package mon import ( "context" "encoding/json" + "errors" "time" "github.com/zeromicro/go-zero/core/breaker" @@ -562,11 +563,19 @@ func (p keepablePromise) keep(err error) error { } func acceptable(err error) bool { - return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue || - err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice || + return err == nil || + errors.Is(err, mongo.ErrNoDocuments) || + errors.Is(err, mongo.ErrNilValue) || + errors.Is(err, mongo.ErrNilDocument) || + errors.Is(err, mongo.ErrNilCursor) || + errors.Is(err, mongo.ErrEmptySlice) || // session errors - err == session.ErrSessionEnded || err == session.ErrNoTransactStarted || - err == session.ErrTransactInProgress || err == session.ErrAbortAfterCommit || - err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort || - err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction + errors.Is(err, session.ErrSessionEnded) || + errors.Is(err, session.ErrNoTransactStarted) || + errors.Is(err, session.ErrTransactInProgress) || + errors.Is(err, session.ErrAbortAfterCommit) || + errors.Is(err, session.ErrAbortTwice) || + errors.Is(err, session.ErrCommitAfterAbort) || + errors.Is(err, session.ErrUnackWCUnsupported) || + errors.Is(err, session.ErrSnapshotTransaction) } diff --git a/core/stores/mon/trace.go b/core/stores/mon/trace.go index 1c9d6061..99bf96b6 100644 --- a/core/stores/mon/trace.go +++ b/core/stores/mon/trace.go @@ -2,6 +2,7 @@ package mon import ( "context" + "errors" "github.com/zeromicro/go-zero/core/trace" "go.mongodb.org/mongo-driver/mongo" @@ -23,8 +24,8 @@ func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span func endSpan(span oteltrace.Span, err error) { defer span.End() - if err == nil || err == mongo.ErrNoDocuments || - err == mongo.ErrNilValue || err == mongo.ErrNilDocument { + if err == nil || errors.Is(err, mongo.ErrNoDocuments) || + errors.Is(err, mongo.ErrNilValue) || errors.Is(err, mongo.ErrNilDocument) { span.SetStatus(codes.Ok, "") return } diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index 82a44ef3..c5549e3d 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -2849,7 +2849,7 @@ func withHook(hook red.Hook) Option { } func acceptable(err error) bool { - return err == nil || err == red.Nil || err == context.Canceled + return err == nil || err == red.Nil || errors.Is(err, context.Canceled) } func getRedis(r *Redis) (RedisNode, error) { diff --git a/core/stores/sqlx/mysql.go b/core/stores/sqlx/mysql.go index 6f0f731d..3c026921 100644 --- a/core/stores/sqlx/mysql.go +++ b/core/stores/sqlx/mysql.go @@ -1,6 +1,10 @@ package sqlx -import "github.com/go-sql-driver/mysql" +import ( + "errors" + + "github.com/go-sql-driver/mysql" +) const ( mysqlDriverName = "mysql" @@ -18,7 +22,8 @@ func mysqlAcceptable(err error) bool { return true } - myerr, ok := err.(*mysql.MySQLError) + var myerr *mysql.MySQLError + ok := errors.As(err, &myerr) if !ok { return false } diff --git a/core/stores/sqlx/mysql_test.go b/core/stores/sqlx/mysql_test.go index 32d730fe..68698e4b 100644 --- a/core/stores/sqlx/mysql_test.go +++ b/core/stores/sqlx/mysql_test.go @@ -28,7 +28,7 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) { var found bool for i := 0; i < 100; i++ { - if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable { + if errors.Is(tryOnDuplicateEntryError(t, nil), breaker.ErrServiceUnavailable) { found = true } } diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index 62b5936c..9603af5b 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -3,6 +3,7 @@ package sqlx import ( "context" "database/sql" + "errors" "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" @@ -157,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) ( result, err = exec(ctx, conn, q, args...) return err }, db.acceptable) - if err == breaker.ErrServiceUnavailable { + if errors.Is(err, breaker.ErrServiceUnavailable) { metricReqErr.Inc("Exec", "breaker") } @@ -193,7 +194,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm } return nil }, db.acceptable) - if err == breaker.ErrServiceUnavailable { + if errors.Is(err, breaker.ErrServiceUnavailable) { metricReqErr.Inc("Prepare", "breaker") } @@ -283,7 +284,7 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex err = db.brk.DoWithAcceptable(func() error { return transact(ctx, db, db.beginTx, fn) }, db.acceptable) - if err == breaker.ErrServiceUnavailable { + if errors.Is(err, breaker.ErrServiceUnavailable) { metricReqErr.Inc("Transact", "breaker") } @@ -291,11 +292,13 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex } func (db *commonSqlConn) acceptable(err error) bool { - if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled { + if err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) || + errors.Is(err, context.Canceled) { return true } - if _, ok := err.(acceptableError); ok { + var e acceptableError + if errors.As(err, &e) { return true } @@ -321,9 +324,9 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) return qerr }, q, args...) }, func(err error) bool { - return qerr == err || db.acceptable(err) + return errors.Is(err, qerr) || db.acceptable(err) }) - if err == breaker.ErrServiceUnavailable { + if errors.Is(err, breaker.ErrServiceUnavailable) { metricReqErr.Inc("queryRows", "breaker") } diff --git a/core/stores/sqlx/utils.go b/core/stores/sqlx/utils.go index c5944517..b3dd1337 100644 --- a/core/stores/sqlx/utils.go +++ b/core/stores/sqlx/utils.go @@ -143,7 +143,7 @@ func logInstanceError(ctx context.Context, datasource string, err error) { } func logSqlError(ctx context.Context, stmt string, err error) { - if err != nil && err != ErrNotFound { + if err != nil && !errors.Is(err, ErrNotFound) { logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error()) } } diff --git a/core/syncx/lockedcalls_test.go b/core/syncx/lockedcalls_test.go index 3ae0debc..93f6fe04 100644 --- a/core/syncx/lockedcalls_test.go +++ b/core/syncx/lockedcalls_test.go @@ -27,7 +27,7 @@ func TestLockedCallDoErr(t *testing.T) { v, err := g.Do("key", func() (any, error) { return nil, someErr }) - if err != someErr { + if !errors.Is(err, someErr) { t.Errorf("Do error = %v; want someErr", err) } if v != nil { diff --git a/core/syncx/singleflight_test.go b/core/syncx/singleflight_test.go index ba68f9d1..591c2736 100644 --- a/core/syncx/singleflight_test.go +++ b/core/syncx/singleflight_test.go @@ -28,7 +28,7 @@ func TestExclusiveCallDoErr(t *testing.T) { v, err := g.Do("key", func() (any, error) { return nil, someErr }) - if err != someErr { + if !errors.Is(err, someErr) { t.Errorf("Do error = %v; want someErr", err) } if v != nil { diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 0461c788..edd133a4 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -3,6 +3,7 @@ package httpx import ( "context" "encoding/json" + "errors" "fmt" "net/http" "sync" @@ -141,10 +142,10 @@ func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, a return } - e, ok := body.(error) - if ok { - http.Error(w, e.Error(), code) - } else { + switch v := body.(type) { + case error: + http.Error(w, v.Error(), code) + default: writeJson(w, code, body) } } @@ -162,7 +163,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error { if n, err := w.Write(bs); err != nil { // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, // so it's ignored here. - if err != http.ErrHandlerTimeout { + if !errors.Is(err, http.ErrHandlerTimeout) { return fmt.Errorf("write response failed, error: %w", err) } } else if n < len(bs) { diff --git a/rest/internal/starter.go b/rest/internal/starter.go index 6ac180d4..08cca832 100644 --- a/rest/internal/starter.go +++ b/rest/internal/starter.go @@ -2,6 +2,7 @@ package internal import ( "context" + "errors" "fmt" "net/http" @@ -49,7 +50,7 @@ func start(host string, port int, handler http.Handler, run func(svr *http.Serve } }) defer func() { - if err == http.ErrServerClosed { + if errors.Is(err, http.ErrServerClosed) { waitForCalled() } }() diff --git a/rest/server.go b/rest/server.go index 9583ea52..bbf6ff39 100644 --- a/rest/server.go +++ b/rest/server.go @@ -2,6 +2,7 @@ package rest import ( "crypto/tls" + "errors" "net/http" "path" "time" @@ -307,7 +308,7 @@ func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { func handleError(err error) { // ErrServerClosed means the server is closed manually - if err == nil || err == http.ErrServerClosed { + if err == nil || errors.Is(err, http.ErrServerClosed) { return } diff --git a/tools/goctl/model/cmd.go b/tools/goctl/model/cmd.go index 0528b3b5..32f5a843 100644 --- a/tools/goctl/model/cmd.go +++ b/tools/goctl/model/cmd.go @@ -56,7 +56,8 @@ func init() { pgDatasourceCmdFlags.StringVar(&command.VarStringHome, "home") pgDatasourceCmdFlags.StringVar(&command.VarStringRemote, "remote") pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch") - pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"}) + pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, + "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"}) mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t") mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c") @@ -68,7 +69,8 @@ func init() { mongoCmdFlags.StringVar(&mongo.VarStringBranch, "branch") mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict") - mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"}) + mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, + "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"}) mysqlCmd.AddCommand(datasourceCmd, ddlCmd) pgCmd.AddCommand(pgDatasourceCmd) diff --git a/zrpc/internal/codes/accept.go b/zrpc/internal/codes/accept.go index 8ecf292a..f11e29c2 100644 --- a/zrpc/internal/codes/accept.go +++ b/zrpc/internal/codes/accept.go @@ -8,7 +8,8 @@ import ( // Acceptable checks if given error is acceptable. func Acceptable(err error) bool { switch status.Code(err) { - case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss, codes.Unimplemented: + case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss, + codes.Unimplemented, codes.ResourceExhausted: return false default: return true diff --git a/zrpc/internal/serverinterceptors/breakerinterceptor.go b/zrpc/internal/serverinterceptors/breakerinterceptor.go index 0298658f..79d8c68c 100644 --- a/zrpc/internal/serverinterceptors/breakerinterceptor.go +++ b/zrpc/internal/serverinterceptors/breakerinterceptor.go @@ -2,10 +2,13 @@ package serverinterceptors import ( "context" + "errors" "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/zrpc/internal/codes" "google.golang.org/grpc" + gcodes "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // StreamBreakerInterceptor is an interceptor that acts as a circuit breaker. @@ -26,6 +29,9 @@ func UnaryBreakerInterceptor(ctx context.Context, req any, info *grpc.UnaryServe resp, err = handler(ctx, req) return err }, codes.Acceptable) + if errors.Is(err, breaker.ErrServiceUnavailable) { + err = status.Error(gcodes.Unavailable, err.Error()) + } return resp, err } diff --git a/zrpc/internal/serverinterceptors/sheddinginterceptor.go b/zrpc/internal/serverinterceptors/sheddinginterceptor.go index 2b8ac1f1..4885795d 100644 --- a/zrpc/internal/serverinterceptors/sheddinginterceptor.go +++ b/zrpc/internal/serverinterceptors/sheddinginterceptor.go @@ -2,11 +2,14 @@ package serverinterceptors import ( "context" + "errors" "sync" "github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/stat" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const serviceType = "rpc" @@ -28,11 +31,12 @@ func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc. if err != nil { metrics.AddDrop() sheddingStat.IncrementDrop() + err = status.Error(codes.ResourceExhausted, err.Error()) return } defer func() { - if err == context.DeadlineExceeded { + if errors.Is(err, context.DeadlineExceeded) { promise.Fail() } else { sheddingStat.IncrementPass() diff --git a/zrpc/internal/serverinterceptors/sheddinginterceptor_test.go b/zrpc/internal/serverinterceptors/sheddinginterceptor_test.go index 7b8afde6..bf8fc0a1 100644 --- a/zrpc/internal/serverinterceptors/sheddinginterceptor_test.go +++ b/zrpc/internal/serverinterceptors/sheddinginterceptor_test.go @@ -8,6 +8,8 @@ import ( "github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/stat" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestUnarySheddingInterceptor(t *testing.T) { @@ -33,7 +35,7 @@ func TestUnarySheddingInterceptor(t *testing.T) { name: "reject", allow: false, handleErr: nil, - expect: load.ErrServiceOverloaded, + expect: status.Error(codes.ResourceExhausted, load.ErrServiceOverloaded.Error()), }, } diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor.go b/zrpc/internal/serverinterceptors/timeoutinterceptor.go index 27cded03..fb652909 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor.go @@ -2,6 +2,7 @@ package serverinterceptors import ( "context" + "errors" "fmt" "runtime/debug" "strings" @@ -49,9 +50,9 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor return resp, err case <-ctx.Done(): err := ctx.Err() - if err == context.Canceled { + if errors.Is(err, context.Canceled) { err = status.Error(codes.Canceled, err.Error()) - } else if err == context.DeadlineExceeded { + } else if errors.Is(err, context.DeadlineExceeded) { err = status.Error(codes.DeadlineExceeded, err.Error()) } return nil, err