bugfix: add ctx check & serializeImpl assign

This commit is contained in:
huchenhao 2024-03-21 10:13:32 +08:00
parent 856a5db8bb
commit 5679e29b9f
3 changed files with 64 additions and 31 deletions

View File

@ -26,6 +26,9 @@ type FaaSDesc struct {
var globalFaaSSerialize = &DefaultFaasSerialize{} var globalFaaSSerialize = &DefaultFaasSerialize{}
func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) { func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) {
var serializeImpl FaasSerialize
funcValue := reflect.ValueOf(f) funcValue := reflect.ValueOf(f)
funcType := funcValue.Type() funcType := funcValue.Type()
@ -35,32 +38,43 @@ func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) {
argsType := make([]reflect.Type, funcType.NumIn()) argsType := make([]reflect.Type, funcType.NumIn())
fullName := runtime.FuncForPC(funcValue.Pointer()).Name() fullName := runtime.FuncForPC(funcValue.Pointer()).Name()
containsKisflowCtx := false containsKisFlow := false
containsCtx := false
for i := 0; i < funcType.NumIn(); i++ { for i := 0; i < funcType.NumIn(); i++ {
paramType := funcType.In(i) paramType := funcType.In(i)
if isFlowType(paramType) { if isFlowType(paramType) {
containsKisflowCtx = true containsKisFlow = true
} else if isContextType(paramType) {
containsCtx = true
} else {
itemType := paramType.Elem()
// 如果切片元素是指针类型,则获取指针所指向的类型
if itemType.Kind() == reflect.Ptr {
itemType = itemType.Elem()
}
// Check if f implements FaasSerialize interface
if isFaasSerialize(itemType) {
serializeImpl = reflect.New(itemType).Interface().(FaasSerialize)
} else {
serializeImpl = globalFaaSSerialize // Use global default implementation
}
} }
argsType[i] = paramType argsType[i] = paramType
} }
if !containsKisflowCtx { if !containsKisFlow {
return nil, errors.New("function parameters must have Kisflow context") return nil, errors.New("function parameters must have Kisflow context")
} }
if !containsCtx {
return nil, errors.New("function parameters must have context")
}
if funcType.NumOut() != 1 || funcType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { if funcType.NumOut() != 1 || funcType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
return nil, errors.New("function must have exactly one return value of type error") return nil, errors.New("function must have exactly one return value of type error")
} }
// Check if f implements FaasSerialize interface
var serializeImpl FaasSerialize
if ser, ok := f.(FaasSerialize); ok {
serializeImpl = ser
} else {
serializeImpl = globalFaaSSerialize // Use global default implementation
}
return &FaaSDesc{ return &FaaSDesc{
FnName: fnName, FnName: fnName,
f: f, f: f,

View File

@ -9,3 +9,9 @@ type FaasSerialize interface {
DecodeParam(common.KisRowArr, reflect.Type) (reflect.Value, error) DecodeParam(common.KisRowArr, reflect.Type) (reflect.Value, error)
EncodeParam(interface{}) (common.KisRowArr, error) EncodeParam(interface{}) (common.KisRowArr, error)
} }
var serializeInterfaceType = reflect.TypeOf((*FaasSerialize)(nil)).Elem()
func isFaasSerialize(paramType reflect.Type) bool {
return paramType.Implements(serializeInterfaceType)
}

View File

@ -10,33 +10,40 @@ import (
type DefaultFaasSerialize struct { type DefaultFaasSerialize struct {
} }
// DecodeParam 用于将 KisRowArr 反序列化为指定类型的值。
func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) (reflect.Value, error) { func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) (reflect.Value, error) {
// 确保传入的类型是一个切片 // 确保传入的类型是一个切片
if r.Kind() != reflect.Slice { if r.Kind() != reflect.Slice {
return reflect.Value{}, fmt.Errorf("r must be a slice") return reflect.Value{}, fmt.Errorf("r must be a slice")
} }
// 创建一个新的切片,类型为传入的类型。
slice := reflect.MakeSlice(r, 0, len(arr)) slice := reflect.MakeSlice(r, 0, len(arr))
// 遍历 KisRowArr 中的每个元素。
for _, row := range arr { for _, row := range arr {
var elem reflect.Value var elem reflect.Value
var err error var err error
// 先尝试断言为结构体或指针 // 使用 switch 语句检查 row 的类型,然后调用相应的解码函数。
elem, err = decodeStruct(row, r.Elem()) switch row := row.(type) {
if err != nil { case reflect.Value:
// 如果失败,则尝试直接反序列化为字符串 elem, err = decodeStruct(row, r.Elem())
elem, err = decodeString(row) case string:
if err != nil { elem, err = decodeString(row, r.Elem())
// 如果还失败,则尝试先序列化为 JSON 再反序列化 default:
elem, err = decodeJSON(row, r.Elem()) elem, err = decodeJSON(row, r.Elem())
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to decode row: %v ", err)
}
}
} }
// 处理解码错误。
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to decode row: %v", err)
}
// 将该值附加到新的切片中。
slice = reflect.Append(slice, elem) slice = reflect.Append(slice, elem)
} }
// 返回最终的切片。
return slice, nil return slice, nil
} }
@ -54,15 +61,21 @@ func decodeStruct(row common.KisRow, elemType reflect.Type) (reflect.Value, erro
} }
// 尝试直接反序列化字符串 // 尝试直接反序列化字符串
func decodeString(row common.KisRow) (reflect.Value, error) { func decodeString(row common.KisRow, elemType reflect.Type) (reflect.Value, error) {
if str, ok := row.(string); ok { str, ok := row.(string)
var intValue int if !ok {
if _, err := fmt.Sscanf(str, "%d", &intValue); err == nil { return reflect.Value{}, fmt.Errorf("not a string")
return reflect.ValueOf(intValue), nil
}
} }
return reflect.Value{}, fmt.Errorf("not a string ") // 创建一个新的结构体实例,用于存储反序列化后的值。
elem := reflect.New(elemType).Elem()
// 尝试将字符串反序列化为结构体。
if err := json.Unmarshal([]byte(str), elem.Addr().Interface()); err != nil {
return reflect.Value{}, fmt.Errorf("failed to unmarshal string to struct: %v", err)
}
return elem, nil
} }
// 尝试先序列化为 JSON 再反序列化 // 尝试先序列化为 JSON 再反序列化