mirror of
https://github.com/aceld/kis-flow.git
synced 2025-01-23 07:30:23 +08:00
feat: automatic injection parameters
This commit is contained in:
parent
fa62729026
commit
76f610dffc
82
kis/faas.go
Normal file
82
kis/faas.go
Normal file
@ -0,0 +1,82 @@
|
||||
package kis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// FaaS Function as a Service
|
||||
// type FaaS func(context.Context, *kisflow, ...interface{}) error
|
||||
// 这是一个方法类型,会在注入时在方法内判断
|
||||
type FaaS interface{}
|
||||
|
||||
type FaaSDesc struct {
|
||||
FnName string
|
||||
f interface{}
|
||||
fName string
|
||||
ArgsType []reflect.Type
|
||||
ArgNum int
|
||||
FuncType reflect.Type
|
||||
FuncValue reflect.Value
|
||||
FaasSerialize
|
||||
}
|
||||
|
||||
var globalFaaSSerialize = &DefaultFaasSerialize{}
|
||||
|
||||
func NewFaaSDesc(fnName string, f FaaS) (*FaaSDesc, error) {
|
||||
funcValue := reflect.ValueOf(f)
|
||||
funcType := funcValue.Type()
|
||||
|
||||
if err := validateFuncType(funcType, funcValue); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
argsType := make([]reflect.Type, funcType.NumIn())
|
||||
fullName := runtime.FuncForPC(funcValue.Pointer()).Name()
|
||||
containsKisflowCtx := false
|
||||
|
||||
for i := 0; i < funcType.NumIn(); i++ {
|
||||
paramType := funcType.In(i)
|
||||
fmt.Println(paramType.Kind(), isFlowType(paramType))
|
||||
if isFlowType(paramType) {
|
||||
containsKisflowCtx = true
|
||||
}
|
||||
argsType[i] = paramType
|
||||
}
|
||||
|
||||
if !containsKisflowCtx {
|
||||
return nil, errors.New("function parameters must have Kisflow context")
|
||||
}
|
||||
|
||||
if funcType.NumOut() != 1 || funcType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
|
||||
return nil, errors.New("function must have exactly one return value of type error")
|
||||
}
|
||||
|
||||
// Check if f implements FaasSerialize interface
|
||||
var serializeImpl FaasSerialize
|
||||
if ser, ok := f.(FaasSerialize); ok {
|
||||
serializeImpl = ser
|
||||
} else {
|
||||
serializeImpl = globalFaaSSerialize // Use global default implementation
|
||||
}
|
||||
|
||||
return &FaaSDesc{
|
||||
FnName: fnName,
|
||||
f: f,
|
||||
fName: fullName,
|
||||
ArgsType: argsType,
|
||||
ArgNum: len(argsType),
|
||||
FuncType: funcType,
|
||||
FuncValue: funcValue,
|
||||
FaasSerialize: serializeImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateFuncType(funcType reflect.Type, funcValue reflect.Value) error {
|
||||
if funcType.Kind() != reflect.Func {
|
||||
return fmt.Errorf("provided FaaS type is %s, not a function", funcType.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"kis-flow/common"
|
||||
"kis-flow/config"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -49,3 +50,9 @@ type Flow interface {
|
||||
// Fork 得到Flow的一个副本(深拷贝)
|
||||
Fork(ctx context.Context) Flow
|
||||
}
|
||||
|
||||
var flowInterfaceType = reflect.TypeOf((*Flow)(nil)).Elem()
|
||||
|
||||
func isFlowType(paramType reflect.Type) bool {
|
||||
return paramType.Implements(flowInterfaceType)
|
||||
}
|
||||
|
43
kis/pool.go
43
kis/pool.go
@ -6,12 +6,13 @@ import (
|
||||
"fmt"
|
||||
"kis-flow/common"
|
||||
"kis-flow/log"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var _poolOnce sync.Once
|
||||
|
||||
// kisPool 用于管理全部的Function和Flow配置的池子
|
||||
// kisPool 用于管理全部的Function和Flow配置的池子
|
||||
type kisPool struct {
|
||||
fnRouter funcRouter // 全部的Function管理路由
|
||||
fnLock sync.RWMutex // fnRouter 锁
|
||||
@ -76,11 +77,17 @@ func (pool *kisPool) GetFlow(name string) Flow {
|
||||
|
||||
// FaaS 注册 Function 计算业务逻辑, 通过Function Name 索引及注册
|
||||
func (pool *kisPool) FaaS(fnName string, f FaaS) {
|
||||
|
||||
faaSDesc, err := NewFaaSDesc(fnName, f)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pool.fnLock.Lock() // 写锁
|
||||
defer pool.fnLock.Unlock()
|
||||
|
||||
if _, ok := pool.fnRouter[fnName]; !ok {
|
||||
pool.fnRouter[fnName] = f
|
||||
pool.fnRouter[fnName] = faaSDesc
|
||||
} else {
|
||||
errString := fmt.Sprintf("KisPoll FaaS Repeat FuncName=%s", fnName)
|
||||
panic(errString)
|
||||
@ -91,11 +98,37 @@ func (pool *kisPool) FaaS(fnName string, f FaaS) {
|
||||
|
||||
// CallFunction 调度 Function
|
||||
func (pool *kisPool) CallFunction(ctx context.Context, fnName string, flow Flow) error {
|
||||
if funcDesc, ok := pool.fnRouter[fnName]; ok {
|
||||
params := make([]reflect.Value, 0, funcDesc.ArgNum)
|
||||
|
||||
for _, argType := range funcDesc.ArgsType {
|
||||
if isFlowType(argType) {
|
||||
params = append(params, reflect.ValueOf(flow))
|
||||
continue
|
||||
}
|
||||
if isContextType(argType) {
|
||||
params = append(params, reflect.ValueOf(ctx))
|
||||
continue
|
||||
}
|
||||
if argType.Kind() == reflect.Slice {
|
||||
value, err := funcDesc.FaasSerialize.DecodeParam(flow.Input(), argType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params = append(params, value)
|
||||
continue
|
||||
}
|
||||
params = append(params, reflect.Zero(argType))
|
||||
}
|
||||
|
||||
retValues := funcDesc.FuncValue.Call(params)
|
||||
ret := retValues[0].Interface()
|
||||
if ret == nil {
|
||||
return nil
|
||||
}
|
||||
return retValues[0].Interface().(error)
|
||||
|
||||
if f, ok := pool.fnRouter[fnName]; ok {
|
||||
return f(ctx, flow)
|
||||
}
|
||||
|
||||
log.Logger().ErrorFX(ctx, "FuncName: %s Can not find in KisPool, Not Added.\n", fnName)
|
||||
|
||||
return errors.New("FuncName: " + fnName + " Can not find in NsPool, Not Added.")
|
||||
|
@ -8,13 +8,10 @@ import (
|
||||
/*
|
||||
Function Call
|
||||
*/
|
||||
// FaaS Function as a Service
|
||||
type FaaS func(context.Context, Flow) error
|
||||
|
||||
// funcRouter
|
||||
// key: Function Name
|
||||
// value: Function 回调自定义业务
|
||||
type funcRouter map[string]FaaS
|
||||
// value: FaaSDesc 回调自定义业务的描述
|
||||
type funcRouter map[string]*FaaSDesc
|
||||
|
||||
// flowRouter
|
||||
// key: Flow Name
|
||||
@ -28,7 +25,7 @@ type flowRouter map[string]Flow
|
||||
type ConnInit func(conn Connector) error
|
||||
|
||||
// connInitRouter
|
||||
//key:
|
||||
// key:
|
||||
type connInitRouter map[string]ConnInit
|
||||
|
||||
/*
|
||||
|
11
kis/serialize.go
Normal file
11
kis/serialize.go
Normal file
@ -0,0 +1,11 @@
|
||||
package kis
|
||||
|
||||
import (
|
||||
"kis-flow/common"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type FaasSerialize interface {
|
||||
DecodeParam(common.KisRowArr, reflect.Type) (reflect.Value, error)
|
||||
EncodeParam(interface{}) (common.KisRowArr, error)
|
||||
}
|
110
kis/serialize_json.go
Normal file
110
kis/serialize_json.go
Normal file
@ -0,0 +1,110 @@
|
||||
package kis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"kis-flow/common"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type DefaultFaasSerialize struct {
|
||||
}
|
||||
|
||||
func (f DefaultFaasSerialize) DecodeParam(arr common.KisRowArr, r reflect.Type) (reflect.Value, error) {
|
||||
// 确保传入的类型是一个切片
|
||||
if r.Kind() != reflect.Slice {
|
||||
return reflect.Value{}, fmt.Errorf("r must be a slice")
|
||||
}
|
||||
|
||||
slice := reflect.MakeSlice(r, 0, len(arr))
|
||||
|
||||
for _, row := range arr {
|
||||
var elem reflect.Value
|
||||
var err error
|
||||
|
||||
// 先尝试断言为结构体或指针
|
||||
elem, err = decodeStruct(row, r.Elem())
|
||||
if err != nil {
|
||||
// 如果失败,则尝试直接反序列化为字符串
|
||||
elem, err = decodeString(row)
|
||||
if err != nil {
|
||||
fmt.Println("---+++-", row)
|
||||
// 如果还失败,则尝试先序列化为 JSON 再反序列化
|
||||
elem, err = decodeJSON(row, r.Elem())
|
||||
if err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("failed to decode row: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slice = reflect.Append(slice, elem)
|
||||
}
|
||||
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// 尝试断言为结构体或指针
|
||||
func decodeStruct(row common.KisRow, elemType reflect.Type) (reflect.Value, error) {
|
||||
elem := reflect.New(elemType).Elem()
|
||||
|
||||
// 如果元素是一个结构体或指针类型,则尝试断言
|
||||
if structElem, ok := row.(reflect.Value); ok && structElem.Type().AssignableTo(elemType) {
|
||||
elem.Set(structElem)
|
||||
return elem, nil
|
||||
}
|
||||
|
||||
return reflect.Value{}, fmt.Errorf("not a struct or pointer")
|
||||
}
|
||||
|
||||
// 尝试直接反序列化字符串
|
||||
func decodeString(row common.KisRow) (reflect.Value, error) {
|
||||
if str, ok := row.(string); ok {
|
||||
var intValue int
|
||||
if _, err := fmt.Sscanf(str, "%d", &intValue); err == nil {
|
||||
return reflect.ValueOf(intValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
return reflect.Value{}, fmt.Errorf("not a string")
|
||||
}
|
||||
|
||||
// 尝试先序列化为 JSON 再反序列化
|
||||
func decodeJSON(row common.KisRow, elemType reflect.Type) (reflect.Value, error) {
|
||||
jsonBytes, err := json.Marshal(row)
|
||||
if err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("failed to marshal row to JSON: %v", err)
|
||||
}
|
||||
|
||||
elem := reflect.New(elemType).Interface()
|
||||
if err := json.Unmarshal(jsonBytes, elem); err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("failed to unmarshal JSON to element: %v", err)
|
||||
}
|
||||
|
||||
return reflect.ValueOf(elem).Elem(), nil
|
||||
}
|
||||
|
||||
func (f DefaultFaasSerialize) EncodeParam(i interface{}) (common.KisRowArr, error) {
|
||||
var arr common.KisRowArr
|
||||
|
||||
switch reflect.TypeOf(i).Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
slice := reflect.ValueOf(i)
|
||||
for i := 0; i < slice.Len(); i++ {
|
||||
// 序列化每个元素为 JSON 字符串,并将其添加到切片中。
|
||||
jsonBytes, err := json.Marshal(slice.Index(i).Interface())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal element to JSON: %v", err)
|
||||
}
|
||||
arr = append(arr, string(jsonBytes))
|
||||
}
|
||||
default:
|
||||
// 如果不是切片或数组类型,则直接序列化整个结构体为 JSON 字符串。
|
||||
jsonBytes, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal element to JSON: %v", err)
|
||||
}
|
||||
arr = append(arr, string(jsonBytes))
|
||||
}
|
||||
|
||||
return arr, nil
|
||||
}
|
11
kis/utils.go
Normal file
11
kis/utils.go
Normal file
@ -0,0 +1,11 @@
|
||||
:package kis
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func isContextType(paramType reflect.Type) bool {
|
||||
typeName := paramType.Name()
|
||||
return strings.Contains(typeName, "Context")
|
||||
}
|
Loading…
Reference in New Issue
Block a user