Code optimized (#493)

This commit is contained in:
kingxt 2021-02-20 19:50:03 +08:00 committed by GitHub
parent 059027bc9d
commit f98c9246b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 472 additions and 372 deletions

View File

@ -39,6 +39,7 @@ service {{.serviceName}} {
} }
` `
// ApiCommand create api template file
func ApiCommand(c *cli.Context) error { func ApiCommand(c *cli.Context) error {
apiFile := c.String("o") apiFile := c.String("o")
if len(apiFile) == 0 { if len(apiFile) == 0 {

View File

@ -9,6 +9,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
// DartCommand create dart network request code
func DartCommand(c *cli.Context) error { func DartCommand(c *cli.Context) error {
apiFile := c.String("api") apiFile := c.String("api")
dir := c.String("dir") dir := c.String("dir")

View File

@ -2,9 +2,9 @@ package dartgen
import ( import (
"os" "os"
"reflect"
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
) )
@ -32,10 +32,18 @@ func pathToFuncName(path string) string {
return util.ToLower(camel[:1]) + camel[1:] return util.ToLower(camel[:1]) + camel[1:]
} }
func tagGet(tag, k string) (reflect.Value, error) { func tagGet(tag, k string) string {
v, _ := util.TagLookup(tag, k) tags, err := spec.Parse(tag)
out := strings.Split(v, ",")[0] if err != nil {
return reflect.ValueOf(out), nil panic(k + " not exist")
}
v, err := tags.Get(k)
if err != nil {
panic(k + " value not exist")
}
return v.Name
} }
func isDirectType(s string) bool { func isDirectType(s string) bool {

View File

@ -12,6 +12,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
// DocCommand generate markdown doc file
func DocCommand(c *cli.Context) error { func DocCommand(c *cli.Context) error {
dir := c.String("dir") dir := c.String("dir")
if len(dir) == 0 { if len(dir) == 0 {

View File

@ -160,6 +160,7 @@ type Response struct {
@server( @server(
jwt: Auth jwt: Auth
signature: true
) )
service A-api { service A-api {
@handler GreetHandler @handler GreetHandler

View File

@ -40,7 +40,7 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
for _, item := range authNames { for _, item := range authNames {
auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate)) auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate))
} }
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
return genFile(fileGenConfig{ return genFile(fileGenConfig{
dir: dir, dir: dir,

View File

@ -109,7 +109,7 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str
if len(route.RequestTypeName()) > 0 { if len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir)))
} }
imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t") return strings.Join(imports, "\n\t")
} }

View File

@ -122,6 +122,6 @@ func genLogicImports(route spec.Route, parentPkg string) string {
if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 { if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
} }
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t") return strings.Join(imports, "\n\t")
} }

View File

@ -74,7 +74,7 @@ func genMainImports(parentPkg string) string {
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir))) imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir)))
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir))) imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir)))
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir)))
imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceURL))
imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t") return strings.Join(imports, "\n\t")
} }

View File

@ -89,7 +89,7 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
} }
var signature string var signature string
if g.signatureEnabled { if g.signatureEnabled {
signature = fmt.Sprintf("\n rest.WithSignature(serverCtx.Config.%s.Signature),", g.authName) signature = "\n rest.WithSignature(serverCtx.Config.Signature),"
} }
var routes string var routes string
@ -163,7 +163,7 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
imports := importSet.KeysStr() imports := importSet.KeysStr()
sort.Strings(imports) sort.Strings(imports)
projectSection := strings.Join(imports, "\n\t") projectSection := strings.Join(imports, "\n\t")
depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
return fmt.Sprintf("%s\n\n\t%s", projectSection, depSection) return fmt.Sprintf("%s\n\n\t%s", projectSection, depSection)
} }
@ -196,6 +196,10 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
groupedRoutes.authName = jwt groupedRoutes.authName = jwt
groupedRoutes.jwtEnabled = true groupedRoutes.jwtEnabled = true
} }
signature := g.GetAnnotation("signature")
if signature == "true" {
groupedRoutes.signatureEnabled = true
}
middleware := g.GetAnnotation("middleware") middleware := g.GetAnnotation("middleware")
if len(middleware) > 0 { if len(middleware) > 0 {
for _, item := range strings.Split(middleware, ",") { for _, item := range strings.Split(middleware, ",") {

View File

@ -64,7 +64,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
if len(middlewareStr) > 0 { if len(middlewareStr) > 0 {
configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\"" configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl) configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceURL)
} }
return genFile(fileGenConfig{ return genFile(fileGenConfig{

View File

@ -94,10 +94,6 @@ func getAuths(api *spec.ApiSpec) []string {
if len(jwt) > 0 { if len(jwt) > 0 {
authNames.Add(jwt) authNames.Add(jwt)
} }
signature := g.GetAnnotation("signature")
if len(signature) > 0 {
authNames.Add(signature)
}
} }
return authNames.KeysStr() return authNames.KeysStr()
} }

View File

@ -22,12 +22,6 @@ type Api struct {
} }
func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} { func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
defer func() {
if p := recover(); p != nil {
panic(fmt.Errorf("%+v", p))
}
}()
var final Api var final Api
final.importM = map[string]PlaceHolder{} final.importM = map[string]PlaceHolder{}
final.typeM = map[string]PlaceHolder{} final.typeM = map[string]PlaceHolder{}
@ -36,63 +30,22 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
final.routeM = map[string]PlaceHolder{} final.routeM = map[string]PlaceHolder{}
for _, each := range ctx.AllSpec() { for _, each := range ctx.AllSpec() {
root := each.Accept(v).(*Api) root := each.Accept(v).(*Api)
if root.Syntax != nil { v.acceptSyntax(root, &final)
if final.Syntax != nil { v.accetpImport(root, &final)
v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration")) v.acceptInfo(root, &final)
v.acceptType(root, &final)
v.acceptService(root, &final)
} }
final.Syntax = root.Syntax return &final
}
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
}
final.importM[imp.Value.Text()] = Holder
final.Import = append(final.Import, imp)
}
if root.Info != nil {
infoM := map[string]PlaceHolder{}
if final.Info != nil {
v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration"))
}
for _, value := range root.Info.Kvs {
if _, ok := infoM[value.Key.Text()]; ok {
v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text()))
}
infoM[value.Key.Text()] = Holder
}
final.Info = root.Info
}
for _, tp := range root.Type {
if _, ok := final.typeM[tp.NameExpr().Text()]; ok {
v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text()))
}
final.typeM[tp.NameExpr().Text()] = Holder
final.Type = append(final.Type, tp)
} }
func (v *ApiVisitor) acceptService(root *Api, final *Api) {
for _, service := range root.Service { for _, service := range root.Service {
if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 { if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 {
v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration")) v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration"))
} }
v.duplicateServerItemCheck(service)
if service.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range service.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
}
}
for _, route := range service.ServiceApi.ServiceRoute { for _, route := range service.ServiceApi.ServiceRoute {
uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text()) uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text())
@ -136,7 +89,67 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
} }
} }
return &final func (v *ApiVisitor) duplicateServerItemCheck(service *Service) {
if service.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range service.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
}
}
}
func (v *ApiVisitor) acceptType(root *Api, final *Api) {
for _, tp := range root.Type {
if _, ok := final.typeM[tp.NameExpr().Text()]; ok {
v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text()))
}
final.typeM[tp.NameExpr().Text()] = Holder
final.Type = append(final.Type, tp)
}
}
func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
if root.Info != nil {
infoM := map[string]PlaceHolder{}
if final.Info != nil {
v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration"))
}
for _, value := range root.Info.Kvs {
if _, ok := infoM[value.Key.Text()]; ok {
v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text()))
}
infoM[value.Key.Text()] = Holder
}
final.Info = root.Info
}
}
func (v *ApiVisitor) accetpImport(root *Api, final *Api) {
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
}
final.importM[imp.Value.Text()] = Holder
final.Import = append(final.Import, imp)
}
}
func (v *ApiVisitor) acceptSyntax(root *Api, final *Api) {
if root.Syntax != nil {
if final.Syntax != nil {
v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration"))
}
final.Syntax = root.Syntax
}
} }
func (v *ApiVisitor) VisitSpec(ctx *api.SpecContext) interface{} { func (v *ApiVisitor) VisitSpec(ctx *api.SpecContext) interface{} {

View File

@ -156,28 +156,9 @@ func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
} }
func (p *Parser) valid(mainApi *Api, nestedApi *Api) error { func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
if len(nestedApi.Import) > 0 { err := p.nestedApiCheck(mainApi, nestedApi)
importToken := nestedApi.Import[0].Import if err != nil {
return fmt.Errorf("%s line %d:%d the nested api does not support import", return err
nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
}
} }
mainHandlerMap := make(map[string]PlaceHolder) mainHandlerMap := make(map[string]PlaceHolder)
@ -218,6 +199,23 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
} }
// duplicate route check // duplicate route check
err = p.duplicateRouteCheck(nestedApi, mainHandlerMap, mainRouteMap)
if err != nil {
return err
}
// duplicate type check
for _, each := range nestedApi.Type {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
}
}
return nil
}
func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap map[string]PlaceHolder, mainRouteMap map[string]PlaceHolder) error {
for _, each := range nestedApi.Service { for _, each := range nestedApi.Service {
for _, r := range each.ServiceApi.ServiceRoute { for _, r := range each.ServiceApi.ServiceRoute {
handler := r.GetHandler() handler := r.GetHandler()
@ -237,12 +235,31 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
} }
} }
} }
return nil
}
// duplicate type check func (p *Parser) nestedApiCheck(mainApi *Api, nestedApi *Api) error {
for _, each := range nestedApi.Type { if len(nestedApi.Import) > 0 {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok { importToken := nestedApi.Import[0].Import
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'", return fmt.Errorf("%s line %d:%d the nested api does not support import",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text()) nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
} }
} }
return nil return nil
@ -276,29 +293,26 @@ func (p *Parser) checkTypeDeclaration(apiList []*Api) error {
for _, apiItem := range apiList { for _, apiItem := range apiList {
linePrefix := apiItem.LinePrefix linePrefix := apiItem.LinePrefix
for _, each := range apiItem.Type { err := p.checkTypes(apiItem, linePrefix, types)
tp, ok := each.(*TypeStruct) if err != nil {
if !ok { return err
continue
} }
for _, member := range tp.Fields { err = p.checkServices(apiItem, types, linePrefix)
err := p.checkType(linePrefix, types, member.DataType)
if err != nil { if err != nil {
return err return err
} }
} }
return nil
} }
func (p *Parser) checkServices(apiItem *Api, types map[string]TypeExpr, linePrefix string) error {
for _, service := range apiItem.Service { for _, service := range apiItem.Service {
for _, each := range service.ServiceApi.ServiceRoute { for _, each := range service.ServiceApi.ServiceRoute {
route := each.Route route := each.Route
if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() { err := p.checkRequestBody(route, types, linePrefix)
_, ok := types[route.Req.Name.Expr().Text()] if err != nil {
if !ok { return err
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text())
}
} }
if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() { if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() {
@ -328,6 +342,33 @@ func (p *Parser) checkTypeDeclaration(apiList []*Api) error {
} }
} }
} }
return nil
}
func (p *Parser) checkRequestBody(route *Route, types map[string]TypeExpr, linePrefix string) error {
if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() {
_, ok := types[route.Req.Name.Expr().Text()]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text())
}
}
return nil
}
func (p *Parser) checkTypes(apiItem *Api, linePrefix string, types map[string]TypeExpr) error {
for _, each := range apiItem.Type {
tp, ok := each.(*TypeStruct)
if !ok {
continue
}
for _, member := range tp.Fields {
err := p.checkType(linePrefix, types, member.DataType)
if err != nil {
return err
}
}
} }
return nil return nil
} }

View File

@ -213,13 +213,7 @@ func (p parser) fillService() error {
var groups []spec.Group var groups []spec.Group
for _, item := range p.ast.Service { for _, item := range p.ast.Service {
var group spec.Group var group spec.Group
if item.AtServer != nil { p.fillAtServer(item, &group)
var properties = make(map[string]string, 0)
for _, kv := range item.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
group.Annotation.Properties = properties
}
for _, astRoute := range item.ServiceApi.ServiceRoute { for _, astRoute := range item.ServiceApi.ServiceRoute {
route := spec.Route{ route := spec.Route{
@ -231,6 +225,49 @@ func (p parser) fillService() error {
route.Handler = astRoute.AtHandler.Name.Text() route.Handler = astRoute.AtHandler.Name.Text()
} }
err := p.fillRouteAtServer(astRoute, &route)
if err != nil {
return err
}
if astRoute.Route.Req != nil {
route.RequestType = p.astTypeToSpec(astRoute.Route.Req.Name)
}
if astRoute.Route.Reply != nil {
route.ResponseType = p.astTypeToSpec(astRoute.Route.Reply.Name)
}
if astRoute.AtDoc != nil {
var properties = make(map[string]string, 0)
for _, kv := range astRoute.AtDoc.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
route.AtDoc.Properties = properties
if astRoute.AtDoc.LineDoc != nil {
route.AtDoc.Text = astRoute.AtDoc.LineDoc.Text()
}
}
err = p.fillRouteType(&route)
if err != nil {
return err
}
group.Routes = append(group.Routes, route)
name := item.ServiceApi.Name.Text()
if len(p.spec.Service.Name) > 0 && p.spec.Service.Name != name {
return fmt.Errorf("mulit service name defined %s and %s", name, p.spec.Service.Name)
}
p.spec.Service.Name = name
}
groups = append(groups, group)
}
p.spec.Service.Groups = groups
return nil
}
func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route) error {
if astRoute.AtServer != nil { if astRoute.AtServer != nil {
var properties = make(map[string]string, 0) var properties = make(map[string]string, 0)
for _, kv := range astRoute.AtServer.Kv { for _, kv := range astRoute.AtServer.Kv {
@ -251,44 +288,19 @@ func (p parser) fillService() error {
} }
} }
} }
return nil
}
if astRoute.Route.Req != nil { func (p parser) fillAtServer(item *ast.Service, group *spec.Group) {
route.RequestType = p.astTypeToSpec(astRoute.Route.Req.Name) if item.AtServer != nil {
}
if astRoute.Route.Reply != nil {
route.ResponseType = p.astTypeToSpec(astRoute.Route.Reply.Name)
}
if astRoute.AtDoc != nil {
var properties = make(map[string]string, 0) var properties = make(map[string]string, 0)
for _, kv := range astRoute.AtDoc.Kv { for _, kv := range item.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text() properties[kv.Key.Text()] = kv.Value.Text()
} }
route.AtDoc.Properties = properties group.Annotation.Properties = properties
if astRoute.AtDoc.LineDoc != nil {
route.AtDoc.Text = astRoute.AtDoc.LineDoc.Text()
} }
} }
err := p.fillRouteType(&route)
if err != nil {
return err
}
group.Routes = append(group.Routes, route)
name := item.ServiceApi.Name.Text()
if len(p.spec.Service.Name) > 0 && p.spec.Service.Name != name {
return fmt.Errorf("mulit service name defined %s and %s", name, p.spec.Service.Name)
}
p.spec.Service.Name = name
}
groups = append(groups, group)
}
p.spec.Service.Groups = groups
return nil
}
func (p parser) fillRouteType(route *spec.Route) error { func (p parser) fillRouteType(route *spec.Route) error {
if route.RequestType != nil { if route.RequestType != nil {
switch route.RequestType.(type) { switch route.RequestType.(type) {

View File

@ -1,58 +0,0 @@
package util
import (
"strconv"
"strings"
)
func TagLookup(tag, key string) (value string, ok bool) {
tag = strings.Replace(tag, "`", "", -1)
for tag != "" {
// Skip leading space.
i := 0
for i < len(tag) && tag[i] == ' ' {
i++
}
tag = tag[i:]
if tag == "" {
break
}
// Scan to colon. A space, a quote or a control character is a syntax error.
// Strictly speaking, control chars include the range [0x7f, 0x9f], not just
// [0x00, 0x1f], but in practice, we ignore the multi-byte control characters
// as it is simpler to inspect the tag's bytes than the tag's runes.
i = 0
for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f {
i++
}
if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' {
break
}
name := string(tag[:i])
tag = tag[i+1:]
// Scan quoted string to find value.
i = 1
for i < len(tag) && tag[i] != '"' {
if tag[i] == '\\' {
i++
}
i++
}
if i >= len(tag) {
break
}
qvalue := string(tag[:i+1])
tag = tag[i+1:]
if key == name {
value, err := strconv.Unquote(qvalue)
if err != nil {
break
}
return value, true
}
}
return "", false
}

View File

@ -28,7 +28,7 @@ import (
) )
var ( var (
BuildVersion = "1.1.5" buildVersion = "1.1.5"
commands = []cli.Command{ commands = []cli.Command{
{ {
Name: "upgrade", Name: "upgrade",
@ -510,7 +510,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Usage = "a cli tool to generate code" app.Usage = "a cli tool to generate code"
app.Version = fmt.Sprintf("%s %s/%s", BuildVersion, runtime.GOOS, runtime.GOARCH) app.Version = fmt.Sprintf("%s %s/%s", buildVersion, runtime.GOOS, runtime.GOARCH)
app.Commands = commands app.Commands = commands
// cli already print error messages // cli already print error messages
if err := app.Run(os.Args); err != nil { if err := app.Run(os.Args); err != nil {

View File

@ -1,6 +1,7 @@
package gen package gen
import ( import (
"bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -31,7 +32,20 @@ type (
pkg string pkg string
cfg *config.Config cfg *config.Config
} }
Option func(generator *defaultGenerator) Option func(generator *defaultGenerator)
code struct {
importsCode string
varsCode string
typesCode string
newCode string
insertCode string
findCode []string
updateCode string
deleteCode string
cacheExtra string
}
) )
func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaultGenerator, error) { func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaultGenerator, error) {
@ -186,15 +200,6 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", fmt.Errorf("table %s: missing primary key", in.Name.Source()) return "", fmt.Errorf("table %s: missing primary key", in.Name.Source())
} }
text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
if err != nil {
return "", err
}
t := util.With("model").
Parse(text).
GoFmt(true)
m, err := genCacheKeys(in) m, err := genCacheKeys(in)
if err != nil { if err != nil {
return "", err return "", err
@ -261,18 +266,19 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", err return "", err
} }
output, err := t.Execute(map[string]interface{}{ code := &code{
"pkg": g.pkg, importsCode: importsCode,
"imports": importsCode, varsCode: varsCode,
"vars": varsCode, typesCode: typesCode,
"types": typesCode, newCode: newCode,
"new": newCode, insertCode: insertCode,
"insert": insertCode, findCode: findCode,
"find": strings.Join(findCode, "\n"), updateCode: updateCode,
"update": updateCode, deleteCode: deleteCode,
"delete": deleteCode, cacheExtra: ret.cacheExtra,
"extraMethod": ret.cacheExtra, }
})
output, err := g.executeModel(code)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -280,6 +286,32 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return output.String(), nil return output.String(), nil
} }
func (g *defaultGenerator) executeModel(code *code) (*bytes.Buffer, error) {
text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
if err != nil {
return nil, err
}
t := util.With("model").
Parse(text).
GoFmt(true)
output, err := t.Execute(map[string]interface{}{
"pkg": g.pkg,
"imports": code.importsCode,
"vars": code.varsCode,
"types": code.typesCode,
"new": code.newCode,
"insert": code.insertCode,
"find": strings.Join(code.findCode, "\n"),
"update": code.updateCode,
"delete": code.deleteCode,
"extraMethod": code.cacheExtra,
})
if err != nil {
return nil, err
}
return output, nil
}
func wrapWithRawString(v string) string { func wrapWithRawString(v string) string {
if v == "`" { if v == "`" {
return v return v

View File

@ -68,6 +68,71 @@ func Parse(ddl string) (*Table, error) {
columns := tableSpec.Columns columns := tableSpec.Columns
indexes := tableSpec.Indexes indexes := tableSpec.Indexes
keyMap, err := getIndexKeyType(indexes)
if err != nil {
return nil, err
}
fields, primaryKey, err := convertFileds(columns, keyMap)
if err != nil {
return nil, err
}
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
Fields: fields,
}, nil
}
func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyType) ([]Field, Primary, error) {
var fields []Field
var primaryKey Primary
for _, column := range columns {
if column == nil {
continue
}
var comment string
if column.Type.Comment != nil {
comment = string(column.Type.Comment.Val)
}
var isDefaultNull = true
if column.Type.NotNull {
isDefaultNull = false
} else {
if column.Type.Default == nil {
isDefaultNull = false
} else if string(column.Type.Default.Val) != "null" {
isDefaultNull = false
}
}
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
if err != nil {
return nil, primaryKey, err
}
var field Field
field.Name = stringx.From(column.Name.String())
field.DataBaseType = column.Type.Type
field.DataType = dataType
field.Comment = comment
key, ok := keyMap[column.Name.String()]
if ok {
field.IsPrimaryKey = key == primary
field.IsUniqueKey = key == unique
if field.IsPrimaryKey {
primaryKey.Field = field
if column.Type.Autoincrement {
primaryKey.AutoIncrement = true
}
}
}
fields = append(fields, field)
}
return fields, primaryKey, nil
}
func getIndexKeyType(indexes []*sqlparser.IndexDefinition) (map[string]KeyType, error) {
keyMap := make(map[string]KeyType) keyMap := make(map[string]KeyType)
for _, index := range indexes { for _, index := range indexes {
info := index.Info info := index.Info
@ -101,56 +166,7 @@ func Parse(ddl string) (*Table, error) {
keyMap[columnName] = normal keyMap[columnName] = normal
} }
} }
return keyMap, nil
var fields []Field
var primaryKey Primary
for _, column := range columns {
if column == nil {
continue
}
var comment string
if column.Type.Comment != nil {
comment = string(column.Type.Comment.Val)
}
var isDefaultNull = true
if column.Type.NotNull {
isDefaultNull = false
} else {
if column.Type.Default == nil {
isDefaultNull = false
} else if string(column.Type.Default.Val) != "null" {
isDefaultNull = false
}
}
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
if err != nil {
return nil, err
}
var field Field
field.Name = stringx.From(column.Name.String())
field.DataBaseType = column.Type.Type
field.DataType = dataType
field.Comment = comment
key, ok := keyMap[column.Name.String()]
if ok {
field.IsPrimaryKey = key == primary
field.IsUniqueKey = key == unique
if field.IsPrimaryKey {
primaryKey.Field = field
if column.Type.Autoincrement {
primaryKey.AutoIncrement = true
}
}
}
fields = append(fields, field)
}
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
Fields: fields,
}, nil
} }
func (t *Table) ContainsTime() bool { func (t *Table) ContainsTime() bool {

View File

@ -8,6 +8,9 @@ import (
) )
type ( type (
// Console wraps from the fmt.Sprintf,
// by default, it implemented the colorConsole to provide the colorful output to the consle
// and the ideaConsole to output with prefix for the plugin of intellij
Console interface { Console interface {
Success(format string, a ...interface{}) Success(format string, a ...interface{})
Info(format string, a ...interface{}) Info(format string, a ...interface{})
@ -25,6 +28,7 @@ type (
} }
) )
// NewConsole returns a instance of Console
func NewConsole(idea bool) Console { func NewConsole(idea bool) Console {
if idea { if idea {
return NewIdeaConsole() return NewIdeaConsole()
@ -32,7 +36,8 @@ func NewConsole(idea bool) Console {
return NewColorConsole() return NewColorConsole()
} }
func NewColorConsole() *colorConsole { // NewColorConsole returns a instance of colorConsole
func NewColorConsole() Console {
return &colorConsole{} return &colorConsole{}
} }
@ -76,7 +81,8 @@ func (c *colorConsole) Must(err error) {
} }
} }
func NewIdeaConsole() *ideaConsole { // NewIdeaConsole returns a instace of ideaConsole
func NewIdeaConsole() Console {
return &ideaConsole{} return &ideaConsole{}
} }

View File

@ -9,6 +9,8 @@ import (
var errModuleCheck = errors.New("the work directory must be found in the go mod or the $GOPATH") var errModuleCheck = errors.New("the work directory must be found in the go mod or the $GOPATH")
// ProjectContext is a structure for the project,
// which contains WorkDir, Name, Path and Dir
type ProjectContext struct { type ProjectContext struct {
WorkDir string WorkDir string
// Name is the root name of the project // Name is the root name of the project

View File

@ -9,6 +9,8 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
) )
// Module contains the relative data of go module,
// which is the result of the command go list
type Module struct { type Module struct {
Path string Path string
Main bool Main bool

View File

@ -10,9 +10,7 @@ import (
"github.com/logrusorgru/aurora" "github.com/logrusorgru/aurora"
) )
const ( const NL = "\n"
NL = "\n"
)
func CreateIfNotExist(file string) (*os.File, error) { func CreateIfNotExist(file string) (*os.File, error) {
_, err := os.Stat(file) _, err := os.Stat(file)

View File

@ -18,6 +18,7 @@ const (
upper upper
) )
// ErrNamingFormat defines an error for unknown fomat
var ErrNamingFormat = errors.New("unsupported format") var ErrNamingFormat = errors.New("unsupported format")
type ( type (

View File

@ -1,3 +1,5 @@
// Package name provides methods to verify naming style and format naming style
// See the method IsNamingValid, FormatFilename
package name package name
import ( import (
@ -6,11 +8,15 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
// NamingStyle the type of string
type NamingStyle = string type NamingStyle = string
const ( const (
// NamingLower defines the lower spell case
NamingLower NamingStyle = "lower" NamingLower NamingStyle = "lower"
// NamingCamel defines the camel spell case
NamingCamel NamingStyle = "camel" NamingCamel NamingStyle = "camel"
// NamingSnake defines the snake spell case
NamingSnake NamingStyle = "snake" NamingSnake NamingStyle = "snake"
) )
@ -29,6 +35,8 @@ func IsNamingValid(namingStyle string) (NamingStyle, bool) {
} }
} }
// FormatFilename converts the filename string to the target
// naming style by calling method of stringx
func FormatFilename(filename string, style NamingStyle) string { func FormatFilename(filename string, style NamingStyle) string {
switch style { switch style {
case NamingCamel: case NamingCamel:

View File

@ -6,14 +6,17 @@ import (
"unicode" "unicode"
) )
// String provides for coverting the source text into other spell case,like lower,snake,camel
type String struct { type String struct {
source string source string
} }
// From converts the input text to String and returns it
func From(data string) String { func From(data string) String {
return String{source: data} return String{source: data}
} }
// IsEmptyOrSpace returns true if the length of the string value is 0 after call strings.TrimSpace, or else returns false
func (s String) IsEmptyOrSpace() bool { func (s String) IsEmptyOrSpace() bool {
if len(s.source) == 0 { if len(s.source) == 0 {
return true return true
@ -24,18 +27,22 @@ func (s String) IsEmptyOrSpace() bool {
return false return false
} }
// Lower calls the strings.ToLower
func (s String) Lower() string { func (s String) Lower() string {
return strings.ToLower(s.source) return strings.ToLower(s.source)
} }
// ReplaceAll calls the strings.ReplaceAll
func (s String) ReplaceAll(old, new string) string { func (s String) ReplaceAll(old, new string) string {
return strings.ReplaceAll(s.source, old, new) return strings.ReplaceAll(s.source, old, new)
} }
//Source returns the source string value
func (s String) Source() string { func (s String) Source() string {
return s.source return s.source
} }
// Title calls the strings.Title
func (s String) Title() string { func (s String) Title() string {
if s.IsEmptyOrSpace() { if s.IsEmptyOrSpace() {
return s.source return s.source
@ -43,7 +50,7 @@ func (s String) Title() string {
return strings.Title(s.source) return strings.Title(s.source)
} }
// snake->camel(upper start) // ToCamel converts the input text into camel case
func (s String) ToCamel() string { func (s String) ToCamel() string {
list := s.splitBy(func(r rune) bool { list := s.splitBy(func(r rune) bool {
return r == '_' return r == '_'
@ -55,7 +62,7 @@ func (s String) ToCamel() string {
return strings.Join(target, "") return strings.Join(target, "")
} }
// camel->snake // ToSnake converts the input text into snake case
func (s String) ToSnake() string { func (s String) ToSnake() string {
list := s.splitBy(unicode.IsUpper, false) list := s.splitBy(unicode.IsUpper, false)
var target []string var target []string
@ -65,7 +72,7 @@ func (s String) ToSnake() string {
return strings.Join(target, "_") return strings.Join(target, "_")
} }
// return original string if rune is not letter at index 0 // Untitle return the original string if rune is not letter at index 0
func (s String) Untitle() string { func (s String) Untitle() string {
if s.IsEmptyOrSpace() { if s.IsEmptyOrSpace() {
return s.source return s.source
@ -77,10 +84,6 @@ func (s String) Untitle() string {
return string(unicode.ToLower(r)) + s.source[1:] return string(unicode.ToLower(r)) + s.source[1:]
} }
func (s String) Upper() string {
return strings.ToUpper(s.source)
}
// it will not ignore spaces // it will not ignore spaces
func (s String) splitBy(fn func(r rune) bool, remove bool) []string { func (s String) splitBy(fn func(r rune) bool, remove bool) []string {
if s.IsEmptyOrSpace() { if s.IsEmptyOrSpace() {

View File

@ -9,29 +9,35 @@ import (
const regularPerm = 0666 const regularPerm = 0666
type defaultTemplate struct { // DefaultTemplate is a tool to provides the text/template operations
type DefaultTemplate struct {
name string name string
text string text string
goFmt bool goFmt bool
savePath string savePath string
} }
func With(name string) *defaultTemplate { // With returns a instace of DefaultTemplate
return &defaultTemplate{ func With(name string) *DefaultTemplate {
return &DefaultTemplate{
name: name, name: name,
} }
} }
func (t *defaultTemplate) Parse(text string) *defaultTemplate {
// Parse accepts a source template and returns DefaultTemplate
func (t *DefaultTemplate) Parse(text string) *DefaultTemplate {
t.text = text t.text = text
return t return t
} }
func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate { // GoFmt sets the value to goFmt and marks the generated codes will be formated or not
func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate {
t.goFmt = format t.goFmt = format
return t return t
} }
func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error { // SaveTo writes the codes to the target path
func (t *DefaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error {
if FileExists(path) && !forceUpdate { if FileExists(path) && !forceUpdate {
return nil return nil
} }
@ -44,7 +50,8 @@ func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool
return ioutil.WriteFile(path, output.Bytes(), regularPerm) return ioutil.WriteFile(path, output.Bytes(), regularPerm)
} }
func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) { // Execute returns the codes after the template executed
func (t *DefaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) {
tem, err := template.New(t.name).Parse(t.text) tem, err := template.New(t.name).Parse(t.text)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -1,9 +1,14 @@
package vars package vars
const ( const (
// ProjectName the const value of zero
ProjectName = "zero" ProjectName = "zero"
ProjectOpenSourceUrl = "github.com/tal-tech/go-zero" // ProjectOpenSourceURL the githb url of go-zero
ProjectOpenSourceURL = "github.com/tal-tech/go-zero"
// OsWindows windows os
OsWindows = "windows" OsWindows = "windows"
// OsMac mac os
OsMac = "darwin" OsMac = "darwin"
// OsLinux linux os
OsLinux = "linux" OsLinux = "linux"
) )