go-zero/tools/goctl/api/gogen/genlogic.go

161 lines
3.9 KiB
Go
Raw Normal View History

2020-07-29 17:11:41 +08:00
package gogen
import (
"fmt"
"path"
"strconv"
2020-07-29 17:11:41 +08:00
"strings"
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
2020-08-08 16:40:10 +08:00
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util/pathx"
2020-08-08 16:40:10 +08:00
"github.com/tal-tech/go-zero/tools/goctl/vars"
2020-07-29 17:11:41 +08:00
)
const logicTemplate = `package {{.pkgName}}
2020-07-29 17:11:41 +08:00
import (
{{.imports}}
)
type {{.logic}} struct {
2020-09-03 10:15:14 +08:00
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
2020-07-29 17:11:41 +08:00
}
func New{{.logic}}(ctx context.Context, svcCtx *svc.ServiceContext) {{.logic}} {
return {{.logic}}{
2020-09-03 10:15:14 +08:00
Logger: logx.WithContext(ctx),
2020-07-29 17:11:41 +08:00
ctx: ctx,
svcCtx: svcCtx,
2020-07-29 17:11:41 +08:00
}
}
func (l *{{.logic}}) {{.function}}({{.request}}) {{.responseType}} {
2020-08-31 20:52:29 +08:00
// todo: add your logic here and delete this line
2020-07-29 17:11:41 +08:00
{{.returnString}}
}
`
func genLogic(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
2020-07-29 17:11:41 +08:00
for _, g := range api.Service.Groups {
for _, r := range g.Routes {
err := genLogicByRoute(dir, rootPkg, cfg, g, r)
2020-07-29 17:11:41 +08:00
if err != nil {
return err
}
}
}
return nil
}
func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
logic := getLogicName(route)
goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
if err != nil {
return err
2020-07-29 17:11:41 +08:00
}
2020-09-03 14:00:09 +08:00
imports := genLogicImports(route, rootPkg)
2020-09-03 14:00:09 +08:00
var responseString string
var returnString string
var requestString string
if len(route.ResponseTypeName()) > 0 {
resp := responseGoTypeName(route, typesPacket)
responseString = "(resp " + resp + ", err error)"
returnString = "return"
2020-07-29 17:11:41 +08:00
} else {
responseString = "error"
returnString = "return nil"
}
if len(route.RequestTypeName()) > 0 {
requestString = "req " + requestGoTypeName(route, typesPacket)
2020-07-29 17:11:41 +08:00
}
subDir := getLogicFolderPath(group, route)
2021-01-09 00:17:23 +08:00
return genFile(fileGenConfig{
dir: dir,
subdir: subDir,
2021-01-09 00:17:23 +08:00
filename: goFile + ".go",
templateName: "logicTemplate",
category: category,
templateFile: logicTemplateFile,
builtinTemplate: logicTemplate,
data: map[string]string{
"pkgName": subDir[strings.LastIndex(subDir, "/")+1:],
2021-01-09 00:17:23 +08:00
"imports": imports,
"logic": strings.Title(logic),
"function": strings.Title(strings.TrimSuffix(logic, "Logic")),
"responseType": responseString,
"returnString": returnString,
"request": requestString,
},
2020-07-29 17:11:41 +08:00
})
}
func getLogicFolderPath(group spec.Group, route spec.Route) string {
folder := route.GetAnnotation(groupProperty)
if len(folder) == 0 {
folder = group.GetAnnotation(groupProperty)
if len(folder) == 0 {
2020-07-29 17:11:41 +08:00
return logicDir
}
}
folder = strings.TrimPrefix(folder, "/")
folder = strings.TrimSuffix(folder, "/")
return path.Join(logicDir, folder)
}
func genLogicImports(route spec.Route, parentPkg string) string {
var imports []string
imports = append(imports, `"context"`+"\n")
2020-08-27 14:40:05 +08:00
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir)))
if shallImportTypesPackage(route) {
2020-08-27 14:40:05 +08:00
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
2020-07-29 17:11:41 +08:00
}
2021-02-20 19:50:03 +08:00
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
2020-07-29 17:11:41 +08:00
return strings.Join(imports, "\n\t")
}
func onlyPrimitiveTypes(val string) bool {
fields := strings.FieldsFunc(val, func(r rune) bool {
return r == '[' || r == ']' || r == ' '
})
for _, field := range fields {
if field == "map" {
continue
}
// ignore array dimension number, like [5]int
if _, err := strconv.Atoi(field); err == nil {
continue
}
if !api.IsBasicType(field) {
return false
}
}
return true
}
func shallImportTypesPackage(route spec.Route) bool {
if len(route.RequestTypeName()) > 0 {
return true
}
respTypeName := route.ResponseTypeName()
if len(respTypeName) == 0 {
return false
}
if onlyPrimitiveTypes(respTypeName) {
return false
}
return true
}