Goctl rpc patch (#117)

* remove mock generation

* add: proto project import

* update document

* remove mock generation

* add: proto project import

* update document

* remove NL

* update document

* optimize code

* add test

* add test
This commit is contained in:
Keson 2020-10-10 16:19:46 +08:00 committed by GitHub
parent c32759d735
commit 0a9c427443
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1394 additions and 230 deletions

View File

@ -7,6 +7,9 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块支持prot
* 简单易用 * 简单易用
* 快速提升开发效率 * 快速提升开发效率
* 出错率低 * 出错率低
* 支持基于main proto作为相对路径的import
* 支持map、enum类型
* 支持any类型
## 快速开始 ## 快速开始
@ -111,14 +114,12 @@ rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题
│   └── usermodel.go │   └── usermodel.go
├── user.go ├── user.go
└── user.proto └── user.proto
``` ```
## 准备工作 ## 准备工作
* 安装了go环境 * 安装了go环境
* 安装了protoc&protoc-gen-go并且已经设置环境变量 * 安装了protoc&protoc-gen-go并且已经设置环境变量
* mockgen(可选,将移除)
* 更多问题请见 <a href="#注意事项">注意事项</a> * 更多问题请见 <a href="#注意事项">注意事项</a>
## 用法 ## 用法
@ -140,7 +141,6 @@ OPTIONS:
--src value, -s value the file path of the proto source file --src value, -s value the file path of the proto source file
--dir value, -d value the target path of the code,default path is "${pwd}". [option] --dir value, -d value the target path of the code,default path is "${pwd}". [option]
--service value, --srv value the name of rpc service. [option] --service value, --srv value the name of rpc service. [option]
--shared[已废弃] value the dir of the shared file,default path is "${pwd}/shared. [option]"
--idea whether the command execution environment is from idea plugin. [option] --idea whether the command execution environment is from idea plugin. [option]
``` ```
@ -159,13 +159,13 @@ OPTIONS:
``` ```
则服务名称亦为user而非proto所在文件夹名称了这里推荐使用这种结构可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。 则服务名称亦为user而非proto所在文件夹名称了这里推荐使用这种结构可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。
* --shared[⚠️已废弃] 非必填,默认为$dir(xxx.proto)/sharedrpc client逻辑代码存放目录。
> 注意这里的shared文件夹名称将会是代码中的package名称。 > 注意这里的shared文件夹名称将会是代码中的package名称。
* --idea 非必填是否为idea插件中执行保留字段终端执行可以忽略 * --idea 非必填是否为idea插件中执行保留字段终端执行可以忽略
## 开发人员需要做什么
### 开发人员需要做什么
关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后开饭人员仅需要修改 关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后开饭人员仅需要修改
@ -173,14 +173,11 @@ OPTIONS:
* 服务中业务逻辑编写(internal/logic/xxlogic.go) * 服务中业务逻辑编写(internal/logic/xxlogic.go)
* 服务中资源上下文的编写(internal/svc/servicecontext.go) * 服务中资源上下文的编写(internal/svc/servicecontext.go)
## 扩展
对于需要进行rpc mock的开发人员在安装了`mockgen`工具的前提下可以在rpc的shared文件中生成好对应的mock文件。 ### 注意事项
## 注意事项
* `google.golang.org/grpc`需要降级到v1.26.0,且protoc-gen-go版本不能高于v1.3.2see [https://github.com/grpc/grpc-go/issues/3347](https://github.com/grpc/grpc-go/issues/3347))即 * `google.golang.org/grpc`需要降级到v1.26.0,且protoc-gen-go版本不能高于v1.3.2see [https://github.com/grpc/grpc-go/issues/3347](https://github.com/grpc/grpc-go/issues/3347))即
```shell script ```shell script
replace google.golang.org/grpc => google.golang.org/grpc v1.26.0 replace google.golang.org/grpc => google.golang.org/grpc v1.26.0
``` ```
@ -189,12 +186,76 @@ OPTIONS:
* proto不支持外部依赖包引入message不支持inline * proto不支持外部依赖包引入message不支持inline
* 目前main文件、shared文件、handler文件会被强制覆盖而和开发人员手动需要编写的则不会覆盖生成这一类在代码头部均有 * 目前main文件、shared文件、handler文件会被强制覆盖而和开发人员手动需要编写的则不会覆盖生成这一类在代码头部均有
```shell script ```shell script
// Code generated by goctl. DO NOT EDIT! // Code generated by goctl. DO NOT EDIT!
// Source: xxx.proto // Source: xxx.proto
``` ```
的标识,请注意不要将也写业务性代码写在里面。 的标识,请注意不要将也写业务性代码写在里面。
## any和import支持
* 支持any类型声明
* 支持import其他proto文件
any类型固定import为`google/protobuf/any.proto`,且从${GOPATH}/src中查找proto的import支持main proto的相对路径的import且与proto文件对应的pb.go文件必须在proto目录中能被找到。不支持工程外的其他proto文件import。
> ⚠️注意: 不支持proto嵌套import被import的proto文件不支持import。
### import书写格式
import书写格式
```golang
// @{package_of_pb}
import {proto_omport}
```
@{package_of_pb}pb文件的真实import目录。
{proto_omport}proto import
demo中的
```golang
// @greet/base
import "base/base.proto";
```
工程目录结构如下
```
greet
│   ├── base
│   │   ├── base.pb.go
│   │   └── base.proto
│   ├── demo.proto
│   ├── go.mod
│   └── go.sum
```
demo
```golang
syntax = "proto3";
import "google/protobuf/any.proto";
// @greet/base
import "base/base.proto";
package stream;
enum Gender{
UNKNOWN = 0;
MAN = 1;
WOMAN = 2;
}
message StreamResp{
string name = 2;
Gender gender = 3;
google.protobuf.Any details = 5;
base.StreamReq req = 6;
}
service StreamGreeter {
rpc greet(base.StreamReq) returns (StreamResp);
}
```
## 常见问题解决(go mod工程) ## 常见问题解决(go mod工程)

4
go.mod
View File

@ -9,7 +9,7 @@ require (
github.com/alicebob/miniredis v2.5.0+incompatible github.com/alicebob/miniredis v2.5.0+incompatible
github.com/dchest/siphash v1.2.1 github.com/dchest/siphash v1.2.1
github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dsymonds/gotoc v0.0.0-20160928043926-5aebcfc91819 github.com/emicklei/proto v1.9.0
github.com/fatih/color v1.9.0 // indirect github.com/fatih/color v1.9.0 // indirect
github.com/frankban/quicktest v1.7.2 // indirect github.com/frankban/quicktest v1.7.2 // indirect
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8
@ -56,7 +56,7 @@ require (
golang.org/x/tools v0.0.0-20200410132612-ae9902aceb98 // indirect golang.org/x/tools v0.0.0-20200410132612-ae9902aceb98 // indirect
google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f // indirect google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f // indirect
google.golang.org/grpc v1.29.1 google.golang.org/grpc v1.29.1
google.golang.org/protobuf v1.25.0 // indirect google.golang.org/protobuf v1.25.0
gopkg.in/cheggaaa/pb.v1 v1.0.28 gopkg.in/cheggaaa/pb.v1 v1.0.28
gopkg.in/h2non/gock.v1 v1.0.15 gopkg.in/h2non/gock.v1 v1.0.15
gopkg.in/yaml.v2 v2.2.8 gopkg.in/yaml.v2 v2.2.8

4
go.sum
View File

@ -50,10 +50,10 @@ github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4=
github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dsymonds/gotoc v0.0.0-20160928043926-5aebcfc91819 h1:9778zj477h/VauD8kHbOtbytW2KGQefJ/wUGE5w+mzw=
github.com/dsymonds/gotoc v0.0.0-20160928043926-5aebcfc91819/go.mod h1:MvzMVHq8BH2Ji/o8TGDocVA70byvLrAgFTxkEnmjO4Y=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4 h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4 h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/emicklei/proto v1.9.0 h1:l0QiNT6Qs7Yj0Mb4X6dnWBQer4ebei2BFcgQLbGqUDc=
github.com/emicklei/proto v1.9.0/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=

View File

@ -7,6 +7,9 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块支持prot
* 简单易用 * 简单易用
* 快速提升开发效率 * 快速提升开发效率
* 出错率低 * 出错率低
* 支持基于main proto作为相对路径的import
* 支持map、enum类型
* 支持any类型
## 快速开始 ## 快速开始
@ -117,7 +120,6 @@ rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题
* 安装了go环境 * 安装了go环境
* 安装了protoc&protoc-gen-go并且已经设置环境变量 * 安装了protoc&protoc-gen-go并且已经设置环境变量
* mockgen(可选,将移除)
* 更多问题请见 <a href="#注意事项">注意事项</a> * 更多问题请见 <a href="#注意事项">注意事项</a>
## 用法 ## 用法
@ -139,7 +141,6 @@ OPTIONS:
--src value, -s value the file path of the proto source file --src value, -s value the file path of the proto source file
--dir value, -d value the target path of the code,default path is "${pwd}". [option] --dir value, -d value the target path of the code,default path is "${pwd}". [option]
--service value, --srv value the name of rpc service. [option] --service value, --srv value the name of rpc service. [option]
--shared[已废弃] value the dir of the shared file,default path is "${pwd}/shared. [option]"
--idea whether the command execution environment is from idea plugin. [option] --idea whether the command execution environment is from idea plugin. [option]
``` ```
@ -158,12 +159,12 @@ OPTIONS:
``` ```
则服务名称亦为user而非proto所在文件夹名称了这里推荐使用这种结构可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。 则服务名称亦为user而非proto所在文件夹名称了这里推荐使用这种结构可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。
* --shared[⚠️已废弃] 非必填,默认为$dir(xxx.proto)/sharedrpc client逻辑代码存放目录。
> 注意这里的shared文件夹名称将会是代码中的package名称。 > 注意这里的shared文件夹名称将会是代码中的package名称。
* --idea 非必填是否为idea插件中执行保留字段终端执行可以忽略 * --idea 非必填是否为idea插件中执行保留字段终端执行可以忽略
### 开发人员需要做什么 ### 开发人员需要做什么
关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后开饭人员仅需要修改 关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后开饭人员仅需要修改
@ -172,9 +173,6 @@ OPTIONS:
* 服务中业务逻辑编写(internal/logic/xxlogic.go) * 服务中业务逻辑编写(internal/logic/xxlogic.go)
* 服务中资源上下文的编写(internal/svc/servicecontext.go) * 服务中资源上下文的编写(internal/svc/servicecontext.go)
## 扩展
对于需要进行rpc mock的开发人员在安装了`mockgen`工具的前提下可以在rpc的shared文件中生成好对应的mock文件。
### 注意事项 ### 注意事项
@ -195,6 +193,70 @@ OPTIONS:
的标识,请注意不要将也写业务性代码写在里面。 的标识,请注意不要将也写业务性代码写在里面。
## any和import支持
* 支持any类型声明
* 支持import其他proto文件
any类型固定import为`google/protobuf/any.proto`,且从${GOPATH}/src中查找proto的import支持main proto的相对路径的import且与proto文件对应的pb.go文件必须在proto目录中能被找到。不支持工程外的其他proto文件import。
> ⚠️注意: 不支持proto嵌套import被import的proto文件不支持import。
### import书写格式
import书写格式
```golang
// @{package_of_pb}
import {proto_omport}
```
@{package_of_pb}pb文件的真实import目录。
{proto_omport}proto import
demo中的
```golang
// @greet/base
import "base/base.proto";
```
工程目录结构如下
```
greet
│   ├── base
│   │   ├── base.pb.go
│   │   └── base.proto
│   ├── demo.proto
│   ├── go.mod
│   └── go.sum
```
demo
```golang
syntax = "proto3";
import "google/protobuf/any.proto";
// @greet/base
import "base/base.proto";
package stream;
enum Gender{
UNKNOWN = 0;
MAN = 1;
WOMAN = 2;
}
message StreamResp{
string name = 2;
Gender gender = 3;
google.protobuf.Any details = 5;
base.StreamReq req = 6;
}
service StreamGreeter {
rpc greet(base.StreamReq) returns (StreamResp);
}
```
## 常见问题解决(go mod工程) ## 常见问题解决(go mod工程)
* 错误一: * 错误一:

108
tools/goctl/rpc/base.pb.go Normal file
View File

@ -0,0 +1,108 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: base.proto
package base
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type IdRequest struct {
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *IdRequest) Reset() { *m = IdRequest{} }
func (m *IdRequest) String() string { return proto.CompactTextString(m) }
func (*IdRequest) ProtoMessage() {}
func (*IdRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_db1b6b0986796150, []int{0}
}
func (m *IdRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_IdRequest.Unmarshal(m, b)
}
func (m *IdRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_IdRequest.Marshal(b, m, deterministic)
}
func (m *IdRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_IdRequest.Merge(m, src)
}
func (m *IdRequest) XXX_Size() int {
return xxx_messageInfo_IdRequest.Size(m)
}
func (m *IdRequest) XXX_DiscardUnknown() {
xxx_messageInfo_IdRequest.DiscardUnknown(m)
}
var xxx_messageInfo_IdRequest proto.InternalMessageInfo
func (m *IdRequest) GetId() string {
if m != nil {
return m.Id
}
return ""
}
type EmptyResponse struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *EmptyResponse) Reset() { *m = EmptyResponse{} }
func (m *EmptyResponse) String() string { return proto.CompactTextString(m) }
func (*EmptyResponse) ProtoMessage() {}
func (*EmptyResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_db1b6b0986796150, []int{1}
}
func (m *EmptyResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_EmptyResponse.Unmarshal(m, b)
}
func (m *EmptyResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_EmptyResponse.Marshal(b, m, deterministic)
}
func (m *EmptyResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_EmptyResponse.Merge(m, src)
}
func (m *EmptyResponse) XXX_Size() int {
return xxx_messageInfo_EmptyResponse.Size(m)
}
func (m *EmptyResponse) XXX_DiscardUnknown() {
xxx_messageInfo_EmptyResponse.DiscardUnknown(m)
}
var xxx_messageInfo_EmptyResponse proto.InternalMessageInfo
func init() {
proto.RegisterType((*IdRequest)(nil), "base.IdRequest")
proto.RegisterType((*EmptyResponse)(nil), "base.EmptyResponse")
}
func init() { proto.RegisterFile("base.proto", fileDescriptor_db1b6b0986796150) }
var fileDescriptor_db1b6b0986796150 = []byte{
// 91 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4a, 0x4a, 0x2c, 0x4e,
0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x01, 0xb1, 0x95, 0xa4, 0xb9, 0x38, 0x3d, 0x53,
0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0xf8, 0xb8, 0x98, 0x32, 0x53, 0x24, 0x18, 0x15,
0x18, 0x35, 0x38, 0x83, 0x98, 0x32, 0x53, 0x94, 0xf8, 0xb9, 0x78, 0x5d, 0x73, 0x0b, 0x4a, 0x2a,
0x83, 0x52, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x93, 0xd8, 0xc0, 0x5a, 0x8d, 0x01, 0x01, 0x00,
0x00, 0xff, 0xff, 0xe1, 0x39, 0x3c, 0x22, 0x48, 0x00, 0x00, 0x00,
}

View File

@ -0,0 +1,11 @@
syntax = "proto3";
package base;
message IdRequest {
string id = 1;
}
message EmptyResponse {
}

View File

@ -11,6 +11,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/project" "github.com/tal-tech/go-zero/tools/goctl/util/project"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
"github.com/tal-tech/go-zero/tools/goctl/vars"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@ -58,7 +59,7 @@ func MustCreateRpcContext(protoSrc, targetDir, serviceName string, idea bool) *R
} }
serviceNameString := stringx.From(serviceName) serviceNameString := stringx.From(serviceName)
if serviceNameString.IsEmptyOrSpace() { if serviceNameString.IsEmptyOrSpace() {
log.Fatalln("service name is not found") log.Fatalln("service name not found")
} }
info, err := project.Prepare(targetDir, true) info, err := project.Prepare(targetDir, true)
@ -80,7 +81,7 @@ func MustCreateRpcContext(protoSrc, targetDir, serviceName string, idea bool) *R
func MustCreateRpcContextFromCli(ctx *cli.Context) *RpcContext { func MustCreateRpcContextFromCli(ctx *cli.Context) *RpcContext {
os := runtime.GOOS os := runtime.GOOS
switch os { switch os {
case "darwin", "linux", "windows": case vars.OsMac, vars.OsLinux, vars.OsWindows:
default: default:
logx.Must(fmt.Errorf("unexpected os: %s", os)) logx.Must(fmt.Errorf("unexpected os: %s", os))
} }

View File

@ -6,15 +6,17 @@ import (
"fmt" "fmt"
"os/exec" "os/exec"
"runtime" "runtime"
"github.com/tal-tech/go-zero/tools/goctl/vars"
) )
func Run(arg string, dir string) (string, error) { func Run(arg string, dir string) (string, error) {
goos := runtime.GOOS goos := runtime.GOOS
var cmd *exec.Cmd var cmd *exec.Cmd
switch goos { switch goos {
case "darwin", "linux": case vars.OsMac, vars.OsLinux:
cmd = exec.Command("sh", "-c", arg) cmd = exec.Command("sh", "-c", arg)
case "windows": case vars.OsWindows:
cmd = exec.Command("cmd.exe", "/c", arg) cmd = exec.Command("cmd.exe", "/c", arg)
default: default:
return "", fmt.Errorf("unexpected os: %v", goos) return "", fmt.Errorf("unexpected os: %v", goos)

View File

@ -1,6 +1,7 @@
package gen package gen
import ( import (
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/rpc/ctx" "github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
) )
@ -31,10 +32,11 @@ func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator {
} }
func (g *defaultRpcGenerator) Generate() (err error) { func (g *defaultRpcGenerator) Generate() (err error) {
g.Ctx.Info("generating code...") g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/go-zero/blob/master/doc/goctl-rpc.md」")
g.Ctx.Warning("-> generating rpc code ...")
defer func() { defer func() {
if err == nil { if err == nil {
g.Ctx.Success("Done.") g.Ctx.MarkDone()
} }
}() }()
err = g.createDir() err = g.createDir()

View File

@ -2,17 +2,16 @@ package gen
import ( import (
"fmt" "fmt"
"os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
const ( const (
typesFilename = "types.go"
callTemplateText = `{{.head}} callTemplateText = `{{.head}}
//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE //go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
@ -54,14 +53,17 @@ import "errors"
var errJsonConvert = errors.New("json convert error") var errJsonConvert = errors.New("json convert error")
{{.const}}
{{.types}} {{.types}}
` `
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}} callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)` {{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
callFunctionTemplate = ` callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}} {{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) { func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
var request {{.package}}.{{.pbRequest}} var request {{.pbRequest}}
bts, err := jsonx.Marshal(in) bts, err := jsonx.Marshal(in)
if err != nil { if err != nil {
return nil, errJsonConvert return nil, errJsonConvert
@ -108,21 +110,23 @@ func (g *defaultRpcGenerator) genCall() error {
return err return err
} }
constLit, err := file.GenEnumCode()
if err != nil {
return err
}
service := file.Service[0] service := file.Service[0]
callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower()) callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
if err = util.MkdirIfNotExist(callPath); err != nil { if err = util.MkdirIfNotExist(callPath); err != nil {
return err return err
} }
pbPkg := file.Package filename := filepath.Join(callPath, typesFilename)
remotePackage := fmt.Sprintf(`%v "%v"`, pbPkg, g.mustGetPackage(dirPb))
filename := filepath.Join(callPath, "types.go")
head := util.GetHead(g.Ctx.ProtoSource) head := util.GetHead(g.Ctx.ProtoSource)
err = util.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{ err = util.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{
"head": head, "head": head,
"const": constLit,
"filePackage": service.Name.Lower(), "filePackage": service.Name.Lower(),
"pbPkg": pbPkg,
"serviceName": g.Ctx.ServiceName.Title(), "serviceName": g.Ctx.ServiceName.Title(),
"lowerStartServiceName": g.Ctx.ServiceName.UnTitle(), "lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
"types": typeCode, "types": typeCode,
@ -131,10 +135,8 @@ func (g *defaultRpcGenerator) genCall() error {
return err return err
} }
_, err = exec.LookPath("mockgen")
mockGenInstalled := err == nil
filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower())) filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
functions, err := g.getFuncs(service) functions, importList, err := g.genFunction(service)
if err != nil { if err != nil {
return err return err
} }
@ -144,72 +146,56 @@ func (g *defaultRpcGenerator) genCall() error {
return err return err
} }
mockFile := filepath.Join(callPath, fmt.Sprintf("%s_mock.go", service.Name.Lower()))
_ = os.Remove(mockFile)
err = util.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{ err = util.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{
"name": service.Name.Lower(), "name": service.Name.Lower(),
"head": head, "head": head,
"filePackage": service.Name.Lower(), "filePackage": service.Name.Lower(),
"pbPkg": pbPkg, "package": strings.Join(importList, util.NL),
"package": remotePackage,
"serviceName": service.Name.Title(), "serviceName": service.Name.Title(),
"functions": strings.Join(functions, "\n"), "functions": strings.Join(functions, util.NL),
"interface": strings.Join(iFunctions, "\n"), "interface": strings.Join(iFunctions, util.NL),
}, filename, true) }, filename, true)
if err != nil { return err
return err
}
// if mockgen is already installed, it will generate code of gomock for shared files
// Deprecated: it will be removed
if mockGenInstalled && g.Ctx.IsInGoEnv {
_, _ = execx.Run(fmt.Sprintf("go generate %s", filename), "")
}
return nil
} }
func (g *defaultRpcGenerator) getFuncs(service *parser.RpcService) ([]string, error) { func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
file := g.ast file := g.ast
pkgName := file.Package pkgName := file.Package
functions := make([]string, 0) functions := make([]string, 0)
imports := collection.NewSet()
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
for _, method := range service.Funcs { for _, method := range service.Funcs {
var comment string imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
if len(method.Document) > 0 {
comment = method.Document[0]
}
buffer, err := util.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{ buffer, err := util.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{
"rpcServiceName": service.Name.Title(), "rpcServiceName": service.Name.Title(),
"method": method.Name.Title(), "method": method.Name.Title(),
"package": pkgName, "package": pkgName,
"pbRequest": method.InType, "pbRequestName": method.ParameterIn.Name,
"pbResponse": method.OutType, "pbRequest": method.ParameterIn.Expression,
"hasComment": len(method.Document) > 0, "pbResponse": method.ParameterOut.Name,
"comment": comment, "hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
}) })
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
functions = append(functions, buffer.String()) functions = append(functions, buffer.String())
} }
return functions, nil return functions, imports.KeysStr(), nil
} }
func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) { func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
functions := make([]string, 0) functions := make([]string, 0)
for _, method := range service.Funcs { for _, method := range service.Funcs {
var comment string
if len(method.Document) > 0 {
comment = method.Document[0]
}
buffer, err := util.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute( buffer, err := util.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute(
map[string]interface{}{ map[string]interface{}{
"hasComment": len(method.Document) > 0, "hasComment": method.HaveDoc(),
"comment": comment, "comment": method.GetDoc(),
"method": method.Name.Title(), "method": method.Name.Title(),
"pbRequest": method.InType, "pbRequest": method.ParameterIn.Name,
"pbResponse": method.OutType, "pbResponse": method.ParameterOut.Name,
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
) )
// target // target
@ -43,9 +44,9 @@ func (g *defaultRpcGenerator) mustGetPackage(dir string) string {
relativePath := strings.TrimPrefix(target, projectPath) relativePath := strings.TrimPrefix(target, projectPath)
os := runtime.GOOS os := runtime.GOOS
switch os { switch os {
case "windows": case vars.OsWindows:
relativePath = filepath.ToSlash(relativePath) relativePath = filepath.ToSlash(relativePath)
case "darwin", "linux": case vars.OsMac, vars.OsLinux:
default: default:
g.Ctx.Fatalln("unexpected os: %s", os) g.Ctx.Fatalln("unexpected os: %s", os)
} }

View File

@ -37,10 +37,10 @@ func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logic
{{.functions}} {{.functions}}
` `
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}} logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
func (l *{{.logicName}}) {{.method}} (in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) { func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
// todo: add your logic here and delete this line // todo: add your logic here and delete this line
return &{{.package}}.{{.response}}{}, nil return &{{.responseType}}{}, nil
} }
` `
) )
@ -53,18 +53,18 @@ func (g *defaultRpcGenerator) genLogic() error {
for _, method := range item.Funcs { for _, method := range item.Funcs {
logicName := fmt.Sprintf("%slogic.go", method.Name.Lower()) logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
filename := filepath.Join(logicPath, logicName) filename := filepath.Join(logicPath, logicName)
functions, err := genLogicFunction(protoPkg, method) functions, importList, err := g.genLogicFunction(protoPkg, method)
if err != nil { if err != nil {
return err return err
} }
imports := collection.NewSet() imports := collection.NewSet()
pbImport := fmt.Sprintf(`%v "%v"`, protoPkg, g.mustGetPackage(dirPb))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc)) svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports.AddStr(pbImport, svcImport) imports.AddStr(svcImport)
imports.AddStr(importList...)
err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{ err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"functions": functions, "functions": functions,
"imports": strings.Join(imports.KeysStr(), "\n"), "imports": strings.Join(imports.KeysStr(), util.NL),
}, filename, false) }, filename, false)
if err != nil { if err != nil {
return err return err
@ -74,20 +74,26 @@ func (g *defaultRpcGenerator) genLogic() error {
return nil return nil
} }
func genLogicFunction(packageName string, method *parser.Func) (string, error) { func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) {
var functions = make([]string, 0) var functions = make([]string, 0)
var imports = collection.NewSet()
if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName {
imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{ buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(), "method": method.Name.Title(),
"package": packageName, "request": method.ParameterIn.StarExpression,
"request": method.InType, "response": method.ParameterOut.StarExpression,
"response": method.OutType, "responseType": method.ParameterOut.Expression,
"hasComment": len(method.Document) > 0, "hasComment": method.HaveDoc(),
"comment": strings.Join(method.Document, "\n"), "comment": method.GetDoc(),
}) })
if err != nil { if err != nil {
return "", err return "", nil, err
} }
functions = append(functions, buffer.String()) functions = append(functions, buffer.String())
return strings.Join(functions, "\n"), nil return strings.Join(functions, util.NL), imports.KeysStr(), nil
} }

View File

@ -65,7 +65,7 @@ func (g *defaultRpcGenerator) genMain() error {
"serviceName": g.Ctx.ServiceName.Lower(), "serviceName": g.Ctx.ServiceName.Lower(),
"srv": srv, "srv": srv,
"registers": registers, "registers": registers,
"imports": strings.Join(imports, "\n"), "imports": strings.Join(imports, util.NL),
}, fileName, true) }, fileName, true)
} }
@ -77,5 +77,5 @@ func (g *defaultRpcGenerator) genServer(pkg string, list []*parser.RpcService) (
list1 = append(list1, fmt.Sprintf("%sSrv := server.New%sServer(ctx)", name, item.Name.Title())) list1 = append(list1, fmt.Sprintf("%sSrv := server.New%sServer(ctx)", name, item.Name.Title()))
list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name)) list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name))
} }
return strings.Join(list1, "\n"), strings.Join(list2, "\n") return strings.Join(list1, util.NL), strings.Join(list2, util.NL)
} }

View File

@ -1,68 +1,37 @@
package gen package gen
import ( import (
"errors" "bytes"
"fmt" "fmt"
"io/ioutil"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/dsymonds/gotoc/parser" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
astParser "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" )
const (
protocCmd = "protoc"
grpcPluginCmd = "--go_out=plugins=grpc"
) )
func (g *defaultRpcGenerator) genPb() error { func (g *defaultRpcGenerator) genPb() error {
importPath, filename := filepath.Split(g.Ctx.ProtoFileSrc)
tree, err := parser.ParseFiles([]string{filename}, []string{importPath})
if err != nil {
return err
}
if len(tree.Files) == 0 {
return errors.New("proto ast parse failed")
}
file := tree.Files[0]
if len(file.Package) == 0 {
return errors.New("expected package, but nothing found")
}
targetStruct := make(map[string]lang.PlaceholderType)
for _, item := range file.Messages {
if len(item.Messages) > 0 {
return fmt.Errorf(`line %v: unexpected inner message near: "%v""`, item.Messages[0].Position.Line, item.Messages[0].Name)
}
name := stringx.From(item.Name)
if _, ok := targetStruct[name.Lower()]; ok {
return fmt.Errorf("line %v: duplicate %v", item.Position.Line, name)
}
targetStruct[name.Lower()] = lang.Placeholder
}
pbPath := g.dirM[dirPb] pbPath := g.dirM[dirPb]
protoFileName := filepath.Base(g.Ctx.ProtoFileSrc) imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc)
err = g.protocGenGo(pbPath)
if err != nil { if err != nil {
return err return err
} }
pbGo := strings.TrimSuffix(protoFileName, ".proto") + ".pb.go" err = g.protocGenGo(pbPath, imports)
pbFile := filepath.Join(pbPath, pbGo)
bts, err := ioutil.ReadFile(pbFile)
if err != nil { if err != nil {
return err return err
} }
ast, err := parser.Transfer(g.Ctx.ProtoFileSrc, pbPath, imports, g.Ctx.Console)
aspParser := astParser.NewAstParser(bts, targetStruct, g.Ctx.Console)
ast, err := aspParser.Parse()
if err != nil { if err != nil {
return err return err
} }
ast.ContainsAny = containsAny
if len(ast.Service) == 0 { if len(ast.Service) == 0 {
return fmt.Errorf("service not found") return fmt.Errorf("service not found")
@ -71,10 +40,35 @@ func (g *defaultRpcGenerator) genPb() error {
return nil return nil
} }
func (g *defaultRpcGenerator) protocGenGo(target string) error { func (g *defaultRpcGenerator) protocGenGo(target string, imports []*parser.Import) error {
src := filepath.Dir(g.Ctx.ProtoFileSrc) dir := filepath.Dir(g.Ctx.ProtoFileSrc)
sh := fmt.Sprintf(`protoc -I=%s --go_out=plugins=grpc:%s %s`, src, target, g.Ctx.ProtoFileSrc) // cmd join,see the document of proto generating class @https://developers.google.com/protocol-buffers/docs/proto3#generating
stdout, err := execx.Run(sh, "") // template: protoc -I=${import_path} -I=${other_import_path} -I=${...} --go_out=plugins=grpc,M${pb_package_kv}, M${...} :${target_dir}
// eg: protoc -I=${GOPATH}/src -I=. example.proto --go_out=plugins=grpc,Mbase/base.proto=github.com/go-zero/base.proto:.
// note: the external import out of the project which are found in ${GOPATH}/src so far.
buffer := new(bytes.Buffer)
buffer.WriteString(protocCmd + " ")
targetImportFiltered := collection.NewSet()
for _, item := range imports {
buffer.WriteString(fmt.Sprintf("-I=%s ", item.OriginalDir))
if len(item.BridgeImport) == 0 {
continue
}
targetImportFiltered.AddStr(item.BridgeImport)
}
buffer.WriteString("-I=${GOPATH}/src ")
buffer.WriteString(fmt.Sprintf("-I=%s %s ", dir, g.Ctx.ProtoFileSrc))
buffer.WriteString(grpcPluginCmd)
if targetImportFiltered.Count() > 0 {
buffer.WriteString(fmt.Sprintf(",%v", strings.Join(targetImportFiltered.KeysStr(), ",")))
}
buffer.WriteString(":" + target)
g.Ctx.Debug("-> " + buffer.String())
stdout, err := execx.Run(buffer.String(), "")
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,6 +5,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@ -32,7 +33,7 @@ func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
` `
functionTemplate = ` functionTemplate = `
{{if .hasComment}}{{.comment}}{{end}} {{if .hasComment}}{{.comment}}{{end}}
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) { func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
l := logic.New{{.logicName}}(ctx,s.svcCtx) l := logic.New{{.logicName}}(ctx,s.svcCtx)
return l.{{.method}}(in) return l.{{.method}}(in)
} }
@ -45,29 +46,26 @@ func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{
func (g *defaultRpcGenerator) genHandler() error { func (g *defaultRpcGenerator) genHandler() error {
serverPath := g.dirM[dirServer] serverPath := g.dirM[dirServer]
file := g.ast file := g.ast
pkg := file.Package
pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic)) logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc)) svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports := []string{ imports := collection.NewSet()
pbImport, imports.AddStr(logicImport, svcImport)
logicImport,
svcImport,
}
head := util.GetHead(g.Ctx.ProtoSource) head := util.GetHead(g.Ctx.ProtoSource)
for _, service := range file.Service { for _, service := range file.Service {
filename := fmt.Sprintf("%vserver.go", service.Name.Lower()) filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
serverFile := filepath.Join(serverPath, filename) serverFile := filepath.Join(serverPath, filename)
funcList, err := g.genFunctions(service) funcList, importList, err := g.genFunctions(service)
if err != nil { if err != nil {
return err return err
} }
imports.AddStr(importList...)
err = util.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{ err = util.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{
"head": head, "head": head,
"types": fmt.Sprintf(typeFmt, service.Name.Title()), "types": fmt.Sprintf(typeFmt, service.Name.Title()),
"server": service.Name.Title(), "server": service.Name.Title(),
"imports": strings.Join(imports, "\n\t"), "imports": strings.Join(imports.KeysStr(), util.NL),
"funcs": strings.Join(funcList, "\n"), "funcs": strings.Join(funcList, util.NL),
}, serverFile, true) }, serverFile, true)
if err != nil { if err != nil {
return err return err
@ -76,25 +74,31 @@ func (g *defaultRpcGenerator) genHandler() error {
return nil return nil
} }
func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, error) { func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) {
file := g.ast file := g.ast
pkg := file.Package pkg := file.Package
var functionList []string var functionList []string
imports := collection.NewSet()
for _, method := range service.Funcs { for _, method := range service.Funcs {
if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg {
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
buffer, err := util.With("func").Parse(functionTemplate).Execute(map[string]interface{}{ buffer, err := util.With("func").Parse(functionTemplate).Execute(map[string]interface{}{
"server": service.Name.Title(), "server": service.Name.Title(),
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(), "method": method.Name.Title(),
"package": pkg, "package": pkg,
"request": method.InType, "request": method.ParameterIn.StarExpression,
"response": method.OutType, "response": method.ParameterOut.StarExpression,
"hasComment": len(method.Document), "hasComment": method.HaveDoc(),
"comment": strings.Join(method.Document, "\n"), "comment": method.GetDoc(),
}) })
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
functionList = append(functionList, buffer.String()) functionList = append(functionList, buffer.String())
} }
return functionList, nil return functionList, imports.KeysStr(), nil
} }

View File

@ -39,6 +39,8 @@ func NewRpcTemplate(out string, idea bool) *rpcTemplate {
} }
func (r *rpcTemplate) MustGenerate(showState bool) { func (r *rpcTemplate) MustGenerate(showState bool) {
r.Info("查看rpc生成请移步至「https://github.com/tal-tech/go-zero/blob/master/doc/goctl-rpc.md」")
r.Info("generating template...")
protoFilename := filepath.Base(r.out) protoFilename := filepath.Base(r.out)
serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename))) serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename)))
err := util.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{ err := util.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{

View File

@ -0,0 +1,35 @@
package base
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
func TestParseImport(t *testing.T) {
src, _ := filepath.Abs("./test.proto")
base, _ := filepath.Abs("./base.proto")
imports, containsAny, err := parser.ParseImport(src)
assert.Nil(t, err)
assert.Equal(t, true, containsAny)
assert.Equal(t, 1, len(imports))
assert.Equal(t, "github.com/tal-tech/go-zero/tools/goctl/rpc", imports[0].PbImportName)
assert.Equal(t, base, imports[0].OriginalProtoPath)
}
func TestTransfer(t *testing.T) {
src, _ := filepath.Abs("./test.proto")
abs, _ := filepath.Abs("./test")
imports, _, _ := parser.ParseImport(src)
proto, err := parser.Transfer(src, abs, imports, console.NewConsole(false))
assert.Nil(t, err)
assert.Equal(t, 1, len(proto.Service))
assert.Equal(t, "Greeter", proto.Service[0].Name.Source())
assert.Equal(t, 5, len(proto.Structure))
data, ok := proto.Structure["map"]
assert.Equal(t, true, ok)
assert.Equal(t, "M", data.Field[0].Name.Source())
}

View File

@ -0,0 +1,46 @@
package parser
import (
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
func Transfer(proto, target string, externalImport []*Import, console console.Console) (*PbAst, error) {
messageM := make(map[string]lang.PlaceholderType)
enumM := make(map[string]*Enum)
protoAst, err := parseProto(proto, messageM, enumM)
if err != nil {
return nil, err
}
for _, item := range externalImport {
err = checkImport(item.OriginalProtoPath)
if err != nil {
return nil, err
}
innerAst, err := parseProto(item.OriginalProtoPath, protoAst.Message, protoAst.Enum)
if err != nil {
return nil, err
}
for k, v := range innerAst.Message {
protoAst.Message[k] = v
}
for k, v := range innerAst.Enum {
protoAst.Enum[k] = v
}
}
protoAst.Import = externalImport
protoAst.PbSrc = filepath.Join(target, strings.TrimSuffix(filepath.Base(proto), ".proto")+".pb.go")
return transfer(protoAst, console)
}
func transfer(proto *Proto, console console.Console) (*PbAst, error) {
parser := MustNewAstParser(proto, console)
parse, err := parser.Parse()
if err != nil {
return nil, err
}
return parse, nil
}

View File

@ -6,6 +6,7 @@ import (
"go/ast" "go/ast"
"go/parser" "go/parser"
"go/token" "go/token"
"io/ioutil"
"sort" "sort"
"strings" "strings"
@ -18,8 +19,9 @@ import (
const ( const (
flagStar = "*" flagStar = "*"
flagDot = "."
suffixServer = "Server" suffixServer = "Server"
referenceContext = "context." referenceContext = "context"
unknownPrefix = "XXX_" unknownPrefix = "XXX_"
ignoreJsonTagExpression = `json:"-"` ignoreJsonTagExpression = `json:"-"`
) )
@ -34,19 +36,23 @@ var (
}` }`
fieldTemplate = `{{if .hasDoc}}{{.doc}} fieldTemplate = `{{if .hasDoc}}{{.doc}}
{{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}` {{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}`
anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}"
objectM = make(map[string]*Struct) objectM = make(map[string]*Struct)
) )
type ( type (
astParser struct { astParser struct {
golang []byte
filterStruct map[string]lang.PlaceholderType filterStruct map[string]lang.PlaceholderType
filterEnum map[string]*Enum
console.Console console.Console
fileSet *token.FileSet fileSet *token.FileSet
proto *Proto
} }
Field struct { Field struct {
Name stringx.String Name stringx.String
TypeName string Type Type
JsonTag string JsonTag string
Document []string Document []string
Comment []string Comment []string
@ -57,13 +63,33 @@ type (
Comment []string Comment []string
Field []*Field Field []*Field
} }
ConstLit struct {
Name stringx.String
Document []string
Comment []string
Lit []*Lit
}
Lit struct {
Key string
Value int
}
Type struct {
// eg:context.Context
Expression string
// eg: *context.Context
StarExpression string
// Invoke Type Expression
InvokeTypeExpression string
// eg:context
Package string
// eg:Context
Name string
}
Func struct { Func struct {
Name stringx.String Name stringx.String
InType string ParameterIn Type
InTypeName string // remove *Context,such as LoginRequest、UserRequest ParameterOut Type
OutTypeName string // remove *Context Document []string
OutType string
Document []string
} }
RpcService struct { RpcService struct {
Name stringx.String Name stringx.String
@ -71,54 +97,98 @@ type (
} }
// parsing for rpc // parsing for rpc
PbAst struct { PbAst struct {
Package string ContainsAny bool
// external reference Imports map[string]string
Imports map[string]string Structure map[string]*Struct
Strcuts map[string]*Struct Service []*RpcService
// rpc server's functions,not all functions *Proto
Service []*RpcService
} }
) )
func NewAstParser(golang []byte, filterStruct map[string]lang.PlaceholderType, log console.Console) *astParser { func MustNewAstParser(proto *Proto, log console.Console) *astParser {
return &astParser{ return &astParser{
golang: golang, filterStruct: proto.Message,
filterStruct: filterStruct, filterEnum: proto.Enum,
Console: log, Console: log,
fileSet: token.NewFileSet(), fileSet: token.NewFileSet(),
proto: proto,
} }
} }
func (a *astParser) Parse() (*PbAst, error) { func (a *astParser) Parse() (*PbAst, error) {
fSet := a.fileSet var pbAst PbAst
f, err := parser.ParseFile(fSet, "", a.golang, parser.ParseComments) pbAst.ContainsAny = a.proto.ContainsAny
pbAst.Proto = a.proto
pbAst.Structure = make(map[string]*Struct)
pbAst.Imports = make(map[string]string)
structure, imports, services, err := a.parse(a.proto.PbSrc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dependencyStructure, err := a.parseExternalDependency()
commentMap := ast.NewCommentMap(fSet, f, f.Comments) if err != nil {
f.Comments = commentMap.Filter(f).Comments() return nil, err
var pbAst PbAst
pbAst.Package = a.mustGetIndentName(f.Name)
imports := make(map[string]string)
for _, item := range f.Imports {
if item == nil {
continue
}
if item.Path == nil {
continue
}
key := a.mustGetIndentName(item.Name)
value := item.Path.Value
imports[key] = value
} }
structs, funcs := a.mustScope(f.Scope) for k, v := range structure {
pbAst.Imports = imports pbAst.Structure[k] = v
pbAst.Strcuts = structs }
pbAst.Service = funcs for k, v := range dependencyStructure {
pbAst.Structure[k] = v
}
for key, path := range imports {
pbAst.Imports[key] = path
}
pbAst.Service = append(pbAst.Service, services...)
return &pbAst, nil return &pbAst, nil
} }
func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcService) { func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) {
structure = make(map[string]*Struct)
imports = make(map[string]string)
data, err := ioutil.ReadFile(pbSrc)
if err != nil {
retErr = err
return
}
fSet := a.fileSet
f, err := parser.ParseFile(fSet, "", data, parser.ParseComments)
if err != nil {
retErr = err
return
}
commentMap := ast.NewCommentMap(fSet, f, f.Comments)
f.Comments = commentMap.Filter(f).Comments()
strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name))
for k, v := range strucs {
if v == nil {
continue
}
structure[k] = v
}
importList := f.Imports
for _, item := range importList {
name := a.mustGetIndentName(item.Name)
if item.Path != nil {
imports[name] = item.Path.Value
}
}
services = append(services, function...)
return
}
func (a *astParser) parseExternalDependency() (map[string]*Struct, error) {
m := make(map[string]*Struct)
for _, impo := range a.proto.Import {
ret, _, _, err := a.parse(impo.OriginalPbPath)
if err != nil {
return nil, err
}
for k, v := range ret {
m[k] = v
}
}
return m, nil
}
func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) {
if scope == nil { if scope == nil {
return nil, nil return nil, nil
} }
@ -140,7 +210,7 @@ func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcServi
switch v := tp.(type) { switch v := tp.(type) {
case *ast.StructType: case *ast.StructType:
st, err := a.parseObject(name, v) st, err := a.parseObject(name, v, sourcePackage)
a.Must(err) a.Must(err)
structs[st.Name.Lower()] = st structs[st.Name.Lower()] = st
@ -148,7 +218,7 @@ func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcServi
if !strings.HasSuffix(name, suffixServer) { if !strings.HasSuffix(name, suffixServer) {
continue continue
} }
list := a.mustServerFunctions(v) list := a.mustServerFunctions(v, sourcePackage)
serviceList = append(serviceList, &RpcService{ serviceList = append(serviceList, &RpcService{
Name: stringx.From(strings.TrimSuffix(name, suffixServer)), Name: stringx.From(strings.TrimSuffix(name, suffixServer)),
Funcs: list, Funcs: list,
@ -163,7 +233,7 @@ func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcServi
return targetStruct, serviceList return targetStruct, serviceList
} }
func (a *astParser) mustServerFunctions(v *ast.InterfaceType) []*Func { func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func {
funcs := make([]*Func, 0) funcs := make([]*Func, 0)
methodObject := v.Methods methodObject := v.Methods
if methodObject == nil { if methodObject == nil {
@ -187,31 +257,27 @@ func (a *astParser) mustServerFunctions(v *ast.InterfaceType) []*Func {
} }
params := v.Params params := v.Params
if params != nil { if params != nil {
inList, err := a.parseFields(params.List, true) inList, err := a.parseFields(params.List, true, sourcePackage)
a.Must(err) a.Must(err)
for _, data := range inList { for _, data := range inList {
if strings.HasPrefix(data.TypeName, referenceContext) { if data.Type.Package == referenceContext {
continue continue
} }
// currently,does not support external references item.ParameterIn = data.Type
item.InTypeName = data.TypeName
item.InType = strings.TrimPrefix(data.TypeName, flagStar)
break break
} }
} }
results := v.Results results := v.Results
if results != nil { if results != nil {
outList, err := a.parseFields(results.List, true) outList, err := a.parseFields(results.List, true, sourcePackage)
a.Must(err) a.Must(err)
for _, data := range outList { for _, data := range outList {
if strings.HasPrefix(data.TypeName, referenceContext) { if data.Type.Package == referenceContext {
continue continue
} }
// currently,does not support external references item.ParameterOut = data.Type
item.OutTypeName = data.TypeName
item.OutType = strings.TrimPrefix(data.TypeName, flagStar)
break break
} }
} }
@ -220,7 +286,67 @@ func (a *astParser) mustServerFunctions(v *ast.InterfaceType) []*Func {
return funcs return funcs
} }
func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct, error) { func (a *astParser) getFieldType(v string, sourcePackage string) Type {
var pkg, name, expression, starExpression, invokeTypeExpression string
if strings.Contains(v, ".") {
starExpression = v
if strings.Contains(v, "*") {
leftIndex := strings.Index(v, "*")
rightIndex := strings.Index(v, ".")
if leftIndex >= 0 {
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
invokeTypeExpression = v[rightIndex+1:]
}
} else {
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
leftIndex := strings.Index(v, "]")
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[rightIndex+1:]
}
}
} else {
expression = strings.TrimPrefix(v, flagStar)
switch v {
case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64",
"bool", "string", "bytes":
invokeTypeExpression = v
break
default:
name = expression
invokeTypeExpression = v
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".")
} else {
starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name)
invokeTypeExpression = v
}
}
}
expression = strings.TrimPrefix(starExpression, flagStar)
index := strings.LastIndex(expression, flagDot)
if index > 0 {
pkg = expression[0:index]
name = expression[index+1:]
} else {
pkg = sourcePackage
}
return Type{
Expression: expression,
StarExpression: starExpression,
InvokeTypeExpression: invokeTypeExpression,
Package: pkg,
Name: name,
}
}
func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) {
if data, ok := objectM[structName]; ok { if data, ok := objectM[structName]; ok {
return data, nil return data, nil
} }
@ -237,7 +363,7 @@ func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct,
} }
fieldList := fields.List fieldList := fields.List
members, err := a.parseFields(fieldList, false) members, err := a.parseFields(fieldList, false, sourcePackage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -245,7 +371,7 @@ func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct,
for _, m := range members { for _, m := range members {
var field Field var field Field
field.Name = m.Name field.Name = m.Name
field.TypeName = m.TypeName field.Type = m.Type
field.JsonTag = m.JsonTag field.JsonTag = m.JsonTag
field.Document = m.Document field.Document = m.Document
field.Comment = m.Comment field.Comment = m.Comment
@ -255,7 +381,7 @@ func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct,
return &st, nil return &st, nil
} }
func (a *astParser) parseFields(fields []*ast.Field, onlyType bool) ([]*Field, error) { func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) {
ret := make([]*Field, 0) ret := make([]*Field, 0)
for _, field := range fields { for _, field := range fields {
var item Field var item Field
@ -278,7 +404,7 @@ func (a *astParser) parseFields(fields []*ast.Field, onlyType bool) ([]*Field, e
return nil, err return nil, err
} }
item.TypeName = typeName item.Type = a.getFieldType(typeName, sourcePackage)
if onlyType { if onlyType {
ret = append(ret, &item) ret = append(ret, &item)
continue continue
@ -414,10 +540,30 @@ func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{})
return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...)) return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...))
} }
func (f *Func) GetDoc() string {
return strings.Join(f.Document, util.NL)
}
func (f *Func) HaveDoc() bool {
return len(f.Document) > 0
}
func (a *PbAst) GenEnumCode() (string, error) {
var element []string
for _, item := range a.Enum {
code, err := item.GenEnumCode()
if err != nil {
return "", err
}
element = append(element, code)
}
return strings.Join(element, util.NL), nil
}
func (a *PbAst) GenTypesCode() (string, error) { func (a *PbAst) GenTypesCode() (string, error) {
types := make([]string, 0) types := make([]string, 0)
sts := make([]*Struct, 0) sts := make([]*Struct, 0)
for _, item := range a.Strcuts { for _, item := range a.Structure {
sts = append(sts, item) sts = append(sts, item)
} }
sort.Slice(sts, func(i, j int) bool { sort.Slice(sts, func(i, j int) bool {
@ -434,8 +580,17 @@ func (a *PbAst) GenTypesCode() (string, error) {
} }
types = append(types, structCode) types = append(types, structCode)
} }
types = append(types, a.genAnyCode())
for _, item := range a.Enum {
typeCode, err := item.GenEnumTypeCode()
if err != nil {
return "", err
}
types = append(types, typeCode)
}
buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{ buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
"types": strings.Join(types, "\n\n"), "types": strings.Join(types, util.NL+util.NL),
}) })
if err != nil { if err != nil {
return "", err return "", err
@ -444,6 +599,13 @@ func (a *PbAst) GenTypesCode() (string, error) {
return buffer.String(), nil return buffer.String(), nil
} }
func (a *PbAst) genAnyCode() string {
if !a.ContainsAny {
return ""
}
return anyTypeTemplate
}
func (s *Struct) genCode(containsTypeStatement bool) (string, error) { func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
fields := make([]string, 0) fields := make([]string, 0)
for _, f := range s.Field { for _, f := range s.Field {
@ -451,10 +613,10 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
if len(f.Comment) > 0 { if len(f.Comment) > 0 {
comment = f.Comment[0] comment = f.Comment[0]
} }
doc = strings.Join(f.Document, "\n") doc = strings.Join(f.Document, util.NL)
buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{ buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
"name": f.Name.Title(), "name": f.Name.Title(),
"type": f.TypeName, "type": f.Type.InvokeTypeExpression,
"tag": f.JsonTag, "tag": f.JsonTag,
"hasDoc": len(f.Document) > 0, "hasDoc": len(f.Document) > 0,
"doc": doc, "doc": doc,
@ -470,7 +632,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{ buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
"type": containsTypeStatement, "type": containsTypeStatement,
"name": s.Name.Title(), "name": s.Name.Title(),
"fields": strings.Join(fields, "\n"), "fields": strings.Join(fields, util.NL),
}) })
if err != nil { if err != nil {
return "", err return "", err

View File

@ -0,0 +1,294 @@
package parser
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/emicklei/proto"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
AnyImport = "google/protobuf/any.proto"
)
var (
enumTypeTemplate = `{{.name}} int32`
enumTemplate = `const (
{{.element}}
)`
enumFiledTemplate = `{{.key}} {{.name}} = {{.value}}`
)
type (
MessageField struct {
Type string
Name stringx.String
}
Message struct {
Name stringx.String
Element []*MessageField
*proto.Message
}
Enum struct {
Name stringx.String
Element []*EnumField
*proto.Enum
}
EnumField struct {
Key string
Value int
}
Proto struct {
Package string
Import []*Import
PbSrc string
ContainsAny bool
Message map[string]lang.PlaceholderType
Enum map[string]*Enum
}
Import struct {
ProtoImportName string
PbImportName string
OriginalDir string
OriginalProtoPath string
OriginalPbPath string
BridgeImport string
exists bool
//xx.proto
protoName string
// xx.pb.go
pbName string
}
)
func checkImport(src string) error {
r, err := os.Open(src)
if err != nil {
return err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return err
}
var base = filepath.Base(src)
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if err != nil {
return
}
err = fmt.Errorf("%v:%v the external proto cannot import other proto files", base, i.Position.Line)
}))
if err != nil {
return err
}
return nil
}
func ParseImport(src string) ([]*Import, bool, error) {
bridgeImportM := make(map[string]string)
r, err := os.Open(src)
if err != nil {
return nil, false, err
}
defer r.Close()
workDir := filepath.Dir(src)
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, false, err
}
protoImportSet := collection.NewSet()
var containsAny bool
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if i.Filename == AnyImport {
containsAny = true
return
}
protoImportSet.AddStr(i.Filename)
if i.Comment != nil {
lines := i.Comment.Lines
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "@") {
continue
}
line = strings.TrimPrefix(line, "@")
bridgeImportM[i.Filename] = line
}
}
}))
var importList []*Import
for _, item := range protoImportSet.KeysStr() {
pb := strings.TrimSuffix(filepath.Base(item), filepath.Ext(item)) + ".pb.go"
var pbImportName, brideImport string
if v, ok := bridgeImportM[item]; ok {
pbImportName = v
brideImport = "M" + item + "=" + v
} else {
pbImportName = item
}
var impo = Import{
ProtoImportName: item,
PbImportName: pbImportName,
BridgeImport: brideImport,
}
protoSource := filepath.Join(workDir, item)
pbSource := filepath.Join(filepath.Dir(protoSource), pb)
if util.FileExists(protoSource) && util.FileExists(pbSource) {
impo.OriginalProtoPath = protoSource
impo.OriginalPbPath = pbSource
impo.OriginalDir = filepath.Dir(protoSource)
impo.exists = true
impo.protoName = filepath.Base(item)
impo.pbName = pb
} else {
return nil, false, fmt.Errorf("「%v」: import must be found in the relative directory of 「%v」", item, filepath.Base(src))
}
importList = append(importList, &impo)
}
return importList, containsAny, nil
}
func parseProto(src string, messageM map[string]lang.PlaceholderType, enumM map[string]*Enum) (*Proto, error) {
if !filepath.IsAbs(src) {
return nil, fmt.Errorf("expected absolute path,but found: %v", src)
}
r, err := os.Open(src)
if err != nil {
return nil, err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, err
}
// xx.proto
fileBase := filepath.Base(src)
var resp Proto
proto.Walk(parseRet, proto.WithPackage(func(p *proto.Package) {
if err != nil {
return
}
if len(resp.Package) != 0 {
err = fmt.Errorf("%v:%v duplicate package「%v」", fileBase, p.Position.Line, p.Name)
}
if len(p.Name) == 0 {
err = errors.New("package not found")
}
resp.Package = p.Name
}), proto.WithMessage(func(message *proto.Message) {
if err != nil {
return
}
for _, item := range message.Elements {
switch item.(type) {
case *proto.NormalField, *proto.MapField, *proto.Comment:
continue
default:
err = fmt.Errorf("%v: unsupport inline declaration", fileBase)
return
}
}
name := stringx.From(message.Name)
if _, ok := messageM[name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate message 「%v」", fileBase, message.Position.Line, message.Name)
return
}
messageM[name.Lower()] = lang.Placeholder
}), proto.WithEnum(func(enum *proto.Enum) {
if err != nil {
return
}
var node Enum
node.Enum = enum
node.Name = stringx.From(enum.Name)
for _, item := range enum.Elements {
v, ok := item.(*proto.EnumField)
if !ok {
continue
}
node.Element = append(node.Element, &EnumField{
Key: v.Name,
Value: v.Integer,
})
}
if _, ok := enumM[node.Name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate enum 「%v」", fileBase, node.Position.Line, node.Name.Source())
return
}
lower := stringx.From(enum.Name).Lower()
enumM[lower] = &node
}))
if err != nil {
return nil, err
}
resp.Message = messageM
resp.Enum = enumM
return &resp, nil
}
func (e *Enum) GenEnumCode() (string, error) {
var element []string
for _, item := range e.Element {
code, err := item.GenEnumFieldCode(e.Name.Source())
if err != nil {
return "", err
}
element = append(element, code)
}
buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
"element": strings.Join(element, util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *Enum) GenEnumTypeCode() (string, error) {
buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
"name": e.Name.Source(),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) {
buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
"key": e.Key,
"name": parentName,
"value": e.Value,
})
if err != nil {
return "", err
}
return buffer.String(), nil
}

View File

@ -0,0 +1,28 @@
syntax = "proto3";
// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test
package test;
// @github.com/tal-tech/go-zero/tools/goctl/rpc
import "base.proto";
import "google/protobuf/any.proto";
message request {
string name = 1;
}
enum Gender{
UNKNOWN = 0;
MALE = 1;
FEMALE = 2;
}
message response {
string greet = 1;
google.protobuf.Any data = 2;
}
message map {
map<string, string> m = 1;
}
service Greeter {
rpc greet(request) returns (response);
rpc idRequest(base.IdRequest)returns(base.EmptyResponse);
}

View File

@ -0,0 +1,331 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: test.proto
// protoc -I=${GOPATH}/src -I=. test.proto --go_out=plugins=grpc,Mbase.proto=github.com/tal-tech/go-zero/tools/goctl/rpc:./test
package test
import (
context "context"
fmt "fmt"
proto "github.com/golang/protobuf/proto"
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
anypb "google.golang.org/protobuf/types/known/anypb"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type Gender int32
const (
Gender_UNKNOWN Gender = 0
Gender_MALE Gender = 1
Gender_FEMALE Gender = 2
)
var Gender_name = map[int32]string{
0: "UNKNOWN",
1: "MALE",
2: "FEMALE",
}
var Gender_value = map[string]int32{
"UNKNOWN": 0,
"MALE": 1,
"FEMALE": 2,
}
func (x Gender) String() string {
return proto.EnumName(Gender_name, int32(x))
}
func (Gender) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{0}
}
type Request struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Request) Reset() { *m = Request{} }
func (m *Request) String() string { return proto.CompactTextString(m) }
func (*Request) ProtoMessage() {}
func (*Request) Descriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{0}
}
func (m *Request) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Request.Unmarshal(m, b)
}
func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Request.Marshal(b, m, deterministic)
}
func (m *Request) XXX_Merge(src proto.Message) {
xxx_messageInfo_Request.Merge(m, src)
}
func (m *Request) XXX_Size() int {
return xxx_messageInfo_Request.Size(m)
}
func (m *Request) XXX_DiscardUnknown() {
xxx_messageInfo_Request.DiscardUnknown(m)
}
var xxx_messageInfo_Request proto.InternalMessageInfo
func (m *Request) GetName() string {
if m != nil {
return m.Name
}
return ""
}
type Response struct {
Greet string `protobuf:"bytes,1,opt,name=greet,proto3" json:"greet,omitempty"`
Data *anypb.Any `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Response) Reset() { *m = Response{} }
func (m *Response) String() string { return proto.CompactTextString(m) }
func (*Response) ProtoMessage() {}
func (*Response) Descriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{1}
}
func (m *Response) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Response.Unmarshal(m, b)
}
func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Response.Marshal(b, m, deterministic)
}
func (m *Response) XXX_Merge(src proto.Message) {
xxx_messageInfo_Response.Merge(m, src)
}
func (m *Response) XXX_Size() int {
return xxx_messageInfo_Response.Size(m)
}
func (m *Response) XXX_DiscardUnknown() {
xxx_messageInfo_Response.DiscardUnknown(m)
}
var xxx_messageInfo_Response proto.InternalMessageInfo
func (m *Response) GetGreet() string {
if m != nil {
return m.Greet
}
return ""
}
func (m *Response) GetData() *anypb.Any {
if m != nil {
return m.Data
}
return nil
}
type Map struct {
M map[string]string `protobuf:"bytes,1,rep,name=m,proto3" json:"m,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Map) Reset() { *m = Map{} }
func (m *Map) String() string { return proto.CompactTextString(m) }
func (*Map) ProtoMessage() {}
func (*Map) Descriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{2}
}
func (m *Map) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Map.Unmarshal(m, b)
}
func (m *Map) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Map.Marshal(b, m, deterministic)
}
func (m *Map) XXX_Merge(src proto.Message) {
xxx_messageInfo_Map.Merge(m, src)
}
func (m *Map) XXX_Size() int {
return xxx_messageInfo_Map.Size(m)
}
func (m *Map) XXX_DiscardUnknown() {
xxx_messageInfo_Map.DiscardUnknown(m)
}
var xxx_messageInfo_Map proto.InternalMessageInfo
func (m *Map) GetM() map[string]string {
if m != nil {
return m.M
}
return nil
}
func init() {
proto.RegisterEnum("test.Gender", Gender_name, Gender_value)
proto.RegisterType((*Request)(nil), "test.request")
proto.RegisterType((*Response)(nil), "test.response")
proto.RegisterType((*Map)(nil), "test.map")
proto.RegisterMapType((map[string]string)(nil), "test.map.MEntry")
}
func init() { proto.RegisterFile("test.proto", fileDescriptor_c161fcfdc0c3ff1e) }
var fileDescriptor_c161fcfdc0c3ff1e = []byte{
// 301 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x34, 0x90, 0x4f, 0x4b, 0xc3, 0x40,
0x10, 0xc5, 0xdd, 0x36, 0xa6, 0xed, 0x14, 0x35, 0x8c, 0x3d, 0xd4, 0x80, 0x52, 0x7a, 0x90, 0xa0,
0xb0, 0xc5, 0xea, 0x41, 0xbc, 0xf5, 0x10, 0x8b, 0x7f, 0x5a, 0x21, 0x20, 0x1e, 0x3c, 0x6d, 0xc9,
0x58, 0xc4, 0xee, 0x26, 0x6e, 0xb6, 0x42, 0xbe, 0xbd, 0x64, 0x77, 0x73, 0x7b, 0xbf, 0xd9, 0x59,
0xde, 0x9b, 0x07, 0x60, 0xa8, 0x32, 0xbc, 0xd4, 0x85, 0x29, 0x30, 0x68, 0x74, 0x0c, 0x1b, 0x51,
0x91, 0x9b, 0xc4, 0x67, 0xdb, 0xa2, 0xd8, 0xee, 0x68, 0x66, 0x69, 0xb3, 0xff, 0x9a, 0x09, 0x55,
0xbb, 0xa7, 0xe9, 0x39, 0xf4, 0x34, 0xfd, 0xee, 0xa9, 0x32, 0x88, 0x10, 0x28, 0x21, 0x69, 0xcc,
0x26, 0x2c, 0x19, 0x64, 0x56, 0x4f, 0x9f, 0xa1, 0xaf, 0xa9, 0x2a, 0x0b, 0x55, 0x11, 0x8e, 0xe0,
0x70, 0xab, 0x89, 0x8c, 0x5f, 0x70, 0x80, 0x09, 0x04, 0xb9, 0x30, 0x62, 0xdc, 0x99, 0xb0, 0x64,
0x38, 0x1f, 0x71, 0x67, 0xc5, 0x5b, 0x2b, 0xbe, 0x50, 0x75, 0x66, 0x37, 0xa6, 0x9f, 0xd0, 0x95,
0xa2, 0xc4, 0x0b, 0x60, 0x72, 0xcc, 0x26, 0xdd, 0x64, 0x38, 0x8f, 0xb8, 0x8d, 0x2d, 0x45, 0xc9,
0x57, 0xa9, 0x32, 0xba, 0xce, 0x98, 0x8c, 0xef, 0x20, 0x74, 0x80, 0x11, 0x74, 0x7f, 0xa8, 0xf6,
0x76, 0x8d, 0x6c, 0x22, 0xfc, 0x89, 0xdd, 0x9e, 0xac, 0xdb, 0x20, 0x73, 0xf0, 0xd0, 0xb9, 0x67,
0x57, 0xd7, 0x10, 0x2e, 0x49, 0xe5, 0xa4, 0x71, 0x08, 0xbd, 0xf7, 0xf5, 0xcb, 0xfa, 0xed, 0x63,
0x1d, 0x1d, 0x60, 0x1f, 0x82, 0xd5, 0xe2, 0x35, 0x8d, 0x18, 0x02, 0x84, 0x8f, 0xa9, 0xd5, 0x9d,
0x79, 0x0e, 0xbd, 0x65, 0x13, 0x9e, 0x34, 0x5e, 0xfa, 0xa3, 0xf0, 0xc8, 0x65, 0xf1, 0x65, 0xc4,
0xc7, 0x2d, 0xfa, 0xe3, 0x6f, 0x60, 0xf0, 0x9d, 0x67, 0xbe, 0xa9, 0x13, 0x6e, 0xcb, 0x7d, 0x6a,
0x07, 0xf1, 0xa9, 0x1b, 0xa4, 0xb2, 0x34, 0x75, 0xe6, 0xbf, 0x6c, 0x42, 0xdb, 0xc1, 0xed, 0x7f,
0x00, 0x00, 0x00, 0xff, 0xff, 0x48, 0x7a, 0x3b, 0x55, 0x9c, 0x01, 0x00, 0x00,
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// GreeterClient is the client API for Greeter service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type GreeterClient interface {
Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error)
IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error)
}
type greeterClient struct {
cc *grpc.ClientConn
}
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient {
return &greeterClient{cc}
}
func (c *greeterClient) Greet(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) {
out := new(Response)
err := c.cc.Invoke(ctx, "/test.Greeter/greet", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *greeterClient) IdRequest(ctx context.Context, in *rpc.IdRequest, opts ...grpc.CallOption) (*rpc.EmptyResponse, error) {
out := new(rpc.EmptyResponse)
err := c.cc.Invoke(ctx, "/test.Greeter/idRequest", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// GreeterServer is the server API for Greeter service.
type GreeterServer interface {
Greet(context.Context, *Request) (*Response, error)
IdRequest(context.Context, *rpc.IdRequest) (*rpc.EmptyResponse, error)
}
// UnimplementedGreeterServer can be embedded to have forward compatible implementations.
type UnimplementedGreeterServer struct {
}
func (*UnimplementedGreeterServer) Greet(ctx context.Context, req *Request) (*Response, error) {
return nil, status.Errorf(codes.Unimplemented, "method Greet not implemented")
}
func (*UnimplementedGreeterServer) IdRequest(ctx context.Context, req *rpc.IdRequest) (*rpc.EmptyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method IdRequest not implemented")
}
func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
s.RegisterService(&_Greeter_serviceDesc, srv)
}
func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Request)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GreeterServer).Greet(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/test.Greeter/Greet",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GreeterServer).Greet(ctx, req.(*Request))
}
return interceptor(ctx, in, info, handler)
}
func _Greeter_IdRequest_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(rpc.IdRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GreeterServer).IdRequest(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/test.Greeter/IdRequest",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GreeterServer).IdRequest(ctx, req.(*rpc.IdRequest))
}
return interceptor(ctx, in, info, handler)
}
var _Greeter_serviceDesc = grpc.ServiceDesc{
ServiceName: "test.Greeter",
HandlerType: (*GreeterServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "greet",
Handler: _Greeter_Greet_Handler,
},
{
MethodName: "idRequest",
Handler: _Greeter_IdRequest_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "test.proto",
}

View File

@ -11,9 +11,11 @@ type (
Console interface { Console interface {
Success(format string, a ...interface{}) Success(format string, a ...interface{})
Info(format string, a ...interface{}) Info(format string, a ...interface{})
Debug(format string, a ...interface{})
Warning(format string, a ...interface{}) Warning(format string, a ...interface{})
Error(format string, a ...interface{}) Error(format string, a ...interface{})
Fatalln(format string, a ...interface{}) Fatalln(format string, a ...interface{})
MarkDone()
Must(err error) Must(err error)
} }
colorConsole struct { colorConsole struct {
@ -39,6 +41,11 @@ func (c *colorConsole) Info(format string, a ...interface{}) {
fmt.Println(msg) fmt.Println(msg)
} }
func (c *colorConsole) Debug(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...)
fmt.Println(aurora.Blue(msg))
}
func (c *colorConsole) Success(format string, a ...interface{}) { func (c *colorConsole) Success(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...) msg := fmt.Sprintf(format, a...)
fmt.Println(aurora.Green(msg)) fmt.Println(aurora.Green(msg))
@ -59,6 +66,10 @@ func (c *colorConsole) Fatalln(format string, a ...interface{}) {
os.Exit(1) os.Exit(1)
} }
func (c *colorConsole) MarkDone() {
c.Success("Done.")
}
func (c *colorConsole) Must(err error) { func (c *colorConsole) Must(err error) {
if err != nil { if err != nil {
c.Fatalln("%+v", err) c.Fatalln("%+v", err)
@ -74,6 +85,11 @@ func (i *ideaConsole) Info(format string, a ...interface{}) {
fmt.Println(msg) fmt.Println(msg)
} }
func (i *ideaConsole) Debug(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...)
fmt.Println(aurora.Blue(msg))
}
func (i *ideaConsole) Success(format string, a ...interface{}) { func (i *ideaConsole) Success(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...) msg := fmt.Sprintf(format, a...)
fmt.Println("[SUCCESS]: ", msg) fmt.Println("[SUCCESS]: ", msg)
@ -94,6 +110,10 @@ func (i *ideaConsole) Fatalln(format string, a ...interface{}) {
os.Exit(1) os.Exit(1)
} }
func (i *ideaConsole) MarkDone() {
i.Success("Done.")
}
func (i *ideaConsole) Must(err error) { func (i *ideaConsole) Must(err error) {
if err != nil { if err != nil {
i.Fatalln("%+v", err) i.Fatalln("%+v", err)

View File

@ -10,6 +10,10 @@ import (
"github.com/logrusorgru/aurora" "github.com/logrusorgru/aurora"
) )
const (
NL = "\n"
)
func CreateIfNotExist(file string) (*os.File, error) { func CreateIfNotExist(file string) (*os.File, error) {
_, err := os.Stat(file) _, err := os.Stat(file)
if !os.IsNotExist(err) { if !os.IsNotExist(err) {

View File

@ -1,6 +1,7 @@
package project package project
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
@ -38,18 +39,18 @@ type (
func Prepare(projectDir string, checkGrpcEnv bool) (*Project, error) { func Prepare(projectDir string, checkGrpcEnv bool) (*Project, error) {
_, err := exec.LookPath(constGo) _, err := exec.LookPath(constGo)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("please install go first,reference documents:「https://golang.org/doc/install」")
} }
if checkGrpcEnv { if checkGrpcEnv {
_, err = exec.LookPath(constProtoC) _, err = exec.LookPath(constProtoC)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("please install protoc first,reference documents:「https://github.com/golang/protobuf」")
} }
_, err = exec.LookPath(constProtoCGenGo) _, err = exec.LookPath(constProtoCGenGo)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("please install plugin protoc-gen-go first,reference documents:「https://github.com/golang/protobuf」")
} }
} }

View File

@ -3,4 +3,7 @@ package vars
const ( const (
ProjectName = "zero" ProjectName = "zero"
ProjectOpenSourceUrl = "github.com/tal-tech/go-zero" ProjectOpenSourceUrl = "github.com/tal-tech/go-zero"
OsWindows = "windows"
OsMac = "darwin"
OsLinux = "linux"
) )