chore: add more tests (#3288)

This commit is contained in:
Kevin Wan 2023-05-27 21:49:11 +08:00 committed by GitHub
parent 0217044900
commit cd0f3726ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 305 additions and 111 deletions

View File

@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
}
valueField := reflect.Indirect(v).Field(i)
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
result[key] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
result[key] = valueField.Addr().Interface()
valueData, err := getValueInterface(valueField)
if err != nil {
return nil, err
}
result[key] = valueData
}
return result, nil
}
func getValueInterface(value reflect.Value) (any, error) {
switch value.Kind() {
case reflect.Ptr:
if !value.CanInterface() {
return nil, ErrNotReadableValue
}
if value.IsNil() {
baseValueType := mapping.Deref(value.Type())
value.Set(reflect.New(baseValueType))
}
return value.Interface(), nil
default:
if !value.CanAddr() || !value.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
return value.Addr().Interface(), nil
}
}
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
fields := unwrapFields(v)
if strict && len(columns) < len(fields) {
@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([
values := make([]any, len(columns))
if len(taggedMap) == 0 {
if len(fields) < len(values) {
return nil, ErrNotMatchDestination
}
for i := 0; i < len(values); i++ {
valueField := fields[i]
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
values[i] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
values[i] = valueField.Addr().Interface()
valueData, err := getValueInterface(valueField)
if err != nil {
return nil, err
}
values[i] = valueData
}
} else {
for i, column := range columns {
@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
if rve.CanSet() {
return scanner.Scan(v)
if !rve.CanSet() {
return ErrNotSettable
}
return ErrNotSettable
return scanner.Scan(v)
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
@ -183,69 +189,66 @@ func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
rt := reflect.TypeOf(v)
rte := rt.Elem()
rve := rv.Elem()
if !rve.CanSet() {
return ErrNotSettable
}
switch rte.Kind() {
case reflect.Slice:
if rve.CanSet() {
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
fillFn := func(value any) error {
if rve.CanSet() {
if err := scanner.Scan(value); err != nil {
return err
}
appendFn(reflect.ValueOf(value))
return nil
}
return ErrNotSettable
}
fillFn := func(value any) error {
if err := scanner.Scan(value); err != nil {
return err
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
appendFn(reflect.ValueOf(value))
return nil
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
case reflect.Struct:
columns, err := scanner.Columns()
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
values, err := mapStructFieldsIntoSlice(value, columns, strict)
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
values, err := mapStructFieldsIntoSlice(value, columns, strict)
if err != nil {
return err
}
if err := scanner.Scan(values...); err != nil {
return err
}
appendFn(value)
if err := scanner.Scan(values...); err != nil {
return err
}
default:
return ErrUnsupportedValueType
}
return nil
appendFn(value)
}
default:
return ErrUnsupportedValueType
}
return ErrNotSettable
return nil
default:
return ErrUnsupportedValueType
}
@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value {
for i := 0; i < indirect.NumField(); i++ {
child := indirect.Field(i)
if !child.CanSet() {
continue
}
if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType))

View File

@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) {
}, "select value from users where user=?", "anyone"))
assert.True(t, value)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value struct {
Value bool `db:"value"`
}
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
}
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) {
}
func TestUnmarshalRowStruct(t *testing.T) {
value := new(struct {
Name string
Age int
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
Age int
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@ -222,15 +234,58 @@ func TestUnmarshalRowStruct(t *testing.T) {
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
}
func TestUnmarshalRowStructWithTags(t *testing.T) {
value := new(struct {
Age int `db:"age"`
Name string `db:"name"`
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
Age int
})
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, &mockedScanner{
colErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
age *int
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
type myString chan int
var value myString
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
})
}
func TestUnmarshalRowStructWithTags(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Age int `db:"age"`
Name string `db:"name"`
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@ -240,6 +295,51 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
age *int `db:"age"`
Name string `db:"name"`
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value struct {
Age *int `db:"age"`
Name *string `db:"name"`
}
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", *value.Name)
assert.Equal(t, 5, *value.Age)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Age int `db:"age"`
Name string
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, 5, value.Age)
})
}
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
@ -270,6 +370,42 @@ func TestUnmarshalRowsBool(t *testing.T) {
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value struct {
value []bool `db:"value"`
}
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
errAny := errors.New("any")
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
scanErr: errAny,
next: 1,
}, true)
}, "select value from users where user=?", "anyone"), errAny)
})
}
func TestUnmarshalRowsInt(t *testing.T) {
@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
}
func TestUnmarshalRowsStruct(t *testing.T) {
expect := []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) {
assert.Equal(t, each.Age, value[i].Age)
}
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []struct {
Name string
Age int64
}
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
colErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []struct {
Name string
Age int64
}
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
cols: []string{"name", "age"},
scanErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []chan int
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
cols: []string{"name", "age"},
scanErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType)
})
}
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
@ -1163,6 +1349,7 @@ func TestAnonymousStructPrError(t *testing.T) {
}
type mockedScanner struct {
cols []string
colErr error
scanErr error
err error
@ -1170,7 +1357,7 @@ type mockedScanner struct {
}
func (m *mockedScanner) Columns() ([]string, error) {
return nil, m.colErr
return m.cols, m.colErr
}
func (m *mockedScanner) Err() error {