feat: support **struct in mapping (#2784)

* feat: support **struct in mapping

* chore: fix test failure
This commit is contained in:
Kevin Wan 2023-01-12 20:45:32 +08:00 committed by GitHub
parent 367afb544c
commit 4d7fa08b0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 179 additions and 116 deletions

View File

@ -77,7 +77,7 @@ func (u *Unmarshaler) Unmarshal(i interface{}, v interface{}) error {
return errValueNotSettable return errValueNotSettable
} }
elemType := valueType.Elem() elemType := Deref(valueType)
switch iv := i.(type) { switch iv := i.(type) {
case map[string]interface{}: case map[string]interface{}:
if elemType.Kind() != reflect.Struct { if elemType.Kind() != reflect.Struct {
@ -818,15 +818,22 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
return err return err
} }
rte := reflect.TypeOf(v).Elem() valueType := reflect.TypeOf(v)
if rte.Kind() != reflect.Struct { baseType := Deref(valueType)
if baseType.Kind() != reflect.Struct {
return errValueNotStruct return errValueNotStruct
} }
rve := rv.Elem() valElem := rv.Elem()
numFields := rte.NumField() if valElem.Kind() == reflect.Ptr {
target := reflect.New(baseType).Elem()
SetValue(valueType.Elem(), valElem, target)
valElem = target
}
numFields := baseType.NumField()
for i := 0; i < numFields; i++ { for i := 0; i < numFields; i++ {
if err := u.processField(rte.Field(i), rve.Field(i), m, fullName); err != nil { if err := u.processField(baseType.Field(i), valElem.Field(i), m, fullName); err != nil {
return err return err
} }
} }

View File

@ -3,6 +3,7 @@ package mapping
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@ -3388,7 +3389,8 @@ func TestUnmarshal_EnvString(t *testing.T) {
envName = "TEST_NAME_STRING" envName = "TEST_NAME_STRING"
envVal = "this is a name" envVal = "this is a name"
) )
t.Setenv(envName, envVal) os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@ -3405,7 +3407,8 @@ func TestUnmarshal_EnvStringOverwrite(t *testing.T) {
envName = "TEST_NAME_STRING" envName = "TEST_NAME_STRING"
envVal = "this is a name" envVal = "this is a name"
) )
t.Setenv(envName, envVal) os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(map[string]interface{}{ if assert.NoError(t, UnmarshalKey(map[string]interface{}{
@ -3420,8 +3423,12 @@ func TestUnmarshal_EnvInt(t *testing.T) {
Age int `key:"age,env=TEST_NAME_INT"` Age int `key:"age,env=TEST_NAME_INT"`
} }
const envName = "TEST_NAME_INT" const (
t.Setenv(envName, "123") envName = "TEST_NAME_INT"
envVal = "123"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@ -3434,8 +3441,12 @@ func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
Age int `key:"age,env=TEST_NAME_INT"` Age int `key:"age,env=TEST_NAME_INT"`
} }
const envName = "TEST_NAME_INT" const (
t.Setenv(envName, "123") envName = "TEST_NAME_INT"
envVal = "123"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(map[string]interface{}{ if assert.NoError(t, UnmarshalKey(map[string]interface{}{
@ -3450,8 +3461,12 @@ func TestUnmarshal_EnvFloat(t *testing.T) {
Age float32 `key:"name,env=TEST_NAME_FLOAT"` Age float32 `key:"name,env=TEST_NAME_FLOAT"`
} }
const envName = "TEST_NAME_FLOAT" const (
t.Setenv(envName, "123.45") envName = "TEST_NAME_FLOAT"
envVal = "123.45"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@ -3464,8 +3479,12 @@ func TestUnmarshal_EnvFloatOverwrite(t *testing.T) {
Age float32 `key:"age,env=TEST_NAME_FLOAT"` Age float32 `key:"age,env=TEST_NAME_FLOAT"`
} }
const envName = "TEST_NAME_FLOAT" const (
t.Setenv(envName, "123.45") envName = "TEST_NAME_FLOAT"
envVal = "123.45"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(map[string]interface{}{ if assert.NoError(t, UnmarshalKey(map[string]interface{}{
@ -3480,8 +3499,12 @@ func TestUnmarshal_EnvBoolTrue(t *testing.T) {
Enable bool `key:"enable,env=TEST_NAME_BOOL_TRUE"` Enable bool `key:"enable,env=TEST_NAME_BOOL_TRUE"`
} }
const envName = "TEST_NAME_BOOL_TRUE" const (
t.Setenv(envName, "true") envName = "TEST_NAME_BOOL_TRUE"
envVal = "true"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@ -3494,8 +3517,12 @@ func TestUnmarshal_EnvBoolFalse(t *testing.T) {
Enable bool `key:"enable,env=TEST_NAME_BOOL_FALSE"` Enable bool `key:"enable,env=TEST_NAME_BOOL_FALSE"`
} }
const envName = "TEST_NAME_BOOL_FALSE" const (
t.Setenv(envName, "false") envName = "TEST_NAME_BOOL_FALSE"
envVal = "false"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@ -3508,8 +3535,12 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
Enable bool `key:"enable,env=TEST_NAME_BOOL_BAD"` Enable bool `key:"enable,env=TEST_NAME_BOOL_BAD"`
} }
const envName = "TEST_NAME_BOOL_BAD" const (
t.Setenv(envName, "bad") envName = "TEST_NAME_BOOL_BAD"
envVal = "bad"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
assert.Error(t, UnmarshalKey(emptyMap, &v)) assert.Error(t, UnmarshalKey(emptyMap, &v))
@ -3520,8 +3551,12 @@ func TestUnmarshal_EnvDuration(t *testing.T) {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"` Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
} }
const envName = "TEST_NAME_DURATION" const (
t.Setenv(envName, "1s") envName = "TEST_NAME_DURATION"
envVal = "1s"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@ -3534,8 +3569,12 @@ func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
Duration time.Duration `key:"duration,env=TEST_NAME_BAD_DURATION"` Duration time.Duration `key:"duration,env=TEST_NAME_BAD_DURATION"`
} }
const envName = "TEST_NAME_BAD_DURATION" const (
t.Setenv(envName, "bad") envName = "TEST_NAME_BAD_DURATION"
envVal = "bad"
)
os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
assert.Error(t, UnmarshalKey(emptyMap, &v)) assert.Error(t, UnmarshalKey(emptyMap, &v))
@ -3550,7 +3589,8 @@ func TestUnmarshal_EnvWithOptions(t *testing.T) {
envName = "TEST_NAME_ENV_OPTIONS_MATCH" envName = "TEST_NAME_ENV_OPTIONS_MATCH"
envVal = "123" envVal = "123"
) )
t.Setenv(envName, envVal) os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@ -3567,7 +3607,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueBool(t *testing.T) {
envName = "TEST_NAME_ENV_OPTIONS_BOOL" envName = "TEST_NAME_ENV_OPTIONS_BOOL"
envVal = "false" envVal = "false"
) )
t.Setenv(envName, envVal) os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
assert.Error(t, UnmarshalKey(emptyMap, &v)) assert.Error(t, UnmarshalKey(emptyMap, &v))
@ -3582,7 +3623,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueDuration(t *testing.T) {
envName = "TEST_NAME_ENV_OPTIONS_DURATION" envName = "TEST_NAME_ENV_OPTIONS_DURATION"
envVal = "4s" envVal = "4s"
) )
t.Setenv(envName, envVal) os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
assert.Error(t, UnmarshalKey(emptyMap, &v)) assert.Error(t, UnmarshalKey(emptyMap, &v))
@ -3597,7 +3639,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueNumber(t *testing.T) {
envName = "TEST_NAME_ENV_OPTIONS_AGE" envName = "TEST_NAME_ENV_OPTIONS_AGE"
envVal = "30" envVal = "30"
) )
t.Setenv(envName, envVal) os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
assert.Error(t, UnmarshalKey(emptyMap, &v)) assert.Error(t, UnmarshalKey(emptyMap, &v))
@ -3612,7 +3655,8 @@ func TestUnmarshal_EnvWithOptionsWrongValueString(t *testing.T) {
envName = "TEST_NAME_ENV_OPTIONS_STRING" envName = "TEST_NAME_ENV_OPTIONS_STRING"
envVal = "this is a name" envVal = "this is a name"
) )
t.Setenv(envName, envVal) os.Setenv(envName, envVal)
defer os.Unsetenv(envName)
var v Value var v Value
assert.Error(t, UnmarshalKey(emptyMap, &v)) assert.Error(t, UnmarshalKey(emptyMap, &v))
@ -4115,6 +4159,20 @@ func TestUnmarshalNestedPtr(t *testing.T) {
} }
} }
func TestUnmarshalStructPtrOfPtr(t *testing.T) {
type inner struct {
Int int `key:"int"`
}
m := map[string]interface{}{
"int": 1,
}
in := new(inner)
if assert.NoError(t, UnmarshalKey(m, &in)) {
assert.Equal(t, 1, in.Int)
}
}
func BenchmarkDefaultValue(b *testing.B) { func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var a struct { var a struct {

View File

@ -118,7 +118,7 @@ func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route
chn := chain.New() chn := chain.New()
if ng.conf.Middlewares.Trace { if ng.conf.Middlewares.Trace {
chn = chn.Append(handler.TracingHandler(ng.conf.Name, chn = chn.Append(handler.TraceHandler(ng.conf.Name,
route.Path, route.Path,
handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths))) handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)))
} }
@ -204,7 +204,7 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
func (ng *engine) notFoundHandler(next http.Handler) http.Handler { func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
chn := chain.New( chn := chain.New(
handler.TracingHandler(ng.conf.Name, handler.TraceHandler(ng.conf.Name,
"", "",
handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)), handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)),
ng.getLogHandler(), ng.getLogHandler(),

View File

@ -0,0 +1,78 @@
package handler
import (
"net/http"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/trace"
"github.com/zeromicro/go-zero/rest/internal/response"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
oteltrace "go.opentelemetry.io/otel/trace"
)
type (
// TraceOption defines the method to customize an traceOptions.
TraceOption func(options *traceOptions)
// traceOptions is TraceHandler options.
traceOptions struct {
traceIgnorePaths []string
}
)
// TraceHandler return a middleware that process the opentelemetry.
func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handler) http.Handler {
var options traceOptions
for _, opt := range opts {
opt(&options)
}
ignorePaths := collection.NewSet()
ignorePaths.AddStr(options.traceIgnorePaths...)
return func(next http.Handler) http.Handler {
tracer := otel.Tracer(trace.TraceName)
propagator := otel.GetTextMapPropagator()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
spanName := path
if len(spanName) == 0 {
spanName = r.URL.Path
}
if ignorePaths.Contains(spanName) {
next.ServeHTTP(w, r)
return
}
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
spanCtx, span := tracer.Start(
ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
serviceName, spanName, r)...),
)
defer span.End()
// convenient for tracking error messages
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
next.ServeHTTP(trw, r.WithContext(spanCtx))
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(
trw.Code, oteltrace.SpanKindServer))
})
}
}
// WithTraceIgnorePaths specifies the traceIgnorePaths option for TraceHandler.
func WithTraceIgnorePaths(traceIgnorePaths []string) TraceOption {
return func(options *traceOptions) {
options.traceIgnorePaths = append(options.traceIgnorePaths, traceIgnorePaths...)
}
}

View File

@ -27,7 +27,7 @@ func TestOtelHandler(t *testing.T) {
for _, test := range []string{"", "bar"} { for _, test := range []string{"", "bar"} {
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
h := chain.New(TracingHandler("foo", test)).Then( h := chain.New(TraceHandler("foo", test)).Then(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.SpanFromContext(r.Context()) span := trace.SpanFromContext(r.Context())
assert.True(t, span.SpanContext().IsValid()) assert.True(t, span.SpanContext().IsValid())
@ -65,7 +65,7 @@ func TestDontTracingSpan(t *testing.T) {
for _, test := range []string{"", "bar", "foo"} { for _, test := range []string{"", "bar", "foo"} {
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
h := chain.New(TracingHandler("foo", test, WithTraceIgnorePaths([]string{"bar"}))).Then( h := chain.New(TraceHandler("foo", test, WithTraceIgnorePaths([]string{"bar"}))).Then(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.SpanFromContext(r.Context()) span := trace.SpanFromContext(r.Context())
spanCtx := span.SpanContext() spanCtx := span.SpanContext()
@ -110,7 +110,7 @@ func TestTraceResponseWriter(t *testing.T) {
for _, test := range []int{0, 200, 300, 400, 401, 500, 503} { for _, test := range []int{0, 200, 300, 400, 401, 500, 503} {
t.Run(strconv.Itoa(test), func(t *testing.T) { t.Run(strconv.Itoa(test), func(t *testing.T) {
h := chain.New(TracingHandler("foo", "bar")).Then( h := chain.New(TraceHandler("foo", "bar")).Then(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.SpanFromContext(r.Context()) span := trace.SpanFromContext(r.Context())
spanCtx := span.SpanContext() spanCtx := span.SpanContext()

View File

@ -1,80 +0,0 @@
package handler
import (
"net/http"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/trace"
"github.com/zeromicro/go-zero/rest/internal/response"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
oteltrace "go.opentelemetry.io/otel/trace"
)
type (
// TracingOption defines the method to customize an tracingOptions.
TracingOption func(options *tracingOptions)
// tracingOptions is TracingHandler options.
tracingOptions struct {
traceIgnorePaths []string
}
)
// TracingHandler return a middleware that process the opentelemetry.
func TracingHandler(serviceName, path string, opts ...TracingOption) func(http.Handler) http.Handler {
var tracingOpts tracingOptions
for _, opt := range opts {
opt(&tracingOpts)
}
ignorePaths := collection.NewSet()
ignorePaths.AddStr(tracingOpts.traceIgnorePaths...)
traceHandler := func(checkIgnore bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
tracer := otel.Tracer(trace.TraceName)
propagator := otel.GetTextMapPropagator()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
spanName := path
if len(spanName) == 0 {
spanName = r.URL.Path
}
if checkIgnore && ignorePaths.Contains(spanName) {
next.ServeHTTP(w, r)
return
}
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
spanCtx, span := tracer.Start(
ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
serviceName, spanName, r)...),
)
defer span.End()
// convenient for tracking error messages
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
next.ServeHTTP(trw, r.WithContext(spanCtx))
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(trw.Code, oteltrace.SpanKindServer))
})
}
}
checkIgnore := ignorePaths.Count() > 0
return traceHandler(checkIgnore)
}
// WithTraceIgnorePaths specifies the traceIgnorePaths option for TracingHandler.
func WithTraceIgnorePaths(traceIgnorePaths []string) TracingOption {
return func(options *tracingOptions) {
options.traceIgnorePaths = append(options.traceIgnorePaths, traceIgnorePaths...)
}
}