Code optimized (#493)

This commit is contained in:
kingxt 2021-02-20 19:50:03 +08:00 committed by GitHub
parent 059027bc9d
commit f98c9246b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 472 additions and 372 deletions

View File

@ -39,6 +39,7 @@ service {{.serviceName}} {
} }
` `
// ApiCommand create api template file
func ApiCommand(c *cli.Context) error { func ApiCommand(c *cli.Context) error {
apiFile := c.String("o") apiFile := c.String("o")
if len(apiFile) == 0 { if len(apiFile) == 0 {

View File

@ -9,6 +9,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
// DartCommand create dart network request code
func DartCommand(c *cli.Context) error { func DartCommand(c *cli.Context) error {
apiFile := c.String("api") apiFile := c.String("api")
dir := c.String("dir") dir := c.String("dir")

View File

@ -2,9 +2,9 @@ package dartgen
import ( import (
"os" "os"
"reflect"
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
) )
@ -32,10 +32,18 @@ func pathToFuncName(path string) string {
return util.ToLower(camel[:1]) + camel[1:] return util.ToLower(camel[:1]) + camel[1:]
} }
func tagGet(tag, k string) (reflect.Value, error) { func tagGet(tag, k string) string {
v, _ := util.TagLookup(tag, k) tags, err := spec.Parse(tag)
out := strings.Split(v, ",")[0] if err != nil {
return reflect.ValueOf(out), nil panic(k + " not exist")
}
v, err := tags.Get(k)
if err != nil {
panic(k + " value not exist")
}
return v.Name
} }
func isDirectType(s string) bool { func isDirectType(s string) bool {

View File

@ -12,6 +12,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
// DocCommand generate markdown doc file
func DocCommand(c *cli.Context) error { func DocCommand(c *cli.Context) error {
dir := c.String("dir") dir := c.String("dir")
if len(dir) == 0 { if len(dir) == 0 {

View File

@ -160,6 +160,7 @@ type Response struct {
@server( @server(
jwt: Auth jwt: Auth
signature: true
) )
service A-api { service A-api {
@handler GreetHandler @handler GreetHandler

View File

@ -40,7 +40,7 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
for _, item := range authNames { for _, item := range authNames {
auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate)) auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate))
} }
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
return genFile(fileGenConfig{ return genFile(fileGenConfig{
dir: dir, dir: dir,

View File

@ -109,7 +109,7 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str
if len(route.RequestTypeName()) > 0 { if len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir)))
} }
imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t") return strings.Join(imports, "\n\t")
} }

View File

@ -122,6 +122,6 @@ func genLogicImports(route spec.Route, parentPkg string) string {
if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 { if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
} }
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t") return strings.Join(imports, "\n\t")
} }

View File

@ -74,7 +74,7 @@ func genMainImports(parentPkg string) string {
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir))) imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, configDir)))
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir))) imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, handlerDir)))
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, contextDir)))
imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/core/conf\"", vars.ProjectOpenSourceURL))
imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t") return strings.Join(imports, "\n\t")
} }

View File

@ -89,7 +89,7 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
} }
var signature string var signature string
if g.signatureEnabled { if g.signatureEnabled {
signature = fmt.Sprintf("\n rest.WithSignature(serverCtx.Config.%s.Signature),", g.authName) signature = "\n rest.WithSignature(serverCtx.Config.Signature),"
} }
var routes string var routes string
@ -163,7 +163,7 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
imports := importSet.KeysStr() imports := importSet.KeysStr()
sort.Strings(imports) sort.Strings(imports)
projectSection := strings.Join(imports, "\n\t") projectSection := strings.Join(imports, "\n\t")
depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) depSection := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
return fmt.Sprintf("%s\n\n\t%s", projectSection, depSection) return fmt.Sprintf("%s\n\n\t%s", projectSection, depSection)
} }
@ -196,6 +196,10 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
groupedRoutes.authName = jwt groupedRoutes.authName = jwt
groupedRoutes.jwtEnabled = true groupedRoutes.jwtEnabled = true
} }
signature := g.GetAnnotation("signature")
if signature == "true" {
groupedRoutes.signatureEnabled = true
}
middleware := g.GetAnnotation("middleware") middleware := g.GetAnnotation("middleware")
if len(middleware) > 0 { if len(middleware) > 0 {
for _, item := range strings.Split(middleware, ",") { for _, item := range strings.Split(middleware, ",") {

View File

@ -64,7 +64,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
if len(middlewareStr) > 0 { if len(middlewareStr) > 0 {
configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\"" configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl) configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceURL)
} }
return genFile(fileGenConfig{ return genFile(fileGenConfig{

View File

@ -94,10 +94,6 @@ func getAuths(api *spec.ApiSpec) []string {
if len(jwt) > 0 { if len(jwt) > 0 {
authNames.Add(jwt) authNames.Add(jwt)
} }
signature := g.GetAnnotation("signature")
if len(signature) > 0 {
authNames.Add(signature)
}
} }
return authNames.KeysStr() return authNames.KeysStr()
} }

View File

@ -22,12 +22,6 @@ type Api struct {
} }
func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} { func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
defer func() {
if p := recover(); p != nil {
panic(fmt.Errorf("%+v", p))
}
}()
var final Api var final Api
final.importM = map[string]PlaceHolder{} final.importM = map[string]PlaceHolder{}
final.typeM = map[string]PlaceHolder{} final.typeM = map[string]PlaceHolder{}
@ -36,109 +30,128 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
final.routeM = map[string]PlaceHolder{} final.routeM = map[string]PlaceHolder{}
for _, each := range ctx.AllSpec() { for _, each := range ctx.AllSpec() {
root := each.Accept(v).(*Api) root := each.Accept(v).(*Api)
if root.Syntax != nil { v.acceptSyntax(root, &final)
if final.Syntax != nil { v.accetpImport(root, &final)
v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration")) v.acceptInfo(root, &final)
} v.acceptType(root, &final)
v.acceptService(root, &final)
final.Syntax = root.Syntax
}
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
}
final.importM[imp.Value.Text()] = Holder
final.Import = append(final.Import, imp)
}
if root.Info != nil {
infoM := map[string]PlaceHolder{}
if final.Info != nil {
v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration"))
}
for _, value := range root.Info.Kvs {
if _, ok := infoM[value.Key.Text()]; ok {
v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text()))
}
infoM[value.Key.Text()] = Holder
}
final.Info = root.Info
}
for _, tp := range root.Type {
if _, ok := final.typeM[tp.NameExpr().Text()]; ok {
v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text()))
}
final.typeM[tp.NameExpr().Text()] = Holder
final.Type = append(final.Type, tp)
}
for _, service := range root.Service {
if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 {
v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration"))
}
if service.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range service.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
}
}
for _, route := range service.ServiceApi.ServiceRoute {
uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text())
if _, ok := final.routeM[uniqueRoute]; ok {
v.panic(route.Route.Method, fmt.Sprintf("duplicate route '%s'", uniqueRoute))
}
final.routeM[uniqueRoute] = Holder
var handlerExpr Expr
if route.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range route.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
if kv.Key.Text() == "handler" {
handlerExpr = kv.Value
}
}
}
if route.AtHandler != nil {
handlerExpr = route.AtHandler.Name
}
if handlerExpr == nil {
v.panic(route.Route.Method, fmt.Sprintf("mismtached handler"))
}
if handlerExpr.Text() == "" {
v.panic(handlerExpr, fmt.Sprintf("mismtached handler"))
}
if _, ok := final.handlerM[handlerExpr.Text()]; ok {
v.panic(handlerExpr, fmt.Sprintf("duplicate handler '%s'", handlerExpr.Text()))
}
final.handlerM[handlerExpr.Text()] = Holder
}
final.Service = append(final.Service, service)
}
} }
return &final return &final
} }
func (v *ApiVisitor) acceptService(root *Api, final *Api) {
for _, service := range root.Service {
if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 {
v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration"))
}
v.duplicateServerItemCheck(service)
for _, route := range service.ServiceApi.ServiceRoute {
uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text())
if _, ok := final.routeM[uniqueRoute]; ok {
v.panic(route.Route.Method, fmt.Sprintf("duplicate route '%s'", uniqueRoute))
}
final.routeM[uniqueRoute] = Holder
var handlerExpr Expr
if route.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range route.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
if kv.Key.Text() == "handler" {
handlerExpr = kv.Value
}
}
}
if route.AtHandler != nil {
handlerExpr = route.AtHandler.Name
}
if handlerExpr == nil {
v.panic(route.Route.Method, fmt.Sprintf("mismtached handler"))
}
if handlerExpr.Text() == "" {
v.panic(handlerExpr, fmt.Sprintf("mismtached handler"))
}
if _, ok := final.handlerM[handlerExpr.Text()]; ok {
v.panic(handlerExpr, fmt.Sprintf("duplicate handler '%s'", handlerExpr.Text()))
}
final.handlerM[handlerExpr.Text()] = Holder
}
final.Service = append(final.Service, service)
}
}
func (v *ApiVisitor) duplicateServerItemCheck(service *Service) {
if service.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range service.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
}
}
}
func (v *ApiVisitor) acceptType(root *Api, final *Api) {
for _, tp := range root.Type {
if _, ok := final.typeM[tp.NameExpr().Text()]; ok {
v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text()))
}
final.typeM[tp.NameExpr().Text()] = Holder
final.Type = append(final.Type, tp)
}
}
func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
if root.Info != nil {
infoM := map[string]PlaceHolder{}
if final.Info != nil {
v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration"))
}
for _, value := range root.Info.Kvs {
if _, ok := infoM[value.Key.Text()]; ok {
v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text()))
}
infoM[value.Key.Text()] = Holder
}
final.Info = root.Info
}
}
func (v *ApiVisitor) accetpImport(root *Api, final *Api) {
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
}
final.importM[imp.Value.Text()] = Holder
final.Import = append(final.Import, imp)
}
}
func (v *ApiVisitor) acceptSyntax(root *Api, final *Api) {
if root.Syntax != nil {
if final.Syntax != nil {
v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration"))
}
final.Syntax = root.Syntax
}
}
func (v *ApiVisitor) VisitSpec(ctx *api.SpecContext) interface{} { func (v *ApiVisitor) VisitSpec(ctx *api.SpecContext) interface{} {
var root Api var root Api
if ctx.SyntaxLit() != nil { if ctx.SyntaxLit() != nil {

View File

@ -156,28 +156,9 @@ func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
} }
func (p *Parser) valid(mainApi *Api, nestedApi *Api) error { func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
if len(nestedApi.Import) > 0 { err := p.nestedApiCheck(mainApi, nestedApi)
importToken := nestedApi.Import[0].Import if err != nil {
return fmt.Errorf("%s line %d:%d the nested api does not support import", return err
nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
}
} }
mainHandlerMap := make(map[string]PlaceHolder) mainHandlerMap := make(map[string]PlaceHolder)
@ -218,6 +199,23 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
} }
// duplicate route check // duplicate route check
err = p.duplicateRouteCheck(nestedApi, mainHandlerMap, mainRouteMap)
if err != nil {
return err
}
// duplicate type check
for _, each := range nestedApi.Type {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
}
}
return nil
}
func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap map[string]PlaceHolder, mainRouteMap map[string]PlaceHolder) error {
for _, each := range nestedApi.Service { for _, each := range nestedApi.Service {
for _, r := range each.ServiceApi.ServiceRoute { for _, r := range each.ServiceApi.ServiceRoute {
handler := r.GetHandler() handler := r.GetHandler()
@ -237,12 +235,31 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
} }
} }
} }
return nil
}
// duplicate type check func (p *Parser) nestedApiCheck(mainApi *Api, nestedApi *Api) error {
for _, each := range nestedApi.Type { if len(nestedApi.Import) > 0 {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok { importToken := nestedApi.Import[0].Import
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'", return fmt.Errorf("%s line %d:%d the nested api does not support import",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text()) nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
} }
} }
return nil return nil
@ -276,56 +293,80 @@ func (p *Parser) checkTypeDeclaration(apiList []*Api) error {
for _, apiItem := range apiList { for _, apiItem := range apiList {
linePrefix := apiItem.LinePrefix linePrefix := apiItem.LinePrefix
for _, each := range apiItem.Type { err := p.checkTypes(apiItem, linePrefix, types)
tp, ok := each.(*TypeStruct) if err != nil {
if !ok { return err
continue }
err = p.checkServices(apiItem, types, linePrefix)
if err != nil {
return err
}
}
return nil
}
func (p *Parser) checkServices(apiItem *Api, types map[string]TypeExpr, linePrefix string) error {
for _, service := range apiItem.Service {
for _, each := range service.ServiceApi.ServiceRoute {
route := each.Route
err := p.checkRequestBody(route, types, linePrefix)
if err != nil {
return err
} }
for _, member := range tp.Fields { if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() {
err := p.checkType(linePrefix, types, member.DataType) reply := route.Reply.Name
if err != nil { var structName string
return err switch tp := reply.(type) {
case *Literal:
structName = tp.Literal.Text()
case *Array:
switch innerTp := tp.Literal.(type) {
case *Literal:
structName = innerTp.Literal.Text()
case *Pointer:
structName = innerTp.Name.Text()
}
}
if api.IsBasicType(structName) {
continue
}
_, ok := types[structName]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName)
} }
} }
} }
}
return nil
}
for _, service := range apiItem.Service { func (p *Parser) checkRequestBody(route *Route, types map[string]TypeExpr, linePrefix string) error {
for _, each := range service.ServiceApi.ServiceRoute { if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() {
route := each.Route _, ok := types[route.Req.Name.Expr().Text()]
if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() { if !ok {
_, ok := types[route.Req.Name.Expr().Text()] return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
if !ok { linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text())
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context", }
linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text()) }
} return nil
} }
if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() { func (p *Parser) checkTypes(apiItem *Api, linePrefix string, types map[string]TypeExpr) error {
reply := route.Reply.Name for _, each := range apiItem.Type {
var structName string tp, ok := each.(*TypeStruct)
switch tp := reply.(type) { if !ok {
case *Literal: continue
structName = tp.Literal.Text() }
case *Array:
switch innerTp := tp.Literal.(type) {
case *Literal:
structName = innerTp.Literal.Text()
case *Pointer:
structName = innerTp.Name.Text()
}
}
if api.IsBasicType(structName) { for _, member := range tp.Fields {
continue err := p.checkType(linePrefix, types, member.DataType)
} if err != nil {
return err
_, ok := types[structName]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName)
}
}
} }
} }
} }

View File

@ -213,13 +213,7 @@ func (p parser) fillService() error {
var groups []spec.Group var groups []spec.Group
for _, item := range p.ast.Service { for _, item := range p.ast.Service {
var group spec.Group var group spec.Group
if item.AtServer != nil { p.fillAtServer(item, &group)
var properties = make(map[string]string, 0)
for _, kv := range item.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
group.Annotation.Properties = properties
}
for _, astRoute := range item.ServiceApi.ServiceRoute { for _, astRoute := range item.ServiceApi.ServiceRoute {
route := spec.Route{ route := spec.Route{
@ -231,25 +225,9 @@ func (p parser) fillService() error {
route.Handler = astRoute.AtHandler.Name.Text() route.Handler = astRoute.AtHandler.Name.Text()
} }
if astRoute.AtServer != nil { err := p.fillRouteAtServer(astRoute, &route)
var properties = make(map[string]string, 0) if err != nil {
for _, kv := range astRoute.AtServer.Kv { return err
properties[kv.Key.Text()] = kv.Value.Text()
}
route.Annotation.Properties = properties
if len(route.Handler) == 0 {
route.Handler = properties["handler"]
}
if len(route.Handler) == 0 {
return fmt.Errorf("missing handler annotation for %q", route.Path)
}
for _, char := range route.Handler {
if !unicode.IsDigit(char) && !unicode.IsLetter(char) {
return fmt.Errorf("route [%s] handler [%s] invalid, handler name should only contains letter or digit",
route.Path, route.Handler)
}
}
} }
if astRoute.Route.Req != nil { if astRoute.Route.Req != nil {
@ -269,7 +247,7 @@ func (p parser) fillService() error {
} }
} }
err := p.fillRouteType(&route) err = p.fillRouteType(&route)
if err != nil { if err != nil {
return err return err
} }
@ -289,6 +267,40 @@ func (p parser) fillService() error {
return nil return nil
} }
func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route) error {
if astRoute.AtServer != nil {
var properties = make(map[string]string, 0)
for _, kv := range astRoute.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
route.Annotation.Properties = properties
if len(route.Handler) == 0 {
route.Handler = properties["handler"]
}
if len(route.Handler) == 0 {
return fmt.Errorf("missing handler annotation for %q", route.Path)
}
for _, char := range route.Handler {
if !unicode.IsDigit(char) && !unicode.IsLetter(char) {
return fmt.Errorf("route [%s] handler [%s] invalid, handler name should only contains letter or digit",
route.Path, route.Handler)
}
}
}
return nil
}
func (p parser) fillAtServer(item *ast.Service, group *spec.Group) {
if item.AtServer != nil {
var properties = make(map[string]string, 0)
for _, kv := range item.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
group.Annotation.Properties = properties
}
}
func (p parser) fillRouteType(route *spec.Route) error { func (p parser) fillRouteType(route *spec.Route) error {
if route.RequestType != nil { if route.RequestType != nil {
switch route.RequestType.(type) { switch route.RequestType.(type) {

View File

@ -1,58 +0,0 @@
package util
import (
"strconv"
"strings"
)
func TagLookup(tag, key string) (value string, ok bool) {
tag = strings.Replace(tag, "`", "", -1)
for tag != "" {
// Skip leading space.
i := 0
for i < len(tag) && tag[i] == ' ' {
i++
}
tag = tag[i:]
if tag == "" {
break
}
// Scan to colon. A space, a quote or a control character is a syntax error.
// Strictly speaking, control chars include the range [0x7f, 0x9f], not just
// [0x00, 0x1f], but in practice, we ignore the multi-byte control characters
// as it is simpler to inspect the tag's bytes than the tag's runes.
i = 0
for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f {
i++
}
if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' {
break
}
name := string(tag[:i])
tag = tag[i+1:]
// Scan quoted string to find value.
i = 1
for i < len(tag) && tag[i] != '"' {
if tag[i] == '\\' {
i++
}
i++
}
if i >= len(tag) {
break
}
qvalue := string(tag[:i+1])
tag = tag[i+1:]
if key == name {
value, err := strconv.Unquote(qvalue)
if err != nil {
break
}
return value, true
}
}
return "", false
}

View File

@ -28,7 +28,7 @@ import (
) )
var ( var (
BuildVersion = "1.1.5" buildVersion = "1.1.5"
commands = []cli.Command{ commands = []cli.Command{
{ {
Name: "upgrade", Name: "upgrade",
@ -510,7 +510,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Usage = "a cli tool to generate code" app.Usage = "a cli tool to generate code"
app.Version = fmt.Sprintf("%s %s/%s", BuildVersion, runtime.GOOS, runtime.GOARCH) app.Version = fmt.Sprintf("%s %s/%s", buildVersion, runtime.GOOS, runtime.GOARCH)
app.Commands = commands app.Commands = commands
// cli already print error messages // cli already print error messages
if err := app.Run(os.Args); err != nil { if err := app.Run(os.Args); err != nil {

View File

@ -1,6 +1,7 @@
package gen package gen
import ( import (
"bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -31,7 +32,20 @@ type (
pkg string pkg string
cfg *config.Config cfg *config.Config
} }
Option func(generator *defaultGenerator) Option func(generator *defaultGenerator)
code struct {
importsCode string
varsCode string
typesCode string
newCode string
insertCode string
findCode []string
updateCode string
deleteCode string
cacheExtra string
}
) )
func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaultGenerator, error) { func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaultGenerator, error) {
@ -186,15 +200,6 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", fmt.Errorf("table %s: missing primary key", in.Name.Source()) return "", fmt.Errorf("table %s: missing primary key", in.Name.Source())
} }
text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
if err != nil {
return "", err
}
t := util.With("model").
Parse(text).
GoFmt(true)
m, err := genCacheKeys(in) m, err := genCacheKeys(in)
if err != nil { if err != nil {
return "", err return "", err
@ -261,18 +266,19 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", err return "", err
} }
output, err := t.Execute(map[string]interface{}{ code := &code{
"pkg": g.pkg, importsCode: importsCode,
"imports": importsCode, varsCode: varsCode,
"vars": varsCode, typesCode: typesCode,
"types": typesCode, newCode: newCode,
"new": newCode, insertCode: insertCode,
"insert": insertCode, findCode: findCode,
"find": strings.Join(findCode, "\n"), updateCode: updateCode,
"update": updateCode, deleteCode: deleteCode,
"delete": deleteCode, cacheExtra: ret.cacheExtra,
"extraMethod": ret.cacheExtra, }
})
output, err := g.executeModel(code)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -280,6 +286,32 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return output.String(), nil return output.String(), nil
} }
func (g *defaultGenerator) executeModel(code *code) (*bytes.Buffer, error) {
text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
if err != nil {
return nil, err
}
t := util.With("model").
Parse(text).
GoFmt(true)
output, err := t.Execute(map[string]interface{}{
"pkg": g.pkg,
"imports": code.importsCode,
"vars": code.varsCode,
"types": code.typesCode,
"new": code.newCode,
"insert": code.insertCode,
"find": strings.Join(code.findCode, "\n"),
"update": code.updateCode,
"delete": code.deleteCode,
"extraMethod": code.cacheExtra,
})
if err != nil {
return nil, err
}
return output, nil
}
func wrapWithRawString(v string) string { func wrapWithRawString(v string) string {
if v == "`" { if v == "`" {
return v return v

View File

@ -68,6 +68,71 @@ func Parse(ddl string) (*Table, error) {
columns := tableSpec.Columns columns := tableSpec.Columns
indexes := tableSpec.Indexes indexes := tableSpec.Indexes
keyMap, err := getIndexKeyType(indexes)
if err != nil {
return nil, err
}
fields, primaryKey, err := convertFileds(columns, keyMap)
if err != nil {
return nil, err
}
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
Fields: fields,
}, nil
}
func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyType) ([]Field, Primary, error) {
var fields []Field
var primaryKey Primary
for _, column := range columns {
if column == nil {
continue
}
var comment string
if column.Type.Comment != nil {
comment = string(column.Type.Comment.Val)
}
var isDefaultNull = true
if column.Type.NotNull {
isDefaultNull = false
} else {
if column.Type.Default == nil {
isDefaultNull = false
} else if string(column.Type.Default.Val) != "null" {
isDefaultNull = false
}
}
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
if err != nil {
return nil, primaryKey, err
}
var field Field
field.Name = stringx.From(column.Name.String())
field.DataBaseType = column.Type.Type
field.DataType = dataType
field.Comment = comment
key, ok := keyMap[column.Name.String()]
if ok {
field.IsPrimaryKey = key == primary
field.IsUniqueKey = key == unique
if field.IsPrimaryKey {
primaryKey.Field = field
if column.Type.Autoincrement {
primaryKey.AutoIncrement = true
}
}
}
fields = append(fields, field)
}
return fields, primaryKey, nil
}
func getIndexKeyType(indexes []*sqlparser.IndexDefinition) (map[string]KeyType, error) {
keyMap := make(map[string]KeyType) keyMap := make(map[string]KeyType)
for _, index := range indexes { for _, index := range indexes {
info := index.Info info := index.Info
@ -101,56 +166,7 @@ func Parse(ddl string) (*Table, error) {
keyMap[columnName] = normal keyMap[columnName] = normal
} }
} }
return keyMap, nil
var fields []Field
var primaryKey Primary
for _, column := range columns {
if column == nil {
continue
}
var comment string
if column.Type.Comment != nil {
comment = string(column.Type.Comment.Val)
}
var isDefaultNull = true
if column.Type.NotNull {
isDefaultNull = false
} else {
if column.Type.Default == nil {
isDefaultNull = false
} else if string(column.Type.Default.Val) != "null" {
isDefaultNull = false
}
}
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
if err != nil {
return nil, err
}
var field Field
field.Name = stringx.From(column.Name.String())
field.DataBaseType = column.Type.Type
field.DataType = dataType
field.Comment = comment
key, ok := keyMap[column.Name.String()]
if ok {
field.IsPrimaryKey = key == primary
field.IsUniqueKey = key == unique
if field.IsPrimaryKey {
primaryKey.Field = field
if column.Type.Autoincrement {
primaryKey.AutoIncrement = true
}
}
}
fields = append(fields, field)
}
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
Fields: fields,
}, nil
} }
func (t *Table) ContainsTime() bool { func (t *Table) ContainsTime() bool {

View File

@ -8,6 +8,9 @@ import (
) )
type ( type (
// Console wraps from the fmt.Sprintf,
// by default, it implemented the colorConsole to provide the colorful output to the consle
// and the ideaConsole to output with prefix for the plugin of intellij
Console interface { Console interface {
Success(format string, a ...interface{}) Success(format string, a ...interface{})
Info(format string, a ...interface{}) Info(format string, a ...interface{})
@ -25,6 +28,7 @@ type (
} }
) )
// NewConsole returns a instance of Console
func NewConsole(idea bool) Console { func NewConsole(idea bool) Console {
if idea { if idea {
return NewIdeaConsole() return NewIdeaConsole()
@ -32,7 +36,8 @@ func NewConsole(idea bool) Console {
return NewColorConsole() return NewColorConsole()
} }
func NewColorConsole() *colorConsole { // NewColorConsole returns a instance of colorConsole
func NewColorConsole() Console {
return &colorConsole{} return &colorConsole{}
} }
@ -76,7 +81,8 @@ func (c *colorConsole) Must(err error) {
} }
} }
func NewIdeaConsole() *ideaConsole { // NewIdeaConsole returns a instace of ideaConsole
func NewIdeaConsole() Console {
return &ideaConsole{} return &ideaConsole{}
} }

View File

@ -9,6 +9,8 @@ import (
var errModuleCheck = errors.New("the work directory must be found in the go mod or the $GOPATH") var errModuleCheck = errors.New("the work directory must be found in the go mod or the $GOPATH")
// ProjectContext is a structure for the project,
// which contains WorkDir, Name, Path and Dir
type ProjectContext struct { type ProjectContext struct {
WorkDir string WorkDir string
// Name is the root name of the project // Name is the root name of the project

View File

@ -9,6 +9,8 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
) )
// Module contains the relative data of go module,
// which is the result of the command go list
type Module struct { type Module struct {
Path string Path string
Main bool Main bool

View File

@ -10,9 +10,7 @@ import (
"github.com/logrusorgru/aurora" "github.com/logrusorgru/aurora"
) )
const ( const NL = "\n"
NL = "\n"
)
func CreateIfNotExist(file string) (*os.File, error) { func CreateIfNotExist(file string) (*os.File, error) {
_, err := os.Stat(file) _, err := os.Stat(file)

View File

@ -18,6 +18,7 @@ const (
upper upper
) )
// ErrNamingFormat defines an error for unknown fomat
var ErrNamingFormat = errors.New("unsupported format") var ErrNamingFormat = errors.New("unsupported format")
type ( type (

View File

@ -1,3 +1,5 @@
// Package name provides methods to verify naming style and format naming style
// See the method IsNamingValid, FormatFilename
package name package name
import ( import (
@ -6,11 +8,15 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
// NamingStyle the type of string
type NamingStyle = string type NamingStyle = string
const ( const (
// NamingLower defines the lower spell case
NamingLower NamingStyle = "lower" NamingLower NamingStyle = "lower"
// NamingCamel defines the camel spell case
NamingCamel NamingStyle = "camel" NamingCamel NamingStyle = "camel"
// NamingSnake defines the snake spell case
NamingSnake NamingStyle = "snake" NamingSnake NamingStyle = "snake"
) )
@ -29,6 +35,8 @@ func IsNamingValid(namingStyle string) (NamingStyle, bool) {
} }
} }
// FormatFilename converts the filename string to the target
// naming style by calling method of stringx
func FormatFilename(filename string, style NamingStyle) string { func FormatFilename(filename string, style NamingStyle) string {
switch style { switch style {
case NamingCamel: case NamingCamel:

View File

@ -6,14 +6,17 @@ import (
"unicode" "unicode"
) )
// String provides for coverting the source text into other spell case,like lower,snake,camel
type String struct { type String struct {
source string source string
} }
// From converts the input text to String and returns it
func From(data string) String { func From(data string) String {
return String{source: data} return String{source: data}
} }
// IsEmptyOrSpace returns true if the length of the string value is 0 after call strings.TrimSpace, or else returns false
func (s String) IsEmptyOrSpace() bool { func (s String) IsEmptyOrSpace() bool {
if len(s.source) == 0 { if len(s.source) == 0 {
return true return true
@ -24,18 +27,22 @@ func (s String) IsEmptyOrSpace() bool {
return false return false
} }
// Lower calls the strings.ToLower
func (s String) Lower() string { func (s String) Lower() string {
return strings.ToLower(s.source) return strings.ToLower(s.source)
} }
// ReplaceAll calls the strings.ReplaceAll
func (s String) ReplaceAll(old, new string) string { func (s String) ReplaceAll(old, new string) string {
return strings.ReplaceAll(s.source, old, new) return strings.ReplaceAll(s.source, old, new)
} }
//Source returns the source string value
func (s String) Source() string { func (s String) Source() string {
return s.source return s.source
} }
// Title calls the strings.Title
func (s String) Title() string { func (s String) Title() string {
if s.IsEmptyOrSpace() { if s.IsEmptyOrSpace() {
return s.source return s.source
@ -43,7 +50,7 @@ func (s String) Title() string {
return strings.Title(s.source) return strings.Title(s.source)
} }
// snake->camel(upper start) // ToCamel converts the input text into camel case
func (s String) ToCamel() string { func (s String) ToCamel() string {
list := s.splitBy(func(r rune) bool { list := s.splitBy(func(r rune) bool {
return r == '_' return r == '_'
@ -55,7 +62,7 @@ func (s String) ToCamel() string {
return strings.Join(target, "") return strings.Join(target, "")
} }
// camel->snake // ToSnake converts the input text into snake case
func (s String) ToSnake() string { func (s String) ToSnake() string {
list := s.splitBy(unicode.IsUpper, false) list := s.splitBy(unicode.IsUpper, false)
var target []string var target []string
@ -65,7 +72,7 @@ func (s String) ToSnake() string {
return strings.Join(target, "_") return strings.Join(target, "_")
} }
// return original string if rune is not letter at index 0 // Untitle return the original string if rune is not letter at index 0
func (s String) Untitle() string { func (s String) Untitle() string {
if s.IsEmptyOrSpace() { if s.IsEmptyOrSpace() {
return s.source return s.source
@ -77,10 +84,6 @@ func (s String) Untitle() string {
return string(unicode.ToLower(r)) + s.source[1:] return string(unicode.ToLower(r)) + s.source[1:]
} }
func (s String) Upper() string {
return strings.ToUpper(s.source)
}
// it will not ignore spaces // it will not ignore spaces
func (s String) splitBy(fn func(r rune) bool, remove bool) []string { func (s String) splitBy(fn func(r rune) bool, remove bool) []string {
if s.IsEmptyOrSpace() { if s.IsEmptyOrSpace() {

View File

@ -9,29 +9,35 @@ import (
const regularPerm = 0666 const regularPerm = 0666
type defaultTemplate struct { // DefaultTemplate is a tool to provides the text/template operations
type DefaultTemplate struct {
name string name string
text string text string
goFmt bool goFmt bool
savePath string savePath string
} }
func With(name string) *defaultTemplate { // With returns a instace of DefaultTemplate
return &defaultTemplate{ func With(name string) *DefaultTemplate {
return &DefaultTemplate{
name: name, name: name,
} }
} }
func (t *defaultTemplate) Parse(text string) *defaultTemplate {
// Parse accepts a source template and returns DefaultTemplate
func (t *DefaultTemplate) Parse(text string) *DefaultTemplate {
t.text = text t.text = text
return t return t
} }
func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate { // GoFmt sets the value to goFmt and marks the generated codes will be formated or not
func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate {
t.goFmt = format t.goFmt = format
return t return t
} }
func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error { // SaveTo writes the codes to the target path
func (t *DefaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error {
if FileExists(path) && !forceUpdate { if FileExists(path) && !forceUpdate {
return nil return nil
} }
@ -44,7 +50,8 @@ func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool
return ioutil.WriteFile(path, output.Bytes(), regularPerm) return ioutil.WriteFile(path, output.Bytes(), regularPerm)
} }
func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) { // Execute returns the codes after the template executed
func (t *DefaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) {
tem, err := template.New(t.name).Parse(t.text) tem, err := template.New(t.name).Parse(t.text)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -1,9 +1,14 @@
package vars package vars
const ( const (
ProjectName = "zero" // ProjectName the const value of zero
ProjectOpenSourceUrl = "github.com/tal-tech/go-zero" ProjectName = "zero"
OsWindows = "windows" // ProjectOpenSourceURL the githb url of go-zero
OsMac = "darwin" ProjectOpenSourceURL = "github.com/tal-tech/go-zero"
OsLinux = "linux" // OsWindows windows os
OsWindows = "windows"
// OsMac mac os
OsMac = "darwin"
// OsLinux linux os
OsLinux = "linux"
) )