reactor rpc (#179)

* reactor rpc generation

* update flag

* update command

* update command

* update unit test

* delete test file

* optimize code

* update doc

* update gen pb

* rename target dir

* update mysql data type convert rule

* add done flag

* optimize req/reply parameter

* optimize req/reply parameter

* remove waste code

* remove duplicate parameter

* format code

* format code

* optimize naming

* reactor rpcv2 to rpc

* remove new line

* format code

* rename underline to snake

* reactor getParentPackage

* remove debug log

* reactor background
This commit is contained in:
Keson 2020-11-05 14:12:47 +08:00 committed by GitHub
parent c9ec22d5f4
commit e76f44a35b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
95 changed files with 2708 additions and 3301 deletions

View File

@ -10,29 +10,20 @@ import (
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
goctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/project"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func getParentPackage(dir string) (string, error) {
p, err := project.Prepare(dir, false)
abs, err := filepath.Abs(dir)
if err != nil {
return "", err
}
if len(p.GoMod.Path) > 0 {
goModePath := filepath.Clean(filepath.Dir(p.GoMod.Path))
absPath, err := filepath.Abs(dir)
if err != nil {
return "", err
}
parent := filepath.Clean(goctlutil.JoinPackages(p.GoMod.Module, absPath[len(goModePath):]))
parent = strings.ReplaceAll(parent, "\\", "/")
parent = strings.ReplaceAll(parent, `\`, "/")
return parent, nil
projectCtx, err := ctx.Prepare(abs)
if err != nil {
return "", err
}
return p.Package, nil
return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
}
func writeIndent(writer io.Writer, indent int) {

View File

@ -19,7 +19,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/configgen"
"github.com/tal-tech/go-zero/tools/goctl/docker"
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/command"
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli"
"github.com/tal-tech/go-zero/tools/goctl/tpl"
"github.com/urfave/cli"
)
@ -211,7 +211,7 @@ var (
Flags: []cli.Flag{
cli.BoolFlag{
Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [option]",
Usage: "whether the command execution environment is from idea plugin. [optional]",
},
},
Action: rpc.RpcNew,
@ -226,7 +226,7 @@ var (
},
cli.BoolFlag{
Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [option]",
Usage: "whether the command execution environment is from idea plugin. [optional]",
},
},
Action: rpc.RpcTemplate,
@ -239,17 +239,17 @@ var (
Name: "src, s",
Usage: "the file path of the proto source file",
},
cli.StringFlag{
Name: "dir, d",
Usage: `the target path of the code,default path is "${pwd}". [option]`,
cli.StringSliceFlag{
Name: "proto_path, I",
Usage: `native command of protoc, specify the directory in which to search for imports. [optional]`,
},
cli.StringFlag{
Name: "service, srv",
Usage: `the name of rpc service. [option]`,
Name: "dir, d",
Usage: `the target path of the code`,
},
cli.BoolFlag{
Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [option]",
Usage: "whether the command execution environment is from idea plugin. [optional]",
},
},
Action: rpc.Rpc,
@ -313,7 +313,7 @@ var (
},
cli.StringFlag{
Name: "style",
Usage: "the file naming style, lower|camel|underline,default is lower",
Usage: "the file naming style, lower|camel|snake, default is lower",
},
cli.BoolFlag{
Name: "idea",

View File

@ -40,7 +40,7 @@ func MysqlDDL(ctx *cli.Context) error {
}
switch namingStyle {
case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline:
case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "":
namingStyle = gen.NamingLower
default:
@ -87,7 +87,7 @@ func MyDataSource(ctx *cli.Context) error {
}
switch namingStyle {
case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline:
case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "":
namingStyle = gen.NamingLower
default:

View File

@ -8,30 +8,36 @@ import (
var (
commonMysqlDataTypeMap = map[string]string{
// For consistency, all integer types are converted to int64
"tinyint": "int64",
"smallint": "int64",
"mediumint": "int64",
"int": "int64",
"integer": "int64",
"bigint": "int64",
"float": "float64",
"double": "float64",
"decimal": "float64",
"date": "time.Time",
"time": "string",
"year": "int64",
"datetime": "time.Time",
"timestamp": "time.Time",
// number
"bool": "int64",
"boolean": "int64",
"tinyint": "int64",
"smallint": "int64",
"mediumint": "int64",
"int": "int64",
"integer": "int64",
"bigint": "int64",
"float": "float64",
"double": "float64",
"decimal": "float64",
// date&time
"date": "time.Time",
"datetime": "time.Time",
"timestamp": "time.Time",
"time": "string",
"year": "int64",
// string
"char": "string",
"varchar": "string",
"tinyblob": "string",
"binary": "string",
"varbinary": "string",
"tinytext": "string",
"blob": "string",
"text": "string",
"mediumblob": "string",
"mediumtext": "string",
"longblob": "string",
"longtext": "string",
"enum": "string",
"set": "string",
"json": "string",
}
)

View File

@ -19,7 +19,7 @@ const (
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
NamingLower = "lower"
NamingCamel = "camel"
NamingUnderline = "underline"
NamingSnake = "snake"
)
type (
@ -81,7 +81,7 @@ func (g *defaultGenerator) Start(withCache bool) error {
switch g.namingStyle {
case NamingCamel:
name = fmt.Sprintf("%sModel.go", tn.ToCamel())
case NamingUnderline:
case NamingSnake:
name = fmt.Sprintf("%s_model.go", tn.ToSnake())
}
filename := filepath.Join(dirAbs, name)

View File

@ -1,6 +1,8 @@
package gen
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
@ -8,27 +10,55 @@ import (
)
var (
source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;"
source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `content` json DEFAULT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;"
)
func TestCacheModel(t *testing.T) {
logx.Disable()
_ = Clean()
g := NewDefaultGenerator(source, "./testmodel/cache", NamingLower)
dir, _ := filepath.Abs("./testmodel")
cacheDir := filepath.Join(dir, "cache")
noCacheDir := filepath.Join(dir, "nocache")
defer func() {
_ = os.RemoveAll(dir)
}()
g := NewDefaultGenerator(source, cacheDir, NamingLower)
err := g.Start(true)
assert.Nil(t, err)
g = NewDefaultGenerator(source, "./testmodel/nocache", NamingLower)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(cacheDir, "testuserinfomodel.go"))
return err == nil
}())
g = NewDefaultGenerator(source, noCacheDir, NamingLower)
err = g.Start(false)
assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go"))
return err == nil
}())
}
func TestNamingModel(t *testing.T) {
logx.Disable()
_ = Clean()
g := NewDefaultGenerator(source, "./testmodel/camel", NamingCamel)
dir, _ := filepath.Abs("./testmodel")
camelDir := filepath.Join(dir, "camel")
snakeDir := filepath.Join(dir, "snake")
defer func() {
_ = os.RemoveAll(dir)
}()
g := NewDefaultGenerator(source, camelDir, NamingCamel)
err := g.Start(true)
assert.Nil(t, err)
g = NewDefaultGenerator(source, "./testmodel/snake", NamingUnderline)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
return err == nil
}())
g = NewDefaultGenerator(source, snakeDir, NamingSnake)
err = g.Start(true)
assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go"))
return err == nil
}())
}

View File

@ -1,125 +0,0 @@
package cache
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#"
cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#"
)
type (
TestUserInfoModel struct {
sqlc.CachedConn
table string
}
TestUserInfo struct {
Id int64 `db:"id"`
Nanosecond int64 `db:"nanosecond"`
Data string `db:"data"`
CreateTime time.Time `db:"create_time"`
UpdateTime time.Time `db:"update_time"`
}
)
func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel {
return &TestUserInfoModel{
CachedConn: sqlc.NewConn(conn, c),
table: "test_user_info",
}
}
func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
return conn.Exec(query, data.Nanosecond, data.Data)
}, testUserInfoNanosecondKey)
return ret, err
}
func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
var resp TestUserInfo
err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
return conn.QueryRow(v, query, id)
})
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond)
var resp TestUserInfo
err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
if err := conn.QueryRow(&resp, query, nanosecond); err != nil {
return nil, err
}
return resp.Id, nil
}, m.queryPrimary)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) Update(data TestUserInfo) error {
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
return conn.Exec(query, data.Nanosecond, data.Data, data.Id)
}, testUserInfoIdKey)
return err
}
func (m *TestUserInfoModel) Delete(id int64) error {
data, err := m.FindOne(id)
if err != nil {
return err
}
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
_, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("delete from %s where id = ?", m.table)
return conn.Exec(query, id)
}, testUserInfoNanosecondKey, testUserInfoIdKey)
return err
}
func (m *TestUserInfoModel) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary)
}
func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
return conn.QueryRow(v, query, primary)
}

View File

@ -1,5 +0,0 @@
package cache
import "github.com/tal-tech/go-zero/core/stores/sqlx"
var ErrNotFound = sqlx.ErrNotFound

View File

@ -1,125 +0,0 @@
package camel
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#"
cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#"
)
type (
TestUserInfoModel struct {
sqlc.CachedConn
table string
}
TestUserInfo struct {
Id int64 `db:"id"`
Nanosecond int64 `db:"nanosecond"`
Data string `db:"data"`
CreateTime time.Time `db:"create_time"`
UpdateTime time.Time `db:"update_time"`
}
)
func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel {
return &TestUserInfoModel{
CachedConn: sqlc.NewConn(conn, c),
table: "test_user_info",
}
}
func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
return conn.Exec(query, data.Nanosecond, data.Data)
}, testUserInfoNanosecondKey)
return ret, err
}
func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
var resp TestUserInfo
err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
return conn.QueryRow(v, query, id)
})
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond)
var resp TestUserInfo
err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
if err := conn.QueryRow(&resp, query, nanosecond); err != nil {
return nil, err
}
return resp.Id, nil
}, m.queryPrimary)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) Update(data TestUserInfo) error {
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
return conn.Exec(query, data.Nanosecond, data.Data, data.Id)
}, testUserInfoIdKey)
return err
}
func (m *TestUserInfoModel) Delete(id int64) error {
data, err := m.FindOne(id)
if err != nil {
return err
}
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
_, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("delete from %s where id = ?", m.table)
return conn.Exec(query, id)
}, testUserInfoIdKey, testUserInfoNanosecondKey)
return err
}
func (m *TestUserInfoModel) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary)
}
func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
return conn.QueryRow(v, query, primary)
}

View File

@ -1,5 +0,0 @@
package camel
import "github.com/tal-tech/go-zero/core/stores/sqlx"
var ErrNotFound = sqlx.ErrNotFound

View File

@ -1,88 +0,0 @@
package nocache
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
)
type (
TestUserInfoModel struct {
conn sqlx.SqlConn
table string
}
TestUserInfo struct {
Id int64 `db:"id"`
Nanosecond int64 `db:"nanosecond"`
Data string `db:"data"`
CreateTime time.Time `db:"create_time"`
UpdateTime time.Time `db:"update_time"`
}
)
func NewTestUserInfoModel(conn sqlx.SqlConn) *TestUserInfoModel {
return &TestUserInfoModel{
conn: conn,
table: "test_user_info",
}
}
func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
ret, err := m.conn.Exec(query, data.Nanosecond, data.Data)
return ret, err
}
func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
var resp TestUserInfo
err := m.conn.QueryRow(&resp, query, id)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
var resp TestUserInfo
query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
err := m.conn.QueryRow(&resp, query, nanosecond)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) Update(data TestUserInfo) error {
query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
_, err := m.conn.Exec(query, data.Nanosecond, data.Data, data.Id)
return err
}
func (m *TestUserInfoModel) Delete(id int64) error {
query := fmt.Sprintf("delete from %s where id = ?", m.table)
_, err := m.conn.Exec(query, id)
return err
}

View File

@ -1,5 +0,0 @@
package nocache
import "github.com/tal-tech/go-zero/core/stores/sqlx"
var ErrNotFound = sqlx.ErrNotFound

View File

@ -1,125 +0,0 @@
package snake
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{})
testUserInfoRows = strings.Join(testUserInfoFieldNames, ",")
testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",")
testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#"
cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#"
)
type (
TestUserInfoModel struct {
sqlc.CachedConn
table string
}
TestUserInfo struct {
Id int64 `db:"id"`
Nanosecond int64 `db:"nanosecond"`
Data string `db:"data"`
CreateTime time.Time `db:"create_time"`
UpdateTime time.Time `db:"update_time"`
}
)
func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel {
return &TestUserInfoModel{
CachedConn: sqlc.NewConn(conn, c),
table: "test_user_info",
}
}
func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) {
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet)
return conn.Exec(query, data.Nanosecond, data.Data)
}, testUserInfoNanosecondKey)
return ret, err
}
func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) {
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
var resp TestUserInfo
err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
return conn.QueryRow(v, query, id)
})
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) {
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond)
var resp TestUserInfo
err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table)
if err := conn.QueryRow(&resp, query, nanosecond); err != nil {
return nil, err
}
return resp.Id, nil
}, m.queryPrimary)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *TestUserInfoModel) Update(data TestUserInfo) error {
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder)
return conn.Exec(query, data.Nanosecond, data.Data, data.Id)
}, testUserInfoIdKey)
return err
}
func (m *TestUserInfoModel) Delete(id int64) error {
data, err := m.FindOne(id)
if err != nil {
return err
}
testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id)
testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond)
_, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("delete from %s where id = ?", m.table)
return conn.Exec(query, id)
}, testUserInfoIdKey, testUserInfoNanosecondKey)
return err
}
func (m *TestUserInfoModel) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary)
}
func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table)
return conn.QueryRow(v, query, primary)
}

View File

@ -1,5 +0,0 @@
package snake
import "github.com/tal-tech/go-zero/core/stores/sqlx"
var ErrNotFound = sqlx.ErrNotFound

View File

@ -1,20 +0,0 @@
# Change log
## 2020-10-19
* 增加template
## 2020-09-10
* rpc greet服务一键生成
* 修复相对路径生成rpc服务package引入错误bug
* 移除`--shared`参数
## 2020-08-29
* 新增支持windows生成
## 2020-08-27
* 新增支持rpc模板生成
* 新增支持rpc服务生成

View File

@ -7,9 +7,8 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块支持prot
* 简单易用
* 快速提升开发效率
* 出错率低
* 支持基于main proto作为相对路径的import
* 支持map、enum类型
* 支持any类型
* 贴近protoc
## 快速开始
@ -19,44 +18,41 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块支持prot
如生成greet rpc服务
```shell script
```Bash
goctl rpc new greet
```
执行后代码结构如下:
```golang
└── greet
├── etc
│   └── greet.yaml
├── go.mod
├── go.sum
├── greet
│   ├── greet.go
│   ├── greet_mock.go
│   └── types.go
├── greet.go
├── greet.proto
├── internal
│   ├── config
│   │   └── config.go
│   ├── logic
│   │   └── pinglogic.go
│   ├── server
│   │   └── greetserver.go
│   └── svc
│   └── servicecontext.go
└── pb
└── greet.pb.go
.
├── etc // 配置文件
│   └── greet.yaml
├── go.mod
├── greet // client call
│   └── greet.go
├── greet.go // main entry
├── greet.proto
└── internal
├── config // 配置声明
│   └── config.go
├── greet // pb.go
│   └── greet.pb.go
├── logic // logic
│   └── pinglogic.go
├── server // pb invoker
│   └── greetserver.go
└── svc // resource dependency
└── servicecontext.go
```
rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题解决</a>
rpc一键生成常见问题解决<a href="#常见问题解决">常见问题解决</a>
### 方式二通过指定proto生成rpc服务
* 生成proto模板
```shell script
```Bash
goctl rpc template -o=user.proto
```
@ -87,35 +83,10 @@ rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题
* 生成rpc服务代码
```shell script
```Bash
goctl rpc proto -src=user.proto
```
代码tree
```Plain Text
user
├── etc
│   └── user.json
├── internal
│   ├── config
│   │   └── config.go
│   ├── handler
│   │   ├── loginhandler.go
│   ├── logic
│   │   └── loginlogic.go
│   └── svc
│   └── servicecontext.go
├── pb
│   └── user.pb.go
├── shared
│   ├── mockusermodel.go
│   ├── types.go
│   └── usermodel.go
├── user.go
└── user.proto
```
## 准备工作
* 安装了go环境
@ -126,11 +97,11 @@ rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题
### rpc服务生成用法
```shell script
```Bash
goctl rpc proto -h
```
```shell script
```Bash
NAME:
goctl rpc proto - generate rpc from proto
@ -139,35 +110,22 @@ USAGE:
OPTIONS:
--src value, -s value the file path of the proto source file
--dir value, -d value the target path of the code,default path is "${pwd}". [option]
--service value, --srv value the name of rpc service. [option]
--idea whether the command execution environment is from idea plugin. [option]
--proto_path value, -I value native command of protoc,specify the directory in which to search for imports. [optional]
--dir value, -d value the target path of the code,default path is "${pwd}". [optional]
--idea whether the command execution environment is from idea plugin. [optional]
```
### 参数说明
* --src 必填proto数据源目前暂时支持单个proto文件生成这里不支持不建议外部依赖
* --dir 非必填默认为proto文件所在目录生成代码的目标目录
* --service 服务名称非必填默认为proto文件所在目录名称但是如果proto所在目录为一下结构
```shell script
user
├── cmd
│   └── rpc
│   └── user.proto
```
则服务名称亦为user而非proto所在文件夹名称了这里推荐使用这种结构可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。
> 注意这里的shared文件夹名称将会是代码中的package名称。
* --idea 非必填是否为idea插件中执行保留字段终端执行可以忽略
* --src 必填proto数据源目前暂时支持单个proto文件生成
* --proto_path 可选protoc原生子命令用于指定proto import从何处查找可指定多个路径,如`goctl rpc -I={path1} -I={path2} ...`,在没有import时可不填。当前proto路径不用指定已经内置`-I`的详细用法请参考`protoc -h`
* --dir 可选默认为proto文件所在目录生成代码的目标目录
* --idea 可选是否为idea插件中执行终端执行可以忽略
### 开发人员需要做什么
关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后人员仅需要修改
关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后开发人员仅需要修改
* 服务中的配置文件编写(etc/xx.json、internal/config/config.go)
* 服务中业务逻辑编写(internal/logic/xxlogic.go)
@ -193,69 +151,54 @@ OPTIONS:
的标识,请注意不要将也写业务性代码写在里面。
## any和import支持
* 支持any类型声明
* 支持import其他proto文件
## proto import
* 对于rpc中的requestType和returnType必须在main proto文件定义对于proto中的message可以像protoc一样import其他proto文件。
any类型固定import为`google/protobuf/any.proto`,且从${GOPATH}/src中查找proto的import支持main proto的相对路径的import且与proto文件对应的pb.go文件必须在proto目录中能被找到。不支持工程外的其他proto文件import。
proto示例:
> ⚠️注意: 不支持proto嵌套import被import的proto文件不支持import。
### import书写格式
import书写格式
```golang
// @{package_of_pb}
import {proto_omport}
```
@{package_of_pb}pb文件的真实import目录。
{proto_omport}proto import
demo中的
```golang
// @greet/base
import "base/base.proto";
```
工程目录结构如下
```
greet
│   ├── base
│   │   ├── base.pb.go
│   │   └── base.proto
│   ├── demo.proto
│   ├── go.mod
│   └── go.sum
```
demo
```golang
### 错误import
```proto
syntax = "proto3";
import "google/protobuf/any.proto";
// @greet/base
import "base/base.proto";
package stream;
package greet;
enum Gender{
UNKNOWN = 0;
MAN = 1;
WOMAN = 2;
import "base/common.proto"
message Request {
string ping = 1;
}
message StreamResp{
string name = 2;
Gender gender = 3;
google.protobuf.Any details = 5;
base.StreamReq req = 6;
message Response {
string pong = 1;
}
service StreamGreeter {
rpc greet(base.StreamReq) returns (StreamResp);
service Greet {
rpc Ping(base.In) returns(base.Out);// request和return 不支持import
}
```
### 正确import
```proto
syntax = "proto3";
package greet;
import "base/common.proto"
message Request {
base.In in = 1;// 支持import
}
message Response {
base.Out out = 2;// 支持import
}
service Greet {
rpc Ping(Request) returns(Response);
}
```
## 常见问题解决(go mod工程)

View File

@ -1,108 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: base.proto
package base
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type IdRequest struct {
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *IdRequest) Reset() { *m = IdRequest{} }
func (m *IdRequest) String() string { return proto.CompactTextString(m) }
func (*IdRequest) ProtoMessage() {}
func (*IdRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_db1b6b0986796150, []int{0}
}
func (m *IdRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_IdRequest.Unmarshal(m, b)
}
func (m *IdRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_IdRequest.Marshal(b, m, deterministic)
}
func (m *IdRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_IdRequest.Merge(m, src)
}
func (m *IdRequest) XXX_Size() int {
return xxx_messageInfo_IdRequest.Size(m)
}
func (m *IdRequest) XXX_DiscardUnknown() {
xxx_messageInfo_IdRequest.DiscardUnknown(m)
}
var xxx_messageInfo_IdRequest proto.InternalMessageInfo
func (m *IdRequest) GetId() string {
if m != nil {
return m.Id
}
return ""
}
type EmptyResponse struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *EmptyResponse) Reset() { *m = EmptyResponse{} }
func (m *EmptyResponse) String() string { return proto.CompactTextString(m) }
func (*EmptyResponse) ProtoMessage() {}
func (*EmptyResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_db1b6b0986796150, []int{1}
}
func (m *EmptyResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_EmptyResponse.Unmarshal(m, b)
}
func (m *EmptyResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_EmptyResponse.Marshal(b, m, deterministic)
}
func (m *EmptyResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_EmptyResponse.Merge(m, src)
}
func (m *EmptyResponse) XXX_Size() int {
return xxx_messageInfo_EmptyResponse.Size(m)
}
func (m *EmptyResponse) XXX_DiscardUnknown() {
xxx_messageInfo_EmptyResponse.DiscardUnknown(m)
}
var xxx_messageInfo_EmptyResponse proto.InternalMessageInfo
func init() {
proto.RegisterType((*IdRequest)(nil), "base.IdRequest")
proto.RegisterType((*EmptyResponse)(nil), "base.EmptyResponse")
}
func init() { proto.RegisterFile("base.proto", fileDescriptor_db1b6b0986796150) }
var fileDescriptor_db1b6b0986796150 = []byte{
// 91 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4a, 0x4a, 0x2c, 0x4e,
0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x01, 0xb1, 0x95, 0xa4, 0xb9, 0x38, 0x3d, 0x53,
0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0xf8, 0xb8, 0x98, 0x32, 0x53, 0x24, 0x18, 0x15,
0x18, 0x35, 0x38, 0x83, 0x98, 0x32, 0x53, 0x94, 0xf8, 0xb9, 0x78, 0x5d, 0x73, 0x0b, 0x4a, 0x2a,
0x83, 0x52, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x93, 0xd8, 0xc0, 0x5a, 0x8d, 0x01, 0x01, 0x00,
0x00, 0xff, 0xff, 0xe1, 0x39, 0x3c, 0x22, 0x48, 0x00, 0x00, 0x00,
}

View File

@ -1,11 +0,0 @@
syntax = "proto3";
package base;
message IdRequest {
string id = 1;
}
message EmptyResponse {
}

View File

@ -0,0 +1,67 @@
package cli
import (
"errors"
"fmt"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
"github.com/urfave/cli"
)
// Rpc is to generate rpc service code from a proto file by specifying a proto file using flag src,
// you can specify a target folder for code generation, when the proto file has import, you can specify
// the import search directory through the proto_path command, for specific usage, please refer to protoc -h
func Rpc(c *cli.Context) error {
src := c.String("src")
out := c.String("dir")
protoImportPath := c.StringSlice("proto_path")
if len(src) == 0 {
return errors.New("the proto source can not be nil")
}
if len(out) == 0 {
return errors.New("the target directory can not be nil")
}
g := generator.NewDefaultRpcGenerator()
return g.Generate(src, out, protoImportPath)
}
// RpcNew is to generate rpc greet service, this greet service can speed
// up your understanding of the zrpc service structure
func RpcNew(c *cli.Context) error {
name := c.Args().First()
ext := filepath.Ext(name)
if len(ext) > 0 {
return fmt.Errorf("unexpected ext: %s", ext)
}
protoName := name + ".proto"
filename := filepath.Join(".", name, protoName)
src, err := filepath.Abs(filename)
if err != nil {
return err
}
err = generator.ProtoTmpl(src)
if err != nil {
return err
}
workDir := filepath.Dir(src)
_, err = execx.Run("go mod init "+name, workDir)
if err != nil {
return err
}
g := generator.NewDefaultRpcGenerator()
return g.Generate(src, filepath.Dir(src), nil)
}
func RpcTemplate(c *cli.Context) error {
name := c.Args().First()
if len(name) == 0 {
name = "greet.proto"
}
return generator.ProtoTmpl(name)
}

View File

@ -1,60 +0,0 @@
package command
import (
"fmt"
"os"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/gen"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
func Rpc(c *cli.Context) error {
rpcCtx := ctx.MustCreateRpcContextFromCli(c)
generator := gen.NewDefaultRpcGenerator(rpcCtx)
rpcCtx.Must(generator.Generate())
return nil
}
func RpcTemplate(c *cli.Context) error {
out := c.String("out")
idea := c.Bool("idea")
generator := gen.NewRpcTemplate(out, idea)
generator.MustGenerate(true)
return nil
}
func RpcNew(c *cli.Context) error {
idea := c.Bool("idea")
arg := c.Args().First()
if len(arg) == 0 {
arg = "greet"
}
abs, err := filepath.Abs(arg)
if err != nil {
return err
}
_, err = os.Stat(abs)
if err != nil {
if !os.IsNotExist(err) {
return err
}
err = util.MkdirIfNotExist(abs)
if err != nil {
return err
}
}
dir := filepath.Base(filepath.Clean(abs))
protoSrc := filepath.Join(abs, fmt.Sprintf("%v.proto", dir))
templateGenerator := gen.NewRpcTemplate(protoSrc, idea)
templateGenerator.MustGenerate(false)
rpcCtx := ctx.MustCreateRpcContext(protoSrc, "", "", idea)
generator := gen.NewDefaultRpcGenerator(rpcCtx)
rpcCtx.Must(generator.Generate())
return nil
}

View File

@ -1,99 +0,0 @@
package ctx
import (
"fmt"
"path/filepath"
"runtime"
"strings"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/project"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
"github.com/tal-tech/go-zero/tools/goctl/vars"
"github.com/urfave/cli"
)
const (
flagSrc = "src"
flagDir = "dir"
flagService = "service"
flagIdea = "idea"
)
type RpcContext struct {
ProjectPath string
ProjectName stringx.String
ServiceName stringx.String
CurrentPath string
Module string
ProtoFileSrc string
ProtoSource string
TargetDir string
IsInGoEnv bool
console.Console
}
func MustCreateRpcContext(protoSrc, targetDir, serviceName string, idea bool) *RpcContext {
log := console.NewConsole(idea)
if stringx.From(protoSrc).IsEmptyOrSpace() {
log.Fatalln("expected proto source, but nothing found")
}
srcFp, err := filepath.Abs(protoSrc)
log.Must(err)
if !util.FileExists(srcFp) {
log.Fatalln("%s is not exists", srcFp)
}
current := filepath.Dir(srcFp)
if stringx.From(targetDir).IsEmptyOrSpace() {
targetDir = current
}
targetDirFp, err := filepath.Abs(targetDir)
log.Must(err)
if stringx.From(serviceName).IsEmptyOrSpace() {
serviceName = getServiceFromRpcStructure(targetDirFp)
}
serviceNameString := stringx.From(serviceName)
if serviceNameString.IsEmptyOrSpace() {
log.Fatalln("service name not found")
}
info, err := project.Prepare(targetDir, true)
log.Must(err)
return &RpcContext{
ProjectPath: info.Path,
ProjectName: stringx.From(info.Name),
ServiceName: serviceNameString,
CurrentPath: current,
Module: info.GoMod.Module,
ProtoFileSrc: srcFp,
ProtoSource: filepath.Base(srcFp),
TargetDir: targetDirFp,
IsInGoEnv: info.IsInGoEnv,
Console: log,
}
}
func MustCreateRpcContextFromCli(ctx *cli.Context) *RpcContext {
os := runtime.GOOS
switch os {
case vars.OsMac, vars.OsLinux, vars.OsWindows:
default:
logx.Must(fmt.Errorf("unexpected os: %s", os))
}
protoSrc := ctx.String(flagSrc)
targetDir := ctx.String(flagDir)
serviceName := ctx.String(flagService)
idea := ctx.Bool(flagIdea)
return MustCreateRpcContext(protoSrc, targetDir, serviceName, idea)
}
func getServiceFromRpcStructure(targetDir string) string {
targetDir = filepath.Clean(targetDir)
suffix := filepath.Join("cmd", "rpc")
return filepath.Base(strings.TrimSuffix(targetDir, suffix))
}

View File

@ -6,7 +6,9 @@ import (
"fmt"
"os/exec"
"runtime"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
)
@ -24,17 +26,17 @@ func Run(arg string, dir string) (string, error) {
if len(dir) > 0 {
cmd.Dir = dir
}
dtsout := new(bytes.Buffer)
stdout := new(bytes.Buffer)
stderr := new(bytes.Buffer)
cmd.Stdout = dtsout
cmd.Stdout = stdout
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
if stderr.Len() > 0 {
return "", errors.New(stderr.String())
return "", errors.New(strings.TrimSuffix(stderr.String(), util.NL))
}
return "", err
}
return dtsout.String(), nil
return strings.TrimSuffix(stdout.String(), util.NL), nil
}

View File

@ -1,93 +0,0 @@
package gen
import (
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
const (
dirTarget = "dirTarget"
dirConfig = "config"
dirEtc = "etc"
dirSvc = "svc"
dirServer = "server"
dirLogic = "logic"
dirPb = "pb"
dirInternal = "internal"
fileConfig = "config.go"
fileServiceContext = "servicecontext.go"
)
type defaultRpcGenerator struct {
dirM map[string]string
Ctx *ctx.RpcContext
ast *parser.PbAst
}
func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator {
return &defaultRpcGenerator{
Ctx: ctx,
}
}
func (g *defaultRpcGenerator) Generate() (err error) {
g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」")
g.Ctx.Warning("-> generating rpc code ...")
defer func() {
if err == nil {
g.Ctx.MarkDone()
}
}()
err = g.createDir()
if err != nil {
return
}
err = g.initGoMod()
if err != nil {
return
}
err = g.genEtc()
if err != nil {
return
}
err = g.genPb()
if err != nil {
return
}
err = g.genConfig()
if err != nil {
return
}
err = g.genSvc()
if err != nil {
return
}
err = g.genLogic()
if err != nil {
return
}
err = g.genHandler()
if err != nil {
return
}
err = g.genMain()
if err != nil {
return
}
err = g.genCall()
if err != nil {
return
}
return
}

View File

@ -1,224 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
typesFilename = "types.go"
callTemplateText = `{{.head}}
//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
package {{.filePackage}}
import (
"context"
{{.package}}
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/zrpc"
)
type (
{{.serviceName}} interface {
{{.interface}}
}
default{{.serviceName}} struct {
cli zrpc.Client
}
)
func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
return &default{{.serviceName}}{
cli: cli,
}
}
{{.functions}}
`
callTemplateTypes = `{{.head}}
package {{.filePackage}}
import "errors"
var errJsonConvert = errors.New("json convert error")
{{.const}}
{{.types}}
`
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
var request {{.pbRequest}}
bts, err := jsonx.Marshal(in)
if err != nil {
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &request)
if err != nil {
return nil, errJsonConvert
}
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
resp, err := client.{{.method}}(ctx, &request)
if err != nil{
return nil, err
}
var ret {{.pbResponse}}
bts, err = jsonx.Marshal(resp)
if err != nil{
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &ret)
if err != nil{
return nil, errJsonConvert
}
return &ret, nil
}
`
)
func (g *defaultRpcGenerator) genCall() error {
file := g.ast
if len(file.Service) == 0 {
return nil
}
if len(file.Service) > 1 {
return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service))
}
typeCode, err := file.GenTypesCode()
if err != nil {
return err
}
constLit, err := file.GenEnumCode()
if err != nil {
return err
}
service := file.Service[0]
callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
if err = util.MkdirIfNotExist(callPath); err != nil {
return err
}
filename := filepath.Join(callPath, typesFilename)
head := util.GetHead(g.Ctx.ProtoSource)
text, err := util.LoadTemplate(category, callTypesTemplateFile, callTemplateTypes)
if err != nil {
return err
}
err = util.With("types").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"const": constLit,
"filePackage": service.Name.Lower(),
"serviceName": g.Ctx.ServiceName.Title(),
"lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
"types": typeCode,
}, filename, true)
if err != nil {
return err
}
filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
functions, importList, err := g.genFunction(service)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(service)
if err != nil {
return err
}
text, err = util.LoadTemplate(category, callTemplateFile, callTemplateText)
if err != nil {
return err
}
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"name": service.Name.Lower(),
"head": head,
"filePackage": service.Name.Lower(),
"package": strings.Join(importList, util.NL),
"serviceName": service.Name.Title(),
"functions": strings.Join(functions, util.NL),
"interface": strings.Join(iFunctions, util.NL),
}, filename, true)
return err
}
func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
file := g.ast
pkgName := file.Package
functions := make([]string, 0)
imports := collection.NewSet()
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
for _, method := range service.Funcs {
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
if err != nil {
return nil, nil, err
}
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"rpcServiceName": service.Name.Title(),
"method": method.Name.Title(),
"package": pkgName,
"pbRequestName": method.ParameterIn.Name,
"pbRequest": method.ParameterIn.Expression,
"pbResponse": method.ParameterOut.Name,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return nil, nil, err
}
functions = append(functions, buffer.String())
}
return functions, imports.KeysStr(), nil
}
func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
functions := make([]string, 0)
for _, method := range service.Funcs {
text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
if err != nil {
return nil, err
}
buffer, err := util.With("interfaceFn").Parse(text).Execute(
map[string]interface{}{
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
"method": method.Name.Title(),
"pbRequest": method.ParameterIn.Name,
"pbResponse": method.ParameterOut.Name,
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}

View File

@ -1,54 +0,0 @@
package gen
import (
"path/filepath"
"runtime"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
)
// target
// ├── etc
// ├── internal
// │   ├── config
// │   ├── handler
// │   ├── logic
// │   ├── pb
// │   └── svc
func (g *defaultRpcGenerator) createDir() error {
ctx := g.Ctx
m := make(map[string]string)
m[dirTarget] = ctx.TargetDir
m[dirEtc] = filepath.Join(ctx.TargetDir, dirEtc)
m[dirInternal] = filepath.Join(ctx.TargetDir, dirInternal)
m[dirConfig] = filepath.Join(ctx.TargetDir, dirInternal, dirConfig)
m[dirServer] = filepath.Join(ctx.TargetDir, dirInternal, dirServer)
m[dirLogic] = filepath.Join(ctx.TargetDir, dirInternal, dirLogic)
m[dirPb] = filepath.Join(ctx.TargetDir, dirPb)
m[dirSvc] = filepath.Join(ctx.TargetDir, dirInternal, dirSvc)
for _, d := range m {
err := util.MkdirIfNotExist(d)
if err != nil {
return err
}
}
g.dirM = m
return nil
}
func (g *defaultRpcGenerator) mustGetPackage(dir string) string {
target := g.dirM[dir]
projectPath := g.Ctx.ProjectPath
relativePath := strings.TrimPrefix(target, projectPath)
os := runtime.GOOS
switch os {
case vars.OsWindows:
relativePath = filepath.ToSlash(relativePath)
case vars.OsMac, vars.OsLinux:
default:
g.Ctx.Fatalln("unexpected os: %s", os)
}
return g.Ctx.Module + relativePath
}

View File

@ -1,109 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
logicTemplate = `package logic
import (
"context"
{{.imports}}
"github.com/tal-tech/go-zero/core/logx"
)
type {{.logicName}} struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
return &{{.logicName}}{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
{{.functions}}
`
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
// todo: add your logic here and delete this line
return &{{.responseType}}{}, nil
}
`
)
func (g *defaultRpcGenerator) genLogic() error {
logicPath := g.dirM[dirLogic]
protoPkg := g.ast.Package
service := g.ast.Service
for _, item := range service {
for _, method := range item.Funcs {
logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
filename := filepath.Join(logicPath, logicName)
functions, importList, err := g.genLogicFunction(protoPkg, method)
if err != nil {
return err
}
imports := collection.NewSet()
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports.AddStr(svcImport)
imports.AddStr(importList...)
text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
if err != nil {
return err
}
err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"functions": functions,
"imports": strings.Join(imports.KeysStr(), util.NL),
}, filename, false)
if err != nil {
return err
}
}
}
return nil
}
func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) {
var functions = make([]string, 0)
var imports = collection.NewSet()
if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName {
imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
if err != nil {
return "", nil, err
}
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"request": method.ParameterIn.StarExpression,
"response": method.ParameterOut.StarExpression,
"responseType": method.ParameterOut.Expression,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return "", nil, err
}
functions = append(functions, buffer.String())
return strings.Join(functions, util.NL), imports.KeysStr(), nil
}

View File

@ -1,85 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const mainTemplate = `{{.head}}
package main
import (
"flag"
"fmt"
{{.imports}}
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/zrpc"
"google.golang.org/grpc"
)
var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
func main() {
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
ctx := svc.NewServiceContext(c)
{{.srv}}
s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
{{.registers}}
})
defer s.Stop()
fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
s.Start()
}
`
func (g *defaultRpcGenerator) genMain() error {
mainPath := g.dirM[dirTarget]
file := g.ast
pkg := file.Package
fileName := filepath.Join(mainPath, fmt.Sprintf("%v.go", g.Ctx.ServiceName.Lower()))
imports := make([]string, 0)
pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
remoteImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirServer))
configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig))
imports = append(imports, configImport, pbImport, remoteImport, svcImport)
srv, registers := g.genServer(pkg, file.Service)
head := util.GetHead(g.Ctx.ProtoSource)
text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate)
if err != nil {
return err
}
return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"package": pkg,
"serviceName": g.Ctx.ServiceName.Lower(),
"srv": srv,
"registers": registers,
"imports": strings.Join(imports, util.NL),
}, fileName, false)
}
func (g *defaultRpcGenerator) genServer(pkg string, list []*parser.RpcService) (string, string) {
list1 := make([]string, 0)
list2 := make([]string, 0)
for _, item := range list {
name := item.Name.UnTitle()
list1 = append(list1, fmt.Sprintf("%sSrv := server.New%sServer(ctx)", name, item.Name.Title()))
list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name))
}
return strings.Join(list1, util.NL), strings.Join(list2, util.NL)
}

View File

@ -1,82 +0,0 @@
package gen
import (
"bytes"
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
const (
protocCmd = "protoc"
grpcPluginCmd = "--go_out=plugins=grpc"
)
func (g *defaultRpcGenerator) genPb() error {
pbPath := g.dirM[dirPb]
// deprecated: containsAny will be removed in the feature
imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc)
if err != nil {
return err
}
err = g.protocGenGo(pbPath, imports)
if err != nil {
return err
}
ast, err := parser.Transfer(g.Ctx.ProtoFileSrc, pbPath, imports, g.Ctx.Console)
if err != nil {
return err
}
ast.ContainsAny = containsAny
if len(ast.Service) == 0 {
return fmt.Errorf("service not found")
}
g.ast = ast
return nil
}
func (g *defaultRpcGenerator) protocGenGo(target string, imports []*parser.Import) error {
dir := filepath.Dir(g.Ctx.ProtoFileSrc)
// cmd join,see the document of proto generating class @https://developers.google.com/protocol-buffers/docs/proto3#generating
// template: protoc -I=${import_path} -I=${other_import_path} -I=${...} --go_out=plugins=grpc,M${pb_package_kv}, M${...} :${target_dir}
// eg: protoc -I=${GOPATH}/src -I=. example.proto --go_out=plugins=grpc,Mbase/base.proto=github.com/go-zero/base.proto:.
// note: the external import out of the project which are found in ${GOPATH}/src so far.
buffer := new(bytes.Buffer)
buffer.WriteString(protocCmd + " ")
targetImportFiltered := collection.NewSet()
for _, item := range imports {
buffer.WriteString(fmt.Sprintf("-I=%s ", item.OriginalDir))
if len(item.BridgeImport) == 0 {
continue
}
targetImportFiltered.AddStr(item.BridgeImport)
}
buffer.WriteString("-I=${GOPATH}/src ")
buffer.WriteString(fmt.Sprintf("-I=%s %s ", dir, g.Ctx.ProtoFileSrc))
buffer.WriteString(grpcPluginCmd)
if targetImportFiltered.Count() > 0 {
buffer.WriteString(fmt.Sprintf(",%v", strings.Join(targetImportFiltered.KeysStr(), ",")))
}
buffer.WriteString(":" + target)
g.Ctx.Debug("-> " + buffer.String())
stdout, err := execx.Run(buffer.String(), "")
if err != nil {
return err
}
if len(stdout) > 0 {
g.Ctx.Info(stdout)
}
return nil
}

View File

@ -1,116 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
serverTemplate = `{{.head}}
package server
import (
"context"
{{.imports}}
)
type {{.types}}
func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
return &{{.server}}Server{
svcCtx: svcCtx,
}
}
{{.funcs}}
`
functionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
l := logic.New{{.logicName}}(ctx,s.svcCtx)
return l.{{.method}}(in)
}
`
typeFmt = `%sServer struct {
svcCtx *svc.ServiceContext
}`
)
func (g *defaultRpcGenerator) genHandler() error {
serverPath := g.dirM[dirServer]
file := g.ast
logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports := collection.NewSet()
imports.AddStr(logicImport, svcImport)
head := util.GetHead(g.Ctx.ProtoSource)
for _, service := range file.Service {
filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
serverFile := filepath.Join(serverPath, filename)
funcList, importList, err := g.genFunctions(service)
if err != nil {
return err
}
imports.AddStr(importList...)
text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
if err != nil {
return err
}
err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"types": fmt.Sprintf(typeFmt, service.Name.Title()),
"server": service.Name.Title(),
"imports": strings.Join(imports.KeysStr(), util.NL),
"funcs": strings.Join(funcList, util.NL),
}, serverFile, true)
if err != nil {
return err
}
}
return nil
}
func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) {
file := g.ast
pkg := file.Package
var functionList []string
imports := collection.NewSet()
for _, method := range service.Funcs {
if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg {
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
if err != nil {
return nil, nil, err
}
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
"server": service.Name.Title(),
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"package": pkg,
"request": method.ParameterIn.StarExpression,
"response": method.ParameterOut.StarExpression,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return nil, nil, err
}
functionList = append(functionList, buffer.String())
}
return functionList, imports.KeysStr(), nil
}

View File

@ -1,22 +0,0 @@
package gen
import (
"fmt"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
func (g *defaultRpcGenerator) initGoMod() error {
if !g.Ctx.IsInGoEnv {
projectDir := g.dirM[dirTarget]
cmd := fmt.Sprintf("go mod init %s", g.Ctx.ProjectName.Source())
output, err := execx.Run(fmt.Sprintf(cmd), projectDir)
if err != nil {
logx.Error(err)
return err
}
g.Ctx.Info(output)
}
return nil
}

View File

@ -1,35 +0,0 @@
package base
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
func TestParseImport(t *testing.T) {
src, _ := filepath.Abs("./test.proto")
base, _ := filepath.Abs("./base.proto")
imports, containsAny, err := parser.ParseImport(src)
assert.Nil(t, err)
assert.Equal(t, true, containsAny)
assert.Equal(t, 1, len(imports))
assert.Equal(t, "github.com/tal-tech/go-zero/tools/goctl/rpc", imports[0].PbImportName)
assert.Equal(t, base, imports[0].OriginalProtoPath)
}
func TestTransfer(t *testing.T) {
src, _ := filepath.Abs("./test.proto")
abs, _ := filepath.Abs("./test")
imports, _, _ := parser.ParseImport(src)
proto, err := parser.Transfer(src, abs, imports, console.NewConsole(false))
assert.Nil(t, err)
assert.Equal(t, 1, len(proto.Service))
assert.Equal(t, "Greeter", proto.Service[0].Name.Source())
assert.Equal(t, 5, len(proto.Structure))
data, ok := proto.Structure["map"]
assert.Equal(t, true, ok)
assert.Equal(t, "M", data.Field[0].Name.Source())
}

View File

@ -0,0 +1,75 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: common.proto
package common
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type User struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *User) Reset() { *m = User{} }
func (m *User) String() string { return proto.CompactTextString(m) }
func (*User) ProtoMessage() {}
func (*User) Descriptor() ([]byte, []int) {
return fileDescriptor_555bd8c177793206, []int{0}
}
func (m *User) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_User.Unmarshal(m, b)
}
func (m *User) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_User.Marshal(b, m, deterministic)
}
func (m *User) XXX_Merge(src proto.Message) {
xxx_messageInfo_User.Merge(m, src)
}
func (m *User) XXX_Size() int {
return xxx_messageInfo_User.Size(m)
}
func (m *User) XXX_DiscardUnknown() {
xxx_messageInfo_User.DiscardUnknown(m)
}
var xxx_messageInfo_User proto.InternalMessageInfo
func (m *User) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func init() {
proto.RegisterType((*User)(nil), "common.User")
}
func init() { proto.RegisterFile("common.proto", fileDescriptor_555bd8c177793206) }
var fileDescriptor_555bd8c177793206 = []byte{
// 72 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0xce, 0xcf, 0xcd,
0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0xa4, 0xb8, 0x58,
0x42, 0x8b, 0x53, 0x8b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, 0x25, 0x18, 0x15, 0x18,
0x35, 0x38, 0x83, 0xc0, 0xec, 0x24, 0x36, 0xb0, 0x52, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
0x2c, 0x6d, 0x58, 0x59, 0x3a, 0x00, 0x00, 0x00,
}

View File

@ -0,0 +1,7 @@
syntax = "proto3";
package common;
message User {
string name = 1;
}

View File

@ -0,0 +1,33 @@
package generator
import (
"os/exec"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
type defaultGenerator struct {
log console.Console
}
func NewDefaultGenerator() *defaultGenerator {
log := console.NewColorConsole()
return &defaultGenerator{
log: log,
}
}
func (g *defaultGenerator) Prepare() error {
_, err := exec.LookPath("go")
if err != nil {
return err
}
_, err = exec.LookPath("protoc")
if err != nil {
return err
}
_, err = exec.LookPath("protoc-gen-go")
return err
}

View File

@ -0,0 +1,11 @@
package generator
import (
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
func formatFilename(filename string) string {
return strings.ToLower(stringx.From(filename).ToCamel())
}

View File

@ -0,0 +1,98 @@
package generator
import (
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
type RpcGenerator struct {
g Generator
}
func NewDefaultRpcGenerator() *RpcGenerator {
return NewRpcGenerator(NewDefaultGenerator())
}
func NewRpcGenerator(g Generator) *RpcGenerator {
return &RpcGenerator{
g: g,
}
}
func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) error {
abs, err := filepath.Abs(target)
if err != nil {
return err
}
err = util.MkdirIfNotExist(abs)
if err != nil {
return err
}
err = g.g.Prepare()
if err != nil {
return err
}
projectCtx, err := ctx.Prepare(abs)
if err != nil {
return err
}
p := parser.NewDefaultProtoParser()
proto, err := p.Parse(src)
if err != nil {
return err
}
dirCtx, err := mkdir(projectCtx, proto)
if err != nil {
return err
}
err = g.g.GenEtc(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenPb(dirCtx, protoImportPath, proto)
if err != nil {
return err
}
err = g.g.GenConfig(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenSvc(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenLogic(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenServer(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenMain(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenCall(dirCtx, proto)
console.NewColorConsole().MarkDone()
return err
}

View File

@ -0,0 +1,104 @@
package generator
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
func TestRpcGenerateCaseNilImport(t *testing.T) {
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_stream.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
func TestRpcGenerateCaseOption(t *testing.T) {
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_option.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
func TestRpcGenerateCaseWordOption(t *testing.T) {
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_word_option.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
// test keyword go
func TestRpcGenerateCaseGoOption(t *testing.T) {
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_go_option.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
func TestRpcGenerateCaseImport(t *testing.T) {
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_import.proto", abs, []string{"./base"})
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.True(t, func() bool {
return strings.Contains(err.Error(), "package base is not in GOROOT")
}())
}
}

View File

@ -0,0 +1,154 @@
package generator
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
callTemplateText = `{{.head}}
//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
package {{.filePackage}}
import (
"context"
{{.package}}
"github.com/tal-tech/go-zero/zrpc"
)
type (
{{.alias}}
{{.serviceName}} interface {
{{.interface}}
}
default{{.serviceName}} struct {
cli zrpc.Client
}
)
func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
return &default{{.serviceName}}{
cli: cli,
}
}
{{.functions}}
`
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
return client.{{.method}}(ctx, in)
}
`
)
func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
dir := ctx.GetCall()
service := proto.Service
head := util.GetHead(proto.Name)
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name)))
functions, err := g.genFunction(proto.PbPackage, service)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(service)
if err != nil {
return err
}
text, err := util.LoadTemplate(category, callTemplateFile, callTemplateText)
if err != nil {
return err
}
var alias []string
for _, item := range service.RPC {
alias = append(alias, fmt.Sprintf("%s = %s", parser.CamelCase(item.RequestType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.RequestType))))
alias = append(alias, fmt.Sprintf("%s = %s", parser.CamelCase(item.ReturnsType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.ReturnsType))))
}
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"name": formatFilename(service.Name),
"alias": strings.Join(alias, util.NL),
"head": head,
"filePackage": formatFilename(service.Name),
"package": fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
"serviceName": parser.CamelCase(service.Name),
"functions": strings.Join(functions, util.NL),
"interface": strings.Join(iFunctions, util.NL),
}, filename, true)
return err
}
func (g *defaultGenerator) genFunction(goPackage string, service parser.Service) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
if err != nil {
return nil, err
}
comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"rpcServiceName": stringx.From(service.Name).Title(),
"method": stringx.From(rpc.Name).Title(),
"package": goPackage,
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType),
"hasComment": len(comment) > 0,
"comment": comment,
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}
func (g *defaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
if err != nil {
return nil, err
}
comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("interfaceFn").Parse(text).Execute(
map[string]interface{}{
"hasComment": len(comment) > 0,
"comment": comment,
"method": stringx.From(rpc.Name).Title(),
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType),
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}

View File

@ -0,0 +1,44 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateCall(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenCall(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -1,10 +1,11 @@
package gen
package generator
import (
"io/ioutil"
"os"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
@ -17,9 +18,9 @@ type Config struct {
}
`
func (g *defaultRpcGenerator) genConfig() error {
configPath := g.dirM[dirConfig]
fileName := filepath.Join(configPath, fileConfig)
func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto) error {
dir := ctx.GetConfig()
fileName := filepath.Join(dir.Filename, formatFilename("config")+".go")
if util.FileExists(fileName) {
return nil
}

View File

@ -0,0 +1,48 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateConfig(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenConfig(dirCtx, proto)
assert.Nil(t, err)
// test file exists
err = g.GenConfig(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -0,0 +1,15 @@
package generator
import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
type Generator interface {
Prepare() error
GenMain(ctx DirContext, proto parser.Proto) error
GenCall(ctx DirContext, proto parser.Proto) error
GenEtc(ctx DirContext, proto parser.Proto) error
GenConfig(ctx DirContext, proto parser.Proto) error
GenLogic(ctx DirContext, proto parser.Proto) error
GenServer(ctx DirContext, proto parser.Proto) error
GenSvc(ctx DirContext, proto parser.Proto) error
GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error
}

View File

@ -1,9 +1,10 @@
package gen
package generator
import (
"fmt"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
@ -15,12 +16,10 @@ Etcd:
Key: {{.serviceName}}.rpc
`
func (g *defaultRpcGenerator) genEtc() error {
etdDir := g.dirM[dirEtc]
fileName := filepath.Join(etdDir, fmt.Sprintf("%v.yaml", g.Ctx.ServiceName.Lower()))
if util.FileExists(fileName) {
return nil
}
func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error {
dir := ctx.GetEtc()
serviceNameLower := formatFilename(ctx.GetMain().Base)
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower))
text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
if err != nil {
@ -28,6 +27,6 @@ func (g *defaultRpcGenerator) genEtc() error {
}
return util.With("etc").Parse(text).SaveTo(map[string]interface{}{
"serviceName": g.Ctx.ServiceName.Lower(),
"serviceName": serviceNameLower,
}, fileName, false)
}

View File

@ -0,0 +1,45 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateEtc(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenEtc(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -0,0 +1,101 @@
package generator
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
logicTemplate = `package logic
import (
"context"
{{.imports}}
"github.com/tal-tech/go-zero/core/logx"
)
type {{.logicName}} struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
return &{{.logicName}}{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
{{.functions}}
`
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
// todo: add your logic here and delete this line
return &{{.responseType}}{}, nil
}
`
)
func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error {
dir := ctx.GetLogic()
for _, rpc := range proto.Service.RPC {
filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic")+".go")
functions, err := g.genLogicFunction(proto.PbPackage, rpc)
if err != nil {
return err
}
imports := collection.NewSet()
imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetSvc().Package))
imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetPb().Package))
text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
if err != nil {
return err
}
err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()),
"functions": functions,
"imports": strings.Join(imports.KeysStr(), util.NL),
}, filename, false)
if err != nil {
return err
}
}
return nil
}
func (g *defaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) {
var functions = make([]string, 0)
text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
if err != nil {
return "", err
}
logicName := stringx.From(rpc.Name + "_logic").ToCamel()
comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
"logicName": logicName,
"method": parser.CamelCase(rpc.Name),
"request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
"responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
"hasComment": len(comment) > 0,
"comment": comment,
})
if err != nil {
return "", err
}
functions = append(functions, buffer.String())
return strings.Join(functions, util.NL), nil
}

View File

@ -0,0 +1,44 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateLogic(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenLogic(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -0,0 +1,70 @@
package generator
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const mainTemplate = `{{.head}}
package main
import (
"flag"
"fmt"
{{.imports}}
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/zrpc"
"google.golang.org/grpc"
)
var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
func main() {
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
ctx := svc.NewServiceContext(c)
srv := server.New{{.service}}Server(ctx)
s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
{{.pkg}}.Register{{.service}}Server(grpcServer, srv)
})
defer s.Stop()
fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
s.Start()
}
`
func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
dir := ctx.GetMain()
serviceNameLower := formatFilename(ctx.GetMain().Base)
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower))
imports := make([]string, 0)
pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
remoteImport := fmt.Sprintf(`"%v"`, ctx.GetServer().Package)
configImport := fmt.Sprintf(`"%v"`, ctx.GetConfig().Package)
imports = append(imports, configImport, pbImport, remoteImport, svcImport)
head := util.GetHead(proto.Name)
text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate)
if err != nil {
return err
}
return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"serviceName": serviceNameLower,
"imports": strings.Join(imports, util.NL),
"pkg": proto.PbPackage,
"service": parser.CamelCase(proto.Service.Name),
}, fileName, false)
}

View File

@ -0,0 +1,45 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateMain(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenMain(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -0,0 +1,31 @@
package generator
import (
"bytes"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error {
dir := ctx.GetPb()
cw := new(bytes.Buffer)
base := filepath.Dir(proto.Src)
cw.WriteString("protoc ")
for _, ip := range protoImportPath {
cw.WriteString(" -I=" + ip)
}
cw.WriteString(" -I=" + base)
cw.WriteString(" " + proto.Name)
if strings.Contains(proto.GoPackage, "/") {
cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetInternal().Filename)
} else {
cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
}
command := cw.String()
g.log.Debug(command)
_, err := execx.Run(command, "")
return err
}

View File

@ -0,0 +1,184 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateCaseNilImport(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
//_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
func TestGenerateCaseImport(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
func TestGenerateCasePathOption(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_option.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_option.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
func TestGenerateCaseWordOption(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_word_option.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_word_option.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}
// test keyword go
func TestGenerateCaseGoOption(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_go_option.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
if err := g.Prepare(); err == nil {
err = g.GenPb(dirCtx, nil, proto)
assert.Nil(t, err)
targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_go_option.pb.go")
assert.True(t, func() bool {
return util.FileExists(targetPb)
}())
}
}

View File

@ -0,0 +1,102 @@
package generator
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
serverTemplate = `{{.head}}
package server
import (
"context"
{{.imports}}
)
type {{.server}}Server struct {
svcCtx *svc.ServiceContext
}
func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
return &{{.server}}Server{
svcCtx: svcCtx,
}
}
{{.funcs}}
`
functionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
l := logic.New{{.logicName}}(ctx,s.svcCtx)
return l.{{.method}}(in)
}
`
)
func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error {
dir := ctx.GetServer()
logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
imports := collection.NewSet()
imports.AddStr(logicImport, svcImport, pbImport)
head := util.GetHead(proto.Name)
service := proto.Service
serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server")+".go")
funcList, err := g.genFunctions(proto.PbPackage, service)
if err != nil {
return err
}
text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
if err != nil {
return err
}
err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"server": stringx.From(service.Name).Title(),
"imports": strings.Join(imports.KeysStr(), util.NL),
"funcs": strings.Join(funcList, util.NL),
}, serverFile, true)
return err
}
func (g *defaultGenerator) genFunctions(goPackage string, service parser.Service) ([]string, error) {
var functionList []string
for _, rpc := range service.RPC {
text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
if err != nil {
return nil, err
}
comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
"server": stringx.From(service.Name).Title(),
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()),
"method": parser.CamelCase(rpc.Name),
"request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
"hasComment": len(comment) > 0,
"comment": comment,
})
if err != nil {
return nil, err
}
functionList = append(functionList, buffer.String())
}
return functionList, nil
}

View File

@ -0,0 +1,45 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateServer(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenServer(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -1,9 +1,10 @@
package gen
package generator
import (
"fmt"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
@ -22,15 +23,15 @@ func NewServiceContext(c config.Config) *ServiceContext {
}
`
func (g *defaultRpcGenerator) genSvc() error {
svcPath := g.dirM[dirSvc]
fileName := filepath.Join(svcPath, fileServiceContext)
func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto) error {
dir := ctx.GetSvc()
fileName := filepath.Join(dir.Filename, formatFilename("service_context")+".go")
text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate)
if err != nil {
return err
}
return util.With("svc").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)),
"imports": fmt.Sprintf(`"%v"`, ctx.GetConfig().Package),
}, fileName, false)
}

View File

@ -0,0 +1,40 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateSvc(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.GenSvc(dirCtx, proto)
assert.Nil(t, err)
}

View File

@ -0,0 +1,152 @@
package generator
import (
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
wd = "wd"
etc = "etc"
internal = "internal"
config = "config"
logic = "logic"
server = "server"
svc = "svc"
pb = "pb"
call = "call"
)
type (
DirContext interface {
GetCall() Dir
GetEtc() Dir
GetInternal() Dir
GetConfig() Dir
GetLogic() Dir
GetServer() Dir
GetSvc() Dir
GetPb() Dir
GetMain() Dir
}
Dir struct {
Base string
Filename string
Package string
}
defaultDirContext struct {
inner map[string]Dir
}
)
func mkdir(ctx *ctx.ProjectContext, proto parser.Proto) (DirContext, error) {
inner := make(map[string]Dir)
etcDir := filepath.Join(ctx.WorkDir, "etc")
internalDir := filepath.Join(ctx.WorkDir, "internal")
configDir := filepath.Join(internalDir, "config")
logicDir := filepath.Join(internalDir, "logic")
serverDir := filepath.Join(internalDir, "server")
svcDir := filepath.Join(internalDir, "svc")
pbDir := filepath.Join(internalDir, proto.GoPackage)
callDir := filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name).ToCamel()))
inner[wd] = Dir{
Filename: ctx.WorkDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(ctx.WorkDir, ctx.Dir))),
Base: filepath.Base(ctx.WorkDir),
}
inner[etc] = Dir{
Filename: etcDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(etcDir, ctx.Dir))),
Base: filepath.Base(etcDir),
}
inner[internal] = Dir{
Filename: internalDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(internalDir, ctx.Dir))),
Base: filepath.Base(internalDir),
}
inner[config] = Dir{
Filename: configDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(configDir, ctx.Dir))),
Base: filepath.Base(configDir),
}
inner[logic] = Dir{
Filename: logicDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(logicDir, ctx.Dir))),
Base: filepath.Base(logicDir),
}
inner[server] = Dir{
Filename: serverDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(serverDir, ctx.Dir))),
Base: filepath.Base(serverDir),
}
inner[svc] = Dir{
Filename: svcDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(svcDir, ctx.Dir))),
Base: filepath.Base(svcDir),
}
inner[pb] = Dir{
Filename: pbDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(pbDir, ctx.Dir))),
Base: filepath.Base(pbDir),
}
inner[call] = Dir{
Filename: callDir,
Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(callDir, ctx.Dir))),
Base: filepath.Base(callDir),
}
for _, v := range inner {
err := util.MkdirIfNotExist(v.Filename)
if err != nil {
return nil, err
}
}
return &defaultDirContext{
inner: inner,
}, nil
}
func (d *defaultDirContext) GetCall() Dir {
return d.inner[call]
}
func (d *defaultDirContext) GetEtc() Dir {
return d.inner[etc]
}
func (d *defaultDirContext) GetInternal() Dir {
return d.inner[internal]
}
func (d *defaultDirContext) GetConfig() Dir {
return d.inner[config]
}
func (d *defaultDirContext) GetLogic() Dir {
return d.inner[logic]
}
func (d *defaultDirContext) GetServer() Dir {
return d.inner[server]
}
func (d *defaultDirContext) GetSvc() Dir {
return d.inner[svc]
}
func (d *defaultDirContext) GetPb() Dir {
return d.inner[pb]
}
func (d *defaultDirContext) GetMain() Dir {
return d.inner[wd]
}
func (d *Dir) Valid() bool {
return len(d.Filename) > 0 && len(d.Package) > 0
}

View File

@ -0,0 +1,130 @@
package generator
import (
"go/build"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestMkDirInGoPath(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
defer func() {
_ = os.RemoveAll(dir)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
internal := filepath.Join(dir, "internal")
assert.True(t, true, func() bool {
return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
}())
assert.True(t, true, func() bool {
return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
}())
assert.True(t, true, func() bool {
return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
}())
}
func TestMkDirInGoMod(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
_, err = execx.Run("go mod init "+projectName, dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(dir)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
internal := filepath.Join(dir, "internal")
assert.True(t, true, func() bool {
return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package
}())
assert.True(t, true, func() bool {
return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package
}())
assert.True(t, true, func() bool {
return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package
}())
assert.True(t, true, func() bool {
return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package
}())
}

View File

@ -1,12 +1,10 @@
package gen
package generator
import (
"path/filepath"
"strings"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
@ -27,33 +25,23 @@ service {{.serviceName}} {
}
`
type rpcTemplate struct {
out string
console.Console
}
func NewRpcTemplate(out string, idea bool) *rpcTemplate {
return &rpcTemplate{
out: out,
Console: console.NewConsole(idea),
}
}
func (r *rpcTemplate) MustGenerate(showState bool) {
r.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」")
r.Info("-> generating template...")
protoFilename := filepath.Base(r.out)
func ProtoTmpl(out string) error {
protoFilename := filepath.Base(out)
serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename)))
text, err := util.LoadTemplate(category, rpcTemplateFile, rpcTemplateText)
r.Must(err)
if err != nil {
return err
}
dir := filepath.Dir(out)
err = util.MkdirIfNotExist(dir)
if err != nil {
return err
}
err = util.With("t").Parse(text).SaveTo(map[string]string{
"package": serviceName.UnTitle(),
"serviceName": serviceName.Title(),
}, r.out, false)
r.Must(err)
if showState {
r.Success("Done.")
}
}, out, false)
return err
}

View File

@ -0,0 +1,21 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProtoTmpl(t *testing.T) {
out, err := filepath.Abs("./test/test.proto")
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(filepath.Dir(out))
}()
err = ProtoTmpl(out)
assert.Nil(t, err)
_, err = os.Stat(out)
assert.Nil(t, err)
}

View File

@ -1,4 +1,4 @@
package gen
package generator
import (
"fmt"
@ -10,7 +10,6 @@ import (
const (
category = "rpc"
callTemplateFile = "call.tpl"
callTypesTemplateFile = "call-types.tpl"
callInterfaceFunctionTemplateFile = "call-interface-func.tpl"
callFunctionTemplateFile = "call-func.tpl"
configTemplateFileFile = "config.tpl"
@ -26,7 +25,6 @@ const (
var templates = map[string]string{
callTemplateFile: callTemplateText,
callTypesTemplateFile: callTemplateTypes,
callInterfaceFunctionTemplateFile: callInterfaceFunctionTemplate,
callFunctionTemplateFile: callFunctionTemplate,
configTemplateFileFile: configTemplate,

View File

@ -1,4 +1,4 @@
package gen
package generator
import (
"io/ioutil"
@ -90,3 +90,7 @@ func TestUpdate(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
}
func TestGetCategory(t *testing.T) {
assert.Equal(t, category, GetCategory())
}

View File

@ -0,0 +1,25 @@
// test proto
syntax = "proto3";
package test;
option go_package = "go";
import "test_base.proto";
message TestMessage{
base.CommonReq req = 1;
}
message TestReq{}
message TestReply{
base.CommonReply reply = 2;
}
enum TestEnum {
unknown = 0;
male = 1;
female = 2;
}
service TestService{
rpc TestRpc (TestReq)returns(TestReply);
}

View File

@ -0,0 +1,12 @@
// test proto
syntax = "proto3";
package base;
message CommonReq {
string in = 1;
}
message CommonReply {
string out = 1;
}

View File

@ -0,0 +1,18 @@
// test proto
syntax = "proto3";
package stream;
option go_package="go";
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp);
}

View File

@ -0,0 +1,18 @@
// test proto
syntax = "proto3";
package greet;
import "base/common.proto";
message In {
string name = 1;
common.User user = 2;
}
message Out {
string greet = 1;
}
service StreamGreeter {
rpc greet(In) returns (Out);
}

View File

@ -0,0 +1,18 @@
// test proto
syntax = "proto3";
package stream;
option go_package="github.com/tal-tech/go-zero";
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp);
}

View File

@ -0,0 +1,16 @@
// test proto
syntax = "proto3";
package stream;
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp);
}

View File

@ -0,0 +1,18 @@
// test proto
syntax = "proto3";
package stream;
option go_package="user";
message StreamReq {
string name = 1;
}
message StreamResp {
string greet = 1;
}
service StreamGreeter {
rpc greet(StreamReq) returns (StreamResp);
}

View File

@ -0,0 +1,10 @@
package parser
import "github.com/emicklei/proto"
func GetComment(comment *proto.Comment) string {
if comment == nil {
return ""
}
return comment.Message()
}

View File

@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type Import struct {
*proto.Import
}

View File

@ -0,0 +1,7 @@
package parser
import pr "github.com/emicklei/proto"
type Message struct {
*pr.Message
}

View File

@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type Option struct {
*proto.Option
}

View File

@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type Package struct {
*proto.Package
}

View File

@ -1,46 +1,170 @@
package parser
import (
"errors"
"fmt"
"go/token"
"os"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/emicklei/proto"
)
func Transfer(proto, target string, externalImport []*Import, console console.Console) (*PbAst, error) {
messageM := make(map[string]lang.PlaceholderType)
enumM := make(map[string]*Enum)
protoAst, err := parseProto(proto, messageM, enumM)
if err != nil {
return nil, err
}
for _, item := range externalImport {
err = checkImport(item.OriginalProtoPath)
if err != nil {
return nil, err
}
innerAst, err := parseProto(item.OriginalProtoPath, protoAst.Message, protoAst.Enum)
if err != nil {
return nil, err
}
for k, v := range innerAst.Message {
protoAst.Message[k] = v
}
for k, v := range innerAst.Enum {
protoAst.Enum[k] = v
}
}
protoAst.Import = externalImport
protoAst.PbSrc = filepath.Join(target, strings.TrimSuffix(filepath.Base(proto), ".proto")+".pb.go")
return transfer(protoAst, console)
type (
defaultProtoParser struct{}
)
func NewDefaultProtoParser() *defaultProtoParser {
return &defaultProtoParser{}
}
func transfer(proto *Proto, console console.Console) (*PbAst, error) {
parser := MustNewAstParser(proto, console)
parse, err := parser.Parse()
func (p *defaultProtoParser) Parse(src string) (Proto, error) {
var ret Proto
abs, err := filepath.Abs(src)
if err != nil {
return nil, err
return Proto{}, err
}
return parse, nil
r, err := os.Open(abs)
if err != nil {
return ret, err
}
defer r.Close()
parser := proto.NewParser(r)
set, err := parser.Parse()
if err != nil {
return ret, err
}
var serviceList []Service
proto.Walk(
set,
proto.WithImport(func(i *proto.Import) {
ret.Import = append(ret.Import, Import{Import: i})
}),
proto.WithMessage(func(message *proto.Message) {
ret.Message = append(ret.Message, Message{Message: message})
}),
proto.WithPackage(func(p *proto.Package) {
ret.Package = Package{Package: p}
}),
proto.WithService(func(service *proto.Service) {
serv := Service{Service: service}
elements := service.Elements
for _, el := range elements {
v, _ := el.(*proto.RPC)
if v == nil {
continue
}
serv.RPC = append(serv.RPC, &RPC{RPC: v})
}
serviceList = append(serviceList, serv)
}),
proto.WithOption(func(option *proto.Option) {
if option.Name == "go_package" {
ret.GoPackage = option.Constant.Source
}
}),
)
if len(serviceList) == 0 {
return ret, errors.New("rpc service not found")
}
if len(serviceList) > 1 {
return ret, errors.New("only one service expected")
}
service := serviceList[0]
name := filepath.Base(abs)
for _, rpc := range service.RPC {
if strings.Contains(rpc.RequestType, ".") {
return ret, fmt.Errorf("line %v:%v, request type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
}
if strings.Contains(rpc.ReturnsType, ".") {
return ret, fmt.Errorf("line %v:%v, returns type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
}
}
if len(ret.GoPackage) == 0 {
ret.GoPackage = ret.Package.Name
}
ret.PbPackage = GoSanitized(filepath.Base(ret.GoPackage))
ret.Src = abs
ret.Name = name
ret.Service = service
return ret, nil
}
// see google.golang.org/protobuf@v1.25.0/internal/strs/strings.go:71
func GoSanitized(s string) string {
// Sanitize the input to the set of valid characters,
// which must be '_' or be in the Unicode L or N categories.
s = strings.Map(func(r rune) rune {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
return r
}
return '_'
}, s)
// Prepend '_' in the event of a Go keyword conflict or if
// the identifier is invalid (does not start in the Unicode L category).
r, _ := utf8.DecodeRuneInString(s)
if token.Lookup(s).IsKeyword() || !unicode.IsLetter(r) {
return "_" + s
}
return s
}
// copy from github.com/golang/protobuf@v1.4.2/protoc-gen-go/generator/generator.go:2648
func CamelCase(s string) string {
if s == "" {
return ""
}
t := make([]byte, 0, 32)
i := 0
if s[0] == '_' {
// Need a capital letter; drop the '_'.
t = append(t, 'X')
i++
}
// Invariant: if the next letter is lower case, it must be converted
// to upper case.
// That is, we process a word at a time, where words are marked by _ or
// upper case letter. Digits are treated as words.
for ; i < len(s); i++ {
c := s[i]
if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) {
continue // Skip the underscore in s.
}
if isASCIIDigit(c) {
t = append(t, c)
continue
}
// Assume we have a letter now - if not, it's a bogus identifier.
// The next word is a sequence of characters that must start upper case.
if isASCIILower(c) {
c ^= ' ' // Make it a capital letter.
}
t = append(t, c) // Guaranteed not lower case.
// Accept lower case sequence that follows.
for i+1 < len(s) && isASCIILower(s[i+1]) {
i++
t = append(t, s[i])
}
}
return string(t)
}
func isASCIILower(c byte) bool {
return 'a' <= c && c <= 'z'
}
// Is c an ASCII digit?
func isASCIIDigit(c byte) bool {
return '0' <= c && c <= '9'
}

View File

@ -0,0 +1,78 @@
package parser
import (
"sort"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDefaultProtoParse(t *testing.T) {
p := NewDefaultProtoParser()
data, err := p.Parse("./test.proto")
assert.Nil(t, err)
assert.Equal(t, "base.proto", func() string {
ip := data.Import[0]
return ip.Filename
}())
assert.Equal(t, "test", data.Package.Name)
assert.Equal(t, true, data.GoPackage == "go")
assert.Equal(t, true, data.PbPackage == "_go")
assert.Equal(t, []string{"TestMessage", "TestReply", "TestReq"}, func() []string {
var list []string
for _, item := range data.Message {
list = append(list, item.Name)
}
sort.Strings(list)
return list
}())
assert.Equal(t, true, func() bool {
s := data.Service
if s.Name != "TestService" {
return false
}
rpcOne := s.RPC[0]
return rpcOne.Name == "TestRpcOne" && rpcOne.RequestType == "TestReq" && rpcOne.ReturnsType == "TestReply"
}())
}
func TestDefaultProtoParseCaseInvalidRequestType(t *testing.T) {
p := NewDefaultProtoParser()
_, err := p.Parse("./test_invalid_request.proto")
assert.True(t, true, func() bool {
return strings.Contains(err.Error(), "request type must defined in")
}())
}
func TestDefaultProtoParseCaseInvalidResponseType(t *testing.T) {
p := NewDefaultProtoParser()
_, err := p.Parse("./test_invalid_response.proto")
assert.True(t, true, func() bool {
return strings.Contains(err.Error(), "response type must defined in")
}())
}
func TestDefaultProtoParseError(t *testing.T) {
p := NewDefaultProtoParser()
_, err := p.Parse("./nil.proto")
assert.NotNil(t, err)
}
func TestDefaultProtoParse_Option(t *testing.T) {
p := NewDefaultProtoParser()
data, err := p.Parse("./test_option.proto")
assert.Nil(t, err)
assert.Equal(t, "github.com/tal-tech/go-zero", data.GoPackage)
assert.Equal(t, "go_zero", data.PbPackage)
}
func TestDefaultProtoParse_Option2(t *testing.T) {
p := NewDefaultProtoParser()
data, err := p.Parse("./test_option2.proto")
assert.Nil(t, err)
assert.Equal(t, "stream", data.GoPackage)
assert.Equal(t, "stream", data.PbPackage)
}

View File

@ -1,643 +0,0 @@
package parser
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"sort"
"strings"
"github.com/tal-tech/go-zero/core/lang"
sx "github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
flagStar = "*"
flagDot = "."
suffixServer = "Server"
referenceContext = "context"
unknownPrefix = "XXX_"
ignoreJsonTagExpression = `json:"-"`
)
var (
errorParseError = errors.New("pb parse error")
typeTemplate = `type (
{{.types}}
)`
structTemplate = `{{if .type}}type {{end}}{{.name}} struct {
{{.fields}}
}`
fieldTemplate = `{{if .hasDoc}}{{.doc}}
{{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}`
anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}"
objectM = make(map[string]*Struct)
)
type (
astParser struct {
filterStruct map[string]lang.PlaceholderType
filterEnum map[string]*Enum
console.Console
fileSet *token.FileSet
proto *Proto
}
Field struct {
Name stringx.String
Type Type
JsonTag string
Document []string
Comment []string
}
Struct struct {
Name stringx.String
Document []string
Comment []string
Field []*Field
}
ConstLit struct {
Name stringx.String
Document []string
Comment []string
Lit []*Lit
}
Lit struct {
Key string
Value int
}
Type struct {
// eg:context.Context
Expression string
// eg: *context.Context
StarExpression string
// Invoke Type Expression
InvokeTypeExpression string
// eg:context
Package string
// eg:Context
Name string
}
Func struct {
Name stringx.String
ParameterIn Type
ParameterOut Type
Document []string
}
RpcService struct {
Name stringx.String
Funcs []*Func
}
// parsing for rpc
PbAst struct {
// deprecated: containsAny will be removed in the feature
ContainsAny bool
Imports map[string]string
Structure map[string]*Struct
Service []*RpcService
*Proto
}
)
func MustNewAstParser(proto *Proto, log console.Console) *astParser {
return &astParser{
filterStruct: proto.Message,
filterEnum: proto.Enum,
Console: log,
fileSet: token.NewFileSet(),
proto: proto,
}
}
func (a *astParser) Parse() (*PbAst, error) {
var pbAst PbAst
pbAst.ContainsAny = a.proto.ContainsAny
pbAst.Proto = a.proto
pbAst.Structure = make(map[string]*Struct)
pbAst.Imports = make(map[string]string)
structure, imports, services, err := a.parse(a.proto.PbSrc)
if err != nil {
return nil, err
}
dependencyStructure, err := a.parseExternalDependency()
if err != nil {
return nil, err
}
for k, v := range structure {
pbAst.Structure[k] = v
}
for k, v := range dependencyStructure {
pbAst.Structure[k] = v
}
for key, path := range imports {
pbAst.Imports[key] = path
}
pbAst.Service = append(pbAst.Service, services...)
return &pbAst, nil
}
func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) {
structure = make(map[string]*Struct)
imports = make(map[string]string)
data, err := ioutil.ReadFile(pbSrc)
if err != nil {
retErr = err
return
}
fSet := a.fileSet
f, err := parser.ParseFile(fSet, "", data, parser.ParseComments)
if err != nil {
retErr = err
return
}
commentMap := ast.NewCommentMap(fSet, f, f.Comments)
f.Comments = commentMap.Filter(f).Comments()
strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name))
for k, v := range strucs {
if v == nil {
continue
}
structure[k] = v
}
importList := f.Imports
for _, item := range importList {
name := a.mustGetIndentName(item.Name)
if item.Path != nil {
imports[name] = item.Path.Value
}
}
services = append(services, function...)
return
}
func (a *astParser) parseExternalDependency() (map[string]*Struct, error) {
m := make(map[string]*Struct)
for _, impo := range a.proto.Import {
ret, _, _, err := a.parse(impo.OriginalPbPath)
if err != nil {
return nil, err
}
for k, v := range ret {
m[k] = v
}
}
return m, nil
}
func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) {
if scope == nil {
return nil, nil
}
objects := scope.Objects
structs := make(map[string]*Struct)
serviceList := make([]*RpcService, 0)
for name, obj := range objects {
decl := obj.Decl
if decl == nil {
continue
}
typeSpec, ok := decl.(*ast.TypeSpec)
if !ok {
continue
}
tp := typeSpec.Type
switch v := tp.(type) {
case *ast.StructType:
st, err := a.parseObject(name, v, sourcePackage)
a.Must(err)
structs[st.Name.Lower()] = st
case *ast.InterfaceType:
if !strings.HasSuffix(name, suffixServer) {
continue
}
list := a.mustServerFunctions(v, sourcePackage)
serviceList = append(serviceList, &RpcService{
Name: stringx.From(strings.TrimSuffix(name, suffixServer)),
Funcs: list,
})
}
}
targetStruct := make(map[string]*Struct)
for st := range a.filterStruct {
lower := strings.ToLower(st)
targetStruct[lower] = structs[lower]
}
return targetStruct, serviceList
}
func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func {
funcs := make([]*Func, 0)
methodObject := v.Methods
if methodObject == nil {
return nil
}
for _, method := range methodObject.List {
var item Func
name := a.mustGetIndentName(method.Names[0])
doc := a.parseCommentOrDoc(method.Doc)
item.Name = stringx.From(name)
item.Document = doc
types := method.Type
if types == nil {
funcs = append(funcs, &item)
continue
}
v, ok := types.(*ast.FuncType)
if !ok {
continue
}
params := v.Params
if params != nil {
inList, err := a.parseFields(params.List, true, sourcePackage)
a.Must(err)
for _, data := range inList {
if data.Type.Package == referenceContext {
continue
}
item.ParameterIn = data.Type
break
}
}
results := v.Results
if results != nil {
outList, err := a.parseFields(results.List, true, sourcePackage)
a.Must(err)
for _, data := range outList {
if data.Type.Package == referenceContext {
continue
}
item.ParameterOut = data.Type
break
}
}
funcs = append(funcs, &item)
}
return funcs
}
func (a *astParser) getFieldType(v string, sourcePackage string) Type {
var pkg, name, expression, starExpression, invokeTypeExpression string
if strings.Contains(v, ".") {
starExpression = v
if strings.Contains(v, "*") {
leftIndex := strings.Index(v, "*")
rightIndex := strings.Index(v, ".")
if leftIndex >= 0 {
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
invokeTypeExpression = v[rightIndex+1:]
}
} else {
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
leftIndex := strings.Index(v, "]")
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[rightIndex+1:]
}
}
} else {
expression = strings.TrimPrefix(v, flagStar)
switch v {
case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64",
"bool", "string", "bytes":
invokeTypeExpression = v
break
default:
name = expression
invokeTypeExpression = v
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".")
} else {
starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name)
invokeTypeExpression = v
}
}
}
expression = strings.TrimPrefix(starExpression, flagStar)
index := strings.LastIndex(expression, flagDot)
if index > 0 {
pkg = expression[0:index]
name = expression[index+1:]
} else {
pkg = sourcePackage
}
return Type{
Expression: expression,
StarExpression: starExpression,
InvokeTypeExpression: invokeTypeExpression,
Package: pkg,
Name: name,
}
}
func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) {
if data, ok := objectM[structName]; ok {
return data, nil
}
var st Struct
st.Name = stringx.From(structName)
if tp == nil {
return &st, nil
}
fields := tp.Fields
if fields == nil {
objectM[structName] = &st
return &st, nil
}
fieldList := fields.List
members, err := a.parseFields(fieldList, false, sourcePackage)
if err != nil {
return nil, err
}
for _, m := range members {
var field Field
field.Name = m.Name
field.Type = m.Type
field.JsonTag = m.JsonTag
field.Document = m.Document
field.Comment = m.Comment
st.Field = append(st.Field, &field)
}
objectM[structName] = &st
return &st, nil
}
func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) {
ret := make([]*Field, 0)
for _, field := range fields {
var item Field
tag := a.parseTag(field.Tag)
if tag == "" && !onlyType {
continue
}
if tag == ignoreJsonTagExpression {
continue
}
item.JsonTag = tag
name := a.parseName(field.Names)
if strings.HasPrefix(name, unknownPrefix) {
continue
}
item.Name = stringx.From(name)
typeName, err := a.parseType(field.Type)
if err != nil {
return nil, err
}
item.Type = a.getFieldType(typeName, sourcePackage)
if onlyType {
ret = append(ret, &item)
continue
}
docs := a.parseCommentOrDoc(field.Doc)
comments := a.parseCommentOrDoc(field.Comment)
item.Document = docs
item.Comment = comments
isInline := name == ""
if isInline {
return nil, a.wrapError(field.Pos(), "unexpected inline type:%s", name)
}
ret = append(ret, &item)
}
return ret, nil
}
func (a *astParser) parseTag(basicLit *ast.BasicLit) string {
if basicLit == nil {
return ""
}
value := basicLit.Value
splits := strings.Split(value, " ")
if len(splits) == 1 {
return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[0], "`", ""))
} else {
return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[1], "`", ""))
}
}
// returns
// resp1:type's string expression,like int、string、[]int64、map[string]User、*User
// resp2:error
func (a *astParser) parseType(expr ast.Expr) (string, error) {
if expr == nil {
return "", errorParseError
}
switch v := expr.(type) {
case *ast.StarExpr:
stringExpr, err := a.parseType(v.X)
if err != nil {
return "", err
}
e := fmt.Sprintf("*%s", stringExpr)
return e, nil
case *ast.Ident:
return a.mustGetIndentName(v), nil
case *ast.MapType:
keyStringExpr, err := a.parseType(v.Key)
if err != nil {
return "", err
}
valueStringExpr, err := a.parseType(v.Value)
if err != nil {
return "", err
}
e := fmt.Sprintf("map[%s]%s", keyStringExpr, valueStringExpr)
return e, nil
case *ast.ArrayType:
stringExpr, err := a.parseType(v.Elt)
if err != nil {
return "", err
}
e := fmt.Sprintf("[]%s", stringExpr)
return e, nil
case *ast.InterfaceType:
return "interface{}", nil
case *ast.SelectorExpr:
join := make([]string, 0)
xIdent, ok := v.X.(*ast.Ident)
xIndentName := a.mustGetIndentName(xIdent)
if ok {
join = append(join, xIndentName)
}
sel := v.Sel
join = append(join, a.mustGetIndentName(sel))
return strings.Join(join, "."), nil
case *ast.ChanType:
return "", a.wrapError(v.Pos(), "unexpected type 'chan'")
case *ast.FuncType:
return "", a.wrapError(v.Pos(), "unexpected type 'func'")
case *ast.StructType:
return "", a.wrapError(v.Pos(), "unexpected inline struct type")
default:
return "", a.wrapError(v.Pos(), "unexpected type '%v'", v)
}
}
func (a *astParser) parseName(names []*ast.Ident) string {
if len(names) == 0 {
return ""
}
name := names[0]
return a.mustGetIndentName(name)
}
func (a *astParser) parseCommentOrDoc(cg *ast.CommentGroup) []string {
if cg == nil {
return nil
}
comments := make([]string, 0)
for _, comment := range cg.List {
if comment == nil {
continue
}
text := strings.TrimSpace(comment.Text)
if text == "" {
continue
}
comments = append(comments, text)
}
return comments
}
func (a *astParser) mustGetIndentName(ident *ast.Ident) string {
if ident == nil {
return ""
}
return ident.Name
}
func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{}) error {
file := a.fileSet.Position(pos)
return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...))
}
func (f *Func) GetDoc() string {
return strings.Join(f.Document, util.NL)
}
func (f *Func) HaveDoc() bool {
return len(f.Document) > 0
}
func (a *PbAst) GenEnumCode() (string, error) {
var element []string
for _, item := range a.Enum {
code, err := item.GenEnumCode()
if err != nil {
return "", err
}
element = append(element, code)
}
return strings.Join(element, util.NL), nil
}
func (a *PbAst) GenTypesCode() (string, error) {
types := make([]string, 0)
sts := make([]*Struct, 0)
for _, item := range a.Structure {
sts = append(sts, item)
}
sort.Slice(sts, func(i, j int) bool {
return sts[i].Name.Source() < sts[j].Name.Source()
})
for _, s := range sts {
structCode, err := s.genCode(false)
if err != nil {
return "", err
}
if structCode == "" {
continue
}
types = append(types, structCode)
}
types = append(types, a.genAnyCode())
for _, item := range a.Enum {
typeCode, err := item.GenEnumTypeCode()
if err != nil {
return "", err
}
types = append(types, typeCode)
}
buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
"types": strings.Join(types, util.NL+util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (a *PbAst) genAnyCode() string {
if !a.ContainsAny {
return ""
}
return anyTypeTemplate
}
func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
fields := make([]string, 0)
for _, f := range s.Field {
var comment, doc string
if len(f.Comment) > 0 {
comment = f.Comment[0]
}
doc = strings.Join(f.Document, util.NL)
buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
"name": f.Name.Title(),
"type": f.Type.InvokeTypeExpression,
"tag": f.JsonTag,
"hasDoc": len(f.Document) > 0,
"doc": doc,
"hasComment": len(f.Comment) > 0,
"comment": comment,
})
if err != nil {
return "", err
}
fields = append(fields, buffer.String())
}
buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
"type": containsTypeStatement,
"name": s.Name.Title(),
"fields": strings.Join(fields, util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}

View File

@ -1,295 +1,12 @@
package parser
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/emicklei/proto"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
AnyImport = "google/protobuf/any.proto"
)
var (
enumTypeTemplate = `{{.name}} int32`
enumTemplate = `const (
{{.element}}
)`
enumFiledTemplate = `{{.key}} {{.name}} = {{.value}}`
)
type (
MessageField struct {
Type string
Name stringx.String
}
Message struct {
Name stringx.String
Element []*MessageField
*proto.Message
}
Enum struct {
Name stringx.String
Element []*EnumField
*proto.Enum
}
EnumField struct {
Key string
Value int
}
Proto struct {
Package string
Import []*Import
PbSrc string
// deprecated: containsAny will be removed in the feature
ContainsAny bool
Message map[string]lang.PlaceholderType
Enum map[string]*Enum
}
Import struct {
ProtoImportName string
PbImportName string
OriginalDir string
OriginalProtoPath string
OriginalPbPath string
BridgeImport string
exists bool
//xx.proto
protoName string
// xx.pb.go
pbName string
}
)
func checkImport(src string) error {
r, err := os.Open(src)
if err != nil {
return err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return err
}
var base = filepath.Base(src)
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if err != nil {
return
}
err = fmt.Errorf("%v:%v the external proto cannot import other proto files", base, i.Position.Line)
}))
if err != nil {
return err
}
return nil
}
func ParseImport(src string) ([]*Import, bool, error) {
bridgeImportM := make(map[string]string)
r, err := os.Open(src)
if err != nil {
return nil, false, err
}
defer r.Close()
workDir := filepath.Dir(src)
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, false, err
}
protoImportSet := collection.NewSet()
var containsAny bool
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if i.Filename == AnyImport {
containsAny = true
return
}
protoImportSet.AddStr(i.Filename)
if i.Comment != nil {
lines := i.Comment.Lines
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "@") {
continue
}
line = strings.TrimPrefix(line, "@")
bridgeImportM[i.Filename] = line
}
}
}))
var importList []*Import
for _, item := range protoImportSet.KeysStr() {
pb := strings.TrimSuffix(filepath.Base(item), filepath.Ext(item)) + ".pb.go"
var pbImportName, brideImport string
if v, ok := bridgeImportM[item]; ok {
pbImportName = v
brideImport = "M" + item + "=" + v
} else {
pbImportName = item
}
var impo = Import{
ProtoImportName: item,
PbImportName: pbImportName,
BridgeImport: brideImport,
}
protoSource := filepath.Join(workDir, item)
pbSource := filepath.Join(filepath.Dir(protoSource), pb)
if util.FileExists(protoSource) && util.FileExists(pbSource) {
impo.OriginalProtoPath = protoSource
impo.OriginalPbPath = pbSource
impo.OriginalDir = filepath.Dir(protoSource)
impo.exists = true
impo.protoName = filepath.Base(item)
impo.pbName = pb
} else {
return nil, false, fmt.Errorf("「%v」: import must be found in the relative directory of 「%v」", item, filepath.Base(src))
}
importList = append(importList, &impo)
}
return importList, containsAny, nil
}
func parseProto(src string, messageM map[string]lang.PlaceholderType, enumM map[string]*Enum) (*Proto, error) {
if !filepath.IsAbs(src) {
return nil, fmt.Errorf("expected absolute path,but found: %v", src)
}
r, err := os.Open(src)
if err != nil {
return nil, err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, err
}
// xx.proto
fileBase := filepath.Base(src)
var resp Proto
proto.Walk(parseRet, proto.WithPackage(func(p *proto.Package) {
if err != nil {
return
}
if len(resp.Package) != 0 {
err = fmt.Errorf("%v:%v duplicate package「%v」", fileBase, p.Position.Line, p.Name)
}
if len(p.Name) == 0 {
err = errors.New("package not found")
}
resp.Package = p.Name
}), proto.WithMessage(func(message *proto.Message) {
if err != nil {
return
}
for _, item := range message.Elements {
switch item.(type) {
case *proto.NormalField, *proto.MapField, *proto.Comment:
continue
default:
err = fmt.Errorf("%v: unsupport inline declaration", fileBase)
return
}
}
name := stringx.From(message.Name)
if _, ok := messageM[name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate message 「%v」", fileBase, message.Position.Line, message.Name)
return
}
messageM[name.Lower()] = lang.Placeholder
}), proto.WithEnum(func(enum *proto.Enum) {
if err != nil {
return
}
var node Enum
node.Enum = enum
node.Name = stringx.From(enum.Name)
for _, item := range enum.Elements {
v, ok := item.(*proto.EnumField)
if !ok {
continue
}
node.Element = append(node.Element, &EnumField{
Key: v.Name,
Value: v.Integer,
})
}
if _, ok := enumM[node.Name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate enum 「%v」", fileBase, node.Position.Line, node.Name.Source())
return
}
lower := stringx.From(enum.Name).Lower()
enumM[lower] = &node
}))
if err != nil {
return nil, err
}
resp.Message = messageM
resp.Enum = enumM
return &resp, nil
}
func (e *Enum) GenEnumCode() (string, error) {
var element []string
for _, item := range e.Element {
code, err := item.GenEnumFieldCode(e.Name.Source())
if err != nil {
return "", err
}
element = append(element, code)
}
buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
"element": strings.Join(element, util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *Enum) GenEnumTypeCode() (string, error) {
buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
"name": e.Name.Source(),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) {
buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
"key": e.Key,
"name": parentName,
"value": e.Value,
})
if err != nil {
return "", err
}
return buffer.String(), nil
type Proto struct {
Src string
Name string
Package Package
PbPackage string
GoPackage string
Import []Import
Message []Message
Service Service
}

View File

@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type RPC struct {
*proto.RPC
}

View File

@ -0,0 +1,8 @@
package parser
import "github.com/emicklei/proto"
type Service struct {
*proto.Service
RPC []*RPC
}

View File

@ -0,0 +1,20 @@
syntax = "proto3";
package test;
option go_package = "go";
import "base.proto";
message TestMessage{}
message TestReq{}
message TestReply{}
enum TestEnum {
unknown = 0;
male = 1;
female = 2;
}
service TestService{
rpc TestRpcOne (TestReq)returns(TestReply);
}

View File

@ -0,0 +1,13 @@
syntax = "proto3";
package test;
option go_package = "go";
import "base.proto";
message Reply{}
service TestService{
rpc TestRpcTwo (base.Req)returns(Reply);
}

View File

@ -0,0 +1,13 @@
syntax = "proto3";
package test;
option go_package = "go";
import "base.proto";
message Req{}
service TestService{
rpc TestRpcTwo (Req)returns(base.Reply);
}

View File

@ -0,0 +1,10 @@
syntax = "proto3";
package stream;
option go_package="github.com/tal-tech/go-zero";
message placeholder{}
service greet{
rpc hello(placeholder)returns(placeholder);
}

View File

@ -0,0 +1,9 @@
syntax = "proto3";
package stream;
message placeholder{}
service greet{
rpc hello(placeholder)returns(placeholder);
}

View File

@ -1,28 +0,0 @@
syntax = "proto3";
// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test
package test;
// @github.com/tal-tech/go-zero/tools/goctl/rpc
import "base.proto";
import "google/protobuf/any.proto";
message request {
string name = 1;
}
enum Gender{
UNKNOWN = 0;
MALE = 1;
FEMALE = 2;
}
message response {
string greet = 1;
google.protobuf.Any data = 2;
}
message map {
map<string, string> m = 1;
}
service Greeter {
rpc greet(request) returns (response);
rpc idRequest(base.IdRequest)returns(base.EmptyResponse);
}

View File

@ -1,331 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: test.proto
// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test
package test
import (
context "context"
fmt "fmt"
proto "github.com/golang/protobuf/proto"
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
anypb "google.golang.org/protobuf/types/known/anypb"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type Gender int32
const (
Gender_UNKNOWN Gender = 0
Gender_MALE Gender = 1
Gender_FEMALE Gender = 2
)
var Gender_name = map[int32]string{
0: "UNKNOWN",
1: "MALE",
2: "FEMALE",
}
var Gender_value = map[string]int32{
"UNKNOWN": 0,
"MALE": 1,
"FEMALE": 2,
}
func (x Gender) String() string {
return proto.EnumName(Gender_name, int32(x))
}
func (Gender) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{0}
}
type Request struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Request) Reset() { *m = Request{} }
func (m *Request) String() string { return proto.CompactTextString(m) }
func (*Request) ProtoMessage() {}
func (*Request) Descriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{0}
}
func (m *Request) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Request.Unmarshal(m, b)
}
func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Request.Marshal(b, m, deterministic)
}
func (m *Request) XXX_Merge(src proto.Message) {
xxx_messageInfo_Request.Merge(m, src)
}
func (m *Request) XXX_Size() int {
return xxx_messageInfo_Request.Size(m)
}
func (m *Request) XXX_DiscardUnknown() {
xxx_messageInfo_Request.DiscardUnknown(m)
}
var xxx_messageInfo_Request proto.InternalMessageInfo
func (m *Request) GetName() string {
if m != nil {
return m.Name
}
return ""
}
type Response struct {
Greet string `protobuf:"bytes,1,opt,name=greet,proto3" json:"greet,omitempty"`
Data *anypb.Any `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Response) Reset() { *m = Response{} }
func (m *Response) String() string { return proto.CompactTextString(m) }
func (*Response) ProtoMessage() {}
func (*Response) Descriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{1}
}
func (m *Response) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Response.Unmarshal(m, b)
}
func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Response.Marshal(b, m, deterministic)
}
func (m *Response) XXX_Merge(src proto.Message) {
xxx_messageInfo_Response.Merge(m, src)
}
func (m *Response) XXX_Size() int {
return xxx_messageInfo_Response.Size(m)
}
func (m *Response) XXX_DiscardUnknown() {
xxx_messageInfo_Response.DiscardUnknown(m)
}
var xxx_messageInfo_Response proto.InternalMessageInfo
func (m *Response) GetGreet() string {
if m != nil {
return m.Greet
}
return ""
}
func (m *Response) GetData() *anypb.Any {
if m != nil {
return m.Data
}
return nil
}
type Map struct {
M map[string]string `protobuf:"bytes,1,rep,name=m,proto3" json:"m,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Map) Reset() { *m = Map{} }
func (m *Map) String() string { return proto.CompactTextString(m) }
func (*Map) ProtoMessage() {}
func (*Map) Descriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{2}
}
func (m *Map) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Map.Unmarshal(m, b)
}
func (m *Map) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Map.Marshal(b, m, deterministic)
}
func (m *Map) XXX_Merge(src proto.Message) {
xxx_messageInfo_Map.Merge(m, src)
}
func (m *Map) XXX_Size() int {
return xxx_messageInfo_Map.Size(m)
}
func (m *Map) XXX_DiscardUnknown() {
xxx_messageInfo_Map.DiscardUnknown(m)
}
var xxx_messageInfo_Map proto.InternalMessageInfo
func (m *Map) GetM() map[string]string {
if m != nil {
return m.M
}
return nil
}
func init() {
proto.RegisterEnum("test.Gender", Gender_name, Gender_value)
proto.RegisterType((*Request)(nil), "test.request")
proto.RegisterType((*Response)(nil), "test.response")
proto.RegisterType((*Map)(nil), "test.map")
proto.RegisterMapType((map[string]string)(nil), "test.map.MEntry")
}
func init() { proto.RegisterFile("test.proto", fileDescriptor_c161fcfdc0c3ff1e) }
var fileDescriptor_c161fcfdc0c3ff1e = []byte{
// 301 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x34, 0x90, 0x4f, 0x4b, 0xc3, 0x40,
0x10, 0xc5, 0xdd, 0x36, 0xa6, 0xed, 0x14, 0x35, 0x8c, 0x3d, 0xd4, 0x80, 0x52, 0x7a, 0x90, 0xa0,
0xb0, 0xc5, 0xea, 0x41, 0xbc, 0xf5, 0x10, 0x8b, 0x7f, 0x5a, 0x21, 0x20, 0x1e, 0x3c, 0x6d, 0xc9,
0x58, 0xc4, 0xee, 0x26, 0x6e, 0xb6, 0x42, 0xbe, 0xbd, 0x64, 0x77, 0x73, 0x7b, 0xbf, 0xd9, 0x59,
0xde, 0x9b, 0x07, 0x60, 0xa8, 0x32, 0xbc, 0xd4, 0x85, 0x29, 0x30, 0x68, 0x74, 0x0c, 0x1b, 0x51,
0x91, 0x9b, 0xc4, 0x67, 0xdb, 0xa2, 0xd8, 0xee, 0x68, 0x66, 0x69, 0xb3, 0xff, 0x9a, 0x09, 0x55,
0xbb, 0xa7, 0xe9, 0x39, 0xf4, 0x34, 0xfd, 0xee, 0xa9, 0x32, 0x88, 0x10, 0x28, 0x21, 0x69, 0xcc,
0x26, 0x2c, 0x19, 0x64, 0x56, 0x4f, 0x9f, 0xa1, 0xaf, 0xa9, 0x2a, 0x0b, 0x55, 0x11, 0x8e, 0xe0,
0x70, 0xab, 0x89, 0x8c, 0x5f, 0x70, 0x80, 0x09, 0x04, 0xb9, 0x30, 0x62, 0xdc, 0x99, 0xb0, 0x64,
0x38, 0x1f, 0x71, 0x67, 0xc5, 0x5b, 0x2b, 0xbe, 0x50, 0x75, 0x66, 0x37, 0xa6, 0x9f, 0xd0, 0x95,
0xa2, 0xc4, 0x0b, 0x60, 0x72, 0xcc, 0x26, 0xdd, 0x64, 0x38, 0x8f, 0xb8, 0x8d, 0x2d, 0x45, 0xc9,
0x57, 0xa9, 0x32, 0xba, 0xce, 0x98, 0x8c, 0xef, 0x20, 0x74, 0x80, 0x11, 0x74, 0x7f, 0xa8, 0xf6,
0x76, 0x8d, 0x6c, 0x22, 0xfc, 0x89, 0xdd, 0x9e, 0xac, 0xdb, 0x20, 0x73, 0xf0, 0xd0, 0xb9, 0x67,
0x57, 0xd7, 0x10, 0x2e, 0x49, 0xe5, 0xa4, 0x71, 0x08, 0xbd, 0xf7, 0xf5, 0xcb, 0xfa, 0xed, 0x63,
0x1d, 0x1d, 0x60, 0x1f, 0x82, 0xd5, 0xe2, 0x35, 0x8d, 0x18, 0x02, 0x84, 0x8f, 0xa9, 0xd5, 0x9d,
0x79, 0x0e, 0xbd, 0x65, 0x13, 0x9e, 0x34, 0x5e, 0xfa, 0xa3, 0xf0, 0xc8, 0x65, 0xf1, 0x65, 0xc4,
0xc7, 0x2d, 0xfa, 0xe3, 0x6f, 0x60, 0xf0, 0x9d, 0x67, 0xbe, 0xa9, 0x13, 0x6e, 0xcb, 0x7d, 0x6a,
0x07, 0xf1, 0xa9, 0x1b, 0xa4, 0xb2, 0x34, 0x75, 0xe6, 0xbf, 0x6c, 0x42, 0xdb, 0xc1, 0xed, 0x7f,
0x00, 0x00, 0x00, 0xff, 0xff, 0x48, 0x7a, 0x3b, 0x55, 0x9c, 0x01, 0x00, 0x00,
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// GreeterClient is the client API for Greeter service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type GreeterClient interface {
Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error)
IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error)
}
type greeterClient struct {
cc *grpc.ClientConn
}
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient {
return &greeterClient{cc}
}
func (c *greeterClient) Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) {
out := new(Response)
err := c.cc.Invoke(ctx, "/test.Greeter/greet", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *greeterClient) IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error) {
out := new(rpc.EmptyResponse)
err := c.cc.Invoke(ctx, "/test.Greeter/idRequest", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// GreeterServer is the server API for Greeter service.
type GreeterServer interface {
Greet(context.Context, *Request) (*Response, error)
IdRequest(context.Context, *rpc.IdRequest) (*rpc.EmptyResponse, error)
}
// UnimplementedGreeterServer can be embedded to have forward compatible implementations.
type UnimplementedGreeterServer struct {
}
func (*UnimplementedGreeterServer) Greet(ctx context.Context, req *Request) (*Response, error) {
return nil, status.Errorf(codes.Unimplemented, "method Greet not implemented")
}
func (*UnimplementedGreeterServer) IdRequest(ctx context.Context, req *rpc.IdRequest) (*rpc.EmptyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method IdRequest not implemented")
}
func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
s.RegisterService(&_Greeter_serviceDesc, srv)
}
func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Request)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GreeterServer).Greet(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/test.Greeter/Greet",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GreeterServer).Greet(ctx, req.(*Request))
}
return interceptor(ctx, in, info, handler)
}
func _Greeter_IdRequest_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(rpc.IdRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GreeterServer).IdRequest(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/test.Greeter/IdRequest",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GreeterServer).IdRequest(ctx, req.(*rpc.IdRequest))
}
return interceptor(ctx, in, info, handler)
}
var _Greeter_serviceDesc = grpc.ServiceDesc{
ServiceName: "test.Greeter",
HandlerType: (*GreeterServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "greet",
Handler: _Greeter_Greet_Handler,
},
{
MethodName: "idRequest",
Handler: _Greeter_IdRequest_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "test.proto",
}

View File

@ -7,7 +7,7 @@ import (
"github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/gen"
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)

View File

@ -0,0 +1,53 @@
package ctx
import (
"errors"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
var moduleCheckErr = errors.New("the work directory must be found in the go mod or the $GOPATH")
type (
ProjectContext struct {
WorkDir string
// Name is the root name of the project
// eg: go-zero、greet
Name string
// Path identifies which module a project belongs to, which is module value if it's a go mod project,
// or else it is the root name of the project, eg: github.com/tal-tech/go-zero、greet
Path string
// Dir is the path of the project, eg: /Users/keson/goland/go/go-zero、/Users/keson/go/src/greet
Dir string
}
)
// Prepare checks the project which module belongs to,and returns the path and module.
// workDir parameter is the directory of the source of generating code,
// where can be found the project path and the project module,
func Prepare(workDir string) (*ProjectContext, error) {
ctx, err := background(workDir)
if err == nil {
return ctx, nil
}
name := filepath.Base(workDir)
_, err = execx.Run("go mod init "+name, workDir)
if err != nil {
return nil, err
}
return background(workDir)
}
func background(workDir string) (*ProjectContext, error) {
isGoMod, err := IsGoMod(workDir)
if err != nil {
return nil, err
}
if isGoMod {
return projectFromGoMod(workDir)
}
return projectFromGoPath(workDir)
}

View File

@ -0,0 +1,22 @@
package ctx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBackground(t *testing.T) {
workDir := "."
ctx, err := Prepare(workDir)
assert.Nil(t, err)
assert.True(t, true, func() bool {
return len(ctx.Dir) != 0 && len(ctx.Path) != 0
}())
}
func TestBackgroundNilWorkDir(t *testing.T) {
workDir := ""
_, err := Prepare(workDir)
assert.NotNil(t, err)
}

View File

@ -0,0 +1,47 @@
package ctx
import (
"errors"
"os"
"path/filepath"
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
type Module struct {
Path string
Main bool
Dir string
GoMod string
GoVersion string
}
// projectFromGoMod is used to find the go module and project file path
// the workDir flag specifies which folder we need to detect based on
// only valid for go mod project
func projectFromGoMod(workDir string) (*ProjectContext, error) {
if len(workDir) == 0 {
return nil, errors.New("the work directory is not found")
}
if _, err := os.Stat(workDir); err != nil {
return nil, err
}
data, err := execx.Run("go list -json -m", workDir)
if err != nil {
return nil, err
}
var m Module
err = jsonx.Unmarshal([]byte(data), &m)
if err != nil {
return nil, err
}
var ret ProjectContext
ret.WorkDir = workDir
ret.Name = filepath.Base(m.Dir)
ret.Dir = m.Dir
ret.Path = m.Path
return &ret, nil
}

View File

@ -0,0 +1,38 @@
package ctx
import (
"go/build"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
func TestProjectFromGoMod(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
_, err = execx.Run("go mod init "+projectName, dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(dir)
}()
ctx, err := projectFromGoMod(dir)
assert.Nil(t, err)
assert.Equal(t, projectName, ctx.Path)
assert.Equal(t, dir, ctx.Dir)
}

View File

@ -0,0 +1,47 @@
package ctx
import (
"errors"
"go/build"
"os"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
// projectFromGoPath is used to find the main module and project file path
// the workDir flag specifies which folder we need to detect based on
// only valid for go mod project
func projectFromGoPath(workDir string) (*ProjectContext, error) {
if len(workDir) == 0 {
return nil, errors.New("the work directory is not found")
}
if _, err := os.Stat(workDir); err != nil {
return nil, err
}
buildContext := build.Default
goPath := buildContext.GOPATH
goSrc := filepath.Join(goPath, "src")
if !util.FileExists(goSrc) {
return nil, moduleCheckErr
}
wd, err := filepath.Abs(workDir)
if err != nil {
return nil, err
}
if !strings.HasPrefix(wd, goSrc) {
return nil, moduleCheckErr
}
projectName := strings.TrimPrefix(wd, goSrc+string(filepath.Separator))
return &ProjectContext{
WorkDir: workDir,
Name: projectName,
Path: projectName,
Dir: filepath.Join(goSrc, projectName),
}, nil
}

View File

@ -0,0 +1,54 @@
package ctx
import (
"go/build"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
func TestProjectFromGoPath(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
defer func() {
_ = os.RemoveAll(dir)
}()
ctx, err := projectFromGoPath(dir)
assert.Nil(t, err)
assert.Equal(t, dir, ctx.Dir)
assert.Equal(t, projectName, ctx.Path)
}
func TestProjectFromGoPathNotInGoSrc(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
defer func() {
_ = os.RemoveAll(dir)
}()
_, err = projectFromGoPath("testPath")
assert.NotNil(t, err)
}

View File

@ -0,0 +1,32 @@
package ctx
import (
"errors"
"os"
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
// IsGoMod is used to determine whether workDir is a go module project through command `go list -json -m`
func IsGoMod(workDir string) (bool, error) {
if len(workDir) == 0 {
return false, errors.New("the work directory is not found")
}
if _, err := os.Stat(workDir); err != nil {
return false, err
}
data, err := execx.Run("go list -json -m", workDir)
if err != nil {
return false, nil
}
var m Module
err = jsonx.Unmarshal([]byte(data), &m)
if err != nil {
return false, err
}
return len(m.GoMod) > 0, nil
}

View File

@ -0,0 +1,67 @@
package ctx
import (
"go/build"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
func TestIsGoMod(t *testing.T) {
// create mod project
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
_, err = execx.Run("go mod init "+projectName, dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(dir)
}()
isGoMod, err := IsGoMod(dir)
assert.Nil(t, err)
assert.Equal(t, true, isGoMod)
}
func TestIsGoModNot(t *testing.T) {
dft := build.Default
gp := dft.GOPATH
if len(gp) == 0 {
return
}
projectName := stringx.Rand()
dir := filepath.Join(gp, "src", projectName)
err := util.MkdirIfNotExist(dir)
if err != nil {
return
}
defer func() {
_ = os.RemoveAll(dir)
}()
isGoMod, err := IsGoMod(dir)
assert.Nil(t, err)
assert.Equal(t, false, isGoMod)
}
func TestIsGoModWorkDirIsNil(t *testing.T) {
_, err := IsGoMod("")
assert.Equal(t, err.Error(), func() string {
return "the work directory is not found"
}())
}

View File

@ -1,141 +0,0 @@
package project
import (
"errors"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
const (
constGo = "go"
constProtoC = "protoc"
constGoMod = "go env GOMOD"
constGoPath = "go env GOPATH"
constProtoCGenGo = "protoc-gen-go"
)
type (
Project struct {
Path string // Project path name
Name string // Project name
Package string // The service related package
// true-> project in go path or project init with go mod,or else->false
IsInGoEnv bool
GoMod GoMod
}
GoMod struct {
Module string // The gomod module name
Path string // The gomod related path
}
)
func Prepare(projectDir string, checkGrpcEnv bool) (*Project, error) {
_, err := exec.LookPath(constGo)
if err != nil {
return nil, fmt.Errorf("please install go first,reference documents:「https://golang.org/doc/install」")
}
if checkGrpcEnv {
_, err = exec.LookPath(constProtoC)
if err != nil {
return nil, fmt.Errorf("please install protoc first,reference documents:「https://github.com/golang/protobuf」")
}
_, err = exec.LookPath(constProtoCGenGo)
if err != nil {
return nil, fmt.Errorf("please install plugin protoc-gen-go first,reference documents:「https://github.com/golang/protobuf」")
}
}
var (
goMod, module string
goPath string
name, path string
pkg string
)
ret, err := execx.Run(constGoMod, projectDir)
if err != nil {
return nil, err
}
goMod = strings.TrimSpace(ret)
if goMod == os.DevNull {
goMod = ""
}
ret, err = execx.Run(constGoPath, "")
if err != nil {
return nil, err
}
goPath = strings.TrimSpace(ret)
src := filepath.Join(goPath, "src")
var isInGoEnv = true
if len(goMod) > 0 {
path = filepath.Dir(goMod)
name = filepath.Base(path)
data, err := ioutil.ReadFile(goMod)
if err != nil {
return nil, err
}
module, err = matchModule(data)
if err != nil {
return nil, err
}
} else {
pwd, err := filepath.Abs(projectDir)
if err != nil {
return nil, err
}
if !strings.HasPrefix(pwd, src) {
name = filepath.Clean(filepath.Base(pwd))
path = projectDir
pkg = name
isInGoEnv = false
} else {
r := strings.TrimPrefix(pwd, src+string(filepath.Separator))
name = filepath.Dir(r)
if name == "." {
name = r
}
path = filepath.Join(src, name)
pkg = r
}
module = name
}
return &Project{
Name: name,
Path: path,
Package: strings.ReplaceAll(pkg, `\`, "/"),
IsInGoEnv: isInGoEnv,
GoMod: GoMod{
Module: module,
Path: goMod,
},
}, nil
}
func matchModule(data []byte) (string, error) {
text := string(data)
re := regexp.MustCompile(`(?m)^\s*module\s+[a-z0-9_/\-.]+$`)
matches := re.FindAllString(text, -1)
if len(matches) == 1 {
target := matches[0]
index := strings.Index(target, "module")
return strings.TrimSpace(target[index+6:]), nil
}
return "", errors.New("module not matched")
}