diff --git a/kis/faas.go b/kis/faas.go index 96f8757..5199405 100644 --- a/kis/faas.go +++ b/kis/faas.go @@ -26,6 +26,9 @@ type FaaSDesc struct { var globalFaaSSerialize = &DefaultFaasSerialize{} func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) { + + var serializeImpl FaasSerialize + funcValue := reflect.ValueOf(f) funcType := funcValue.Type() @@ -35,32 +38,43 @@ func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) { argsType := make([]reflect.Type, funcType.NumIn()) fullName := runtime.FuncForPC(funcValue.Pointer()).Name() - containsKisflowCtx := false + containsKisFlow := false + containsCtx := false for i := 0; i < funcType.NumIn(); i++ { paramType := funcType.In(i) if isFlowType(paramType) { - containsKisflowCtx = true + containsKisFlow = true + } else if isContextType(paramType) { + containsCtx = true + } else { + itemType := paramType.Elem() + // 如果切片元素是指针类型,则获取指针所指向的类型 + if itemType.Kind() == reflect.Ptr { + itemType = itemType.Elem() + } + // Check if f implements FaasSerialize interface + if isFaasSerialize(itemType) { + serializeImpl = reflect.New(itemType).Interface().(FaasSerialize) + } else { + serializeImpl = globalFaaSSerialize // Use global default implementation + } + } argsType[i] = paramType } - if !containsKisflowCtx { + if !containsKisFlow { return nil, errors.New("function parameters must have Kisflow context") } + if !containsCtx { + return nil, errors.New("function parameters must have context") + } if funcType.NumOut() != 1 || funcType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { return nil, errors.New("function must have exactly one return value of type error") } - // Check if f implements FaasSerialize interface - var serializeImpl FaasSerialize - if ser, ok := f.(FaasSerialize); ok { - serializeImpl = ser - } else { - serializeImpl = globalFaaSSerialize // Use global default implementation - } - return &FaaSDesc{ FnName: fnName, f: f, diff --git a/kis/serialize.go b/kis/serialize.go index f441e76..5c5a583 100644 --- a/kis/serialize.go +++ b/kis/serialize.go @@ -9,3 +9,9 @@ type FaasSerialize interface { DecodeParam(common.KisRowArr, reflect.Type) (reflect.Value, error) EncodeParam(interface{}) (common.KisRowArr, error) } + +var serializeInterfaceType = reflect.TypeOf((*FaasSerialize)(nil)).Elem() + +func isFaasSerialize(paramType reflect.Type) bool { + return paramType.Implements(serializeInterfaceType) +} diff --git a/kis/serialize_json.go b/kis/serialize_default.go similarity index 61% rename from kis/serialize_json.go rename to kis/serialize_default.go index 74f9585..836198e 100644 --- a/kis/serialize_json.go +++ b/kis/serialize_default.go @@ -10,6 +10,7 @@ import ( type DefaultFaasSerialize struct { } +// DecodeParam 用于将 KisRowArr 反序列化为指定类型的值。 func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) (reflect.Value, error) { // 确保传入的类型是一个切片 if r.Kind() != reflect.Slice { @@ -23,8 +24,8 @@ func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) // 先尝试断言为结构体或指针 elem, err = decodeStruct(row, r.Elem()) if err != nil { - // 如果失败,则尝试直接反序列化为字符串 - elem, err = decodeString(row) + // 如果失败,则尝试直接反序列化字符串 + elem, err = decodeString(row, r.Elem()) if err != nil { // 如果还失败,则尝试先序列化为 JSON 再反序列化 elem, err = decodeJSON(row, r.Elem()) @@ -42,27 +43,49 @@ func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) // 尝试断言为结构体或指针 func decodeStruct(row common.KisRow, elemType reflect.Type) (reflect.Value, error) { - elem := reflect.New(elemType).Elem() - - // 如果元素是一个结构体或指针类型,则尝试断言 - if structElem, ok := row.(reflect.Value); ok && structElem.Type().AssignableTo(elemType) { - elem.Set(structElem) - return elem, nil + // 检查 row 是否为结构体或结构体指针类型 + rowType := reflect.TypeOf(row) + if rowType == nil { + return reflect.Value{}, fmt.Errorf("row is nil pointer") + } + if rowType.Kind() != reflect.Struct && rowType.Kind() != reflect.Ptr { + return reflect.Value{}, fmt.Errorf("row must be a struct or struct pointer type") } - return reflect.Value{}, fmt.Errorf("not a struct or pointer") + // 如果 row 是指针类型,则获取它指向的类型 + if rowType.Kind() == reflect.Ptr { + if reflect.ValueOf(row).IsNil() { + return reflect.Value{}, fmt.Errorf("row is nil pointer") + } + row = reflect.ValueOf(row).Elem().Interface() // 解引用 + rowType = reflect.TypeOf(row) + } + + // 检查是否可以将 row 断言为 elemType + if !rowType.AssignableTo(elemType) { + return reflect.Value{}, fmt.Errorf("row type cannot be asserted to elemType") + } + + // 将 row 转换为 reflect.Value 并返回 + return reflect.ValueOf(row), nil } // 尝试直接反序列化字符串 -func decodeString(row common.KisRow) (reflect.Value, error) { - if str, ok := row.(string); ok { - var intValue int - if _, err := fmt.Sscanf(str, "%d", &intValue); err == nil { - return reflect.ValueOf(intValue), nil - } +func decodeString(row common.KisRow, elemType reflect.Type) (reflect.Value, error) { + str, ok := row.(string) + if !ok { + return reflect.Value{}, fmt.Errorf("not a string") } - return reflect.Value{}, fmt.Errorf("not a string ") + // 创建一个新的结构体实例,用于存储反序列化后的值。 + elem := reflect.New(elemType).Elem() + + // 尝试将字符串反序列化为结构体。 + if err := json.Unmarshal([]byte(str), elem.Addr().Interface()); err != nil { + return reflect.Value{}, fmt.Errorf("failed to unmarshal string to struct: %v", err) + } + + return elem, nil } // 尝试先序列化为 JSON 再反序列化 diff --git a/test/kis_auto_inject_param_test.go b/test/kis_auto_inject_param_test.go index 8da2932..e4d3ebc 100644 --- a/test/kis_auto_inject_param_test.go +++ b/test/kis_auto_inject_param_test.go @@ -45,23 +45,22 @@ func TestAutoInjectParam(t *testing.T) { } // 3. 提交原始数据 - _ = flow1.CommitRow(proto.StuScores{ - StuId: 100, - Score1: 1, - Score2: 2, - Score3: 3, + _ = flow1.CommitRow(&faas.AvgStuScoreIn{ + proto.StuScores{ + StuId: 100, + Score1: 1, + Score2: 2, + Score3: 3, + }, }) - _ = flow1.CommitRow(proto.StuScores{ - StuId: 101, - Score1: 11, - Score2: 22, - Score3: 33, - }) - _ = flow1.CommitRow(proto.StuScores{ - StuId: 102, - Score1: 111, - Score2: 222, - Score3: 333, + _ = flow1.CommitRow(`{"stu_id":101}`) + _ = flow1.CommitRow(faas.AvgStuScoreIn{ + proto.StuScores{ + StuId: 100, + Score1: 1, + Score2: 2, + Score3: 3, + }, }) // 4. 执行flow1