mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-02 16:28:39 +08:00
Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
This commit is contained in:
parent
598fda0c97
commit
db87fd3239
@ -29,7 +29,6 @@ func TestRpcGenerate(t *testing.T) {
|
||||
projectName := stringx.Rand()
|
||||
g := NewRPCGenerator(dispatcher, cfg)
|
||||
|
||||
// case go path
|
||||
src := filepath.Join(build.Default.GOPATH, "src")
|
||||
_, err = os.Stat(src)
|
||||
if err != nil {
|
||||
@ -41,45 +40,52 @@ func TestRpcGenerate(t *testing.T) {
|
||||
defer func() {
|
||||
_ = os.RemoveAll(srcDir)
|
||||
}()
|
||||
|
||||
common, err := filepath.Abs(".")
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
// case go path
|
||||
t.Run("GOPATH", func(t *testing.T) {
|
||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
})
|
||||
|
||||
// case go mod
|
||||
workDir := t.TempDir()
|
||||
name := filepath.Base(workDir)
|
||||
_, err = execx.Run("go mod init "+name, workDir)
|
||||
if err != nil {
|
||||
logx.Error(err)
|
||||
return
|
||||
}
|
||||
t.Run("GOMOD", func(t *testing.T) {
|
||||
workDir := t.TempDir()
|
||||
name := filepath.Base(workDir)
|
||||
_, err = execx.Run("go mod init "+name, workDir)
|
||||
if err != nil {
|
||||
logx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
projectDir = filepath.Join(workDir, projectName)
|
||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
projectDir = filepath.Join(workDir, projectName)
|
||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// case not in go mod and go path
|
||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
t.Run("OTHER", func(t *testing.T) {
|
||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
assert.True(t, func() bool {
|
||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||
}())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -49,13 +49,13 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
|
||||
`
|
||||
|
||||
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
|
||||
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
|
||||
{{end}}{{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error)`
|
||||
|
||||
callFunctionTemplate = `
|
||||
{{if .hasComment}}{{.comment}}{{end}}
|
||||
func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
|
||||
func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) {
|
||||
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
|
||||
return client.{{.method}}(ctx, in)
|
||||
return client.{{.method}}(ctx,{{if .hasReq}} in{{end}})
|
||||
}
|
||||
`
|
||||
)
|
||||
@ -78,7 +78,7 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
|
||||
return err
|
||||
}
|
||||
|
||||
iFunctions, err := g.getInterfaceFuncs(service)
|
||||
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -115,6 +115,14 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
||||
}
|
||||
|
||||
comment := parser.GetComment(rpc.Doc())
|
||||
var streamServer string
|
||||
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient")
|
||||
} else if rpc.StreamsRequest {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient")
|
||||
} else {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient")
|
||||
}
|
||||
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
|
||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||
"rpcServiceName": parser.CamelCase(service.Name),
|
||||
@ -124,6 +132,9 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||
"streamBody": streamServer,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -134,7 +145,7 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
||||
return functions, nil
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) {
|
||||
func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Service) ([]string, error) {
|
||||
functions := make([]string, 0)
|
||||
|
||||
for _, rpc := range service.RPC {
|
||||
@ -144,13 +155,25 @@ func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string,
|
||||
}
|
||||
|
||||
comment := parser.GetComment(rpc.Doc())
|
||||
var streamServer string
|
||||
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient")
|
||||
} else if rpc.StreamsRequest {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient")
|
||||
} else {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient")
|
||||
}
|
||||
|
||||
buffer, err := util.With("interfaceFn").Parse(text).Execute(
|
||||
map[string]interface{}{
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
"method": parser.CamelCase(rpc.Name),
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
"pbRequest": parser.CamelCase(rpc.RequestType),
|
||||
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||
"streamBody": streamServer,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -40,10 +40,10 @@ func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logic
|
||||
{{.functions}}
|
||||
`
|
||||
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
|
||||
func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
|
||||
func (l *{{.logicName}}) {{.method}} ({{if .hasReq}}in {{.request}}{{if .stream}},stream {{.streamBody}}{{end}}{{else}}stream {{.streamBody}}{{end}}) ({{if .hasReply}}{{.response}},{{end}} error) {
|
||||
// todo: add your logic here and delete this line
|
||||
|
||||
return &{{.responseType}}{}, nil
|
||||
return {{if .hasReply}}&{{.responseType}}{},{{end}} nil
|
||||
}
|
||||
`
|
||||
)
|
||||
@ -51,6 +51,7 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
|
||||
// GenLogic generates the logic file of the rpc service, which corresponds to the RPC definition items in proto.
|
||||
func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetLogic()
|
||||
service := proto.Service.Service.Name
|
||||
for _, rpc := range proto.Service.RPC {
|
||||
logicFilename, err := format.FileNamingFormat(cfg.NamingFormat, rpc.Name+"_logic")
|
||||
if err != nil {
|
||||
@ -58,7 +59,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
|
||||
}
|
||||
|
||||
filename := filepath.Join(dir.Filename, logicFilename+".go")
|
||||
functions, err := g.genLogicFunction(proto.PbPackage, rpc)
|
||||
functions, err := g.genLogicFunction(service, proto.PbPackage, rpc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -82,7 +83,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) {
|
||||
func (g *DefaultGenerator) genLogicFunction(serviceName string, goPackage string, rpc *parser.RPC) (string, error) {
|
||||
functions := make([]string, 0)
|
||||
text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
|
||||
if err != nil {
|
||||
@ -91,12 +92,24 @@ func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (
|
||||
|
||||
logicName := stringx.From(rpc.Name + "_logic").ToCamel()
|
||||
comment := parser.GetComment(rpc.Doc())
|
||||
var streamServer string
|
||||
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_StreamServer")
|
||||
} else if rpc.StreamsRequest {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ClientStreamServer")
|
||||
} else {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ServerStreamServer")
|
||||
}
|
||||
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
|
||||
"logicName": logicName,
|
||||
"method": parser.CamelCase(rpc.Name),
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
"request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
|
||||
"hasReply": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
||||
"responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
||||
"stream": rpc.StreamsRequest || rpc.StreamsReturns,
|
||||
"streamBody": streamServer,
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
})
|
||||
|
@ -38,9 +38,9 @@ func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
|
||||
`
|
||||
functionTemplate = `
|
||||
{{if .hasComment}}{{.comment}}{{end}}
|
||||
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
|
||||
l := logic.New{{.logicName}}(ctx,s.svcCtx)
|
||||
return l.{{.method}}(in)
|
||||
func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{if .hasReq}} in {{.request}}{{end}}{{else}}{{if .hasReq}} in {{.request}},{{end}}stream {{.streamBody}}{{end}}) ({{if .notStream}}{{.response}},{{end}}error) {
|
||||
l := logic.New{{.logicName}}({{if .notStream}}ctx,{{else}}stream.Context(),{{end}}s.svcCtx)
|
||||
return l.{{.method}}({{if .hasReq}}in{{if .stream}} ,stream{{end}}{{else}}{{if .stream}}stream{{end}}{{end}})
|
||||
}
|
||||
`
|
||||
)
|
||||
@ -91,6 +91,15 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service
|
||||
}
|
||||
|
||||
comment := parser.GetComment(rpc.Doc())
|
||||
var streamServer string
|
||||
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamServer")
|
||||
} else if rpc.StreamsRequest {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamServer")
|
||||
} else {
|
||||
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamServer")
|
||||
}
|
||||
|
||||
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
|
||||
"server": stringx.From(service.Name).ToCamel(),
|
||||
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
|
||||
@ -99,6 +108,10 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service
|
||||
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
"stream": rpc.StreamsRequest || rpc.StreamsReturns,
|
||||
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||
"streamBody": streamServer,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -59,4 +59,10 @@ service Test_Service {
|
||||
rpc MapService (MapReq) returns (CommonReply);
|
||||
// case repeated
|
||||
rpc RepeatedService (RepeatedReq) returns (CommonReply);
|
||||
// server stream
|
||||
rpc ServerStream (Req) returns (stream Reply);
|
||||
// client stream
|
||||
rpc ClientStream (stream Req) returns (Reply);
|
||||
// stream
|
||||
rpc Stream(stream Req) returns (stream Reply);
|
||||
}
|
16
tools/goctl/rpc/parser/stream.proto
Normal file
16
tools/goctl/rpc/parser/stream.proto
Normal file
@ -0,0 +1,16 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package test;
|
||||
|
||||
message Req{
|
||||
string input = 1;
|
||||
}
|
||||
|
||||
message Reply{
|
||||
string output = 1;
|
||||
}
|
||||
service TestService{
|
||||
rpc ServerStream (Req) returns (stream Reply);
|
||||
rpc ClientStream (stream Req) returns (Reply);
|
||||
rpc Stream (stream Req) returns (stream Reply);
|
||||
}
|
Loading…
Reference in New Issue
Block a user