patch model&rpc (#207)

* change column to read from information_schema

* reactor generate mode from datasource

* reactor generate mode from datasource

* add primary key check logic

* resolve rebase conflicts

* add naming style

* add filename test case

* resolve rebase conflicts

* reactor test

* add test case

* change shell script to makefile

* update rpc new

* update gen_test.go

* format code

* format code

* update test

* generates alias
This commit is contained in:
Keson 2020-11-18 15:32:53 +08:00 committed by GitHub
parent 71083b5e64
commit 24fb29a356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 674 additions and 1163 deletions

View File

@ -201,6 +201,10 @@ var (
Name: "new", Name: "new",
Usage: `generate rpc demo service`, Usage: `generate rpc demo service`,
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.StringFlag{
Name: "style",
Usage: "the file naming style, lower|camel|snake,default is lower",
},
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [optional]", Usage: "whether the command execution environment is from idea plugin. [optional]",
@ -235,6 +239,10 @@ var (
Name: "dir, d", Name: "dir, d",
Usage: `the target path of the code`, Usage: `the target path of the code`,
}, },
cli.StringFlag{
Name: "style",
Usage: "the file naming style, lower|camel|snake,default is lower",
},
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [optional]", Usage: "whether the command execution environment is from idea plugin. [optional]",
@ -266,7 +274,7 @@ var (
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Usage: "the file naming style, lower|camel|underline,default is lower", Usage: "the file naming style, lower|camel|snake,default is lower",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "cache, c", Name: "cache, c",

View File

@ -68,30 +68,3 @@ func FieldNames(in interface{}) []string {
} }
return out return out
} }
func FieldNamesAlias(in interface{}, alias string) []string {
out := make([]string, 0)
v := reflect.ValueOf(in)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
// we only accept structs
if v.Kind() != reflect.Struct {
panic(fmt.Errorf("ToMap only accepts structs; got %T", v))
}
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
// gets us a StructField
fi := typ.Field(i)
tagName := ""
if tagv := fi.Tag.Get(dbTag); tagv != "" {
tagName = tagv
} else {
tagName = fi.Name
}
if len(alias) > 0 {
tagName = alias + "." + tagName
}
out = append(out, tagName)
}
return out
}

View File

@ -17,6 +17,8 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
var errNotMatched = errors.New("sql not matched")
const ( const (
flagSrc = "src" flagSrc = "src"
flagDir = "dir" flagDir = "dir"
@ -33,6 +35,20 @@ func MysqlDDL(ctx *cli.Context) error {
cache := ctx.Bool(flagCache) cache := ctx.Bool(flagCache)
idea := ctx.Bool(flagIdea) idea := ctx.Bool(flagIdea)
namingStyle := strings.TrimSpace(ctx.String(flagStyle)) namingStyle := strings.TrimSpace(ctx.String(flagStyle))
return fromDDl(src, dir, namingStyle, cache, idea)
}
func MyDataSource(ctx *cli.Context) error {
url := strings.TrimSpace(ctx.String(flagUrl))
dir := strings.TrimSpace(ctx.String(flagDir))
cache := ctx.Bool(flagCache)
idea := ctx.Bool(flagIdea)
namingStyle := strings.TrimSpace(ctx.String(flagStyle))
pattern := strings.TrimSpace(ctx.String(flagTable))
return fromDataSource(url, pattern, dir, namingStyle, cache, idea)
}
func fromDDl(src, dir, namingStyle string, cache, idea bool) error {
log := console.NewConsole(idea) log := console.NewConsole(idea)
src = strings.TrimSpace(src) src = strings.TrimSpace(src)
if len(src) == 0 { if len(src) == 0 {
@ -52,29 +68,29 @@ func MysqlDDL(ctx *cli.Context) error {
return err return err
} }
if len(files) == 0 {
return errNotMatched
}
var source []string var source []string
for _, file := range files { for _, file := range files {
data, err := ioutil.ReadFile(file) data, err := ioutil.ReadFile(file)
if err != nil { if err != nil {
return err return err
} }
source = append(source, string(data)) source = append(source, string(data))
} }
generator := gen.NewDefaultGenerator(strings.Join(source, "\n"), dir, namingStyle, gen.WithConsoleOption(log)) generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log))
err = generator.Start(cache)
if err != nil { if err != nil {
log.Error("%v", err) return err
}
return nil
} }
func MyDataSource(ctx *cli.Context) error { err = generator.StartFromDDL(strings.Join(source, "\n"), cache)
url := strings.TrimSpace(ctx.String(flagUrl)) return err
dir := strings.TrimSpace(ctx.String(flagDir)) }
cache := ctx.Bool(flagCache)
idea := ctx.Bool(flagIdea) func fromDataSource(url, pattern, dir, namingStyle string, cache, idea bool) error {
namingStyle := strings.TrimSpace(ctx.String(flagStyle))
pattern := strings.TrimSpace(ctx.String(flagTable))
log := console.NewConsole(idea) log := console.NewConsole(idea)
if len(url) == 0 { if len(url) == 0 {
log.Error("%v", "expected data source of mysql, but nothing found") log.Error("%v", "expected data source of mysql, but nothing found")
@ -100,10 +116,8 @@ func MyDataSource(ctx *cli.Context) error {
} }
logx.Disable() logx.Disable()
conn := sqlx.NewMysql(url)
databaseSource := strings.TrimSuffix(url, "/"+cfg.DBName) + "/information_schema" databaseSource := strings.TrimSuffix(url, "/"+cfg.DBName) + "/information_schema"
db := sqlx.NewMysql(databaseSource) db := sqlx.NewMysql(databaseSource)
m := model.NewDDLModel(conn)
im := model.NewInformationSchemaModel(db) im := model.NewInformationSchemaModel(db)
tables, err := im.GetAllTables(cfg.DBName) tables, err := im.GetAllTables(cfg.DBName)
@ -111,7 +125,7 @@ func MyDataSource(ctx *cli.Context) error {
return err return err
} }
var matchTables []string matchTables := make(map[string][]*model.Column)
for _, item := range tables { for _, item := range tables {
match, err := filepath.Match(pattern, item) match, err := filepath.Match(pattern, item)
if err != nil { if err != nil {
@ -121,24 +135,22 @@ func MyDataSource(ctx *cli.Context) error {
if !match { if !match {
continue continue
} }
columns, err := im.FindByTableName(cfg.DBName, item)
matchTables = append(matchTables, item) if err != nil {
return err
} }
matchTables[item] = columns
}
if len(matchTables) == 0 { if len(matchTables) == 0 {
return errors.New("no tables matched") return errors.New("no tables matched")
} }
ddl, err := m.ShowDDL(matchTables...) generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log))
if err != nil { if err != nil {
log.Error("%v", err) return err
return nil
} }
generator := gen.NewDefaultGenerator(strings.Join(ddl, "\n"), dir, namingStyle, gen.WithConsoleOption(log)) err = generator.StartFromInformationSchema(cfg.DBName, matchTables, cache)
err = generator.Start(cache) return err
if err != nil {
log.Error("%v", err)
}
return nil
} }

View File

@ -0,0 +1,75 @@
package command
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
var sql = "-- 用户表 --\nCREATE TABLE `user` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n UNIQUE KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;\n\n"
func TestFromDDl(t *testing.T) {
err := fromDDl("./user.sql", t.TempDir(), gen.NamingCamel, true, false)
assert.Equal(t, errNotMatched, err)
// case dir is not exists
unknownDir := filepath.Join(t.TempDir(), "test", "user.sql")
err = fromDDl(unknownDir, t.TempDir(), gen.NamingCamel, true, false)
assert.True(t, func() bool {
switch err.(type) {
case *os.PathError:
return true
default:
return false
}
}())
// case empty src
err = fromDDl("", t.TempDir(), gen.NamingCamel, true, false)
if err != nil {
assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
}
// case unknown naming style
tmp := filepath.Join(t.TempDir(), "user.sql")
err = fromDDl(tmp, t.TempDir(), "lower1", true, false)
if err != nil {
assert.Equal(t, "unexpected naming style: lower1", err.Error())
}
tempDir := filepath.Join(t.TempDir(), "test")
err = util.MkdirIfNotExist(tempDir)
if err != nil {
return
}
user1Sql := filepath.Join(tempDir, "user1.sql")
user2Sql := filepath.Join(tempDir, "user2.sql")
err = ioutil.WriteFile(user1Sql, []byte(sql), os.ModePerm)
if err != nil {
return
}
err = ioutil.WriteFile(user2Sql, []byte(sql), os.ModePerm)
if err != nil {
return
}
_, err = os.Stat(user1Sql)
assert.Nil(t, err)
_, err = os.Stat(user2Sql)
assert.Nil(t, err)
err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, gen.NamingLower, true, false)
assert.Nil(t, err)
_, err = os.Stat(filepath.Join(tempDir, "usermodel.go"))
assert.Nil(t, err)
}

View File

@ -1,11 +0,0 @@
#!/bin/bash
# generate model with cache from ddl
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -c
# generate model with cache from data source
#user=root
#password=password
#datasource=127.0.0.1:3306
#database=test
#goctl model mysql datasource -url="${user}:${password}@tcp(${datasource})/${database}" -table="*" -dir ./model

View File

@ -0,0 +1,15 @@
#!/bin/bash
# generate model with cache from ddl
fromDDL:
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -c
# generate model with cache from data source
user=root
password=password
datasource=127.0.0.1:3306
database=gozero
fromDataSource:
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style camel

View File

@ -42,5 +42,6 @@ func genDelete(table Table, withCache bool) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.String(), nil return output.String(), nil
} }

View File

@ -15,6 +15,7 @@ func genFields(fields []parser.Field) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
list = append(list, result) list = append(list, result)
} }
return strings.Join(list, "\n"), nil return strings.Join(list, "\n"), nil
@ -43,5 +44,6 @@ func genField(field parser.Field) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.String(), nil return output.String(), nil
} }

View File

@ -28,5 +28,6 @@ func genFindOne(table Table, withCache bool) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.String(), nil return output.String(), nil
} }

View File

@ -7,6 +7,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
@ -24,7 +25,7 @@ const (
type ( type (
defaultGenerator struct { defaultGenerator struct {
source string //source string
dir string dir string
console.Console console.Console
pkg string pkg string
@ -33,18 +34,30 @@ type (
Option func(generator *defaultGenerator) Option func(generator *defaultGenerator)
) )
func NewDefaultGenerator(source, dir, namingStyle string, opt ...Option) *defaultGenerator { func NewDefaultGenerator(dir, namingStyle string, opt ...Option) (*defaultGenerator, error) {
if dir == "" { if dir == "" {
dir = pwd dir = pwd
} }
generator := &defaultGenerator{source: source, dir: dir, namingStyle: namingStyle} dirAbs, err := filepath.Abs(dir)
if err != nil {
return nil, err
}
dir = dirAbs
pkg := filepath.Base(dirAbs)
err = util.MkdirIfNotExist(dir)
if err != nil {
return nil, err
}
generator := &defaultGenerator{dir: dir, namingStyle: namingStyle, pkg: pkg}
var optionList []Option var optionList []Option
optionList = append(optionList, newDefaultOption()) optionList = append(optionList, newDefaultOption())
optionList = append(optionList, opt...) optionList = append(optionList, opt...)
for _, fn := range optionList { for _, fn := range optionList {
fn(generator) fn(generator)
} }
return generator return generator, nil
} }
func WithConsoleOption(c console.Console) Option { func WithConsoleOption(c console.Console) Option {
@ -59,21 +72,45 @@ func newDefaultOption() Option {
} }
} }
func (g *defaultGenerator) Start(withCache bool) error { func (g *defaultGenerator) StartFromDDL(source string, withCache bool) error {
modelList, err := g.genFromDDL(source, withCache)
if err != nil {
return err
}
return g.createFile(modelList)
}
func (g *defaultGenerator) StartFromInformationSchema(db string, columns map[string][]*model.Column, withCache bool) error {
m := make(map[string]string)
for tableName, column := range columns {
table, err := parser.ConvertColumn(db, tableName, column)
if err != nil {
return err
}
code, err := g.genModel(*table, withCache)
if err != nil {
return err
}
m[table.Name.Source()] = code
}
return g.createFile(m)
}
func (g *defaultGenerator) createFile(modelList map[string]string) error {
dirAbs, err := filepath.Abs(g.dir) dirAbs, err := filepath.Abs(g.dir)
if err != nil { if err != nil {
return err return err
} }
g.dir = dirAbs g.dir = dirAbs
g.pkg = filepath.Base(dirAbs) g.pkg = filepath.Base(dirAbs)
err = util.MkdirIfNotExist(dirAbs) err = util.MkdirIfNotExist(dirAbs)
if err != nil { if err != nil {
return err return err
} }
modelList, err := g.genFromDDL(withCache)
if err != nil {
return err
}
for tableName, code := range modelList { for tableName, code := range modelList {
tn := stringx.From(tableName) tn := stringx.From(tableName)
@ -96,6 +133,9 @@ func (g *defaultGenerator) Start(withCache bool) error {
} }
// generate error file // generate error file
filename := filepath.Join(dirAbs, "vars.go") filename := filepath.Join(dirAbs, "vars.go")
if g.namingStyle == NamingCamel {
filename = filepath.Join(dirAbs, "Vars.go")
}
text, err := util.LoadTemplate(category, errTemplateFile, template.Error) text, err := util.LoadTemplate(category, errTemplateFile, template.Error)
if err != nil { if err != nil {
return err return err
@ -113,8 +153,8 @@ func (g *defaultGenerator) Start(withCache bool) error {
} }
// ret1: key-table name,value-code // ret1: key-table name,value-code
func (g *defaultGenerator) genFromDDL(withCache bool) (map[string]string, error) { func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string]string, error) {
ddlList := g.split() ddlList := g.split(source)
m := make(map[string]string) m := make(map[string]string)
for _, ddl := range ddlList { for _, ddl := range ddlList {
table, err := parser.Parse(ddl) table, err := parser.Parse(ddl)
@ -139,10 +179,15 @@ type (
) )
func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) { func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
if len(in.PrimaryKey.Name.Source()) == 0 {
return "", fmt.Errorf("table %s: missing primary key", in.Name.Source())
}
text, err := util.LoadTemplate(category, modelTemplateFile, template.Model) text, err := util.LoadTemplate(category, modelTemplateFile, template.Model)
if err != nil { if err != nil {
return "", err return "", err
} }
t := util.With("model"). t := util.With("model").
Parse(text). Parse(text).
GoFmt(true) GoFmt(true)

View File

@ -22,15 +22,19 @@ func TestCacheModel(t *testing.T) {
defer func() { defer func() {
_ = os.RemoveAll(dir) _ = os.RemoveAll(dir)
}() }()
g := NewDefaultGenerator(source, cacheDir, NamingLower) g, err := NewDefaultGenerator(cacheDir, NamingCamel)
err := g.Start(true) assert.Nil(t, err)
err = g.StartFromDDL(source, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(cacheDir, "testuserinfomodel.go")) _, err := os.Stat(filepath.Join(cacheDir, "TestUserInfoModel.go"))
return err == nil return err == nil
}()) }())
g = NewDefaultGenerator(source, noCacheDir, NamingLower) g, err = NewDefaultGenerator(noCacheDir, NamingLower)
err = g.Start(false) assert.Nil(t, err)
err = g.StartFromDDL(source, false)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go")) _, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go"))
@ -47,15 +51,19 @@ func TestNamingModel(t *testing.T) {
defer func() { defer func() {
_ = os.RemoveAll(dir) _ = os.RemoveAll(dir)
}() }()
g := NewDefaultGenerator(source, camelDir, NamingCamel) g, err := NewDefaultGenerator(camelDir, NamingCamel)
err := g.Start(true) assert.Nil(t, err)
err = g.StartFromDDL(source, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go")) _, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
return err == nil return err == nil
}()) }())
g = NewDefaultGenerator(source, snakeDir, NamingSnake) g, err = NewDefaultGenerator(snakeDir, NamingSnake)
err = g.Start(true) assert.Nil(t, err)
err = g.StartFromDDL(source, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go")) _, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go"))

View File

@ -17,7 +17,6 @@ func TestGenCacheKeys(t *testing.T) {
Name: stringx.From("id"), Name: stringx.From("id"),
DataBaseType: "bigint", DataBaseType: "bigint",
DataType: "int64", DataType: "int64",
IsKey: false,
IsPrimaryKey: true, IsPrimaryKey: true,
IsUniqueKey: false, IsUniqueKey: false,
Comment: "自增id", Comment: "自增id",
@ -29,7 +28,6 @@ func TestGenCacheKeys(t *testing.T) {
Name: stringx.From("mobile"), Name: stringx.From("mobile"),
DataBaseType: "varchar", DataBaseType: "varchar",
DataType: "string", DataType: "string",
IsKey: false,
IsPrimaryKey: false, IsPrimaryKey: false,
IsUniqueKey: true, IsUniqueKey: true,
Comment: "手机号", Comment: "手机号",
@ -38,7 +36,6 @@ func TestGenCacheKeys(t *testing.T) {
Name: stringx.From("name"), Name: stringx.From("name"),
DataBaseType: "varchar", DataBaseType: "varchar",
DataType: "string", DataType: "string",
IsKey: false,
IsPrimaryKey: false, IsPrimaryKey: false,
IsUniqueKey: true, IsUniqueKey: true,
Comment: "姓名", Comment: "姓名",
@ -47,7 +44,6 @@ func TestGenCacheKeys(t *testing.T) {
Name: stringx.From("createTime"), Name: stringx.From("createTime"),
DataBaseType: "timestamp", DataBaseType: "timestamp",
DataType: "time.Time", DataType: "time.Time",
IsKey: false,
IsPrimaryKey: false, IsPrimaryKey: false,
IsUniqueKey: false, IsUniqueKey: false,
Comment: "创建时间", Comment: "创建时间",
@ -56,7 +52,6 @@ func TestGenCacheKeys(t *testing.T) {
Name: stringx.From("updateTime"), Name: stringx.From("updateTime"),
DataBaseType: "timestamp", DataBaseType: "timestamp",
DataType: "time.Time", DataType: "time.Time",
IsKey: false,
IsPrimaryKey: false, IsPrimaryKey: false,
IsUniqueKey: false, IsUniqueKey: false,
Comment: "更新时间", Comment: "更新时间",

View File

@ -4,11 +4,10 @@ import (
"regexp" "regexp"
) )
func (g *defaultGenerator) split() []string { func (g *defaultGenerator) split(source string) []string {
reg := regexp.MustCompile(createTableFlag) reg := regexp.MustCompile(createTableFlag)
index := reg.FindAllStringIndex(g.source, -1) index := reg.FindAllStringIndex(source, -1)
list := make([]string, 0) list := make([]string, 0)
source := g.source
for i := len(index) - 1; i >= 0; i-- { for i := len(index) - 1; i >= 0; i-- {
subIndex := index[i] subIndex := index[i]
if len(subIndex) == 0 { if len(subIndex) == 0 {

View File

@ -22,5 +22,6 @@ func genTag(in string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.String(), nil return output.String(), nil
} }

View File

@ -27,6 +27,7 @@ func (m *DDLModel) ShowDDL(table ...string) ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
ddl = append(ddl, resp.DDL) ddl = append(ddl, resp.DDL)
} }
return ddl, nil return ddl, nil

View File

@ -8,6 +8,13 @@ type (
InformationSchemaModel struct { InformationSchemaModel struct {
conn sqlx.SqlConn conn sqlx.SqlConn
} }
Column struct {
Name string `db:"COLUMN_NAME"`
DataType string `db:"DATA_TYPE"`
Key string `db:"COLUMN_KEY"`
Extra string `db:"EXTRA"`
Comment string `db:"COLUMN_COMMENT"`
}
) )
func NewInformationSchemaModel(conn sqlx.SqlConn) *InformationSchemaModel { func NewInformationSchemaModel(conn sqlx.SqlConn) *InformationSchemaModel {
@ -21,5 +28,13 @@ func (m *InformationSchemaModel) GetAllTables(database string) ([]string, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tables, nil return tables, nil
} }
func (m *InformationSchemaModel) FindByTableName(db, table string) ([]*Column, error) {
querySql := `select COLUMN_NAME,DATA_TYPE,COLUMN_KEY,EXTRA,COLUMN_COMMENT from COLUMNS where TABLE_SCHEMA = ? and TABLE_NAME = ?`
var reply []*Column
err := m.conn.QueryRows(&reply, querySql, db, table)
return reply, err
}

View File

@ -2,8 +2,10 @@ package parser
import ( import (
"fmt" "fmt"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/converter" "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
"github.com/xwb1989/sqlparser" "github.com/xwb1989/sqlparser"
) )
@ -34,7 +36,6 @@ type (
Name stringx.String Name stringx.String
DataBaseType string DataBaseType string
DataType string DataType string
IsKey bool
IsPrimaryKey bool IsPrimaryKey bool
IsUniqueKey bool IsUniqueKey bool
Comment string Comment string
@ -123,7 +124,6 @@ func Parse(ddl string) (*Table, error) {
field.Comment = comment field.Comment = comment
key, ok := keyMap[column.Name.String()] key, ok := keyMap[column.Name.String()]
if ok { if ok {
field.IsKey = true
field.IsPrimaryKey = key == primary field.IsPrimaryKey = key == primary
field.IsUniqueKey = key == unique field.IsUniqueKey = key == unique
if field.IsPrimaryKey { if field.IsPrimaryKey {
@ -151,3 +151,62 @@ func (t *Table) ContainsTime() bool {
} }
return false return false
} }
func ConvertColumn(db, table string, in []*model.Column) (*Table, error) {
var reply Table
reply.Name = stringx.From(table)
keyMap := make(map[string][]*model.Column)
for _, column := range in {
keyMap[column.Key] = append(keyMap[column.Key], column)
}
primaryColumns := keyMap["PRI"]
if len(primaryColumns) == 0 {
return nil, fmt.Errorf("database:%s, table %s: missing primary key", db, table)
}
if len(primaryColumns) > 1 {
return nil, fmt.Errorf("database:%s, table %s: only one primary key expected", db, table)
}
primaryColumn := primaryColumns[0]
primaryFt, err := converter.ConvertDataType(primaryColumn.DataType)
if err != nil {
return nil, err
}
primaryField := Field{
Name: stringx.From(primaryColumn.Name),
DataBaseType: primaryColumn.DataType,
DataType: primaryFt,
IsUniqueKey: true,
IsPrimaryKey: true,
Comment: primaryColumn.Comment,
}
reply.PrimaryKey = Primary{
Field: primaryField,
AutoIncrement: strings.Contains(primaryColumn.Extra, "auto_increment"),
}
for key, columns := range keyMap {
for _, item := range columns {
dt, err := converter.ConvertDataType(item.DataType)
if err != nil {
return nil, err
}
f := Field{
Name: stringx.From(item.Name),
DataBaseType: item.DataType,
DataType: dt,
IsPrimaryKey: primaryColumn.Name == item.Name,
Comment: item.Comment,
}
if key == "UNI" {
f.IsUniqueKey = true
}
reply.Fields = append(reply.Fields, f)
}
}
return &reply, nil
}

View File

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
) )
func TestParsePlainText(t *testing.T) { func TestParsePlainText(t *testing.T) {
@ -23,3 +24,58 @@ func TestParseCreateTable(t *testing.T) {
assert.Equal(t, "id", table.PrimaryKey.Name.Source()) assert.Equal(t, "id", table.PrimaryKey.Name.Source())
assert.Equal(t, true, table.ContainsTime()) assert.Equal(t, true, table.ContainsTime())
} }
func TestConvertColumn(t *testing.T) {
_, err := ConvertColumn("user", "user", []*model.Column{
{
Name: "id",
DataType: "bigint",
Key: "",
Extra: "",
Comment: "",
},
})
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "missing primary key")
_, err = ConvertColumn("user", "user", []*model.Column{
{
Name: "id",
DataType: "bigint",
Key: "PRI",
Extra: "",
Comment: "",
},
{
Name: "mobile",
DataType: "varchar",
Key: "PRI",
Extra: "",
Comment: "手机号",
},
})
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "only one primary key expected")
table, err := ConvertColumn("user", "user", []*model.Column{
{
Name: "id",
DataType: "bigint",
Key: "PRI",
Extra: "auto_increment",
Comment: "",
},
{
Name: "mobile",
DataType: "varchar",
Key: "UNI",
Extra: "",
Comment: "手机号",
},
})
assert.Nil(t, err)
assert.True(t, table.PrimaryKey.AutoIncrement && table.PrimaryKey.IsPrimaryKey)
assert.Equal(t, "id", table.PrimaryKey.Name.Source())
assert.Equal(t, "mobile", table.Fields[1].Name.Source())
assert.True(t, table.Fields[1].IsUniqueKey)
}

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/generator" "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@ -16,6 +15,7 @@ import (
func Rpc(c *cli.Context) error { func Rpc(c *cli.Context) error {
src := c.String("src") src := c.String("src")
out := c.String("dir") out := c.String("dir")
style := c.String("style")
protoImportPath := c.StringSlice("proto_path") protoImportPath := c.StringSlice("proto_path")
if len(src) == 0 { if len(src) == 0 {
return errors.New("missing -src") return errors.New("missing -src")
@ -23,7 +23,13 @@ func Rpc(c *cli.Context) error {
if len(out) == 0 { if len(out) == 0 {
return errors.New("missing -dir") return errors.New("missing -dir")
} }
g := generator.NewDefaultRpcGenerator()
namingStyle, valid := generator.IsNamingValid(style)
if !valid {
return fmt.Errorf("unexpected naming style %s", style)
}
g := generator.NewDefaultRpcGenerator(namingStyle)
return g.Generate(src, out, protoImportPath) return g.Generate(src, out, protoImportPath)
} }
@ -36,6 +42,12 @@ func RpcNew(c *cli.Context) error {
return fmt.Errorf("unexpected ext: %s", ext) return fmt.Errorf("unexpected ext: %s", ext)
} }
style := c.String("style")
namingStyle, valid := generator.IsNamingValid(style)
if !valid {
return fmt.Errorf("expected naming style [lower|camel|snake], but found %s", style)
}
protoName := name + ".proto" protoName := name + ".proto"
filename := filepath.Join(".", name, protoName) filename := filepath.Join(".", name, protoName)
src, err := filepath.Abs(filename) src, err := filepath.Abs(filename)
@ -48,13 +60,7 @@ func RpcNew(c *cli.Context) error {
return err return err
} }
workDir := filepath.Dir(src) g := generator.NewDefaultRpcGenerator(namingStyle)
_, err = execx.Run("go mod init "+name, workDir)
if err != nil {
return err
}
g := generator.NewDefaultRpcGenerator()
return g.Generate(src, filepath.Dir(src), nil) return g.Generate(src, filepath.Dir(src), nil)
} }

View File

@ -1,75 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: common.proto
package common
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type User struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *User) Reset() { *m = User{} }
func (m *User) String() string { return proto.CompactTextString(m) }
func (*User) ProtoMessage() {}
func (*User) Descriptor() ([]byte, []int) {
return fileDescriptor_555bd8c177793206, []int{0}
}
func (m *User) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_User.Unmarshal(m, b)
}
func (m *User) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_User.Marshal(b, m, deterministic)
}
func (m *User) XXX_Merge(src proto.Message) {
xxx_messageInfo_User.Merge(m, src)
}
func (m *User) XXX_Size() int {
return xxx_messageInfo_User.Size(m)
}
func (m *User) XXX_DiscardUnknown() {
xxx_messageInfo_User.DiscardUnknown(m)
}
var xxx_messageInfo_User proto.InternalMessageInfo
func (m *User) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func init() {
proto.RegisterType((*User)(nil), "common.User")
}
func init() { proto.RegisterFile("common.proto", fileDescriptor_555bd8c177793206) }
var fileDescriptor_555bd8c177793206 = []byte{
// 72 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0xce, 0xcf, 0xcd,
0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0xa4, 0xb8, 0x58,
0x42, 0x8b, 0x53, 0x8b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, 0x25, 0x18, 0x15, 0x18,
0x35, 0x38, 0x83, 0xc0, 0xec, 0x24, 0x36, 0xb0, 0x52, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
0x2c, 0x6d, 0x58, 0x59, 0x3a, 0x00, 0x00, 0x00,
}

View File

@ -6,6 +6,13 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func formatFilename(filename string) string { func formatFilename(filename string, style NamingStyle) string {
switch style {
case namingCamel:
return stringx.From(filename).ToCamel()
case namingSnake:
return stringx.From(filename).ToSnake()
default:
return strings.ToLower(stringx.From(filename).ToCamel()) return strings.ToLower(stringx.From(filename).ToCamel())
} }
}

View File

@ -0,0 +1,17 @@
package generator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatFilename(t *testing.T) {
assert.Equal(t, "abc", formatFilename("a_b_c", namingLower))
assert.Equal(t, "ABC", formatFilename("a_b_c", namingCamel))
assert.Equal(t, "a_b_c", formatFilename("a_b_c", namingSnake))
assert.Equal(t, "a", formatFilename("a", namingSnake))
assert.Equal(t, "A", formatFilename("a", namingCamel))
// no flag to convert to snake
assert.Equal(t, "abc", formatFilename("abc", namingSnake))
}

View File

@ -11,15 +11,17 @@ import (
type RpcGenerator struct { type RpcGenerator struct {
g Generator g Generator
style NamingStyle
} }
func NewDefaultRpcGenerator() *RpcGenerator { func NewDefaultRpcGenerator(style NamingStyle) *RpcGenerator {
return NewRpcGenerator(NewDefaultGenerator()) return NewRpcGenerator(NewDefaultGenerator(), style)
} }
func NewRpcGenerator(g Generator) *RpcGenerator { func NewRpcGenerator(g Generator, style NamingStyle) *RpcGenerator {
return &RpcGenerator{ return &RpcGenerator{
g: g, g: g,
style: style,
} }
} }
@ -55,42 +57,42 @@ func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) er
return err return err
} }
err = g.g.GenEtc(dirCtx, proto) err = g.g.GenEtc(dirCtx, proto, g.style)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenPb(dirCtx, protoImportPath, proto) err = g.g.GenPb(dirCtx, protoImportPath, proto, g.style)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenConfig(dirCtx, proto) err = g.g.GenConfig(dirCtx, proto, g.style)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenSvc(dirCtx, proto) err = g.g.GenSvc(dirCtx, proto, g.style)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenLogic(dirCtx, proto) err = g.g.GenLogic(dirCtx, proto, g.style)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenServer(dirCtx, proto) err = g.g.GenServer(dirCtx, proto, g.style)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenMain(dirCtx, proto) err = g.g.GenMain(dirCtx, proto, g.style)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenCall(dirCtx, proto) err = g.g.GenCall(dirCtx, proto, g.style)
console.NewColorConsole().MarkDone() console.NewColorConsole().MarkDone()

View File

@ -1,128 +1,74 @@
package generator package generator
import ( import (
"go/build"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
) )
func TestRpcGenerateCaseNilImport(t *testing.T) { func TestRpcGenerate(t *testing.T) {
_ = Clean() _ = Clean()
dispatcher := NewDefaultGenerator() dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil { err := dispatcher.Prepare()
g := NewRpcGenerator(dispatcher) if err != nil {
abs, err := filepath.Abs("./test") logx.Error(err)
assert.Nil(t, err) return
}
projectName := stringx.Rand()
g := NewRpcGenerator(dispatcher, namingLower)
err = g.Generate("./test_stream.proto", abs, nil) // case go path
src := filepath.Join(build.Default.GOPATH, "src")
_, err = os.Stat(src)
if err != nil {
return
}
projectDir := filepath.Join(src, projectName)
srcDir := projectDir
defer func() { defer func() {
_ = os.RemoveAll(abs) _ = os.RemoveAll(srcDir)
}() }()
err = g.Generate("./test.proto", projectDir, []string{src})
assert.Nil(t, err) assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
_, err = execx.Run("go test "+abs, abs) if err != nil {
assert.Nil(t, err) assert.Contains(t, err.Error(), "not in GOROOT")
}
} }
func TestRpcGenerateCaseOption(t *testing.T) { // case go mod
_ = Clean() workDir := t.TempDir()
dispatcher := NewDefaultGenerator() name := filepath.Base(workDir)
if err := dispatcher.Prepare(); err == nil { _, err = execx.Run("go mod init "+name, workDir)
g := NewRpcGenerator(dispatcher) if err != nil {
abs, err := filepath.Abs("./test") logx.Error(err)
assert.Nil(t, err) return
err = g.Generate("./test_option.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
} }
func TestRpcGenerateCaseWordOption(t *testing.T) { projectDir = filepath.Join(workDir, projectName)
_ = Clean() err = g.Generate("./test.proto", projectDir, []string{src})
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err) assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
err = g.Generate("./test_word_option.proto", abs, nil) if err != nil {
defer func() { assert.Contains(t, err.Error(), "not in GOROOT")
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
} }
// test keyword go // case not in go mod and go path
func TestRpcGenerateCaseGoOption(t *testing.T) { err = g.Generate("./test.proto", projectDir, []string{src})
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err) assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
err = g.Generate("./test_go_option.proto", abs, nil) if err != nil {
defer func() { assert.Contains(t, err.Error(), "not in GOROOT")
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
} }
func TestRpcGenerateCaseImport(t *testing.T) { // invalid directory
_ = Clean() projectDir = filepath.Join(t.TempDir(), ".....")
dispatcher := NewDefaultGenerator() err = g.Generate("./test.proto", projectDir, nil)
if err := dispatcher.Prepare(); err == nil { assert.NotNil(t, err)
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_import.proto", abs, []string{"./base"})
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.True(t, func() bool {
return strings.Contains(err.Error(), "package base is not in GOROOT")
}())
}
}
func TestRpcGenerateCaseServiceRpcNamingSnake(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_service_rpc_naming_snake.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
} }

View File

@ -59,12 +59,12 @@ func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbReque
` `
) )
func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error { func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetCall() dir := ctx.GetCall()
service := proto.Service service := proto.Service
head := util.GetHead(proto.Name) head := util.GetHead(proto.Name)
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name))) filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name, namingStyle)))
functions, err := g.genFunction(proto.PbPackage, service) functions, err := g.genFunction(proto.PbPackage, service)
if err != nil { if err != nil {
return err return err
@ -81,13 +81,12 @@ func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
} }
var alias = collection.NewSet() var alias = collection.NewSet()
for _, item := range service.RPC { for _, item := range proto.Message {
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.RequestType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.RequestType)))) alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.Name), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.Name))))
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.ReturnsType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.ReturnsType))))
} }
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"name": formatFilename(service.Name), "name": formatFilename(service.Name, namingStyle),
"alias": strings.Join(alias.KeysStr(), util.NL), "alias": strings.Join(alias.KeysStr(), util.NL),
"head": head, "head": head,
"filePackage": dir.Base, "filePackage": dir.Base,

View File

@ -1,44 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateCall(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenCall(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -18,9 +18,9 @@ type Config struct {
} }
` `
func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto) error { func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetConfig() dir := ctx.GetConfig()
fileName := filepath.Join(dir.Filename, formatFilename("config")+".go") fileName := filepath.Join(dir.Filename, formatFilename("config", namingStyle)+".go")
if util.FileExists(fileName) { if util.FileExists(fileName) {
return nil return nil
} }

View File

@ -1,48 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateConfig(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenConfig(dirCtx, proto)
assert.Nil(t, err)
// test file exists
err = g.GenConfig(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -4,12 +4,12 @@ import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
type Generator interface { type Generator interface {
Prepare() error Prepare() error
GenMain(ctx DirContext, proto parser.Proto) error GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
GenCall(ctx DirContext, proto parser.Proto) error GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
GenEtc(ctx DirContext, proto parser.Proto) error GenEtc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
GenConfig(ctx DirContext, proto parser.Proto) error GenConfig(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
GenLogic(ctx DirContext, proto parser.Proto) error GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
GenServer(ctx DirContext, proto parser.Proto) error GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
GenSvc(ctx DirContext, proto parser.Proto) error GenSvc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error
GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error
} }

View File

@ -3,9 +3,11 @@ package generator
import ( import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
const etcTemplate = `Name: {{.serviceName}}.rpc const etcTemplate = `Name: {{.serviceName}}.rpc
@ -16,9 +18,9 @@ Etcd:
Key: {{.serviceName}}.rpc Key: {{.serviceName}}.rpc
` `
func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error { func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetEtc() dir := ctx.GetEtc()
serviceNameLower := formatFilename(ctx.GetMain().Base) serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle)
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower)) fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower))
text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate) text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
@ -27,6 +29,6 @@ func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error {
} }
return util.With("etc").Parse(text).SaveTo(map[string]interface{}{ return util.With("etc").Parse(text).SaveTo(map[string]interface{}{
"serviceName": serviceNameLower, "serviceName": strings.ToLower(stringx.From(ctx.GetMain().Base).ToCamel()),
}, fileName, false) }, fileName, false)
} }

View File

@ -1,45 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateEtc(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenEtc(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -46,10 +46,10 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
` `
) )
func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error { func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetLogic() dir := ctx.GetLogic()
for _, rpc := range proto.Service.RPC { for _, rpc := range proto.Service.RPC {
filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic")+".go") filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic", namingStyle)+".go")
functions, err := g.genLogicFunction(proto.PbPackage, rpc) functions, err := g.genLogicFunction(proto.PbPackage, rpc)
if err != nil { if err != nil {
return err return err

View File

@ -1,44 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateLogic(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenLogic(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -45,9 +45,9 @@ func main() {
} }
` `
func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error { func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetMain() dir := ctx.GetMain()
serviceNameLower := formatFilename(ctx.GetMain().Base) serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle)
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower)) fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower))
imports := make([]string, 0) imports := make([]string, 0)
pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package) pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
@ -63,7 +63,7 @@ func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head, "head": head,
"serviceName": serviceNameLower, "serviceName": strings.ToLower(stringx.From(ctx.GetMain().Base).ToCamel()),
"imports": strings.Join(imports, util.NL), "imports": strings.Join(imports, util.NL),
"pkg": proto.PbPackage, "pkg": proto.PbPackage,
"serviceNew": stringx.From(proto.Service.Name).ToCamel(), "serviceNew": stringx.From(proto.Service.Name).ToCamel(),

View File

@ -1,45 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateMain(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenMain(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -9,7 +9,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
) )
func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error { func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetPb() dir := ctx.GetPb()
cw := new(bytes.Buffer) cw := new(bytes.Buffer)
base := filepath.Dir(proto.Src) base := filepath.Dir(proto.Src)

View File

@ -1,184 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateCaseNilImport(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
//_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
func TestGenerateCaseImport(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
func TestGenerateCasePathOption(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_option.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_option.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
func TestGenerateCaseWordOption(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_word_option.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_word_option.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
// test keyword go
func TestGenerateCaseGoOption(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_go_option.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_go_option.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}

View File

@ -43,7 +43,7 @@ func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) (
` `
) )
func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error { func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetServer() dir := ctx.GetServer()
logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package) logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package) svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
@ -54,7 +54,7 @@ func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
head := util.GetHead(proto.Name) head := util.GetHead(proto.Name)
service := proto.Service service := proto.Service
serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server")+".go") serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server", namingStyle)+".go")
funcList, err := g.genFunctions(proto.PbPackage, service) funcList, err := g.genFunctions(proto.PbPackage, service)
if err != nil { if err != nil {
return err return err

View File

@ -1,45 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateServer(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenServer(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -23,9 +23,9 @@ func NewServiceContext(c config.Config) *ServiceContext {
} }
` `
func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto) error { func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error {
dir := ctx.GetSvc() dir := ctx.GetSvc()
fileName := filepath.Join(dir.Filename, formatFilename("service_context")+".go") fileName := filepath.Join(dir.Filename, formatFilename("service_context", namingStyle)+".go")
text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate) text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate)
if err != nil { if err != nil {
return err return err

View File

@ -1,40 +0,0 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateSvc(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.GenSvc(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -1,130 +0,0 @@
package generator
import (
"go/build"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestMkDirInGoPath(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
defer func() {
_ = os.RemoveAll(dir)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
internal := filepath.Join(dir, "internal")
assert.True(t, true, func() bool {
return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
}())
assert.True(t, true, func() bool {
return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
}())
assert.True(t, true, func() bool {
return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
}())
}
func TestMkDirInGoMod(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
_, err = execx.Run("go mod init "+projectName, dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(dir)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
internal := filepath.Join(dir, "internal")
assert.True(t, true, func() bool {
return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
}())
assert.True(t, true, func() bool {
return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
}())
assert.True(t, true, func() bool {
return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
}())
}

View File

@ -0,0 +1,24 @@
package generator
type NamingStyle = string
const (
namingLower NamingStyle = "lower"
namingCamel NamingStyle = "camel"
namingSnake NamingStyle = "snake"
)
// IsNamingValid validates whether the namingStyle is valid or not,return
// namingStyle and true if it is valid, or else return empty string
// and false, and it is a valid value even namingStyle is empty string
func IsNamingValid(namingStyle string) (NamingStyle, bool) {
if len(namingStyle) == 0 {
namingStyle = namingLower
}
switch namingStyle {
case namingLower, namingCamel, namingSnake:
return namingStyle, true
default:
return "", false
}
}

View File

@ -0,0 +1,25 @@
package generator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsNamingValid(t *testing.T) {
style, valid := IsNamingValid("")
assert.True(t, valid)
assert.Equal(t, namingLower, style)
_, valid = IsNamingValid("lower1")
assert.False(t, valid)
_, valid = IsNamingValid("lower")
assert.True(t, valid)
_, valid = IsNamingValid("snake")
assert.True(t, valid)
_, valid = IsNamingValid("camel")
assert.True(t, valid)
}

View File

@ -1,7 +1,6 @@
package generator package generator
import ( import (
"os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -9,13 +8,13 @@ import (
) )
func TestProtoTmpl(t *testing.T) { func TestProtoTmpl(t *testing.T) {
out, err := filepath.Abs("./test/test.proto") _ = Clean()
// exists dir
err := ProtoTmpl(t.TempDir())
assert.Nil(t, err) assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(filepath.Dir(out)) // not exist dir
}() dir := filepath.Join(t.TempDir(), "test")
err = ProtoTmpl(out) err = ProtoTmpl(dir)
assert.Nil(t, err)
_, err = os.Stat(out)
assert.Nil(t, err) assert.Nil(t, err)
} }

View File

@ -2,6 +2,7 @@ package generator
import ( import (
"io/ioutil" "io/ioutil"
"os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -10,87 +11,104 @@ import (
) )
func TestGenTemplates(t *testing.T) { func TestGenTemplates(t *testing.T) {
err := util.InitTemplates(category, templates) _ = Clean()
err := GenTemplates(nil)
assert.Nil(t, err) assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, "main.tpl")
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), mainTemplate)
} }
func TestRevertTemplate(t *testing.T) { func TestRevertTemplate(t *testing.T) {
name := "main.tpl" _ = Clean()
err := util.InitTemplates(category, templates) err := GenTemplates(nil)
assert.Nil(t, err)
fp, err := util.GetTemplateDir(category)
if err != nil {
return
}
mainTpl := filepath.Join(fp, mainTemplateFile)
data, err := ioutil.ReadFile(mainTpl)
if err != nil {
return
}
assert.Equal(t, templates[mainTemplateFile], string(data))
err = RevertTemplate("test")
if err != nil {
assert.Equal(t, "test: no such file name", err.Error())
}
err = ioutil.WriteFile(mainTpl, []byte("modify"), os.ModePerm)
if err != nil {
return
}
data, err = ioutil.ReadFile(mainTpl)
if err != nil {
return
}
assert.Equal(t, "modify", string(data))
err = RevertTemplate(mainTemplateFile)
assert.Nil(t, err) assert.Nil(t, err)
dir, err := util.GetTemplateDir(category) data, err = ioutil.ReadFile(mainTpl)
assert.Nil(t, err) if err != nil {
return
file := filepath.Join(dir, name) }
data, err := ioutil.ReadFile(file) assert.Equal(t, templates[mainTemplateFile], string(data))
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, RevertTemplate(name))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
} }
func TestClean(t *testing.T) { func TestClean(t *testing.T) {
name := "main.tpl" _ = Clean()
err := util.InitTemplates(category, templates) err := GenTemplates(nil)
assert.Nil(t, err)
fp, err := util.GetTemplateDir(category)
if err != nil {
return
}
mainTpl := filepath.Join(fp, mainTemplateFile)
_, err = os.Stat(mainTpl)
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, Clean()) err = Clean()
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err) assert.Nil(t, err)
file := filepath.Join(dir, name) _, err = os.Stat(mainTpl)
_, err = ioutil.ReadFile(file)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
name := "main.tpl" _ = Clean()
err := util.InitTemplates(category, templates) err := GenTemplates(nil)
assert.Nil(t, err)
fp, err := util.GetTemplateDir(category)
if err != nil {
return
}
mainTpl := filepath.Join(fp, mainTemplateFile)
err = ioutil.WriteFile(mainTpl, []byte("modify"), os.ModePerm)
if err != nil {
return
}
data, err := ioutil.ReadFile(mainTpl)
if err != nil {
return
}
assert.Equal(t, "modify", string(data))
err = Update(category)
assert.Nil(t, err) assert.Nil(t, err)
dir, err := util.GetTemplateDir(category) data, err = ioutil.ReadFile(mainTpl)
assert.Nil(t, err) if err != nil {
return
file := filepath.Join(dir, name) }
data, err := ioutil.ReadFile(file) assert.Equal(t, templates[mainTemplateFile], string(data))
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
} }
func TestGetCategory(t *testing.T) { func TestGetCategory(t *testing.T) {
assert.Equal(t, category, GetCategory()) _ = Clean()
result := GetCategory()
assert.Equal(t, category, result)
} }

View File

@ -2,24 +2,61 @@
syntax = "proto3"; syntax = "proto3";
package test; package test;
option go_package = "go";
import "test_base.proto"; import "base/common.proto";
import "google/protobuf/any.proto";
message TestMessage { option go_package = "github.com/test";
base.CommonReq req = 1;
} message Req {
message TestReq {} string in = 1;
message TestReply { common.User user = 2;
base.CommonReply reply = 2; google.protobuf.Any object = 4;
} }
enum TestEnum { message Reply {
string out = 1;
}
message snake_req {}
message snake_reply {}
message CamelReq{}
message CamelReply{}
message EnumMessage {
enum Enum {
unknown = 0; unknown = 0;
male = 1; male = 1;
female = 2; female = 2;
} }
}
service TestService {
rpc TestRpc (TestReq) returns (TestReply); message CommonReply{}
message MapReq{
map<string, string> m = 1;
}
message RepeatedReq{
repeated string id = 1;
}
service Test_Service {
// service
rpc Service (Req) returns (Reply);
// greet service
rpc GreetService (Req) returns (Reply);
// case snake
rpc snake_service (snake_req) returns (snake_reply);
// case camel
rpc CamelService (CamelReq) returns (CamelReply);
// case enum
rpc EnumService (EnumMessage) returns (CommonReply);
// case map
rpc MapService (MapReq) returns (CommonReply);
// case repeated
rpc RepeatedService (RepeatedReq) returns (CommonReply);
} }

View File

@ -1,12 +0,0 @@
// test proto
syntax = "proto3";
package base;
message CommonReq {
string in = 1;
}
message CommonReply {
string out = 1;
}

View File

@ -1,18 +0,0 @@
// test proto
syntax = "proto3";
package stream;
option go_package="go";
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
rpc greet (StreamReq) returns (StreamResp);
}

View File

@ -1,18 +0,0 @@
// test proto
syntax = "proto3";
package greet;
import "base/common.proto";
message In {
string name = 1;
common.User user = 2;
}
message Out {
string greet = 1;
}
service StreamGreeter {
rpc greet (In) returns (Out);
}

View File

@ -1,18 +0,0 @@
// test proto
syntax = "proto3";
package stream;
option go_package="github.com/tal-tech/go-zero";
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
rpc greet (StreamReq) returns (StreamResp);
}

View File

@ -1,27 +0,0 @@
// test proto
syntax = "proto3";
package snake_package;
message StreamReq {
string name = 1;
}
message Stream_Resp {
string greet = 1;
}
message lowercase {
string in = 1;
string lower = 2;
}
message CamelCase {
string Camel = 1;
}
service Stream_Greeter {
rpc snake_service(StreamReq) returns (Stream_Resp);
rpc ServiceCamelCase(CamelCase) returns (CamelCase);
rpc servicelowercase(lowercase) returns (lowercase);
}

View File

@ -1,17 +0,0 @@
// test proto
syntax = "proto3";
package stream;
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
// greet service
rpc greet (StreamReq) returns (StreamResp);
}

View File

@ -1,18 +0,0 @@
// test proto
syntax = "proto3";
package stream;
option go_package="user";
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp);
}

View File

@ -29,9 +29,11 @@ func (s String) IsEmptyOrSpace() bool {
func (s String) Lower() string { func (s String) Lower() string {
return strings.ToLower(s.source) return strings.ToLower(s.source)
} }
func (s String) Upper() string { func (s String) Upper() string {
return strings.ToUpper(s.source) return strings.ToUpper(s.source)
} }
func (s String) Title() string { func (s String) Title() string {
if s.IsEmptyOrSpace() { if s.IsEmptyOrSpace() {
return s.source return s.source