From 55e2c7ee83007ddbdc9d7f72fb603765aaf4f282 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Thu, 18 May 2023 23:11:32 +0800 Subject: [PATCH] chore: add more tests (#3258) --- core/stores/sqlx/sqlconn_test.go | 182 +++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index cf1a148e..bbe15a92 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -2,13 +2,16 @@ package sqlx import ( "database/sql" + "errors" "io" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/trace/tracetest" + "github.com/zeromicro/go-zero/internal/dbtest" ) const mockedDatasource = "sqlmock" @@ -54,6 +57,185 @@ func TestSqlConn(t *testing.T) { assert.Equal(t, 14, len(me.GetSpans())) } +func TestSqlConn_RawDB(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + conn := NewSqlConnFromDB(db) + var val string + assert.NoError(t, conn.QueryRow(&val, "any")) + assert.Equal(t, "bar", val) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + conn := NewSqlConnFromDB(db) + var val string + assert.NoError(t, conn.QueryRowPartial(&val, "any")) + assert.Equal(t, "bar", val) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + conn := NewSqlConnFromDB(db) + var vals []string + assert.NoError(t, conn.QueryRows(&vals, "any")) + assert.ElementsMatch(t, []string{"foo", "bar"}, vals) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + conn := NewSqlConnFromDB(db) + var vals []string + assert.NoError(t, conn.QueryRowsPartial(&vals, "any")) + assert.ElementsMatch(t, []string{"foo", "bar"}, vals) + }) +} + +func TestSqlConn_Errors(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + conn := NewSqlConnFromDB(db) + conn.(*commonSqlConn).connProv = func() (*sql.DB, error) { + return nil, errors.New("error") + } + _, err := conn.Prepare("any") + assert.Error(t, err) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectExec("any").WillReturnError(breaker.ErrServiceUnavailable) + conn := NewSqlConnFromDB(db) + _, err := conn.Exec("any") + assert.Error(t, err) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any").WillReturnError(breaker.ErrServiceUnavailable) + conn := NewSqlConnFromDB(db) + _, err := conn.Prepare("any") + assert.Error(t, err) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectRollback() + conn := NewSqlConnFromDB(db) + err := conn.Transact(func(session Session) error { + return breaker.ErrServiceUnavailable + }) + assert.Equal(t, breaker.ErrServiceUnavailable, err) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectQuery("any").WillReturnError(breaker.ErrServiceUnavailable) + conn := NewSqlConnFromDB(db) + var vals []string + err := conn.QueryRows(&vals, "any") + assert.Equal(t, breaker.ErrServiceUnavailable, err) + }) +} + +func TestStatement(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any").WillBeClosed() + + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + assert.NoError(t, stmt.Close()) + }) + + dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any").WillBeClosed() + + stmt, err := tx.Prepare("any") + assert.NoError(t, err) + st := statement{ + query: "foo", + stmt: stmt, + } + assert.NoError(t, st.Close()) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + res, err := stmt.Exec() + assert.NoError(t, err) + lastInsertID, err := res.LastInsertId() + assert.NoError(t, err) + assert.Equal(t, int64(2), lastInsertID) + rowsAffected, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(3), rowsAffected) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + row := sqlmock.NewRows([]string{"foo"}).AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(row) + + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + var val string + err = stmt.QueryRow(&val) + assert.NoError(t, err) + assert.Equal(t, "bar", val) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + row := sqlmock.NewRows([]string{"foo"}).AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(row) + + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + var val string + err = stmt.QueryRowPartial(&val) + assert.NoError(t, err) + assert.Equal(t, "bar", val) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + var vals []string + assert.NoError(t, stmt.QueryRows(&vals)) + assert.ElementsMatch(t, []string{"foo", "bar"}, vals) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + var vals []string + assert.NoError(t, stmt.QueryRowsPartial(&vals)) + assert.ElementsMatch(t, []string{"foo", "bar"}, vals) + }) +} + func buildConn() (mock sqlmock.Sqlmock, err error) { _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { var db *sql.DB