From 76f610dffc426b31f5310979cd487409eebc518d Mon Sep 17 00:00:00 2001 From: huchenhao Date: Tue, 19 Mar 2024 19:16:03 +0800 Subject: [PATCH] feat: automatic injection parameters --- kis/faas.go | 82 +++++++++++++++++++++++++++++++ kis/flow.go | 7 +++ kis/pool.go | 43 +++++++++++++++-- kis/router.go | 9 ++-- kis/serialize.go | 11 +++++ kis/serialize_json.go | 110 ++++++++++++++++++++++++++++++++++++++++++ kis/utils.go | 11 +++++ 7 files changed, 262 insertions(+), 11 deletions(-) create mode 100644 kis/faas.go create mode 100644 kis/serialize.go create mode 100644 kis/serialize_json.go create mode 100644 kis/utils.go diff --git a/kis/faas.go b/kis/faas.go new file mode 100644 index 0000000..388cf22 --- /dev/null +++ b/kis/faas.go @@ -0,0 +1,82 @@ +package kis + +import ( + "errors" + "fmt" + "reflect" + "runtime" +) + +// FaaS Function as a Service +// type FaaS func(context.Context, *kisflow, ...interface{}) error +// 这是一个方法类型,会在注入时在方法内判断 +type FaaS interface{} + +type FaaSDesc struct { + FnName string + f interface{} + fName string + ArgsType []reflect.Type + ArgNum int + FuncType reflect.Type + FuncValue reflect.Value + FaasSerialize +} + +var globalFaaSSerialize = &DefaultFaasSerialize{} + +func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) { + funcValue := reflect.ValueOf(f) + funcType := funcValue.Type() + + if err := validateFuncType(funcType, funcValue); err != nil { + return nil, err + } + + argsType := make([]reflect.Type, funcType.NumIn()) + fullName := runtime.FuncForPC(funcValue.Pointer()).Name() + containsKisflowCtx := false + + for i := 0; i < funcType.NumIn(); i++ { + paramType := funcType.In(i) + fmt.Println(paramType.Kind(), isFlowType(paramType)) + if isFlowType(paramType) { + containsKisflowCtx = true + } + argsType[i] = paramType + } + + if !containsKisflowCtx { + return nil, errors.New("function parameters must have Kisflow context") + } + + if funcType.NumOut() != 1 || funcType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() { + return nil, errors.New("function must have exactly one return value of type error") + } + + // Check if f implements FaasSerialize interface + var serializeImpl FaasSerialize + if ser, ok := f.(FaasSerialize); ok { + serializeImpl = ser + } else { + serializeImpl = globalFaaSSerialize // Use global default implementation + } + + return &FaaSDesc{ + FnName: fnName, + f: f, + fName: fullName, + ArgsType: argsType, + ArgNum: len(argsType), + FuncType: funcType, + FuncValue: funcValue, + FaasSerialize: serializeImpl, + }, nil +} + +func validateFuncType(funcType reflect.Type, funcValue reflect.Value) error { + if funcType.Kind() != reflect.Func { + return fmt.Errorf("provided FaaS type is %s, not a function", funcType.Name()) + } + return nil +} diff --git a/kis/flow.go b/kis/flow.go index 6c631f1..99ce722 100644 --- a/kis/flow.go +++ b/kis/flow.go @@ -4,6 +4,7 @@ import ( "context" "kis-flow/common" "kis-flow/config" + "reflect" "time" ) @@ -49,3 +50,9 @@ type Flow interface { // Fork 得到Flow的一个副本(深拷贝) Fork(ctx context.Context) Flow } + +var flowInterfaceType = reflect.TypeOf((*Flow)(nil)).Elem() + +func isFlowType(paramType reflect.Type) bool { + return paramType.Implements(flowInterfaceType) +} diff --git a/kis/pool.go b/kis/pool.go index b1a9b51..2c926e8 100644 --- a/kis/pool.go +++ b/kis/pool.go @@ -6,12 +6,13 @@ import ( "fmt" "kis-flow/common" "kis-flow/log" + "reflect" "sync" ) var _poolOnce sync.Once -// kisPool 用于管理全部的Function和Flow配置的池子 +// kisPool 用于管理全部的Function和Flow配置的池子 type kisPool struct { fnRouter funcRouter // 全部的Function管理路由 fnLock sync.RWMutex // fnRouter 锁 @@ -76,11 +77,17 @@ func (pool *kisPool) GetFlow(name string) Flow { // FaaS 注册 Function 计算业务逻辑, 通过Function Name 索引及注册 func (pool *kisPool) FaaS(fnName string, f FaaS) { + + faaSDesc, err := NewFaaSDesc(fnName, f) + if err != nil { + panic(err) + } + pool.fnLock.Lock() // 写锁 defer pool.fnLock.Unlock() if _, ok := pool.fnRouter[fnName]; !ok { - pool.fnRouter[fnName] = f + pool.fnRouter[fnName] = faaSDesc } else { errString := fmt.Sprintf("KisPoll FaaS Repeat FuncName=%s", fnName) panic(errString) @@ -91,11 +98,37 @@ func (pool *kisPool) FaaS(fnName string, f FaaS) { // CallFunction 调度 Function func (pool *kisPool) CallFunction(ctx context.Context, fnName string, flow Flow) error { + if funcDesc, ok := pool.fnRouter[fnName]; ok { + params := make([]reflect.Value, 0, funcDesc.ArgNum) + + for _, argType := range funcDesc.ArgsType { + if isFlowType(argType) { + params = append(params, reflect.ValueOf(flow)) + continue + } + if isContextType(argType) { + params = append(params, reflect.ValueOf(ctx)) + continue + } + if argType.Kind() == reflect.Slice { + value, err := funcDesc.FaasSerialize.DecodeParam(flow.Input(), argType) + if err != nil { + return err + } + params = append(params, value) + continue + } + params = append(params, reflect.Zero(argType)) + } + + retValues := funcDesc.FuncValue.Call(params) + ret := retValues[0].Interface() + if ret == nil { + return nil + } + return retValues[0].Interface().(error) - if f, ok := pool.fnRouter[fnName]; ok { - return f(ctx, flow) } - log.Logger().ErrorFX(ctx, "FuncName: %s Can not find in KisPool, Not Added.\n", fnName) return errors.New("FuncName: " + fnName + " Can not find in NsPool, Not Added.") diff --git a/kis/router.go b/kis/router.go index a644e37..6943d18 100644 --- a/kis/router.go +++ b/kis/router.go @@ -8,13 +8,10 @@ import ( /* Function Call */ -// FaaS Function as a Service -type FaaS func(context.Context, Flow) error - // funcRouter // key: Function Name -// value: Function 回调自定义业务 -type funcRouter map[string]FaaS +// value: FaaSDesc 回调自定义业务的描述 +type funcRouter map[string]*FaaSDesc // flowRouter // key: Flow Name @@ -28,7 +25,7 @@ type flowRouter map[string]Flow type ConnInit func(conn Connector) error // connInitRouter -//key: +// key: type connInitRouter map[string]ConnInit /* diff --git a/kis/serialize.go b/kis/serialize.go new file mode 100644 index 0000000..f441e76 --- /dev/null +++ b/kis/serialize.go @@ -0,0 +1,11 @@ +package kis + +import ( + "kis-flow/common" + "reflect" +) + +type FaasSerialize interface { + DecodeParam(common.KisRowArr, reflect.Type) (reflect.Value, error) + EncodeParam(interface{}) (common.KisRowArr, error) +} diff --git a/kis/serialize_json.go b/kis/serialize_json.go new file mode 100644 index 0000000..80900f8 --- /dev/null +++ b/kis/serialize_json.go @@ -0,0 +1,110 @@ +package kis + +import ( + "encoding/json" + "fmt" + "kis-flow/common" + "reflect" +) + +type DefaultFaasSerialize struct { +} + +func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) (reflect.Value, error) { + // 确保传入的类型是一个切片 + if r.Kind() != reflect.Slice { + return reflect.Value{}, fmt.Errorf("r must be a slice") + } + + slice := reflect.MakeSlice(r, 0, len(arr)) + + for _, row := range arr { + var elem reflect.Value + var err error + + // 先尝试断言为结构体或指针 + elem, err = decodeStruct(row, r.Elem()) + if err != nil { + // 如果失败,则尝试直接反序列化为字符串 + elem, err = decodeString(row) + if err != nil { + fmt.Println("---+++-", row) + // 如果还失败,则尝试先序列化为 JSON 再反序列化 + elem, err = decodeJSON(row, r.Elem()) + if err != nil { + return reflect.Value{}, fmt.Errorf("failed to decode row: %v", err) + } + } + } + + slice = reflect.Append(slice, elem) + } + + return slice, nil +} + +// 尝试断言为结构体或指针 +func decodeStruct(row common.KisRow, elemType reflect.Type) (reflect.Value, error) { + elem := reflect.New(elemType).Elem() + + // 如果元素是一个结构体或指针类型,则尝试断言 + if structElem, ok := row.(reflect.Value); ok && structElem.Type().AssignableTo(elemType) { + elem.Set(structElem) + return elem, nil + } + + return reflect.Value{}, fmt.Errorf("not a struct or pointer") +} + +// 尝试直接反序列化字符串 +func decodeString(row common.KisRow) (reflect.Value, error) { + if str, ok := row.(string); ok { + var intValue int + if _, err := fmt.Sscanf(str, "%d", &intValue); err == nil { + return reflect.ValueOf(intValue), nil + } + } + + return reflect.Value{}, fmt.Errorf("not a string") +} + +// 尝试先序列化为 JSON 再反序列化 +func decodeJSON(row common.KisRow, elemType reflect.Type) (reflect.Value, error) { + jsonBytes, err := json.Marshal(row) + if err != nil { + return reflect.Value{}, fmt.Errorf("failed to marshal row to JSON: %v", err) + } + + elem := reflect.New(elemType).Interface() + if err := json.Unmarshal(jsonBytes, elem); err != nil { + return reflect.Value{}, fmt.Errorf("failed to unmarshal JSON to element: %v", err) + } + + return reflect.ValueOf(elem).Elem(), nil +} + +func (f DefaultFaasSerialize) EncodeParam(i interface{}) (common.KisRowArr, error) { + var arr common.KisRowArr + + switch reflect.TypeOf(i).Kind() { + case reflect.Slice, reflect.Array: + slice := reflect.ValueOf(i) + for i := 0; i < slice.Len(); i++ { + // 序列化每个元素为 JSON 字符串,并将其添加到切片中。 + jsonBytes, err := json.Marshal(slice.Index(i).Interface()) + if err != nil { + return nil, fmt.Errorf("failed to marshal element to JSON: %v", err) + } + arr = append(arr, string(jsonBytes)) + } + default: + // 如果不是切片或数组类型,则直接序列化整个结构体为 JSON 字符串。 + jsonBytes, err := json.Marshal(i) + if err != nil { + return nil, fmt.Errorf("failed to marshal element to JSON: %v", err) + } + arr = append(arr, string(jsonBytes)) + } + + return arr, nil +} diff --git a/kis/utils.go b/kis/utils.go new file mode 100644 index 0000000..523071d --- /dev/null +++ b/kis/utils.go @@ -0,0 +1,11 @@ +:package kis + +import ( + "reflect" + "strings" +) + +func isContextType(paramType reflect.Type) bool { + typeName := paramType.Name() + return strings.Contains(typeName, "Context") +}