mirror of
https://github.com/aceld/kis-flow.git
synced 2025-01-23 07:30:23 +08:00
bugfix: add ctx check & serializeImpl assign
This commit is contained in:
parent
856a5db8bb
commit
5679e29b9f
36
kis/faas.go
36
kis/faas.go
@ -26,6 +26,9 @@ type FaaSDesc struct {
|
|||||||
var globalFaaSSerialize = &DefaultFaasSerialize{}
|
var globalFaaSSerialize = &DefaultFaasSerialize{}
|
||||||
|
|
||||||
func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) {
|
func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) {
|
||||||
|
|
||||||
|
var serializeImpl FaasSerialize
|
||||||
|
|
||||||
funcValue := reflect.ValueOf(f)
|
funcValue := reflect.ValueOf(f)
|
||||||
funcType := funcValue.Type()
|
funcType := funcValue.Type()
|
||||||
|
|
||||||
@ -35,32 +38,43 @@ func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) {
|
|||||||
|
|
||||||
argsType := make([]reflect.Type, funcType.NumIn())
|
argsType := make([]reflect.Type, funcType.NumIn())
|
||||||
fullName := runtime.FuncForPC(funcValue.Pointer()).Name()
|
fullName := runtime.FuncForPC(funcValue.Pointer()).Name()
|
||||||
containsKisflowCtx := false
|
containsKisFlow := false
|
||||||
|
containsCtx := false
|
||||||
|
|
||||||
for i := 0; i < funcType.NumIn(); i++ {
|
for i := 0; i < funcType.NumIn(); i++ {
|
||||||
paramType := funcType.In(i)
|
paramType := funcType.In(i)
|
||||||
if isFlowType(paramType) {
|
if isFlowType(paramType) {
|
||||||
containsKisflowCtx = true
|
containsKisFlow = true
|
||||||
|
} else if isContextType(paramType) {
|
||||||
|
containsCtx = true
|
||||||
|
} else {
|
||||||
|
itemType := paramType.Elem()
|
||||||
|
// 如果切片元素是指针类型,则获取指针所指向的类型
|
||||||
|
if itemType.Kind() == reflect.Ptr {
|
||||||
|
itemType = itemType.Elem()
|
||||||
|
}
|
||||||
|
// Check if f implements FaasSerialize interface
|
||||||
|
if isFaasSerialize(itemType) {
|
||||||
|
serializeImpl = reflect.New(itemType).Interface().(FaasSerialize)
|
||||||
|
} else {
|
||||||
|
serializeImpl = globalFaaSSerialize // Use global default implementation
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
argsType[i] = paramType
|
argsType[i] = paramType
|
||||||
}
|
}
|
||||||
|
|
||||||
if !containsKisflowCtx {
|
if !containsKisFlow {
|
||||||
return nil, errors.New("function parameters must have Kisflow context")
|
return nil, errors.New("function parameters must have Kisflow context")
|
||||||
}
|
}
|
||||||
|
if !containsCtx {
|
||||||
|
return nil, errors.New("function parameters must have context")
|
||||||
|
}
|
||||||
|
|
||||||
if funcType.NumOut() != 1 || funcType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
|
if funcType.NumOut() != 1 || funcType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
|
||||||
return nil, errors.New("function must have exactly one return value of type error")
|
return nil, errors.New("function must have exactly one return value of type error")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if f implements FaasSerialize interface
|
|
||||||
var serializeImpl FaasSerialize
|
|
||||||
if ser, ok := f.(FaasSerialize); ok {
|
|
||||||
serializeImpl = ser
|
|
||||||
} else {
|
|
||||||
serializeImpl = globalFaaSSerialize // Use global default implementation
|
|
||||||
}
|
|
||||||
|
|
||||||
return &FaaSDesc{
|
return &FaaSDesc{
|
||||||
FnName: fnName,
|
FnName: fnName,
|
||||||
f: f,
|
f: f,
|
||||||
|
@ -9,3 +9,9 @@ type FaasSerialize interface {
|
|||||||
DecodeParam(common.KisRowArr, reflect.Type) (reflect.Value, error)
|
DecodeParam(common.KisRowArr, reflect.Type) (reflect.Value, error)
|
||||||
EncodeParam(interface{}) (common.KisRowArr, error)
|
EncodeParam(interface{}) (common.KisRowArr, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var serializeInterfaceType = reflect.TypeOf((*FaasSerialize)(nil)).Elem()
|
||||||
|
|
||||||
|
func isFaasSerialize(paramType reflect.Type) bool {
|
||||||
|
return paramType.Implements(serializeInterfaceType)
|
||||||
|
}
|
||||||
|
@ -10,33 +10,40 @@ import (
|
|||||||
type DefaultFaasSerialize struct {
|
type DefaultFaasSerialize struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DecodeParam 用于将 KisRowArr 反序列化为指定类型的值。
|
||||||
func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) (reflect.Value, error) {
|
func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) (reflect.Value, error) {
|
||||||
// 确保传入的类型是一个切片
|
// 确保传入的类型是一个切片。
|
||||||
if r.Kind() != reflect.Slice {
|
if r.Kind() != reflect.Slice {
|
||||||
return reflect.Value{}, fmt.Errorf("r must be a slice")
|
return reflect.Value{}, fmt.Errorf("r must be a slice")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建一个新的切片,类型为传入的类型。
|
||||||
slice := reflect.MakeSlice(r, 0, len(arr))
|
slice := reflect.MakeSlice(r, 0, len(arr))
|
||||||
|
|
||||||
|
// 遍历 KisRowArr 中的每个元素。
|
||||||
for _, row := range arr {
|
for _, row := range arr {
|
||||||
var elem reflect.Value
|
var elem reflect.Value
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// 先尝试断言为结构体或指针
|
// 使用 switch 语句检查 row 的类型,然后调用相应的解码函数。
|
||||||
elem, err = decodeStruct(row, r.Elem())
|
switch row := row.(type) {
|
||||||
if err != nil {
|
case reflect.Value:
|
||||||
// 如果失败,则尝试直接反序列化为字符串
|
elem, err = decodeStruct(row, r.Elem())
|
||||||
elem, err = decodeString(row)
|
case string:
|
||||||
if err != nil {
|
elem, err = decodeString(row, r.Elem())
|
||||||
// 如果还失败,则尝试先序列化为 JSON 再反序列化
|
default:
|
||||||
elem, err = decodeJSON(row, r.Elem())
|
elem, err = decodeJSON(row, r.Elem())
|
||||||
if err != nil {
|
|
||||||
return reflect.Value{}, fmt.Errorf("failed to decode row: %v ", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理解码错误。
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, fmt.Errorf("failed to decode row: %v", err)
|
||||||
|
}
|
||||||
|
// 将该值附加到新的切片中。
|
||||||
slice = reflect.Append(slice, elem)
|
slice = reflect.Append(slice, elem)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 返回最终的切片。
|
||||||
return slice, nil
|
return slice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,15 +61,21 @@ func decodeStruct(row common.KisRow, elemType reflect.Type) (reflect.Value, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 尝试直接反序列化字符串
|
// 尝试直接反序列化字符串
|
||||||
func decodeString(row common.KisRow) (reflect.Value, error) {
|
func decodeString(row common.KisRow, elemType reflect.Type) (reflect.Value, error) {
|
||||||
if str, ok := row.(string); ok {
|
str, ok := row.(string)
|
||||||
var intValue int
|
if !ok {
|
||||||
if _, err := fmt.Sscanf(str, "%d", &intValue); err == nil {
|
return reflect.Value{}, fmt.Errorf("not a string")
|
||||||
return reflect.ValueOf(intValue), nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return reflect.Value{}, fmt.Errorf("not a string ")
|
// 创建一个新的结构体实例,用于存储反序列化后的值。
|
||||||
|
elem := reflect.New(elemType).Elem()
|
||||||
|
|
||||||
|
// 尝试将字符串反序列化为结构体。
|
||||||
|
if err := json.Unmarshal([]byte(str), elem.Addr().Interface()); err != nil {
|
||||||
|
return reflect.Value{}, fmt.Errorf("failed to unmarshal string to struct: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return elem, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 尝试先序列化为 JSON 再反序列化
|
// 尝试先序列化为 JSON 再反序列化
|
Loading…
Reference in New Issue
Block a user