mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-02 16:28:39 +08:00
feat: verify RpcPath on startup (#2159)
* feat: verify RpcPath on startup * feat: support http header Grpc-Timeout
This commit is contained in:
parent
b206dd28a3
commit
557383fbbf
@ -21,8 +21,8 @@ type (
|
||||
Method string
|
||||
// Path is the HTTP path.
|
||||
Path string
|
||||
// Rpc is the gRPC rpc method, with format of package.service/method
|
||||
Rpc string
|
||||
// RpcPath is the gRPC rpc method, with format of package.service/method
|
||||
RpcPath string
|
||||
}
|
||||
|
||||
// upstream is the configuration for upstream.
|
||||
|
34
gateway/internal/descriptorsource.go
Normal file
34
gateway/internal/descriptorsource.go
Normal file
@ -0,0 +1,34 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fullstorydev/grpcurl"
|
||||
"github.com/jhump/protoreflect/desc"
|
||||
)
|
||||
|
||||
// GetMethods returns all methods of the given grpcurl.DescriptorSource.
|
||||
func GetMethods(source grpcurl.DescriptorSource) ([]string, error) {
|
||||
svcs, err := source.ListServices()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var methods []string
|
||||
for _, svc := range svcs {
|
||||
d, err := source.FindSymbol(svc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch val := d.(type) {
|
||||
case *desc.ServiceDescriptor:
|
||||
svcMethods := val.GetMethods()
|
||||
for _, method := range svcMethods {
|
||||
methods = append(methods, fmt.Sprintf("%s/%s", svc, method.GetName()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return methods, nil
|
||||
}
|
29
gateway/internal/descriptorsource_test.go
Normal file
29
gateway/internal/descriptorsource_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/fullstorydev/grpcurl"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
)
|
||||
|
||||
const b64pb = `CpgBCgtoZWxsby5wcm90bxIFaGVsbG8iHQoHUmVxdWVzdBISCgRwaW5nGAEgASgJUgRwaW5nIh4KCFJlc3BvbnNlEhIKBHBvbmcYASABKAlSBHBvbmcyMAoFSGVsbG8SJwoEUGluZxIOLmhlbGxvLlJlcXVlc3QaDy5oZWxsby5SZXNwb25zZUIJWgcuL2hlbGxvYgZwcm90bzM=`
|
||||
|
||||
func TestGetMethods(t *testing.T) {
|
||||
tmpfile, err := ioutil.TempFile(os.TempDir(), hash.Md5Hex([]byte(b64pb)))
|
||||
assert.Nil(t, err)
|
||||
b, err := base64.StdEncoding.DecodeString(b64pb)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, ioutil.WriteFile(tmpfile.Name(), b, os.ModeTemporary))
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
source, err := grpcurl.DescriptorSourceFromProtoSets(tmpfile.Name())
|
||||
assert.Nil(t, err)
|
||||
methods, err := GetMethods(source)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"hello.Hello/Ping"}, methods)
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package gateway
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -11,7 +11,8 @@ const (
|
||||
metadataPrefix = "gateway-"
|
||||
)
|
||||
|
||||
func buildHeaders(header http.Header) []string {
|
||||
// BuildHeaders builds the headers for the gateway from HTTP headers.
|
||||
func BuildHeaders(header http.Header) []string {
|
||||
var headers []string
|
||||
|
||||
for k, v := range header {
|
@ -1,4 +1,4 @@
|
||||
package gateway
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
@ -10,12 +10,12 @@ import (
|
||||
func TestBuildHeadersNoValue(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Add("a", "b")
|
||||
assert.Nil(t, buildHeaders(req.Header))
|
||||
assert.Nil(t, BuildHeaders(req.Header))
|
||||
}
|
||||
|
||||
func TestBuildHeadersWithValues(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Add("grpc-metadata-a", "b")
|
||||
req.Header.Add("grpc-metadata-b", "b")
|
||||
assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, buildHeaders(req.Header))
|
||||
assert.EqualValues(t, []string{"gateway-A:b", "gateway-B:b"}, BuildHeaders(req.Header))
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package gateway
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -11,17 +11,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/rest/pathvar"
|
||||
)
|
||||
|
||||
func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) (
|
||||
grpcurl.RequestParser, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return grpcurl.NewJSONRequestParser(&buf, resolver), nil
|
||||
}
|
||||
|
||||
func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
|
||||
// NewRequestParser creates a new request parser from the given http.Request and resolver.
|
||||
func NewRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) {
|
||||
vars := pathvar.Vars(r)
|
||||
params, err := httpx.GetFormValues(r)
|
||||
if err != nil {
|
||||
@ -50,3 +41,13 @@ func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.Req
|
||||
|
||||
return buildJsonRequestParser(m, resolver)
|
||||
}
|
||||
|
||||
func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) (
|
||||
grpcurl.RequestParser, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return grpcurl.NewJSONRequestParser(&buf, resolver), nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package gateway
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
@ -11,7 +11,7 @@ import (
|
||||
|
||||
func TestNewRequestParserNoVar(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
parser, err := newRequestParser(req, nil)
|
||||
parser, err := NewRequestParser(req, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
}
|
||||
@ -19,14 +19,14 @@ func TestNewRequestParserNoVar(t *testing.T) {
|
||||
func TestNewRequestParserWithVars(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req = pathvar.WithVars(req, map[string]string{"a": "b"})
|
||||
parser, err := newRequestParser(req, nil)
|
||||
parser, err := NewRequestParser(req, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
}
|
||||
|
||||
func TestNewRequestParserNoVarWithBody(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`))
|
||||
parser, err := newRequestParser(req, nil)
|
||||
parser, err := NewRequestParser(req, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
}
|
||||
@ -34,7 +34,7 @@ func TestNewRequestParserNoVarWithBody(t *testing.T) {
|
||||
func TestNewRequestParserWithVarsWithBody(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"}`))
|
||||
req = pathvar.WithVars(req, map[string]string{"c": "d"})
|
||||
parser, err := newRequestParser(req, nil)
|
||||
parser, err := NewRequestParser(req, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
}
|
||||
@ -42,14 +42,14 @@ func TestNewRequestParserWithVarsWithBody(t *testing.T) {
|
||||
func TestNewRequestParserWithVarsWithWrongBody(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", strings.NewReader(`{"a": "b"`))
|
||||
req = pathvar.WithVars(req, map[string]string{"c": "d"})
|
||||
parser, err := newRequestParser(req, nil)
|
||||
parser, err := NewRequestParser(req, nil)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, parser)
|
||||
}
|
||||
|
||||
func TestNewRequestParserWithForm(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/val?a=b", nil)
|
||||
parser, err := newRequestParser(req, nil)
|
||||
parser, err := NewRequestParser(req, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, parser)
|
||||
}
|
19
gateway/internal/timeout.go
Normal file
19
gateway/internal/timeout.go
Normal file
@ -0,0 +1,19 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const grpcTimeoutHeader = "Grpc-Timeout"
|
||||
|
||||
// GetTimeout returns the timeout from the header, if not set, returns the default timeout.
|
||||
func GetTimeout(header http.Header, defaultTimeout time.Duration) time.Duration {
|
||||
if timeout := header.Get(grpcTimeoutHeader); len(timeout) > 0 {
|
||||
if t, err := time.ParseDuration(timeout); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
return defaultTimeout
|
||||
}
|
22
gateway/internal/timeout_test.go
Normal file
22
gateway/internal/timeout_test.go
Normal file
@ -0,0 +1,22 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetTimeout(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(grpcTimeoutHeader, "1s")
|
||||
timeout := GetTimeout(req.Header, time.Second*5)
|
||||
assert.Equal(t, time.Second, timeout)
|
||||
}
|
||||
|
||||
func TestGetTimeoutDefault(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
timeout := GetTimeout(req.Header, time.Second*5)
|
||||
assert.Equal(t, time.Second*5, timeout)
|
||||
}
|
@ -35,7 +35,7 @@ Upstreams:
|
||||
Mapping:
|
||||
- Method: get
|
||||
Path: /pingHello/:ping
|
||||
Rpc: hello.Hello/Ping
|
||||
RpcPath: hello.Hello/Ping
|
||||
- Grpc:
|
||||
Endpoints:
|
||||
- localhost:8081
|
||||
@ -43,7 +43,7 @@ Upstreams:
|
||||
Mapping:
|
||||
- Method: post
|
||||
Path: /pingWorld
|
||||
Rpc: world.World/Ping
|
||||
RpcPath: world.World/Ping
|
||||
```
|
||||
|
||||
## Generate ProtoSet files
|
||||
|
@ -2,6 +2,7 @@ package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@ -11,6 +12,7 @@ import (
|
||||
"github.com/jhump/protoreflect/grpcreflect"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/mr"
|
||||
"github.com/zeromicro/go-zero/gateway/internal"
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"github.com/zeromicro/go-zero/zrpc"
|
||||
@ -58,8 +60,23 @@ func (s *Server) build() error {
|
||||
return
|
||||
}
|
||||
|
||||
methods, err := internal.GetMethods(source)
|
||||
if err != nil {
|
||||
cancel(err)
|
||||
return
|
||||
}
|
||||
|
||||
methodSet := make(map[string]struct{})
|
||||
for _, m := range methods {
|
||||
methodSet[m] = struct{}{}
|
||||
}
|
||||
resolver := grpcurl.AnyResolverFromDescriptorSource(source)
|
||||
for _, m := range up.Mapping {
|
||||
if _, ok := methodSet[m.RpcPath]; !ok {
|
||||
cancel(fmt.Errorf("rpc method %s not found", m.RpcPath))
|
||||
return
|
||||
}
|
||||
|
||||
writer.Write(rest.Route{
|
||||
Method: strings.ToUpper(m.Method),
|
||||
Path: m.Path,
|
||||
@ -82,15 +99,16 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
|
||||
Formatter: grpcurl.NewJSONFormatter(true,
|
||||
grpcurl.AnyResolverFromDescriptorSource(source)),
|
||||
}
|
||||
parser, err := newRequestParser(r, resolver)
|
||||
parser, err := internal.NewRequestParser(r, resolver)
|
||||
if err != nil {
|
||||
httpx.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, can := context.WithTimeout(r.Context(), s.timeout)
|
||||
timeout := internal.GetTimeout(r.Header, s.timeout)
|
||||
ctx, can := context.WithTimeout(r.Context(), timeout)
|
||||
defer can()
|
||||
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.Rpc, buildHeaders(r.Header),
|
||||
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), m.RpcPath, internal.BuildHeaders(r.Header),
|
||||
handler, parser.Next); err != nil {
|
||||
httpx.Error(w, err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user