From f98c9246b28207a49fae5a8912432224c96b7086 Mon Sep 17 00:00:00 2001 From: kingxt Date: Sat, 20 Feb 2021 19:50:03 +0800 Subject: [PATCH] Code optimized (#493) --- tools/goctl/api/apigen/gen.go | 1 + tools/goctl/api/dartgen/gen.go | 1 + tools/goctl/api/dartgen/util.go | 18 +- tools/goctl/api/docgen/gen.go | 1 + tools/goctl/api/gogen/gen_test.go | 1 + tools/goctl/api/gogen/genconfig.go | 2 +- tools/goctl/api/gogen/genhandlers.go | 2 +- tools/goctl/api/gogen/genlogic.go | 2 +- tools/goctl/api/gogen/genmain.go | 4 +- tools/goctl/api/gogen/genroutes.go | 8 +- tools/goctl/api/gogen/gensvc.go | 2 +- tools/goctl/api/gogen/util.go | 4 - tools/goctl/api/parser/g4/ast/api.go | 221 +++++++++++---------- tools/goctl/api/parser/g4/ast/apiparser.go | 179 ++++++++++------- tools/goctl/api/parser/parser.go | 66 +++--- tools/goctl/api/util/tag.go | 58 ------ tools/goctl/goctl.go | 4 +- tools/goctl/model/sql/gen/gen.go | 74 +++++-- tools/goctl/model/sql/parser/parser.go | 116 ++++++----- tools/goctl/util/console/console.go | 10 +- tools/goctl/util/ctx/context.go | 2 + tools/goctl/util/ctx/gomod.go | 2 + tools/goctl/util/file.go | 4 +- tools/goctl/util/format/format.go | 1 + tools/goctl/util/name/naming.go | 8 + tools/goctl/util/stringx/string.go | 17 +- tools/goctl/util/templatex.go | 21 +- tools/goctl/vars/settings.go | 15 +- 28 files changed, 472 insertions(+), 372 deletions(-) delete mode 100644 tools/goctl/api/util/tag.go diff --git a/tools/goctl/api/apigen/gen.go b/tools/goctl/api/apigen/gen.go index 2b16aebe..35c4218d 100644 --- a/tools/goctl/api/apigen/gen.go +++ b/tools/goctl/api/apigen/gen.go @@ -39,6 +39,7 @@ service {{.serviceName}} { } ` +// ApiCommand create api template file func ApiCommand(c *cli.Context) error { apiFile := c.String("o") if len(apiFile) == 0 { diff --git a/tools/goctl/api/dartgen/gen.go b/tools/goctl/api/dartgen/gen.go index dc5af11a..f2bbc8dc 100644 --- a/tools/goctl/api/dartgen/gen.go +++ b/tools/goctl/api/dartgen/gen.go @@ -9,6 +9,7 @@ import ( "github.com/urfave/cli" ) +// DartCommand create dart network request code func DartCommand(c *cli.Context) error { apiFile := c.String("api") dir := c.String("dir") diff --git a/tools/goctl/api/dartgen/util.go b/tools/goctl/api/dartgen/util.go index dab2d416..9c83bd6b 100644 --- a/tools/goctl/api/dartgen/util.go +++ b/tools/goctl/api/dartgen/util.go @@ -2,9 +2,9 @@ package dartgen import ( "os" - "reflect" "strings" + "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/util" ) @@ -32,10 +32,18 @@ func pathToFuncName(path string) string { return util.ToLower(camel[:1]) + camel[1:] } -func tagGet(tag, k string) (reflect.Value, error) { - v, _ := util.TagLookup(tag, k) - out := strings.Split(v, ",")[0] - return reflect.ValueOf(out), nil +func tagGet(tag, k string) string { + tags, err := spec.Parse(tag) + if err != nil { + panic(k + " not exist") + } + + v, err := tags.Get(k) + if err != nil { + panic(k + " value not exist") + } + + return v.Name } func isDirectType(s string) bool { diff --git a/tools/goctl/api/docgen/gen.go b/tools/goctl/api/docgen/gen.go index 840ab8a8..b3c6e250 100644 --- a/tools/goctl/api/docgen/gen.go +++ b/tools/goctl/api/docgen/gen.go @@ -12,6 +12,7 @@ import ( "github.com/urfave/cli" ) +// DocCommand generate markdown doc file func DocCommand(c *cli.Context) error { dir := c.String("dir") if len(dir) == 0 { diff --git a/tools/goctl/api/gogen/gen_test.go b/tools/goctl/api/gogen/gen_test.go index db8ad3cd..434de035 100644 --- a/tools/goctl/api/gogen/gen_test.go +++ b/tools/goctl/api/gogen/gen_test.go @@ -160,6 +160,7 @@ type Response struct { @server( jwt: Auth + signature: true ) service A-api { @handler GreetHandler diff --git a/tools/goctl/api/gogen/genconfig.go b/tools/goctl/api/gogen/genconfig.go index 16c38366..369edec5 100644 --- a/tools/goctl/api/gogen/genconfig.go +++ b/tools/goctl/api/gogen/genconfig.go @@ -40,7 +40,7 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error { for _, item := range authNames { auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate)) } - var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) + var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL) return genFile(fileGenConfig{ dir: dir, diff --git a/tools/goctl/api/gogen/genhandlers.go b/tools/goctl/api/gogen/genhandlers.go index 26549d7c..5a3a49bd 100644 --- a/tools/goctl/api/gogen/genhandlers.go +++ b/tools/goctl/api/gogen/genhandlers.go @@ -109,7 +109,7 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str if len(route.RequestTypeName()) > 0 { imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir))) } - imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceUrl)) + imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceURL)) return strings.Join(imports, "\n\t") } diff --git a/tools/goctl/api/gogen/genlogic.go b/tools/goctl/api/gogen/genlogic.go index 77c45bc0..09a0c543 100644 --- a/tools/goctl/api/gogen/genlogic.go +++ b/tools/goctl/api/gogen/genlogic.go @@ -122,6 +122,6 @@ func genLogicImports(route spec.Route, parentPkg string) string { if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 { imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir))) } - imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceUrl)) + imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL)) return strings.Join(imports, "\n\t") } diff --git a/tools/goctl/api/gogen/genmain.go b/tools/goctl/api/gogen/genmain.go index 96ca06d8..269c25a3 100644 --- a/tools/goctl/api/gogen/genmain.go +++ b/tools/goctl/api/gogen/genmain.go @@ -74,7 +74,7 @@ func genMainImports(parentPkg string) string { imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir))) imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir))) - imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceUrl)) - imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)) + imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceURL)) + imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)) return strings.Join(imports, "\n\t") } diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 8d311652..d097eb87 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -89,7 +89,7 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error { } var signature string if g.signatureEnabled { - signature = fmt.Sprintf("\n rest.WithSignature(serverCtx.Config.%s.Signature),", g.authName) + signature = "\n rest.WithSignature(serverCtx.Config.Signature)," } var routes string @@ -163,7 +163,7 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string { imports := importSet.KeysStr() sort.Strings(imports) projectSection := strings.Join(imports, "\n\t") - depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) + depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL) return fmt.Sprintf("%s\n\n\t%s", projectSection, depSection) } @@ -196,6 +196,10 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) { groupedRoutes.authName = jwt groupedRoutes.jwtEnabled = true } + signature := g.GetAnnotation("signature") + if signature == "true" { + groupedRoutes.signatureEnabled = true + } middleware := g.GetAnnotation("middleware") if len(middleware) > 0 { for _, item := range strings.Split(middleware, ",") { diff --git a/tools/goctl/api/gogen/gensvc.go b/tools/goctl/api/gogen/gensvc.go index e4adab37..43adbc42 100644 --- a/tools/goctl/api/gogen/gensvc.go +++ b/tools/goctl/api/gogen/gensvc.go @@ -64,7 +64,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" if len(middlewareStr) > 0 { configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\"" - configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl) + configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceURL) } return genFile(fileGenConfig{ diff --git a/tools/goctl/api/gogen/util.go b/tools/goctl/api/gogen/util.go index b327ed26..106688b1 100644 --- a/tools/goctl/api/gogen/util.go +++ b/tools/goctl/api/gogen/util.go @@ -94,10 +94,6 @@ func getAuths(api *spec.ApiSpec) []string { if len(jwt) > 0 { authNames.Add(jwt) } - signature := g.GetAnnotation("signature") - if len(signature) > 0 { - authNames.Add(signature) - } } return authNames.KeysStr() } diff --git a/tools/goctl/api/parser/g4/ast/api.go b/tools/goctl/api/parser/g4/ast/api.go index 4760f842..9c12f91d 100644 --- a/tools/goctl/api/parser/g4/ast/api.go +++ b/tools/goctl/api/parser/g4/ast/api.go @@ -22,12 +22,6 @@ type Api struct { } func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} { - defer func() { - if p := recover(); p != nil { - panic(fmt.Errorf("%+v", p)) - } - }() - var final Api final.importM = map[string]PlaceHolder{} final.typeM = map[string]PlaceHolder{} @@ -36,109 +30,128 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} { final.routeM = map[string]PlaceHolder{} for _, each := range ctx.AllSpec() { root := each.Accept(v).(*Api) - if root.Syntax != nil { - if final.Syntax != nil { - v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration")) - } - - final.Syntax = root.Syntax - } - - for _, imp := range root.Import { - if _, ok := final.importM[imp.Value.Text()]; ok { - v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text())) - } - - final.importM[imp.Value.Text()] = Holder - final.Import = append(final.Import, imp) - } - - if root.Info != nil { - infoM := map[string]PlaceHolder{} - if final.Info != nil { - v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration")) - } - - for _, value := range root.Info.Kvs { - if _, ok := infoM[value.Key.Text()]; ok { - v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text())) - } - infoM[value.Key.Text()] = Holder - } - - final.Info = root.Info - } - - for _, tp := range root.Type { - if _, ok := final.typeM[tp.NameExpr().Text()]; ok { - v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text())) - } - - final.typeM[tp.NameExpr().Text()] = Holder - final.Type = append(final.Type, tp) - } - - for _, service := range root.Service { - if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 { - v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration")) - } - - if service.AtServer != nil { - atServerM := map[string]PlaceHolder{} - for _, kv := range service.AtServer.Kv { - if _, ok := atServerM[kv.Key.Text()]; ok { - v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text())) - } - - atServerM[kv.Key.Text()] = Holder - } - } - - for _, route := range service.ServiceApi.ServiceRoute { - uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text()) - if _, ok := final.routeM[uniqueRoute]; ok { - v.panic(route.Route.Method, fmt.Sprintf("duplicate route '%s'", uniqueRoute)) - } - - final.routeM[uniqueRoute] = Holder - var handlerExpr Expr - if route.AtServer != nil { - atServerM := map[string]PlaceHolder{} - for _, kv := range route.AtServer.Kv { - if _, ok := atServerM[kv.Key.Text()]; ok { - v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text())) - } - atServerM[kv.Key.Text()] = Holder - if kv.Key.Text() == "handler" { - handlerExpr = kv.Value - } - } - } - - if route.AtHandler != nil { - handlerExpr = route.AtHandler.Name - } - - if handlerExpr == nil { - v.panic(route.Route.Method, fmt.Sprintf("mismtached handler")) - } - - if handlerExpr.Text() == "" { - v.panic(handlerExpr, fmt.Sprintf("mismtached handler")) - } - - if _, ok := final.handlerM[handlerExpr.Text()]; ok { - v.panic(handlerExpr, fmt.Sprintf("duplicate handler '%s'", handlerExpr.Text())) - } - final.handlerM[handlerExpr.Text()] = Holder - } - final.Service = append(final.Service, service) - } + v.acceptSyntax(root, &final) + v.accetpImport(root, &final) + v.acceptInfo(root, &final) + v.acceptType(root, &final) + v.acceptService(root, &final) } return &final } +func (v *ApiVisitor) acceptService(root *Api, final *Api) { + for _, service := range root.Service { + if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 { + v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration")) + } + v.duplicateServerItemCheck(service) + + for _, route := range service.ServiceApi.ServiceRoute { + uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text()) + if _, ok := final.routeM[uniqueRoute]; ok { + v.panic(route.Route.Method, fmt.Sprintf("duplicate route '%s'", uniqueRoute)) + } + + final.routeM[uniqueRoute] = Holder + var handlerExpr Expr + if route.AtServer != nil { + atServerM := map[string]PlaceHolder{} + for _, kv := range route.AtServer.Kv { + if _, ok := atServerM[kv.Key.Text()]; ok { + v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text())) + } + atServerM[kv.Key.Text()] = Holder + if kv.Key.Text() == "handler" { + handlerExpr = kv.Value + } + } + } + + if route.AtHandler != nil { + handlerExpr = route.AtHandler.Name + } + + if handlerExpr == nil { + v.panic(route.Route.Method, fmt.Sprintf("mismtached handler")) + } + + if handlerExpr.Text() == "" { + v.panic(handlerExpr, fmt.Sprintf("mismtached handler")) + } + + if _, ok := final.handlerM[handlerExpr.Text()]; ok { + v.panic(handlerExpr, fmt.Sprintf("duplicate handler '%s'", handlerExpr.Text())) + } + final.handlerM[handlerExpr.Text()] = Holder + } + final.Service = append(final.Service, service) + } +} + +func (v *ApiVisitor) duplicateServerItemCheck(service *Service) { + if service.AtServer != nil { + atServerM := map[string]PlaceHolder{} + for _, kv := range service.AtServer.Kv { + if _, ok := atServerM[kv.Key.Text()]; ok { + v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text())) + } + + atServerM[kv.Key.Text()] = Holder + } + } +} + +func (v *ApiVisitor) acceptType(root *Api, final *Api) { + for _, tp := range root.Type { + if _, ok := final.typeM[tp.NameExpr().Text()]; ok { + v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text())) + } + + final.typeM[tp.NameExpr().Text()] = Holder + final.Type = append(final.Type, tp) + } +} + +func (v *ApiVisitor) acceptInfo(root *Api, final *Api) { + if root.Info != nil { + infoM := map[string]PlaceHolder{} + if final.Info != nil { + v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration")) + } + + for _, value := range root.Info.Kvs { + if _, ok := infoM[value.Key.Text()]; ok { + v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text())) + } + infoM[value.Key.Text()] = Holder + } + + final.Info = root.Info + } +} + +func (v *ApiVisitor) accetpImport(root *Api, final *Api) { + for _, imp := range root.Import { + if _, ok := final.importM[imp.Value.Text()]; ok { + v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text())) + } + + final.importM[imp.Value.Text()] = Holder + final.Import = append(final.Import, imp) + } +} + +func (v *ApiVisitor) acceptSyntax(root *Api, final *Api) { + if root.Syntax != nil { + if final.Syntax != nil { + v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration")) + } + + final.Syntax = root.Syntax + } +} + func (v *ApiVisitor) VisitSpec(ctx *api.SpecContext) interface{} { var root Api if ctx.SyntaxLit() != nil { diff --git a/tools/goctl/api/parser/g4/ast/apiparser.go b/tools/goctl/api/parser/g4/ast/apiparser.go index 3c05f550..3b0042eb 100644 --- a/tools/goctl/api/parser/g4/ast/apiparser.go +++ b/tools/goctl/api/parser/g4/ast/apiparser.go @@ -156,28 +156,9 @@ func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) { } func (p *Parser) valid(mainApi *Api, nestedApi *Api) error { - if len(nestedApi.Import) > 0 { - importToken := nestedApi.Import[0].Import - return fmt.Errorf("%s line %d:%d the nested api does not support import", - nestedApi.LinePrefix, importToken.Line(), importToken.Column()) - } - - if mainApi.Syntax != nil && nestedApi.Syntax != nil { - if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() { - syntaxToken := nestedApi.Syntax.Syntax - return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'", - nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text()) - } - } - - if len(mainApi.Service) > 0 { - mainService := mainApi.Service[0] - for _, service := range nestedApi.Service { - if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() { - return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'", - nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text()) - } - } + err := p.nestedApiCheck(mainApi, nestedApi) + if err != nil { + return err } mainHandlerMap := make(map[string]PlaceHolder) @@ -218,6 +199,23 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error { } // duplicate route check + err = p.duplicateRouteCheck(nestedApi, mainHandlerMap, mainRouteMap) + if err != nil { + return err + } + + // duplicate type check + for _, each := range nestedApi.Type { + if _, ok := mainTypeMap[each.NameExpr().Text()]; ok { + return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'", + nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text()) + } + } + + return nil +} + +func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap map[string]PlaceHolder, mainRouteMap map[string]PlaceHolder) error { for _, each := range nestedApi.Service { for _, r := range each.ServiceApi.ServiceRoute { handler := r.GetHandler() @@ -237,12 +235,31 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error { } } } + return nil +} - // duplicate type check - for _, each := range nestedApi.Type { - if _, ok := mainTypeMap[each.NameExpr().Text()]; ok { - return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'", - nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text()) +func (p *Parser) nestedApiCheck(mainApi *Api, nestedApi *Api) error { + if len(nestedApi.Import) > 0 { + importToken := nestedApi.Import[0].Import + return fmt.Errorf("%s line %d:%d the nested api does not support import", + nestedApi.LinePrefix, importToken.Line(), importToken.Column()) + } + + if mainApi.Syntax != nil && nestedApi.Syntax != nil { + if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() { + syntaxToken := nestedApi.Syntax.Syntax + return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'", + nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text()) + } + } + + if len(mainApi.Service) > 0 { + mainService := mainApi.Service[0] + for _, service := range nestedApi.Service { + if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() { + return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'", + nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text()) + } } } return nil @@ -276,56 +293,80 @@ func (p *Parser) checkTypeDeclaration(apiList []*Api) error { for _, apiItem := range apiList { linePrefix := apiItem.LinePrefix - for _, each := range apiItem.Type { - tp, ok := each.(*TypeStruct) - if !ok { - continue + err := p.checkTypes(apiItem, linePrefix, types) + if err != nil { + return err + } + + err = p.checkServices(apiItem, types, linePrefix) + if err != nil { + return err + } + } + return nil +} + +func (p *Parser) checkServices(apiItem *Api, types map[string]TypeExpr, linePrefix string) error { + for _, service := range apiItem.Service { + for _, each := range service.ServiceApi.ServiceRoute { + route := each.Route + err := p.checkRequestBody(route, types, linePrefix) + if err != nil { + return err } - for _, member := range tp.Fields { - err := p.checkType(linePrefix, types, member.DataType) - if err != nil { - return err + if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() { + reply := route.Reply.Name + var structName string + switch tp := reply.(type) { + case *Literal: + structName = tp.Literal.Text() + case *Array: + switch innerTp := tp.Literal.(type) { + case *Literal: + structName = innerTp.Literal.Text() + case *Pointer: + structName = innerTp.Name.Text() + } + } + + if api.IsBasicType(structName) { + continue + } + + _, ok := types[structName] + if !ok { + return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context", + linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName) } } } + } + return nil +} - for _, service := range apiItem.Service { - for _, each := range service.ServiceApi.ServiceRoute { - route := each.Route - if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() { - _, ok := types[route.Req.Name.Expr().Text()] - if !ok { - return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context", - linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text()) - } - } +func (p *Parser) checkRequestBody(route *Route, types map[string]TypeExpr, linePrefix string) error { + if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() { + _, ok := types[route.Req.Name.Expr().Text()] + if !ok { + return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context", + linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text()) + } + } + return nil +} - if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() { - reply := route.Reply.Name - var structName string - switch tp := reply.(type) { - case *Literal: - structName = tp.Literal.Text() - case *Array: - switch innerTp := tp.Literal.(type) { - case *Literal: - structName = innerTp.Literal.Text() - case *Pointer: - structName = innerTp.Name.Text() - } - } +func (p *Parser) checkTypes(apiItem *Api, linePrefix string, types map[string]TypeExpr) error { + for _, each := range apiItem.Type { + tp, ok := each.(*TypeStruct) + if !ok { + continue + } - if api.IsBasicType(structName) { - continue - } - - _, ok := types[structName] - if !ok { - return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context", - linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName) - } - } + for _, member := range tp.Fields { + err := p.checkType(linePrefix, types, member.DataType) + if err != nil { + return err } } } diff --git a/tools/goctl/api/parser/parser.go b/tools/goctl/api/parser/parser.go index 3940e1d8..7fc41793 100644 --- a/tools/goctl/api/parser/parser.go +++ b/tools/goctl/api/parser/parser.go @@ -213,13 +213,7 @@ func (p parser) fillService() error { var groups []spec.Group for _, item := range p.ast.Service { var group spec.Group - if item.AtServer != nil { - var properties = make(map[string]string, 0) - for _, kv := range item.AtServer.Kv { - properties[kv.Key.Text()] = kv.Value.Text() - } - group.Annotation.Properties = properties - } + p.fillAtServer(item, &group) for _, astRoute := range item.ServiceApi.ServiceRoute { route := spec.Route{ @@ -231,25 +225,9 @@ func (p parser) fillService() error { route.Handler = astRoute.AtHandler.Name.Text() } - if astRoute.AtServer != nil { - var properties = make(map[string]string, 0) - for _, kv := range astRoute.AtServer.Kv { - properties[kv.Key.Text()] = kv.Value.Text() - } - route.Annotation.Properties = properties - if len(route.Handler) == 0 { - route.Handler = properties["handler"] - } - if len(route.Handler) == 0 { - return fmt.Errorf("missing handler annotation for %q", route.Path) - } - - for _, char := range route.Handler { - if !unicode.IsDigit(char) && !unicode.IsLetter(char) { - return fmt.Errorf("route [%s] handler [%s] invalid, handler name should only contains letter or digit", - route.Path, route.Handler) - } - } + err := p.fillRouteAtServer(astRoute, &route) + if err != nil { + return err } if astRoute.Route.Req != nil { @@ -269,7 +247,7 @@ func (p parser) fillService() error { } } - err := p.fillRouteType(&route) + err = p.fillRouteType(&route) if err != nil { return err } @@ -289,6 +267,40 @@ func (p parser) fillService() error { return nil } +func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route) error { + if astRoute.AtServer != nil { + var properties = make(map[string]string, 0) + for _, kv := range astRoute.AtServer.Kv { + properties[kv.Key.Text()] = kv.Value.Text() + } + route.Annotation.Properties = properties + if len(route.Handler) == 0 { + route.Handler = properties["handler"] + } + if len(route.Handler) == 0 { + return fmt.Errorf("missing handler annotation for %q", route.Path) + } + + for _, char := range route.Handler { + if !unicode.IsDigit(char) && !unicode.IsLetter(char) { + return fmt.Errorf("route [%s] handler [%s] invalid, handler name should only contains letter or digit", + route.Path, route.Handler) + } + } + } + return nil +} + +func (p parser) fillAtServer(item *ast.Service, group *spec.Group) { + if item.AtServer != nil { + var properties = make(map[string]string, 0) + for _, kv := range item.AtServer.Kv { + properties[kv.Key.Text()] = kv.Value.Text() + } + group.Annotation.Properties = properties + } +} + func (p parser) fillRouteType(route *spec.Route) error { if route.RequestType != nil { switch route.RequestType.(type) { diff --git a/tools/goctl/api/util/tag.go b/tools/goctl/api/util/tag.go deleted file mode 100644 index 5c6798bb..00000000 --- a/tools/goctl/api/util/tag.go +++ /dev/null @@ -1,58 +0,0 @@ -package util - -import ( - "strconv" - "strings" -) - -func TagLookup(tag, key string) (value string, ok bool) { - tag = strings.Replace(tag, "`", "", -1) - for tag != "" { - // Skip leading space. - i := 0 - for i < len(tag) && tag[i] == ' ' { - i++ - } - tag = tag[i:] - if tag == "" { - break - } - - // Scan to colon. A space, a quote or a control character is a syntax error. - // Strictly speaking, control chars include the range [0x7f, 0x9f], not just - // [0x00, 0x1f], but in practice, we ignore the multi-byte control characters - // as it is simpler to inspect the tag's bytes than the tag's runes. - i = 0 - for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f { - i++ - } - if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' { - break - } - name := string(tag[:i]) - tag = tag[i+1:] - - // Scan quoted string to find value. - i = 1 - for i < len(tag) && tag[i] != '"' { - if tag[i] == '\\' { - i++ - } - i++ - } - if i >= len(tag) { - break - } - qvalue := string(tag[:i+1]) - tag = tag[i+1:] - - if key == name { - value, err := strconv.Unquote(qvalue) - if err != nil { - break - } - return value, true - } - } - return "", false -} diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index 66b0219d..3f9130b4 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -28,7 +28,7 @@ import ( ) var ( - BuildVersion = "1.1.5" + buildVersion = "1.1.5" commands = []cli.Command{ { Name: "upgrade", @@ -510,7 +510,7 @@ func main() { app := cli.NewApp() app.Usage = "a cli tool to generate code" - app.Version = fmt.Sprintf("%s %s/%s", BuildVersion, runtime.GOOS, runtime.GOARCH) + app.Version = fmt.Sprintf("%s %s/%s", buildVersion, runtime.GOOS, runtime.GOARCH) app.Commands = commands // cli already print error messages if err := app.Run(os.Args); err != nil { diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index f29d5992..2973a50e 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -1,6 +1,7 @@ package gen import ( + "bytes" "fmt" "io/ioutil" "os" @@ -31,7 +32,20 @@ type ( pkg string cfg *config.Config } + Option func(generator *defaultGenerator) + + code struct { + importsCode string + varsCode string + typesCode string + newCode string + insertCode string + findCode []string + updateCode string + deleteCode string + cacheExtra string + } ) func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaultGenerator, error) { @@ -186,15 +200,6 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er return "", fmt.Errorf("table %s: missing primary key", in.Name.Source()) } - text, err := util.LoadTemplate(category, modelTemplateFile, template.Model) - if err != nil { - return "", err - } - - t := util.With("model"). - Parse(text). - GoFmt(true) - m, err := genCacheKeys(in) if err != nil { return "", err @@ -261,18 +266,19 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er return "", err } - output, err := t.Execute(map[string]interface{}{ - "pkg": g.pkg, - "imports": importsCode, - "vars": varsCode, - "types": typesCode, - "new": newCode, - "insert": insertCode, - "find": strings.Join(findCode, "\n"), - "update": updateCode, - "delete": deleteCode, - "extraMethod": ret.cacheExtra, - }) + code := &code{ + importsCode: importsCode, + varsCode: varsCode, + typesCode: typesCode, + newCode: newCode, + insertCode: insertCode, + findCode: findCode, + updateCode: updateCode, + deleteCode: deleteCode, + cacheExtra: ret.cacheExtra, + } + + output, err := g.executeModel(code) if err != nil { return "", err } @@ -280,6 +286,32 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er return output.String(), nil } +func (g *defaultGenerator) executeModel(code *code) (*bytes.Buffer, error) { + text, err := util.LoadTemplate(category, modelTemplateFile, template.Model) + if err != nil { + return nil, err + } + t := util.With("model"). + Parse(text). + GoFmt(true) + output, err := t.Execute(map[string]interface{}{ + "pkg": g.pkg, + "imports": code.importsCode, + "vars": code.varsCode, + "types": code.typesCode, + "new": code.newCode, + "insert": code.insertCode, + "find": strings.Join(code.findCode, "\n"), + "update": code.updateCode, + "delete": code.deleteCode, + "extraMethod": code.cacheExtra, + }) + if err != nil { + return nil, err + } + return output, nil +} + func wrapWithRawString(v string) string { if v == "`" { return v diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 85bd9bfd..6b5338f2 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -68,6 +68,71 @@ func Parse(ddl string) (*Table, error) { columns := tableSpec.Columns indexes := tableSpec.Indexes + keyMap, err := getIndexKeyType(indexes) + if err != nil { + return nil, err + } + + fields, primaryKey, err := convertFileds(columns, keyMap) + if err != nil { + return nil, err + } + + return &Table{ + Name: stringx.From(tableName), + PrimaryKey: primaryKey, + Fields: fields, + }, nil +} + +func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyType) ([]Field, Primary, error) { + var fields []Field + var primaryKey Primary + for _, column := range columns { + if column == nil { + continue + } + var comment string + if column.Type.Comment != nil { + comment = string(column.Type.Comment.Val) + } + var isDefaultNull = true + if column.Type.NotNull { + isDefaultNull = false + } else { + if column.Type.Default == nil { + isDefaultNull = false + } else if string(column.Type.Default.Val) != "null" { + isDefaultNull = false + } + } + dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull) + if err != nil { + return nil, primaryKey, err + } + + var field Field + field.Name = stringx.From(column.Name.String()) + field.DataBaseType = column.Type.Type + field.DataType = dataType + field.Comment = comment + key, ok := keyMap[column.Name.String()] + if ok { + field.IsPrimaryKey = key == primary + field.IsUniqueKey = key == unique + if field.IsPrimaryKey { + primaryKey.Field = field + if column.Type.Autoincrement { + primaryKey.AutoIncrement = true + } + } + } + fields = append(fields, field) + } + return fields, primaryKey, nil +} + +func getIndexKeyType(indexes []*sqlparser.IndexDefinition) (map[string]KeyType, error) { keyMap := make(map[string]KeyType) for _, index := range indexes { info := index.Info @@ -101,56 +166,7 @@ func Parse(ddl string) (*Table, error) { keyMap[columnName] = normal } } - - var fields []Field - var primaryKey Primary - for _, column := range columns { - if column == nil { - continue - } - var comment string - if column.Type.Comment != nil { - comment = string(column.Type.Comment.Val) - } - var isDefaultNull = true - if column.Type.NotNull { - isDefaultNull = false - } else { - if column.Type.Default == nil { - isDefaultNull = false - } else if string(column.Type.Default.Val) != "null" { - isDefaultNull = false - } - } - dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull) - if err != nil { - return nil, err - } - - var field Field - field.Name = stringx.From(column.Name.String()) - field.DataBaseType = column.Type.Type - field.DataType = dataType - field.Comment = comment - key, ok := keyMap[column.Name.String()] - if ok { - field.IsPrimaryKey = key == primary - field.IsUniqueKey = key == unique - if field.IsPrimaryKey { - primaryKey.Field = field - if column.Type.Autoincrement { - primaryKey.AutoIncrement = true - } - } - } - fields = append(fields, field) - } - - return &Table{ - Name: stringx.From(tableName), - PrimaryKey: primaryKey, - Fields: fields, - }, nil + return keyMap, nil } func (t *Table) ContainsTime() bool { diff --git a/tools/goctl/util/console/console.go b/tools/goctl/util/console/console.go index e2fb09d5..a6cb2069 100644 --- a/tools/goctl/util/console/console.go +++ b/tools/goctl/util/console/console.go @@ -8,6 +8,9 @@ import ( ) type ( + // Console wraps from the fmt.Sprintf, + // by default, it implemented the colorConsole to provide the colorful output to the consle + // and the ideaConsole to output with prefix for the plugin of intellij Console interface { Success(format string, a ...interface{}) Info(format string, a ...interface{}) @@ -25,6 +28,7 @@ type ( } ) +// NewConsole returns a instance of Console func NewConsole(idea bool) Console { if idea { return NewIdeaConsole() @@ -32,7 +36,8 @@ func NewConsole(idea bool) Console { return NewColorConsole() } -func NewColorConsole() *colorConsole { +// NewColorConsole returns a instance of colorConsole +func NewColorConsole() Console { return &colorConsole{} } @@ -76,7 +81,8 @@ func (c *colorConsole) Must(err error) { } } -func NewIdeaConsole() *ideaConsole { +// NewIdeaConsole returns a instace of ideaConsole +func NewIdeaConsole() Console { return &ideaConsole{} } diff --git a/tools/goctl/util/ctx/context.go b/tools/goctl/util/ctx/context.go index 68ebc2c4..aab06e5b 100644 --- a/tools/goctl/util/ctx/context.go +++ b/tools/goctl/util/ctx/context.go @@ -9,6 +9,8 @@ import ( var errModuleCheck = errors.New("the work directory must be found in the go mod or the $GOPATH") +// ProjectContext is a structure for the project, +// which contains WorkDir, Name, Path and Dir type ProjectContext struct { WorkDir string // Name is the root name of the project diff --git a/tools/goctl/util/ctx/gomod.go b/tools/goctl/util/ctx/gomod.go index c073e449..db9eeba8 100644 --- a/tools/goctl/util/ctx/gomod.go +++ b/tools/goctl/util/ctx/gomod.go @@ -9,6 +9,8 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" ) +// Module contains the relative data of go module, +// which is the result of the command go list type Module struct { Path string Main bool diff --git a/tools/goctl/util/file.go b/tools/goctl/util/file.go index cadb5980..11f777b1 100644 --- a/tools/goctl/util/file.go +++ b/tools/goctl/util/file.go @@ -10,9 +10,7 @@ import ( "github.com/logrusorgru/aurora" ) -const ( - NL = "\n" -) +const NL = "\n" func CreateIfNotExist(file string) (*os.File, error) { _, err := os.Stat(file) diff --git a/tools/goctl/util/format/format.go b/tools/goctl/util/format/format.go index 51ef965c..039f7e81 100644 --- a/tools/goctl/util/format/format.go +++ b/tools/goctl/util/format/format.go @@ -18,6 +18,7 @@ const ( upper ) +// ErrNamingFormat defines an error for unknown fomat var ErrNamingFormat = errors.New("unsupported format") type ( diff --git a/tools/goctl/util/name/naming.go b/tools/goctl/util/name/naming.go index 30cd67a8..61baf5c3 100644 --- a/tools/goctl/util/name/naming.go +++ b/tools/goctl/util/name/naming.go @@ -1,3 +1,5 @@ +// Package name provides methods to verify naming style and format naming style +// See the method IsNamingValid, FormatFilename package name import ( @@ -6,11 +8,15 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) +// NamingStyle the type of string type NamingStyle = string const ( + // NamingLower defines the lower spell case NamingLower NamingStyle = "lower" + // NamingCamel defines the camel spell case NamingCamel NamingStyle = "camel" + // NamingSnake defines the snake spell case NamingSnake NamingStyle = "snake" ) @@ -29,6 +35,8 @@ func IsNamingValid(namingStyle string) (NamingStyle, bool) { } } +// FormatFilename converts the filename string to the target +// naming style by calling method of stringx func FormatFilename(filename string, style NamingStyle) string { switch style { case NamingCamel: diff --git a/tools/goctl/util/stringx/string.go b/tools/goctl/util/stringx/string.go index e93fd9d1..001e4bdd 100644 --- a/tools/goctl/util/stringx/string.go +++ b/tools/goctl/util/stringx/string.go @@ -6,14 +6,17 @@ import ( "unicode" ) +// String provides for coverting the source text into other spell case,like lower,snake,camel type String struct { source string } +// From converts the input text to String and returns it func From(data string) String { return String{source: data} } +// IsEmptyOrSpace returns true if the length of the string value is 0 after call strings.TrimSpace, or else returns false func (s String) IsEmptyOrSpace() bool { if len(s.source) == 0 { return true @@ -24,18 +27,22 @@ func (s String) IsEmptyOrSpace() bool { return false } +// Lower calls the strings.ToLower func (s String) Lower() string { return strings.ToLower(s.source) } +// ReplaceAll calls the strings.ReplaceAll func (s String) ReplaceAll(old, new string) string { return strings.ReplaceAll(s.source, old, new) } +//Source returns the source string value func (s String) Source() string { return s.source } +// Title calls the strings.Title func (s String) Title() string { if s.IsEmptyOrSpace() { return s.source @@ -43,7 +50,7 @@ func (s String) Title() string { return strings.Title(s.source) } -// snake->camel(upper start) +// ToCamel converts the input text into camel case func (s String) ToCamel() string { list := s.splitBy(func(r rune) bool { return r == '_' @@ -55,7 +62,7 @@ func (s String) ToCamel() string { return strings.Join(target, "") } -// camel->snake +// ToSnake converts the input text into snake case func (s String) ToSnake() string { list := s.splitBy(unicode.IsUpper, false) var target []string @@ -65,7 +72,7 @@ func (s String) ToSnake() string { return strings.Join(target, "_") } -// return original string if rune is not letter at index 0 +// Untitle return the original string if rune is not letter at index 0 func (s String) Untitle() string { if s.IsEmptyOrSpace() { return s.source @@ -77,10 +84,6 @@ func (s String) Untitle() string { return string(unicode.ToLower(r)) + s.source[1:] } -func (s String) Upper() string { - return strings.ToUpper(s.source) -} - // it will not ignore spaces func (s String) splitBy(fn func(r rune) bool, remove bool) []string { if s.IsEmptyOrSpace() { diff --git a/tools/goctl/util/templatex.go b/tools/goctl/util/templatex.go index 461b2aca..3feae08a 100644 --- a/tools/goctl/util/templatex.go +++ b/tools/goctl/util/templatex.go @@ -9,29 +9,35 @@ import ( const regularPerm = 0666 -type defaultTemplate struct { +// DefaultTemplate is a tool to provides the text/template operations +type DefaultTemplate struct { name string text string goFmt bool savePath string } -func With(name string) *defaultTemplate { - return &defaultTemplate{ +// With returns a instace of DefaultTemplate +func With(name string) *DefaultTemplate { + return &DefaultTemplate{ name: name, } } -func (t *defaultTemplate) Parse(text string) *defaultTemplate { + +// Parse accepts a source template and returns DefaultTemplate +func (t *DefaultTemplate) Parse(text string) *DefaultTemplate { t.text = text return t } -func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate { +// GoFmt sets the value to goFmt and marks the generated codes will be formated or not +func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate { t.goFmt = format return t } -func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error { +// SaveTo writes the codes to the target path +func (t *DefaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error { if FileExists(path) && !forceUpdate { return nil } @@ -44,7 +50,8 @@ func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool return ioutil.WriteFile(path, output.Bytes(), regularPerm) } -func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) { +// Execute returns the codes after the template executed +func (t *DefaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) { tem, err := template.New(t.name).Parse(t.text) if err != nil { return nil, err diff --git a/tools/goctl/vars/settings.go b/tools/goctl/vars/settings.go index 9b11c60e..fdc18fc3 100644 --- a/tools/goctl/vars/settings.go +++ b/tools/goctl/vars/settings.go @@ -1,9 +1,14 @@ package vars const ( - ProjectName = "zero" - ProjectOpenSourceUrl = "github.com/tal-tech/go-zero" - OsWindows = "windows" - OsMac = "darwin" - OsLinux = "linux" + // ProjectName the const value of zero + ProjectName = "zero" + // ProjectOpenSourceURL the githb url of go-zero + ProjectOpenSourceURL = "github.com/tal-tech/go-zero" + // OsWindows windows os + OsWindows = "windows" + // OsMac mac os + OsMac = "darwin" + // OsLinux linux os + OsLinux = "linux" )