diff --git a/tools/goctl/api/gogen/util.go b/tools/goctl/api/gogen/util.go index 6550f297..9423d7e6 100644 --- a/tools/goctl/api/gogen/util.go +++ b/tools/goctl/api/gogen/util.go @@ -10,29 +10,20 @@ import ( "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/util" - goctlutil "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/project" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" ) func getParentPackage(dir string) (string, error) { - p, err := project.Prepare(dir, false) + abs, err := filepath.Abs(dir) if err != nil { return "", err } - if len(p.GoMod.Path) > 0 { - goModePath := filepath.Clean(filepath.Dir(p.GoMod.Path)) - absPath, err := filepath.Abs(dir) - if err != nil { - return "", err - } - parent := filepath.Clean(goctlutil.JoinPackages(p.GoMod.Module, absPath[len(goModePath):])) - parent = strings.ReplaceAll(parent, "\\", "/") - parent = strings.ReplaceAll(parent, `\`, "/") - return parent, nil + projectCtx, err := ctx.Prepare(abs) + if err != nil { + return "", err } - - return p.Package, nil + return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil } func writeIndent(writer io.Writer, indent int) { diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index cddc274e..b4e96f1f 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -19,7 +19,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/configgen" "github.com/tal-tech/go-zero/tools/goctl/docker" model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command" - rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/command" + rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli" "github.com/tal-tech/go-zero/tools/goctl/tpl" "github.com/urfave/cli" ) @@ -211,7 +211,7 @@ var ( Flags: []cli.Flag{ cli.BoolFlag{ Name: "idea", - Usage: "whether the command execution environment is from idea plugin. [option]", + Usage: "whether the command execution environment is from idea plugin. [optional]", }, }, Action: rpc.RpcNew, @@ -226,7 +226,7 @@ var ( }, cli.BoolFlag{ Name: "idea", - Usage: "whether the command execution environment is from idea plugin. [option]", + Usage: "whether the command execution environment is from idea plugin. [optional]", }, }, Action: rpc.RpcTemplate, @@ -239,17 +239,17 @@ var ( Name: "src, s", Usage: "the file path of the proto source file", }, - cli.StringFlag{ - Name: "dir, d", - Usage: `the target path of the code,default path is "${pwd}". [option]`, + cli.StringSliceFlag{ + Name: "proto_path, I", + Usage: `native command of protoc, specify the directory in which to search for imports. [optional]`, }, cli.StringFlag{ - Name: "service, srv", - Usage: `the name of rpc service. [option]`, + Name: "dir, d", + Usage: `the target path of the code`, }, cli.BoolFlag{ Name: "idea", - Usage: "whether the command execution environment is from idea plugin. [option]", + Usage: "whether the command execution environment is from idea plugin. [optional]", }, }, Action: rpc.Rpc, @@ -313,7 +313,7 @@ var ( }, cli.StringFlag{ Name: "style", - Usage: "the file naming style, lower|camel|underline,default is lower", + Usage: "the file naming style, lower|camel|snake, default is lower", }, cli.BoolFlag{ Name: "idea", diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index b12ebada..fc377ff9 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -40,7 +40,7 @@ func MysqlDDL(ctx *cli.Context) error { } switch namingStyle { - case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline: + case gen.NamingLower, gen.NamingCamel, gen.NamingSnake: case "": namingStyle = gen.NamingLower default: @@ -87,7 +87,7 @@ func MyDataSource(ctx *cli.Context) error { } switch namingStyle { - case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline: + case gen.NamingLower, gen.NamingCamel, gen.NamingSnake: case "": namingStyle = gen.NamingLower default: diff --git a/tools/goctl/model/sql/converter/types.go b/tools/goctl/model/sql/converter/types.go index c6f93105..a1b1a175 100644 --- a/tools/goctl/model/sql/converter/types.go +++ b/tools/goctl/model/sql/converter/types.go @@ -8,30 +8,36 @@ import ( var ( commonMysqlDataTypeMap = map[string]string{ // For consistency, all integer types are converted to int64 - "tinyint": "int64", - "smallint": "int64", - "mediumint": "int64", - "int": "int64", - "integer": "int64", - "bigint": "int64", - "float": "float64", - "double": "float64", - "decimal": "float64", - "date": "time.Time", - "time": "string", - "year": "int64", - "datetime": "time.Time", - "timestamp": "time.Time", + // number + "bool": "int64", + "boolean": "int64", + "tinyint": "int64", + "smallint": "int64", + "mediumint": "int64", + "int": "int64", + "integer": "int64", + "bigint": "int64", + "float": "float64", + "double": "float64", + "decimal": "float64", + // date&time + "date": "time.Time", + "datetime": "time.Time", + "timestamp": "time.Time", + "time": "string", + "year": "int64", + // string "char": "string", "varchar": "string", - "tinyblob": "string", + "binary": "string", + "varbinary": "string", "tinytext": "string", - "blob": "string", "text": "string", - "mediumblob": "string", "mediumtext": "string", - "longblob": "string", "longtext": "string", + "enum": "string", + "set": "string", + "json": "string", } ) diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 9bb0d304..752fe0ee 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -19,7 +19,7 @@ const ( createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case NamingLower = "lower" NamingCamel = "camel" - NamingUnderline = "underline" + NamingSnake = "snake" ) type ( @@ -81,7 +81,7 @@ func (g *defaultGenerator) Start(withCache bool) error { switch g.namingStyle { case NamingCamel: name = fmt.Sprintf("%sModel.go", tn.ToCamel()) - case NamingUnderline: + case NamingSnake: name = fmt.Sprintf("%s_model.go", tn.ToSnake()) } filename := filepath.Join(dirAbs, name) diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index 7e615816..656b6e50 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -1,6 +1,8 @@ package gen import ( + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -8,27 +10,55 @@ import ( ) var ( - source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;" + source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `content` json DEFAULT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;" ) func TestCacheModel(t *testing.T) { logx.Disable() _ = Clean() - g := NewDefaultGenerator(source, "./testmodel/cache", NamingLower) + dir, _ := filepath.Abs("./testmodel") + cacheDir := filepath.Join(dir, "cache") + noCacheDir := filepath.Join(dir, "nocache") + defer func() { + _ = os.RemoveAll(dir) + }() + g := NewDefaultGenerator(source, cacheDir, NamingLower) err := g.Start(true) assert.Nil(t, err) - g = NewDefaultGenerator(source, "./testmodel/nocache", NamingLower) + assert.True(t, func() bool { + _, err := os.Stat(filepath.Join(cacheDir, "testuserinfomodel.go")) + return err == nil + }()) + g = NewDefaultGenerator(source, noCacheDir, NamingLower) err = g.Start(false) assert.Nil(t, err) + assert.True(t, func() bool { + _, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go")) + return err == nil + }()) } func TestNamingModel(t *testing.T) { logx.Disable() _ = Clean() - g := NewDefaultGenerator(source, "./testmodel/camel", NamingCamel) + dir, _ := filepath.Abs("./testmodel") + camelDir := filepath.Join(dir, "camel") + snakeDir := filepath.Join(dir, "snake") + defer func() { + _ = os.RemoveAll(dir) + }() + g := NewDefaultGenerator(source, camelDir, NamingCamel) err := g.Start(true) assert.Nil(t, err) - g = NewDefaultGenerator(source, "./testmodel/snake", NamingUnderline) + assert.True(t, func() bool { + _, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go")) + return err == nil + }()) + g = NewDefaultGenerator(source, snakeDir, NamingSnake) err = g.Start(true) assert.Nil(t, err) + assert.True(t, func() bool { + _, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go")) + return err == nil + }()) } diff --git a/tools/goctl/model/sql/gen/testmodel/cache/testuserinfomodel.go b/tools/goctl/model/sql/gen/testmodel/cache/testuserinfomodel.go deleted file mode 100755 index e657ce19..00000000 --- a/tools/goctl/model/sql/gen/testmodel/cache/testuserinfomodel.go +++ /dev/null @@ -1,125 +0,0 @@ -package cache - -import ( - "database/sql" - "fmt" - "strings" - "time" - - "github.com/tal-tech/go-zero/core/stores/cache" - "github.com/tal-tech/go-zero/core/stores/sqlc" - "github.com/tal-tech/go-zero/core/stores/sqlx" - "github.com/tal-tech/go-zero/core/stringx" - "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" -) - -var ( - testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{}) - testUserInfoRows = strings.Join(testUserInfoFieldNames, ",") - testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",") - testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?" - - cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#" - cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#" -) - -type ( - TestUserInfoModel struct { - sqlc.CachedConn - table string - } - - TestUserInfo struct { - Id int64 `db:"id"` - Nanosecond int64 `db:"nanosecond"` - Data string `db:"data"` - CreateTime time.Time `db:"create_time"` - UpdateTime time.Time `db:"update_time"` - } -) - -func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel { - return &TestUserInfoModel{ - CachedConn: sqlc.NewConn(conn, c), - table: "test_user_info", - } -} - -func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) { - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond) - ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet) - return conn.Exec(query, data.Nanosecond, data.Data) - }, testUserInfoNanosecondKey) - return ret, err -} - -func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) { - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id) - var resp TestUserInfo - err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error { - query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table) - return conn.QueryRow(v, query, id) - }) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) { - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond) - var resp TestUserInfo - err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) { - query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table) - if err := conn.QueryRow(&resp, query, nanosecond); err != nil { - return nil, err - } - return resp.Id, nil - }, m.queryPrimary) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) Update(data TestUserInfo) error { - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id) - _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder) - return conn.Exec(query, data.Nanosecond, data.Data, data.Id) - }, testUserInfoIdKey) - return err -} - -func (m *TestUserInfoModel) Delete(id int64) error { - data, err := m.FindOne(id) - if err != nil { - return err - } - - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond) - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id) - _, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("delete from %s where id = ?", m.table) - return conn.Exec(query, id) - }, testUserInfoNanosecondKey, testUserInfoIdKey) - return err -} - -func (m *TestUserInfoModel) formatPrimary(primary interface{}) string { - return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary) -} - -func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error { - query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table) - return conn.QueryRow(v, query, primary) -} diff --git a/tools/goctl/model/sql/gen/testmodel/cache/vars.go b/tools/goctl/model/sql/gen/testmodel/cache/vars.go deleted file mode 100644 index 14e43396..00000000 --- a/tools/goctl/model/sql/gen/testmodel/cache/vars.go +++ /dev/null @@ -1,5 +0,0 @@ -package cache - -import "github.com/tal-tech/go-zero/core/stores/sqlx" - -var ErrNotFound = sqlx.ErrNotFound diff --git a/tools/goctl/model/sql/gen/testmodel/camel/TestUserInfoModel.go b/tools/goctl/model/sql/gen/testmodel/camel/TestUserInfoModel.go deleted file mode 100755 index 3fc2da64..00000000 --- a/tools/goctl/model/sql/gen/testmodel/camel/TestUserInfoModel.go +++ /dev/null @@ -1,125 +0,0 @@ -package camel - -import ( - "database/sql" - "fmt" - "strings" - "time" - - "github.com/tal-tech/go-zero/core/stores/cache" - "github.com/tal-tech/go-zero/core/stores/sqlc" - "github.com/tal-tech/go-zero/core/stores/sqlx" - "github.com/tal-tech/go-zero/core/stringx" - "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" -) - -var ( - testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{}) - testUserInfoRows = strings.Join(testUserInfoFieldNames, ",") - testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",") - testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?" - - cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#" - cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#" -) - -type ( - TestUserInfoModel struct { - sqlc.CachedConn - table string - } - - TestUserInfo struct { - Id int64 `db:"id"` - Nanosecond int64 `db:"nanosecond"` - Data string `db:"data"` - CreateTime time.Time `db:"create_time"` - UpdateTime time.Time `db:"update_time"` - } -) - -func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel { - return &TestUserInfoModel{ - CachedConn: sqlc.NewConn(conn, c), - table: "test_user_info", - } -} - -func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) { - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond) - ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet) - return conn.Exec(query, data.Nanosecond, data.Data) - }, testUserInfoNanosecondKey) - return ret, err -} - -func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) { - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id) - var resp TestUserInfo - err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error { - query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table) - return conn.QueryRow(v, query, id) - }) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) { - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond) - var resp TestUserInfo - err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) { - query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table) - if err := conn.QueryRow(&resp, query, nanosecond); err != nil { - return nil, err - } - return resp.Id, nil - }, m.queryPrimary) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) Update(data TestUserInfo) error { - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id) - _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder) - return conn.Exec(query, data.Nanosecond, data.Data, data.Id) - }, testUserInfoIdKey) - return err -} - -func (m *TestUserInfoModel) Delete(id int64) error { - data, err := m.FindOne(id) - if err != nil { - return err - } - - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond) - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id) - _, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("delete from %s where id = ?", m.table) - return conn.Exec(query, id) - }, testUserInfoIdKey, testUserInfoNanosecondKey) - return err -} - -func (m *TestUserInfoModel) formatPrimary(primary interface{}) string { - return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary) -} - -func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error { - query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table) - return conn.QueryRow(v, query, primary) -} diff --git a/tools/goctl/model/sql/gen/testmodel/camel/vars.go b/tools/goctl/model/sql/gen/testmodel/camel/vars.go deleted file mode 100644 index 9e6eb09e..00000000 --- a/tools/goctl/model/sql/gen/testmodel/camel/vars.go +++ /dev/null @@ -1,5 +0,0 @@ -package camel - -import "github.com/tal-tech/go-zero/core/stores/sqlx" - -var ErrNotFound = sqlx.ErrNotFound diff --git a/tools/goctl/model/sql/gen/testmodel/nocache/testuserinfomodel.go b/tools/goctl/model/sql/gen/testmodel/nocache/testuserinfomodel.go deleted file mode 100755 index ee5f3e42..00000000 --- a/tools/goctl/model/sql/gen/testmodel/nocache/testuserinfomodel.go +++ /dev/null @@ -1,88 +0,0 @@ -package nocache - -import ( - "database/sql" - "fmt" - "strings" - "time" - - "github.com/tal-tech/go-zero/core/stores/sqlc" - "github.com/tal-tech/go-zero/core/stores/sqlx" - "github.com/tal-tech/go-zero/core/stringx" - "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" -) - -var ( - testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{}) - testUserInfoRows = strings.Join(testUserInfoFieldNames, ",") - testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",") - testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?" -) - -type ( - TestUserInfoModel struct { - conn sqlx.SqlConn - table string - } - - TestUserInfo struct { - Id int64 `db:"id"` - Nanosecond int64 `db:"nanosecond"` - Data string `db:"data"` - CreateTime time.Time `db:"create_time"` - UpdateTime time.Time `db:"update_time"` - } -) - -func NewTestUserInfoModel(conn sqlx.SqlConn) *TestUserInfoModel { - return &TestUserInfoModel{ - conn: conn, - table: "test_user_info", - } -} - -func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) { - query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet) - ret, err := m.conn.Exec(query, data.Nanosecond, data.Data) - return ret, err -} - -func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) { - query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table) - var resp TestUserInfo - err := m.conn.QueryRow(&resp, query, id) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) { - var resp TestUserInfo - query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table) - err := m.conn.QueryRow(&resp, query, nanosecond) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) Update(data TestUserInfo) error { - query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder) - _, err := m.conn.Exec(query, data.Nanosecond, data.Data, data.Id) - return err -} - -func (m *TestUserInfoModel) Delete(id int64) error { - query := fmt.Sprintf("delete from %s where id = ?", m.table) - _, err := m.conn.Exec(query, id) - return err -} diff --git a/tools/goctl/model/sql/gen/testmodel/nocache/vars.go b/tools/goctl/model/sql/gen/testmodel/nocache/vars.go deleted file mode 100644 index c6c2f592..00000000 --- a/tools/goctl/model/sql/gen/testmodel/nocache/vars.go +++ /dev/null @@ -1,5 +0,0 @@ -package nocache - -import "github.com/tal-tech/go-zero/core/stores/sqlx" - -var ErrNotFound = sqlx.ErrNotFound diff --git a/tools/goctl/model/sql/gen/testmodel/snake/test_user_info_model.go b/tools/goctl/model/sql/gen/testmodel/snake/test_user_info_model.go deleted file mode 100755 index 54de8dee..00000000 --- a/tools/goctl/model/sql/gen/testmodel/snake/test_user_info_model.go +++ /dev/null @@ -1,125 +0,0 @@ -package snake - -import ( - "database/sql" - "fmt" - "strings" - "time" - - "github.com/tal-tech/go-zero/core/stores/cache" - "github.com/tal-tech/go-zero/core/stores/sqlc" - "github.com/tal-tech/go-zero/core/stores/sqlx" - "github.com/tal-tech/go-zero/core/stringx" - "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" -) - -var ( - testUserInfoFieldNames = builderx.FieldNames(&TestUserInfo{}) - testUserInfoRows = strings.Join(testUserInfoFieldNames, ",") - testUserInfoRowsExpectAutoSet = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), ",") - testUserInfoRowsWithPlaceHolder = strings.Join(stringx.Remove(testUserInfoFieldNames, "id", "create_time", "update_time"), "=?,") + "=?" - - cacheTestUserInfoIdPrefix = "cache#TestUserInfo#id#" - cacheTestUserInfoNanosecondPrefix = "cache#TestUserInfo#nanosecond#" -) - -type ( - TestUserInfoModel struct { - sqlc.CachedConn - table string - } - - TestUserInfo struct { - Id int64 `db:"id"` - Nanosecond int64 `db:"nanosecond"` - Data string `db:"data"` - CreateTime time.Time `db:"create_time"` - UpdateTime time.Time `db:"update_time"` - } -) - -func NewTestUserInfoModel(conn sqlx.SqlConn, c cache.CacheConf) *TestUserInfoModel { - return &TestUserInfoModel{ - CachedConn: sqlc.NewConn(conn, c), - table: "test_user_info", - } -} - -func (m *TestUserInfoModel) Insert(data TestUserInfo) (sql.Result, error) { - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond) - ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, testUserInfoRowsExpectAutoSet) - return conn.Exec(query, data.Nanosecond, data.Data) - }, testUserInfoNanosecondKey) - return ret, err -} - -func (m *TestUserInfoModel) FindOne(id int64) (*TestUserInfo, error) { - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id) - var resp TestUserInfo - err := m.QueryRow(&resp, testUserInfoIdKey, func(conn sqlx.SqlConn, v interface{}) error { - query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table) - return conn.QueryRow(v, query, id) - }) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) FindOneByNanosecond(nanosecond int64) (*TestUserInfo, error) { - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, nanosecond) - var resp TestUserInfo - err := m.QueryRowIndex(&resp, testUserInfoNanosecondKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) { - query := fmt.Sprintf("select %s from %s where nanosecond = ? limit 1", testUserInfoRows, m.table) - if err := conn.QueryRow(&resp, query, nanosecond); err != nil { - return nil, err - } - return resp.Id, nil - }, m.queryPrimary) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } -} - -func (m *TestUserInfoModel) Update(data TestUserInfo) error { - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, data.Id) - _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("update %s set %s where id = ?", m.table, testUserInfoRowsWithPlaceHolder) - return conn.Exec(query, data.Nanosecond, data.Data, data.Id) - }, testUserInfoIdKey) - return err -} - -func (m *TestUserInfoModel) Delete(id int64) error { - data, err := m.FindOne(id) - if err != nil { - return err - } - - testUserInfoIdKey := fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, id) - testUserInfoNanosecondKey := fmt.Sprintf("%s%v", cacheTestUserInfoNanosecondPrefix, data.Nanosecond) - _, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("delete from %s where id = ?", m.table) - return conn.Exec(query, id) - }, testUserInfoIdKey, testUserInfoNanosecondKey) - return err -} - -func (m *TestUserInfoModel) formatPrimary(primary interface{}) string { - return fmt.Sprintf("%s%v", cacheTestUserInfoIdPrefix, primary) -} - -func (m *TestUserInfoModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error { - query := fmt.Sprintf("select %s from %s where id = ? limit 1", testUserInfoRows, m.table) - return conn.QueryRow(v, query, primary) -} diff --git a/tools/goctl/model/sql/gen/testmodel/snake/vars.go b/tools/goctl/model/sql/gen/testmodel/snake/vars.go deleted file mode 100644 index e18aa209..00000000 --- a/tools/goctl/model/sql/gen/testmodel/snake/vars.go +++ /dev/null @@ -1,5 +0,0 @@ -package snake - -import "github.com/tal-tech/go-zero/core/stores/sqlx" - -var ErrNotFound = sqlx.ErrNotFound diff --git a/tools/goctl/rpc/CHANGELOG.md b/tools/goctl/rpc/CHANGELOG.md deleted file mode 100644 index ee4429eb..00000000 --- a/tools/goctl/rpc/CHANGELOG.md +++ /dev/null @@ -1,20 +0,0 @@ -# Change log - -## 2020-10-19 - -* 增加template - -## 2020-09-10 - -* rpc greet服务一键生成 -* 修复相对路径生成rpc服务package引入错误bug -* 移除`--shared`参数 - -## 2020-08-29 - -* 新增支持windows生成 - -## 2020-08-27 - -* 新增支持rpc模板生成 -* 新增支持rpc服务生成 diff --git a/tools/goctl/rpc/README.md b/tools/goctl/rpc/README.md index 7c5c2d50..d9a35e4f 100644 --- a/tools/goctl/rpc/README.md +++ b/tools/goctl/rpc/README.md @@ -7,9 +7,8 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块,支持prot * 简单易用 * 快速提升开发效率 * 出错率低 -* 支持基于main proto作为相对路径的import -* 支持map、enum类型 -* 支持any类型 +* 贴近protoc + ## 快速开始 @@ -19,44 +18,41 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块,支持prot 如生成greet rpc服务: - ```shell script + ```Bash goctl rpc new greet ``` 执行后代码结构如下: ```golang - └── greet - ├── etc - │   └── greet.yaml - ├── go.mod - ├── go.sum - ├── greet - │   ├── greet.go - │   ├── greet_mock.go - │   └── types.go - ├── greet.go - ├── greet.proto - ├── internal - │   ├── config - │   │   └── config.go - │   ├── logic - │   │   └── pinglogic.go - │   ├── server - │   │   └── greetserver.go - │   └── svc - │   └── servicecontext.go - └── pb - └── greet.pb.go +. +├── etc // 配置文件 +│   └── greet.yaml +├── go.mod +├── greet // client call +│   └── greet.go +├── greet.go // main entry +├── greet.proto +└── internal + ├── config // 配置声明 + │   └── config.go + ├── greet // pb.go + │   └── greet.pb.go + ├── logic // logic + │   └── pinglogic.go + ├── server // pb invoker + │   └── greetserver.go + └── svc // resource dependency + └── servicecontext.go ``` -rpc一键生成常见问题解决见 常见问题解决 +rpc一键生成常见问题解决,见 常见问题解决 ### 方式二:通过指定proto生成rpc服务 * 生成proto模板 - ```shell script + ```Bash goctl rpc template -o=user.proto ``` @@ -87,35 +83,10 @@ rpc一键生成常见问题解决见 常见问题 * 生成rpc服务代码 - ```shell script + ```Bash goctl rpc proto -src=user.proto ``` - 代码tree - - ```Plain Text - user - ├── etc - │   └── user.json - ├── internal - │   ├── config - │   │   └── config.go - │   ├── handler - │   │   ├── loginhandler.go - │   ├── logic - │   │   └── loginlogic.go - │   └── svc - │   └── servicecontext.go - ├── pb - │   └── user.pb.go - ├── shared - │   ├── mockusermodel.go - │   ├── types.go - │   └── usermodel.go - ├── user.go - └── user.proto - ``` - ## 准备工作 * 安装了go环境 @@ -126,11 +97,11 @@ rpc一键生成常见问题解决见 常见问题 ### rpc服务生成用法 -```shell script +```Bash goctl rpc proto -h ``` -```shell script +```Bash NAME: goctl rpc proto - generate rpc from proto @@ -139,35 +110,22 @@ USAGE: OPTIONS: --src value, -s value the file path of the proto source file - --dir value, -d value the target path of the code,default path is "${pwd}". [option] - --service value, --srv value the name of rpc service. [option] - --idea whether the command execution environment is from idea plugin. [option] - + --proto_path value, -I value native command of protoc,specify the directory in which to search for imports. [optional] + --dir value, -d value the target path of the code,default path is "${pwd}". [optional] + --idea whether the command execution environment is from idea plugin. [optional] ``` ### 参数说明 -* --src 必填,proto数据源,目前暂时支持单个proto文件生成,这里不支持(不建议)外部依赖 -* --dir 非必填,默认为proto文件所在目录,生成代码的目标目录 -* --service 服务名称,非必填,默认为proto文件所在目录名称,但是,如果proto所在目录为一下结构: - - ```shell script - user - ├── cmd - │   └── rpc - │   └── user.proto - ``` - - 则服务名称亦为user,而非proto所在文件夹名称了,这里推荐使用这种结构,可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。 - - > 注意:这里的shared文件夹名称将会是代码中的package名称。 - -* --idea 非必填,是否为idea插件中执行,保留字段,终端执行可以忽略 +* --src 必填,proto数据源,目前暂时支持单个proto文件生成 +* --proto_path 可选,protoc原生子命令,用于指定proto import从何处查找,可指定多个路径,如`goctl rpc -I={path1} -I={path2} ...`,在没有import时可不填。当前proto路径不用指定,已经内置,`-I`的详细用法请参考`protoc -h` +* --dir 可选,默认为proto文件所在目录,生成代码的目标目录 +* --idea 可选,是否为idea插件中执行,终端执行可以忽略 ### 开发人员需要做什么 -关注业务代码编写,将重复性、与业务无关的工作交给goctl,生成好rpc服务代码后,开饭人员仅需要修改 +关注业务代码编写,将重复性、与业务无关的工作交给goctl,生成好rpc服务代码后,开发人员仅需要修改 * 服务中的配置文件编写(etc/xx.json、internal/config/config.go) * 服务中业务逻辑编写(internal/logic/xxlogic.go) @@ -193,69 +151,54 @@ OPTIONS: 的标识,请注意不要将也写业务性代码写在里面。 -## any和import支持 -* 支持any类型声明 -* 支持import其他proto文件 +## proto import +* 对于rpc中的requestType和returnType必须在main proto文件定义,对于proto中的message可以像protoc一样import其他proto文件。 - any类型固定import为`google/protobuf/any.proto`,且从${GOPATH}/src中查找,proto的import支持main proto的相对路径的import,且与proto文件对应的pb.go文件必须在proto目录中能被找到。不支持工程外的其他proto文件import。 +proto示例: -> ⚠️注意: 不支持proto嵌套import,即:被import的proto文件不支持import。 - -### import书写格式 -import书写格式 -```golang -// @{package_of_pb} -import {proto_omport} -``` -@{package_of_pb}:pb文件的真实import目录。 -{proto_omport}:proto import - - -如:demo中的 - -```golang -// @greet/base -import "base/base.proto"; -``` - -工程目录结构如下 -``` -greet -│   ├── base -│   │   ├── base.pb.go -│   │   └── base.proto -│   ├── demo.proto -│   ├── go.mod -│   └── go.sum -``` - -demo -```golang +### 错误import +```proto syntax = "proto3"; -import "google/protobuf/any.proto"; -// @greet/base -import "base/base.proto"; -package stream; +package greet; -enum Gender{ - UNKNOWN = 0; - MAN = 1; - WOMAN = 2; +import "base/common.proto" + +message Request { + string ping = 1; } -message StreamResp{ - string name = 2; - Gender gender = 3; - google.protobuf.Any details = 5; - base.StreamReq req = 6; +message Response { + string pong = 1; } -service StreamGreeter { - rpc greet(base.StreamReq) returns (StreamResp); + +service Greet { + rpc Ping(base.In) returns(base.Out);// request和return 不支持import } + ``` +### 正确import +```proto +syntax = "proto3"; + +package greet; + +import "base/common.proto" + +message Request { + base.In in = 1;// 支持import +} + +message Response { + base.Out out = 2;// 支持import +} + +service Greet { + rpc Ping(Request) returns(Response); +} +``` ## 常见问题解决(go mod工程) diff --git a/tools/goctl/rpc/base.pb.go b/tools/goctl/rpc/base.pb.go deleted file mode 100644 index 0f2653f9..00000000 --- a/tools/goctl/rpc/base.pb.go +++ /dev/null @@ -1,108 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// source: base.proto - -package base - -import ( - fmt "fmt" - proto "github.com/golang/protobuf/proto" - math "math" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package - -type IdRequest struct { - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *IdRequest) Reset() { *m = IdRequest{} } -func (m *IdRequest) String() string { return proto.CompactTextString(m) } -func (*IdRequest) ProtoMessage() {} -func (*IdRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_db1b6b0986796150, []int{0} -} - -func (m *IdRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_IdRequest.Unmarshal(m, b) -} -func (m *IdRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_IdRequest.Marshal(b, m, deterministic) -} -func (m *IdRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_IdRequest.Merge(m, src) -} -func (m *IdRequest) XXX_Size() int { - return xxx_messageInfo_IdRequest.Size(m) -} -func (m *IdRequest) XXX_DiscardUnknown() { - xxx_messageInfo_IdRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_IdRequest proto.InternalMessageInfo - -func (m *IdRequest) GetId() string { - if m != nil { - return m.Id - } - return "" -} - -type EmptyResponse struct { - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *EmptyResponse) Reset() { *m = EmptyResponse{} } -func (m *EmptyResponse) String() string { return proto.CompactTextString(m) } -func (*EmptyResponse) ProtoMessage() {} -func (*EmptyResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_db1b6b0986796150, []int{1} -} - -func (m *EmptyResponse) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_EmptyResponse.Unmarshal(m, b) -} -func (m *EmptyResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_EmptyResponse.Marshal(b, m, deterministic) -} -func (m *EmptyResponse) XXX_Merge(src proto.Message) { - xxx_messageInfo_EmptyResponse.Merge(m, src) -} -func (m *EmptyResponse) XXX_Size() int { - return xxx_messageInfo_EmptyResponse.Size(m) -} -func (m *EmptyResponse) XXX_DiscardUnknown() { - xxx_messageInfo_EmptyResponse.DiscardUnknown(m) -} - -var xxx_messageInfo_EmptyResponse proto.InternalMessageInfo - -func init() { - proto.RegisterType((*IdRequest)(nil), "base.IdRequest") - proto.RegisterType((*EmptyResponse)(nil), "base.EmptyResponse") -} - -func init() { proto.RegisterFile("base.proto", fileDescriptor_db1b6b0986796150) } - -var fileDescriptor_db1b6b0986796150 = []byte{ - // 91 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4a, 0x4a, 0x2c, 0x4e, - 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x01, 0xb1, 0x95, 0xa4, 0xb9, 0x38, 0x3d, 0x53, - 0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0xf8, 0xb8, 0x98, 0x32, 0x53, 0x24, 0x18, 0x15, - 0x18, 0x35, 0x38, 0x83, 0x98, 0x32, 0x53, 0x94, 0xf8, 0xb9, 0x78, 0x5d, 0x73, 0x0b, 0x4a, 0x2a, - 0x83, 0x52, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x93, 0xd8, 0xc0, 0x5a, 0x8d, 0x01, 0x01, 0x00, - 0x00, 0xff, 0xff, 0xe1, 0x39, 0x3c, 0x22, 0x48, 0x00, 0x00, 0x00, -} diff --git a/tools/goctl/rpc/base.proto b/tools/goctl/rpc/base.proto deleted file mode 100644 index 501f12f7..00000000 --- a/tools/goctl/rpc/base.proto +++ /dev/null @@ -1,11 +0,0 @@ -syntax = "proto3"; - -package base; - -message IdRequest { - string id = 1; -} - -message EmptyResponse { - -} diff --git a/tools/goctl/rpc/cli/cli.go b/tools/goctl/rpc/cli/cli.go new file mode 100644 index 00000000..a76dde22 --- /dev/null +++ b/tools/goctl/rpc/cli/cli.go @@ -0,0 +1,67 @@ +package cli + +import ( + "errors" + "fmt" + "path/filepath" + + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/generator" + "github.com/urfave/cli" +) + +// Rpc is to generate rpc service code from a proto file by specifying a proto file using flag src, +// you can specify a target folder for code generation, when the proto file has import, you can specify +// the import search directory through the proto_path command, for specific usage, please refer to protoc -h +func Rpc(c *cli.Context) error { + src := c.String("src") + out := c.String("dir") + protoImportPath := c.StringSlice("proto_path") + if len(src) == 0 { + return errors.New("the proto source can not be nil") + } + if len(out) == 0 { + return errors.New("the target directory can not be nil") + } + g := generator.NewDefaultRpcGenerator() + return g.Generate(src, out, protoImportPath) +} + +// RpcNew is to generate rpc greet service, this greet service can speed +// up your understanding of the zrpc service structure +func RpcNew(c *cli.Context) error { + name := c.Args().First() + ext := filepath.Ext(name) + if len(ext) > 0 { + return fmt.Errorf("unexpected ext: %s", ext) + } + + protoName := name + ".proto" + filename := filepath.Join(".", name, protoName) + src, err := filepath.Abs(filename) + if err != nil { + return err + } + + err = generator.ProtoTmpl(src) + if err != nil { + return err + } + + workDir := filepath.Dir(src) + _, err = execx.Run("go mod init "+name, workDir) + if err != nil { + return err + } + + g := generator.NewDefaultRpcGenerator() + return g.Generate(src, filepath.Dir(src), nil) +} + +func RpcTemplate(c *cli.Context) error { + name := c.Args().First() + if len(name) == 0 { + name = "greet.proto" + } + return generator.ProtoTmpl(name) +} diff --git a/tools/goctl/rpc/command/command.go b/tools/goctl/rpc/command/command.go deleted file mode 100644 index fde97dad..00000000 --- a/tools/goctl/rpc/command/command.go +++ /dev/null @@ -1,60 +0,0 @@ -package command - -import ( - "fmt" - "os" - "path/filepath" - - "github.com/tal-tech/go-zero/tools/goctl/rpc/ctx" - "github.com/tal-tech/go-zero/tools/goctl/rpc/gen" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/urfave/cli" -) - -func Rpc(c *cli.Context) error { - rpcCtx := ctx.MustCreateRpcContextFromCli(c) - generator := gen.NewDefaultRpcGenerator(rpcCtx) - rpcCtx.Must(generator.Generate()) - return nil -} - -func RpcTemplate(c *cli.Context) error { - out := c.String("out") - idea := c.Bool("idea") - generator := gen.NewRpcTemplate(out, idea) - generator.MustGenerate(true) - return nil -} - -func RpcNew(c *cli.Context) error { - idea := c.Bool("idea") - arg := c.Args().First() - if len(arg) == 0 { - arg = "greet" - } - abs, err := filepath.Abs(arg) - if err != nil { - return err - } - _, err = os.Stat(abs) - if err != nil { - if !os.IsNotExist(err) { - return err - } - err = util.MkdirIfNotExist(abs) - if err != nil { - return err - } - } - - dir := filepath.Base(filepath.Clean(abs)) - - protoSrc := filepath.Join(abs, fmt.Sprintf("%v.proto", dir)) - templateGenerator := gen.NewRpcTemplate(protoSrc, idea) - templateGenerator.MustGenerate(false) - - rpcCtx := ctx.MustCreateRpcContext(protoSrc, "", "", idea) - generator := gen.NewDefaultRpcGenerator(rpcCtx) - rpcCtx.Must(generator.Generate()) - return nil -} diff --git a/tools/goctl/rpc/ctx/ctx.go b/tools/goctl/rpc/ctx/ctx.go deleted file mode 100644 index 5facee42..00000000 --- a/tools/goctl/rpc/ctx/ctx.go +++ /dev/null @@ -1,99 +0,0 @@ -package ctx - -import ( - "fmt" - "path/filepath" - "runtime" - "strings" - - "github.com/tal-tech/go-zero/core/logx" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/console" - "github.com/tal-tech/go-zero/tools/goctl/util/project" - "github.com/tal-tech/go-zero/tools/goctl/util/stringx" - "github.com/tal-tech/go-zero/tools/goctl/vars" - "github.com/urfave/cli" -) - -const ( - flagSrc = "src" - flagDir = "dir" - flagService = "service" - flagIdea = "idea" -) - -type RpcContext struct { - ProjectPath string - ProjectName stringx.String - ServiceName stringx.String - CurrentPath string - Module string - ProtoFileSrc string - ProtoSource string - TargetDir string - IsInGoEnv bool - console.Console -} - -func MustCreateRpcContext(protoSrc, targetDir, serviceName string, idea bool) *RpcContext { - log := console.NewConsole(idea) - - if stringx.From(protoSrc).IsEmptyOrSpace() { - log.Fatalln("expected proto source, but nothing found") - } - srcFp, err := filepath.Abs(protoSrc) - log.Must(err) - - if !util.FileExists(srcFp) { - log.Fatalln("%s is not exists", srcFp) - } - current := filepath.Dir(srcFp) - if stringx.From(targetDir).IsEmptyOrSpace() { - targetDir = current - } - targetDirFp, err := filepath.Abs(targetDir) - log.Must(err) - - if stringx.From(serviceName).IsEmptyOrSpace() { - serviceName = getServiceFromRpcStructure(targetDirFp) - } - serviceNameString := stringx.From(serviceName) - if serviceNameString.IsEmptyOrSpace() { - log.Fatalln("service name not found") - } - - info, err := project.Prepare(targetDir, true) - log.Must(err) - - return &RpcContext{ - ProjectPath: info.Path, - ProjectName: stringx.From(info.Name), - ServiceName: serviceNameString, - CurrentPath: current, - Module: info.GoMod.Module, - ProtoFileSrc: srcFp, - ProtoSource: filepath.Base(srcFp), - TargetDir: targetDirFp, - IsInGoEnv: info.IsInGoEnv, - Console: log, - } -} -func MustCreateRpcContextFromCli(ctx *cli.Context) *RpcContext { - os := runtime.GOOS - switch os { - case vars.OsMac, vars.OsLinux, vars.OsWindows: - default: - logx.Must(fmt.Errorf("unexpected os: %s", os)) - } - protoSrc := ctx.String(flagSrc) - targetDir := ctx.String(flagDir) - serviceName := ctx.String(flagService) - idea := ctx.Bool(flagIdea) - return MustCreateRpcContext(protoSrc, targetDir, serviceName, idea) -} - -func getServiceFromRpcStructure(targetDir string) string { - targetDir = filepath.Clean(targetDir) - suffix := filepath.Join("cmd", "rpc") - return filepath.Base(strings.TrimSuffix(targetDir, suffix)) -} diff --git a/tools/goctl/rpc/execx/execx.go b/tools/goctl/rpc/execx/execx.go index 039a94c7..9b7e482f 100644 --- a/tools/goctl/rpc/execx/execx.go +++ b/tools/goctl/rpc/execx/execx.go @@ -6,7 +6,9 @@ import ( "fmt" "os/exec" "runtime" + "strings" + "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/vars" ) @@ -24,17 +26,17 @@ func Run(arg string, dir string) (string, error) { if len(dir) > 0 { cmd.Dir = dir } - dtsout := new(bytes.Buffer) + stdout := new(bytes.Buffer) stderr := new(bytes.Buffer) - cmd.Stdout = dtsout + cmd.Stdout = stdout cmd.Stderr = stderr err := cmd.Run() if err != nil { if stderr.Len() > 0 { - return "", errors.New(stderr.String()) + return "", errors.New(strings.TrimSuffix(stderr.String(), util.NL)) } return "", err } - return dtsout.String(), nil + return strings.TrimSuffix(stdout.String(), util.NL), nil } diff --git a/tools/goctl/rpc/gen/gen.go b/tools/goctl/rpc/gen/gen.go deleted file mode 100644 index 172d9013..00000000 --- a/tools/goctl/rpc/gen/gen.go +++ /dev/null @@ -1,93 +0,0 @@ -package gen - -import ( - "github.com/logrusorgru/aurora" - "github.com/tal-tech/go-zero/tools/goctl/rpc/ctx" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" -) - -const ( - dirTarget = "dirTarget" - dirConfig = "config" - dirEtc = "etc" - dirSvc = "svc" - dirServer = "server" - dirLogic = "logic" - dirPb = "pb" - dirInternal = "internal" - fileConfig = "config.go" - fileServiceContext = "servicecontext.go" -) - -type defaultRpcGenerator struct { - dirM map[string]string - Ctx *ctx.RpcContext - ast *parser.PbAst -} - -func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator { - return &defaultRpcGenerator{ - Ctx: ctx, - } -} - -func (g *defaultRpcGenerator) Generate() (err error) { - g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」") - g.Ctx.Warning("-> generating rpc code ...") - defer func() { - if err == nil { - g.Ctx.MarkDone() - } - }() - err = g.createDir() - if err != nil { - return - } - - err = g.initGoMod() - if err != nil { - return - } - - err = g.genEtc() - if err != nil { - return - } - - err = g.genPb() - if err != nil { - return - } - - err = g.genConfig() - if err != nil { - return - } - - err = g.genSvc() - if err != nil { - return - } - - err = g.genLogic() - if err != nil { - return - } - - err = g.genHandler() - if err != nil { - return - } - - err = g.genMain() - if err != nil { - return - } - - err = g.genCall() - if err != nil { - return - } - - return -} diff --git a/tools/goctl/rpc/gen/gencall.go b/tools/goctl/rpc/gen/gencall.go deleted file mode 100644 index d8a62339..00000000 --- a/tools/goctl/rpc/gen/gencall.go +++ /dev/null @@ -1,224 +0,0 @@ -package gen - -import ( - "fmt" - "path/filepath" - "strings" - - "github.com/tal-tech/go-zero/core/collection" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" -) - -const ( - typesFilename = "types.go" - callTemplateText = `{{.head}} - -//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE - -package {{.filePackage}} - -import ( - "context" - - {{.package}} - - "github.com/tal-tech/go-zero/core/jsonx" - "github.com/tal-tech/go-zero/zrpc" -) - -type ( - {{.serviceName}} interface { - {{.interface}} - } - - default{{.serviceName}} struct { - cli zrpc.Client - } -) - -func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} { - return &default{{.serviceName}}{ - cli: cli, - } -} - -{{.functions}} -` - callTemplateTypes = `{{.head}} - -package {{.filePackage}} - -import "errors" - -var errJsonConvert = errors.New("json convert error") - -{{.const}} - -{{.types}} -` - callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}} -{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)` - - callFunctionTemplate = ` -{{if .hasComment}}{{.comment}}{{end}} -func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) { - var request {{.pbRequest}} - bts, err := jsonx.Marshal(in) - if err != nil { - return nil, errJsonConvert - } - - err = jsonx.Unmarshal(bts, &request) - if err != nil { - return nil, errJsonConvert - } - - client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn()) - resp, err := client.{{.method}}(ctx, &request) - if err != nil{ - return nil, err - } - - var ret {{.pbResponse}} - bts, err = jsonx.Marshal(resp) - if err != nil{ - return nil, errJsonConvert - } - - err = jsonx.Unmarshal(bts, &ret) - if err != nil{ - return nil, errJsonConvert - } - - return &ret, nil -} -` -) - -func (g *defaultRpcGenerator) genCall() error { - file := g.ast - if len(file.Service) == 0 { - return nil - } - if len(file.Service) > 1 { - return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service)) - } - - typeCode, err := file.GenTypesCode() - if err != nil { - return err - } - - constLit, err := file.GenEnumCode() - if err != nil { - return err - } - - service := file.Service[0] - callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower()) - if err = util.MkdirIfNotExist(callPath); err != nil { - return err - } - - filename := filepath.Join(callPath, typesFilename) - head := util.GetHead(g.Ctx.ProtoSource) - text, err := util.LoadTemplate(category, callTypesTemplateFile, callTemplateTypes) - if err != nil { - return err - } - err = util.With("types").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ - "head": head, - "const": constLit, - "filePackage": service.Name.Lower(), - "serviceName": g.Ctx.ServiceName.Title(), - "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(), - "types": typeCode, - }, filename, true) - if err != nil { - return err - } - - filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower())) - functions, importList, err := g.genFunction(service) - if err != nil { - return err - } - - iFunctions, err := g.getInterfaceFuncs(service) - if err != nil { - return err - } - text, err = util.LoadTemplate(category, callTemplateFile, callTemplateText) - if err != nil { - return err - } - err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ - "name": service.Name.Lower(), - "head": head, - "filePackage": service.Name.Lower(), - "package": strings.Join(importList, util.NL), - "serviceName": service.Name.Title(), - "functions": strings.Join(functions, util.NL), - "interface": strings.Join(iFunctions, util.NL), - }, filename, true) - return err -} - -func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) { - file := g.ast - pkgName := file.Package - functions := make([]string, 0) - imports := collection.NewSet() - imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb))) - for _, method := range service.Funcs { - imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) - text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate) - if err != nil { - return nil, nil, err - } - buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{ - "rpcServiceName": service.Name.Title(), - "method": method.Name.Title(), - "package": pkgName, - "pbRequestName": method.ParameterIn.Name, - "pbRequest": method.ParameterIn.Expression, - "pbResponse": method.ParameterOut.Name, - "hasComment": method.HaveDoc(), - "comment": method.GetDoc(), - }) - if err != nil { - return nil, nil, err - } - - functions = append(functions, buffer.String()) - } - return functions, imports.KeysStr(), nil -} - -func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) { - functions := make([]string, 0) - - for _, method := range service.Funcs { - text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate) - if err != nil { - return nil, err - } - - buffer, err := util.With("interfaceFn").Parse(text).Execute( - map[string]interface{}{ - "hasComment": method.HaveDoc(), - "comment": method.GetDoc(), - "method": method.Name.Title(), - "pbRequest": method.ParameterIn.Name, - "pbResponse": method.ParameterOut.Name, - }) - if err != nil { - return nil, err - } - - functions = append(functions, buffer.String()) - } - - return functions, nil -} diff --git a/tools/goctl/rpc/gen/gendir.go b/tools/goctl/rpc/gen/gendir.go deleted file mode 100644 index bb0bac22..00000000 --- a/tools/goctl/rpc/gen/gendir.go +++ /dev/null @@ -1,54 +0,0 @@ -package gen - -import ( - "path/filepath" - "runtime" - "strings" - - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/vars" -) - -// target -// ├── etc -// ├── internal -// │   ├── config -// │   ├── handler -// │   ├── logic -// │   ├── pb -// │   └── svc -func (g *defaultRpcGenerator) createDir() error { - ctx := g.Ctx - m := make(map[string]string) - m[dirTarget] = ctx.TargetDir - m[dirEtc] = filepath.Join(ctx.TargetDir, dirEtc) - m[dirInternal] = filepath.Join(ctx.TargetDir, dirInternal) - m[dirConfig] = filepath.Join(ctx.TargetDir, dirInternal, dirConfig) - m[dirServer] = filepath.Join(ctx.TargetDir, dirInternal, dirServer) - m[dirLogic] = filepath.Join(ctx.TargetDir, dirInternal, dirLogic) - m[dirPb] = filepath.Join(ctx.TargetDir, dirPb) - m[dirSvc] = filepath.Join(ctx.TargetDir, dirInternal, dirSvc) - for _, d := range m { - err := util.MkdirIfNotExist(d) - if err != nil { - return err - } - } - g.dirM = m - return nil -} - -func (g *defaultRpcGenerator) mustGetPackage(dir string) string { - target := g.dirM[dir] - projectPath := g.Ctx.ProjectPath - relativePath := strings.TrimPrefix(target, projectPath) - os := runtime.GOOS - switch os { - case vars.OsWindows: - relativePath = filepath.ToSlash(relativePath) - case vars.OsMac, vars.OsLinux: - default: - g.Ctx.Fatalln("unexpected os: %s", os) - } - return g.Ctx.Module + relativePath -} diff --git a/tools/goctl/rpc/gen/genlogic.go b/tools/goctl/rpc/gen/genlogic.go deleted file mode 100644 index 969205ce..00000000 --- a/tools/goctl/rpc/gen/genlogic.go +++ /dev/null @@ -1,109 +0,0 @@ -package gen - -import ( - "fmt" - "path/filepath" - "strings" - - "github.com/tal-tech/go-zero/core/collection" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" -) - -const ( - logicTemplate = `package logic - -import ( - "context" - - {{.imports}} - - "github.com/tal-tech/go-zero/core/logx" -) - -type {{.logicName}} struct { - ctx context.Context - svcCtx *svc.ServiceContext - logx.Logger -} - -func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} { - return &{{.logicName}}{ - ctx: ctx, - svcCtx: svcCtx, - Logger: logx.WithContext(ctx), - } -} -{{.functions}} -` - logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}} -func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) { - // todo: add your logic here and delete this line - - return &{{.responseType}}{}, nil -} -` -) - -func (g *defaultRpcGenerator) genLogic() error { - logicPath := g.dirM[dirLogic] - protoPkg := g.ast.Package - service := g.ast.Service - for _, item := range service { - for _, method := range item.Funcs { - logicName := fmt.Sprintf("%slogic.go", method.Name.Lower()) - filename := filepath.Join(logicPath, logicName) - functions, importList, err := g.genLogicFunction(protoPkg, method) - if err != nil { - return err - } - imports := collection.NewSet() - svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc)) - imports.AddStr(svcImport) - imports.AddStr(importList...) - text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate) - if err != nil { - return err - } - err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ - "logicName": fmt.Sprintf("%sLogic", method.Name.Title()), - "functions": functions, - "imports": strings.Join(imports.KeysStr(), util.NL), - }, filename, false) - if err != nil { - return err - } - } - } - return nil -} - -func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) { - var functions = make([]string, 0) - var imports = collection.NewSet() - if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName { - imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb))) - } - imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) - imports.AddStr(g.ast.Imports[method.ParameterOut.Package]) - text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate) - if err != nil { - return "", nil, err - } - - buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{ - "logicName": fmt.Sprintf("%sLogic", method.Name.Title()), - "method": method.Name.Title(), - "request": method.ParameterIn.StarExpression, - "response": method.ParameterOut.StarExpression, - "responseType": method.ParameterOut.Expression, - "hasComment": method.HaveDoc(), - "comment": method.GetDoc(), - }) - if err != nil { - return "", nil, err - } - - functions = append(functions, buffer.String()) - return strings.Join(functions, util.NL), imports.KeysStr(), nil -} diff --git a/tools/goctl/rpc/gen/genmain.go b/tools/goctl/rpc/gen/genmain.go deleted file mode 100644 index 653e4da5..00000000 --- a/tools/goctl/rpc/gen/genmain.go +++ /dev/null @@ -1,85 +0,0 @@ -package gen - -import ( - "fmt" - "path/filepath" - "strings" - - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" -) - -const mainTemplate = `{{.head}} - -package main - -import ( - "flag" - "fmt" - - {{.imports}} - - "github.com/tal-tech/go-zero/core/conf" - "github.com/tal-tech/go-zero/zrpc" - "google.golang.org/grpc" -) - -var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file") - -func main() { - flag.Parse() - - var c config.Config - conf.MustLoad(*configFile, &c) - ctx := svc.NewServiceContext(c) - {{.srv}} - - s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) { - {{.registers}} - }) - defer s.Stop() - - fmt.Printf("Starting rpc server at %s...\n", c.ListenOn) - s.Start() -} -` - -func (g *defaultRpcGenerator) genMain() error { - mainPath := g.dirM[dirTarget] - file := g.ast - pkg := file.Package - - fileName := filepath.Join(mainPath, fmt.Sprintf("%v.go", g.Ctx.ServiceName.Lower())) - imports := make([]string, 0) - pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)) - svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc)) - remoteImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirServer)) - configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)) - imports = append(imports, configImport, pbImport, remoteImport, svcImport) - srv, registers := g.genServer(pkg, file.Service) - head := util.GetHead(g.Ctx.ProtoSource) - text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate) - if err != nil { - return err - } - - return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ - "head": head, - "package": pkg, - "serviceName": g.Ctx.ServiceName.Lower(), - "srv": srv, - "registers": registers, - "imports": strings.Join(imports, util.NL), - }, fileName, false) -} - -func (g *defaultRpcGenerator) genServer(pkg string, list []*parser.RpcService) (string, string) { - list1 := make([]string, 0) - list2 := make([]string, 0) - for _, item := range list { - name := item.Name.UnTitle() - list1 = append(list1, fmt.Sprintf("%sSrv := server.New%sServer(ctx)", name, item.Name.Title())) - list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name)) - } - return strings.Join(list1, util.NL), strings.Join(list2, util.NL) -} diff --git a/tools/goctl/rpc/gen/genpb.go b/tools/goctl/rpc/gen/genpb.go deleted file mode 100644 index be772e34..00000000 --- a/tools/goctl/rpc/gen/genpb.go +++ /dev/null @@ -1,82 +0,0 @@ -package gen - -import ( - "bytes" - "fmt" - "path/filepath" - "strings" - - "github.com/tal-tech/go-zero/core/collection" - "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" -) - -const ( - protocCmd = "protoc" - grpcPluginCmd = "--go_out=plugins=grpc" -) - -func (g *defaultRpcGenerator) genPb() error { - pbPath := g.dirM[dirPb] - // deprecated: containsAny will be removed in the feature - imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc) - if err != nil { - return err - } - - err = g.protocGenGo(pbPath, imports) - if err != nil { - return err - } - ast, err := parser.Transfer(g.Ctx.ProtoFileSrc, pbPath, imports, g.Ctx.Console) - if err != nil { - return err - } - ast.ContainsAny = containsAny - - if len(ast.Service) == 0 { - return fmt.Errorf("service not found") - } - g.ast = ast - return nil -} - -func (g *defaultRpcGenerator) protocGenGo(target string, imports []*parser.Import) error { - dir := filepath.Dir(g.Ctx.ProtoFileSrc) - // cmd join,see the document of proto generating class @https://developers.google.com/protocol-buffers/docs/proto3#generating - // template: protoc -I=${import_path} -I=${other_import_path} -I=${...} --go_out=plugins=grpc,M${pb_package_kv}, M${...} :${target_dir} - // eg: protoc -I=${GOPATH}/src -I=. example.proto --go_out=plugins=grpc,Mbase/base.proto=github.com/go-zero/base.proto:. - // note: the external import out of the project which are found in ${GOPATH}/src so far. - - buffer := new(bytes.Buffer) - buffer.WriteString(protocCmd + " ") - targetImportFiltered := collection.NewSet() - - for _, item := range imports { - buffer.WriteString(fmt.Sprintf("-I=%s ", item.OriginalDir)) - if len(item.BridgeImport) == 0 { - continue - } - targetImportFiltered.AddStr(item.BridgeImport) - - } - buffer.WriteString("-I=${GOPATH}/src ") - buffer.WriteString(fmt.Sprintf("-I=%s %s ", dir, g.Ctx.ProtoFileSrc)) - - buffer.WriteString(grpcPluginCmd) - if targetImportFiltered.Count() > 0 { - buffer.WriteString(fmt.Sprintf(",%v", strings.Join(targetImportFiltered.KeysStr(), ","))) - } - buffer.WriteString(":" + target) - g.Ctx.Debug("-> " + buffer.String()) - stdout, err := execx.Run(buffer.String(), "") - if err != nil { - return err - } - - if len(stdout) > 0 { - g.Ctx.Info(stdout) - } - - return nil -} diff --git a/tools/goctl/rpc/gen/genserver.go b/tools/goctl/rpc/gen/genserver.go deleted file mode 100644 index a9ef3418..00000000 --- a/tools/goctl/rpc/gen/genserver.go +++ /dev/null @@ -1,116 +0,0 @@ -package gen - -import ( - "fmt" - "path/filepath" - "strings" - - "github.com/tal-tech/go-zero/core/collection" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" -) - -const ( - serverTemplate = `{{.head}} - -package server - -import ( - "context" - - {{.imports}} -) - -type {{.types}} - -func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server { - return &{{.server}}Server{ - svcCtx: svcCtx, - } -} - -{{.funcs}} -` - functionTemplate = ` -{{if .hasComment}}{{.comment}}{{end}} -func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) { - l := logic.New{{.logicName}}(ctx,s.svcCtx) - return l.{{.method}}(in) -} -` - typeFmt = `%sServer struct { - svcCtx *svc.ServiceContext - }` -) - -func (g *defaultRpcGenerator) genHandler() error { - serverPath := g.dirM[dirServer] - file := g.ast - logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic)) - svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc)) - imports := collection.NewSet() - imports.AddStr(logicImport, svcImport) - - head := util.GetHead(g.Ctx.ProtoSource) - for _, service := range file.Service { - filename := fmt.Sprintf("%vserver.go", service.Name.Lower()) - serverFile := filepath.Join(serverPath, filename) - funcList, importList, err := g.genFunctions(service) - if err != nil { - return err - } - - imports.AddStr(importList...) - text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate) - if err != nil { - return err - } - - err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ - "head": head, - "types": fmt.Sprintf(typeFmt, service.Name.Title()), - "server": service.Name.Title(), - "imports": strings.Join(imports.KeysStr(), util.NL), - "funcs": strings.Join(funcList, util.NL), - }, serverFile, true) - if err != nil { - return err - } - } - return nil -} - -func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) { - file := g.ast - pkg := file.Package - var functionList []string - imports := collection.NewSet() - for _, method := range service.Funcs { - if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg { - imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))) - } - imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) - imports.AddStr(g.ast.Imports[method.ParameterOut.Package]) - text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate) - if err != nil { - return nil, nil, err - } - - buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{ - "server": service.Name.Title(), - "logicName": fmt.Sprintf("%sLogic", method.Name.Title()), - "method": method.Name.Title(), - "package": pkg, - "request": method.ParameterIn.StarExpression, - "response": method.ParameterOut.StarExpression, - "hasComment": method.HaveDoc(), - "comment": method.GetDoc(), - }) - if err != nil { - return nil, nil, err - } - - functionList = append(functionList, buffer.String()) - } - return functionList, imports.KeysStr(), nil -} diff --git a/tools/goctl/rpc/gen/gomod.go b/tools/goctl/rpc/gen/gomod.go deleted file mode 100644 index 070b4d9e..00000000 --- a/tools/goctl/rpc/gen/gomod.go +++ /dev/null @@ -1,22 +0,0 @@ -package gen - -import ( - "fmt" - - "github.com/tal-tech/go-zero/core/logx" - "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" -) - -func (g *defaultRpcGenerator) initGoMod() error { - if !g.Ctx.IsInGoEnv { - projectDir := g.dirM[dirTarget] - cmd := fmt.Sprintf("go mod init %s", g.Ctx.ProjectName.Source()) - output, err := execx.Run(fmt.Sprintf(cmd), projectDir) - if err != nil { - logx.Error(err) - return err - } - g.Ctx.Info(output) - } - return nil -} diff --git a/tools/goctl/rpc/gen_test.go b/tools/goctl/rpc/gen_test.go deleted file mode 100644 index e32e99b2..00000000 --- a/tools/goctl/rpc/gen_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package base - -import ( - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util/console" -) - -func TestParseImport(t *testing.T) { - src, _ := filepath.Abs("./test.proto") - base, _ := filepath.Abs("./base.proto") - imports, containsAny, err := parser.ParseImport(src) - assert.Nil(t, err) - assert.Equal(t, true, containsAny) - assert.Equal(t, 1, len(imports)) - assert.Equal(t, "github.com/tal-tech/go-zero/tools/goctl/rpc", imports[0].PbImportName) - assert.Equal(t, base, imports[0].OriginalProtoPath) -} - -func TestTransfer(t *testing.T) { - src, _ := filepath.Abs("./test.proto") - abs, _ := filepath.Abs("./test") - imports, _, _ := parser.ParseImport(src) - proto, err := parser.Transfer(src, abs, imports, console.NewConsole(false)) - assert.Nil(t, err) - assert.Equal(t, 1, len(proto.Service)) - assert.Equal(t, "Greeter", proto.Service[0].Name.Source()) - assert.Equal(t, 5, len(proto.Structure)) - data, ok := proto.Structure["map"] - assert.Equal(t, true, ok) - assert.Equal(t, "M", data.Field[0].Name.Source()) -} diff --git a/tools/goctl/rpc/generator/base/common.pb.go b/tools/goctl/rpc/generator/base/common.pb.go new file mode 100644 index 00000000..455529fe --- /dev/null +++ b/tools/goctl/rpc/generator/base/common.pb.go @@ -0,0 +1,75 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: common.proto + +package common + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type User struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *User) Reset() { *m = User{} } +func (m *User) String() string { return proto.CompactTextString(m) } +func (*User) ProtoMessage() {} +func (*User) Descriptor() ([]byte, []int) { + return fileDescriptor_555bd8c177793206, []int{0} +} + +func (m *User) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_User.Unmarshal(m, b) +} +func (m *User) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_User.Marshal(b, m, deterministic) +} +func (m *User) XXX_Merge(src proto.Message) { + xxx_messageInfo_User.Merge(m, src) +} +func (m *User) XXX_Size() int { + return xxx_messageInfo_User.Size(m) +} +func (m *User) XXX_DiscardUnknown() { + xxx_messageInfo_User.DiscardUnknown(m) +} + +var xxx_messageInfo_User proto.InternalMessageInfo + +func (m *User) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func init() { + proto.RegisterType((*User)(nil), "common.User") +} + +func init() { proto.RegisterFile("common.proto", fileDescriptor_555bd8c177793206) } + +var fileDescriptor_555bd8c177793206 = []byte{ + // 72 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0xce, 0xcf, 0xcd, + 0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0xa4, 0xb8, 0x58, + 0x42, 0x8b, 0x53, 0x8b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, 0x25, 0x18, 0x15, 0x18, + 0x35, 0x38, 0x83, 0xc0, 0xec, 0x24, 0x36, 0xb0, 0x52, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, + 0x2c, 0x6d, 0x58, 0x59, 0x3a, 0x00, 0x00, 0x00, +} diff --git a/tools/goctl/rpc/generator/base/common.proto b/tools/goctl/rpc/generator/base/common.proto new file mode 100644 index 00000000..42c9107e --- /dev/null +++ b/tools/goctl/rpc/generator/base/common.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package common; + +message User { + string name = 1; +} diff --git a/tools/goctl/rpc/generator/defaultgenerator.go b/tools/goctl/rpc/generator/defaultgenerator.go new file mode 100644 index 00000000..eb5853d1 --- /dev/null +++ b/tools/goctl/rpc/generator/defaultgenerator.go @@ -0,0 +1,33 @@ +package generator + +import ( + "os/exec" + + "github.com/tal-tech/go-zero/tools/goctl/util/console" +) + +type defaultGenerator struct { + log console.Console +} + +func NewDefaultGenerator() *defaultGenerator { + log := console.NewColorConsole() + return &defaultGenerator{ + log: log, + } +} + +func (g *defaultGenerator) Prepare() error { + _, err := exec.LookPath("go") + if err != nil { + return err + } + + _, err = exec.LookPath("protoc") + if err != nil { + return err + } + + _, err = exec.LookPath("protoc-gen-go") + return err +} diff --git a/tools/goctl/rpc/generator/filename.go b/tools/goctl/rpc/generator/filename.go new file mode 100644 index 00000000..89617ded --- /dev/null +++ b/tools/goctl/rpc/generator/filename.go @@ -0,0 +1,11 @@ +package generator + +import ( + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" +) + +func formatFilename(filename string) string { + return strings.ToLower(stringx.From(filename).ToCamel()) +} diff --git a/tools/goctl/rpc/generator/gen.go b/tools/goctl/rpc/generator/gen.go new file mode 100644 index 00000000..d09c6394 --- /dev/null +++ b/tools/goctl/rpc/generator/gen.go @@ -0,0 +1,98 @@ +package generator + +import ( + "path/filepath" + + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/console" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +type RpcGenerator struct { + g Generator +} + +func NewDefaultRpcGenerator() *RpcGenerator { + return NewRpcGenerator(NewDefaultGenerator()) +} + +func NewRpcGenerator(g Generator) *RpcGenerator { + return &RpcGenerator{ + g: g, + } +} + +func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) error { + abs, err := filepath.Abs(target) + if err != nil { + return err + } + + err = util.MkdirIfNotExist(abs) + if err != nil { + return err + } + + err = g.g.Prepare() + if err != nil { + return err + } + + projectCtx, err := ctx.Prepare(abs) + if err != nil { + return err + } + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse(src) + if err != nil { + return err + } + + dirCtx, err := mkdir(projectCtx, proto) + if err != nil { + return err + } + + err = g.g.GenEtc(dirCtx, proto) + if err != nil { + return err + } + + err = g.g.GenPb(dirCtx, protoImportPath, proto) + if err != nil { + return err + } + + err = g.g.GenConfig(dirCtx, proto) + if err != nil { + return err + } + + err = g.g.GenSvc(dirCtx, proto) + if err != nil { + return err + } + + err = g.g.GenLogic(dirCtx, proto) + if err != nil { + return err + } + + err = g.g.GenServer(dirCtx, proto) + if err != nil { + return err + } + + err = g.g.GenMain(dirCtx, proto) + if err != nil { + return err + } + + err = g.g.GenCall(dirCtx, proto) + + console.NewColorConsole().MarkDone() + + return err +} diff --git a/tools/goctl/rpc/generator/gen_test.go b/tools/goctl/rpc/generator/gen_test.go new file mode 100644 index 00000000..25c2126d --- /dev/null +++ b/tools/goctl/rpc/generator/gen_test.go @@ -0,0 +1,104 @@ +package generator + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" +) + +func TestRpcGenerateCaseNilImport(t *testing.T) { + dispatcher := NewDefaultGenerator() + if err := dispatcher.Prepare(); err == nil { + g := NewRpcGenerator(dispatcher) + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + err = g.Generate("./test_stream.proto", abs, nil) + defer func() { + _ = os.RemoveAll(abs) + }() + assert.Nil(t, err) + + _, err = execx.Run("go test "+abs, abs) + assert.Nil(t, err) + } +} + +func TestRpcGenerateCaseOption(t *testing.T) { + dispatcher := NewDefaultGenerator() + if err := dispatcher.Prepare(); err == nil { + g := NewRpcGenerator(dispatcher) + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + err = g.Generate("./test_option.proto", abs, nil) + defer func() { + _ = os.RemoveAll(abs) + }() + assert.Nil(t, err) + + _, err = execx.Run("go test "+abs, abs) + assert.Nil(t, err) + } +} + +func TestRpcGenerateCaseWordOption(t *testing.T) { + dispatcher := NewDefaultGenerator() + if err := dispatcher.Prepare(); err == nil { + g := NewRpcGenerator(dispatcher) + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + err = g.Generate("./test_word_option.proto", abs, nil) + defer func() { + _ = os.RemoveAll(abs) + }() + assert.Nil(t, err) + + _, err = execx.Run("go test "+abs, abs) + assert.Nil(t, err) + } +} + +// test keyword go +func TestRpcGenerateCaseGoOption(t *testing.T) { + dispatcher := NewDefaultGenerator() + if err := dispatcher.Prepare(); err == nil { + g := NewRpcGenerator(dispatcher) + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + err = g.Generate("./test_go_option.proto", abs, nil) + defer func() { + _ = os.RemoveAll(abs) + }() + assert.Nil(t, err) + + _, err = execx.Run("go test "+abs, abs) + assert.Nil(t, err) + } +} + +func TestRpcGenerateCaseImport(t *testing.T) { + dispatcher := NewDefaultGenerator() + if err := dispatcher.Prepare(); err == nil { + g := NewRpcGenerator(dispatcher) + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + err = g.Generate("./test_import.proto", abs, []string{"./base"}) + defer func() { + _ = os.RemoveAll(abs) + }() + assert.Nil(t, err) + + _, err = execx.Run("go test "+abs, abs) + assert.True(t, func() bool { + return strings.Contains(err.Error(), "package base is not in GOROOT") + }()) + } +} diff --git a/tools/goctl/rpc/generator/gencall.go b/tools/goctl/rpc/generator/gencall.go new file mode 100644 index 00000000..4afc1f04 --- /dev/null +++ b/tools/goctl/rpc/generator/gencall.go @@ -0,0 +1,154 @@ +package generator + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" +) + +const ( + callTemplateText = `{{.head}} + +//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE + +package {{.filePackage}} + +import ( + "context" + + {{.package}} + + "github.com/tal-tech/go-zero/zrpc" +) + +type ( + {{.alias}} + + {{.serviceName}} interface { + {{.interface}} + } + + default{{.serviceName}} struct { + cli zrpc.Client + } +) + +func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} { + return &default{{.serviceName}}{ + cli: cli, + } +} + +{{.functions}} +` + + callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}} +{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)` + + callFunctionTemplate = ` +{{if .hasComment}}{{.comment}}{{end}} +func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) { + client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn()) + return client.{{.method}}(ctx, in) +} +` +) + +func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error { + dir := ctx.GetCall() + service := proto.Service + head := util.GetHead(proto.Name) + + filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name))) + functions, err := g.genFunction(proto.PbPackage, service) + if err != nil { + return err + } + + iFunctions, err := g.getInterfaceFuncs(service) + if err != nil { + return err + } + + text, err := util.LoadTemplate(category, callTemplateFile, callTemplateText) + if err != nil { + return err + } + + var alias []string + for _, item := range service.RPC { + alias = append(alias, fmt.Sprintf("%s = %s", parser.CamelCase(item.RequestType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.RequestType)))) + alias = append(alias, fmt.Sprintf("%s = %s", parser.CamelCase(item.ReturnsType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.ReturnsType)))) + } + + err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ + "name": formatFilename(service.Name), + "alias": strings.Join(alias, util.NL), + "head": head, + "filePackage": formatFilename(service.Name), + "package": fmt.Sprintf(`"%s"`, ctx.GetPb().Package), + "serviceName": parser.CamelCase(service.Name), + "functions": strings.Join(functions, util.NL), + "interface": strings.Join(iFunctions, util.NL), + }, filename, true) + return err +} + +func (g *defaultGenerator) genFunction(goPackage string, service parser.Service) ([]string, error) { + functions := make([]string, 0) + for _, rpc := range service.RPC { + text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate) + if err != nil { + return nil, err + } + + comment := parser.GetComment(rpc.Doc()) + buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{ + "rpcServiceName": stringx.From(service.Name).Title(), + "method": stringx.From(rpc.Name).Title(), + "package": goPackage, + "pbRequest": parser.CamelCase(rpc.RequestType), + "pbResponse": parser.CamelCase(rpc.ReturnsType), + "hasComment": len(comment) > 0, + "comment": comment, + }) + if err != nil { + return nil, err + } + + functions = append(functions, buffer.String()) + } + return functions, nil +} + +func (g *defaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) { + functions := make([]string, 0) + + for _, rpc := range service.RPC { + text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate) + if err != nil { + return nil, err + } + + comment := parser.GetComment(rpc.Doc()) + buffer, err := util.With("interfaceFn").Parse(text).Execute( + map[string]interface{}{ + "hasComment": len(comment) > 0, + "comment": comment, + "method": stringx.From(rpc.Name).Title(), + "pbRequest": parser.CamelCase(rpc.RequestType), + "pbResponse": parser.CamelCase(rpc.ReturnsType), + }) + if err != nil { + return nil, err + } + + functions = append(functions, buffer.String()) + } + + return functions, nil +} diff --git a/tools/goctl/rpc/generator/gencall_test.go b/tools/goctl/rpc/generator/gencall_test.go new file mode 100644 index 00000000..6ae0cefe --- /dev/null +++ b/tools/goctl/rpc/generator/gencall_test.go @@ -0,0 +1,44 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateCall(t *testing.T) { + _ = Clean() + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + err = g.Prepare() + if err != nil { + return + } + err = g.GenCall(dirCtx, proto) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/gen/genconfig.go b/tools/goctl/rpc/generator/genconfig.go similarity index 64% rename from tools/goctl/rpc/gen/genconfig.go rename to tools/goctl/rpc/generator/genconfig.go index 91daf5c9..c9cd4bc4 100644 --- a/tools/goctl/rpc/gen/genconfig.go +++ b/tools/goctl/rpc/generator/genconfig.go @@ -1,10 +1,11 @@ -package gen +package generator import ( "io/ioutil" "os" "path/filepath" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -17,9 +18,9 @@ type Config struct { } ` -func (g *defaultRpcGenerator) genConfig() error { - configPath := g.dirM[dirConfig] - fileName := filepath.Join(configPath, fileConfig) +func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto) error { + dir := ctx.GetConfig() + fileName := filepath.Join(dir.Filename, formatFilename("config")+".go") if util.FileExists(fileName) { return nil } diff --git a/tools/goctl/rpc/generator/genconfig_test.go b/tools/goctl/rpc/generator/genconfig_test.go new file mode 100644 index 00000000..39006b10 --- /dev/null +++ b/tools/goctl/rpc/generator/genconfig_test.go @@ -0,0 +1,48 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateConfig(t *testing.T) { + _ = Clean() + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + err = g.Prepare() + if err != nil { + return + } + err = g.GenConfig(dirCtx, proto) + assert.Nil(t, err) + + // test file exists + err = g.GenConfig(dirCtx, proto) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/generator/generator.go b/tools/goctl/rpc/generator/generator.go new file mode 100644 index 00000000..a46ed42c --- /dev/null +++ b/tools/goctl/rpc/generator/generator.go @@ -0,0 +1,15 @@ +package generator + +import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + +type Generator interface { + Prepare() error + GenMain(ctx DirContext, proto parser.Proto) error + GenCall(ctx DirContext, proto parser.Proto) error + GenEtc(ctx DirContext, proto parser.Proto) error + GenConfig(ctx DirContext, proto parser.Proto) error + GenLogic(ctx DirContext, proto parser.Proto) error + GenServer(ctx DirContext, proto parser.Proto) error + GenSvc(ctx DirContext, proto parser.Proto) error + GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error +} diff --git a/tools/goctl/rpc/gen/genetc.go b/tools/goctl/rpc/generator/genetc.go similarity index 55% rename from tools/goctl/rpc/gen/genetc.go rename to tools/goctl/rpc/generator/genetc.go index c2e64b53..ddd36fba 100644 --- a/tools/goctl/rpc/gen/genetc.go +++ b/tools/goctl/rpc/generator/genetc.go @@ -1,9 +1,10 @@ -package gen +package generator import ( "fmt" "path/filepath" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -15,12 +16,10 @@ Etcd: Key: {{.serviceName}}.rpc ` -func (g *defaultRpcGenerator) genEtc() error { - etdDir := g.dirM[dirEtc] - fileName := filepath.Join(etdDir, fmt.Sprintf("%v.yaml", g.Ctx.ServiceName.Lower())) - if util.FileExists(fileName) { - return nil - } +func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error { + dir := ctx.GetEtc() + serviceNameLower := formatFilename(ctx.GetMain().Base) + fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower)) text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate) if err != nil { @@ -28,6 +27,6 @@ func (g *defaultRpcGenerator) genEtc() error { } return util.With("etc").Parse(text).SaveTo(map[string]interface{}{ - "serviceName": g.Ctx.ServiceName.Lower(), + "serviceName": serviceNameLower, }, fileName, false) } diff --git a/tools/goctl/rpc/generator/genetc_test.go b/tools/goctl/rpc/generator/genetc_test.go new file mode 100644 index 00000000..457cfed4 --- /dev/null +++ b/tools/goctl/rpc/generator/genetc_test.go @@ -0,0 +1,45 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateEtc(t *testing.T) { + _ = Clean() + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + err = g.Prepare() + if err != nil { + return + } + + err = g.GenEtc(dirCtx, proto) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/generator/genlogic.go b/tools/goctl/rpc/generator/genlogic.go new file mode 100644 index 00000000..a48c9092 --- /dev/null +++ b/tools/goctl/rpc/generator/genlogic.go @@ -0,0 +1,101 @@ +package generator + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/core/collection" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" +) + +const ( + logicTemplate = `package logic + +import ( + "context" + + {{.imports}} + + "github.com/tal-tech/go-zero/core/logx" +) + +type {{.logicName}} struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} { + return &{{.logicName}}{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} +{{.functions}} +` + logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}} +func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) { + // todo: add your logic here and delete this line + + return &{{.responseType}}{}, nil +} +` +) + +func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error { + dir := ctx.GetLogic() + for _, rpc := range proto.Service.RPC { + filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic")+".go") + functions, err := g.genLogicFunction(proto.PbPackage, rpc) + if err != nil { + return err + } + + imports := collection.NewSet() + imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)) + imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetPb().Package)) + text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate) + if err != nil { + return err + } + err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ + "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()), + "functions": functions, + "imports": strings.Join(imports.KeysStr(), util.NL), + }, filename, false) + if err != nil { + return err + } + } + return nil +} + +func (g *defaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) { + var functions = make([]string, 0) + text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate) + if err != nil { + return "", err + } + + logicName := stringx.From(rpc.Name + "_logic").ToCamel() + comment := parser.GetComment(rpc.Doc()) + buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{ + "logicName": logicName, + "method": parser.CamelCase(rpc.Name), + "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)), + "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)), + "responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)), + "hasComment": len(comment) > 0, + "comment": comment, + }) + if err != nil { + return "", err + } + + functions = append(functions, buffer.String()) + return strings.Join(functions, util.NL), nil +} diff --git a/tools/goctl/rpc/generator/genlogic_test.go b/tools/goctl/rpc/generator/genlogic_test.go new file mode 100644 index 00000000..681c89d7 --- /dev/null +++ b/tools/goctl/rpc/generator/genlogic_test.go @@ -0,0 +1,44 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateLogic(t *testing.T) { + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + err = g.Prepare() + if err != nil { + return + } + + err = g.GenLogic(dirCtx, proto) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/generator/genmain.go b/tools/goctl/rpc/generator/genmain.go new file mode 100644 index 00000000..eed6a7d2 --- /dev/null +++ b/tools/goctl/rpc/generator/genmain.go @@ -0,0 +1,70 @@ +package generator + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" +) + +const mainTemplate = `{{.head}} + +package main + +import ( + "flag" + "fmt" + + {{.imports}} + + "github.com/tal-tech/go-zero/core/conf" + "github.com/tal-tech/go-zero/zrpc" + "google.golang.org/grpc" +) + +var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file") + +func main() { + flag.Parse() + + var c config.Config + conf.MustLoad(*configFile, &c) + ctx := svc.NewServiceContext(c) + srv := server.New{{.service}}Server(ctx) + + s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) { + {{.pkg}}.Register{{.service}}Server(grpcServer, srv) + }) + defer s.Stop() + + fmt.Printf("Starting rpc server at %s...\n", c.ListenOn) + s.Start() +} +` + +func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error { + dir := ctx.GetMain() + serviceNameLower := formatFilename(ctx.GetMain().Base) + fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower)) + imports := make([]string, 0) + pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package) + svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package) + remoteImport := fmt.Sprintf(`"%v"`, ctx.GetServer().Package) + configImport := fmt.Sprintf(`"%v"`, ctx.GetConfig().Package) + imports = append(imports, configImport, pbImport, remoteImport, svcImport) + head := util.GetHead(proto.Name) + text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate) + if err != nil { + return err + } + + return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ + "head": head, + "serviceName": serviceNameLower, + "imports": strings.Join(imports, util.NL), + "pkg": proto.PbPackage, + "service": parser.CamelCase(proto.Service.Name), + }, fileName, false) +} diff --git a/tools/goctl/rpc/generator/genmain_test.go b/tools/goctl/rpc/generator/genmain_test.go new file mode 100644 index 00000000..aed65b02 --- /dev/null +++ b/tools/goctl/rpc/generator/genmain_test.go @@ -0,0 +1,45 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateMain(t *testing.T) { + _ = Clean() + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + err = g.Prepare() + if err != nil { + return + } + + err = g.GenMain(dirCtx, proto) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/generator/genpb.go b/tools/goctl/rpc/generator/genpb.go new file mode 100644 index 00000000..4327ae52 --- /dev/null +++ b/tools/goctl/rpc/generator/genpb.go @@ -0,0 +1,31 @@ +package generator + +import ( + "bytes" + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" +) + +func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error { + dir := ctx.GetPb() + cw := new(bytes.Buffer) + base := filepath.Dir(proto.Src) + cw.WriteString("protoc ") + for _, ip := range protoImportPath { + cw.WriteString(" -I=" + ip) + } + cw.WriteString(" -I=" + base) + cw.WriteString(" " + proto.Name) + if strings.Contains(proto.GoPackage, "/") { + cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetInternal().Filename) + } else { + cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename) + } + command := cw.String() + g.log.Debug(command) + _, err := execx.Run(command, "") + return err +} diff --git a/tools/goctl/rpc/generator/genpb_test.go b/tools/goctl/rpc/generator/genpb_test.go new file mode 100644 index 00000000..1c423072 --- /dev/null +++ b/tools/goctl/rpc/generator/genpb_test.go @@ -0,0 +1,184 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateCaseNilImport(t *testing.T) { + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + //_ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + if err := g.Prepare(); err == nil { + targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go") + err = g.GenPb(dirCtx, nil, proto) + assert.Nil(t, err) + assert.True(t, func() bool { + return util.FileExists(targetPb) + }()) + } +} + +func TestGenerateCaseImport(t *testing.T) { + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + if err := g.Prepare(); err == nil { + err = g.GenPb(dirCtx, nil, proto) + assert.Nil(t, err) + + targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go") + assert.True(t, func() bool { + return util.FileExists(targetPb) + }()) + } +} + +func TestGenerateCasePathOption(t *testing.T) { + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_option.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + if err := g.Prepare(); err == nil { + err = g.GenPb(dirCtx, nil, proto) + assert.Nil(t, err) + + targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_option.pb.go") + assert.True(t, func() bool { + return util.FileExists(targetPb) + }()) + } +} + +func TestGenerateCaseWordOption(t *testing.T) { + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_word_option.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + if err := g.Prepare(); err == nil { + + err = g.GenPb(dirCtx, nil, proto) + assert.Nil(t, err) + + targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_word_option.pb.go") + assert.True(t, func() bool { + return util.FileExists(targetPb) + }()) + } +} + +// test keyword go +func TestGenerateCaseGoOption(t *testing.T) { + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_go_option.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + if err := g.Prepare(); err == nil { + + err = g.GenPb(dirCtx, nil, proto) + assert.Nil(t, err) + + targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_go_option.pb.go") + assert.True(t, func() bool { + return util.FileExists(targetPb) + }()) + } +} diff --git a/tools/goctl/rpc/generator/genserver.go b/tools/goctl/rpc/generator/genserver.go new file mode 100644 index 00000000..6cc8e820 --- /dev/null +++ b/tools/goctl/rpc/generator/genserver.go @@ -0,0 +1,102 @@ +package generator + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/core/collection" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" +) + +const ( + serverTemplate = `{{.head}} + +package server + +import ( + "context" + + {{.imports}} +) + +type {{.server}}Server struct { + svcCtx *svc.ServiceContext +} + +func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server { + return &{{.server}}Server{ + svcCtx: svcCtx, + } +} + +{{.funcs}} +` + functionTemplate = ` +{{if .hasComment}}{{.comment}}{{end}} +func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) { + l := logic.New{{.logicName}}(ctx,s.svcCtx) + return l.{{.method}}(in) +} +` +) + +func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error { + dir := ctx.GetServer() + logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package) + svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package) + pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package) + + imports := collection.NewSet() + imports.AddStr(logicImport, svcImport, pbImport) + + head := util.GetHead(proto.Name) + service := proto.Service + serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server")+".go") + funcList, err := g.genFunctions(proto.PbPackage, service) + if err != nil { + return err + } + + text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate) + if err != nil { + return err + } + + err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ + "head": head, + "server": stringx.From(service.Name).Title(), + "imports": strings.Join(imports.KeysStr(), util.NL), + "funcs": strings.Join(funcList, util.NL), + }, serverFile, true) + return err +} + +func (g *defaultGenerator) genFunctions(goPackage string, service parser.Service) ([]string, error) { + var functionList []string + for _, rpc := range service.RPC { + text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate) + if err != nil { + return nil, err + } + + comment := parser.GetComment(rpc.Doc()) + buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{ + "server": stringx.From(service.Name).Title(), + "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).Title()), + "method": parser.CamelCase(rpc.Name), + "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)), + "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)), + "hasComment": len(comment) > 0, + "comment": comment, + }) + if err != nil { + return nil, err + } + + functionList = append(functionList, buffer.String()) + } + return functionList, nil +} diff --git a/tools/goctl/rpc/generator/genserver_test.go b/tools/goctl/rpc/generator/genserver_test.go new file mode 100644 index 00000000..e5f1e3f6 --- /dev/null +++ b/tools/goctl/rpc/generator/genserver_test.go @@ -0,0 +1,45 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateServer(t *testing.T) { + _ = Clean() + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + err = g.Prepare() + if err != nil { + return + } + + err = g.GenServer(dirCtx, proto) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/gen/gensvc.go b/tools/goctl/rpc/generator/gensvc.go similarity index 61% rename from tools/goctl/rpc/gen/gensvc.go rename to tools/goctl/rpc/generator/gensvc.go index a1dcb3d1..86df41b4 100644 --- a/tools/goctl/rpc/gen/gensvc.go +++ b/tools/goctl/rpc/generator/gensvc.go @@ -1,9 +1,10 @@ -package gen +package generator import ( "fmt" "path/filepath" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/util" ) @@ -22,15 +23,15 @@ func NewServiceContext(c config.Config) *ServiceContext { } ` -func (g *defaultRpcGenerator) genSvc() error { - svcPath := g.dirM[dirSvc] - fileName := filepath.Join(svcPath, fileServiceContext) +func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto) error { + dir := ctx.GetSvc() + fileName := filepath.Join(dir.Filename, formatFilename("service_context")+".go") text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate) if err != nil { return err } return util.With("svc").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ - "imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)), + "imports": fmt.Sprintf(`"%v"`, ctx.GetConfig().Package), }, fileName, false) } diff --git a/tools/goctl/rpc/generator/gensvc_test.go b/tools/goctl/rpc/generator/gensvc_test.go new file mode 100644 index 00000000..6bf43e0d --- /dev/null +++ b/tools/goctl/rpc/generator/gensvc_test.go @@ -0,0 +1,40 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestGenerateSvc(t *testing.T) { + _ = Clean() + project := "stream" + abs, err := filepath.Abs("./test") + assert.Nil(t, err) + + dir := filepath.Join(abs, project) + err = util.MkdirIfNotExist(dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(abs) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test_stream.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + + g := NewDefaultGenerator() + err = g.GenSvc(dirCtx, proto) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/generator/mkdir.go b/tools/goctl/rpc/generator/mkdir.go new file mode 100644 index 00000000..e0cb6d30 --- /dev/null +++ b/tools/goctl/rpc/generator/mkdir.go @@ -0,0 +1,152 @@ +package generator + +import ( + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" +) + +const ( + wd = "wd" + etc = "etc" + internal = "internal" + config = "config" + logic = "logic" + server = "server" + svc = "svc" + pb = "pb" + call = "call" +) + +type ( + DirContext interface { + GetCall() Dir + GetEtc() Dir + GetInternal() Dir + GetConfig() Dir + GetLogic() Dir + GetServer() Dir + GetSvc() Dir + GetPb() Dir + GetMain() Dir + } + + Dir struct { + Base string + Filename string + Package string + } + defaultDirContext struct { + inner map[string]Dir + } +) + +func mkdir(ctx *ctx.ProjectContext, proto parser.Proto) (DirContext, error) { + inner := make(map[string]Dir) + etcDir := filepath.Join(ctx.WorkDir, "etc") + internalDir := filepath.Join(ctx.WorkDir, "internal") + configDir := filepath.Join(internalDir, "config") + logicDir := filepath.Join(internalDir, "logic") + serverDir := filepath.Join(internalDir, "server") + svcDir := filepath.Join(internalDir, "svc") + pbDir := filepath.Join(internalDir, proto.GoPackage) + callDir := filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name).ToCamel())) + inner[wd] = Dir{ + Filename: ctx.WorkDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(ctx.WorkDir, ctx.Dir))), + Base: filepath.Base(ctx.WorkDir), + } + inner[etc] = Dir{ + Filename: etcDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(etcDir, ctx.Dir))), + Base: filepath.Base(etcDir), + } + inner[internal] = Dir{ + Filename: internalDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(internalDir, ctx.Dir))), + Base: filepath.Base(internalDir), + } + inner[config] = Dir{ + Filename: configDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(configDir, ctx.Dir))), + Base: filepath.Base(configDir), + } + inner[logic] = Dir{ + Filename: logicDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(logicDir, ctx.Dir))), + Base: filepath.Base(logicDir), + } + inner[server] = Dir{ + Filename: serverDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(serverDir, ctx.Dir))), + Base: filepath.Base(serverDir), + } + inner[svc] = Dir{ + Filename: svcDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(svcDir, ctx.Dir))), + Base: filepath.Base(svcDir), + } + inner[pb] = Dir{ + Filename: pbDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(pbDir, ctx.Dir))), + Base: filepath.Base(pbDir), + } + inner[call] = Dir{ + Filename: callDir, + Package: filepath.ToSlash(filepath.Join(ctx.Path, strings.TrimPrefix(callDir, ctx.Dir))), + Base: filepath.Base(callDir), + } + for _, v := range inner { + err := util.MkdirIfNotExist(v.Filename) + if err != nil { + return nil, err + } + } + return &defaultDirContext{ + inner: inner, + }, nil +} + +func (d *defaultDirContext) GetCall() Dir { + return d.inner[call] +} + +func (d *defaultDirContext) GetEtc() Dir { + return d.inner[etc] +} + +func (d *defaultDirContext) GetInternal() Dir { + return d.inner[internal] +} + +func (d *defaultDirContext) GetConfig() Dir { + return d.inner[config] +} + +func (d *defaultDirContext) GetLogic() Dir { + return d.inner[logic] +} + +func (d *defaultDirContext) GetServer() Dir { + return d.inner[server] +} + +func (d *defaultDirContext) GetSvc() Dir { + return d.inner[svc] +} + +func (d *defaultDirContext) GetPb() Dir { + return d.inner[pb] +} + +func (d *defaultDirContext) GetMain() Dir { + return d.inner[wd] +} + +func (d *Dir) Valid() bool { + return len(d.Filename) > 0 && len(d.Package) > 0 +} diff --git a/tools/goctl/rpc/generator/mkdir_test.go b/tools/goctl/rpc/generator/mkdir_test.go new file mode 100644 index 00000000..ae4f1f3a --- /dev/null +++ b/tools/goctl/rpc/generator/mkdir_test.go @@ -0,0 +1,130 @@ +package generator + +import ( + "go/build" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/ctx" +) + +func TestMkDirInGoPath(t *testing.T) { + dft := build.Default + gp := dft.GOPATH + if len(gp) == 0 { + return + } + projectName := stringx.Rand() + dir := filepath.Join(gp, "src", projectName) + err := util.MkdirIfNotExist(dir) + if err != nil { + return + } + + defer func() { + _ = os.RemoveAll(dir) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + internal := filepath.Join(dir, "internal") + assert.True(t, true, func() bool { + return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package + }()) + assert.True(t, true, func() bool { + return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package + }()) + assert.True(t, true, func() bool { + return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package + }()) +} + +func TestMkDirInGoMod(t *testing.T) { + dft := build.Default + gp := dft.GOPATH + if len(gp) == 0 { + return + } + projectName := stringx.Rand() + dir := filepath.Join(gp, "src", projectName) + err := util.MkdirIfNotExist(dir) + if err != nil { + return + } + + _, err = execx.Run("go mod init "+projectName, dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(dir) + }() + + projectCtx, err := ctx.Prepare(dir) + assert.Nil(t, err) + + p := parser.NewDefaultProtoParser() + proto, err := p.Parse("./test.proto") + assert.Nil(t, err) + + dirCtx, err := mkdir(projectCtx, proto) + assert.Nil(t, err) + internal := filepath.Join(dir, "internal") + assert.True(t, true, func() bool { + return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package + }()) + assert.True(t, true, func() bool { + return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package + }()) + assert.True(t, true, func() bool { + return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package + }()) + assert.True(t, true, func() bool { + return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package + }()) +} diff --git a/tools/goctl/rpc/gen/rpctemplate.go b/tools/goctl/rpc/generator/prototmpl.go similarity index 51% rename from tools/goctl/rpc/gen/rpctemplate.go rename to tools/goctl/rpc/generator/prototmpl.go index daf9d016..a4afd242 100644 --- a/tools/goctl/rpc/gen/rpctemplate.go +++ b/tools/goctl/rpc/generator/prototmpl.go @@ -1,12 +1,10 @@ -package gen +package generator import ( "path/filepath" "strings" - "github.com/logrusorgru/aurora" "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -27,33 +25,23 @@ service {{.serviceName}} { } ` -type rpcTemplate struct { - out string - console.Console -} - -func NewRpcTemplate(out string, idea bool) *rpcTemplate { - return &rpcTemplate{ - out: out, - Console: console.NewConsole(idea), - } -} - -func (r *rpcTemplate) MustGenerate(showState bool) { - r.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」") - r.Info("-> generating template...") - protoFilename := filepath.Base(r.out) +func ProtoTmpl(out string) error { + protoFilename := filepath.Base(out) serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename))) text, err := util.LoadTemplate(category, rpcTemplateFile, rpcTemplateText) - r.Must(err) + if err != nil { + return err + } + + dir := filepath.Dir(out) + err = util.MkdirIfNotExist(dir) + if err != nil { + return err + } err = util.With("t").Parse(text).SaveTo(map[string]string{ "package": serviceName.UnTitle(), "serviceName": serviceName.Title(), - }, r.out, false) - r.Must(err) - - if showState { - r.Success("Done.") - } + }, out, false) + return err } diff --git a/tools/goctl/rpc/generator/prototmpl_test.go b/tools/goctl/rpc/generator/prototmpl_test.go new file mode 100644 index 00000000..627ec6f4 --- /dev/null +++ b/tools/goctl/rpc/generator/prototmpl_test.go @@ -0,0 +1,21 @@ +package generator + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProtoTmpl(t *testing.T) { + out, err := filepath.Abs("./test/test.proto") + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(filepath.Dir(out)) + }() + err = ProtoTmpl(out) + assert.Nil(t, err) + _, err = os.Stat(out) + assert.Nil(t, err) +} diff --git a/tools/goctl/rpc/gen/template.go b/tools/goctl/rpc/generator/template.go similarity index 94% rename from tools/goctl/rpc/gen/template.go rename to tools/goctl/rpc/generator/template.go index ccbf453f..1816f900 100644 --- a/tools/goctl/rpc/gen/template.go +++ b/tools/goctl/rpc/generator/template.go @@ -1,4 +1,4 @@ -package gen +package generator import ( "fmt" @@ -10,7 +10,6 @@ import ( const ( category = "rpc" callTemplateFile = "call.tpl" - callTypesTemplateFile = "call-types.tpl" callInterfaceFunctionTemplateFile = "call-interface-func.tpl" callFunctionTemplateFile = "call-func.tpl" configTemplateFileFile = "config.tpl" @@ -26,7 +25,6 @@ const ( var templates = map[string]string{ callTemplateFile: callTemplateText, - callTypesTemplateFile: callTemplateTypes, callInterfaceFunctionTemplateFile: callInterfaceFunctionTemplate, callFunctionTemplateFile: callFunctionTemplate, configTemplateFileFile: configTemplate, diff --git a/tools/goctl/rpc/gen/template_test.go b/tools/goctl/rpc/generator/template_test.go similarity index 95% rename from tools/goctl/rpc/gen/template_test.go rename to tools/goctl/rpc/generator/template_test.go index b3f21d8d..89200090 100644 --- a/tools/goctl/rpc/gen/template_test.go +++ b/tools/goctl/rpc/generator/template_test.go @@ -1,4 +1,4 @@ -package gen +package generator import ( "io/ioutil" @@ -90,3 +90,7 @@ func TestUpdate(t *testing.T) { assert.Nil(t, err) assert.Equal(t, mainTemplate, string(data)) } + +func TestGetCategory(t *testing.T) { + assert.Equal(t, category, GetCategory()) +} diff --git a/tools/goctl/rpc/generator/test.proto b/tools/goctl/rpc/generator/test.proto new file mode 100644 index 00000000..1856d6d8 --- /dev/null +++ b/tools/goctl/rpc/generator/test.proto @@ -0,0 +1,25 @@ +// test proto +syntax = "proto3"; + +package test; +option go_package = "go"; + +import "test_base.proto"; + +message TestMessage{ + base.CommonReq req = 1; +} +message TestReq{} +message TestReply{ + base.CommonReply reply = 2; +} + +enum TestEnum { + unknown = 0; + male = 1; + female = 2; +} + +service TestService{ + rpc TestRpc (TestReq)returns(TestReply); +} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_base.proto b/tools/goctl/rpc/generator/test_base.proto new file mode 100644 index 00000000..36e0ca5d --- /dev/null +++ b/tools/goctl/rpc/generator/test_base.proto @@ -0,0 +1,12 @@ +// test proto +syntax = "proto3"; + +package base; + +message CommonReq { + string in = 1; +} + +message CommonReply { + string out = 1; +} diff --git a/tools/goctl/rpc/generator/test_go_option.proto b/tools/goctl/rpc/generator/test_go_option.proto new file mode 100644 index 00000000..9c4397b7 --- /dev/null +++ b/tools/goctl/rpc/generator/test_go_option.proto @@ -0,0 +1,18 @@ +// test proto +syntax = "proto3"; + +package stream; + +option go_package="go"; + +message StreamReq { + string name = 1; +} + +message StreamResp { + string greet = 1; +} + +service StreamGreeter { + rpc greet(StreamReq) returns (StreamResp); +} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_import.proto b/tools/goctl/rpc/generator/test_import.proto new file mode 100644 index 00000000..4d727aee --- /dev/null +++ b/tools/goctl/rpc/generator/test_import.proto @@ -0,0 +1,18 @@ +// test proto +syntax = "proto3"; + +package greet; +import "base/common.proto"; + +message In { + string name = 1; + common.User user = 2; +} + +message Out { + string greet = 1; +} + +service StreamGreeter { + rpc greet(In) returns (Out); +} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_option.proto b/tools/goctl/rpc/generator/test_option.proto new file mode 100644 index 00000000..30df31ab --- /dev/null +++ b/tools/goctl/rpc/generator/test_option.proto @@ -0,0 +1,18 @@ +// test proto +syntax = "proto3"; + +package stream; + +option go_package="github.com/tal-tech/go-zero"; + +message StreamReq { + string name = 1; +} + +message StreamResp { + string greet = 1; +} + +service StreamGreeter { + rpc greet(StreamReq) returns (StreamResp); +} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_stream.proto b/tools/goctl/rpc/generator/test_stream.proto new file mode 100644 index 00000000..5006d77d --- /dev/null +++ b/tools/goctl/rpc/generator/test_stream.proto @@ -0,0 +1,16 @@ +// test proto +syntax = "proto3"; + +package stream; + +message StreamReq { + string name = 1; +} + +message StreamResp { + string greet = 1; +} + +service StreamGreeter { + rpc greet(StreamReq) returns (StreamResp); +} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_word_option.proto b/tools/goctl/rpc/generator/test_word_option.proto new file mode 100644 index 00000000..4f7753a5 --- /dev/null +++ b/tools/goctl/rpc/generator/test_word_option.proto @@ -0,0 +1,18 @@ +// test proto +syntax = "proto3"; + +package stream; + +option go_package="user"; + +message StreamReq { + string name = 1; +} + +message StreamResp { + string greet = 1; +} + +service StreamGreeter { + rpc greet(StreamReq) returns (StreamResp); +} \ No newline at end of file diff --git a/tools/goctl/rpc/parser/comment.go b/tools/goctl/rpc/parser/comment.go new file mode 100644 index 00000000..ad7bccb8 --- /dev/null +++ b/tools/goctl/rpc/parser/comment.go @@ -0,0 +1,10 @@ +package parser + +import "github.com/emicklei/proto" + +func GetComment(comment *proto.Comment) string { + if comment == nil { + return "" + } + return comment.Message() +} diff --git a/tools/goctl/rpc/parser/import.go b/tools/goctl/rpc/parser/import.go new file mode 100644 index 00000000..5095d52e --- /dev/null +++ b/tools/goctl/rpc/parser/import.go @@ -0,0 +1,7 @@ +package parser + +import "github.com/emicklei/proto" + +type Import struct { + *proto.Import +} diff --git a/tools/goctl/rpc/parser/message.go b/tools/goctl/rpc/parser/message.go new file mode 100644 index 00000000..5f5a7196 --- /dev/null +++ b/tools/goctl/rpc/parser/message.go @@ -0,0 +1,7 @@ +package parser + +import pr "github.com/emicklei/proto" + +type Message struct { + *pr.Message +} diff --git a/tools/goctl/rpc/parser/option.go b/tools/goctl/rpc/parser/option.go new file mode 100644 index 00000000..06f34cf4 --- /dev/null +++ b/tools/goctl/rpc/parser/option.go @@ -0,0 +1,7 @@ +package parser + +import "github.com/emicklei/proto" + +type Option struct { + *proto.Option +} diff --git a/tools/goctl/rpc/parser/package.go b/tools/goctl/rpc/parser/package.go new file mode 100644 index 00000000..6a7abf88 --- /dev/null +++ b/tools/goctl/rpc/parser/package.go @@ -0,0 +1,7 @@ +package parser + +import "github.com/emicklei/proto" + +type Package struct { + *proto.Package +} diff --git a/tools/goctl/rpc/parser/parser.go b/tools/goctl/rpc/parser/parser.go index 16dec9fd..764f647d 100644 --- a/tools/goctl/rpc/parser/parser.go +++ b/tools/goctl/rpc/parser/parser.go @@ -1,46 +1,170 @@ package parser import ( + "errors" + "fmt" + "go/token" + "os" "path/filepath" "strings" + "unicode" + "unicode/utf8" - "github.com/tal-tech/go-zero/core/lang" - "github.com/tal-tech/go-zero/tools/goctl/util/console" + "github.com/emicklei/proto" ) -func Transfer(proto, target string, externalImport []*Import, console console.Console) (*PbAst, error) { - messageM := make(map[string]lang.PlaceholderType) - enumM := make(map[string]*Enum) - protoAst, err := parseProto(proto, messageM, enumM) - if err != nil { - return nil, err - } - for _, item := range externalImport { - err = checkImport(item.OriginalProtoPath) - if err != nil { - return nil, err - } - innerAst, err := parseProto(item.OriginalProtoPath, protoAst.Message, protoAst.Enum) - if err != nil { - return nil, err - } - for k, v := range innerAst.Message { - protoAst.Message[k] = v - } - for k, v := range innerAst.Enum { - protoAst.Enum[k] = v - } - } - protoAst.Import = externalImport - protoAst.PbSrc = filepath.Join(target, strings.TrimSuffix(filepath.Base(proto), ".proto")+".pb.go") - return transfer(protoAst, console) +type ( + defaultProtoParser struct{} +) + +func NewDefaultProtoParser() *defaultProtoParser { + return &defaultProtoParser{} } -func transfer(proto *Proto, console console.Console) (*PbAst, error) { - parser := MustNewAstParser(proto, console) - parse, err := parser.Parse() +func (p *defaultProtoParser) Parse(src string) (Proto, error) { + var ret Proto + + abs, err := filepath.Abs(src) if err != nil { - return nil, err + return Proto{}, err } - return parse, nil + + r, err := os.Open(abs) + if err != nil { + return ret, err + } + defer r.Close() + + parser := proto.NewParser(r) + set, err := parser.Parse() + if err != nil { + return ret, err + } + + var serviceList []Service + proto.Walk( + set, + proto.WithImport(func(i *proto.Import) { + ret.Import = append(ret.Import, Import{Import: i}) + }), + proto.WithMessage(func(message *proto.Message) { + ret.Message = append(ret.Message, Message{Message: message}) + }), + proto.WithPackage(func(p *proto.Package) { + ret.Package = Package{Package: p} + }), + proto.WithService(func(service *proto.Service) { + serv := Service{Service: service} + elements := service.Elements + for _, el := range elements { + v, _ := el.(*proto.RPC) + if v == nil { + continue + } + serv.RPC = append(serv.RPC, &RPC{RPC: v}) + } + + serviceList = append(serviceList, serv) + }), + proto.WithOption(func(option *proto.Option) { + if option.Name == "go_package" { + ret.GoPackage = option.Constant.Source + } + }), + ) + if len(serviceList) == 0 { + return ret, errors.New("rpc service not found") + } + + if len(serviceList) > 1 { + return ret, errors.New("only one service expected") + } + service := serviceList[0] + name := filepath.Base(abs) + + for _, rpc := range service.RPC { + if strings.Contains(rpc.RequestType, ".") { + return ret, fmt.Errorf("line %v:%v, request type must defined in %s", rpc.Position.Line, rpc.Position.Column, name) + } + if strings.Contains(rpc.ReturnsType, ".") { + return ret, fmt.Errorf("line %v:%v, returns type must defined in %s", rpc.Position.Line, rpc.Position.Column, name) + } + } + if len(ret.GoPackage) == 0 { + ret.GoPackage = ret.Package.Name + } + ret.PbPackage = GoSanitized(filepath.Base(ret.GoPackage)) + ret.Src = abs + ret.Name = name + ret.Service = service + + return ret, nil +} + +// see google.golang.org/protobuf@v1.25.0/internal/strs/strings.go:71 +func GoSanitized(s string) string { + // Sanitize the input to the set of valid characters, + // which must be '_' or be in the Unicode L or N categories. + s = strings.Map(func(r rune) rune { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + return r + } + return '_' + }, s) + + // Prepend '_' in the event of a Go keyword conflict or if + // the identifier is invalid (does not start in the Unicode L category). + r, _ := utf8.DecodeRuneInString(s) + if token.Lookup(s).IsKeyword() || !unicode.IsLetter(r) { + return "_" + s + } + return s +} + +// copy from github.com/golang/protobuf@v1.4.2/protoc-gen-go/generator/generator.go:2648 +func CamelCase(s string) string { + if s == "" { + return "" + } + t := make([]byte, 0, 32) + i := 0 + if s[0] == '_' { + // Need a capital letter; drop the '_'. + t = append(t, 'X') + i++ + } + // Invariant: if the next letter is lower case, it must be converted + // to upper case. + // That is, we process a word at a time, where words are marked by _ or + // upper case letter. Digits are treated as words. + for ; i < len(s); i++ { + c := s[i] + if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) { + continue // Skip the underscore in s. + } + if isASCIIDigit(c) { + t = append(t, c) + continue + } + // Assume we have a letter now - if not, it's a bogus identifier. + // The next word is a sequence of characters that must start upper case. + if isASCIILower(c) { + c ^= ' ' // Make it a capital letter. + } + t = append(t, c) // Guaranteed not lower case. + // Accept lower case sequence that follows. + for i+1 < len(s) && isASCIILower(s[i+1]) { + i++ + t = append(t, s[i]) + } + } + return string(t) +} +func isASCIILower(c byte) bool { + return 'a' <= c && c <= 'z' +} + +// Is c an ASCII digit? +func isASCIIDigit(c byte) bool { + return '0' <= c && c <= '9' } diff --git a/tools/goctl/rpc/parser/parser_test.go b/tools/goctl/rpc/parser/parser_test.go new file mode 100644 index 00000000..2674a8d4 --- /dev/null +++ b/tools/goctl/rpc/parser/parser_test.go @@ -0,0 +1,78 @@ +package parser + +import ( + "sort" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultProtoParse(t *testing.T) { + p := NewDefaultProtoParser() + data, err := p.Parse("./test.proto") + assert.Nil(t, err) + assert.Equal(t, "base.proto", func() string { + ip := data.Import[0] + return ip.Filename + }()) + assert.Equal(t, "test", data.Package.Name) + assert.Equal(t, true, data.GoPackage == "go") + assert.Equal(t, true, data.PbPackage == "_go") + assert.Equal(t, []string{"TestMessage", "TestReply", "TestReq"}, func() []string { + var list []string + for _, item := range data.Message { + list = append(list, item.Name) + } + sort.Strings(list) + return list + }()) + + assert.Equal(t, true, func() bool { + s := data.Service + if s.Name != "TestService" { + return false + } + rpcOne := s.RPC[0] + + return rpcOne.Name == "TestRpcOne" && rpcOne.RequestType == "TestReq" && rpcOne.ReturnsType == "TestReply" + }()) +} + +func TestDefaultProtoParseCaseInvalidRequestType(t *testing.T) { + p := NewDefaultProtoParser() + _, err := p.Parse("./test_invalid_request.proto") + assert.True(t, true, func() bool { + return strings.Contains(err.Error(), "request type must defined in") + }()) +} + +func TestDefaultProtoParseCaseInvalidResponseType(t *testing.T) { + p := NewDefaultProtoParser() + _, err := p.Parse("./test_invalid_response.proto") + assert.True(t, true, func() bool { + return strings.Contains(err.Error(), "response type must defined in") + }()) +} + +func TestDefaultProtoParseError(t *testing.T) { + p := NewDefaultProtoParser() + _, err := p.Parse("./nil.proto") + assert.NotNil(t, err) +} + +func TestDefaultProtoParse_Option(t *testing.T) { + p := NewDefaultProtoParser() + data, err := p.Parse("./test_option.proto") + assert.Nil(t, err) + assert.Equal(t, "github.com/tal-tech/go-zero", data.GoPackage) + assert.Equal(t, "go_zero", data.PbPackage) +} + +func TestDefaultProtoParse_Option2(t *testing.T) { + p := NewDefaultProtoParser() + data, err := p.Parse("./test_option2.proto") + assert.Nil(t, err) + assert.Equal(t, "stream", data.GoPackage) + assert.Equal(t, "stream", data.PbPackage) +} diff --git a/tools/goctl/rpc/parser/pbast.go b/tools/goctl/rpc/parser/pbast.go deleted file mode 100644 index 7bed50a6..00000000 --- a/tools/goctl/rpc/parser/pbast.go +++ /dev/null @@ -1,643 +0,0 @@ -package parser - -import ( - "errors" - "fmt" - "go/ast" - "go/parser" - "go/token" - "io/ioutil" - "sort" - "strings" - - "github.com/tal-tech/go-zero/core/lang" - sx "github.com/tal-tech/go-zero/core/stringx" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/console" - "github.com/tal-tech/go-zero/tools/goctl/util/stringx" -) - -const ( - flagStar = "*" - flagDot = "." - suffixServer = "Server" - referenceContext = "context" - unknownPrefix = "XXX_" - ignoreJsonTagExpression = `json:"-"` -) - -var ( - errorParseError = errors.New("pb parse error") - typeTemplate = `type ( - {{.types}} -)` - structTemplate = `{{if .type}}type {{end}}{{.name}} struct { - {{.fields}} -}` - fieldTemplate = `{{if .hasDoc}}{{.doc}} -{{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}` - - anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}" - - objectM = make(map[string]*Struct) -) - -type ( - astParser struct { - filterStruct map[string]lang.PlaceholderType - filterEnum map[string]*Enum - console.Console - fileSet *token.FileSet - proto *Proto - } - Field struct { - Name stringx.String - Type Type - JsonTag string - Document []string - Comment []string - } - Struct struct { - Name stringx.String - Document []string - Comment []string - Field []*Field - } - ConstLit struct { - Name stringx.String - Document []string - Comment []string - Lit []*Lit - } - Lit struct { - Key string - Value int - } - Type struct { - // eg:context.Context - Expression string - // eg: *context.Context - StarExpression string - // Invoke Type Expression - InvokeTypeExpression string - // eg:context - Package string - // eg:Context - Name string - } - Func struct { - Name stringx.String - ParameterIn Type - ParameterOut Type - Document []string - } - RpcService struct { - Name stringx.String - Funcs []*Func - } - // parsing for rpc - PbAst struct { - // deprecated: containsAny will be removed in the feature - ContainsAny bool - Imports map[string]string - Structure map[string]*Struct - Service []*RpcService - *Proto - } -) - -func MustNewAstParser(proto *Proto, log console.Console) *astParser { - return &astParser{ - filterStruct: proto.Message, - filterEnum: proto.Enum, - Console: log, - fileSet: token.NewFileSet(), - proto: proto, - } -} -func (a *astParser) Parse() (*PbAst, error) { - var pbAst PbAst - pbAst.ContainsAny = a.proto.ContainsAny - pbAst.Proto = a.proto - pbAst.Structure = make(map[string]*Struct) - pbAst.Imports = make(map[string]string) - structure, imports, services, err := a.parse(a.proto.PbSrc) - if err != nil { - return nil, err - } - dependencyStructure, err := a.parseExternalDependency() - if err != nil { - return nil, err - } - for k, v := range structure { - pbAst.Structure[k] = v - } - for k, v := range dependencyStructure { - pbAst.Structure[k] = v - } - for key, path := range imports { - pbAst.Imports[key] = path - } - pbAst.Service = append(pbAst.Service, services...) - return &pbAst, nil -} - -func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) { - structure = make(map[string]*Struct) - imports = make(map[string]string) - data, err := ioutil.ReadFile(pbSrc) - if err != nil { - retErr = err - return - } - fSet := a.fileSet - f, err := parser.ParseFile(fSet, "", data, parser.ParseComments) - if err != nil { - retErr = err - return - } - commentMap := ast.NewCommentMap(fSet, f, f.Comments) - f.Comments = commentMap.Filter(f).Comments() - strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name)) - for k, v := range strucs { - if v == nil { - continue - } - structure[k] = v - } - importList := f.Imports - for _, item := range importList { - name := a.mustGetIndentName(item.Name) - if item.Path != nil { - imports[name] = item.Path.Value - } - } - services = append(services, function...) - return -} -func (a *astParser) parseExternalDependency() (map[string]*Struct, error) { - m := make(map[string]*Struct) - for _, impo := range a.proto.Import { - ret, _, _, err := a.parse(impo.OriginalPbPath) - if err != nil { - return nil, err - } - for k, v := range ret { - m[k] = v - } - } - return m, nil -} - -func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) { - if scope == nil { - return nil, nil - } - - objects := scope.Objects - structs := make(map[string]*Struct) - serviceList := make([]*RpcService, 0) - for name, obj := range objects { - decl := obj.Decl - if decl == nil { - continue - } - typeSpec, ok := decl.(*ast.TypeSpec) - if !ok { - continue - } - tp := typeSpec.Type - - switch v := tp.(type) { - - case *ast.StructType: - st, err := a.parseObject(name, v, sourcePackage) - a.Must(err) - structs[st.Name.Lower()] = st - - case *ast.InterfaceType: - if !strings.HasSuffix(name, suffixServer) { - continue - } - list := a.mustServerFunctions(v, sourcePackage) - serviceList = append(serviceList, &RpcService{ - Name: stringx.From(strings.TrimSuffix(name, suffixServer)), - Funcs: list, - }) - } - } - targetStruct := make(map[string]*Struct) - for st := range a.filterStruct { - lower := strings.ToLower(st) - targetStruct[lower] = structs[lower] - } - return targetStruct, serviceList -} - -func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func { - funcs := make([]*Func, 0) - methodObject := v.Methods - if methodObject == nil { - return nil - } - - for _, method := range methodObject.List { - var item Func - name := a.mustGetIndentName(method.Names[0]) - doc := a.parseCommentOrDoc(method.Doc) - item.Name = stringx.From(name) - item.Document = doc - types := method.Type - if types == nil { - funcs = append(funcs, &item) - continue - } - v, ok := types.(*ast.FuncType) - if !ok { - continue - } - params := v.Params - if params != nil { - inList, err := a.parseFields(params.List, true, sourcePackage) - a.Must(err) - - for _, data := range inList { - if data.Type.Package == referenceContext { - continue - } - item.ParameterIn = data.Type - break - } - } - results := v.Results - if results != nil { - outList, err := a.parseFields(results.List, true, sourcePackage) - a.Must(err) - - for _, data := range outList { - if data.Type.Package == referenceContext { - continue - } - item.ParameterOut = data.Type - break - } - } - funcs = append(funcs, &item) - } - return funcs -} - -func (a *astParser) getFieldType(v string, sourcePackage string) Type { - var pkg, name, expression, starExpression, invokeTypeExpression string - - if strings.Contains(v, ".") { - starExpression = v - if strings.Contains(v, "*") { - leftIndex := strings.Index(v, "*") - rightIndex := strings.Index(v, ".") - if leftIndex >= 0 { - invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:] - } else { - invokeTypeExpression = v[rightIndex+1:] - } - } else { - if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") { - leftIndex := strings.Index(v, "]") - rightIndex := strings.Index(v, ".") - invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:] - } else { - rightIndex := strings.Index(v, ".") - invokeTypeExpression = v[rightIndex+1:] - } - } - } else { - expression = strings.TrimPrefix(v, flagStar) - switch v { - case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64", - "bool", "string", "bytes": - invokeTypeExpression = v - break - default: - name = expression - invokeTypeExpression = v - if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") { - starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".") - } else { - starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name) - invokeTypeExpression = v - } - - } - } - expression = strings.TrimPrefix(starExpression, flagStar) - index := strings.LastIndex(expression, flagDot) - if index > 0 { - pkg = expression[0:index] - name = expression[index+1:] - } else { - pkg = sourcePackage - } - - return Type{ - Expression: expression, - StarExpression: starExpression, - InvokeTypeExpression: invokeTypeExpression, - Package: pkg, - Name: name, - } -} - -func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) { - if data, ok := objectM[structName]; ok { - return data, nil - } - var st Struct - st.Name = stringx.From(structName) - if tp == nil { - return &st, nil - } - - fields := tp.Fields - if fields == nil { - objectM[structName] = &st - return &st, nil - } - - fieldList := fields.List - members, err := a.parseFields(fieldList, false, sourcePackage) - if err != nil { - return nil, err - } - - for _, m := range members { - var field Field - field.Name = m.Name - field.Type = m.Type - field.JsonTag = m.JsonTag - field.Document = m.Document - field.Comment = m.Comment - st.Field = append(st.Field, &field) - } - objectM[structName] = &st - return &st, nil -} - -func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) { - ret := make([]*Field, 0) - for _, field := range fields { - var item Field - tag := a.parseTag(field.Tag) - if tag == "" && !onlyType { - continue - } - if tag == ignoreJsonTagExpression { - continue - } - - item.JsonTag = tag - name := a.parseName(field.Names) - if strings.HasPrefix(name, unknownPrefix) { - continue - } - item.Name = stringx.From(name) - typeName, err := a.parseType(field.Type) - if err != nil { - return nil, err - } - - item.Type = a.getFieldType(typeName, sourcePackage) - if onlyType { - ret = append(ret, &item) - continue - } - docs := a.parseCommentOrDoc(field.Doc) - comments := a.parseCommentOrDoc(field.Comment) - - item.Document = docs - item.Comment = comments - - isInline := name == "" - if isInline { - return nil, a.wrapError(field.Pos(), "unexpected inline type:%s", name) - } - - ret = append(ret, &item) - - } - return ret, nil -} - -func (a *astParser) parseTag(basicLit *ast.BasicLit) string { - if basicLit == nil { - return "" - } - value := basicLit.Value - splits := strings.Split(value, " ") - if len(splits) == 1 { - return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[0], "`", "")) - } else { - return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[1], "`", "")) - } -} - -// returns -// resp1:type's string expression,like int、string、[]int64、map[string]User、*User -// resp2:error -func (a *astParser) parseType(expr ast.Expr) (string, error) { - if expr == nil { - return "", errorParseError - } - - switch v := expr.(type) { - case *ast.StarExpr: - stringExpr, err := a.parseType(v.X) - if err != nil { - return "", err - } - - e := fmt.Sprintf("*%s", stringExpr) - return e, nil - - case *ast.Ident: - return a.mustGetIndentName(v), nil - case *ast.MapType: - keyStringExpr, err := a.parseType(v.Key) - if err != nil { - return "", err - } - - valueStringExpr, err := a.parseType(v.Value) - if err != nil { - return "", err - } - - e := fmt.Sprintf("map[%s]%s", keyStringExpr, valueStringExpr) - return e, nil - case *ast.ArrayType: - stringExpr, err := a.parseType(v.Elt) - if err != nil { - return "", err - } - - e := fmt.Sprintf("[]%s", stringExpr) - return e, nil - case *ast.InterfaceType: - return "interface{}", nil - case *ast.SelectorExpr: - join := make([]string, 0) - xIdent, ok := v.X.(*ast.Ident) - xIndentName := a.mustGetIndentName(xIdent) - if ok { - join = append(join, xIndentName) - } - sel := v.Sel - join = append(join, a.mustGetIndentName(sel)) - return strings.Join(join, "."), nil - case *ast.ChanType: - return "", a.wrapError(v.Pos(), "unexpected type 'chan'") - case *ast.FuncType: - return "", a.wrapError(v.Pos(), "unexpected type 'func'") - case *ast.StructType: - return "", a.wrapError(v.Pos(), "unexpected inline struct type") - default: - return "", a.wrapError(v.Pos(), "unexpected type '%v'", v) - } -} -func (a *astParser) parseName(names []*ast.Ident) string { - if len(names) == 0 { - return "" - } - name := names[0] - return a.mustGetIndentName(name) -} - -func (a *astParser) parseCommentOrDoc(cg *ast.CommentGroup) []string { - if cg == nil { - return nil - } - comments := make([]string, 0) - for _, comment := range cg.List { - if comment == nil { - continue - } - text := strings.TrimSpace(comment.Text) - if text == "" { - continue - } - comments = append(comments, text) - } - return comments -} - -func (a *astParser) mustGetIndentName(ident *ast.Ident) string { - if ident == nil { - return "" - } - return ident.Name -} - -func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{}) error { - file := a.fileSet.Position(pos) - return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...)) -} - -func (f *Func) GetDoc() string { - return strings.Join(f.Document, util.NL) -} - -func (f *Func) HaveDoc() bool { - return len(f.Document) > 0 -} - -func (a *PbAst) GenEnumCode() (string, error) { - var element []string - for _, item := range a.Enum { - code, err := item.GenEnumCode() - if err != nil { - return "", err - } - element = append(element, code) - } - return strings.Join(element, util.NL), nil -} - -func (a *PbAst) GenTypesCode() (string, error) { - types := make([]string, 0) - sts := make([]*Struct, 0) - for _, item := range a.Structure { - sts = append(sts, item) - } - sort.Slice(sts, func(i, j int) bool { - return sts[i].Name.Source() < sts[j].Name.Source() - }) - for _, s := range sts { - structCode, err := s.genCode(false) - if err != nil { - return "", err - } - - if structCode == "" { - continue - } - types = append(types, structCode) - } - types = append(types, a.genAnyCode()) - for _, item := range a.Enum { - typeCode, err := item.GenEnumTypeCode() - if err != nil { - return "", err - } - types = append(types, typeCode) - } - - buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{ - "types": strings.Join(types, util.NL+util.NL), - }) - if err != nil { - return "", err - } - - return buffer.String(), nil -} - -func (a *PbAst) genAnyCode() string { - if !a.ContainsAny { - return "" - } - return anyTypeTemplate -} - -func (s *Struct) genCode(containsTypeStatement bool) (string, error) { - fields := make([]string, 0) - for _, f := range s.Field { - var comment, doc string - if len(f.Comment) > 0 { - comment = f.Comment[0] - } - doc = strings.Join(f.Document, util.NL) - buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{ - "name": f.Name.Title(), - "type": f.Type.InvokeTypeExpression, - "tag": f.JsonTag, - "hasDoc": len(f.Document) > 0, - "doc": doc, - "hasComment": len(f.Comment) > 0, - "comment": comment, - }) - if err != nil { - return "", err - } - - fields = append(fields, buffer.String()) - } - buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{ - "type": containsTypeStatement, - "name": s.Name.Title(), - "fields": strings.Join(fields, util.NL), - }) - if err != nil { - return "", err - } - - return buffer.String(), nil -} diff --git a/tools/goctl/rpc/parser/proto.go b/tools/goctl/rpc/parser/proto.go index cf5dbea7..4e23b8e5 100644 --- a/tools/goctl/rpc/parser/proto.go +++ b/tools/goctl/rpc/parser/proto.go @@ -1,295 +1,12 @@ package parser -import ( - "errors" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/emicklei/proto" - "github.com/tal-tech/go-zero/core/collection" - "github.com/tal-tech/go-zero/core/lang" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/stringx" -) - -const ( - AnyImport = "google/protobuf/any.proto" -) - -var ( - enumTypeTemplate = `{{.name}} int32` - enumTemplate = `const ( - {{.element}} -)` - enumFiledTemplate = `{{.key}} {{.name}} = {{.value}}` -) - -type ( - MessageField struct { - Type string - Name stringx.String - } - Message struct { - Name stringx.String - Element []*MessageField - *proto.Message - } - Enum struct { - Name stringx.String - Element []*EnumField - *proto.Enum - } - EnumField struct { - Key string - Value int - } - - Proto struct { - Package string - Import []*Import - PbSrc string - // deprecated: containsAny will be removed in the feature - ContainsAny bool - Message map[string]lang.PlaceholderType - Enum map[string]*Enum - } - Import struct { - ProtoImportName string - PbImportName string - OriginalDir string - OriginalProtoPath string - OriginalPbPath string - BridgeImport string - exists bool - //xx.proto - protoName string - // xx.pb.go - pbName string - } -) - -func checkImport(src string) error { - r, err := os.Open(src) - if err != nil { - return err - } - defer r.Close() - - parser := proto.NewParser(r) - parseRet, err := parser.Parse() - if err != nil { - return err - } - var base = filepath.Base(src) - proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) { - if err != nil { - return - } - err = fmt.Errorf("%v:%v the external proto cannot import other proto files", base, i.Position.Line) - })) - if err != nil { - return err - } - return nil -} -func ParseImport(src string) ([]*Import, bool, error) { - bridgeImportM := make(map[string]string) - r, err := os.Open(src) - if err != nil { - return nil, false, err - } - defer r.Close() - - workDir := filepath.Dir(src) - parser := proto.NewParser(r) - parseRet, err := parser.Parse() - if err != nil { - return nil, false, err - } - protoImportSet := collection.NewSet() - var containsAny bool - proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) { - if i.Filename == AnyImport { - containsAny = true - return - } - protoImportSet.AddStr(i.Filename) - if i.Comment != nil { - lines := i.Comment.Lines - for _, line := range lines { - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "@") { - continue - } - line = strings.TrimPrefix(line, "@") - bridgeImportM[i.Filename] = line - } - } - })) - var importList []*Import - - for _, item := range protoImportSet.KeysStr() { - pb := strings.TrimSuffix(filepath.Base(item), filepath.Ext(item)) + ".pb.go" - var pbImportName, brideImport string - if v, ok := bridgeImportM[item]; ok { - pbImportName = v - brideImport = "M" + item + "=" + v - } else { - pbImportName = item - } - var impo = Import{ - ProtoImportName: item, - PbImportName: pbImportName, - BridgeImport: brideImport, - } - protoSource := filepath.Join(workDir, item) - pbSource := filepath.Join(filepath.Dir(protoSource), pb) - if util.FileExists(protoSource) && util.FileExists(pbSource) { - impo.OriginalProtoPath = protoSource - impo.OriginalPbPath = pbSource - impo.OriginalDir = filepath.Dir(protoSource) - impo.exists = true - impo.protoName = filepath.Base(item) - impo.pbName = pb - } else { - return nil, false, fmt.Errorf("「%v」: import must be found in the relative directory of 「%v」", item, filepath.Base(src)) - } - importList = append(importList, &impo) - } - - return importList, containsAny, nil -} - -func parseProto(src string, messageM map[string]lang.PlaceholderType, enumM map[string]*Enum) (*Proto, error) { - if !filepath.IsAbs(src) { - return nil, fmt.Errorf("expected absolute path,but found: %v", src) - } - - r, err := os.Open(src) - if err != nil { - return nil, err - } - defer r.Close() - - parser := proto.NewParser(r) - parseRet, err := parser.Parse() - if err != nil { - return nil, err - } - - // xx.proto - fileBase := filepath.Base(src) - var resp Proto - - proto.Walk(parseRet, proto.WithPackage(func(p *proto.Package) { - if err != nil { - return - } - - if len(resp.Package) != 0 { - err = fmt.Errorf("%v:%v duplicate package「%v」", fileBase, p.Position.Line, p.Name) - } - - if len(p.Name) == 0 { - err = errors.New("package not found") - } - - resp.Package = p.Name - }), proto.WithMessage(func(message *proto.Message) { - if err != nil { - return - } - - for _, item := range message.Elements { - switch item.(type) { - case *proto.NormalField, *proto.MapField, *proto.Comment: - continue - default: - err = fmt.Errorf("%v: unsupport inline declaration", fileBase) - return - } - } - name := stringx.From(message.Name) - if _, ok := messageM[name.Lower()]; ok { - err = fmt.Errorf("%v:%v duplicate message 「%v」", fileBase, message.Position.Line, message.Name) - return - } - - messageM[name.Lower()] = lang.Placeholder - }), proto.WithEnum(func(enum *proto.Enum) { - if err != nil { - return - } - - var node Enum - node.Enum = enum - node.Name = stringx.From(enum.Name) - for _, item := range enum.Elements { - v, ok := item.(*proto.EnumField) - if !ok { - continue - } - node.Element = append(node.Element, &EnumField{ - Key: v.Name, - Value: v.Integer, - }) - } - if _, ok := enumM[node.Name.Lower()]; ok { - err = fmt.Errorf("%v:%v duplicate enum 「%v」", fileBase, node.Position.Line, node.Name.Source()) - return - } - - lower := stringx.From(enum.Name).Lower() - enumM[lower] = &node - })) - - if err != nil { - return nil, err - } - resp.Message = messageM - resp.Enum = enumM - - return &resp, nil -} - -func (e *Enum) GenEnumCode() (string, error) { - var element []string - for _, item := range e.Element { - code, err := item.GenEnumFieldCode(e.Name.Source()) - if err != nil { - return "", err - } - element = append(element, code) - } - buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{ - "element": strings.Join(element, util.NL), - }) - if err != nil { - return "", err - } - return buffer.String(), nil -} - -func (e *Enum) GenEnumTypeCode() (string, error) { - buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{ - "name": e.Name.Source(), - }) - if err != nil { - return "", err - } - return buffer.String(), nil -} - -func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) { - buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{ - "key": e.Key, - "name": parentName, - "value": e.Value, - }) - if err != nil { - return "", err - } - return buffer.String(), nil +type Proto struct { + Src string + Name string + Package Package + PbPackage string + GoPackage string + Import []Import + Message []Message + Service Service } diff --git a/tools/goctl/rpc/parser/rpc.go b/tools/goctl/rpc/parser/rpc.go new file mode 100644 index 00000000..bf4cd8ed --- /dev/null +++ b/tools/goctl/rpc/parser/rpc.go @@ -0,0 +1,7 @@ +package parser + +import "github.com/emicklei/proto" + +type RPC struct { + *proto.RPC +} diff --git a/tools/goctl/rpc/parser/service.go b/tools/goctl/rpc/parser/service.go new file mode 100644 index 00000000..9ace8dfb --- /dev/null +++ b/tools/goctl/rpc/parser/service.go @@ -0,0 +1,8 @@ +package parser + +import "github.com/emicklei/proto" + +type Service struct { + *proto.Service + RPC []*RPC +} diff --git a/tools/goctl/rpc/parser/test.proto b/tools/goctl/rpc/parser/test.proto new file mode 100644 index 00000000..45643864 --- /dev/null +++ b/tools/goctl/rpc/parser/test.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package test; +option go_package = "go"; + +import "base.proto"; + +message TestMessage{} +message TestReq{} +message TestReply{} + +enum TestEnum { + unknown = 0; + male = 1; + female = 2; +} + +service TestService{ + rpc TestRpcOne (TestReq)returns(TestReply); +} \ No newline at end of file diff --git a/tools/goctl/rpc/parser/test_invalid_request.proto b/tools/goctl/rpc/parser/test_invalid_request.proto new file mode 100644 index 00000000..cdb1e438 --- /dev/null +++ b/tools/goctl/rpc/parser/test_invalid_request.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package test; +option go_package = "go"; + +import "base.proto"; + +message Reply{} + + +service TestService{ + rpc TestRpcTwo (base.Req)returns(Reply); +} \ No newline at end of file diff --git a/tools/goctl/rpc/parser/test_invalid_response.proto b/tools/goctl/rpc/parser/test_invalid_response.proto new file mode 100644 index 00000000..2adf5d1d --- /dev/null +++ b/tools/goctl/rpc/parser/test_invalid_response.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package test; +option go_package = "go"; + +import "base.proto"; + +message Req{} + + +service TestService{ + rpc TestRpcTwo (Req)returns(base.Reply); +} \ No newline at end of file diff --git a/tools/goctl/rpc/parser/test_option.proto b/tools/goctl/rpc/parser/test_option.proto new file mode 100644 index 00000000..9953136a --- /dev/null +++ b/tools/goctl/rpc/parser/test_option.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package stream; + +option go_package="github.com/tal-tech/go-zero"; + +message placeholder{} +service greet{ + rpc hello(placeholder)returns(placeholder); +} \ No newline at end of file diff --git a/tools/goctl/rpc/parser/test_option2.proto b/tools/goctl/rpc/parser/test_option2.proto new file mode 100644 index 00000000..aa680cdc --- /dev/null +++ b/tools/goctl/rpc/parser/test_option2.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package stream; + + +message placeholder{} +service greet{ + rpc hello(placeholder)returns(placeholder); +} \ No newline at end of file diff --git a/tools/goctl/rpc/test.proto b/tools/goctl/rpc/test.proto deleted file mode 100644 index d98ec884..00000000 --- a/tools/goctl/rpc/test.proto +++ /dev/null @@ -1,28 +0,0 @@ -syntax = "proto3"; -// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test -package test; -// @github.com/tal-tech/go-zero/tools/goctl/rpc -import "base.proto"; -import "google/protobuf/any.proto"; - -message request { - string name = 1; -} -enum Gender{ - UNKNOWN = 0; - MALE = 1; - FEMALE = 2; -} -message response { - string greet = 1; - google.protobuf.Any data = 2; - -} -message map { - map m = 1; -} - -service Greeter { - rpc greet(request) returns (response); - rpc idRequest(base.IdRequest)returns(base.EmptyResponse); -} \ No newline at end of file diff --git a/tools/goctl/rpc/test/test.pb.go b/tools/goctl/rpc/test/test.pb.go deleted file mode 100644 index 298ec08d..00000000 --- a/tools/goctl/rpc/test/test.pb.go +++ /dev/null @@ -1,331 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// source: test.proto - -// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test - -package test - -import ( - context "context" - fmt "fmt" - proto "github.com/golang/protobuf/proto" - rpc "github.com/tal-tech/go-zero/tools/goctl/rpc" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" - anypb "google.golang.org/protobuf/types/known/anypb" - math "math" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package - -type Gender int32 - -const ( - Gender_UNKNOWN Gender = 0 - Gender_MALE Gender = 1 - Gender_FEMALE Gender = 2 -) - -var Gender_name = map[int32]string{ - 0: "UNKNOWN", - 1: "MALE", - 2: "FEMALE", -} - -var Gender_value = map[string]int32{ - "UNKNOWN": 0, - "MALE": 1, - "FEMALE": 2, -} - -func (x Gender) String() string { - return proto.EnumName(Gender_name, int32(x)) -} - -func (Gender) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_c161fcfdc0c3ff1e, []int{0} -} - -type Request struct { - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *Request) Reset() { *m = Request{} } -func (m *Request) String() string { return proto.CompactTextString(m) } -func (*Request) ProtoMessage() {} -func (*Request) Descriptor() ([]byte, []int) { - return fileDescriptor_c161fcfdc0c3ff1e, []int{0} -} - -func (m *Request) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_Request.Unmarshal(m, b) -} -func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_Request.Marshal(b, m, deterministic) -} -func (m *Request) XXX_Merge(src proto.Message) { - xxx_messageInfo_Request.Merge(m, src) -} -func (m *Request) XXX_Size() int { - return xxx_messageInfo_Request.Size(m) -} -func (m *Request) XXX_DiscardUnknown() { - xxx_messageInfo_Request.DiscardUnknown(m) -} - -var xxx_messageInfo_Request proto.InternalMessageInfo - -func (m *Request) GetName() string { - if m != nil { - return m.Name - } - return "" -} - -type Response struct { - Greet string `protobuf:"bytes,1,opt,name=greet,proto3" json:"greet,omitempty"` - Data *anypb.Any `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *Response) Reset() { *m = Response{} } -func (m *Response) String() string { return proto.CompactTextString(m) } -func (*Response) ProtoMessage() {} -func (*Response) Descriptor() ([]byte, []int) { - return fileDescriptor_c161fcfdc0c3ff1e, []int{1} -} - -func (m *Response) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_Response.Unmarshal(m, b) -} -func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_Response.Marshal(b, m, deterministic) -} -func (m *Response) XXX_Merge(src proto.Message) { - xxx_messageInfo_Response.Merge(m, src) -} -func (m *Response) XXX_Size() int { - return xxx_messageInfo_Response.Size(m) -} -func (m *Response) XXX_DiscardUnknown() { - xxx_messageInfo_Response.DiscardUnknown(m) -} - -var xxx_messageInfo_Response proto.InternalMessageInfo - -func (m *Response) GetGreet() string { - if m != nil { - return m.Greet - } - return "" -} - -func (m *Response) GetData() *anypb.Any { - if m != nil { - return m.Data - } - return nil -} - -type Map struct { - M map[string]string `protobuf:"bytes,1,rep,name=m,proto3" json:"m,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *Map) Reset() { *m = Map{} } -func (m *Map) String() string { return proto.CompactTextString(m) } -func (*Map) ProtoMessage() {} -func (*Map) Descriptor() ([]byte, []int) { - return fileDescriptor_c161fcfdc0c3ff1e, []int{2} -} - -func (m *Map) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_Map.Unmarshal(m, b) -} -func (m *Map) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_Map.Marshal(b, m, deterministic) -} -func (m *Map) XXX_Merge(src proto.Message) { - xxx_messageInfo_Map.Merge(m, src) -} -func (m *Map) XXX_Size() int { - return xxx_messageInfo_Map.Size(m) -} -func (m *Map) XXX_DiscardUnknown() { - xxx_messageInfo_Map.DiscardUnknown(m) -} - -var xxx_messageInfo_Map proto.InternalMessageInfo - -func (m *Map) GetM() map[string]string { - if m != nil { - return m.M - } - return nil -} - -func init() { - proto.RegisterEnum("test.Gender", Gender_name, Gender_value) - proto.RegisterType((*Request)(nil), "test.request") - proto.RegisterType((*Response)(nil), "test.response") - proto.RegisterType((*Map)(nil), "test.map") - proto.RegisterMapType((map[string]string)(nil), "test.map.MEntry") -} - -func init() { proto.RegisterFile("test.proto", fileDescriptor_c161fcfdc0c3ff1e) } - -var fileDescriptor_c161fcfdc0c3ff1e = []byte{ - // 301 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x34, 0x90, 0x4f, 0x4b, 0xc3, 0x40, - 0x10, 0xc5, 0xdd, 0x36, 0xa6, 0xed, 0x14, 0x35, 0x8c, 0x3d, 0xd4, 0x80, 0x52, 0x7a, 0x90, 0xa0, - 0xb0, 0xc5, 0xea, 0x41, 0xbc, 0xf5, 0x10, 0x8b, 0x7f, 0x5a, 0x21, 0x20, 0x1e, 0x3c, 0x6d, 0xc9, - 0x58, 0xc4, 0xee, 0x26, 0x6e, 0xb6, 0x42, 0xbe, 0xbd, 0x64, 0x77, 0x73, 0x7b, 0xbf, 0xd9, 0x59, - 0xde, 0x9b, 0x07, 0x60, 0xa8, 0x32, 0xbc, 0xd4, 0x85, 0x29, 0x30, 0x68, 0x74, 0x0c, 0x1b, 0x51, - 0x91, 0x9b, 0xc4, 0x67, 0xdb, 0xa2, 0xd8, 0xee, 0x68, 0x66, 0x69, 0xb3, 0xff, 0x9a, 0x09, 0x55, - 0xbb, 0xa7, 0xe9, 0x39, 0xf4, 0x34, 0xfd, 0xee, 0xa9, 0x32, 0x88, 0x10, 0x28, 0x21, 0x69, 0xcc, - 0x26, 0x2c, 0x19, 0x64, 0x56, 0x4f, 0x9f, 0xa1, 0xaf, 0xa9, 0x2a, 0x0b, 0x55, 0x11, 0x8e, 0xe0, - 0x70, 0xab, 0x89, 0x8c, 0x5f, 0x70, 0x80, 0x09, 0x04, 0xb9, 0x30, 0x62, 0xdc, 0x99, 0xb0, 0x64, - 0x38, 0x1f, 0x71, 0x67, 0xc5, 0x5b, 0x2b, 0xbe, 0x50, 0x75, 0x66, 0x37, 0xa6, 0x9f, 0xd0, 0x95, - 0xa2, 0xc4, 0x0b, 0x60, 0x72, 0xcc, 0x26, 0xdd, 0x64, 0x38, 0x8f, 0xb8, 0x8d, 0x2d, 0x45, 0xc9, - 0x57, 0xa9, 0x32, 0xba, 0xce, 0x98, 0x8c, 0xef, 0x20, 0x74, 0x80, 0x11, 0x74, 0x7f, 0xa8, 0xf6, - 0x76, 0x8d, 0x6c, 0x22, 0xfc, 0x89, 0xdd, 0x9e, 0xac, 0xdb, 0x20, 0x73, 0xf0, 0xd0, 0xb9, 0x67, - 0x57, 0xd7, 0x10, 0x2e, 0x49, 0xe5, 0xa4, 0x71, 0x08, 0xbd, 0xf7, 0xf5, 0xcb, 0xfa, 0xed, 0x63, - 0x1d, 0x1d, 0x60, 0x1f, 0x82, 0xd5, 0xe2, 0x35, 0x8d, 0x18, 0x02, 0x84, 0x8f, 0xa9, 0xd5, 0x9d, - 0x79, 0x0e, 0xbd, 0x65, 0x13, 0x9e, 0x34, 0x5e, 0xfa, 0xa3, 0xf0, 0xc8, 0x65, 0xf1, 0x65, 0xc4, - 0xc7, 0x2d, 0xfa, 0xe3, 0x6f, 0x60, 0xf0, 0x9d, 0x67, 0xbe, 0xa9, 0x13, 0x6e, 0xcb, 0x7d, 0x6a, - 0x07, 0xf1, 0xa9, 0x1b, 0xa4, 0xb2, 0x34, 0x75, 0xe6, 0xbf, 0x6c, 0x42, 0xdb, 0xc1, 0xed, 0x7f, - 0x00, 0x00, 0x00, 0xff, 0xff, 0x48, 0x7a, 0x3b, 0x55, 0x9c, 0x01, 0x00, 0x00, -} - -// Reference imports to suppress errors if they are not otherwise used. -var _ context.Context -var _ grpc.ClientConn - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -const _ = grpc.SupportPackageIsVersion4 - -// GreeterClient is the client API for Greeter service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. -type GreeterClient interface { - Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) - IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error) -} - -type greeterClient struct { - cc *grpc.ClientConn -} - -func NewGreeterClient(cc *grpc.ClientConn) GreeterClient { - return &greeterClient{cc} -} - -func (c *greeterClient) Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) { - out := new(Response) - err := c.cc.Invoke(ctx, "/test.Greeter/greet", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *greeterClient) IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error) { - out := new(rpc.EmptyResponse) - err := c.cc.Invoke(ctx, "/test.Greeter/idRequest", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// GreeterServer is the server API for Greeter service. -type GreeterServer interface { - Greet(context.Context, *Request) (*Response, error) - IdRequest(context.Context, *rpc.IdRequest) (*rpc.EmptyResponse, error) -} - -// UnimplementedGreeterServer can be embedded to have forward compatible implementations. -type UnimplementedGreeterServer struct { -} - -func (*UnimplementedGreeterServer) Greet(ctx context.Context, req *Request) (*Response, error) { - return nil, status.Errorf(codes.Unimplemented, "method Greet not implemented") -} -func (*UnimplementedGreeterServer) IdRequest(ctx context.Context, req *rpc.IdRequest) (*rpc.EmptyResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method IdRequest not implemented") -} - -func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) { - s.RegisterService(&_Greeter_serviceDesc, srv) -} - -func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Request) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(GreeterServer).Greet(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/test.Greeter/Greet", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(GreeterServer).Greet(ctx, req.(*Request)) - } - return interceptor(ctx, in, info, handler) -} - -func _Greeter_IdRequest_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(rpc.IdRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(GreeterServer).IdRequest(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/test.Greeter/IdRequest", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(GreeterServer).IdRequest(ctx, req.(*rpc.IdRequest)) - } - return interceptor(ctx, in, info, handler) -} - -var _Greeter_serviceDesc = grpc.ServiceDesc{ - ServiceName: "test.Greeter", - HandlerType: (*GreeterServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "greet", - Handler: _Greeter_Greet_Handler, - }, - { - MethodName: "idRequest", - Handler: _Greeter_IdRequest_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "test.proto", -} diff --git a/tools/goctl/tpl/templates.go b/tools/goctl/tpl/templates.go index e10d54da..492d85e9 100644 --- a/tools/goctl/tpl/templates.go +++ b/tools/goctl/tpl/templates.go @@ -7,7 +7,7 @@ import ( "github.com/tal-tech/go-zero/core/errorx" "github.com/tal-tech/go-zero/tools/goctl/api/gogen" modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" - rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/gen" + rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator" "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/urfave/cli" ) diff --git a/tools/goctl/util/ctx/context.go b/tools/goctl/util/ctx/context.go new file mode 100644 index 00000000..c41b1cf8 --- /dev/null +++ b/tools/goctl/util/ctx/context.go @@ -0,0 +1,53 @@ +package ctx + +import ( + "errors" + "path/filepath" + + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" +) + +var moduleCheckErr = errors.New("the work directory must be found in the go mod or the $GOPATH") + +type ( + ProjectContext struct { + WorkDir string + // Name is the root name of the project + // eg: go-zero、greet + Name string + // Path identifies which module a project belongs to, which is module value if it's a go mod project, + // or else it is the root name of the project, eg: github.com/tal-tech/go-zero、greet + Path string + // Dir is the path of the project, eg: /Users/keson/goland/go/go-zero、/Users/keson/go/src/greet + Dir string + } +) + +// Prepare checks the project which module belongs to,and returns the path and module. +// workDir parameter is the directory of the source of generating code, +// where can be found the project path and the project module, +func Prepare(workDir string) (*ProjectContext, error) { + ctx, err := background(workDir) + if err == nil { + return ctx, nil + } + + name := filepath.Base(workDir) + _, err = execx.Run("go mod init "+name, workDir) + if err != nil { + return nil, err + } + return background(workDir) +} + +func background(workDir string) (*ProjectContext, error) { + isGoMod, err := IsGoMod(workDir) + if err != nil { + return nil, err + } + + if isGoMod { + return projectFromGoMod(workDir) + } + return projectFromGoPath(workDir) +} diff --git a/tools/goctl/util/ctx/context_test.go b/tools/goctl/util/ctx/context_test.go new file mode 100644 index 00000000..abc53af4 --- /dev/null +++ b/tools/goctl/util/ctx/context_test.go @@ -0,0 +1,22 @@ +package ctx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBackground(t *testing.T) { + workDir := "." + ctx, err := Prepare(workDir) + assert.Nil(t, err) + assert.True(t, true, func() bool { + return len(ctx.Dir) != 0 && len(ctx.Path) != 0 + }()) +} + +func TestBackgroundNilWorkDir(t *testing.T) { + workDir := "" + _, err := Prepare(workDir) + assert.NotNil(t, err) +} diff --git a/tools/goctl/util/ctx/gomod.go b/tools/goctl/util/ctx/gomod.go new file mode 100644 index 00000000..c073e449 --- /dev/null +++ b/tools/goctl/util/ctx/gomod.go @@ -0,0 +1,47 @@ +package ctx + +import ( + "errors" + "os" + "path/filepath" + + "github.com/tal-tech/go-zero/core/jsonx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" +) + +type Module struct { + Path string + Main bool + Dir string + GoMod string + GoVersion string +} + +// projectFromGoMod is used to find the go module and project file path +// the workDir flag specifies which folder we need to detect based on +// only valid for go mod project +func projectFromGoMod(workDir string) (*ProjectContext, error) { + if len(workDir) == 0 { + return nil, errors.New("the work directory is not found") + } + if _, err := os.Stat(workDir); err != nil { + return nil, err + } + + data, err := execx.Run("go list -json -m", workDir) + if err != nil { + return nil, err + } + + var m Module + err = jsonx.Unmarshal([]byte(data), &m) + if err != nil { + return nil, err + } + var ret ProjectContext + ret.WorkDir = workDir + ret.Name = filepath.Base(m.Dir) + ret.Dir = m.Dir + ret.Path = m.Path + return &ret, nil +} diff --git a/tools/goctl/util/ctx/gomod_test.go b/tools/goctl/util/ctx/gomod_test.go new file mode 100644 index 00000000..bd86642b --- /dev/null +++ b/tools/goctl/util/ctx/gomod_test.go @@ -0,0 +1,38 @@ +package ctx + +import ( + "go/build" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" + "github.com/tal-tech/go-zero/tools/goctl/util" +) + +func TestProjectFromGoMod(t *testing.T) { + dft := build.Default + gp := dft.GOPATH + if len(gp) == 0 { + return + } + projectName := stringx.Rand() + dir := filepath.Join(gp, "src", projectName) + err := util.MkdirIfNotExist(dir) + if err != nil { + return + } + + _, err = execx.Run("go mod init "+projectName, dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(dir) + }() + + ctx, err := projectFromGoMod(dir) + assert.Nil(t, err) + assert.Equal(t, projectName, ctx.Path) + assert.Equal(t, dir, ctx.Dir) +} diff --git a/tools/goctl/util/ctx/gopath.go b/tools/goctl/util/ctx/gopath.go new file mode 100644 index 00000000..ec44ef2b --- /dev/null +++ b/tools/goctl/util/ctx/gopath.go @@ -0,0 +1,47 @@ +package ctx + +import ( + "errors" + "go/build" + "os" + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/util" +) + +// projectFromGoPath is used to find the main module and project file path +// the workDir flag specifies which folder we need to detect based on +// only valid for go mod project +func projectFromGoPath(workDir string) (*ProjectContext, error) { + if len(workDir) == 0 { + return nil, errors.New("the work directory is not found") + } + if _, err := os.Stat(workDir); err != nil { + return nil, err + } + + buildContext := build.Default + goPath := buildContext.GOPATH + goSrc := filepath.Join(goPath, "src") + if !util.FileExists(goSrc) { + return nil, moduleCheckErr + } + + wd, err := filepath.Abs(workDir) + if err != nil { + return nil, err + } + + if !strings.HasPrefix(wd, goSrc) { + return nil, moduleCheckErr + } + + projectName := strings.TrimPrefix(wd, goSrc+string(filepath.Separator)) + return &ProjectContext{ + WorkDir: workDir, + Name: projectName, + Path: projectName, + Dir: filepath.Join(goSrc, projectName), + }, nil +} diff --git a/tools/goctl/util/ctx/gopath_test.go b/tools/goctl/util/ctx/gopath_test.go new file mode 100644 index 00000000..1f0f0780 --- /dev/null +++ b/tools/goctl/util/ctx/gopath_test.go @@ -0,0 +1,54 @@ +package ctx + +import ( + "go/build" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util" +) + +func TestProjectFromGoPath(t *testing.T) { + dft := build.Default + gp := dft.GOPATH + if len(gp) == 0 { + return + } + projectName := stringx.Rand() + dir := filepath.Join(gp, "src", projectName) + err := util.MkdirIfNotExist(dir) + if err != nil { + return + } + defer func() { + _ = os.RemoveAll(dir) + }() + + ctx, err := projectFromGoPath(dir) + assert.Nil(t, err) + assert.Equal(t, dir, ctx.Dir) + assert.Equal(t, projectName, ctx.Path) +} + +func TestProjectFromGoPathNotInGoSrc(t *testing.T) { + dft := build.Default + gp := dft.GOPATH + if len(gp) == 0 { + return + } + projectName := stringx.Rand() + dir := filepath.Join(gp, "src", projectName) + err := util.MkdirIfNotExist(dir) + if err != nil { + return + } + defer func() { + _ = os.RemoveAll(dir) + }() + + _, err = projectFromGoPath("testPath") + assert.NotNil(t, err) +} diff --git a/tools/goctl/util/ctx/modcheck.go b/tools/goctl/util/ctx/modcheck.go new file mode 100644 index 00000000..0db491e4 --- /dev/null +++ b/tools/goctl/util/ctx/modcheck.go @@ -0,0 +1,32 @@ +package ctx + +import ( + "errors" + "os" + + "github.com/tal-tech/go-zero/core/jsonx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" +) + +// IsGoMod is used to determine whether workDir is a go module project through command `go list -json -m` +func IsGoMod(workDir string) (bool, error) { + if len(workDir) == 0 { + return false, errors.New("the work directory is not found") + } + if _, err := os.Stat(workDir); err != nil { + return false, err + } + + data, err := execx.Run("go list -json -m", workDir) + if err != nil { + return false, nil + } + + var m Module + err = jsonx.Unmarshal([]byte(data), &m) + if err != nil { + return false, err + } + + return len(m.GoMod) > 0, nil +} diff --git a/tools/goctl/util/ctx/modcheck_test.go b/tools/goctl/util/ctx/modcheck_test.go new file mode 100644 index 00000000..b651af99 --- /dev/null +++ b/tools/goctl/util/ctx/modcheck_test.go @@ -0,0 +1,67 @@ +package ctx + +import ( + "go/build" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" + "github.com/tal-tech/go-zero/tools/goctl/util" +) + +func TestIsGoMod(t *testing.T) { + // create mod project + dft := build.Default + gp := dft.GOPATH + if len(gp) == 0 { + return + } + projectName := stringx.Rand() + dir := filepath.Join(gp, "src", projectName) + err := util.MkdirIfNotExist(dir) + if err != nil { + return + } + + _, err = execx.Run("go mod init "+projectName, dir) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(dir) + }() + + isGoMod, err := IsGoMod(dir) + assert.Nil(t, err) + assert.Equal(t, true, isGoMod) +} + +func TestIsGoModNot(t *testing.T) { + dft := build.Default + gp := dft.GOPATH + if len(gp) == 0 { + return + } + projectName := stringx.Rand() + dir := filepath.Join(gp, "src", projectName) + err := util.MkdirIfNotExist(dir) + if err != nil { + return + } + + defer func() { + _ = os.RemoveAll(dir) + }() + + isGoMod, err := IsGoMod(dir) + assert.Nil(t, err) + assert.Equal(t, false, isGoMod) +} + +func TestIsGoModWorkDirIsNil(t *testing.T) { + _, err := IsGoMod("") + assert.Equal(t, err.Error(), func() string { + return "the work directory is not found" + }()) +} diff --git a/tools/goctl/util/project/project.go b/tools/goctl/util/project/project.go deleted file mode 100644 index 0ebb61b8..00000000 --- a/tools/goctl/util/project/project.go +++ /dev/null @@ -1,141 +0,0 @@ -package project - -import ( - "errors" - "fmt" - "io/ioutil" - "os" - "os/exec" - "path/filepath" - "regexp" - "strings" - - "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" -) - -const ( - constGo = "go" - constProtoC = "protoc" - constGoMod = "go env GOMOD" - constGoPath = "go env GOPATH" - constProtoCGenGo = "protoc-gen-go" -) - -type ( - Project struct { - Path string // Project path name - Name string // Project name - Package string // The service related package - // true-> project in go path or project init with go mod,or else->false - IsInGoEnv bool - GoMod GoMod - } - - GoMod struct { - Module string // The gomod module name - Path string // The gomod related path - } -) - -func Prepare(projectDir string, checkGrpcEnv bool) (*Project, error) { - _, err := exec.LookPath(constGo) - if err != nil { - return nil, fmt.Errorf("please install go first,reference documents:「https://golang.org/doc/install」") - } - - if checkGrpcEnv { - _, err = exec.LookPath(constProtoC) - if err != nil { - return nil, fmt.Errorf("please install protoc first,reference documents:「https://github.com/golang/protobuf」") - } - - _, err = exec.LookPath(constProtoCGenGo) - if err != nil { - return nil, fmt.Errorf("please install plugin protoc-gen-go first,reference documents:「https://github.com/golang/protobuf」") - } - } - - var ( - goMod, module string - goPath string - name, path string - pkg string - ) - - ret, err := execx.Run(constGoMod, projectDir) - if err != nil { - return nil, err - } - - goMod = strings.TrimSpace(ret) - if goMod == os.DevNull { - goMod = "" - } - - ret, err = execx.Run(constGoPath, "") - if err != nil { - return nil, err - } - - goPath = strings.TrimSpace(ret) - src := filepath.Join(goPath, "src") - var isInGoEnv = true - if len(goMod) > 0 { - path = filepath.Dir(goMod) - name = filepath.Base(path) - data, err := ioutil.ReadFile(goMod) - if err != nil { - return nil, err - } - - module, err = matchModule(data) - if err != nil { - return nil, err - } - } else { - pwd, err := filepath.Abs(projectDir) - if err != nil { - return nil, err - } - - if !strings.HasPrefix(pwd, src) { - name = filepath.Clean(filepath.Base(pwd)) - path = projectDir - pkg = name - isInGoEnv = false - } else { - r := strings.TrimPrefix(pwd, src+string(filepath.Separator)) - name = filepath.Dir(r) - if name == "." { - name = r - } - path = filepath.Join(src, name) - pkg = r - } - module = name - } - - return &Project{ - Name: name, - Path: path, - Package: strings.ReplaceAll(pkg, `\`, "/"), - IsInGoEnv: isInGoEnv, - GoMod: GoMod{ - Module: module, - Path: goMod, - }, - }, nil -} - -func matchModule(data []byte) (string, error) { - text := string(data) - re := regexp.MustCompile(`(?m)^\s*module\s+[a-z0-9_/\-.]+$`) - matches := re.FindAllString(text, -1) - if len(matches) == 1 { - target := matches[0] - index := strings.Index(target, "module") - return strings.TrimSpace(target[index+6:]), nil - } - - return "", errors.New("module not matched") -}