diff --git a/tools/goctl/api/gogen/util.go b/tools/goctl/api/gogen/util.go
index 6550f297..9423d7e6 100644
--- a/tools/goctl/api/gogen/util.go
+++ b/tools/goctl/api/gogen/util.go
@@ -10,29 +10,20 @@ import (
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
- goctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
- "github.com/tal-tech/go-zero/tools/goctl/util/project"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func getParentPackage(dir string) (string, error) {
- p, err := project.Prepare(dir, false)
+ abs, err := filepath.Abs(dir)
if err != nil {
return "", err
}
- if len(p.GoMod.Path) > 0 {
- goModePath := filepath.Clean(filepath.Dir(p.GoMod.Path))
- absPath, err := filepath.Abs(dir)
- if err != nil {
- return "", err
- }
- parent := filepath.Clean(goctlutil.JoinPackages(p.GoMod.Module, absPath[len(goModePath):]))
- parent = strings.ReplaceAll(parent, "\\", "/")
- parent = strings.ReplaceAll(parent, `\`, "/")
- return parent, nil
+ projectCtx, err := ctx.Prepare(abs)
+ if err != nil {
+ return "", err
}
-
- return p.Package, nil
+ return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
}
func writeIndent(writer io.Writer, indent int) {
diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go
index cddc274e..b4e96f1f 100644
--- a/tools/goctl/goctl.go
+++ b/tools/goctl/goctl.go
@@ -19,7 +19,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/configgen"
"github.com/tal-tech/go-zero/tools/goctl/docker"
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
- rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/command"
+ rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli"
"github.com/tal-tech/go-zero/tools/goctl/tpl"
"github.com/urfave/cli"
)
@@ -211,7 +211,7 @@ var (
Flags: []cli.Flag{
cli.BoolFlag{
Name: "idea",
- Usage: "whether the command execution environment is from idea plugin. [option]",
+ Usage: "whether the command execution environment is from idea plugin. [optional]",
},
},
Action: rpc.RpcNew,
@@ -226,7 +226,7 @@ var (
},
cli.BoolFlag{
Name: "idea",
- Usage: "whether the command execution environment is from idea plugin. [option]",
+ Usage: "whether the command execution environment is from idea plugin. [optional]",
},
},
Action: rpc.RpcTemplate,
@@ -239,17 +239,17 @@ var (
Name: "src, s",
Usage: "the file path of the proto source file",
},
- cli.StringFlag{
- Name: "dir, d",
- Usage: `the target path of the code,default path is "${pwd}". [option]`,
+ cli.StringSliceFlag{
+ Name: "proto_path, I",
+ Usage: `native command of protoc, specify the directory in which to search for imports. [optional]`,
},
cli.StringFlag{
- Name: "service, srv",
- Usage: `the name of rpc service. [option]`,
+ Name: "dir, d",
+ Usage: `the target path of the code`,
},
cli.BoolFlag{
Name: "idea",
- Usage: "whether the command execution environment is from idea plugin. [option]",
+ Usage: "whether the command execution environment is from idea plugin. [optional]",
},
},
Action: rpc.Rpc,
@@ -313,7 +313,7 @@ var (
},
cli.StringFlag{
Name: "style",
- Usage: "the file naming style, lower|camel|underline,default is lower",
+ Usage: "the file naming style, lower|camel|snake, default is lower",
},
cli.BoolFlag{
Name: "idea",
diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go
index b12ebada..fc377ff9 100644
--- a/tools/goctl/model/sql/command/command.go
+++ b/tools/goctl/model/sql/command/command.go
@@ -40,7 +40,7 @@ func MysqlDDL(ctx *cli.Context) error {
}
switch namingStyle {
- case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline:
+ case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "":
namingStyle = gen.NamingLower
default:
@@ -87,7 +87,7 @@ func MyDataSource(ctx *cli.Context) error {
}
switch namingStyle {
- case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline:
+ case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "":
namingStyle = gen.NamingLower
default:
diff --git a/tools/goctl/model/sql/converter/types.go b/tools/goctl/model/sql/converter/types.go
index c6f93105..a1b1a175 100644
--- a/tools/goctl/model/sql/converter/types.go
+++ b/tools/goctl/model/sql/converter/types.go
@@ -8,30 +8,36 @@ import (
var (
commonMysqlDataTypeMap = map[string]string{
// For consistency, all integer types are converted to int64
- "tinyint": "int64",
- "smallint": "int64",
- "mediumint": "int64",
- "int": "int64",
- "integer": "int64",
- "bigint": "int64",
- "float": "float64",
- "double": "float64",
- "decimal": "float64",
- "date": "time.Time",
- "time": "string",
- "year": "int64",
- "datetime": "time.Time",
- "timestamp": "time.Time",
+ // number
+ "bool": "int64",
+ "boolean": "int64",
+ "tinyint": "int64",
+ "smallint": "int64",
+ "mediumint": "int64",
+ "int": "int64",
+ "integer": "int64",
+ "bigint": "int64",
+ "float": "float64",
+ "double": "float64",
+ "decimal": "float64",
+ // date&time
+ "date": "time.Time",
+ "datetime": "time.Time",
+ "timestamp": "time.Time",
+ "time": "string",
+ "year": "int64",
+ // string
"char": "string",
"varchar": "string",
- "tinyblob": "string",
+ "binary": "string",
+ "varbinary": "string",
"tinytext": "string",
- "blob": "string",
"text": "string",
- "mediumblob": "string",
"mediumtext": "string",
- "longblob": "string",
"longtext": "string",
+ "enum": "string",
+ "set": "string",
+ "json": "string",
}
)
diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go
index 9bb0d304..752fe0ee 100644
--- a/tools/goctl/model/sql/gen/gen.go
+++ b/tools/goctl/model/sql/gen/gen.go
@@ -19,7 +19,7 @@ const (
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
NamingLower = "lower"
NamingCamel = "camel"
- NamingUnderline = "underline"
+ NamingSnake = "snake"
)
type (
@@ -81,7 +81,7 @@ func (g *defaultGenerator) Start(withCache bool) error {
switch g.namingStyle {
case NamingCamel:
name = fmt.Sprintf("%sModel.go", tn.ToCamel())
- case NamingUnderline:
+ case NamingSnake:
name = fmt.Sprintf("%s_model.go", tn.ToSnake())
}
filename := filepath.Join(dirAbs, name)
diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go
index 7e615816..656b6e50 100644
--- a/tools/goctl/model/sql/gen/gen_test.go
+++ b/tools/goctl/model/sql/gen/gen_test.go
@@ -1,6 +1,8 @@
package gen
import (
+ "os"
+ "path/filepath"
"testing"
"github.com/stretchr/testify/assert"
@@ -8,27 +10,55 @@ import (
)
var (
- source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;"
+ source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `content` json DEFAULT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;"
)
func TestCacheModel(t *testing.T) {
logx.Disable()
_ = Clean()
- g := NewDefaultGenerator(source, "./testmodel/cache", NamingLower)
+ dir, _ := filepath.Abs("./testmodel")
+ cacheDir := filepath.Join(dir, "cache")
+ noCacheDir := filepath.Join(dir, "nocache")
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+ g := NewDefaultGenerator(source, cacheDir, NamingLower)
err := g.Start(true)
assert.Nil(t, err)
- g = NewDefaultGenerator(source, "./testmodel/nocache", NamingLower)
+ assert.True(t, func() bool {
+ _, err := os.Stat(filepath.Join(cacheDir, "testuserinfomodel.go"))
+ return err == nil
+ }())
+ g = NewDefaultGenerator(source, noCacheDir, NamingLower)
err = g.Start(false)
assert.Nil(t, err)
+ assert.True(t, func() bool {
+ _, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go"))
+ return err == nil
+ }())
}
func TestNamingModel(t *testing.T) {
logx.Disable()
_ = Clean()
- g := NewDefaultGenerator(source, "./testmodel/camel", NamingCamel)
+ dir, _ := filepath.Abs("./testmodel")
+ camelDir := filepath.Join(dir, "camel")
+ snakeDir := filepath.Join(dir, "snake")
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+ g := NewDefaultGenerator(source, camelDir, NamingCamel)
err := g.Start(true)
assert.Nil(t, err)
- g = NewDefaultGenerator(source, "./testmodel/snake", NamingUnderline)
+ assert.True(t, func() bool {
+ _, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
+ return err == nil
+ }())
+ g = NewDefaultGenerator(source, snakeDir, NamingSnake)
err = g.Start(true)
assert.Nil(t, err)
+ assert.True(t, func() bool {
+ _, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go"))
+ return err == nil
+ }())
}
diff --git a/tools/goctl/model/sql/gen/testmodel/cache/testuserinfomodel.go b/tools/goctl/model/sql/gen/testmodel/cache/testuserinfomodel.go
deleted file mode 100755
index e657ce19..00000000
--- a/tools/goctl/model/sql/gen/testmodel/cache/testuserinfomodel.go
+++ /dev/null
@@ -1,125 +0,0 @@
-package cache
-
-import (
- "database/sql"
- "fmt"
- "strings"
- "time"
-
- "github.com/tal-tech/go-zero/core/stores/cache"
- "github.com/tal-tech/go-zero/core/stores/sqlc"
- "github.com/tal-tech/go-zero/core/stores/sqlx"
- "github.com/tal-tech/go-zero/core/stringx"
- "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
-)
-
-var (
- testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
- testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
- testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
- testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
-
- cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#"
- cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#"
-)
-
-type (
- TestUserInfoModel struct {
- sqlc.CachedConn
- table string
- }
-
- TestUserInfo struct {
- Id int64 `db:"id"`
- Nanosecond int64 `db:"nanosecond"`
- Data string `db:"data"`
- CreateTime time.Time `db:"create_time"`
- UpdateTime time.Time `db:"update_time"`
- }
-)
-
-func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel {
- return &TestUserInfoModel{
- CachedConn: sqlc.NewConn(conn, c),
- table: "test_user_info",
- }
-}
-
-func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
- ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
- return conn.Exec(query, data.Nanosecond, data.Data)
- }, testUserInfoNanosecondKey)
- return ret, err
-}
-
-func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
- var resp TestUserInfo
- err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error {
- query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
- return conn.QueryRow(v, query, id)
- })
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond)
- var resp TestUserInfo
- err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
- query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
- if err := conn.QueryRow(&resp, query, nanosecond); err != nil {
- return nil, err
- }
- return resp.Id, nil
- }, m.queryPrimary)
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) Update(data TestUserInfo) error {
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id)
- _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
- return conn.Exec(query, data.Nanosecond, data.Data, data.Id)
- }, testUserInfoIdKey)
- return err
-}
-
-func (m *TestUserInfoModel) Delete(id int64) error {
- data, err := m.FindOne(id)
- if err != nil {
- return err
- }
-
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
- _, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("delete from %s where id = ?", m.table)
- return conn.Exec(query, id)
- }, testUserInfoNanosecondKey, testUserInfoIdKey)
- return err
-}
-
-func (m *TestUserInfoModel) formatPrimary(primary interface{}) string {
- return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary)
-}
-
-func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
- query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
- return conn.QueryRow(v, query, primary)
-}
diff --git a/tools/goctl/model/sql/gen/testmodel/cache/vars.go b/tools/goctl/model/sql/gen/testmodel/cache/vars.go
deleted file mode 100644
index 14e43396..00000000
--- a/tools/goctl/model/sql/gen/testmodel/cache/vars.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package cache
-
-import "github.com/tal-tech/go-zero/core/stores/sqlx"
-
-var ErrNotFound = sqlx.ErrNotFound
diff --git a/tools/goctl/model/sql/gen/testmodel/camel/TestUserInfoModel.go b/tools/goctl/model/sql/gen/testmodel/camel/TestUserInfoModel.go
deleted file mode 100755
index 3fc2da64..00000000
--- a/tools/goctl/model/sql/gen/testmodel/camel/TestUserInfoModel.go
+++ /dev/null
@@ -1,125 +0,0 @@
-package camel
-
-import (
- "database/sql"
- "fmt"
- "strings"
- "time"
-
- "github.com/tal-tech/go-zero/core/stores/cache"
- "github.com/tal-tech/go-zero/core/stores/sqlc"
- "github.com/tal-tech/go-zero/core/stores/sqlx"
- "github.com/tal-tech/go-zero/core/stringx"
- "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
-)
-
-var (
- testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
- testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
- testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
- testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
-
- cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#"
- cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#"
-)
-
-type (
- TestUserInfoModel struct {
- sqlc.CachedConn
- table string
- }
-
- TestUserInfo struct {
- Id int64 `db:"id"`
- Nanosecond int64 `db:"nanosecond"`
- Data string `db:"data"`
- CreateTime time.Time `db:"create_time"`
- UpdateTime time.Time `db:"update_time"`
- }
-)
-
-func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel {
- return &TestUserInfoModel{
- CachedConn: sqlc.NewConn(conn, c),
- table: "test_user_info",
- }
-}
-
-func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
- ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
- return conn.Exec(query, data.Nanosecond, data.Data)
- }, testUserInfoNanosecondKey)
- return ret, err
-}
-
-func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
- var resp TestUserInfo
- err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error {
- query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
- return conn.QueryRow(v, query, id)
- })
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond)
- var resp TestUserInfo
- err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
- query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
- if err := conn.QueryRow(&resp, query, nanosecond); err != nil {
- return nil, err
- }
- return resp.Id, nil
- }, m.queryPrimary)
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) Update(data TestUserInfo) error {
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id)
- _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
- return conn.Exec(query, data.Nanosecond, data.Data, data.Id)
- }, testUserInfoIdKey)
- return err
-}
-
-func (m *TestUserInfoModel) Delete(id int64) error {
- data, err := m.FindOne(id)
- if err != nil {
- return err
- }
-
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
- _, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("delete from %s where id = ?", m.table)
- return conn.Exec(query, id)
- }, testUserInfoIdKey, testUserInfoNanosecondKey)
- return err
-}
-
-func (m *TestUserInfoModel) formatPrimary(primary interface{}) string {
- return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary)
-}
-
-func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
- query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
- return conn.QueryRow(v, query, primary)
-}
diff --git a/tools/goctl/model/sql/gen/testmodel/camel/vars.go b/tools/goctl/model/sql/gen/testmodel/camel/vars.go
deleted file mode 100644
index 9e6eb09e..00000000
--- a/tools/goctl/model/sql/gen/testmodel/camel/vars.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package camel
-
-import "github.com/tal-tech/go-zero/core/stores/sqlx"
-
-var ErrNotFound = sqlx.ErrNotFound
diff --git a/tools/goctl/model/sql/gen/testmodel/nocache/testuserinfomodel.go b/tools/goctl/model/sql/gen/testmodel/nocache/testuserinfomodel.go
deleted file mode 100755
index ee5f3e42..00000000
--- a/tools/goctl/model/sql/gen/testmodel/nocache/testuserinfomodel.go
+++ /dev/null
@@ -1,88 +0,0 @@
-package nocache
-
-import (
- "database/sql"
- "fmt"
- "strings"
- "time"
-
- "github.com/tal-tech/go-zero/core/stores/sqlc"
- "github.com/tal-tech/go-zero/core/stores/sqlx"
- "github.com/tal-tech/go-zero/core/stringx"
- "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
-)
-
-var (
- testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
- testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
- testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
- testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
-)
-
-type (
- TestUserInfoModel struct {
- conn sqlx.SqlConn
- table string
- }
-
- TestUserInfo struct {
- Id int64 `db:"id"`
- Nanosecond int64 `db:"nanosecond"`
- Data string `db:"data"`
- CreateTime time.Time `db:"create_time"`
- UpdateTime time.Time `db:"update_time"`
- }
-)
-
-func NewTestUserInfoModel(conn sqlx.SqlConn) *TestUserInfoModel {
- return &TestUserInfoModel{
- conn: conn,
- table: "test_user_info",
- }
-}
-
-func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
- query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
- ret, err := m.conn.Exec(query, data.Nanosecond, data.Data)
- return ret, err
-}
-
-func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
- query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
- var resp TestUserInfo
- err := m.conn.QueryRow(&resp, query, id)
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
- var resp TestUserInfo
- query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
- err := m.conn.QueryRow(&resp, query, nanosecond)
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) Update(data TestUserInfo) error {
- query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
- _, err := m.conn.Exec(query, data.Nanosecond, data.Data, data.Id)
- return err
-}
-
-func (m *TestUserInfoModel) Delete(id int64) error {
- query := fmt.Sprintf("delete from %s where id = ?", m.table)
- _, err := m.conn.Exec(query, id)
- return err
-}
diff --git a/tools/goctl/model/sql/gen/testmodel/nocache/vars.go b/tools/goctl/model/sql/gen/testmodel/nocache/vars.go
deleted file mode 100644
index c6c2f592..00000000
--- a/tools/goctl/model/sql/gen/testmodel/nocache/vars.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package nocache
-
-import "github.com/tal-tech/go-zero/core/stores/sqlx"
-
-var ErrNotFound = sqlx.ErrNotFound
diff --git a/tools/goctl/model/sql/gen/testmodel/snake/test_user_info_model.go b/tools/goctl/model/sql/gen/testmodel/snake/test_user_info_model.go
deleted file mode 100755
index 54de8dee..00000000
--- a/tools/goctl/model/sql/gen/testmodel/snake/test_user_info_model.go
+++ /dev/null
@@ -1,125 +0,0 @@
-package snake
-
-import (
- "database/sql"
- "fmt"
- "strings"
- "time"
-
- "github.com/tal-tech/go-zero/core/stores/cache"
- "github.com/tal-tech/go-zero/core/stores/sqlc"
- "github.com/tal-tech/go-zero/core/stores/sqlx"
- "github.com/tal-tech/go-zero/core/stringx"
- "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
-)
-
-var (
- testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
- testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
- testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
- testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
-
- cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#"
- cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#"
-)
-
-type (
- TestUserInfoModel struct {
- sqlc.CachedConn
- table string
- }
-
- TestUserInfo struct {
- Id int64 `db:"id"`
- Nanosecond int64 `db:"nanosecond"`
- Data string `db:"data"`
- CreateTime time.Time `db:"create_time"`
- UpdateTime time.Time `db:"update_time"`
- }
-)
-
-func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel {
- return &TestUserInfoModel{
- CachedConn: sqlc.NewConn(conn, c),
- table: "test_user_info",
- }
-}
-
-func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
- ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
- return conn.Exec(query, data.Nanosecond, data.Data)
- }, testUserInfoNanosecondKey)
- return ret, err
-}
-
-func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
- var resp TestUserInfo
- err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error {
- query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
- return conn.QueryRow(v, query, id)
- })
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond)
- var resp TestUserInfo
- err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
- query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
- if err := conn.QueryRow(&resp, query, nanosecond); err != nil {
- return nil, err
- }
- return resp.Id, nil
- }, m.queryPrimary)
- switch err {
- case nil:
- return &resp, nil
- case sqlc.ErrNotFound:
- return nil, ErrNotFound
- default:
- return nil, err
- }
-}
-
-func (m *TestUserInfoModel) Update(data TestUserInfo) error {
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id)
- _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
- return conn.Exec(query, data.Nanosecond, data.Data, data.Id)
- }, testUserInfoIdKey)
- return err
-}
-
-func (m *TestUserInfoModel) Delete(id int64) error {
- data, err := m.FindOne(id)
- if err != nil {
- return err
- }
-
- testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
- testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
- _, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
- query := fmt.Sprintf("delete from %s where id = ?", m.table)
- return conn.Exec(query, id)
- }, testUserInfoIdKey, testUserInfoNanosecondKey)
- return err
-}
-
-func (m *TestUserInfoModel) formatPrimary(primary interface{}) string {
- return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary)
-}
-
-func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
- query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
- return conn.QueryRow(v, query, primary)
-}
diff --git a/tools/goctl/model/sql/gen/testmodel/snake/vars.go b/tools/goctl/model/sql/gen/testmodel/snake/vars.go
deleted file mode 100644
index e18aa209..00000000
--- a/tools/goctl/model/sql/gen/testmodel/snake/vars.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package snake
-
-import "github.com/tal-tech/go-zero/core/stores/sqlx"
-
-var ErrNotFound = sqlx.ErrNotFound
diff --git a/tools/goctl/rpc/CHANGELOG.md b/tools/goctl/rpc/CHANGELOG.md
deleted file mode 100644
index ee4429eb..00000000
--- a/tools/goctl/rpc/CHANGELOG.md
+++ /dev/null
@@ -1,20 +0,0 @@
-# Change log
-
-## 2020-10-19
-
-* 增加template
-
-## 2020-09-10
-
-* rpc greet服务一键生成
-* 修复相对路径生成rpc服务package引入错误bug
-* 移除`--shared`参数
-
-## 2020-08-29
-
-* 新增支持windows生成
-
-## 2020-08-27
-
-* 新增支持rpc模板生成
-* 新增支持rpc服务生成
diff --git a/tools/goctl/rpc/README.md b/tools/goctl/rpc/README.md
index 7c5c2d50..d9a35e4f 100644
--- a/tools/goctl/rpc/README.md
+++ b/tools/goctl/rpc/README.md
@@ -7,9 +7,8 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块,支持prot
* 简单易用
* 快速提升开发效率
* 出错率低
-* 支持基于main proto作为相对路径的import
-* 支持map、enum类型
-* 支持any类型
+* 贴近protoc
+
## 快速开始
@@ -19,44 +18,41 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块,支持prot
如生成greet rpc服务:
- ```shell script
+ ```Bash
goctl rpc new greet
```
执行后代码结构如下:
```golang
- └── greet
- ├── etc
- │ └── greet.yaml
- ├── go.mod
- ├── go.sum
- ├── greet
- │ ├── greet.go
- │ ├── greet_mock.go
- │ └── types.go
- ├── greet.go
- ├── greet.proto
- ├── internal
- │ ├── config
- │ │ └── config.go
- │ ├── logic
- │ │ └── pinglogic.go
- │ ├── server
- │ │ └── greetserver.go
- │ └── svc
- │ └── servicecontext.go
- └── pb
- └── greet.pb.go
+.
+├── etc // 配置文件
+│ └── greet.yaml
+├── go.mod
+├── greet // client call
+│ └── greet.go
+├── greet.go // main entry
+├── greet.proto
+└── internal
+ ├── config // 配置声明
+ │ └── config.go
+ ├── greet // pb.go
+ │ └── greet.pb.go
+ ├── logic // logic
+ │ └── pinglogic.go
+ ├── server // pb invoker
+ │ └── greetserver.go
+ └── svc // resource dependency
+ └── servicecontext.go
```
-rpc一键生成常见问题解决见 常见问题解决
+rpc一键生成常见问题解决,见 常见问题解决
### 方式二:通过指定proto生成rpc服务
* 生成proto模板
- ```shell script
+ ```Bash
goctl rpc template -o=user.proto
```
@@ -87,35 +83,10 @@ rpc一键生成常见问题解决见 常见问题
* 生成rpc服务代码
- ```shell script
+ ```Bash
goctl rpc proto -src=user.proto
```
- 代码tree
-
- ```Plain Text
- user
- ├── etc
- │ └── user.json
- ├── internal
- │ ├── config
- │ │ └── config.go
- │ ├── handler
- │ │ ├── loginhandler.go
- │ ├── logic
- │ │ └── loginlogic.go
- │ └── svc
- │ └── servicecontext.go
- ├── pb
- │ └── user.pb.go
- ├── shared
- │ ├── mockusermodel.go
- │ ├── types.go
- │ └── usermodel.go
- ├── user.go
- └── user.proto
- ```
-
## 准备工作
* 安装了go环境
@@ -126,11 +97,11 @@ rpc一键生成常见问题解决见 常见问题
### rpc服务生成用法
-```shell script
+```Bash
goctl rpc proto -h
```
-```shell script
+```Bash
NAME:
goctl rpc proto - generate rpc from proto
@@ -139,35 +110,22 @@ USAGE:
OPTIONS:
--src value, -s value the file path of the proto source file
- --dir value, -d value the target path of the code,default path is "${pwd}". [option]
- --service value, --srv value the name of rpc service. [option]
- --idea whether the command execution environment is from idea plugin. [option]
-
+ --proto_path value, -I value native command of protoc,specify the directory in which to search for imports. [optional]
+ --dir value, -d value the target path of the code,default path is "${pwd}". [optional]
+ --idea whether the command execution environment is from idea plugin. [optional]
```
### 参数说明
-* --src 必填,proto数据源,目前暂时支持单个proto文件生成,这里不支持(不建议)外部依赖
-* --dir 非必填,默认为proto文件所在目录,生成代码的目标目录
-* --service 服务名称,非必填,默认为proto文件所在目录名称,但是,如果proto所在目录为一下结构:
-
- ```shell script
- user
- ├── cmd
- │ └── rpc
- │ └── user.proto
- ```
-
- 则服务名称亦为user,而非proto所在文件夹名称了,这里推荐使用这种结构,可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。
-
- > 注意:这里的shared文件夹名称将会是代码中的package名称。
-
-* --idea 非必填,是否为idea插件中执行,保留字段,终端执行可以忽略
+* --src 必填,proto数据源,目前暂时支持单个proto文件生成
+* --proto_path 可选,protoc原生子命令,用于指定proto import从何处查找,可指定多个路径,如`goctl rpc -I={path1} -I={path2} ...`,在没有import时可不填。当前proto路径不用指定,已经内置,`-I`的详细用法请参考`protoc -h`
+* --dir 可选,默认为proto文件所在目录,生成代码的目标目录
+* --idea 可选,是否为idea插件中执行,终端执行可以忽略
### 开发人员需要做什么
-关注业务代码编写,将重复性、与业务无关的工作交给goctl,生成好rpc服务代码后,开饭人员仅需要修改
+关注业务代码编写,将重复性、与业务无关的工作交给goctl,生成好rpc服务代码后,开发人员仅需要修改
* 服务中的配置文件编写(etc/xx.json、internal/config/config.go)
* 服务中业务逻辑编写(internal/logic/xxlogic.go)
@@ -193,69 +151,54 @@ OPTIONS:
的标识,请注意不要将也写业务性代码写在里面。
-## any和import支持
-* 支持any类型声明
-* 支持import其他proto文件
+## proto import
+* 对于rpc中的requestType和returnType必须在main proto文件定义,对于proto中的message可以像protoc一样import其他proto文件。
- any类型固定import为`google/protobuf/any.proto`,且从${GOPATH}/src中查找,proto的import支持main proto的相对路径的import,且与proto文件对应的pb.go文件必须在proto目录中能被找到。不支持工程外的其他proto文件import。
+proto示例:
-> ⚠️注意: 不支持proto嵌套import,即:被import的proto文件不支持import。
-
-### import书写格式
-import书写格式
-```golang
-// @{package_of_pb}
-import {proto_omport}
-```
-@{package_of_pb}:pb文件的真实import目录。
-{proto_omport}:proto import
-
-
-如:demo中的
-
-```golang
-// @greet/base
-import "base/base.proto";
-```
-
-工程目录结构如下
-```
-greet
-│ ├── base
-│ │ ├── base.pb.go
-│ │ └── base.proto
-│ ├── demo.proto
-│ ├── go.mod
-│ └── go.sum
-```
-
-demo
-```golang
+### 错误import
+```proto
syntax = "proto3";
-import "google/protobuf/any.proto";
-// @greet/base
-import "base/base.proto";
-package stream;
+package greet;
-enum Gender{
- UNKNOWN = 0;
- MAN = 1;
- WOMAN = 2;
+import "base/common.proto"
+
+message Request {
+ string ping = 1;
}
-message StreamResp{
- string name = 2;
- Gender gender = 3;
- google.protobuf.Any details = 5;
- base.StreamReq req = 6;
+message Response {
+ string pong = 1;
}
-service StreamGreeter {
- rpc greet(base.StreamReq) returns (StreamResp);
+
+service Greet {
+ rpc Ping(base.In) returns(base.Out);// request和return 不支持import
}
+
```
+### 正确import
+```proto
+syntax = "proto3";
+
+package greet;
+
+import "base/common.proto"
+
+message Request {
+ base.In in = 1;// 支持import
+}
+
+message Response {
+ base.Out out = 2;// 支持import
+}
+
+service Greet {
+ rpc Ping(Request) returns(Response);
+}
+```
## 常见问题解决(go mod工程)
diff --git a/tools/goctl/rpc/base.pb.go b/tools/goctl/rpc/base.pb.go
deleted file mode 100644
index 0f2653f9..00000000
--- a/tools/goctl/rpc/base.pb.go
+++ /dev/null
@@ -1,108 +0,0 @@
-// Code generated by protoc-gen-go. DO NOT EDIT.
-// source: base.proto
-
-package base
-
-import (
- fmt "fmt"
- proto "github.com/golang/protobuf/proto"
- math "math"
-)
-
-// Reference imports to suppress errors if they are not otherwise used.
-var _ = proto.Marshal
-var _ = fmt.Errorf
-var _ = math.Inf
-
-// This is a compile-time assertion to ensure that this generated file
-// is compatible with the proto package it is being compiled against.
-// A compilation error at this line likely means your copy of the
-// proto package needs to be updated.
-const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
-
-type IdRequest struct {
- Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
- XXX_NoUnkeyedLiteral struct{} `json:"-"`
- XXX_unrecognized []byte `json:"-"`
- XXX_sizecache int32 `json:"-"`
-}
-
-func (m *IdRequest) Reset() { *m = IdRequest{} }
-func (m *IdRequest) String() string { return proto.CompactTextString(m) }
-func (*IdRequest) ProtoMessage() {}
-func (*IdRequest) Descriptor() ([]byte, []int) {
- return fileDescriptor_db1b6b0986796150, []int{0}
-}
-
-func (m *IdRequest) XXX_Unmarshal(b []byte) error {
- return xxx_messageInfo_IdRequest.Unmarshal(m, b)
-}
-func (m *IdRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
- return xxx_messageInfo_IdRequest.Marshal(b, m, deterministic)
-}
-func (m *IdRequest) XXX_Merge(src proto.Message) {
- xxx_messageInfo_IdRequest.Merge(m, src)
-}
-func (m *IdRequest) XXX_Size() int {
- return xxx_messageInfo_IdRequest.Size(m)
-}
-func (m *IdRequest) XXX_DiscardUnknown() {
- xxx_messageInfo_IdRequest.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_IdRequest proto.InternalMessageInfo
-
-func (m *IdRequest) GetId() string {
- if m != nil {
- return m.Id
- }
- return ""
-}
-
-type EmptyResponse struct {
- XXX_NoUnkeyedLiteral struct{} `json:"-"`
- XXX_unrecognized []byte `json:"-"`
- XXX_sizecache int32 `json:"-"`
-}
-
-func (m *EmptyResponse) Reset() { *m = EmptyResponse{} }
-func (m *EmptyResponse) String() string { return proto.CompactTextString(m) }
-func (*EmptyResponse) ProtoMessage() {}
-func (*EmptyResponse) Descriptor() ([]byte, []int) {
- return fileDescriptor_db1b6b0986796150, []int{1}
-}
-
-func (m *EmptyResponse) XXX_Unmarshal(b []byte) error {
- return xxx_messageInfo_EmptyResponse.Unmarshal(m, b)
-}
-func (m *EmptyResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
- return xxx_messageInfo_EmptyResponse.Marshal(b, m, deterministic)
-}
-func (m *EmptyResponse) XXX_Merge(src proto.Message) {
- xxx_messageInfo_EmptyResponse.Merge(m, src)
-}
-func (m *EmptyResponse) XXX_Size() int {
- return xxx_messageInfo_EmptyResponse.Size(m)
-}
-func (m *EmptyResponse) XXX_DiscardUnknown() {
- xxx_messageInfo_EmptyResponse.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_EmptyResponse proto.InternalMessageInfo
-
-func init() {
- proto.RegisterType((*IdRequest)(nil), "base.IdRequest")
- proto.RegisterType((*EmptyResponse)(nil), "base.EmptyResponse")
-}
-
-func init() { proto.RegisterFile("base.proto", fileDescriptor_db1b6b0986796150) }
-
-var fileDescriptor_db1b6b0986796150 = []byte{
- // 91 bytes of a gzipped FileDescriptorProto
- 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4a, 0x4a, 0x2c, 0x4e,
- 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x01, 0xb1, 0x95, 0xa4, 0xb9, 0x38, 0x3d, 0x53,
- 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0xf8, 0xb8, 0x98, 0x32, 0x53, 0x24, 0x18, 0x15,
- 0x18, 0x35, 0x38, 0x83, 0x98, 0x32, 0x53, 0x94, 0xf8, 0xb9, 0x78, 0x5d, 0x73, 0x0b, 0x4a, 0x2a,
- 0x83, 0x52, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x93, 0xd8, 0xc0, 0x5a, 0x8d, 0x01, 0x01, 0x00,
- 0x00, 0xff, 0xff, 0xe1, 0x39, 0x3c, 0x22, 0x48, 0x00, 0x00, 0x00,
-}
diff --git a/tools/goctl/rpc/base.proto b/tools/goctl/rpc/base.proto
deleted file mode 100644
index 501f12f7..00000000
--- a/tools/goctl/rpc/base.proto
+++ /dev/null
@@ -1,11 +0,0 @@
-syntax = "proto3";
-
-package base;
-
-message IdRequest {
- string id = 1;
-}
-
-message EmptyResponse {
-
-}
diff --git a/tools/goctl/rpc/cli/cli.go b/tools/goctl/rpc/cli/cli.go
new file mode 100644
index 00000000..a76dde22
--- /dev/null
+++ b/tools/goctl/rpc/cli/cli.go
@@ -0,0 +1,67 @@
+package cli
+
+import (
+ "errors"
+ "fmt"
+ "path/filepath"
+
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
+ "github.com/urfave/cli"
+)
+
+// Rpc is to generate rpc service code from a proto file by specifying a proto file using flag src,
+// you can specify a target folder for code generation, when the proto file has import, you can specify
+// the import search directory through the proto_path command, for specific usage, please refer to protoc -h
+func Rpc(c *cli.Context) error {
+ src := c.String("src")
+ out := c.String("dir")
+ protoImportPath := c.StringSlice("proto_path")
+ if len(src) == 0 {
+ return errors.New("the proto source can not be nil")
+ }
+ if len(out) == 0 {
+ return errors.New("the target directory can not be nil")
+ }
+ g := generator.NewDefaultRpcGenerator()
+ return g.Generate(src, out, protoImportPath)
+}
+
+// RpcNew is to generate rpc greet service, this greet service can speed
+// up your understanding of the zrpc service structure
+func RpcNew(c *cli.Context) error {
+ name := c.Args().First()
+ ext := filepath.Ext(name)
+ if len(ext) > 0 {
+ return fmt.Errorf("unexpected ext: %s", ext)
+ }
+
+ protoName := name + ".proto"
+ filename := filepath.Join(".", name, protoName)
+ src, err := filepath.Abs(filename)
+ if err != nil {
+ return err
+ }
+
+ err = generator.ProtoTmpl(src)
+ if err != nil {
+ return err
+ }
+
+ workDir := filepath.Dir(src)
+ _, err = execx.Run("go mod init "+name, workDir)
+ if err != nil {
+ return err
+ }
+
+ g := generator.NewDefaultRpcGenerator()
+ return g.Generate(src, filepath.Dir(src), nil)
+}
+
+func RpcTemplate(c *cli.Context) error {
+ name := c.Args().First()
+ if len(name) == 0 {
+ name = "greet.proto"
+ }
+ return generator.ProtoTmpl(name)
+}
diff --git a/tools/goctl/rpc/command/command.go b/tools/goctl/rpc/command/command.go
deleted file mode 100644
index fde97dad..00000000
--- a/tools/goctl/rpc/command/command.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package command
-
-import (
- "fmt"
- "os"
- "path/filepath"
-
- "github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/gen"
- "github.com/tal-tech/go-zero/tools/goctl/util"
- "github.com/urfave/cli"
-)
-
-func Rpc(c *cli.Context) error {
- rpcCtx := ctx.MustCreateRpcContextFromCli(c)
- generator := gen.NewDefaultRpcGenerator(rpcCtx)
- rpcCtx.Must(generator.Generate())
- return nil
-}
-
-func RpcTemplate(c *cli.Context) error {
- out := c.String("out")
- idea := c.Bool("idea")
- generator := gen.NewRpcTemplate(out, idea)
- generator.MustGenerate(true)
- return nil
-}
-
-func RpcNew(c *cli.Context) error {
- idea := c.Bool("idea")
- arg := c.Args().First()
- if len(arg) == 0 {
- arg = "greet"
- }
- abs, err := filepath.Abs(arg)
- if err != nil {
- return err
- }
- _, err = os.Stat(abs)
- if err != nil {
- if !os.IsNotExist(err) {
- return err
- }
- err = util.MkdirIfNotExist(abs)
- if err != nil {
- return err
- }
- }
-
- dir := filepath.Base(filepath.Clean(abs))
-
- protoSrc := filepath.Join(abs, fmt.Sprintf("%v.proto", dir))
- templateGenerator := gen.NewRpcTemplate(protoSrc, idea)
- templateGenerator.MustGenerate(false)
-
- rpcCtx := ctx.MustCreateRpcContext(protoSrc, "", "", idea)
- generator := gen.NewDefaultRpcGenerator(rpcCtx)
- rpcCtx.Must(generator.Generate())
- return nil
-}
diff --git a/tools/goctl/rpc/ctx/ctx.go b/tools/goctl/rpc/ctx/ctx.go
deleted file mode 100644
index 5facee42..00000000
--- a/tools/goctl/rpc/ctx/ctx.go
+++ /dev/null
@@ -1,99 +0,0 @@
-package ctx
-
-import (
- "fmt"
- "path/filepath"
- "runtime"
- "strings"
-
- "github.com/tal-tech/go-zero/core/logx"
- "github.com/tal-tech/go-zero/tools/goctl/util"
- "github.com/tal-tech/go-zero/tools/goctl/util/console"
- "github.com/tal-tech/go-zero/tools/goctl/util/project"
- "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
- "github.com/tal-tech/go-zero/tools/goctl/vars"
- "github.com/urfave/cli"
-)
-
-const (
- flagSrc = "src"
- flagDir = "dir"
- flagService = "service"
- flagIdea = "idea"
-)
-
-type RpcContext struct {
- ProjectPath string
- ProjectName stringx.String
- ServiceName stringx.String
- CurrentPath string
- Module string
- ProtoFileSrc string
- ProtoSource string
- TargetDir string
- IsInGoEnv bool
- console.Console
-}
-
-func MustCreateRpcContext(protoSrc, targetDir, serviceName string, idea bool) *RpcContext {
- log := console.NewConsole(idea)
-
- if stringx.From(protoSrc).IsEmptyOrSpace() {
- log.Fatalln("expected proto source, but nothing found")
- }
- srcFp, err := filepath.Abs(protoSrc)
- log.Must(err)
-
- if !util.FileExists(srcFp) {
- log.Fatalln("%s is not exists", srcFp)
- }
- current := filepath.Dir(srcFp)
- if stringx.From(targetDir).IsEmptyOrSpace() {
- targetDir = current
- }
- targetDirFp, err := filepath.Abs(targetDir)
- log.Must(err)
-
- if stringx.From(serviceName).IsEmptyOrSpace() {
- serviceName = getServiceFromRpcStructure(targetDirFp)
- }
- serviceNameString := stringx.From(serviceName)
- if serviceNameString.IsEmptyOrSpace() {
- log.Fatalln("service name not found")
- }
-
- info, err := project.Prepare(targetDir, true)
- log.Must(err)
-
- return &RpcContext{
- ProjectPath: info.Path,
- ProjectName: stringx.From(info.Name),
- ServiceName: serviceNameString,
- CurrentPath: current,
- Module: info.GoMod.Module,
- ProtoFileSrc: srcFp,
- ProtoSource: filepath.Base(srcFp),
- TargetDir: targetDirFp,
- IsInGoEnv: info.IsInGoEnv,
- Console: log,
- }
-}
-func MustCreateRpcContextFromCli(ctx *cli.Context) *RpcContext {
- os := runtime.GOOS
- switch os {
- case vars.OsMac, vars.OsLinux, vars.OsWindows:
- default:
- logx.Must(fmt.Errorf("unexpected os: %s", os))
- }
- protoSrc := ctx.String(flagSrc)
- targetDir := ctx.String(flagDir)
- serviceName := ctx.String(flagService)
- idea := ctx.Bool(flagIdea)
- return MustCreateRpcContext(protoSrc, targetDir, serviceName, idea)
-}
-
-func getServiceFromRpcStructure(targetDir string) string {
- targetDir = filepath.Clean(targetDir)
- suffix := filepath.Join("cmd", "rpc")
- return filepath.Base(strings.TrimSuffix(targetDir, suffix))
-}
diff --git a/tools/goctl/rpc/execx/execx.go b/tools/goctl/rpc/execx/execx.go
index 039a94c7..9b7e482f 100644
--- a/tools/goctl/rpc/execx/execx.go
+++ b/tools/goctl/rpc/execx/execx.go
@@ -6,7 +6,9 @@ import (
"fmt"
"os/exec"
"runtime"
+ "strings"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
)
@@ -24,17 +26,17 @@ func Run(arg string, dir string) (string, error) {
if len(dir) > 0 {
cmd.Dir = dir
}
- dtsout := new(bytes.Buffer)
+ stdout := new(bytes.Buffer)
stderr := new(bytes.Buffer)
- cmd.Stdout = dtsout
+ cmd.Stdout = stdout
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
if stderr.Len() > 0 {
- return "", errors.New(stderr.String())
+ return "", errors.New(strings.TrimSuffix(stderr.String(), util.NL))
}
return "", err
}
- return dtsout.String(), nil
+ return strings.TrimSuffix(stdout.String(), util.NL), nil
}
diff --git a/tools/goctl/rpc/gen/gen.go b/tools/goctl/rpc/gen/gen.go
deleted file mode 100644
index 172d9013..00000000
--- a/tools/goctl/rpc/gen/gen.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package gen
-
-import (
- "github.com/logrusorgru/aurora"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-)
-
-const (
- dirTarget = "dirTarget"
- dirConfig = "config"
- dirEtc = "etc"
- dirSvc = "svc"
- dirServer = "server"
- dirLogic = "logic"
- dirPb = "pb"
- dirInternal = "internal"
- fileConfig = "config.go"
- fileServiceContext = "servicecontext.go"
-)
-
-type defaultRpcGenerator struct {
- dirM map[string]string
- Ctx *ctx.RpcContext
- ast *parser.PbAst
-}
-
-func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator {
- return &defaultRpcGenerator{
- Ctx: ctx,
- }
-}
-
-func (g *defaultRpcGenerator) Generate() (err error) {
- g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」")
- g.Ctx.Warning("-> generating rpc code ...")
- defer func() {
- if err == nil {
- g.Ctx.MarkDone()
- }
- }()
- err = g.createDir()
- if err != nil {
- return
- }
-
- err = g.initGoMod()
- if err != nil {
- return
- }
-
- err = g.genEtc()
- if err != nil {
- return
- }
-
- err = g.genPb()
- if err != nil {
- return
- }
-
- err = g.genConfig()
- if err != nil {
- return
- }
-
- err = g.genSvc()
- if err != nil {
- return
- }
-
- err = g.genLogic()
- if err != nil {
- return
- }
-
- err = g.genHandler()
- if err != nil {
- return
- }
-
- err = g.genMain()
- if err != nil {
- return
- }
-
- err = g.genCall()
- if err != nil {
- return
- }
-
- return
-}
diff --git a/tools/goctl/rpc/gen/gencall.go b/tools/goctl/rpc/gen/gencall.go
deleted file mode 100644
index d8a62339..00000000
--- a/tools/goctl/rpc/gen/gencall.go
+++ /dev/null
@@ -1,224 +0,0 @@
-package gen
-
-import (
- "fmt"
- "path/filepath"
- "strings"
-
- "github.com/tal-tech/go-zero/core/collection"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
- "github.com/tal-tech/go-zero/tools/goctl/util"
-)
-
-const (
- typesFilename = "types.go"
- callTemplateText = `{{.head}}
-
-//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
-
-package {{.filePackage}}
-
-import (
- "context"
-
- {{.package}}
-
- "github.com/tal-tech/go-zero/core/jsonx"
- "github.com/tal-tech/go-zero/zrpc"
-)
-
-type (
- {{.serviceName}} interface {
- {{.interface}}
- }
-
- default{{.serviceName}} struct {
- cli zrpc.Client
- }
-)
-
-func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
- return &default{{.serviceName}}{
- cli: cli,
- }
-}
-
-{{.functions}}
-`
- callTemplateTypes = `{{.head}}
-
-package {{.filePackage}}
-
-import "errors"
-
-var errJsonConvert = errors.New("json convert error")
-
-{{.const}}
-
-{{.types}}
-`
- callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
-{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
-
- callFunctionTemplate = `
-{{if .hasComment}}{{.comment}}{{end}}
-func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
- var request {{.pbRequest}}
- bts, err := jsonx.Marshal(in)
- if err != nil {
- return nil, errJsonConvert
- }
-
- err = jsonx.Unmarshal(bts, &request)
- if err != nil {
- return nil, errJsonConvert
- }
-
- client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
- resp, err := client.{{.method}}(ctx, &request)
- if err != nil{
- return nil, err
- }
-
- var ret {{.pbResponse}}
- bts, err = jsonx.Marshal(resp)
- if err != nil{
- return nil, errJsonConvert
- }
-
- err = jsonx.Unmarshal(bts, &ret)
- if err != nil{
- return nil, errJsonConvert
- }
-
- return &ret, nil
-}
-`
-)
-
-func (g *defaultRpcGenerator) genCall() error {
- file := g.ast
- if len(file.Service) == 0 {
- return nil
- }
- if len(file.Service) > 1 {
- return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service))
- }
-
- typeCode, err := file.GenTypesCode()
- if err != nil {
- return err
- }
-
- constLit, err := file.GenEnumCode()
- if err != nil {
- return err
- }
-
- service := file.Service[0]
- callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
- if err = util.MkdirIfNotExist(callPath); err != nil {
- return err
- }
-
- filename := filepath.Join(callPath, typesFilename)
- head := util.GetHead(g.Ctx.ProtoSource)
- text, err := util.LoadTemplate(category, callTypesTemplateFile, callTemplateTypes)
- if err != nil {
- return err
- }
- err = util.With("types").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
- "head": head,
- "const": constLit,
- "filePackage": service.Name.Lower(),
- "serviceName": g.Ctx.ServiceName.Title(),
- "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
- "types": typeCode,
- }, filename, true)
- if err != nil {
- return err
- }
-
- filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
- functions, importList, err := g.genFunction(service)
- if err != nil {
- return err
- }
-
- iFunctions, err := g.getInterfaceFuncs(service)
- if err != nil {
- return err
- }
- text, err = util.LoadTemplate(category, callTemplateFile, callTemplateText)
- if err != nil {
- return err
- }
- err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
- "name": service.Name.Lower(),
- "head": head,
- "filePackage": service.Name.Lower(),
- "package": strings.Join(importList, util.NL),
- "serviceName": service.Name.Title(),
- "functions": strings.Join(functions, util.NL),
- "interface": strings.Join(iFunctions, util.NL),
- }, filename, true)
- return err
-}
-
-func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
- file := g.ast
- pkgName := file.Package
- functions := make([]string, 0)
- imports := collection.NewSet()
- imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
- for _, method := range service.Funcs {
- imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
- text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
- if err != nil {
- return nil, nil, err
- }
- buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
- "rpcServiceName": service.Name.Title(),
- "method": method.Name.Title(),
- "package": pkgName,
- "pbRequestName": method.ParameterIn.Name,
- "pbRequest": method.ParameterIn.Expression,
- "pbResponse": method.ParameterOut.Name,
- "hasComment": method.HaveDoc(),
- "comment": method.GetDoc(),
- })
- if err != nil {
- return nil, nil, err
- }
-
- functions = append(functions, buffer.String())
- }
- return functions, imports.KeysStr(), nil
-}
-
-func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
- functions := make([]string, 0)
-
- for _, method := range service.Funcs {
- text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
- if err != nil {
- return nil, err
- }
-
- buffer, err := util.With("interfaceFn").Parse(text).Execute(
- map[string]interface{}{
- "hasComment": method.HaveDoc(),
- "comment": method.GetDoc(),
- "method": method.Name.Title(),
- "pbRequest": method.ParameterIn.Name,
- "pbResponse": method.ParameterOut.Name,
- })
- if err != nil {
- return nil, err
- }
-
- functions = append(functions, buffer.String())
- }
-
- return functions, nil
-}
diff --git a/tools/goctl/rpc/gen/gendir.go b/tools/goctl/rpc/gen/gendir.go
deleted file mode 100644
index bb0bac22..00000000
--- a/tools/goctl/rpc/gen/gendir.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package gen
-
-import (
- "path/filepath"
- "runtime"
- "strings"
-
- "github.com/tal-tech/go-zero/tools/goctl/util"
- "github.com/tal-tech/go-zero/tools/goctl/vars"
-)
-
-// target
-// ├── etc
-// ├── internal
-// │ ├── config
-// │ ├── handler
-// │ ├── logic
-// │ ├── pb
-// │ └── svc
-func (g *defaultRpcGenerator) createDir() error {
- ctx := g.Ctx
- m := make(map[string]string)
- m[dirTarget] = ctx.TargetDir
- m[dirEtc] = filepath.Join(ctx.TargetDir, dirEtc)
- m[dirInternal] = filepath.Join(ctx.TargetDir, dirInternal)
- m[dirConfig] = filepath.Join(ctx.TargetDir, dirInternal, dirConfig)
- m[dirServer] = filepath.Join(ctx.TargetDir, dirInternal, dirServer)
- m[dirLogic] = filepath.Join(ctx.TargetDir, dirInternal, dirLogic)
- m[dirPb] = filepath.Join(ctx.TargetDir, dirPb)
- m[dirSvc] = filepath.Join(ctx.TargetDir, dirInternal, dirSvc)
- for _, d := range m {
- err := util.MkdirIfNotExist(d)
- if err != nil {
- return err
- }
- }
- g.dirM = m
- return nil
-}
-
-func (g *defaultRpcGenerator) mustGetPackage(dir string) string {
- target := g.dirM[dir]
- projectPath := g.Ctx.ProjectPath
- relativePath := strings.TrimPrefix(target, projectPath)
- os := runtime.GOOS
- switch os {
- case vars.OsWindows:
- relativePath = filepath.ToSlash(relativePath)
- case vars.OsMac, vars.OsLinux:
- default:
- g.Ctx.Fatalln("unexpected os: %s", os)
- }
- return g.Ctx.Module + relativePath
-}
diff --git a/tools/goctl/rpc/gen/genlogic.go b/tools/goctl/rpc/gen/genlogic.go
deleted file mode 100644
index 969205ce..00000000
--- a/tools/goctl/rpc/gen/genlogic.go
+++ /dev/null
@@ -1,109 +0,0 @@
-package gen
-
-import (
- "fmt"
- "path/filepath"
- "strings"
-
- "github.com/tal-tech/go-zero/core/collection"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
- "github.com/tal-tech/go-zero/tools/goctl/util"
-)
-
-const (
- logicTemplate = `package logic
-
-import (
- "context"
-
- {{.imports}}
-
- "github.com/tal-tech/go-zero/core/logx"
-)
-
-type {{.logicName}} struct {
- ctx context.Context
- svcCtx *svc.ServiceContext
- logx.Logger
-}
-
-func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
- return &{{.logicName}}{
- ctx: ctx,
- svcCtx: svcCtx,
- Logger: logx.WithContext(ctx),
- }
-}
-{{.functions}}
-`
- logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
-func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
- // todo: add your logic here and delete this line
-
- return &{{.responseType}}{}, nil
-}
-`
-)
-
-func (g *defaultRpcGenerator) genLogic() error {
- logicPath := g.dirM[dirLogic]
- protoPkg := g.ast.Package
- service := g.ast.Service
- for _, item := range service {
- for _, method := range item.Funcs {
- logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
- filename := filepath.Join(logicPath, logicName)
- functions, importList, err := g.genLogicFunction(protoPkg, method)
- if err != nil {
- return err
- }
- imports := collection.NewSet()
- svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
- imports.AddStr(svcImport)
- imports.AddStr(importList...)
- text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
- if err != nil {
- return err
- }
- err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
- "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
- "functions": functions,
- "imports": strings.Join(imports.KeysStr(), util.NL),
- }, filename, false)
- if err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) {
- var functions = make([]string, 0)
- var imports = collection.NewSet()
- if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName {
- imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb)))
- }
- imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
- imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
- text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
- if err != nil {
- return "", nil, err
- }
-
- buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
- "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
- "method": method.Name.Title(),
- "request": method.ParameterIn.StarExpression,
- "response": method.ParameterOut.StarExpression,
- "responseType": method.ParameterOut.Expression,
- "hasComment": method.HaveDoc(),
- "comment": method.GetDoc(),
- })
- if err != nil {
- return "", nil, err
- }
-
- functions = append(functions, buffer.String())
- return strings.Join(functions, util.NL), imports.KeysStr(), nil
-}
diff --git a/tools/goctl/rpc/gen/genmain.go b/tools/goctl/rpc/gen/genmain.go
deleted file mode 100644
index 653e4da5..00000000
--- a/tools/goctl/rpc/gen/genmain.go
+++ /dev/null
@@ -1,85 +0,0 @@
-package gen
-
-import (
- "fmt"
- "path/filepath"
- "strings"
-
- "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
- "github.com/tal-tech/go-zero/tools/goctl/util"
-)
-
-const mainTemplate = `{{.head}}
-
-package main
-
-import (
- "flag"
- "fmt"
-
- {{.imports}}
-
- "github.com/tal-tech/go-zero/core/conf"
- "github.com/tal-tech/go-zero/zrpc"
- "google.golang.org/grpc"
-)
-
-var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
-
-func main() {
- flag.Parse()
-
- var c config.Config
- conf.MustLoad(*configFile, &c)
- ctx := svc.NewServiceContext(c)
- {{.srv}}
-
- s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
- {{.registers}}
- })
- defer s.Stop()
-
- fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
- s.Start()
-}
-`
-
-func (g *defaultRpcGenerator) genMain() error {
- mainPath := g.dirM[dirTarget]
- file := g.ast
- pkg := file.Package
-
- fileName := filepath.Join(mainPath, fmt.Sprintf("%v.go", g.Ctx.ServiceName.Lower()))
- imports := make([]string, 0)
- pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
- svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
- remoteImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirServer))
- configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig))
- imports = append(imports, configImport, pbImport, remoteImport, svcImport)
- srv, registers := g.genServer(pkg, file.Service)
- head := util.GetHead(g.Ctx.ProtoSource)
- text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate)
- if err != nil {
- return err
- }
-
- return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
- "head": head,
- "package": pkg,
- "serviceName": g.Ctx.ServiceName.Lower(),
- "srv": srv,
- "registers": registers,
- "imports": strings.Join(imports, util.NL),
- }, fileName, false)
-}
-
-func (g *defaultRpcGenerator) genServer(pkg string, list []*parser.RpcService) (string, string) {
- list1 := make([]string, 0)
- list2 := make([]string, 0)
- for _, item := range list {
- name := item.Name.UnTitle()
- list1 = append(list1, fmt.Sprintf("%sSrv := server.New%sServer(ctx)", name, item.Name.Title()))
- list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name))
- }
- return strings.Join(list1, util.NL), strings.Join(list2, util.NL)
-}
diff --git a/tools/goctl/rpc/gen/genpb.go b/tools/goctl/rpc/gen/genpb.go
deleted file mode 100644
index be772e34..00000000
--- a/tools/goctl/rpc/gen/genpb.go
+++ /dev/null
@@ -1,82 +0,0 @@
-package gen
-
-import (
- "bytes"
- "fmt"
- "path/filepath"
- "strings"
-
- "github.com/tal-tech/go-zero/core/collection"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
-)
-
-const (
- protocCmd = "protoc"
- grpcPluginCmd = "--go_out=plugins=grpc"
-)
-
-func (g *defaultRpcGenerator) genPb() error {
- pbPath := g.dirM[dirPb]
- // deprecated: containsAny will be removed in the feature
- imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc)
- if err != nil {
- return err
- }
-
- err = g.protocGenGo(pbPath, imports)
- if err != nil {
- return err
- }
- ast, err := parser.Transfer(g.Ctx.ProtoFileSrc, pbPath, imports, g.Ctx.Console)
- if err != nil {
- return err
- }
- ast.ContainsAny = containsAny
-
- if len(ast.Service) == 0 {
- return fmt.Errorf("service not found")
- }
- g.ast = ast
- return nil
-}
-
-func (g *defaultRpcGenerator) protocGenGo(target string, imports []*parser.Import) error {
- dir := filepath.Dir(g.Ctx.ProtoFileSrc)
- // cmd join,see the document of proto generating class @https://developers.google.com/protocol-buffers/docs/proto3#generating
- // template: protoc -I=${import_path} -I=${other_import_path} -I=${...} --go_out=plugins=grpc,M${pb_package_kv}, M${...} :${target_dir}
- // eg: protoc -I=${GOPATH}/src -I=. example.proto --go_out=plugins=grpc,Mbase/base.proto=github.com/go-zero/base.proto:.
- // note: the external import out of the project which are found in ${GOPATH}/src so far.
-
- buffer := new(bytes.Buffer)
- buffer.WriteString(protocCmd + " ")
- targetImportFiltered := collection.NewSet()
-
- for _, item := range imports {
- buffer.WriteString(fmt.Sprintf("-I=%s ", item.OriginalDir))
- if len(item.BridgeImport) == 0 {
- continue
- }
- targetImportFiltered.AddStr(item.BridgeImport)
-
- }
- buffer.WriteString("-I=${GOPATH}/src ")
- buffer.WriteString(fmt.Sprintf("-I=%s %s ", dir, g.Ctx.ProtoFileSrc))
-
- buffer.WriteString(grpcPluginCmd)
- if targetImportFiltered.Count() > 0 {
- buffer.WriteString(fmt.Sprintf(",%v", strings.Join(targetImportFiltered.KeysStr(), ",")))
- }
- buffer.WriteString(":" + target)
- g.Ctx.Debug("-> " + buffer.String())
- stdout, err := execx.Run(buffer.String(), "")
- if err != nil {
- return err
- }
-
- if len(stdout) > 0 {
- g.Ctx.Info(stdout)
- }
-
- return nil
-}
diff --git a/tools/goctl/rpc/gen/genserver.go b/tools/goctl/rpc/gen/genserver.go
deleted file mode 100644
index a9ef3418..00000000
--- a/tools/goctl/rpc/gen/genserver.go
+++ /dev/null
@@ -1,116 +0,0 @@
-package gen
-
-import (
- "fmt"
- "path/filepath"
- "strings"
-
- "github.com/tal-tech/go-zero/core/collection"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
- "github.com/tal-tech/go-zero/tools/goctl/util"
-)
-
-const (
- serverTemplate = `{{.head}}
-
-package server
-
-import (
- "context"
-
- {{.imports}}
-)
-
-type {{.types}}
-
-func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
- return &{{.server}}Server{
- svcCtx: svcCtx,
- }
-}
-
-{{.funcs}}
-`
- functionTemplate = `
-{{if .hasComment}}{{.comment}}{{end}}
-func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
- l := logic.New{{.logicName}}(ctx,s.svcCtx)
- return l.{{.method}}(in)
-}
-`
- typeFmt = `%sServer struct {
- svcCtx *svc.ServiceContext
- }`
-)
-
-func (g *defaultRpcGenerator) genHandler() error {
- serverPath := g.dirM[dirServer]
- file := g.ast
- logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
- svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
- imports := collection.NewSet()
- imports.AddStr(logicImport, svcImport)
-
- head := util.GetHead(g.Ctx.ProtoSource)
- for _, service := range file.Service {
- filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
- serverFile := filepath.Join(serverPath, filename)
- funcList, importList, err := g.genFunctions(service)
- if err != nil {
- return err
- }
-
- imports.AddStr(importList...)
- text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
- if err != nil {
- return err
- }
-
- err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
- "head": head,
- "types": fmt.Sprintf(typeFmt, service.Name.Title()),
- "server": service.Name.Title(),
- "imports": strings.Join(imports.KeysStr(), util.NL),
- "funcs": strings.Join(funcList, util.NL),
- }, serverFile, true)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) {
- file := g.ast
- pkg := file.Package
- var functionList []string
- imports := collection.NewSet()
- for _, method := range service.Funcs {
- if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg {
- imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)))
- }
- imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
- imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
- text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
- if err != nil {
- return nil, nil, err
- }
-
- buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
- "server": service.Name.Title(),
- "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
- "method": method.Name.Title(),
- "package": pkg,
- "request": method.ParameterIn.StarExpression,
- "response": method.ParameterOut.StarExpression,
- "hasComment": method.HaveDoc(),
- "comment": method.GetDoc(),
- })
- if err != nil {
- return nil, nil, err
- }
-
- functionList = append(functionList, buffer.String())
- }
- return functionList, imports.KeysStr(), nil
-}
diff --git a/tools/goctl/rpc/gen/gomod.go b/tools/goctl/rpc/gen/gomod.go
deleted file mode 100644
index 070b4d9e..00000000
--- a/tools/goctl/rpc/gen/gomod.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package gen
-
-import (
- "fmt"
-
- "github.com/tal-tech/go-zero/core/logx"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
-)
-
-func (g *defaultRpcGenerator) initGoMod() error {
- if !g.Ctx.IsInGoEnv {
- projectDir := g.dirM[dirTarget]
- cmd := fmt.Sprintf("go mod init %s", g.Ctx.ProjectName.Source())
- output, err := execx.Run(fmt.Sprintf(cmd), projectDir)
- if err != nil {
- logx.Error(err)
- return err
- }
- g.Ctx.Info(output)
- }
- return nil
-}
diff --git a/tools/goctl/rpc/gen_test.go b/tools/goctl/rpc/gen_test.go
deleted file mode 100644
index e32e99b2..00000000
--- a/tools/goctl/rpc/gen_test.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package base
-
-import (
- "path/filepath"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
- "github.com/tal-tech/go-zero/tools/goctl/util/console"
-)
-
-func TestParseImport(t *testing.T) {
- src, _ := filepath.Abs("./test.proto")
- base, _ := filepath.Abs("./base.proto")
- imports, containsAny, err := parser.ParseImport(src)
- assert.Nil(t, err)
- assert.Equal(t, true, containsAny)
- assert.Equal(t, 1, len(imports))
- assert.Equal(t, "github.com/tal-tech/go-zero/tools/goctl/rpc", imports[0].PbImportName)
- assert.Equal(t, base, imports[0].OriginalProtoPath)
-}
-
-func TestTransfer(t *testing.T) {
- src, _ := filepath.Abs("./test.proto")
- abs, _ := filepath.Abs("./test")
- imports, _, _ := parser.ParseImport(src)
- proto, err := parser.Transfer(src, abs, imports, console.NewConsole(false))
- assert.Nil(t, err)
- assert.Equal(t, 1, len(proto.Service))
- assert.Equal(t, "Greeter", proto.Service[0].Name.Source())
- assert.Equal(t, 5, len(proto.Structure))
- data, ok := proto.Structure["map"]
- assert.Equal(t, true, ok)
- assert.Equal(t, "M", data.Field[0].Name.Source())
-}
diff --git a/tools/goctl/rpc/generator/base/common.pb.go b/tools/goctl/rpc/generator/base/common.pb.go
new file mode 100644
index 00000000..455529fe
--- /dev/null
+++ b/tools/goctl/rpc/generator/base/common.pb.go
@@ -0,0 +1,75 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: common.proto
+
+package common
+
+import (
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type User struct {
+ Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *User) Reset() { *m = User{} }
+func (m *User) String() string { return proto.CompactTextString(m) }
+func (*User) ProtoMessage() {}
+func (*User) Descriptor() ([]byte, []int) {
+ return fileDescriptor_555bd8c177793206, []int{0}
+}
+
+func (m *User) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_User.Unmarshal(m, b)
+}
+func (m *User) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_User.Marshal(b, m, deterministic)
+}
+func (m *User) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_User.Merge(m, src)
+}
+func (m *User) XXX_Size() int {
+ return xxx_messageInfo_User.Size(m)
+}
+func (m *User) XXX_DiscardUnknown() {
+ xxx_messageInfo_User.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_User proto.InternalMessageInfo
+
+func (m *User) GetName() string {
+ if m != nil {
+ return m.Name
+ }
+ return ""
+}
+
+func init() {
+ proto.RegisterType((*User)(nil), "common.User")
+}
+
+func init() { proto.RegisterFile("common.proto", fileDescriptor_555bd8c177793206) }
+
+var fileDescriptor_555bd8c177793206 = []byte{
+ // 72 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0xce, 0xcf, 0xcd,
+ 0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0xa4, 0xb8, 0x58,
+ 0x42, 0x8b, 0x53, 0x8b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, 0x25, 0x18, 0x15, 0x18,
+ 0x35, 0x38, 0x83, 0xc0, 0xec, 0x24, 0x36, 0xb0, 0x52, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
+ 0x2c, 0x6d, 0x58, 0x59, 0x3a, 0x00, 0x00, 0x00,
+}
diff --git a/tools/goctl/rpc/generator/base/common.proto b/tools/goctl/rpc/generator/base/common.proto
new file mode 100644
index 00000000..42c9107e
--- /dev/null
+++ b/tools/goctl/rpc/generator/base/common.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package common;
+
+message User {
+ string name = 1;
+}
diff --git a/tools/goctl/rpc/generator/defaultgenerator.go b/tools/goctl/rpc/generator/defaultgenerator.go
new file mode 100644
index 00000000..eb5853d1
--- /dev/null
+++ b/tools/goctl/rpc/generator/defaultgenerator.go
@@ -0,0 +1,33 @@
+package generator
+
+import (
+ "os/exec"
+
+ "github.com/tal-tech/go-zero/tools/goctl/util/console"
+)
+
+type defaultGenerator struct {
+ log console.Console
+}
+
+func NewDefaultGenerator() *defaultGenerator {
+ log := console.NewColorConsole()
+ return &defaultGenerator{
+ log: log,
+ }
+}
+
+func (g *defaultGenerator) Prepare() error {
+ _, err := exec.LookPath("go")
+ if err != nil {
+ return err
+ }
+
+ _, err = exec.LookPath("protoc")
+ if err != nil {
+ return err
+ }
+
+ _, err = exec.LookPath("protoc-gen-go")
+ return err
+}
diff --git a/tools/goctl/rpc/generator/filename.go b/tools/goctl/rpc/generator/filename.go
new file mode 100644
index 00000000..89617ded
--- /dev/null
+++ b/tools/goctl/rpc/generator/filename.go
@@ -0,0 +1,11 @@
+package generator
+
+import (
+ "strings"
+
+ "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
+)
+
+func formatFilename(filename string) string {
+ return strings.ToLower(stringx.From(filename).ToCamel())
+}
diff --git a/tools/goctl/rpc/generator/gen.go b/tools/goctl/rpc/generator/gen.go
new file mode 100644
index 00000000..d09c6394
--- /dev/null
+++ b/tools/goctl/rpc/generator/gen.go
@@ -0,0 +1,98 @@
+package generator
+
+import (
+ "path/filepath"
+
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/console"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+type RpcGenerator struct {
+ g Generator
+}
+
+func NewDefaultRpcGenerator() *RpcGenerator {
+ return NewRpcGenerator(NewDefaultGenerator())
+}
+
+func NewRpcGenerator(g Generator) *RpcGenerator {
+ return &RpcGenerator{
+ g: g,
+ }
+}
+
+func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) error {
+ abs, err := filepath.Abs(target)
+ if err != nil {
+ return err
+ }
+
+ err = util.MkdirIfNotExist(abs)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.Prepare()
+ if err != nil {
+ return err
+ }
+
+ projectCtx, err := ctx.Prepare(abs)
+ if err != nil {
+ return err
+ }
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse(src)
+ if err != nil {
+ return err
+ }
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenEtc(dirCtx, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenPb(dirCtx, protoImportPath, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenConfig(dirCtx, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenSvc(dirCtx, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenLogic(dirCtx, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenServer(dirCtx, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenMain(dirCtx, proto)
+ if err != nil {
+ return err
+ }
+
+ err = g.g.GenCall(dirCtx, proto)
+
+ console.NewColorConsole().MarkDone()
+
+ return err
+}
diff --git a/tools/goctl/rpc/generator/gen_test.go b/tools/goctl/rpc/generator/gen_test.go
new file mode 100644
index 00000000..25c2126d
--- /dev/null
+++ b/tools/goctl/rpc/generator/gen_test.go
@@ -0,0 +1,104 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+)
+
+func TestRpcGenerateCaseNilImport(t *testing.T) {
+ dispatcher := NewDefaultGenerator()
+ if err := dispatcher.Prepare(); err == nil {
+ g := NewRpcGenerator(dispatcher)
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ err = g.Generate("./test_stream.proto", abs, nil)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+ assert.Nil(t, err)
+
+ _, err = execx.Run("go test "+abs, abs)
+ assert.Nil(t, err)
+ }
+}
+
+func TestRpcGenerateCaseOption(t *testing.T) {
+ dispatcher := NewDefaultGenerator()
+ if err := dispatcher.Prepare(); err == nil {
+ g := NewRpcGenerator(dispatcher)
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ err = g.Generate("./test_option.proto", abs, nil)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+ assert.Nil(t, err)
+
+ _, err = execx.Run("go test "+abs, abs)
+ assert.Nil(t, err)
+ }
+}
+
+func TestRpcGenerateCaseWordOption(t *testing.T) {
+ dispatcher := NewDefaultGenerator()
+ if err := dispatcher.Prepare(); err == nil {
+ g := NewRpcGenerator(dispatcher)
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ err = g.Generate("./test_word_option.proto", abs, nil)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+ assert.Nil(t, err)
+
+ _, err = execx.Run("go test "+abs, abs)
+ assert.Nil(t, err)
+ }
+}
+
+// test keyword go
+func TestRpcGenerateCaseGoOption(t *testing.T) {
+ dispatcher := NewDefaultGenerator()
+ if err := dispatcher.Prepare(); err == nil {
+ g := NewRpcGenerator(dispatcher)
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ err = g.Generate("./test_go_option.proto", abs, nil)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+ assert.Nil(t, err)
+
+ _, err = execx.Run("go test "+abs, abs)
+ assert.Nil(t, err)
+ }
+}
+
+func TestRpcGenerateCaseImport(t *testing.T) {
+ dispatcher := NewDefaultGenerator()
+ if err := dispatcher.Prepare(); err == nil {
+ g := NewRpcGenerator(dispatcher)
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ err = g.Generate("./test_import.proto", abs, []string{"./base"})
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+ assert.Nil(t, err)
+
+ _, err = execx.Run("go test "+abs, abs)
+ assert.True(t, func() bool {
+ return strings.Contains(err.Error(), "package base is not in GOROOT")
+ }())
+ }
+}
diff --git a/tools/goctl/rpc/generator/gencall.go b/tools/goctl/rpc/generator/gencall.go
new file mode 100644
index 00000000..4afc1f04
--- /dev/null
+++ b/tools/goctl/rpc/generator/gencall.go
@@ -0,0 +1,154 @@
+package generator
+
+import (
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
+)
+
+const (
+ callTemplateText = `{{.head}}
+
+//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
+
+package {{.filePackage}}
+
+import (
+ "context"
+
+ {{.package}}
+
+ "github.com/tal-tech/go-zero/zrpc"
+)
+
+type (
+ {{.alias}}
+
+ {{.serviceName}} interface {
+ {{.interface}}
+ }
+
+ default{{.serviceName}} struct {
+ cli zrpc.Client
+ }
+)
+
+func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
+ return &default{{.serviceName}}{
+ cli: cli,
+ }
+}
+
+{{.functions}}
+`
+
+ callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
+{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
+
+ callFunctionTemplate = `
+{{if .hasComment}}{{.comment}}{{end}}
+func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
+ client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
+ return client.{{.method}}(ctx, in)
+}
+`
+)
+
+func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
+ dir := ctx.GetCall()
+ service := proto.Service
+ head := util.GetHead(proto.Name)
+
+ filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name)))
+ functions, err := g.genFunction(proto.PbPackage, service)
+ if err != nil {
+ return err
+ }
+
+ iFunctions, err := g.getInterfaceFuncs(service)
+ if err != nil {
+ return err
+ }
+
+ text, err := util.LoadTemplate(category, callTemplateFile, callTemplateText)
+ if err != nil {
+ return err
+ }
+
+ var alias []string
+ for _, item := range service.RPC {
+ alias = append(alias, fmt.Sprintf("%s = %s", parser.CamelCase(item.RequestType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.RequestType))))
+ alias = append(alias, fmt.Sprintf("%s = %s", parser.CamelCase(item.ReturnsType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.ReturnsType))))
+ }
+
+ err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
+ "name": formatFilename(service.Name),
+ "alias": strings.Join(alias, util.NL),
+ "head": head,
+ "filePackage": formatFilename(service.Name),
+ "package": fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
+ "serviceName": parser.CamelCase(service.Name),
+ "functions": strings.Join(functions, util.NL),
+ "interface": strings.Join(iFunctions, util.NL),
+ }, filename, true)
+ return err
+}
+
+func (g *defaultGenerator) genFunction(goPackage string, service parser.Service) ([]string, error) {
+ functions := make([]string, 0)
+ for _, rpc := range service.RPC {
+ text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
+ if err != nil {
+ return nil, err
+ }
+
+ comment := parser.GetComment(rpc.Doc())
+ buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
+ "rpcServiceName": stringx.From(service.Name).Title(),
+ "method": stringx.From(rpc.Name).Title(),
+ "package": goPackage,
+ "pbRequest": parser.CamelCase(rpc.RequestType),
+ "pbResponse": parser.CamelCase(rpc.ReturnsType),
+ "hasComment": len(comment) > 0,
+ "comment": comment,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ functions = append(functions, buffer.String())
+ }
+ return functions, nil
+}
+
+func (g *defaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) {
+ functions := make([]string, 0)
+
+ for _, rpc := range service.RPC {
+ text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
+ if err != nil {
+ return nil, err
+ }
+
+ comment := parser.GetComment(rpc.Doc())
+ buffer, err := util.With("interfaceFn").Parse(text).Execute(
+ map[string]interface{}{
+ "hasComment": len(comment) > 0,
+ "comment": comment,
+ "method": stringx.From(rpc.Name).Title(),
+ "pbRequest": parser.CamelCase(rpc.RequestType),
+ "pbResponse": parser.CamelCase(rpc.ReturnsType),
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ functions = append(functions, buffer.String())
+ }
+
+ return functions, nil
+}
diff --git a/tools/goctl/rpc/generator/gencall_test.go b/tools/goctl/rpc/generator/gencall_test.go
new file mode 100644
index 00000000..6ae0cefe
--- /dev/null
+++ b/tools/goctl/rpc/generator/gencall_test.go
@@ -0,0 +1,44 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateCall(t *testing.T) {
+ _ = Clean()
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ err = g.Prepare()
+ if err != nil {
+ return
+ }
+ err = g.GenCall(dirCtx, proto)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/gen/genconfig.go b/tools/goctl/rpc/generator/genconfig.go
similarity index 64%
rename from tools/goctl/rpc/gen/genconfig.go
rename to tools/goctl/rpc/generator/genconfig.go
index 91daf5c9..c9cd4bc4 100644
--- a/tools/goctl/rpc/gen/genconfig.go
+++ b/tools/goctl/rpc/generator/genconfig.go
@@ -1,10 +1,11 @@
-package gen
+package generator
import (
"io/ioutil"
"os"
"path/filepath"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
@@ -17,9 +18,9 @@ type Config struct {
}
`
-func (g *defaultRpcGenerator) genConfig() error {
- configPath := g.dirM[dirConfig]
- fileName := filepath.Join(configPath, fileConfig)
+func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto) error {
+ dir := ctx.GetConfig()
+ fileName := filepath.Join(dir.Filename, formatFilename("config")+".go")
if util.FileExists(fileName) {
return nil
}
diff --git a/tools/goctl/rpc/generator/genconfig_test.go b/tools/goctl/rpc/generator/genconfig_test.go
new file mode 100644
index 00000000..39006b10
--- /dev/null
+++ b/tools/goctl/rpc/generator/genconfig_test.go
@@ -0,0 +1,48 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateConfig(t *testing.T) {
+ _ = Clean()
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ err = g.Prepare()
+ if err != nil {
+ return
+ }
+ err = g.GenConfig(dirCtx, proto)
+ assert.Nil(t, err)
+
+ // test file exists
+ err = g.GenConfig(dirCtx, proto)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/generator/generator.go b/tools/goctl/rpc/generator/generator.go
new file mode 100644
index 00000000..a46ed42c
--- /dev/null
+++ b/tools/goctl/rpc/generator/generator.go
@@ -0,0 +1,15 @@
+package generator
+
+import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+
+type Generator interface {
+ Prepare() error
+ GenMain(ctx DirContext, proto parser.Proto) error
+ GenCall(ctx DirContext, proto parser.Proto) error
+ GenEtc(ctx DirContext, proto parser.Proto) error
+ GenConfig(ctx DirContext, proto parser.Proto) error
+ GenLogic(ctx DirContext, proto parser.Proto) error
+ GenServer(ctx DirContext, proto parser.Proto) error
+ GenSvc(ctx DirContext, proto parser.Proto) error
+ GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error
+}
diff --git a/tools/goctl/rpc/gen/genetc.go b/tools/goctl/rpc/generator/genetc.go
similarity index 55%
rename from tools/goctl/rpc/gen/genetc.go
rename to tools/goctl/rpc/generator/genetc.go
index c2e64b53..ddd36fba 100644
--- a/tools/goctl/rpc/gen/genetc.go
+++ b/tools/goctl/rpc/generator/genetc.go
@@ -1,9 +1,10 @@
-package gen
+package generator
import (
"fmt"
"path/filepath"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
@@ -15,12 +16,10 @@ Etcd:
Key: {{.serviceName}}.rpc
`
-func (g *defaultRpcGenerator) genEtc() error {
- etdDir := g.dirM[dirEtc]
- fileName := filepath.Join(etdDir, fmt.Sprintf("%v.yaml", g.Ctx.ServiceName.Lower()))
- if util.FileExists(fileName) {
- return nil
- }
+func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error {
+ dir := ctx.GetEtc()
+ serviceNameLower := formatFilename(ctx.GetMain().Base)
+ fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower))
text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
if err != nil {
@@ -28,6 +27,6 @@ func (g *defaultRpcGenerator) genEtc() error {
}
return util.With("etc").Parse(text).SaveTo(map[string]interface{}{
- "serviceName": g.Ctx.ServiceName.Lower(),
+ "serviceName": serviceNameLower,
}, fileName, false)
}
diff --git a/tools/goctl/rpc/generator/genetc_test.go b/tools/goctl/rpc/generator/genetc_test.go
new file mode 100644
index 00000000..457cfed4
--- /dev/null
+++ b/tools/goctl/rpc/generator/genetc_test.go
@@ -0,0 +1,45 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateEtc(t *testing.T) {
+ _ = Clean()
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ err = g.Prepare()
+ if err != nil {
+ return
+ }
+
+ err = g.GenEtc(dirCtx, proto)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/generator/genlogic.go b/tools/goctl/rpc/generator/genlogic.go
new file mode 100644
index 00000000..a48c9092
--- /dev/null
+++ b/tools/goctl/rpc/generator/genlogic.go
@@ -0,0 +1,101 @@
+package generator
+
+import (
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "github.com/tal-tech/go-zero/core/collection"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
+)
+
+const (
+ logicTemplate = `package logic
+
+import (
+ "context"
+
+ {{.imports}}
+
+ "github.com/tal-tech/go-zero/core/logx"
+)
+
+type {{.logicName}} struct {
+ ctx context.Context
+ svcCtx *svc.ServiceContext
+ logx.Logger
+}
+
+func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
+ return &{{.logicName}}{
+ ctx: ctx,
+ svcCtx: svcCtx,
+ Logger: logx.WithContext(ctx),
+ }
+}
+{{.functions}}
+`
+ logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
+func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
+ // todo: add your logic here and delete this line
+
+ return &{{.responseType}}{}, nil
+}
+`
+)
+
+func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error {
+ dir := ctx.GetLogic()
+ for _, rpc := range proto.Service.RPC {
+ filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic")+".go")
+ functions, err := g.genLogicFunction(proto.PbPackage, rpc)
+ if err != nil {
+ return err
+ }
+
+ imports := collection.NewSet()
+ imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetSvc().Package))
+ imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetPb().Package))
+ text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
+ if err != nil {
+ return err
+ }
+ err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
+ "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()),
+ "functions": functions,
+ "imports": strings.Join(imports.KeysStr(), util.NL),
+ }, filename, false)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (g *defaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) {
+ var functions = make([]string, 0)
+ text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
+ if err != nil {
+ return "", err
+ }
+
+ logicName := stringx.From(rpc.Name + "_logic").ToCamel()
+ comment := parser.GetComment(rpc.Doc())
+ buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
+ "logicName": logicName,
+ "method": parser.CamelCase(rpc.Name),
+ "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
+ "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
+ "responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
+ "hasComment": len(comment) > 0,
+ "comment": comment,
+ })
+ if err != nil {
+ return "", err
+ }
+
+ functions = append(functions, buffer.String())
+ return strings.Join(functions, util.NL), nil
+}
diff --git a/tools/goctl/rpc/generator/genlogic_test.go b/tools/goctl/rpc/generator/genlogic_test.go
new file mode 100644
index 00000000..681c89d7
--- /dev/null
+++ b/tools/goctl/rpc/generator/genlogic_test.go
@@ -0,0 +1,44 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateLogic(t *testing.T) {
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ err = g.Prepare()
+ if err != nil {
+ return
+ }
+
+ err = g.GenLogic(dirCtx, proto)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/generator/genmain.go b/tools/goctl/rpc/generator/genmain.go
new file mode 100644
index 00000000..eed6a7d2
--- /dev/null
+++ b/tools/goctl/rpc/generator/genmain.go
@@ -0,0 +1,70 @@
+package generator
+
+import (
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+)
+
+const mainTemplate = `{{.head}}
+
+package main
+
+import (
+ "flag"
+ "fmt"
+
+ {{.imports}}
+
+ "github.com/tal-tech/go-zero/core/conf"
+ "github.com/tal-tech/go-zero/zrpc"
+ "google.golang.org/grpc"
+)
+
+var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
+
+func main() {
+ flag.Parse()
+
+ var c config.Config
+ conf.MustLoad(*configFile, &c)
+ ctx := svc.NewServiceContext(c)
+ srv := server.New{{.service}}Server(ctx)
+
+ s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
+ {{.pkg}}.Register{{.service}}Server(grpcServer, srv)
+ })
+ defer s.Stop()
+
+ fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
+ s.Start()
+}
+`
+
+func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
+ dir := ctx.GetMain()
+ serviceNameLower := formatFilename(ctx.GetMain().Base)
+ fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower))
+ imports := make([]string, 0)
+ pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
+ svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
+ remoteImport := fmt.Sprintf(`"%v"`, ctx.GetServer().Package)
+ configImport := fmt.Sprintf(`"%v"`, ctx.GetConfig().Package)
+ imports = append(imports, configImport, pbImport, remoteImport, svcImport)
+ head := util.GetHead(proto.Name)
+ text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate)
+ if err != nil {
+ return err
+ }
+
+ return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
+ "head": head,
+ "serviceName": serviceNameLower,
+ "imports": strings.Join(imports, util.NL),
+ "pkg": proto.PbPackage,
+ "service": parser.CamelCase(proto.Service.Name),
+ }, fileName, false)
+}
diff --git a/tools/goctl/rpc/generator/genmain_test.go b/tools/goctl/rpc/generator/genmain_test.go
new file mode 100644
index 00000000..aed65b02
--- /dev/null
+++ b/tools/goctl/rpc/generator/genmain_test.go
@@ -0,0 +1,45 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateMain(t *testing.T) {
+ _ = Clean()
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ err = g.Prepare()
+ if err != nil {
+ return
+ }
+
+ err = g.GenMain(dirCtx, proto)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/generator/genpb.go b/tools/goctl/rpc/generator/genpb.go
new file mode 100644
index 00000000..4327ae52
--- /dev/null
+++ b/tools/goctl/rpc/generator/genpb.go
@@ -0,0 +1,31 @@
+package generator
+
+import (
+ "bytes"
+ "path/filepath"
+ "strings"
+
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+)
+
+func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error {
+ dir := ctx.GetPb()
+ cw := new(bytes.Buffer)
+ base := filepath.Dir(proto.Src)
+ cw.WriteString("protoc ")
+ for _, ip := range protoImportPath {
+ cw.WriteString(" -I=" + ip)
+ }
+ cw.WriteString(" -I=" + base)
+ cw.WriteString(" " + proto.Name)
+ if strings.Contains(proto.GoPackage, "/") {
+ cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetInternal().Filename)
+ } else {
+ cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
+ }
+ command := cw.String()
+ g.log.Debug(command)
+ _, err := execx.Run(command, "")
+ return err
+}
diff --git a/tools/goctl/rpc/generator/genpb_test.go b/tools/goctl/rpc/generator/genpb_test.go
new file mode 100644
index 00000000..1c423072
--- /dev/null
+++ b/tools/goctl/rpc/generator/genpb_test.go
@@ -0,0 +1,184 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateCaseNilImport(t *testing.T) {
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ //_ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ if err := g.Prepare(); err == nil {
+ targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
+ err = g.GenPb(dirCtx, nil, proto)
+ assert.Nil(t, err)
+ assert.True(t, func() bool {
+ return util.FileExists(targetPb)
+ }())
+ }
+}
+
+func TestGenerateCaseImport(t *testing.T) {
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ if err := g.Prepare(); err == nil {
+ err = g.GenPb(dirCtx, nil, proto)
+ assert.Nil(t, err)
+
+ targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
+ assert.True(t, func() bool {
+ return util.FileExists(targetPb)
+ }())
+ }
+}
+
+func TestGenerateCasePathOption(t *testing.T) {
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_option.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ if err := g.Prepare(); err == nil {
+ err = g.GenPb(dirCtx, nil, proto)
+ assert.Nil(t, err)
+
+ targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_option.pb.go")
+ assert.True(t, func() bool {
+ return util.FileExists(targetPb)
+ }())
+ }
+}
+
+func TestGenerateCaseWordOption(t *testing.T) {
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_word_option.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ if err := g.Prepare(); err == nil {
+
+ err = g.GenPb(dirCtx, nil, proto)
+ assert.Nil(t, err)
+
+ targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_word_option.pb.go")
+ assert.True(t, func() bool {
+ return util.FileExists(targetPb)
+ }())
+ }
+}
+
+// test keyword go
+func TestGenerateCaseGoOption(t *testing.T) {
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_go_option.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ if err := g.Prepare(); err == nil {
+
+ err = g.GenPb(dirCtx, nil, proto)
+ assert.Nil(t, err)
+
+ targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_go_option.pb.go")
+ assert.True(t, func() bool {
+ return util.FileExists(targetPb)
+ }())
+ }
+}
diff --git a/tools/goctl/rpc/generator/genserver.go b/tools/goctl/rpc/generator/genserver.go
new file mode 100644
index 00000000..6cc8e820
--- /dev/null
+++ b/tools/goctl/rpc/generator/genserver.go
@@ -0,0 +1,102 @@
+package generator
+
+import (
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "github.com/tal-tech/go-zero/core/collection"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
+)
+
+const (
+ serverTemplate = `{{.head}}
+
+package server
+
+import (
+ "context"
+
+ {{.imports}}
+)
+
+type {{.server}}Server struct {
+ svcCtx *svc.ServiceContext
+}
+
+func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
+ return &{{.server}}Server{
+ svcCtx: svcCtx,
+ }
+}
+
+{{.funcs}}
+`
+ functionTemplate = `
+{{if .hasComment}}{{.comment}}{{end}}
+func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
+ l := logic.New{{.logicName}}(ctx,s.svcCtx)
+ return l.{{.method}}(in)
+}
+`
+)
+
+func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
+ dir := ctx.GetServer()
+ logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
+ svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
+ pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
+
+ imports := collection.NewSet()
+ imports.AddStr(logicImport, svcImport, pbImport)
+
+ head := util.GetHead(proto.Name)
+ service := proto.Service
+ serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server")+".go")
+ funcList, err := g.genFunctions(proto.PbPackage, service)
+ if err != nil {
+ return err
+ }
+
+ text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
+ if err != nil {
+ return err
+ }
+
+ err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
+ "head": head,
+ "server": stringx.From(service.Name).Title(),
+ "imports": strings.Join(imports.KeysStr(), util.NL),
+ "funcs": strings.Join(funcList, util.NL),
+ }, serverFile, true)
+ return err
+}
+
+func (g *defaultGenerator) genFunctions(goPackage string, service parser.Service) ([]string, error) {
+ var functionList []string
+ for _, rpc := range service.RPC {
+ text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
+ if err != nil {
+ return nil, err
+ }
+
+ comment := parser.GetComment(rpc.Doc())
+ buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
+ "server": stringx.From(service.Name).Title(),
+ "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()),
+ "method": parser.CamelCase(rpc.Name),
+ "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
+ "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
+ "hasComment": len(comment) > 0,
+ "comment": comment,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ functionList = append(functionList, buffer.String())
+ }
+ return functionList, nil
+}
diff --git a/tools/goctl/rpc/generator/genserver_test.go b/tools/goctl/rpc/generator/genserver_test.go
new file mode 100644
index 00000000..e5f1e3f6
--- /dev/null
+++ b/tools/goctl/rpc/generator/genserver_test.go
@@ -0,0 +1,45 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateServer(t *testing.T) {
+ _ = Clean()
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ err = g.Prepare()
+ if err != nil {
+ return
+ }
+
+ err = g.GenServer(dirCtx, proto)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/gen/gensvc.go b/tools/goctl/rpc/generator/gensvc.go
similarity index 61%
rename from tools/goctl/rpc/gen/gensvc.go
rename to tools/goctl/rpc/generator/gensvc.go
index a1dcb3d1..86df41b4 100644
--- a/tools/goctl/rpc/gen/gensvc.go
+++ b/tools/goctl/rpc/generator/gensvc.go
@@ -1,9 +1,10 @@
-package gen
+package generator
import (
"fmt"
"path/filepath"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
@@ -22,15 +23,15 @@ func NewServiceContext(c config.Config) *ServiceContext {
}
`
-func (g *defaultRpcGenerator) genSvc() error {
- svcPath := g.dirM[dirSvc]
- fileName := filepath.Join(svcPath, fileServiceContext)
+func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto) error {
+ dir := ctx.GetSvc()
+ fileName := filepath.Join(dir.Filename, formatFilename("service_context")+".go")
text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate)
if err != nil {
return err
}
return util.With("svc").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
- "imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)),
+ "imports": fmt.Sprintf(`"%v"`, ctx.GetConfig().Package),
}, fileName, false)
}
diff --git a/tools/goctl/rpc/generator/gensvc_test.go b/tools/goctl/rpc/generator/gensvc_test.go
new file mode 100644
index 00000000..6bf43e0d
--- /dev/null
+++ b/tools/goctl/rpc/generator/gensvc_test.go
@@ -0,0 +1,40 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestGenerateSvc(t *testing.T) {
+ _ = Clean()
+ project := "stream"
+ abs, err := filepath.Abs("./test")
+ assert.Nil(t, err)
+
+ dir := filepath.Join(abs, project)
+ err = util.MkdirIfNotExist(dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(abs)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test_stream.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+
+ g := NewDefaultGenerator()
+ err = g.GenSvc(dirCtx, proto)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/generator/mkdir.go b/tools/goctl/rpc/generator/mkdir.go
new file mode 100644
index 00000000..e0cb6d30
--- /dev/null
+++ b/tools/goctl/rpc/generator/mkdir.go
@@ -0,0 +1,152 @@
+package generator
+
+import (
+ "path/filepath"
+ "strings"
+
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+ "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
+)
+
+const (
+ wd = "wd"
+ etc = "etc"
+ internal = "internal"
+ config = "config"
+ logic = "logic"
+ server = "server"
+ svc = "svc"
+ pb = "pb"
+ call = "call"
+)
+
+type (
+ DirContext interface {
+ GetCall() Dir
+ GetEtc() Dir
+ GetInternal() Dir
+ GetConfig() Dir
+ GetLogic() Dir
+ GetServer() Dir
+ GetSvc() Dir
+ GetPb() Dir
+ GetMain() Dir
+ }
+
+ Dir struct {
+ Base string
+ Filename string
+ Package string
+ }
+ defaultDirContext struct {
+ inner map[string]Dir
+ }
+)
+
+func mkdir(ctx *ctx.ProjectContext, proto parser.Proto) (DirContext, error) {
+ inner := make(map[string]Dir)
+ etcDir := filepath.Join(ctx.WorkDir, "etc")
+ internalDir := filepath.Join(ctx.WorkDir, "internal")
+ configDir := filepath.Join(internalDir, "config")
+ logicDir := filepath.Join(internalDir, "logic")
+ serverDir := filepath.Join(internalDir, "server")
+ svcDir := filepath.Join(internalDir, "svc")
+ pbDir := filepath.Join(internalDir, proto.GoPackage)
+ callDir := filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name).ToCamel()))
+ inner[wd] = Dir{
+ Filename: ctx.WorkDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(ctx.WorkDir, ctx.Dir))),
+ Base: filepath.Base(ctx.WorkDir),
+ }
+ inner[etc] = Dir{
+ Filename: etcDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(etcDir, ctx.Dir))),
+ Base: filepath.Base(etcDir),
+ }
+ inner[internal] = Dir{
+ Filename: internalDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(internalDir, ctx.Dir))),
+ Base: filepath.Base(internalDir),
+ }
+ inner[config] = Dir{
+ Filename: configDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(configDir, ctx.Dir))),
+ Base: filepath.Base(configDir),
+ }
+ inner[logic] = Dir{
+ Filename: logicDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(logicDir, ctx.Dir))),
+ Base: filepath.Base(logicDir),
+ }
+ inner[server] = Dir{
+ Filename: serverDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(serverDir, ctx.Dir))),
+ Base: filepath.Base(serverDir),
+ }
+ inner[svc] = Dir{
+ Filename: svcDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(svcDir, ctx.Dir))),
+ Base: filepath.Base(svcDir),
+ }
+ inner[pb] = Dir{
+ Filename: pbDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(pbDir, ctx.Dir))),
+ Base: filepath.Base(pbDir),
+ }
+ inner[call] = Dir{
+ Filename: callDir,
+ Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(callDir, ctx.Dir))),
+ Base: filepath.Base(callDir),
+ }
+ for _, v := range inner {
+ err := util.MkdirIfNotExist(v.Filename)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return &defaultDirContext{
+ inner: inner,
+ }, nil
+}
+
+func (d *defaultDirContext) GetCall() Dir {
+ return d.inner[call]
+}
+
+func (d *defaultDirContext) GetEtc() Dir {
+ return d.inner[etc]
+}
+
+func (d *defaultDirContext) GetInternal() Dir {
+ return d.inner[internal]
+}
+
+func (d *defaultDirContext) GetConfig() Dir {
+ return d.inner[config]
+}
+
+func (d *defaultDirContext) GetLogic() Dir {
+ return d.inner[logic]
+}
+
+func (d *defaultDirContext) GetServer() Dir {
+ return d.inner[server]
+}
+
+func (d *defaultDirContext) GetSvc() Dir {
+ return d.inner[svc]
+}
+
+func (d *defaultDirContext) GetPb() Dir {
+ return d.inner[pb]
+}
+
+func (d *defaultDirContext) GetMain() Dir {
+ return d.inner[wd]
+}
+
+func (d *Dir) Valid() bool {
+ return len(d.Filename) > 0 && len(d.Package) > 0
+}
diff --git a/tools/goctl/rpc/generator/mkdir_test.go b/tools/goctl/rpc/generator/mkdir_test.go
new file mode 100644
index 00000000..ae4f1f3a
--- /dev/null
+++ b/tools/goctl/rpc/generator/mkdir_test.go
@@ -0,0 +1,130 @@
+package generator
+
+import (
+ "go/build"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/core/stringx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+ "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
+)
+
+func TestMkDirInGoPath(t *testing.T) {
+ dft := build.Default
+ gp := dft.GOPATH
+ if len(gp) == 0 {
+ return
+ }
+ projectName := stringx.Rand()
+ dir := filepath.Join(gp, "src", projectName)
+ err := util.MkdirIfNotExist(dir)
+ if err != nil {
+ return
+ }
+
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+ internal := filepath.Join(dir, "internal")
+ assert.True(t, true, func() bool {
+ return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
+ }())
+ assert.True(t, true, func() bool {
+ return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
+ }())
+ assert.True(t, true, func() bool {
+ return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
+ }())
+}
+
+func TestMkDirInGoMod(t *testing.T) {
+ dft := build.Default
+ gp := dft.GOPATH
+ if len(gp) == 0 {
+ return
+ }
+ projectName := stringx.Rand()
+ dir := filepath.Join(gp, "src", projectName)
+ err := util.MkdirIfNotExist(dir)
+ if err != nil {
+ return
+ }
+
+ _, err = execx.Run("go mod init "+projectName, dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+
+ projectCtx, err := ctx.Prepare(dir)
+ assert.Nil(t, err)
+
+ p := parser.NewDefaultProtoParser()
+ proto, err := p.Parse("./test.proto")
+ assert.Nil(t, err)
+
+ dirCtx, err := mkdir(projectCtx, proto)
+ assert.Nil(t, err)
+ internal := filepath.Join(dir, "internal")
+ assert.True(t, true, func() bool {
+ return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
+ }())
+ assert.True(t, true, func() bool {
+ return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
+ }())
+ assert.True(t, true, func() bool {
+ return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
+ }())
+ assert.True(t, true, func() bool {
+ return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
+ }())
+}
diff --git a/tools/goctl/rpc/gen/rpctemplate.go b/tools/goctl/rpc/generator/prototmpl.go
similarity index 51%
rename from tools/goctl/rpc/gen/rpctemplate.go
rename to tools/goctl/rpc/generator/prototmpl.go
index daf9d016..a4afd242 100644
--- a/tools/goctl/rpc/gen/rpctemplate.go
+++ b/tools/goctl/rpc/generator/prototmpl.go
@@ -1,12 +1,10 @@
-package gen
+package generator
import (
"path/filepath"
"strings"
- "github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util"
- "github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
@@ -27,33 +25,23 @@ service {{.serviceName}} {
}
`
-type rpcTemplate struct {
- out string
- console.Console
-}
-
-func NewRpcTemplate(out string, idea bool) *rpcTemplate {
- return &rpcTemplate{
- out: out,
- Console: console.NewConsole(idea),
- }
-}
-
-func (r *rpcTemplate) MustGenerate(showState bool) {
- r.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」")
- r.Info("-> generating template...")
- protoFilename := filepath.Base(r.out)
+func ProtoTmpl(out string) error {
+ protoFilename := filepath.Base(out)
serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename)))
text, err := util.LoadTemplate(category, rpcTemplateFile, rpcTemplateText)
- r.Must(err)
+ if err != nil {
+ return err
+ }
+
+ dir := filepath.Dir(out)
+ err = util.MkdirIfNotExist(dir)
+ if err != nil {
+ return err
+ }
err = util.With("t").Parse(text).SaveTo(map[string]string{
"package": serviceName.UnTitle(),
"serviceName": serviceName.Title(),
- }, r.out, false)
- r.Must(err)
-
- if showState {
- r.Success("Done.")
- }
+ }, out, false)
+ return err
}
diff --git a/tools/goctl/rpc/generator/prototmpl_test.go b/tools/goctl/rpc/generator/prototmpl_test.go
new file mode 100644
index 00000000..627ec6f4
--- /dev/null
+++ b/tools/goctl/rpc/generator/prototmpl_test.go
@@ -0,0 +1,21 @@
+package generator
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestProtoTmpl(t *testing.T) {
+ out, err := filepath.Abs("./test/test.proto")
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(filepath.Dir(out))
+ }()
+ err = ProtoTmpl(out)
+ assert.Nil(t, err)
+ _, err = os.Stat(out)
+ assert.Nil(t, err)
+}
diff --git a/tools/goctl/rpc/gen/template.go b/tools/goctl/rpc/generator/template.go
similarity index 94%
rename from tools/goctl/rpc/gen/template.go
rename to tools/goctl/rpc/generator/template.go
index ccbf453f..1816f900 100644
--- a/tools/goctl/rpc/gen/template.go
+++ b/tools/goctl/rpc/generator/template.go
@@ -1,4 +1,4 @@
-package gen
+package generator
import (
"fmt"
@@ -10,7 +10,6 @@ import (
const (
category = "rpc"
callTemplateFile = "call.tpl"
- callTypesTemplateFile = "call-types.tpl"
callInterfaceFunctionTemplateFile = "call-interface-func.tpl"
callFunctionTemplateFile = "call-func.tpl"
configTemplateFileFile = "config.tpl"
@@ -26,7 +25,6 @@ const (
var templates = map[string]string{
callTemplateFile: callTemplateText,
- callTypesTemplateFile: callTemplateTypes,
callInterfaceFunctionTemplateFile: callInterfaceFunctionTemplate,
callFunctionTemplateFile: callFunctionTemplate,
configTemplateFileFile: configTemplate,
diff --git a/tools/goctl/rpc/gen/template_test.go b/tools/goctl/rpc/generator/template_test.go
similarity index 95%
rename from tools/goctl/rpc/gen/template_test.go
rename to tools/goctl/rpc/generator/template_test.go
index b3f21d8d..89200090 100644
--- a/tools/goctl/rpc/gen/template_test.go
+++ b/tools/goctl/rpc/generator/template_test.go
@@ -1,4 +1,4 @@
-package gen
+package generator
import (
"io/ioutil"
@@ -90,3 +90,7 @@ func TestUpdate(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
}
+
+func TestGetCategory(t *testing.T) {
+ assert.Equal(t, category, GetCategory())
+}
diff --git a/tools/goctl/rpc/generator/test.proto b/tools/goctl/rpc/generator/test.proto
new file mode 100644
index 00000000..1856d6d8
--- /dev/null
+++ b/tools/goctl/rpc/generator/test.proto
@@ -0,0 +1,25 @@
+// test proto
+syntax = "proto3";
+
+package test;
+option go_package = "go";
+
+import "test_base.proto";
+
+message TestMessage{
+ base.CommonReq req = 1;
+}
+message TestReq{}
+message TestReply{
+ base.CommonReply reply = 2;
+}
+
+enum TestEnum {
+ unknown = 0;
+ male = 1;
+ female = 2;
+}
+
+service TestService{
+ rpc TestRpc (TestReq)returns(TestReply);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/generator/test_base.proto b/tools/goctl/rpc/generator/test_base.proto
new file mode 100644
index 00000000..36e0ca5d
--- /dev/null
+++ b/tools/goctl/rpc/generator/test_base.proto
@@ -0,0 +1,12 @@
+// test proto
+syntax = "proto3";
+
+package base;
+
+message CommonReq {
+ string in = 1;
+}
+
+message CommonReply {
+ string out = 1;
+}
diff --git a/tools/goctl/rpc/generator/test_go_option.proto b/tools/goctl/rpc/generator/test_go_option.proto
new file mode 100644
index 00000000..9c4397b7
--- /dev/null
+++ b/tools/goctl/rpc/generator/test_go_option.proto
@@ -0,0 +1,18 @@
+// test proto
+syntax = "proto3";
+
+package stream;
+
+option go_package="go";
+
+message StreamReq {
+ string name = 1;
+}
+
+message StreamResp {
+ string greet = 1;
+}
+
+service StreamGreeter {
+ rpc greet(StreamReq) returns (StreamResp);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/generator/test_import.proto b/tools/goctl/rpc/generator/test_import.proto
new file mode 100644
index 00000000..4d727aee
--- /dev/null
+++ b/tools/goctl/rpc/generator/test_import.proto
@@ -0,0 +1,18 @@
+// test proto
+syntax = "proto3";
+
+package greet;
+import "base/common.proto";
+
+message In {
+ string name = 1;
+ common.User user = 2;
+}
+
+message Out {
+ string greet = 1;
+}
+
+service StreamGreeter {
+ rpc greet(In) returns (Out);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/generator/test_option.proto b/tools/goctl/rpc/generator/test_option.proto
new file mode 100644
index 00000000..30df31ab
--- /dev/null
+++ b/tools/goctl/rpc/generator/test_option.proto
@@ -0,0 +1,18 @@
+// test proto
+syntax = "proto3";
+
+package stream;
+
+option go_package="github.com/tal-tech/go-zero";
+
+message StreamReq {
+ string name = 1;
+}
+
+message StreamResp {
+ string greet = 1;
+}
+
+service StreamGreeter {
+ rpc greet(StreamReq) returns (StreamResp);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/generator/test_stream.proto b/tools/goctl/rpc/generator/test_stream.proto
new file mode 100644
index 00000000..5006d77d
--- /dev/null
+++ b/tools/goctl/rpc/generator/test_stream.proto
@@ -0,0 +1,16 @@
+// test proto
+syntax = "proto3";
+
+package stream;
+
+message StreamReq {
+ string name = 1;
+}
+
+message StreamResp {
+ string greet = 1;
+}
+
+service StreamGreeter {
+ rpc greet(StreamReq) returns (StreamResp);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/generator/test_word_option.proto b/tools/goctl/rpc/generator/test_word_option.proto
new file mode 100644
index 00000000..4f7753a5
--- /dev/null
+++ b/tools/goctl/rpc/generator/test_word_option.proto
@@ -0,0 +1,18 @@
+// test proto
+syntax = "proto3";
+
+package stream;
+
+option go_package="user";
+
+message StreamReq {
+ string name = 1;
+}
+
+message StreamResp {
+ string greet = 1;
+}
+
+service StreamGreeter {
+ rpc greet(StreamReq) returns (StreamResp);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/parser/comment.go b/tools/goctl/rpc/parser/comment.go
new file mode 100644
index 00000000..ad7bccb8
--- /dev/null
+++ b/tools/goctl/rpc/parser/comment.go
@@ -0,0 +1,10 @@
+package parser
+
+import "github.com/emicklei/proto"
+
+func GetComment(comment *proto.Comment) string {
+ if comment == nil {
+ return ""
+ }
+ return comment.Message()
+}
diff --git a/tools/goctl/rpc/parser/import.go b/tools/goctl/rpc/parser/import.go
new file mode 100644
index 00000000..5095d52e
--- /dev/null
+++ b/tools/goctl/rpc/parser/import.go
@@ -0,0 +1,7 @@
+package parser
+
+import "github.com/emicklei/proto"
+
+type Import struct {
+ *proto.Import
+}
diff --git a/tools/goctl/rpc/parser/message.go b/tools/goctl/rpc/parser/message.go
new file mode 100644
index 00000000..5f5a7196
--- /dev/null
+++ b/tools/goctl/rpc/parser/message.go
@@ -0,0 +1,7 @@
+package parser
+
+import pr "github.com/emicklei/proto"
+
+type Message struct {
+ *pr.Message
+}
diff --git a/tools/goctl/rpc/parser/option.go b/tools/goctl/rpc/parser/option.go
new file mode 100644
index 00000000..06f34cf4
--- /dev/null
+++ b/tools/goctl/rpc/parser/option.go
@@ -0,0 +1,7 @@
+package parser
+
+import "github.com/emicklei/proto"
+
+type Option struct {
+ *proto.Option
+}
diff --git a/tools/goctl/rpc/parser/package.go b/tools/goctl/rpc/parser/package.go
new file mode 100644
index 00000000..6a7abf88
--- /dev/null
+++ b/tools/goctl/rpc/parser/package.go
@@ -0,0 +1,7 @@
+package parser
+
+import "github.com/emicklei/proto"
+
+type Package struct {
+ *proto.Package
+}
diff --git a/tools/goctl/rpc/parser/parser.go b/tools/goctl/rpc/parser/parser.go
index 16dec9fd..764f647d 100644
--- a/tools/goctl/rpc/parser/parser.go
+++ b/tools/goctl/rpc/parser/parser.go
@@ -1,46 +1,170 @@
package parser
import (
+ "errors"
+ "fmt"
+ "go/token"
+ "os"
"path/filepath"
"strings"
+ "unicode"
+ "unicode/utf8"
- "github.com/tal-tech/go-zero/core/lang"
- "github.com/tal-tech/go-zero/tools/goctl/util/console"
+ "github.com/emicklei/proto"
)
-func Transfer(proto, target string, externalImport []*Import, console console.Console) (*PbAst, error) {
- messageM := make(map[string]lang.PlaceholderType)
- enumM := make(map[string]*Enum)
- protoAst, err := parseProto(proto, messageM, enumM)
- if err != nil {
- return nil, err
- }
- for _, item := range externalImport {
- err = checkImport(item.OriginalProtoPath)
- if err != nil {
- return nil, err
- }
- innerAst, err := parseProto(item.OriginalProtoPath, protoAst.Message, protoAst.Enum)
- if err != nil {
- return nil, err
- }
- for k, v := range innerAst.Message {
- protoAst.Message[k] = v
- }
- for k, v := range innerAst.Enum {
- protoAst.Enum[k] = v
- }
- }
- protoAst.Import = externalImport
- protoAst.PbSrc = filepath.Join(target, strings.TrimSuffix(filepath.Base(proto), ".proto")+".pb.go")
- return transfer(protoAst, console)
+type (
+ defaultProtoParser struct{}
+)
+
+func NewDefaultProtoParser() *defaultProtoParser {
+ return &defaultProtoParser{}
}
-func transfer(proto *Proto, console console.Console) (*PbAst, error) {
- parser := MustNewAstParser(proto, console)
- parse, err := parser.Parse()
+func (p *defaultProtoParser) Parse(src string) (Proto, error) {
+ var ret Proto
+
+ abs, err := filepath.Abs(src)
if err != nil {
- return nil, err
+ return Proto{}, err
}
- return parse, nil
+
+ r, err := os.Open(abs)
+ if err != nil {
+ return ret, err
+ }
+ defer r.Close()
+
+ parser := proto.NewParser(r)
+ set, err := parser.Parse()
+ if err != nil {
+ return ret, err
+ }
+
+ var serviceList []Service
+ proto.Walk(
+ set,
+ proto.WithImport(func(i *proto.Import) {
+ ret.Import = append(ret.Import, Import{Import: i})
+ }),
+ proto.WithMessage(func(message *proto.Message) {
+ ret.Message = append(ret.Message, Message{Message: message})
+ }),
+ proto.WithPackage(func(p *proto.Package) {
+ ret.Package = Package{Package: p}
+ }),
+ proto.WithService(func(service *proto.Service) {
+ serv := Service{Service: service}
+ elements := service.Elements
+ for _, el := range elements {
+ v, _ := el.(*proto.RPC)
+ if v == nil {
+ continue
+ }
+ serv.RPC = append(serv.RPC, &RPC{RPC: v})
+ }
+
+ serviceList = append(serviceList, serv)
+ }),
+ proto.WithOption(func(option *proto.Option) {
+ if option.Name == "go_package" {
+ ret.GoPackage = option.Constant.Source
+ }
+ }),
+ )
+ if len(serviceList) == 0 {
+ return ret, errors.New("rpc service not found")
+ }
+
+ if len(serviceList) > 1 {
+ return ret, errors.New("only one service expected")
+ }
+ service := serviceList[0]
+ name := filepath.Base(abs)
+
+ for _, rpc := range service.RPC {
+ if strings.Contains(rpc.RequestType, ".") {
+ return ret, fmt.Errorf("line %v:%v, request type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
+ }
+ if strings.Contains(rpc.ReturnsType, ".") {
+ return ret, fmt.Errorf("line %v:%v, returns type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
+ }
+ }
+ if len(ret.GoPackage) == 0 {
+ ret.GoPackage = ret.Package.Name
+ }
+ ret.PbPackage = GoSanitized(filepath.Base(ret.GoPackage))
+ ret.Src = abs
+ ret.Name = name
+ ret.Service = service
+
+ return ret, nil
+}
+
+// see google.golang.org/protobuf@v1.25.0/internal/strs/strings.go:71
+func GoSanitized(s string) string {
+ // Sanitize the input to the set of valid characters,
+ // which must be '_' or be in the Unicode L or N categories.
+ s = strings.Map(func(r rune) rune {
+ if unicode.IsLetter(r) || unicode.IsDigit(r) {
+ return r
+ }
+ return '_'
+ }, s)
+
+ // Prepend '_' in the event of a Go keyword conflict or if
+ // the identifier is invalid (does not start in the Unicode L category).
+ r, _ := utf8.DecodeRuneInString(s)
+ if token.Lookup(s).IsKeyword() || !unicode.IsLetter(r) {
+ return "_" + s
+ }
+ return s
+}
+
+// copy from github.com/golang/protobuf@v1.4.2/protoc-gen-go/generator/generator.go:2648
+func CamelCase(s string) string {
+ if s == "" {
+ return ""
+ }
+ t := make([]byte, 0, 32)
+ i := 0
+ if s[0] == '_' {
+ // Need a capital letter; drop the '_'.
+ t = append(t, 'X')
+ i++
+ }
+ // Invariant: if the next letter is lower case, it must be converted
+ // to upper case.
+ // That is, we process a word at a time, where words are marked by _ or
+ // upper case letter. Digits are treated as words.
+ for ; i < len(s); i++ {
+ c := s[i]
+ if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) {
+ continue // Skip the underscore in s.
+ }
+ if isASCIIDigit(c) {
+ t = append(t, c)
+ continue
+ }
+ // Assume we have a letter now - if not, it's a bogus identifier.
+ // The next word is a sequence of characters that must start upper case.
+ if isASCIILower(c) {
+ c ^= ' ' // Make it a capital letter.
+ }
+ t = append(t, c) // Guaranteed not lower case.
+ // Accept lower case sequence that follows.
+ for i+1 < len(s) && isASCIILower(s[i+1]) {
+ i++
+ t = append(t, s[i])
+ }
+ }
+ return string(t)
+}
+func isASCIILower(c byte) bool {
+ return 'a' <= c && c <= 'z'
+}
+
+// Is c an ASCII digit?
+func isASCIIDigit(c byte) bool {
+ return '0' <= c && c <= '9'
}
diff --git a/tools/goctl/rpc/parser/parser_test.go b/tools/goctl/rpc/parser/parser_test.go
new file mode 100644
index 00000000..2674a8d4
--- /dev/null
+++ b/tools/goctl/rpc/parser/parser_test.go
@@ -0,0 +1,78 @@
+package parser
+
+import (
+ "sort"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDefaultProtoParse(t *testing.T) {
+ p := NewDefaultProtoParser()
+ data, err := p.Parse("./test.proto")
+ assert.Nil(t, err)
+ assert.Equal(t, "base.proto", func() string {
+ ip := data.Import[0]
+ return ip.Filename
+ }())
+ assert.Equal(t, "test", data.Package.Name)
+ assert.Equal(t, true, data.GoPackage == "go")
+ assert.Equal(t, true, data.PbPackage == "_go")
+ assert.Equal(t, []string{"TestMessage", "TestReply", "TestReq"}, func() []string {
+ var list []string
+ for _, item := range data.Message {
+ list = append(list, item.Name)
+ }
+ sort.Strings(list)
+ return list
+ }())
+
+ assert.Equal(t, true, func() bool {
+ s := data.Service
+ if s.Name != "TestService" {
+ return false
+ }
+ rpcOne := s.RPC[0]
+
+ return rpcOne.Name == "TestRpcOne" && rpcOne.RequestType == "TestReq" && rpcOne.ReturnsType == "TestReply"
+ }())
+}
+
+func TestDefaultProtoParseCaseInvalidRequestType(t *testing.T) {
+ p := NewDefaultProtoParser()
+ _, err := p.Parse("./test_invalid_request.proto")
+ assert.True(t, true, func() bool {
+ return strings.Contains(err.Error(), "request type must defined in")
+ }())
+}
+
+func TestDefaultProtoParseCaseInvalidResponseType(t *testing.T) {
+ p := NewDefaultProtoParser()
+ _, err := p.Parse("./test_invalid_response.proto")
+ assert.True(t, true, func() bool {
+ return strings.Contains(err.Error(), "response type must defined in")
+ }())
+}
+
+func TestDefaultProtoParseError(t *testing.T) {
+ p := NewDefaultProtoParser()
+ _, err := p.Parse("./nil.proto")
+ assert.NotNil(t, err)
+}
+
+func TestDefaultProtoParse_Option(t *testing.T) {
+ p := NewDefaultProtoParser()
+ data, err := p.Parse("./test_option.proto")
+ assert.Nil(t, err)
+ assert.Equal(t, "github.com/tal-tech/go-zero", data.GoPackage)
+ assert.Equal(t, "go_zero", data.PbPackage)
+}
+
+func TestDefaultProtoParse_Option2(t *testing.T) {
+ p := NewDefaultProtoParser()
+ data, err := p.Parse("./test_option2.proto")
+ assert.Nil(t, err)
+ assert.Equal(t, "stream", data.GoPackage)
+ assert.Equal(t, "stream", data.PbPackage)
+}
diff --git a/tools/goctl/rpc/parser/pbast.go b/tools/goctl/rpc/parser/pbast.go
deleted file mode 100644
index 7bed50a6..00000000
--- a/tools/goctl/rpc/parser/pbast.go
+++ /dev/null
@@ -1,643 +0,0 @@
-package parser
-
-import (
- "errors"
- "fmt"
- "go/ast"
- "go/parser"
- "go/token"
- "io/ioutil"
- "sort"
- "strings"
-
- "github.com/tal-tech/go-zero/core/lang"
- sx "github.com/tal-tech/go-zero/core/stringx"
- "github.com/tal-tech/go-zero/tools/goctl/util"
- "github.com/tal-tech/go-zero/tools/goctl/util/console"
- "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
-)
-
-const (
- flagStar = "*"
- flagDot = "."
- suffixServer = "Server"
- referenceContext = "context"
- unknownPrefix = "XXX_"
- ignoreJsonTagExpression = `json:"-"`
-)
-
-var (
- errorParseError = errors.New("pb parse error")
- typeTemplate = `type (
- {{.types}}
-)`
- structTemplate = `{{if .type}}type {{end}}{{.name}} struct {
- {{.fields}}
-}`
- fieldTemplate = `{{if .hasDoc}}{{.doc}}
-{{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}`
-
- anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}"
-
- objectM = make(map[string]*Struct)
-)
-
-type (
- astParser struct {
- filterStruct map[string]lang.PlaceholderType
- filterEnum map[string]*Enum
- console.Console
- fileSet *token.FileSet
- proto *Proto
- }
- Field struct {
- Name stringx.String
- Type Type
- JsonTag string
- Document []string
- Comment []string
- }
- Struct struct {
- Name stringx.String
- Document []string
- Comment []string
- Field []*Field
- }
- ConstLit struct {
- Name stringx.String
- Document []string
- Comment []string
- Lit []*Lit
- }
- Lit struct {
- Key string
- Value int
- }
- Type struct {
- // eg:context.Context
- Expression string
- // eg: *context.Context
- StarExpression string
- // Invoke Type Expression
- InvokeTypeExpression string
- // eg:context
- Package string
- // eg:Context
- Name string
- }
- Func struct {
- Name stringx.String
- ParameterIn Type
- ParameterOut Type
- Document []string
- }
- RpcService struct {
- Name stringx.String
- Funcs []*Func
- }
- // parsing for rpc
- PbAst struct {
- // deprecated: containsAny will be removed in the feature
- ContainsAny bool
- Imports map[string]string
- Structure map[string]*Struct
- Service []*RpcService
- *Proto
- }
-)
-
-func MustNewAstParser(proto *Proto, log console.Console) *astParser {
- return &astParser{
- filterStruct: proto.Message,
- filterEnum: proto.Enum,
- Console: log,
- fileSet: token.NewFileSet(),
- proto: proto,
- }
-}
-func (a *astParser) Parse() (*PbAst, error) {
- var pbAst PbAst
- pbAst.ContainsAny = a.proto.ContainsAny
- pbAst.Proto = a.proto
- pbAst.Structure = make(map[string]*Struct)
- pbAst.Imports = make(map[string]string)
- structure, imports, services, err := a.parse(a.proto.PbSrc)
- if err != nil {
- return nil, err
- }
- dependencyStructure, err := a.parseExternalDependency()
- if err != nil {
- return nil, err
- }
- for k, v := range structure {
- pbAst.Structure[k] = v
- }
- for k, v := range dependencyStructure {
- pbAst.Structure[k] = v
- }
- for key, path := range imports {
- pbAst.Imports[key] = path
- }
- pbAst.Service = append(pbAst.Service, services...)
- return &pbAst, nil
-}
-
-func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) {
- structure = make(map[string]*Struct)
- imports = make(map[string]string)
- data, err := ioutil.ReadFile(pbSrc)
- if err != nil {
- retErr = err
- return
- }
- fSet := a.fileSet
- f, err := parser.ParseFile(fSet, "", data, parser.ParseComments)
- if err != nil {
- retErr = err
- return
- }
- commentMap := ast.NewCommentMap(fSet, f, f.Comments)
- f.Comments = commentMap.Filter(f).Comments()
- strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name))
- for k, v := range strucs {
- if v == nil {
- continue
- }
- structure[k] = v
- }
- importList := f.Imports
- for _, item := range importList {
- name := a.mustGetIndentName(item.Name)
- if item.Path != nil {
- imports[name] = item.Path.Value
- }
- }
- services = append(services, function...)
- return
-}
-func (a *astParser) parseExternalDependency() (map[string]*Struct, error) {
- m := make(map[string]*Struct)
- for _, impo := range a.proto.Import {
- ret, _, _, err := a.parse(impo.OriginalPbPath)
- if err != nil {
- return nil, err
- }
- for k, v := range ret {
- m[k] = v
- }
- }
- return m, nil
-}
-
-func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) {
- if scope == nil {
- return nil, nil
- }
-
- objects := scope.Objects
- structs := make(map[string]*Struct)
- serviceList := make([]*RpcService, 0)
- for name, obj := range objects {
- decl := obj.Decl
- if decl == nil {
- continue
- }
- typeSpec, ok := decl.(*ast.TypeSpec)
- if !ok {
- continue
- }
- tp := typeSpec.Type
-
- switch v := tp.(type) {
-
- case *ast.StructType:
- st, err := a.parseObject(name, v, sourcePackage)
- a.Must(err)
- structs[st.Name.Lower()] = st
-
- case *ast.InterfaceType:
- if !strings.HasSuffix(name, suffixServer) {
- continue
- }
- list := a.mustServerFunctions(v, sourcePackage)
- serviceList = append(serviceList, &RpcService{
- Name: stringx.From(strings.TrimSuffix(name, suffixServer)),
- Funcs: list,
- })
- }
- }
- targetStruct := make(map[string]*Struct)
- for st := range a.filterStruct {
- lower := strings.ToLower(st)
- targetStruct[lower] = structs[lower]
- }
- return targetStruct, serviceList
-}
-
-func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func {
- funcs := make([]*Func, 0)
- methodObject := v.Methods
- if methodObject == nil {
- return nil
- }
-
- for _, method := range methodObject.List {
- var item Func
- name := a.mustGetIndentName(method.Names[0])
- doc := a.parseCommentOrDoc(method.Doc)
- item.Name = stringx.From(name)
- item.Document = doc
- types := method.Type
- if types == nil {
- funcs = append(funcs, &item)
- continue
- }
- v, ok := types.(*ast.FuncType)
- if !ok {
- continue
- }
- params := v.Params
- if params != nil {
- inList, err := a.parseFields(params.List, true, sourcePackage)
- a.Must(err)
-
- for _, data := range inList {
- if data.Type.Package == referenceContext {
- continue
- }
- item.ParameterIn = data.Type
- break
- }
- }
- results := v.Results
- if results != nil {
- outList, err := a.parseFields(results.List, true, sourcePackage)
- a.Must(err)
-
- for _, data := range outList {
- if data.Type.Package == referenceContext {
- continue
- }
- item.ParameterOut = data.Type
- break
- }
- }
- funcs = append(funcs, &item)
- }
- return funcs
-}
-
-func (a *astParser) getFieldType(v string, sourcePackage string) Type {
- var pkg, name, expression, starExpression, invokeTypeExpression string
-
- if strings.Contains(v, ".") {
- starExpression = v
- if strings.Contains(v, "*") {
- leftIndex := strings.Index(v, "*")
- rightIndex := strings.Index(v, ".")
- if leftIndex >= 0 {
- invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
- } else {
- invokeTypeExpression = v[rightIndex+1:]
- }
- } else {
- if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
- leftIndex := strings.Index(v, "]")
- rightIndex := strings.Index(v, ".")
- invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
- } else {
- rightIndex := strings.Index(v, ".")
- invokeTypeExpression = v[rightIndex+1:]
- }
- }
- } else {
- expression = strings.TrimPrefix(v, flagStar)
- switch v {
- case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64",
- "bool", "string", "bytes":
- invokeTypeExpression = v
- break
- default:
- name = expression
- invokeTypeExpression = v
- if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
- starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".")
- } else {
- starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name)
- invokeTypeExpression = v
- }
-
- }
- }
- expression = strings.TrimPrefix(starExpression, flagStar)
- index := strings.LastIndex(expression, flagDot)
- if index > 0 {
- pkg = expression[0:index]
- name = expression[index+1:]
- } else {
- pkg = sourcePackage
- }
-
- return Type{
- Expression: expression,
- StarExpression: starExpression,
- InvokeTypeExpression: invokeTypeExpression,
- Package: pkg,
- Name: name,
- }
-}
-
-func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) {
- if data, ok := objectM[structName]; ok {
- return data, nil
- }
- var st Struct
- st.Name = stringx.From(structName)
- if tp == nil {
- return &st, nil
- }
-
- fields := tp.Fields
- if fields == nil {
- objectM[structName] = &st
- return &st, nil
- }
-
- fieldList := fields.List
- members, err := a.parseFields(fieldList, false, sourcePackage)
- if err != nil {
- return nil, err
- }
-
- for _, m := range members {
- var field Field
- field.Name = m.Name
- field.Type = m.Type
- field.JsonTag = m.JsonTag
- field.Document = m.Document
- field.Comment = m.Comment
- st.Field = append(st.Field, &field)
- }
- objectM[structName] = &st
- return &st, nil
-}
-
-func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) {
- ret := make([]*Field, 0)
- for _, field := range fields {
- var item Field
- tag := a.parseTag(field.Tag)
- if tag == "" && !onlyType {
- continue
- }
- if tag == ignoreJsonTagExpression {
- continue
- }
-
- item.JsonTag = tag
- name := a.parseName(field.Names)
- if strings.HasPrefix(name, unknownPrefix) {
- continue
- }
- item.Name = stringx.From(name)
- typeName, err := a.parseType(field.Type)
- if err != nil {
- return nil, err
- }
-
- item.Type = a.getFieldType(typeName, sourcePackage)
- if onlyType {
- ret = append(ret, &item)
- continue
- }
- docs := a.parseCommentOrDoc(field.Doc)
- comments := a.parseCommentOrDoc(field.Comment)
-
- item.Document = docs
- item.Comment = comments
-
- isInline := name == ""
- if isInline {
- return nil, a.wrapError(field.Pos(), "unexpected inline type:%s", name)
- }
-
- ret = append(ret, &item)
-
- }
- return ret, nil
-}
-
-func (a *astParser) parseTag(basicLit *ast.BasicLit) string {
- if basicLit == nil {
- return ""
- }
- value := basicLit.Value
- splits := strings.Split(value, " ")
- if len(splits) == 1 {
- return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[0], "`", ""))
- } else {
- return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[1], "`", ""))
- }
-}
-
-// returns
-// resp1:type's string expression,like int、string、[]int64、map[string]User、*User
-// resp2:error
-func (a *astParser) parseType(expr ast.Expr) (string, error) {
- if expr == nil {
- return "", errorParseError
- }
-
- switch v := expr.(type) {
- case *ast.StarExpr:
- stringExpr, err := a.parseType(v.X)
- if err != nil {
- return "", err
- }
-
- e := fmt.Sprintf("*%s", stringExpr)
- return e, nil
-
- case *ast.Ident:
- return a.mustGetIndentName(v), nil
- case *ast.MapType:
- keyStringExpr, err := a.parseType(v.Key)
- if err != nil {
- return "", err
- }
-
- valueStringExpr, err := a.parseType(v.Value)
- if err != nil {
- return "", err
- }
-
- e := fmt.Sprintf("map[%s]%s", keyStringExpr, valueStringExpr)
- return e, nil
- case *ast.ArrayType:
- stringExpr, err := a.parseType(v.Elt)
- if err != nil {
- return "", err
- }
-
- e := fmt.Sprintf("[]%s", stringExpr)
- return e, nil
- case *ast.InterfaceType:
- return "interface{}", nil
- case *ast.SelectorExpr:
- join := make([]string, 0)
- xIdent, ok := v.X.(*ast.Ident)
- xIndentName := a.mustGetIndentName(xIdent)
- if ok {
- join = append(join, xIndentName)
- }
- sel := v.Sel
- join = append(join, a.mustGetIndentName(sel))
- return strings.Join(join, "."), nil
- case *ast.ChanType:
- return "", a.wrapError(v.Pos(), "unexpected type 'chan'")
- case *ast.FuncType:
- return "", a.wrapError(v.Pos(), "unexpected type 'func'")
- case *ast.StructType:
- return "", a.wrapError(v.Pos(), "unexpected inline struct type")
- default:
- return "", a.wrapError(v.Pos(), "unexpected type '%v'", v)
- }
-}
-func (a *astParser) parseName(names []*ast.Ident) string {
- if len(names) == 0 {
- return ""
- }
- name := names[0]
- return a.mustGetIndentName(name)
-}
-
-func (a *astParser) parseCommentOrDoc(cg *ast.CommentGroup) []string {
- if cg == nil {
- return nil
- }
- comments := make([]string, 0)
- for _, comment := range cg.List {
- if comment == nil {
- continue
- }
- text := strings.TrimSpace(comment.Text)
- if text == "" {
- continue
- }
- comments = append(comments, text)
- }
- return comments
-}
-
-func (a *astParser) mustGetIndentName(ident *ast.Ident) string {
- if ident == nil {
- return ""
- }
- return ident.Name
-}
-
-func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{}) error {
- file := a.fileSet.Position(pos)
- return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...))
-}
-
-func (f *Func) GetDoc() string {
- return strings.Join(f.Document, util.NL)
-}
-
-func (f *Func) HaveDoc() bool {
- return len(f.Document) > 0
-}
-
-func (a *PbAst) GenEnumCode() (string, error) {
- var element []string
- for _, item := range a.Enum {
- code, err := item.GenEnumCode()
- if err != nil {
- return "", err
- }
- element = append(element, code)
- }
- return strings.Join(element, util.NL), nil
-}
-
-func (a *PbAst) GenTypesCode() (string, error) {
- types := make([]string, 0)
- sts := make([]*Struct, 0)
- for _, item := range a.Structure {
- sts = append(sts, item)
- }
- sort.Slice(sts, func(i, j int) bool {
- return sts[i].Name.Source() < sts[j].Name.Source()
- })
- for _, s := range sts {
- structCode, err := s.genCode(false)
- if err != nil {
- return "", err
- }
-
- if structCode == "" {
- continue
- }
- types = append(types, structCode)
- }
- types = append(types, a.genAnyCode())
- for _, item := range a.Enum {
- typeCode, err := item.GenEnumTypeCode()
- if err != nil {
- return "", err
- }
- types = append(types, typeCode)
- }
-
- buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
- "types": strings.Join(types, util.NL+util.NL),
- })
- if err != nil {
- return "", err
- }
-
- return buffer.String(), nil
-}
-
-func (a *PbAst) genAnyCode() string {
- if !a.ContainsAny {
- return ""
- }
- return anyTypeTemplate
-}
-
-func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
- fields := make([]string, 0)
- for _, f := range s.Field {
- var comment, doc string
- if len(f.Comment) > 0 {
- comment = f.Comment[0]
- }
- doc = strings.Join(f.Document, util.NL)
- buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
- "name": f.Name.Title(),
- "type": f.Type.InvokeTypeExpression,
- "tag": f.JsonTag,
- "hasDoc": len(f.Document) > 0,
- "doc": doc,
- "hasComment": len(f.Comment) > 0,
- "comment": comment,
- })
- if err != nil {
- return "", err
- }
-
- fields = append(fields, buffer.String())
- }
- buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
- "type": containsTypeStatement,
- "name": s.Name.Title(),
- "fields": strings.Join(fields, util.NL),
- })
- if err != nil {
- return "", err
- }
-
- return buffer.String(), nil
-}
diff --git a/tools/goctl/rpc/parser/proto.go b/tools/goctl/rpc/parser/proto.go
index cf5dbea7..4e23b8e5 100644
--- a/tools/goctl/rpc/parser/proto.go
+++ b/tools/goctl/rpc/parser/proto.go
@@ -1,295 +1,12 @@
package parser
-import (
- "errors"
- "fmt"
- "os"
- "path/filepath"
- "strings"
-
- "github.com/emicklei/proto"
- "github.com/tal-tech/go-zero/core/collection"
- "github.com/tal-tech/go-zero/core/lang"
- "github.com/tal-tech/go-zero/tools/goctl/util"
- "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
-)
-
-const (
- AnyImport = "google/protobuf/any.proto"
-)
-
-var (
- enumTypeTemplate = `{{.name}} int32`
- enumTemplate = `const (
- {{.element}}
-)`
- enumFiledTemplate = `{{.key}} {{.name}} = {{.value}}`
-)
-
-type (
- MessageField struct {
- Type string
- Name stringx.String
- }
- Message struct {
- Name stringx.String
- Element []*MessageField
- *proto.Message
- }
- Enum struct {
- Name stringx.String
- Element []*EnumField
- *proto.Enum
- }
- EnumField struct {
- Key string
- Value int
- }
-
- Proto struct {
- Package string
- Import []*Import
- PbSrc string
- // deprecated: containsAny will be removed in the feature
- ContainsAny bool
- Message map[string]lang.PlaceholderType
- Enum map[string]*Enum
- }
- Import struct {
- ProtoImportName string
- PbImportName string
- OriginalDir string
- OriginalProtoPath string
- OriginalPbPath string
- BridgeImport string
- exists bool
- //xx.proto
- protoName string
- // xx.pb.go
- pbName string
- }
-)
-
-func checkImport(src string) error {
- r, err := os.Open(src)
- if err != nil {
- return err
- }
- defer r.Close()
-
- parser := proto.NewParser(r)
- parseRet, err := parser.Parse()
- if err != nil {
- return err
- }
- var base = filepath.Base(src)
- proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
- if err != nil {
- return
- }
- err = fmt.Errorf("%v:%v the external proto cannot import other proto files", base, i.Position.Line)
- }))
- if err != nil {
- return err
- }
- return nil
-}
-func ParseImport(src string) ([]*Import, bool, error) {
- bridgeImportM := make(map[string]string)
- r, err := os.Open(src)
- if err != nil {
- return nil, false, err
- }
- defer r.Close()
-
- workDir := filepath.Dir(src)
- parser := proto.NewParser(r)
- parseRet, err := parser.Parse()
- if err != nil {
- return nil, false, err
- }
- protoImportSet := collection.NewSet()
- var containsAny bool
- proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
- if i.Filename == AnyImport {
- containsAny = true
- return
- }
- protoImportSet.AddStr(i.Filename)
- if i.Comment != nil {
- lines := i.Comment.Lines
- for _, line := range lines {
- line = strings.TrimSpace(line)
- if !strings.HasPrefix(line, "@") {
- continue
- }
- line = strings.TrimPrefix(line, "@")
- bridgeImportM[i.Filename] = line
- }
- }
- }))
- var importList []*Import
-
- for _, item := range protoImportSet.KeysStr() {
- pb := strings.TrimSuffix(filepath.Base(item), filepath.Ext(item)) + ".pb.go"
- var pbImportName, brideImport string
- if v, ok := bridgeImportM[item]; ok {
- pbImportName = v
- brideImport = "M" + item + "=" + v
- } else {
- pbImportName = item
- }
- var impo = Import{
- ProtoImportName: item,
- PbImportName: pbImportName,
- BridgeImport: brideImport,
- }
- protoSource := filepath.Join(workDir, item)
- pbSource := filepath.Join(filepath.Dir(protoSource), pb)
- if util.FileExists(protoSource) && util.FileExists(pbSource) {
- impo.OriginalProtoPath = protoSource
- impo.OriginalPbPath = pbSource
- impo.OriginalDir = filepath.Dir(protoSource)
- impo.exists = true
- impo.protoName = filepath.Base(item)
- impo.pbName = pb
- } else {
- return nil, false, fmt.Errorf("「%v」: import must be found in the relative directory of 「%v」", item, filepath.Base(src))
- }
- importList = append(importList, &impo)
- }
-
- return importList, containsAny, nil
-}
-
-func parseProto(src string, messageM map[string]lang.PlaceholderType, enumM map[string]*Enum) (*Proto, error) {
- if !filepath.IsAbs(src) {
- return nil, fmt.Errorf("expected absolute path,but found: %v", src)
- }
-
- r, err := os.Open(src)
- if err != nil {
- return nil, err
- }
- defer r.Close()
-
- parser := proto.NewParser(r)
- parseRet, err := parser.Parse()
- if err != nil {
- return nil, err
- }
-
- // xx.proto
- fileBase := filepath.Base(src)
- var resp Proto
-
- proto.Walk(parseRet, proto.WithPackage(func(p *proto.Package) {
- if err != nil {
- return
- }
-
- if len(resp.Package) != 0 {
- err = fmt.Errorf("%v:%v duplicate package「%v」", fileBase, p.Position.Line, p.Name)
- }
-
- if len(p.Name) == 0 {
- err = errors.New("package not found")
- }
-
- resp.Package = p.Name
- }), proto.WithMessage(func(message *proto.Message) {
- if err != nil {
- return
- }
-
- for _, item := range message.Elements {
- switch item.(type) {
- case *proto.NormalField, *proto.MapField, *proto.Comment:
- continue
- default:
- err = fmt.Errorf("%v: unsupport inline declaration", fileBase)
- return
- }
- }
- name := stringx.From(message.Name)
- if _, ok := messageM[name.Lower()]; ok {
- err = fmt.Errorf("%v:%v duplicate message 「%v」", fileBase, message.Position.Line, message.Name)
- return
- }
-
- messageM[name.Lower()] = lang.Placeholder
- }), proto.WithEnum(func(enum *proto.Enum) {
- if err != nil {
- return
- }
-
- var node Enum
- node.Enum = enum
- node.Name = stringx.From(enum.Name)
- for _, item := range enum.Elements {
- v, ok := item.(*proto.EnumField)
- if !ok {
- continue
- }
- node.Element = append(node.Element, &EnumField{
- Key: v.Name,
- Value: v.Integer,
- })
- }
- if _, ok := enumM[node.Name.Lower()]; ok {
- err = fmt.Errorf("%v:%v duplicate enum 「%v」", fileBase, node.Position.Line, node.Name.Source())
- return
- }
-
- lower := stringx.From(enum.Name).Lower()
- enumM[lower] = &node
- }))
-
- if err != nil {
- return nil, err
- }
- resp.Message = messageM
- resp.Enum = enumM
-
- return &resp, nil
-}
-
-func (e *Enum) GenEnumCode() (string, error) {
- var element []string
- for _, item := range e.Element {
- code, err := item.GenEnumFieldCode(e.Name.Source())
- if err != nil {
- return "", err
- }
- element = append(element, code)
- }
- buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
- "element": strings.Join(element, util.NL),
- })
- if err != nil {
- return "", err
- }
- return buffer.String(), nil
-}
-
-func (e *Enum) GenEnumTypeCode() (string, error) {
- buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
- "name": e.Name.Source(),
- })
- if err != nil {
- return "", err
- }
- return buffer.String(), nil
-}
-
-func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) {
- buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
- "key": e.Key,
- "name": parentName,
- "value": e.Value,
- })
- if err != nil {
- return "", err
- }
- return buffer.String(), nil
+type Proto struct {
+ Src string
+ Name string
+ Package Package
+ PbPackage string
+ GoPackage string
+ Import []Import
+ Message []Message
+ Service Service
}
diff --git a/tools/goctl/rpc/parser/rpc.go b/tools/goctl/rpc/parser/rpc.go
new file mode 100644
index 00000000..bf4cd8ed
--- /dev/null
+++ b/tools/goctl/rpc/parser/rpc.go
@@ -0,0 +1,7 @@
+package parser
+
+import "github.com/emicklei/proto"
+
+type RPC struct {
+ *proto.RPC
+}
diff --git a/tools/goctl/rpc/parser/service.go b/tools/goctl/rpc/parser/service.go
new file mode 100644
index 00000000..9ace8dfb
--- /dev/null
+++ b/tools/goctl/rpc/parser/service.go
@@ -0,0 +1,8 @@
+package parser
+
+import "github.com/emicklei/proto"
+
+type Service struct {
+ *proto.Service
+ RPC []*RPC
+}
diff --git a/tools/goctl/rpc/parser/test.proto b/tools/goctl/rpc/parser/test.proto
new file mode 100644
index 00000000..45643864
--- /dev/null
+++ b/tools/goctl/rpc/parser/test.proto
@@ -0,0 +1,20 @@
+syntax = "proto3";
+
+package test;
+option go_package = "go";
+
+import "base.proto";
+
+message TestMessage{}
+message TestReq{}
+message TestReply{}
+
+enum TestEnum {
+ unknown = 0;
+ male = 1;
+ female = 2;
+}
+
+service TestService{
+ rpc TestRpcOne (TestReq)returns(TestReply);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/parser/test_invalid_request.proto b/tools/goctl/rpc/parser/test_invalid_request.proto
new file mode 100644
index 00000000..cdb1e438
--- /dev/null
+++ b/tools/goctl/rpc/parser/test_invalid_request.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+package test;
+option go_package = "go";
+
+import "base.proto";
+
+message Reply{}
+
+
+service TestService{
+ rpc TestRpcTwo (base.Req)returns(Reply);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/parser/test_invalid_response.proto b/tools/goctl/rpc/parser/test_invalid_response.proto
new file mode 100644
index 00000000..2adf5d1d
--- /dev/null
+++ b/tools/goctl/rpc/parser/test_invalid_response.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+package test;
+option go_package = "go";
+
+import "base.proto";
+
+message Req{}
+
+
+service TestService{
+ rpc TestRpcTwo (Req)returns(base.Reply);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/parser/test_option.proto b/tools/goctl/rpc/parser/test_option.proto
new file mode 100644
index 00000000..9953136a
--- /dev/null
+++ b/tools/goctl/rpc/parser/test_option.proto
@@ -0,0 +1,10 @@
+syntax = "proto3";
+
+package stream;
+
+option go_package="github.com/tal-tech/go-zero";
+
+message placeholder{}
+service greet{
+ rpc hello(placeholder)returns(placeholder);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/parser/test_option2.proto b/tools/goctl/rpc/parser/test_option2.proto
new file mode 100644
index 00000000..aa680cdc
--- /dev/null
+++ b/tools/goctl/rpc/parser/test_option2.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package stream;
+
+
+message placeholder{}
+service greet{
+ rpc hello(placeholder)returns(placeholder);
+}
\ No newline at end of file
diff --git a/tools/goctl/rpc/test.proto b/tools/goctl/rpc/test.proto
deleted file mode 100644
index d98ec884..00000000
--- a/tools/goctl/rpc/test.proto
+++ /dev/null
@@ -1,28 +0,0 @@
-syntax = "proto3";
-// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test
-package test;
-// @github.com/tal-tech/go-zero/tools/goctl/rpc
-import "base.proto";
-import "google/protobuf/any.proto";
-
-message request {
- string name = 1;
-}
-enum Gender{
- UNKNOWN = 0;
- MALE = 1;
- FEMALE = 2;
-}
-message response {
- string greet = 1;
- google.protobuf.Any data = 2;
-
-}
-message map {
- map m = 1;
-}
-
-service Greeter {
- rpc greet(request) returns (response);
- rpc idRequest(base.IdRequest)returns(base.EmptyResponse);
-}
\ No newline at end of file
diff --git a/tools/goctl/rpc/test/test.pb.go b/tools/goctl/rpc/test/test.pb.go
deleted file mode 100644
index 298ec08d..00000000
--- a/tools/goctl/rpc/test/test.pb.go
+++ /dev/null
@@ -1,331 +0,0 @@
-// Code generated by protoc-gen-go. DO NOT EDIT.
-// source: test.proto
-
-// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test
-
-package test
-
-import (
- context "context"
- fmt "fmt"
- proto "github.com/golang/protobuf/proto"
- rpc "github.com/tal-tech/go-zero/tools/goctl/rpc"
- grpc "google.golang.org/grpc"
- codes "google.golang.org/grpc/codes"
- status "google.golang.org/grpc/status"
- anypb "google.golang.org/protobuf/types/known/anypb"
- math "math"
-)
-
-// Reference imports to suppress errors if they are not otherwise used.
-var _ = proto.Marshal
-var _ = fmt.Errorf
-var _ = math.Inf
-
-// This is a compile-time assertion to ensure that this generated file
-// is compatible with the proto package it is being compiled against.
-// A compilation error at this line likely means your copy of the
-// proto package needs to be updated.
-const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
-
-type Gender int32
-
-const (
- Gender_UNKNOWN Gender = 0
- Gender_MALE Gender = 1
- Gender_FEMALE Gender = 2
-)
-
-var Gender_name = map[int32]string{
- 0: "UNKNOWN",
- 1: "MALE",
- 2: "FEMALE",
-}
-
-var Gender_value = map[string]int32{
- "UNKNOWN": 0,
- "MALE": 1,
- "FEMALE": 2,
-}
-
-func (x Gender) String() string {
- return proto.EnumName(Gender_name, int32(x))
-}
-
-func (Gender) EnumDescriptor() ([]byte, []int) {
- return fileDescriptor_c161fcfdc0c3ff1e, []int{0}
-}
-
-type Request struct {
- Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
- XXX_NoUnkeyedLiteral struct{} `json:"-"`
- XXX_unrecognized []byte `json:"-"`
- XXX_sizecache int32 `json:"-"`
-}
-
-func (m *Request) Reset() { *m = Request{} }
-func (m *Request) String() string { return proto.CompactTextString(m) }
-func (*Request) ProtoMessage() {}
-func (*Request) Descriptor() ([]byte, []int) {
- return fileDescriptor_c161fcfdc0c3ff1e, []int{0}
-}
-
-func (m *Request) XXX_Unmarshal(b []byte) error {
- return xxx_messageInfo_Request.Unmarshal(m, b)
-}
-func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
- return xxx_messageInfo_Request.Marshal(b, m, deterministic)
-}
-func (m *Request) XXX_Merge(src proto.Message) {
- xxx_messageInfo_Request.Merge(m, src)
-}
-func (m *Request) XXX_Size() int {
- return xxx_messageInfo_Request.Size(m)
-}
-func (m *Request) XXX_DiscardUnknown() {
- xxx_messageInfo_Request.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_Request proto.InternalMessageInfo
-
-func (m *Request) GetName() string {
- if m != nil {
- return m.Name
- }
- return ""
-}
-
-type Response struct {
- Greet string `protobuf:"bytes,1,opt,name=greet,proto3" json:"greet,omitempty"`
- Data *anypb.Any `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
- XXX_NoUnkeyedLiteral struct{} `json:"-"`
- XXX_unrecognized []byte `json:"-"`
- XXX_sizecache int32 `json:"-"`
-}
-
-func (m *Response) Reset() { *m = Response{} }
-func (m *Response) String() string { return proto.CompactTextString(m) }
-func (*Response) ProtoMessage() {}
-func (*Response) Descriptor() ([]byte, []int) {
- return fileDescriptor_c161fcfdc0c3ff1e, []int{1}
-}
-
-func (m *Response) XXX_Unmarshal(b []byte) error {
- return xxx_messageInfo_Response.Unmarshal(m, b)
-}
-func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
- return xxx_messageInfo_Response.Marshal(b, m, deterministic)
-}
-func (m *Response) XXX_Merge(src proto.Message) {
- xxx_messageInfo_Response.Merge(m, src)
-}
-func (m *Response) XXX_Size() int {
- return xxx_messageInfo_Response.Size(m)
-}
-func (m *Response) XXX_DiscardUnknown() {
- xxx_messageInfo_Response.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_Response proto.InternalMessageInfo
-
-func (m *Response) GetGreet() string {
- if m != nil {
- return m.Greet
- }
- return ""
-}
-
-func (m *Response) GetData() *anypb.Any {
- if m != nil {
- return m.Data
- }
- return nil
-}
-
-type Map struct {
- M map[string]string `protobuf:"bytes,1,rep,name=m,proto3" json:"m,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
- XXX_NoUnkeyedLiteral struct{} `json:"-"`
- XXX_unrecognized []byte `json:"-"`
- XXX_sizecache int32 `json:"-"`
-}
-
-func (m *Map) Reset() { *m = Map{} }
-func (m *Map) String() string { return proto.CompactTextString(m) }
-func (*Map) ProtoMessage() {}
-func (*Map) Descriptor() ([]byte, []int) {
- return fileDescriptor_c161fcfdc0c3ff1e, []int{2}
-}
-
-func (m *Map) XXX_Unmarshal(b []byte) error {
- return xxx_messageInfo_Map.Unmarshal(m, b)
-}
-func (m *Map) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
- return xxx_messageInfo_Map.Marshal(b, m, deterministic)
-}
-func (m *Map) XXX_Merge(src proto.Message) {
- xxx_messageInfo_Map.Merge(m, src)
-}
-func (m *Map) XXX_Size() int {
- return xxx_messageInfo_Map.Size(m)
-}
-func (m *Map) XXX_DiscardUnknown() {
- xxx_messageInfo_Map.DiscardUnknown(m)
-}
-
-var xxx_messageInfo_Map proto.InternalMessageInfo
-
-func (m *Map) GetM() map[string]string {
- if m != nil {
- return m.M
- }
- return nil
-}
-
-func init() {
- proto.RegisterEnum("test.Gender", Gender_name, Gender_value)
- proto.RegisterType((*Request)(nil), "test.request")
- proto.RegisterType((*Response)(nil), "test.response")
- proto.RegisterType((*Map)(nil), "test.map")
- proto.RegisterMapType((map[string]string)(nil), "test.map.MEntry")
-}
-
-func init() { proto.RegisterFile("test.proto", fileDescriptor_c161fcfdc0c3ff1e) }
-
-var fileDescriptor_c161fcfdc0c3ff1e = []byte{
- // 301 bytes of a gzipped FileDescriptorProto
- 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x34, 0x90, 0x4f, 0x4b, 0xc3, 0x40,
- 0x10, 0xc5, 0xdd, 0x36, 0xa6, 0xed, 0x14, 0x35, 0x8c, 0x3d, 0xd4, 0x80, 0x52, 0x7a, 0x90, 0xa0,
- 0xb0, 0xc5, 0xea, 0x41, 0xbc, 0xf5, 0x10, 0x8b, 0x7f, 0x5a, 0x21, 0x20, 0x1e, 0x3c, 0x6d, 0xc9,
- 0x58, 0xc4, 0xee, 0x26, 0x6e, 0xb6, 0x42, 0xbe, 0xbd, 0x64, 0x77, 0x73, 0x7b, 0xbf, 0xd9, 0x59,
- 0xde, 0x9b, 0x07, 0x60, 0xa8, 0x32, 0xbc, 0xd4, 0x85, 0x29, 0x30, 0x68, 0x74, 0x0c, 0x1b, 0x51,
- 0x91, 0x9b, 0xc4, 0x67, 0xdb, 0xa2, 0xd8, 0xee, 0x68, 0x66, 0x69, 0xb3, 0xff, 0x9a, 0x09, 0x55,
- 0xbb, 0xa7, 0xe9, 0x39, 0xf4, 0x34, 0xfd, 0xee, 0xa9, 0x32, 0x88, 0x10, 0x28, 0x21, 0x69, 0xcc,
- 0x26, 0x2c, 0x19, 0x64, 0x56, 0x4f, 0x9f, 0xa1, 0xaf, 0xa9, 0x2a, 0x0b, 0x55, 0x11, 0x8e, 0xe0,
- 0x70, 0xab, 0x89, 0x8c, 0x5f, 0x70, 0x80, 0x09, 0x04, 0xb9, 0x30, 0x62, 0xdc, 0x99, 0xb0, 0x64,
- 0x38, 0x1f, 0x71, 0x67, 0xc5, 0x5b, 0x2b, 0xbe, 0x50, 0x75, 0x66, 0x37, 0xa6, 0x9f, 0xd0, 0x95,
- 0xa2, 0xc4, 0x0b, 0x60, 0x72, 0xcc, 0x26, 0xdd, 0x64, 0x38, 0x8f, 0xb8, 0x8d, 0x2d, 0x45, 0xc9,
- 0x57, 0xa9, 0x32, 0xba, 0xce, 0x98, 0x8c, 0xef, 0x20, 0x74, 0x80, 0x11, 0x74, 0x7f, 0xa8, 0xf6,
- 0x76, 0x8d, 0x6c, 0x22, 0xfc, 0x89, 0xdd, 0x9e, 0xac, 0xdb, 0x20, 0x73, 0xf0, 0xd0, 0xb9, 0x67,
- 0x57, 0xd7, 0x10, 0x2e, 0x49, 0xe5, 0xa4, 0x71, 0x08, 0xbd, 0xf7, 0xf5, 0xcb, 0xfa, 0xed, 0x63,
- 0x1d, 0x1d, 0x60, 0x1f, 0x82, 0xd5, 0xe2, 0x35, 0x8d, 0x18, 0x02, 0x84, 0x8f, 0xa9, 0xd5, 0x9d,
- 0x79, 0x0e, 0xbd, 0x65, 0x13, 0x9e, 0x34, 0x5e, 0xfa, 0xa3, 0xf0, 0xc8, 0x65, 0xf1, 0x65, 0xc4,
- 0xc7, 0x2d, 0xfa, 0xe3, 0x6f, 0x60, 0xf0, 0x9d, 0x67, 0xbe, 0xa9, 0x13, 0x6e, 0xcb, 0x7d, 0x6a,
- 0x07, 0xf1, 0xa9, 0x1b, 0xa4, 0xb2, 0x34, 0x75, 0xe6, 0xbf, 0x6c, 0x42, 0xdb, 0xc1, 0xed, 0x7f,
- 0x00, 0x00, 0x00, 0xff, 0xff, 0x48, 0x7a, 0x3b, 0x55, 0x9c, 0x01, 0x00, 0x00,
-}
-
-// Reference imports to suppress errors if they are not otherwise used.
-var _ context.Context
-var _ grpc.ClientConn
-
-// This is a compile-time assertion to ensure that this generated file
-// is compatible with the grpc package it is being compiled against.
-const _ = grpc.SupportPackageIsVersion4
-
-// GreeterClient is the client API for Greeter service.
-//
-// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
-type GreeterClient interface {
- Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error)
- IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error)
-}
-
-type greeterClient struct {
- cc *grpc.ClientConn
-}
-
-func NewGreeterClient(cc *grpc.ClientConn) GreeterClient {
- return &greeterClient{cc}
-}
-
-func (c *greeterClient) Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) {
- out := new(Response)
- err := c.cc.Invoke(ctx, "/test.Greeter/greet", in, out, opts...)
- if err != nil {
- return nil, err
- }
- return out, nil
-}
-
-func (c *greeterClient) IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error) {
- out := new(rpc.EmptyResponse)
- err := c.cc.Invoke(ctx, "/test.Greeter/idRequest", in, out, opts...)
- if err != nil {
- return nil, err
- }
- return out, nil
-}
-
-// GreeterServer is the server API for Greeter service.
-type GreeterServer interface {
- Greet(context.Context, *Request) (*Response, error)
- IdRequest(context.Context, *rpc.IdRequest) (*rpc.EmptyResponse, error)
-}
-
-// UnimplementedGreeterServer can be embedded to have forward compatible implementations.
-type UnimplementedGreeterServer struct {
-}
-
-func (*UnimplementedGreeterServer) Greet(ctx context.Context, req *Request) (*Response, error) {
- return nil, status.Errorf(codes.Unimplemented, "method Greet not implemented")
-}
-func (*UnimplementedGreeterServer) IdRequest(ctx context.Context, req *rpc.IdRequest) (*rpc.EmptyResponse, error) {
- return nil, status.Errorf(codes.Unimplemented, "method IdRequest not implemented")
-}
-
-func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
- s.RegisterService(&_Greeter_serviceDesc, srv)
-}
-
-func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
- in := new(Request)
- if err := dec(in); err != nil {
- return nil, err
- }
- if interceptor == nil {
- return srv.(GreeterServer).Greet(ctx, in)
- }
- info := &grpc.UnaryServerInfo{
- Server: srv,
- FullMethod: "/test.Greeter/Greet",
- }
- handler := func(ctx context.Context, req interface{}) (interface{}, error) {
- return srv.(GreeterServer).Greet(ctx, req.(*Request))
- }
- return interceptor(ctx, in, info, handler)
-}
-
-func _Greeter_IdRequest_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
- in := new(rpc.IdRequest)
- if err := dec(in); err != nil {
- return nil, err
- }
- if interceptor == nil {
- return srv.(GreeterServer).IdRequest(ctx, in)
- }
- info := &grpc.UnaryServerInfo{
- Server: srv,
- FullMethod: "/test.Greeter/IdRequest",
- }
- handler := func(ctx context.Context, req interface{}) (interface{}, error) {
- return srv.(GreeterServer).IdRequest(ctx, req.(*rpc.IdRequest))
- }
- return interceptor(ctx, in, info, handler)
-}
-
-var _Greeter_serviceDesc = grpc.ServiceDesc{
- ServiceName: "test.Greeter",
- HandlerType: (*GreeterServer)(nil),
- Methods: []grpc.MethodDesc{
- {
- MethodName: "greet",
- Handler: _Greeter_Greet_Handler,
- },
- {
- MethodName: "idRequest",
- Handler: _Greeter_IdRequest_Handler,
- },
- },
- Streams: []grpc.StreamDesc{},
- Metadata: "test.proto",
-}
diff --git a/tools/goctl/tpl/templates.go b/tools/goctl/tpl/templates.go
index e10d54da..492d85e9 100644
--- a/tools/goctl/tpl/templates.go
+++ b/tools/goctl/tpl/templates.go
@@ -7,7 +7,7 @@ import (
"github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
- rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/gen"
+ rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
diff --git a/tools/goctl/util/ctx/context.go b/tools/goctl/util/ctx/context.go
new file mode 100644
index 00000000..c41b1cf8
--- /dev/null
+++ b/tools/goctl/util/ctx/context.go
@@ -0,0 +1,53 @@
+package ctx
+
+import (
+ "errors"
+ "path/filepath"
+
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+)
+
+var moduleCheckErr = errors.New("the work directory must be found in the go mod or the $GOPATH")
+
+type (
+ ProjectContext struct {
+ WorkDir string
+ // Name is the root name of the project
+ // eg: go-zero、greet
+ Name string
+ // Path identifies which module a project belongs to, which is module value if it's a go mod project,
+ // or else it is the root name of the project, eg: github.com/tal-tech/go-zero、greet
+ Path string
+ // Dir is the path of the project, eg: /Users/keson/goland/go/go-zero、/Users/keson/go/src/greet
+ Dir string
+ }
+)
+
+// Prepare checks the project which module belongs to,and returns the path and module.
+// workDir parameter is the directory of the source of generating code,
+// where can be found the project path and the project module,
+func Prepare(workDir string) (*ProjectContext, error) {
+ ctx, err := background(workDir)
+ if err == nil {
+ return ctx, nil
+ }
+
+ name := filepath.Base(workDir)
+ _, err = execx.Run("go mod init "+name, workDir)
+ if err != nil {
+ return nil, err
+ }
+ return background(workDir)
+}
+
+func background(workDir string) (*ProjectContext, error) {
+ isGoMod, err := IsGoMod(workDir)
+ if err != nil {
+ return nil, err
+ }
+
+ if isGoMod {
+ return projectFromGoMod(workDir)
+ }
+ return projectFromGoPath(workDir)
+}
diff --git a/tools/goctl/util/ctx/context_test.go b/tools/goctl/util/ctx/context_test.go
new file mode 100644
index 00000000..abc53af4
--- /dev/null
+++ b/tools/goctl/util/ctx/context_test.go
@@ -0,0 +1,22 @@
+package ctx
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBackground(t *testing.T) {
+ workDir := "."
+ ctx, err := Prepare(workDir)
+ assert.Nil(t, err)
+ assert.True(t, true, func() bool {
+ return len(ctx.Dir) != 0 && len(ctx.Path) != 0
+ }())
+}
+
+func TestBackgroundNilWorkDir(t *testing.T) {
+ workDir := ""
+ _, err := Prepare(workDir)
+ assert.NotNil(t, err)
+}
diff --git a/tools/goctl/util/ctx/gomod.go b/tools/goctl/util/ctx/gomod.go
new file mode 100644
index 00000000..c073e449
--- /dev/null
+++ b/tools/goctl/util/ctx/gomod.go
@@ -0,0 +1,47 @@
+package ctx
+
+import (
+ "errors"
+ "os"
+ "path/filepath"
+
+ "github.com/tal-tech/go-zero/core/jsonx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+)
+
+type Module struct {
+ Path string
+ Main bool
+ Dir string
+ GoMod string
+ GoVersion string
+}
+
+// projectFromGoMod is used to find the go module and project file path
+// the workDir flag specifies which folder we need to detect based on
+// only valid for go mod project
+func projectFromGoMod(workDir string) (*ProjectContext, error) {
+ if len(workDir) == 0 {
+ return nil, errors.New("the work directory is not found")
+ }
+ if _, err := os.Stat(workDir); err != nil {
+ return nil, err
+ }
+
+ data, err := execx.Run("go list -json -m", workDir)
+ if err != nil {
+ return nil, err
+ }
+
+ var m Module
+ err = jsonx.Unmarshal([]byte(data), &m)
+ if err != nil {
+ return nil, err
+ }
+ var ret ProjectContext
+ ret.WorkDir = workDir
+ ret.Name = filepath.Base(m.Dir)
+ ret.Dir = m.Dir
+ ret.Path = m.Path
+ return &ret, nil
+}
diff --git a/tools/goctl/util/ctx/gomod_test.go b/tools/goctl/util/ctx/gomod_test.go
new file mode 100644
index 00000000..bd86642b
--- /dev/null
+++ b/tools/goctl/util/ctx/gomod_test.go
@@ -0,0 +1,38 @@
+package ctx
+
+import (
+ "go/build"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/core/stringx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+)
+
+func TestProjectFromGoMod(t *testing.T) {
+ dft := build.Default
+ gp := dft.GOPATH
+ if len(gp) == 0 {
+ return
+ }
+ projectName := stringx.Rand()
+ dir := filepath.Join(gp, "src", projectName)
+ err := util.MkdirIfNotExist(dir)
+ if err != nil {
+ return
+ }
+
+ _, err = execx.Run("go mod init "+projectName, dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+
+ ctx, err := projectFromGoMod(dir)
+ assert.Nil(t, err)
+ assert.Equal(t, projectName, ctx.Path)
+ assert.Equal(t, dir, ctx.Dir)
+}
diff --git a/tools/goctl/util/ctx/gopath.go b/tools/goctl/util/ctx/gopath.go
new file mode 100644
index 00000000..ec44ef2b
--- /dev/null
+++ b/tools/goctl/util/ctx/gopath.go
@@ -0,0 +1,47 @@
+package ctx
+
+import (
+ "errors"
+ "go/build"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+)
+
+// projectFromGoPath is used to find the main module and project file path
+// the workDir flag specifies which folder we need to detect based on
+// only valid for go mod project
+func projectFromGoPath(workDir string) (*ProjectContext, error) {
+ if len(workDir) == 0 {
+ return nil, errors.New("the work directory is not found")
+ }
+ if _, err := os.Stat(workDir); err != nil {
+ return nil, err
+ }
+
+ buildContext := build.Default
+ goPath := buildContext.GOPATH
+ goSrc := filepath.Join(goPath, "src")
+ if !util.FileExists(goSrc) {
+ return nil, moduleCheckErr
+ }
+
+ wd, err := filepath.Abs(workDir)
+ if err != nil {
+ return nil, err
+ }
+
+ if !strings.HasPrefix(wd, goSrc) {
+ return nil, moduleCheckErr
+ }
+
+ projectName := strings.TrimPrefix(wd, goSrc+string(filepath.Separator))
+ return &ProjectContext{
+ WorkDir: workDir,
+ Name: projectName,
+ Path: projectName,
+ Dir: filepath.Join(goSrc, projectName),
+ }, nil
+}
diff --git a/tools/goctl/util/ctx/gopath_test.go b/tools/goctl/util/ctx/gopath_test.go
new file mode 100644
index 00000000..1f0f0780
--- /dev/null
+++ b/tools/goctl/util/ctx/gopath_test.go
@@ -0,0 +1,54 @@
+package ctx
+
+import (
+ "go/build"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/core/stringx"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+)
+
+func TestProjectFromGoPath(t *testing.T) {
+ dft := build.Default
+ gp := dft.GOPATH
+ if len(gp) == 0 {
+ return
+ }
+ projectName := stringx.Rand()
+ dir := filepath.Join(gp, "src", projectName)
+ err := util.MkdirIfNotExist(dir)
+ if err != nil {
+ return
+ }
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+
+ ctx, err := projectFromGoPath(dir)
+ assert.Nil(t, err)
+ assert.Equal(t, dir, ctx.Dir)
+ assert.Equal(t, projectName, ctx.Path)
+}
+
+func TestProjectFromGoPathNotInGoSrc(t *testing.T) {
+ dft := build.Default
+ gp := dft.GOPATH
+ if len(gp) == 0 {
+ return
+ }
+ projectName := stringx.Rand()
+ dir := filepath.Join(gp, "src", projectName)
+ err := util.MkdirIfNotExist(dir)
+ if err != nil {
+ return
+ }
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+
+ _, err = projectFromGoPath("testPath")
+ assert.NotNil(t, err)
+}
diff --git a/tools/goctl/util/ctx/modcheck.go b/tools/goctl/util/ctx/modcheck.go
new file mode 100644
index 00000000..0db491e4
--- /dev/null
+++ b/tools/goctl/util/ctx/modcheck.go
@@ -0,0 +1,32 @@
+package ctx
+
+import (
+ "errors"
+ "os"
+
+ "github.com/tal-tech/go-zero/core/jsonx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+)
+
+// IsGoMod is used to determine whether workDir is a go module project through command `go list -json -m`
+func IsGoMod(workDir string) (bool, error) {
+ if len(workDir) == 0 {
+ return false, errors.New("the work directory is not found")
+ }
+ if _, err := os.Stat(workDir); err != nil {
+ return false, err
+ }
+
+ data, err := execx.Run("go list -json -m", workDir)
+ if err != nil {
+ return false, nil
+ }
+
+ var m Module
+ err = jsonx.Unmarshal([]byte(data), &m)
+ if err != nil {
+ return false, err
+ }
+
+ return len(m.GoMod) > 0, nil
+}
diff --git a/tools/goctl/util/ctx/modcheck_test.go b/tools/goctl/util/ctx/modcheck_test.go
new file mode 100644
index 00000000..b651af99
--- /dev/null
+++ b/tools/goctl/util/ctx/modcheck_test.go
@@ -0,0 +1,67 @@
+package ctx
+
+import (
+ "go/build"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tal-tech/go-zero/core/stringx"
+ "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
+ "github.com/tal-tech/go-zero/tools/goctl/util"
+)
+
+func TestIsGoMod(t *testing.T) {
+ // create mod project
+ dft := build.Default
+ gp := dft.GOPATH
+ if len(gp) == 0 {
+ return
+ }
+ projectName := stringx.Rand()
+ dir := filepath.Join(gp, "src", projectName)
+ err := util.MkdirIfNotExist(dir)
+ if err != nil {
+ return
+ }
+
+ _, err = execx.Run("go mod init "+projectName, dir)
+ assert.Nil(t, err)
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+
+ isGoMod, err := IsGoMod(dir)
+ assert.Nil(t, err)
+ assert.Equal(t, true, isGoMod)
+}
+
+func TestIsGoModNot(t *testing.T) {
+ dft := build.Default
+ gp := dft.GOPATH
+ if len(gp) == 0 {
+ return
+ }
+ projectName := stringx.Rand()
+ dir := filepath.Join(gp, "src", projectName)
+ err := util.MkdirIfNotExist(dir)
+ if err != nil {
+ return
+ }
+
+ defer func() {
+ _ = os.RemoveAll(dir)
+ }()
+
+ isGoMod, err := IsGoMod(dir)
+ assert.Nil(t, err)
+ assert.Equal(t, false, isGoMod)
+}
+
+func TestIsGoModWorkDirIsNil(t *testing.T) {
+ _, err := IsGoMod("")
+ assert.Equal(t, err.Error(), func() string {
+ return "the work directory is not found"
+ }())
+}
diff --git a/tools/goctl/util/project/project.go b/tools/goctl/util/project/project.go
deleted file mode 100644
index 0ebb61b8..00000000
--- a/tools/goctl/util/project/project.go
+++ /dev/null
@@ -1,141 +0,0 @@
-package project
-
-import (
- "errors"
- "fmt"
- "io/ioutil"
- "os"
- "os/exec"
- "path/filepath"
- "regexp"
- "strings"
-
- "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
-)
-
-const (
- constGo = "go"
- constProtoC = "protoc"
- constGoMod = "go env GOMOD"
- constGoPath = "go env GOPATH"
- constProtoCGenGo = "protoc-gen-go"
-)
-
-type (
- Project struct {
- Path string // Project path name
- Name string // Project name
- Package string // The service related package
- // true-> project in go path or project init with go mod,or else->false
- IsInGoEnv bool
- GoMod GoMod
- }
-
- GoMod struct {
- Module string // The gomod module name
- Path string // The gomod related path
- }
-)
-
-func Prepare(projectDir string, checkGrpcEnv bool) (*Project, error) {
- _, err := exec.LookPath(constGo)
- if err != nil {
- return nil, fmt.Errorf("please install go first,reference documents:「https://golang.org/doc/install」")
- }
-
- if checkGrpcEnv {
- _, err = exec.LookPath(constProtoC)
- if err != nil {
- return nil, fmt.Errorf("please install protoc first,reference documents:「https://github.com/golang/protobuf」")
- }
-
- _, err = exec.LookPath(constProtoCGenGo)
- if err != nil {
- return nil, fmt.Errorf("please install plugin protoc-gen-go first,reference documents:「https://github.com/golang/protobuf」")
- }
- }
-
- var (
- goMod, module string
- goPath string
- name, path string
- pkg string
- )
-
- ret, err := execx.Run(constGoMod, projectDir)
- if err != nil {
- return nil, err
- }
-
- goMod = strings.TrimSpace(ret)
- if goMod == os.DevNull {
- goMod = ""
- }
-
- ret, err = execx.Run(constGoPath, "")
- if err != nil {
- return nil, err
- }
-
- goPath = strings.TrimSpace(ret)
- src := filepath.Join(goPath, "src")
- var isInGoEnv = true
- if len(goMod) > 0 {
- path = filepath.Dir(goMod)
- name = filepath.Base(path)
- data, err := ioutil.ReadFile(goMod)
- if err != nil {
- return nil, err
- }
-
- module, err = matchModule(data)
- if err != nil {
- return nil, err
- }
- } else {
- pwd, err := filepath.Abs(projectDir)
- if err != nil {
- return nil, err
- }
-
- if !strings.HasPrefix(pwd, src) {
- name = filepath.Clean(filepath.Base(pwd))
- path = projectDir
- pkg = name
- isInGoEnv = false
- } else {
- r := strings.TrimPrefix(pwd, src+string(filepath.Separator))
- name = filepath.Dir(r)
- if name == "." {
- name = r
- }
- path = filepath.Join(src, name)
- pkg = r
- }
- module = name
- }
-
- return &Project{
- Name: name,
- Path: path,
- Package: strings.ReplaceAll(pkg, `\`, "/"),
- IsInGoEnv: isInGoEnv,
- GoMod: GoMod{
- Module: module,
- Path: goMod,
- },
- }, nil
-}
-
-func matchModule(data []byte) (string, error) {
- text := string(data)
- re := regexp.MustCompile(`(?m)^\s*module\s+[a-z0-9_/\-.]+$`)
- matches := re.FindAllString(text, -1)
- if len(matches) == 1 {
- target := matches[0]
- index := strings.Index(target, "module")
- return strings.TrimSpace(target[index+6:]), nil
- }
-
- return "", errors.New("module not matched")
-}