mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
feat: support breaker with sql statements (#3936)
This commit is contained in:
parent
914bcdcf2b
commit
0dfaf135dd
@ -59,7 +59,7 @@ func GetBreaker(name string) Breaker {
|
||||
// NoBreakerFor disables the circuit breaker for the given name.
|
||||
func NoBreakerFor(name string) {
|
||||
lock.Lock()
|
||||
breakers[name] = newNopBreaker()
|
||||
breakers[name] = NopBreaker()
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,8 @@ const nopBreakerName = "nopBreaker"
|
||||
|
||||
type nopBreaker struct{}
|
||||
|
||||
func newNopBreaker() Breaker {
|
||||
// NopBreaker returns a breaker that never trigger breaker circuit.
|
||||
func NopBreaker() Breaker {
|
||||
return nopBreaker{}
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNopBreaker(t *testing.T) {
|
||||
b := newNopBreaker()
|
||||
b := NopBreaker()
|
||||
assert.Equal(t, nopBreakerName, b.Name())
|
||||
p, err := b.Allow()
|
||||
assert.Nil(t, err)
|
||||
|
@ -42,21 +42,6 @@ type (
|
||||
// SqlOption defines the method to customize a sql connection.
|
||||
SqlOption func(*commonSqlConn)
|
||||
|
||||
// StmtSession interface represents a session that can be used to execute statements.
|
||||
StmtSession interface {
|
||||
Close() error
|
||||
Exec(args ...any) (sql.Result, error)
|
||||
ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
|
||||
QueryRow(v any, args ...any) error
|
||||
QueryRowCtx(ctx context.Context, v any, args ...any) error
|
||||
QueryRowPartial(v any, args ...any) error
|
||||
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
|
||||
QueryRows(v any, args ...any) error
|
||||
QueryRowsCtx(ctx context.Context, v any, args ...any) error
|
||||
QueryRowsPartial(v any, args ...any) error
|
||||
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
|
||||
}
|
||||
|
||||
// thread-safe
|
||||
// Because CORBA doesn't support PREPARE, so we need to combine the
|
||||
// query arguments into one string and do underlying query without arguments
|
||||
@ -65,7 +50,7 @@ type (
|
||||
onError func(context.Context, error)
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept func(error) bool
|
||||
accept breaker.Acceptable
|
||||
}
|
||||
|
||||
connProvider func() (*sql.DB, error)
|
||||
@ -76,18 +61,6 @@ type (
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
statement struct {
|
||||
query string
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
stmtConn interface {
|
||||
Exec(args ...any) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, args ...any) (sql.Result, error)
|
||||
Query(args ...any) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
)
|
||||
|
||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||
@ -189,8 +162,10 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
|
||||
}
|
||||
|
||||
stmt = statement{
|
||||
query: query,
|
||||
stmt: st,
|
||||
query: query,
|
||||
stmt: st,
|
||||
brk: db.brk,
|
||||
accept: db.acceptable,
|
||||
}
|
||||
return nil
|
||||
}, db.acceptable)
|
||||
@ -311,7 +286,7 @@ func (db *commonSqlConn) acceptable(err error) bool {
|
||||
|
||||
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
|
||||
q string, args ...any) (err error) {
|
||||
var qerr error
|
||||
var scanFailed bool
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
conn, err := db.connProv()
|
||||
if err != nil {
|
||||
@ -320,11 +295,14 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
}
|
||||
|
||||
return query(ctx, conn, func(rows *sql.Rows) error {
|
||||
qerr = scanner(rows)
|
||||
return qerr
|
||||
e := scanner(rows)
|
||||
if e != nil {
|
||||
scanFailed = true
|
||||
}
|
||||
return e
|
||||
}, q, args...)
|
||||
}, func(err error) bool {
|
||||
return errors.Is(err, qerr) || db.acceptable(err)
|
||||
return scanFailed || db.acceptable(err)
|
||||
})
|
||||
if errors.Is(err, breaker.ErrServiceUnavailable) {
|
||||
metricReqErr.Inc("queryRows", "breaker")
|
||||
@ -333,83 +311,6 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
return
|
||||
}
|
||||
|
||||
func (s statement) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...any) (sql.Result, error) {
|
||||
return s.ExecCtx(context.Background(), args...)
|
||||
}
|
||||
|
||||
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
|
||||
ctx, span := startSpan(ctx, "Exec")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return execStmt(ctx, s.stmt, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v any, args ...any) error {
|
||||
return s.QueryRowCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRow")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v any, args ...any) error {
|
||||
return s.QueryRowPartialCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRowPartial")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v any, args ...any) error {
|
||||
return s.QueryRowsCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRows")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v any, args ...any) error {
|
||||
return s.QueryRowsPartialCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRowsPartial")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
||||
// acceptable is the func to check if the error can be accepted.
|
||||
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||
|
@ -156,6 +156,7 @@ func TestStatement(t *testing.T) {
|
||||
st := statement{
|
||||
query: "foo",
|
||||
stmt: stmt,
|
||||
brk: breaker.NopBreaker(),
|
||||
}
|
||||
assert.NoError(t, st.Close())
|
||||
})
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
@ -18,6 +19,137 @@ var (
|
||||
logSlowSql = syncx.ForAtomicBool(true)
|
||||
)
|
||||
|
||||
type (
|
||||
// StmtSession interface represents a session that can be used to execute statements.
|
||||
StmtSession interface {
|
||||
Close() error
|
||||
Exec(args ...any) (sql.Result, error)
|
||||
ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
|
||||
QueryRow(v any, args ...any) error
|
||||
QueryRowCtx(ctx context.Context, v any, args ...any) error
|
||||
QueryRowPartial(v any, args ...any) error
|
||||
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
|
||||
QueryRows(v any, args ...any) error
|
||||
QueryRowsCtx(ctx context.Context, v any, args ...any) error
|
||||
QueryRowsPartial(v any, args ...any) error
|
||||
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
|
||||
}
|
||||
|
||||
statement struct {
|
||||
query string
|
||||
stmt *sql.Stmt
|
||||
brk breaker.Breaker
|
||||
accept breaker.Acceptable
|
||||
}
|
||||
|
||||
stmtConn interface {
|
||||
Exec(args ...any) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, args ...any) (sql.Result, error)
|
||||
Query(args ...any) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
)
|
||||
|
||||
func (s statement) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...any) (sql.Result, error) {
|
||||
return s.ExecCtx(context.Background(), args...)
|
||||
}
|
||||
|
||||
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
|
||||
ctx, span := startSpan(ctx, "Exec")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
result, err = execStmt(ctx, s.stmt, s.query, args...)
|
||||
return err
|
||||
}, func(err error) bool {
|
||||
return s.accept(err)
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v any, args ...any) error {
|
||||
return s.QueryRowCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRow")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
|
||||
return unmarshalRow(v, scanner, true)
|
||||
}, v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v any, args ...any) error {
|
||||
return s.QueryRowPartialCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRowPartial")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
|
||||
return unmarshalRow(v, scanner, false)
|
||||
}, v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v any, args ...any) error {
|
||||
return s.QueryRowsCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRows")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
|
||||
return unmarshalRows(v, scanner, true)
|
||||
}, v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v any, args ...any) error {
|
||||
return s.QueryRowsPartialCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
|
||||
ctx, span := startSpan(ctx, "QueryRowsPartial")
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
}()
|
||||
|
||||
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
|
||||
return unmarshalRows(v, scanner, false)
|
||||
}, v, args...)
|
||||
}
|
||||
|
||||
func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner) error,
|
||||
v any, args ...any) error {
|
||||
var scanFailed bool
|
||||
return s.brk.DoWithAcceptable(func() error {
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
err := scanFn(v, rows)
|
||||
if err != nil {
|
||||
scanFailed = true
|
||||
}
|
||||
return err
|
||||
}, s.query, args...)
|
||||
}, func(err error) bool {
|
||||
return scanFailed || s.accept(err)
|
||||
})
|
||||
}
|
||||
|
||||
// DisableLog disables logging of sql statements, includes info and slow logs.
|
||||
func DisableLog() {
|
||||
logSql.Set(false)
|
||||
|
@ -7,7 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/stores/dbtest"
|
||||
)
|
||||
|
||||
var errMockedPlaceholder = errors.New("placeholder")
|
||||
@ -219,6 +222,28 @@ func TestNilGuard(t *testing.T) {
|
||||
assert.Equal(t, nilGuard{}, guard)
|
||||
}
|
||||
|
||||
func TestStmtScanFailed(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectPrepare("any")
|
||||
|
||||
conn := NewSqlConnFromDB(db)
|
||||
stmt, err := conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var val struct {
|
||||
Foo int
|
||||
Bar string
|
||||
}
|
||||
for i := 0; i < 1000; i++ {
|
||||
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(row)
|
||||
err := stmt.QueryRow(&val)
|
||||
assert.Error(t, err)
|
||||
assert.NotErrorIs(t, err, breaker.ErrServiceUnavailable)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type mockedSessionConn struct {
|
||||
lastInsertId int64
|
||||
rowsAffected int64
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -75,6 +77,7 @@ func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSe
|
||||
return statement{
|
||||
query: q,
|
||||
stmt: stmt,
|
||||
brk: breaker.NopBreaker(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user