mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-02 16:28:39 +08:00
chore: add more tests (#3288)
This commit is contained in:
parent
0217044900
commit
cd0f3726ed
@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
|
||||
}
|
||||
|
||||
valueField := reflect.Indirect(v).Field(i)
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
result[key] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
result[key] = valueField.Addr().Interface()
|
||||
valueData, err := getValueInterface(valueField)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result[key] = valueData
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func getValueInterface(value reflect.Value) (any, error) {
|
||||
switch value.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !value.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
|
||||
if value.IsNil() {
|
||||
baseValueType := mapping.Deref(value.Type())
|
||||
value.Set(reflect.New(baseValueType))
|
||||
}
|
||||
|
||||
return value.Interface(), nil
|
||||
default:
|
||||
if !value.CanAddr() || !value.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
|
||||
return value.Addr().Interface(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
|
||||
fields := unwrapFields(v)
|
||||
if strict && len(columns) < len(fields) {
|
||||
@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([
|
||||
|
||||
values := make([]any, len(columns))
|
||||
if len(taggedMap) == 0 {
|
||||
if len(fields) < len(values) {
|
||||
return nil, ErrNotMatchDestination
|
||||
}
|
||||
|
||||
for i := 0; i < len(values); i++ {
|
||||
valueField := fields[i]
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
values[i] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
values[i] = valueField.Addr().Interface()
|
||||
valueData, err := getValueInterface(valueField)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values[i] = valueData
|
||||
}
|
||||
} else {
|
||||
for i, column := range columns {
|
||||
@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
if rve.CanSet() {
|
||||
return scanner.Scan(v)
|
||||
if !rve.CanSet() {
|
||||
return ErrNotSettable
|
||||
}
|
||||
|
||||
return ErrNotSettable
|
||||
return scanner.Scan(v)
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
@ -183,69 +189,66 @@ func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
|
||||
rt := reflect.TypeOf(v)
|
||||
rte := rt.Elem()
|
||||
rve := rv.Elem()
|
||||
if !rve.CanSet() {
|
||||
return ErrNotSettable
|
||||
}
|
||||
|
||||
switch rte.Kind() {
|
||||
case reflect.Slice:
|
||||
if rve.CanSet() {
|
||||
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||
appendFn := func(item reflect.Value) {
|
||||
if ptr {
|
||||
rve.Set(reflect.Append(rve, item))
|
||||
} else {
|
||||
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||
}
|
||||
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||
appendFn := func(item reflect.Value) {
|
||||
if ptr {
|
||||
rve.Set(reflect.Append(rve, item))
|
||||
} else {
|
||||
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||
}
|
||||
fillFn := func(value any) error {
|
||||
if rve.CanSet() {
|
||||
if err := scanner.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
appendFn(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
return ErrNotSettable
|
||||
}
|
||||
fillFn := func(value any) error {
|
||||
if err := scanner.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mapping.Deref(rte.Elem())
|
||||
switch base.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if err := fillFn(value.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
appendFn(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
|
||||
base := mapping.Deref(rte.Elem())
|
||||
switch base.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if err := fillFn(value.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
values, err := mapStructFieldsIntoSlice(value, columns, strict)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
values, err := mapStructFieldsIntoSlice(value, columns, strict)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := scanner.Scan(values...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
appendFn(value)
|
||||
if err := scanner.Scan(values...); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
return nil
|
||||
appendFn(value)
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
return ErrNotSettable
|
||||
return nil
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value {
|
||||
|
||||
for i := 0; i < indirect.NumField(); i++ {
|
||||
child := indirect.Field(i)
|
||||
if !child.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||
baseValueType := mapping.Deref(child.Type())
|
||||
child.Set(reflect.New(baseValueType))
|
||||
|
@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) {
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.True(t, value)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value struct {
|
||||
Value bool `db:"value"`
|
||||
}
|
||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
||||
@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStruct(t *testing.T) {
|
||||
value := new(struct {
|
||||
Name string
|
||||
Age int
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Name string
|
||||
Age int
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@ -222,15 +234,58 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
value := new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Name string
|
||||
Age int
|
||||
})
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, &mockedScanner{
|
||||
colErr: errAny,
|
||||
next: 1,
|
||||
}, true)
|
||||
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Name string
|
||||
age *int
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
type myString chan int
|
||||
var value myString
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@ -240,6 +295,51 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
age *int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var value struct {
|
||||
Age *int `db:"age"`
|
||||
Name *string `db:"name"`
|
||||
}
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", *value.Name)
|
||||
assert.Equal(t, 5, *value.Age)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||
@ -270,6 +370,42 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value struct {
|
||||
value []bool `db:"value"`
|
||||
}
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
errAny := errors.New("any")
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, &mockedScanner{
|
||||
scanErr: errAny,
|
||||
next: 1,
|
||||
}, true)
|
||||
}, "select value from users where user=?", "anyone"), errAny)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt(t *testing.T) {
|
||||
@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
expect := []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var value []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, &mockedScanner{
|
||||
colErr: errAny,
|
||||
next: 1,
|
||||
}, true)
|
||||
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var value []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, &mockedScanner{
|
||||
cols: []string{"name", "age"},
|
||||
scanErr: errAny,
|
||||
next: 1,
|
||||
}, true)
|
||||
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var value []chan int
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, &mockedScanner{
|
||||
cols: []string{"name", "age"},
|
||||
scanErr: errAny,
|
||||
next: 1,
|
||||
}, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||
@ -1163,6 +1349,7 @@ func TestAnonymousStructPrError(t *testing.T) {
|
||||
}
|
||||
|
||||
type mockedScanner struct {
|
||||
cols []string
|
||||
colErr error
|
||||
scanErr error
|
||||
err error
|
||||
@ -1170,7 +1357,7 @@ type mockedScanner struct {
|
||||
}
|
||||
|
||||
func (m *mockedScanner) Columns() ([]string, error) {
|
||||
return nil, m.colErr
|
||||
return m.cols, m.colErr
|
||||
}
|
||||
|
||||
func (m *mockedScanner) Err() error {
|
||||
|
Loading…
Reference in New Issue
Block a user