mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-03 00:38:40 +08:00
print entire sql statements in logx if necessary (#704)
This commit is contained in:
parent
73906f996d
commit
aaa39e17a3
@ -56,7 +56,8 @@ type (
|
||||
}
|
||||
|
||||
statement struct {
|
||||
stmt *sql.Stmt
|
||||
query string
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
stmtConn interface {
|
||||
@ -111,7 +112,8 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
}
|
||||
|
||||
stmt = statement{
|
||||
stmt: st,
|
||||
query: query,
|
||||
stmt: st,
|
||||
}
|
||||
return nil
|
||||
}, db.acceptable)
|
||||
@ -181,29 +183,29 @@ func (s statement) Close() error {
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(s.stmt, args...)
|
||||
return execStmt(s.stmt, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
@ -12,10 +11,14 @@ import (
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := formatForPrint(q, args)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
@ -28,11 +31,15 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
@ -46,10 +53,14 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
@ -64,8 +75,12 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
|
||||
return scanner(rows)
|
||||
}
|
||||
|
||||
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error {
|
||||
stmt := fmt.Sprint(args...)
|
||||
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(args...)
|
||||
duration := timex.Since(startTime)
|
||||
|
@ -14,6 +14,7 @@ var errMockedPlaceholder = errors.New("placeholder")
|
||||
func TestStmt_exec(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
delay bool
|
||||
hasError bool
|
||||
@ -23,18 +24,28 @@ func TestStmt_exec(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
lastInsertId: 1,
|
||||
rowsAffected: 2,
|
||||
},
|
||||
{
|
||||
name: "exec error",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "exec more args error",
|
||||
query: "select user from users where id=? and name=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "slowcall",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
delay: true,
|
||||
lastInsertId: 1,
|
||||
@ -51,7 +62,7 @@ func TestStmt_exec(t *testing.T) {
|
||||
rowsAffected: test.rowsAffected,
|
||||
err: test.err,
|
||||
delay: test.delay,
|
||||
}, "select user from users where id=?", args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
func(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(&mockedStmtConn{
|
||||
@ -59,7 +70,7 @@ func TestStmt_exec(t *testing.T) {
|
||||
rowsAffected: test.rowsAffected,
|
||||
err: test.err,
|
||||
delay: test.delay,
|
||||
}, args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
}
|
||||
|
||||
@ -89,23 +100,34 @@ func TestStmt_exec(t *testing.T) {
|
||||
func TestStmt_query(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
delay bool
|
||||
hasError bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
args: []interface{}{1},
|
||||
name: "normal",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "query error",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "query more args error",
|
||||
query: "select user from users where id=? and name=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "slowcall",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
delay: true,
|
||||
},
|
||||
@ -120,7 +142,7 @@ func TestStmt_query(t *testing.T) {
|
||||
delay: test.delay,
|
||||
}, func(rows *sql.Rows) error {
|
||||
return nil
|
||||
}, "select user from users where id=?", args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
func(args ...interface{}) error {
|
||||
return queryStmt(&mockedStmtConn{
|
||||
@ -128,7 +150,7 @@ func TestStmt_query(t *testing.T) {
|
||||
delay: test.delay,
|
||||
}, func(rows *sql.Rows) error {
|
||||
return nil
|
||||
}, args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
}
|
||||
|
||||
@ -143,7 +165,7 @@ func TestStmt_query(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, errMockedPlaceholder, err)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package sqlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
@ -45,24 +46,6 @@ func escape(input string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatForPrint(query string, args ...interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for _, arg := range args {
|
||||
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
b.WriteString(strings.Join(vals, ", "))
|
||||
b.WriteByte(']')
|
||||
|
||||
return strings.Join([]string{query, b.String()}, " ")
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
@ -72,36 +55,50 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
var b strings.Builder
|
||||
argIndex := 0
|
||||
|
||||
for _, ch := range query {
|
||||
if ch == '?' {
|
||||
bytes := len(query)
|
||||
for i := 0; i < bytes; i++ {
|
||||
ch := query[i]
|
||||
switch ch {
|
||||
case '?':
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
writeValue(&b, args[argIndex])
|
||||
argIndex++
|
||||
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
b.WriteByte('1')
|
||||
} else {
|
||||
b.WriteByte('0')
|
||||
case '$':
|
||||
var j int
|
||||
for j = i + 1; j < bytes; j++ {
|
||||
char := query[j]
|
||||
if char < '0' || '9' < char {
|
||||
break
|
||||
}
|
||||
case string:
|
||||
b.WriteByte('\'')
|
||||
b.WriteString(escape(v))
|
||||
b.WriteByte('\'')
|
||||
default:
|
||||
b.WriteString(mapping.Repr(v))
|
||||
}
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
if j > i+1 {
|
||||
index, err := strconv.Atoi(query[i+1 : j])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// index starts from 1 for pg
|
||||
if index > argIndex {
|
||||
argIndex = index
|
||||
}
|
||||
index--
|
||||
if index < 0 || numArgs <= index {
|
||||
return "", fmt.Errorf("error: wrong index %d in sql", index)
|
||||
}
|
||||
|
||||
writeValue(&b, args[index])
|
||||
i = j - 1
|
||||
}
|
||||
default:
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
@ -117,3 +114,20 @@ func logSqlError(stmt string, err error) {
|
||||
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func writeValue(buf *strings.Builder, arg interface{}) {
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
buf.WriteByte('1')
|
||||
} else {
|
||||
buf.WriteByte('0')
|
||||
}
|
||||
case string:
|
||||
buf.WriteByte('\'')
|
||||
buf.WriteString(escape(v))
|
||||
buf.WriteByte('\'')
|
||||
default:
|
||||
buf.WriteString(mapping.Repr(v))
|
||||
}
|
||||
}
|
||||
|
@ -29,30 +29,63 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
|
||||
func TestFormatForPrint(t *testing.T) {
|
||||
func TestFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
expect string
|
||||
hasErr bool
|
||||
}{
|
||||
{
|
||||
name: "no args",
|
||||
query: "select user, name from table where id=?",
|
||||
expect: `select user, name from table where id=?`,
|
||||
name: "mysql normal",
|
||||
query: "select name, age from users where bool=? and phone=?",
|
||||
args: []interface{}{true, "133"},
|
||||
expect: "select name, age from users where bool=1 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "one arg",
|
||||
query: "select user, name from table where id=?",
|
||||
args: []interface{}{"kevin"},
|
||||
expect: `select user, name from table where id=? ["kevin"]`,
|
||||
name: "mysql normal",
|
||||
query: "select name, age from users where bool=? and phone=?",
|
||||
args: []interface{}{false, "133"},
|
||||
expect: "select name, age from users where bool=0 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "pg normal",
|
||||
query: "select name, age from users where bool=$1 and phone=$2",
|
||||
args: []interface{}{true, "133"},
|
||||
expect: "select name, age from users where bool=1 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "pg normal reverse",
|
||||
query: "select name, age from users where bool=$2 and phone=$1",
|
||||
args: []interface{}{"133", false},
|
||||
expect: "select name, age from users where bool=0 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "pg error not number",
|
||||
query: "select name, age from users where bool=$a and phone=$1",
|
||||
args: []interface{}{"133", false},
|
||||
hasErr: true,
|
||||
},
|
||||
{
|
||||
name: "pg error more args",
|
||||
query: "select name, age from users where bool=$2 and phone=$1 and nickname=$3",
|
||||
args: []interface{}{"133", false},
|
||||
hasErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
actual := formatForPrint(test.query, test.args...)
|
||||
assert.Equal(t, test.expect, actual)
|
||||
t.Parallel()
|
||||
|
||||
actual, err := format(test.query, test.args...)
|
||||
if test.hasErr {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Equal(t, test.expect, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user