mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
chore: refactor errors to use errors.Is (#3654)
This commit is contained in:
parent
81ae7d36b5
commit
42e0a6f90c
@ -30,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
|
|||||||
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
|
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
|
||||||
return errDummy
|
return errDummy
|
||||||
}, func(err error) bool {
|
}, func(err error) bool {
|
||||||
return err == nil || err == errDummy
|
return err == nil || errors.Is(err, errDummy)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
verify(t, func() bool {
|
verify(t, func() bool {
|
||||||
@ -45,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
|
|||||||
}, func(err error) bool {
|
}, func(err error) bool {
|
||||||
return err == nil
|
return err == nil
|
||||||
})
|
})
|
||||||
assert.True(t, err == errDummy || err == ErrServiceUnavailable)
|
assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable))
|
||||||
}
|
}
|
||||||
verify(t, func() bool {
|
verify(t, func() bool {
|
||||||
return ErrServiceUnavailable == Do("another", func() error {
|
return errors.Is(Do("another", func() error {
|
||||||
return nil
|
return nil
|
||||||
})
|
}), ErrServiceUnavailable)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,12 +75,12 @@ func TestBreakersFallback(t *testing.T) {
|
|||||||
}, func(err error) error {
|
}, func(err error) error {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
assert.True(t, err == nil || err == errDummy)
|
assert.True(t, err == nil || errors.Is(err, errDummy))
|
||||||
}
|
}
|
||||||
verify(t, func() bool {
|
verify(t, func() bool {
|
||||||
return ErrServiceUnavailable == Do("fallback", func() error {
|
return errors.Is(Do("fallback", func() error {
|
||||||
return nil
|
return nil
|
||||||
})
|
}), ErrServiceUnavailable)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,12 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) {
|
|||||||
}, func(err error) bool {
|
}, func(err error) bool {
|
||||||
return err == nil
|
return err == nil
|
||||||
})
|
})
|
||||||
assert.True(t, err == nil || err == errDummy)
|
assert.True(t, err == nil || errors.Is(err, errDummy))
|
||||||
}
|
}
|
||||||
verify(t, func() bool {
|
verify(t, func() bool {
|
||||||
return ErrServiceUnavailable == Do("acceptablefallback", func() error {
|
return errors.Is(Do("acceptablefallback", func() error {
|
||||||
return nil
|
return nil
|
||||||
})
|
}), ErrServiceUnavailable)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,10 +69,10 @@ func (t *Tree) Add(route string, item any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err := add(t.root, route[1:], item)
|
err := add(t.root, route[1:], item)
|
||||||
switch err {
|
switch {
|
||||||
case errDupItem:
|
case errors.Is(err, errDupItem):
|
||||||
return duplicatedItem(route)
|
return duplicatedItem(route)
|
||||||
case errDupSlash:
|
case errors.Is(err, errDupSlash):
|
||||||
return duplicatedSlash(route)
|
return duplicatedSlash(route)
|
||||||
default:
|
default:
|
||||||
return err
|
return err
|
||||||
|
8
core/stores/cache/cachenode.go
vendored
8
core/stores/cache/cachenode.go
vendored
@ -96,7 +96,7 @@ func (c cacheNode) Get(key string, val any) error {
|
|||||||
// GetCtx gets the cache with key and fills into v.
|
// GetCtx gets the cache with key and fills into v.
|
||||||
func (c cacheNode) GetCtx(ctx context.Context, key string, val any) error {
|
func (c cacheNode) GetCtx(ctx context.Context, key string, val any) error {
|
||||||
err := c.doGetCache(ctx, key, val)
|
err := c.doGetCache(ctx, key, val)
|
||||||
if err == errPlaceholder {
|
if errors.Is(err, errPlaceholder) {
|
||||||
return c.errNotFound
|
return c.errNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,16 +210,16 @@ func (c cacheNode) doTake(ctx context.Context, v any, key string,
|
|||||||
logger := logx.WithContext(ctx)
|
logger := logx.WithContext(ctx)
|
||||||
val, fresh, err := c.barrier.DoEx(key, func() (any, error) {
|
val, fresh, err := c.barrier.DoEx(key, func() (any, error) {
|
||||||
if err := c.doGetCache(ctx, key, v); err != nil {
|
if err := c.doGetCache(ctx, key, v); err != nil {
|
||||||
if err == errPlaceholder {
|
if errors.Is(err, errPlaceholder) {
|
||||||
return nil, c.errNotFound
|
return nil, c.errNotFound
|
||||||
} else if err != c.errNotFound {
|
} else if !errors.Is(err, c.errNotFound) {
|
||||||
// why we just return the error instead of query from db,
|
// why we just return the error instead of query from db,
|
||||||
// because we don't allow the disaster pass to the dbs.
|
// because we don't allow the disaster pass to the dbs.
|
||||||
// fail fast, in case we bring down the dbs.
|
// fail fast, in case we bring down the dbs.
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = query(v); err == c.errNotFound {
|
if err = query(v); errors.Is(err, c.errNotFound) {
|
||||||
if err = c.setCacheWithNotFound(ctx, key); err != nil {
|
if err = c.setCacheWithNotFound(ctx, key); err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package mon
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
@ -562,11 +563,19 @@ func (p keepablePromise) keep(err error) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func acceptable(err error) bool {
|
func acceptable(err error) bool {
|
||||||
return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue ||
|
return err == nil ||
|
||||||
err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice ||
|
errors.Is(err, mongo.ErrNoDocuments) ||
|
||||||
|
errors.Is(err, mongo.ErrNilValue) ||
|
||||||
|
errors.Is(err, mongo.ErrNilDocument) ||
|
||||||
|
errors.Is(err, mongo.ErrNilCursor) ||
|
||||||
|
errors.Is(err, mongo.ErrEmptySlice) ||
|
||||||
// session errors
|
// session errors
|
||||||
err == session.ErrSessionEnded || err == session.ErrNoTransactStarted ||
|
errors.Is(err, session.ErrSessionEnded) ||
|
||||||
err == session.ErrTransactInProgress || err == session.ErrAbortAfterCommit ||
|
errors.Is(err, session.ErrNoTransactStarted) ||
|
||||||
err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort ||
|
errors.Is(err, session.ErrTransactInProgress) ||
|
||||||
err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction
|
errors.Is(err, session.ErrAbortAfterCommit) ||
|
||||||
|
errors.Is(err, session.ErrAbortTwice) ||
|
||||||
|
errors.Is(err, session.ErrCommitAfterAbort) ||
|
||||||
|
errors.Is(err, session.ErrUnackWCUnsupported) ||
|
||||||
|
errors.Is(err, session.ErrSnapshotTransaction)
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package mon
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/trace"
|
"github.com/zeromicro/go-zero/core/trace"
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
@ -23,8 +24,8 @@ func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span
|
|||||||
func endSpan(span oteltrace.Span, err error) {
|
func endSpan(span oteltrace.Span, err error) {
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
if err == nil || err == mongo.ErrNoDocuments ||
|
if err == nil || errors.Is(err, mongo.ErrNoDocuments) ||
|
||||||
err == mongo.ErrNilValue || err == mongo.ErrNilDocument {
|
errors.Is(err, mongo.ErrNilValue) || errors.Is(err, mongo.ErrNilDocument) {
|
||||||
span.SetStatus(codes.Ok, "")
|
span.SetStatus(codes.Ok, "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -2849,7 +2849,7 @@ func withHook(hook red.Hook) Option {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func acceptable(err error) bool {
|
func acceptable(err error) bool {
|
||||||
return err == nil || err == red.Nil || err == context.Canceled
|
return err == nil || err == red.Nil || errors.Is(err, context.Canceled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRedis(r *Redis) (RedisNode, error) {
|
func getRedis(r *Redis) (RedisNode, error) {
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
package sqlx
|
package sqlx
|
||||||
|
|
||||||
import "github.com/go-sql-driver/mysql"
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/go-sql-driver/mysql"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
mysqlDriverName = "mysql"
|
mysqlDriverName = "mysql"
|
||||||
@ -18,7 +22,8 @@ func mysqlAcceptable(err error) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
myerr, ok := err.(*mysql.MySQLError)
|
var myerr *mysql.MySQLError
|
||||||
|
ok := errors.As(err, &myerr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
|
|||||||
|
|
||||||
var found bool
|
var found bool
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable {
|
if errors.Is(tryOnDuplicateEntryError(t, nil), breaker.ErrServiceUnavailable) {
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package sqlx
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
@ -157,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
|
|||||||
result, err = exec(ctx, conn, q, args...)
|
result, err = exec(ctx, conn, q, args...)
|
||||||
return err
|
return err
|
||||||
}, db.acceptable)
|
}, db.acceptable)
|
||||||
if err == breaker.ErrServiceUnavailable {
|
if errors.Is(err, breaker.ErrServiceUnavailable) {
|
||||||
metricReqErr.Inc("Exec", "breaker")
|
metricReqErr.Inc("Exec", "breaker")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,7 +194,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}, db.acceptable)
|
}, db.acceptable)
|
||||||
if err == breaker.ErrServiceUnavailable {
|
if errors.Is(err, breaker.ErrServiceUnavailable) {
|
||||||
metricReqErr.Inc("Prepare", "breaker")
|
metricReqErr.Inc("Prepare", "breaker")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,7 +284,7 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
|
|||||||
err = db.brk.DoWithAcceptable(func() error {
|
err = db.brk.DoWithAcceptable(func() error {
|
||||||
return transact(ctx, db, db.beginTx, fn)
|
return transact(ctx, db, db.beginTx, fn)
|
||||||
}, db.acceptable)
|
}, db.acceptable)
|
||||||
if err == breaker.ErrServiceUnavailable {
|
if errors.Is(err, breaker.ErrServiceUnavailable) {
|
||||||
metricReqErr.Inc("Transact", "breaker")
|
metricReqErr.Inc("Transact", "breaker")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -291,11 +292,13 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) acceptable(err error) bool {
|
func (db *commonSqlConn) acceptable(err error) bool {
|
||||||
if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
|
if err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) ||
|
||||||
|
errors.Is(err, context.Canceled) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := err.(acceptableError); ok {
|
var e acceptableError
|
||||||
|
if errors.As(err, &e) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,9 +324,9 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
|||||||
return qerr
|
return qerr
|
||||||
}, q, args...)
|
}, q, args...)
|
||||||
}, func(err error) bool {
|
}, func(err error) bool {
|
||||||
return qerr == err || db.acceptable(err)
|
return errors.Is(err, qerr) || db.acceptable(err)
|
||||||
})
|
})
|
||||||
if err == breaker.ErrServiceUnavailable {
|
if errors.Is(err, breaker.ErrServiceUnavailable) {
|
||||||
metricReqErr.Inc("queryRows", "breaker")
|
metricReqErr.Inc("queryRows", "breaker")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ func logInstanceError(ctx context.Context, datasource string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func logSqlError(ctx context.Context, stmt string, err error) {
|
func logSqlError(ctx context.Context, stmt string, err error) {
|
||||||
if err != nil && err != ErrNotFound {
|
if err != nil && !errors.Is(err, ErrNotFound) {
|
||||||
logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
|
logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ func TestLockedCallDoErr(t *testing.T) {
|
|||||||
v, err := g.Do("key", func() (any, error) {
|
v, err := g.Do("key", func() (any, error) {
|
||||||
return nil, someErr
|
return nil, someErr
|
||||||
})
|
})
|
||||||
if err != someErr {
|
if !errors.Is(err, someErr) {
|
||||||
t.Errorf("Do error = %v; want someErr", err)
|
t.Errorf("Do error = %v; want someErr", err)
|
||||||
}
|
}
|
||||||
if v != nil {
|
if v != nil {
|
||||||
|
@ -28,7 +28,7 @@ func TestExclusiveCallDoErr(t *testing.T) {
|
|||||||
v, err := g.Do("key", func() (any, error) {
|
v, err := g.Do("key", func() (any, error) {
|
||||||
return nil, someErr
|
return nil, someErr
|
||||||
})
|
})
|
||||||
if err != someErr {
|
if !errors.Is(err, someErr) {
|
||||||
t.Errorf("Do error = %v; want someErr", err)
|
t.Errorf("Do error = %v; want someErr", err)
|
||||||
}
|
}
|
||||||
if v != nil {
|
if v != nil {
|
||||||
|
@ -3,6 +3,7 @@ package httpx
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
@ -141,10 +142,10 @@ func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, a
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
e, ok := body.(error)
|
switch v := body.(type) {
|
||||||
if ok {
|
case error:
|
||||||
http.Error(w, e.Error(), code)
|
http.Error(w, v.Error(), code)
|
||||||
} else {
|
default:
|
||||||
writeJson(w, code, body)
|
writeJson(w, code, body)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -162,7 +163,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
|
|||||||
if n, err := w.Write(bs); err != nil {
|
if n, err := w.Write(bs); err != nil {
|
||||||
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
|
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
|
||||||
// so it's ignored here.
|
// so it's ignored here.
|
||||||
if err != http.ErrHandlerTimeout {
|
if !errors.Is(err, http.ErrHandlerTimeout) {
|
||||||
return fmt.Errorf("write response failed, error: %w", err)
|
return fmt.Errorf("write response failed, error: %w", err)
|
||||||
}
|
}
|
||||||
} else if n < len(bs) {
|
} else if n < len(bs) {
|
||||||
|
@ -2,6 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ func start(host string, port int, handler http.Handler, run func(svr *http.Serve
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
defer func() {
|
defer func() {
|
||||||
if err == http.ErrServerClosed {
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
waitForCalled()
|
waitForCalled()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -2,6 +2,7 @@ package rest
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"time"
|
"time"
|
||||||
@ -307,7 +308,7 @@ func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
|||||||
|
|
||||||
func handleError(err error) {
|
func handleError(err error) {
|
||||||
// ErrServerClosed means the server is closed manually
|
// ErrServerClosed means the server is closed manually
|
||||||
if err == nil || err == http.ErrServerClosed {
|
if err == nil || errors.Is(err, http.ErrServerClosed) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,7 +56,8 @@ func init() {
|
|||||||
pgDatasourceCmdFlags.StringVar(&command.VarStringHome, "home")
|
pgDatasourceCmdFlags.StringVar(&command.VarStringHome, "home")
|
||||||
pgDatasourceCmdFlags.StringVar(&command.VarStringRemote, "remote")
|
pgDatasourceCmdFlags.StringVar(&command.VarStringRemote, "remote")
|
||||||
pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
|
pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
|
||||||
pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
|
pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
|
||||||
|
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
|
||||||
|
|
||||||
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
|
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
|
||||||
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
|
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
|
||||||
@ -68,7 +69,8 @@ func init() {
|
|||||||
mongoCmdFlags.StringVar(&mongo.VarStringBranch, "branch")
|
mongoCmdFlags.StringVar(&mongo.VarStringBranch, "branch")
|
||||||
|
|
||||||
mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict")
|
mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict")
|
||||||
mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
|
mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
|
||||||
|
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
|
||||||
|
|
||||||
mysqlCmd.AddCommand(datasourceCmd, ddlCmd)
|
mysqlCmd.AddCommand(datasourceCmd, ddlCmd)
|
||||||
pgCmd.AddCommand(pgDatasourceCmd)
|
pgCmd.AddCommand(pgDatasourceCmd)
|
||||||
|
@ -8,7 +8,8 @@ import (
|
|||||||
// Acceptable checks if given error is acceptable.
|
// Acceptable checks if given error is acceptable.
|
||||||
func Acceptable(err error) bool {
|
func Acceptable(err error) bool {
|
||||||
switch status.Code(err) {
|
switch status.Code(err) {
|
||||||
case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss, codes.Unimplemented:
|
case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss,
|
||||||
|
codes.Unimplemented, codes.ResourceExhausted:
|
||||||
return false
|
return false
|
||||||
default:
|
default:
|
||||||
return true
|
return true
|
||||||
|
@ -2,10 +2,13 @@ package serverinterceptors
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/codes"
|
"github.com/zeromicro/go-zero/zrpc/internal/codes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
gcodes "google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StreamBreakerInterceptor is an interceptor that acts as a circuit breaker.
|
// StreamBreakerInterceptor is an interceptor that acts as a circuit breaker.
|
||||||
@ -26,6 +29,9 @@ func UnaryBreakerInterceptor(ctx context.Context, req any, info *grpc.UnaryServe
|
|||||||
resp, err = handler(ctx, req)
|
resp, err = handler(ctx, req)
|
||||||
return err
|
return err
|
||||||
}, codes.Acceptable)
|
}, codes.Acceptable)
|
||||||
|
if errors.Is(err, breaker.ErrServiceUnavailable) {
|
||||||
|
err = status.Error(gcodes.Unavailable, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,14 @@ package serverinterceptors
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/load"
|
"github.com/zeromicro/go-zero/core/load"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const serviceType = "rpc"
|
const serviceType = "rpc"
|
||||||
@ -28,11 +31,12 @@ func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
metrics.AddDrop()
|
metrics.AddDrop()
|
||||||
sheddingStat.IncrementDrop()
|
sheddingStat.IncrementDrop()
|
||||||
|
err = status.Error(codes.ResourceExhausted, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err == context.DeadlineExceeded {
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
promise.Fail()
|
promise.Fail()
|
||||||
} else {
|
} else {
|
||||||
sheddingStat.IncrementPass()
|
sheddingStat.IncrementPass()
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/load"
|
"github.com/zeromicro/go-zero/core/load"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUnarySheddingInterceptor(t *testing.T) {
|
func TestUnarySheddingInterceptor(t *testing.T) {
|
||||||
@ -33,7 +35,7 @@ func TestUnarySheddingInterceptor(t *testing.T) {
|
|||||||
name: "reject",
|
name: "reject",
|
||||||
allow: false,
|
allow: false,
|
||||||
handleErr: nil,
|
handleErr: nil,
|
||||||
expect: load.ErrServiceOverloaded,
|
expect: status.Error(codes.ResourceExhausted, load.ErrServiceOverloaded.Error()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package serverinterceptors
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
@ -49,9 +50,9 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor
|
|||||||
return resp, err
|
return resp, err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
err := ctx.Err()
|
err := ctx.Err()
|
||||||
if err == context.Canceled {
|
if errors.Is(err, context.Canceled) {
|
||||||
err = status.Error(codes.Canceled, err.Error())
|
err = status.Error(codes.Canceled, err.Error())
|
||||||
} else if err == context.DeadlineExceeded {
|
} else if errors.Is(err, context.DeadlineExceeded) {
|
||||||
err = status.Error(codes.DeadlineExceeded, err.Error())
|
err = status.Error(codes.DeadlineExceeded, err.Error())
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
Reference in New Issue
Block a user