feat: support breaker with sql statements (#3936)

This commit is contained in:
Kevin Wan 2024-02-25 11:24:44 +08:00 committed by GitHub
parent 914bcdcf2b
commit 0dfaf135dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 177 additions and 114 deletions

View File

@ -59,7 +59,7 @@ func GetBreaker(name string) Breaker {
// NoBreakerFor disables the circuit breaker for the given name.
func NoBreakerFor(name string) {
lock.Lock()
breakers[name] = newNopBreaker()
breakers[name] = NopBreaker()
lock.Unlock()
}

View File

@ -4,7 +4,8 @@ const nopBreakerName = "nopBreaker"
type nopBreaker struct{}
func newNopBreaker() Breaker {
// NopBreaker returns a breaker that never trigger breaker circuit.
func NopBreaker() Breaker {
return nopBreaker{}
}

View File

@ -8,7 +8,7 @@ import (
)
func TestNopBreaker(t *testing.T) {
b := newNopBreaker()
b := NopBreaker()
assert.Equal(t, nopBreakerName, b.Name())
p, err := b.Allow()
assert.Nil(t, err)

View File

@ -42,21 +42,6 @@ type (
// SqlOption defines the method to customize a sql connection.
SqlOption func(*commonSqlConn)
// StmtSession interface represents a session that can be used to execute statements.
StmtSession interface {
Close() error
Exec(args ...any) (sql.Result, error)
ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
QueryRow(v any, args ...any) error
QueryRowCtx(ctx context.Context, v any, args ...any) error
QueryRowPartial(v any, args ...any) error
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
QueryRows(v any, args ...any) error
QueryRowsCtx(ctx context.Context, v any, args ...any) error
QueryRowsPartial(v any, args ...any) error
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
}
// thread-safe
// Because CORBA doesn't support PREPARE, so we need to combine the
// query arguments into one string and do underlying query without arguments
@ -65,7 +50,7 @@ type (
onError func(context.Context, error)
beginTx beginnable
brk breaker.Breaker
accept func(error) bool
accept breaker.Acceptable
}
connProvider func() (*sql.DB, error)
@ -76,18 +61,6 @@ type (
Query(query string, args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
statement struct {
query string
stmt *sql.Stmt
}
stmtConn interface {
Exec(args ...any) (sql.Result, error)
ExecContext(ctx context.Context, args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
}
)
// NewSqlConn returns a SqlConn with given driver name and datasource.
@ -189,8 +162,10 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
}
stmt = statement{
query: query,
stmt: st,
query: query,
stmt: st,
brk: db.brk,
accept: db.acceptable,
}
return nil
}, db.acceptable)
@ -311,7 +286,7 @@ func (db *commonSqlConn) acceptable(err error) bool {
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
q string, args ...any) (err error) {
var qerr error
var scanFailed bool
err = db.brk.DoWithAcceptable(func() error {
conn, err := db.connProv()
if err != nil {
@ -320,11 +295,14 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
}
return query(ctx, conn, func(rows *sql.Rows) error {
qerr = scanner(rows)
return qerr
e := scanner(rows)
if e != nil {
scanFailed = true
}
return e
}, q, args...)
}, func(err error) bool {
return errors.Is(err, qerr) || db.acceptable(err)
return scanFailed || db.acceptable(err)
})
if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("queryRows", "breaker")
@ -333,83 +311,6 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return
}
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...any) (sql.Result, error) {
return s.ExecCtx(context.Background(), args...)
}
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
ctx, span := startSpan(ctx, "Exec")
defer func() {
endSpan(span, err)
}()
return execStmt(ctx, s.stmt, s.query, args...)
}
func (s statement) QueryRow(v any, args ...any) error {
return s.QueryRowCtx(context.Background(), v, args...)
}
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRow")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, s.query, args...)
}
func (s statement) QueryRowPartial(v any, args ...any) error {
return s.QueryRowPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowPartial")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, s.query, args...)
}
func (s statement) QueryRows(v any, args ...any) error {
return s.QueryRowsCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRows")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, s.query, args...)
}
func (s statement) QueryRowsPartial(v any, args ...any) error {
return s.QueryRowsPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, s.query, args...)
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {

View File

@ -156,6 +156,7 @@ func TestStatement(t *testing.T) {
st := statement{
query: "foo",
stmt: stmt,
brk: breaker.NopBreaker(),
}
assert.NoError(t, st.Close())
})

View File

@ -5,6 +5,7 @@ import (
"database/sql"
"time"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex"
@ -18,6 +19,137 @@ var (
logSlowSql = syncx.ForAtomicBool(true)
)
type (
// StmtSession interface represents a session that can be used to execute statements.
StmtSession interface {
Close() error
Exec(args ...any) (sql.Result, error)
ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
QueryRow(v any, args ...any) error
QueryRowCtx(ctx context.Context, v any, args ...any) error
QueryRowPartial(v any, args ...any) error
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
QueryRows(v any, args ...any) error
QueryRowsCtx(ctx context.Context, v any, args ...any) error
QueryRowsPartial(v any, args ...any) error
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
}
statement struct {
query string
stmt *sql.Stmt
brk breaker.Breaker
accept breaker.Acceptable
}
stmtConn interface {
Exec(args ...any) (sql.Result, error)
ExecContext(ctx context.Context, args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
}
)
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...any) (sql.Result, error) {
return s.ExecCtx(context.Background(), args...)
}
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
ctx, span := startSpan(ctx, "Exec")
defer func() {
endSpan(span, err)
}()
err = s.brk.DoWithAcceptable(func() error {
result, err = execStmt(ctx, s.stmt, s.query, args...)
return err
}, func(err error) bool {
return s.accept(err)
})
return
}
func (s statement) QueryRow(v any, args ...any) error {
return s.QueryRowCtx(context.Background(), v, args...)
}
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRow")
defer func() {
endSpan(span, err)
}()
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
return unmarshalRow(v, scanner, true)
}, v, args...)
}
func (s statement) QueryRowPartial(v any, args ...any) error {
return s.QueryRowPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowPartial")
defer func() {
endSpan(span, err)
}()
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
return unmarshalRow(v, scanner, false)
}, v, args...)
}
func (s statement) QueryRows(v any, args ...any) error {
return s.QueryRowsCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRows")
defer func() {
endSpan(span, err)
}()
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
return unmarshalRows(v, scanner, true)
}, v, args...)
}
func (s statement) QueryRowsPartial(v any, args ...any) error {
return s.QueryRowsPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
return unmarshalRows(v, scanner, false)
}, v, args...)
}
func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner) error,
v any, args ...any) error {
var scanFailed bool
return s.brk.DoWithAcceptable(func() error {
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
err := scanFn(v, rows)
if err != nil {
scanFailed = true
}
return err
}, s.query, args...)
}, func(err error) bool {
return scanFailed || s.accept(err)
})
}
// DisableLog disables logging of sql statements, includes info and slow logs.
func DisableLog() {
logSql.Set(false)

View File

@ -7,7 +7,10 @@ import (
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/stores/dbtest"
)
var errMockedPlaceholder = errors.New("placeholder")
@ -219,6 +222,28 @@ func TestNilGuard(t *testing.T) {
assert.Equal(t, nilGuard{}, guard)
}
func TestStmtScanFailed(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
conn := NewSqlConnFromDB(db)
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
var val struct {
Foo int
Bar string
}
for i := 0; i < 1000; i++ {
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(row)
err := stmt.QueryRow(&val)
assert.Error(t, err)
assert.NotErrorIs(t, err, breaker.ErrServiceUnavailable)
}
})
}
type mockedSessionConn struct {
lastInsertId int64
rowsAffected int64

View File

@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"fmt"
"github.com/zeromicro/go-zero/core/breaker"
)
type (
@ -75,6 +77,7 @@ func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSe
return statement{
query: q,
stmt: stmt,
brk: breaker.NopBreaker(),
}, nil
}