mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-02 16:28:39 +08:00
feat: support http->http in gateway (#4605)
This commit is contained in:
parent
c71829c8de
commit
d415ba39e2
@ -12,14 +12,22 @@ type (
|
||||
Upstreams []Upstream
|
||||
}
|
||||
|
||||
// HttpClientConf is the configuration for an HTTP client.
|
||||
HttpClientConf struct {
|
||||
Target string
|
||||
Prefix string `json:",optional"`
|
||||
Timeout int64 `json:",default=3000"`
|
||||
}
|
||||
|
||||
// RouteMapping is a mapping between a gateway route and an upstream rpc method.
|
||||
RouteMapping struct {
|
||||
// Method is the HTTP method, like GET, POST, PUT, DELETE.
|
||||
Method string
|
||||
// Path is the HTTP path.
|
||||
Path string
|
||||
// RpcPath is the gRPC rpc method, with format of package.service/method
|
||||
RpcPath string
|
||||
// RpcPath is the gRPC rpc method, with format of package.service/method, optional.
|
||||
// If the mapping is for HTTP, it's not necessary.
|
||||
RpcPath string `json:",optional"`
|
||||
}
|
||||
|
||||
// Upstream is the configuration for an upstream.
|
||||
@ -27,12 +35,14 @@ type (
|
||||
// Name is the name of the upstream.
|
||||
Name string `json:",optional"`
|
||||
// Grpc is the target of the upstream.
|
||||
Grpc zrpc.RpcClientConf
|
||||
Grpc *zrpc.RpcClientConf `json:",optional"`
|
||||
// Http is the target of the upstream.
|
||||
Http *HttpClientConf `json:",optional=!grpc"`
|
||||
// ProtoSets is the file list of proto set, like [hello.pb].
|
||||
// if your proto file import another proto file, you need to write multi-file slice,
|
||||
// like [hello.pb, common.pb].
|
||||
ProtoSets []string `json:",optional"`
|
||||
// Mappings is the mapping between gateway routes and Upstream rpc methods.
|
||||
// Mappings is the mapping between gateway routes and Upstream methods.
|
||||
// Keep it blank if annotations are added in rpc methods.
|
||||
Mappings []RouteMapping `json:",optional"`
|
||||
}
|
||||
|
@ -3,22 +3,29 @@ package gateway
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fullstorydev/grpcurl"
|
||||
"github.com/golang/protobuf/jsonpb"
|
||||
"github.com/jhump/protoreflect/grpcreflect"
|
||||
"github.com/zeromicro/go-zero/core/logc"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/mr"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
"github.com/zeromicro/go-zero/gateway/internal"
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
"github.com/zeromicro/go-zero/rest/httpc"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"github.com/zeromicro/go-zero/zrpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
const defaultHttpScheme = "http"
|
||||
|
||||
type (
|
||||
// Server is a gateway server.
|
||||
Server struct {
|
||||
@ -83,52 +90,11 @@ func (s *Server) build() error {
|
||||
source <- up
|
||||
}
|
||||
}, func(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
|
||||
var cli zrpc.Client
|
||||
if s.dialer != nil {
|
||||
cli = s.dialer(up.Grpc)
|
||||
} else {
|
||||
cli = zrpc.MustNewClient(up.Grpc)
|
||||
}
|
||||
s.conns = append(s.conns, cli)
|
||||
|
||||
source, err := s.createDescriptorSource(cli, up)
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("%s: %w", up.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
methods, err := internal.GetMethods(source)
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("%s: %w", up.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
resolver := grpcurl.AnyResolverFromDescriptorSource(source)
|
||||
for _, m := range methods {
|
||||
if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 {
|
||||
writer.Write(rest.Route{
|
||||
Method: m.HttpMethod,
|
||||
Path: m.HttpPath,
|
||||
Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
methodSet := make(map[string]struct{})
|
||||
for _, m := range methods {
|
||||
methodSet[m.RpcPath] = struct{}{}
|
||||
}
|
||||
for _, m := range up.Mappings {
|
||||
if _, ok := methodSet[m.RpcPath]; !ok {
|
||||
cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath))
|
||||
return
|
||||
}
|
||||
|
||||
writer.Write(rest.Route{
|
||||
Method: strings.ToUpper(m.Method),
|
||||
Path: m.Path,
|
||||
Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
|
||||
})
|
||||
// up.Grpc and up.Http are exclusive
|
||||
if up.Grpc != nil {
|
||||
s.buildGrpcRoute(up, writer, cancel)
|
||||
} else if up.Http != nil {
|
||||
s.buildHttpRoute(up, writer)
|
||||
}
|
||||
}, func(pipe <-chan rest.Route, cancel func(error)) {
|
||||
for route := range pipe {
|
||||
@ -137,7 +103,7 @@ func (s *Server) build() error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
|
||||
func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
|
||||
cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
parser, err := internal.NewRequestParser(r, resolver)
|
||||
@ -160,31 +126,119 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
|
||||
var source grpcurl.DescriptorSource
|
||||
var err error
|
||||
|
||||
if len(up.ProtoSets) > 0 {
|
||||
source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
|
||||
var cli zrpc.Client
|
||||
if s.dialer != nil {
|
||||
cli = s.dialer(*up.Grpc)
|
||||
} else {
|
||||
client := grpcreflect.NewClientAuto(context.Background(), cli.Conn())
|
||||
source = grpcurl.DescriptorSourceFromServer(context.Background(), client)
|
||||
cli = zrpc.MustNewClient(*up.Grpc)
|
||||
}
|
||||
s.conns = append(s.conns, cli)
|
||||
|
||||
source, err := createDescriptorSource(cli, up)
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("%s: %w", up.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
return source, nil
|
||||
methods, err := internal.GetMethods(source)
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("%s: %w", up.Name, err))
|
||||
return
|
||||
}
|
||||
|
||||
resolver := grpcurl.AnyResolverFromDescriptorSource(source)
|
||||
for _, m := range methods {
|
||||
if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 {
|
||||
writer.Write(rest.Route{
|
||||
Method: m.HttpMethod,
|
||||
Path: m.HttpPath,
|
||||
Handler: s.buildGrpcHandler(source, resolver, cli, m.RpcPath),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
methodSet := make(map[string]struct{})
|
||||
for _, m := range methods {
|
||||
methodSet[m.RpcPath] = struct{}{}
|
||||
}
|
||||
for _, m := range up.Mappings {
|
||||
if _, ok := methodSet[m.RpcPath]; !ok {
|
||||
cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath))
|
||||
return
|
||||
}
|
||||
|
||||
writer.Write(rest.Route{
|
||||
Method: strings.ToUpper(m.Method),
|
||||
Path: m.Path,
|
||||
Handler: s.buildGrpcHandler(source, resolver, cli, m.RpcPath),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(httpx.ContentType, httpx.JsonContentType)
|
||||
req, err := buildRequestWithNewTarget(r, target)
|
||||
if err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if target.Timeout > 0 {
|
||||
timeout := time.Duration(target.Timeout) * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
resp, err := httpc.DoRequest(req)
|
||||
if err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
if _, err = io.Copy(w, resp.Body); err != nil {
|
||||
// log the error with original request info
|
||||
logc.Error(r.Context(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) buildHttpRoute(up Upstream, writer mr.Writer[rest.Route]) {
|
||||
for _, m := range up.Mappings {
|
||||
writer.Write(rest.Route{
|
||||
Method: strings.ToUpper(m.Method),
|
||||
Path: m.Path,
|
||||
Handler: s.buildHttpHandler(up.Http),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ensureUpstreamNames() error {
|
||||
for i := 0; i < len(s.upstreams); i++ {
|
||||
target, err := s.upstreams[i].Grpc.BuildTarget()
|
||||
if err != nil {
|
||||
return err
|
||||
if len(s.upstreams[i].Name) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
s.upstreams[i].Name = target
|
||||
if s.upstreams[i].Grpc != nil {
|
||||
target, err := s.upstreams[i].Grpc.BuildTarget()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.upstreams[i].Name = target
|
||||
} else if s.upstreams[i].Http != nil {
|
||||
s.upstreams[i].Name = s.upstreams[i].Http.Target
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -207,6 +261,50 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server)
|
||||
}
|
||||
}
|
||||
|
||||
func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.Request, error) {
|
||||
u := *r.URL
|
||||
u.Host = target.Target
|
||||
if len(u.Scheme) == 0 {
|
||||
u.Scheme = defaultHttpScheme
|
||||
}
|
||||
|
||||
if len(target.Prefix) > 0 {
|
||||
var err error
|
||||
u.Path, err = url.JoinPath(target.Prefix, u.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Method: r.Method,
|
||||
URL: &u,
|
||||
Header: r.Header.Clone(),
|
||||
Proto: r.Proto,
|
||||
ProtoMajor: r.ProtoMajor,
|
||||
ProtoMinor: r.ProtoMinor,
|
||||
ContentLength: r.ContentLength,
|
||||
Body: io.NopCloser(r.Body),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
|
||||
var source grpcurl.DescriptorSource
|
||||
var err error
|
||||
|
||||
if len(up.ProtoSets) > 0 {
|
||||
source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
client := grpcreflect.NewClientAuto(context.Background(), cli.Conn())
|
||||
source = grpcurl.DescriptorSourceFromServer(context.Background(), client)
|
||||
}
|
||||
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// withDialer sets a dialer to create a gRPC client.
|
||||
func withDialer(dialer func(conf zrpc.RpcClientConf) zrpc.Client) func(*Server) {
|
||||
return func(s *Server) {
|
||||
|
@ -2,9 +2,12 @@ package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -65,7 +68,7 @@ func TestMustNewServer(t *testing.T) {
|
||||
RpcPath: "mock.DepositService/Deposit",
|
||||
},
|
||||
},
|
||||
Grpc: zrpc.RpcClientConf{
|
||||
Grpc: &zrpc.RpcClientConf{
|
||||
Endpoints: []string{"foo"},
|
||||
Timeout: 1000,
|
||||
Middlewares: zrpc.ClientMiddlewaresConf{
|
||||
@ -98,7 +101,7 @@ func TestServer_ensureUpstreamNames(t *testing.T) {
|
||||
var s = Server{
|
||||
upstreams: []Upstream{
|
||||
{
|
||||
Grpc: zrpc.RpcClientConf{
|
||||
Grpc: &zrpc.RpcClientConf{
|
||||
Target: "target",
|
||||
},
|
||||
},
|
||||
@ -113,7 +116,7 @@ func TestServer_ensureUpstreamNames_badEtcd(t *testing.T) {
|
||||
var s = Server{
|
||||
upstreams: []Upstream{
|
||||
{
|
||||
Grpc: zrpc.RpcClientConf{
|
||||
Grpc: &zrpc.RpcClientConf{
|
||||
Etcd: discov.EtcdConf{},
|
||||
},
|
||||
},
|
||||
@ -125,3 +128,193 @@ func TestServer_ensureUpstreamNames_badEtcd(t *testing.T) {
|
||||
s.Start()
|
||||
})
|
||||
}
|
||||
|
||||
func TestHttpToHttp(t *testing.T) {
|
||||
server := startTestServer(t)
|
||||
defer server.Close()
|
||||
|
||||
var c GatewayConf
|
||||
assert.NoError(t, conf.FillDefault(&c))
|
||||
c.DevServer.Host = "localhost"
|
||||
c.Host = "localhost"
|
||||
c.Port = 18882
|
||||
|
||||
s := MustNewServer(c)
|
||||
s.upstreams = []Upstream{
|
||||
{
|
||||
Name: "test",
|
||||
Mappings: []RouteMapping{
|
||||
{
|
||||
Method: "get",
|
||||
Path: "/api/ping",
|
||||
},
|
||||
},
|
||||
Http: &HttpClientConf{
|
||||
Target: "localhost:45678",
|
||||
Timeout: 3000,
|
||||
},
|
||||
},
|
||||
{
|
||||
Mappings: []RouteMapping{
|
||||
{
|
||||
Method: "get",
|
||||
Path: "/ping",
|
||||
},
|
||||
},
|
||||
Http: &HttpClientConf{
|
||||
Target: "localhost:45678",
|
||||
Prefix: "/api",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
defer s.Stop()
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
t.Run("/api/ping", func(t *testing.T) {
|
||||
resp, err := httpc.Do(context.Background(), http.MethodGet,
|
||||
"http://localhost:18882/api/ping", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "pong", string(body))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("/ping", func(t *testing.T) {
|
||||
resp, err := httpc.Do(context.Background(), http.MethodGet,
|
||||
"http://localhost:18882/ping", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "pong", string(body))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no upstream", func(t *testing.T) {
|
||||
resp, err := httpc.Do(context.Background(), http.MethodGet,
|
||||
"http://localhost:18882/ping/bad", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHttpToHttpBadUpstream(t *testing.T) {
|
||||
var c GatewayConf
|
||||
assert.NoError(t, conf.FillDefault(&c))
|
||||
c.DevServer.Host = "localhost"
|
||||
c.Host = "localhost"
|
||||
c.Port = 18883
|
||||
|
||||
s := MustNewServer(c)
|
||||
s.upstreams = []Upstream{
|
||||
{
|
||||
Mappings: []RouteMapping{
|
||||
{
|
||||
Method: "get",
|
||||
Path: "/api/ping",
|
||||
},
|
||||
},
|
||||
Http: &HttpClientConf{
|
||||
Target: "localhost:45678",
|
||||
Prefix: "\x7f/api",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
defer s.Stop()
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
t.Run("/api/ping", func(t *testing.T) {
|
||||
resp, err := httpc.Do(context.Background(), http.MethodGet,
|
||||
"http://localhost:18883/api/ping", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHttpToHttpBadWriter(t *testing.T) {
|
||||
t.Run("bad url", func(t *testing.T) {
|
||||
handler := new(Server).buildHttpHandler(&HttpClientConf{
|
||||
Target: "http://example.com",
|
||||
Timeout: 3000,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(&badResponseWriter{w},
|
||||
httptest.NewRequest(http.MethodGet, "http://localhost:18884", nil))
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
|
||||
t.Run("bad url", func(t *testing.T) {
|
||||
var c GatewayConf
|
||||
assert.NoError(t, conf.FillDefault(&c))
|
||||
c.DevServer.Host = "localhost"
|
||||
c.Host = "localhost"
|
||||
c.Port = 18884
|
||||
|
||||
s := MustNewServer(c)
|
||||
s.upstreams = []Upstream{
|
||||
{
|
||||
Mappings: []RouteMapping{
|
||||
{
|
||||
Method: "get",
|
||||
Path: "/api/ping",
|
||||
},
|
||||
},
|
||||
Http: &HttpClientConf{
|
||||
Target: "localhost:45678",
|
||||
Prefix: "\x7f/api",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
defer s.Stop()
|
||||
|
||||
handler := new(Server).buildHttpHandler(&HttpClientConf{
|
||||
Target: "localhost:18884",
|
||||
Timeout: 3000,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(&badResponseWriter{w},
|
||||
httptest.NewRequest(http.MethodGet, "http://localhost:18884/api/ping", nil))
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// Handler function for the root route
|
||||
func pingHandler(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("pong"))
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T) *http.Server {
|
||||
http.HandleFunc("/api/ping", pingHandler)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":45678",
|
||||
Handler: http.DefaultServeMux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Errorf("failed to start server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
type badResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *badResponseWriter) Write([]byte) (int, error) {
|
||||
return 0, errors.New("bad writer")
|
||||
}
|
||||
|
@ -6,15 +6,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
func TestMetricsInterceptor(t *testing.T) {
|
||||
c := gomock.NewController(t)
|
||||
defer c.Finish()
|
||||
|
||||
logx.Disable()
|
||||
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -183,7 +183,6 @@ func request(r *http.Request, cli client) (*http.Response, error) {
|
||||
for i := len(respHandlers) - 1; i >= 0; i-- {
|
||||
respHandlers[i](resp, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
|
@ -617,7 +617,7 @@ func NewScanner(filename string, src interface{}) (*Scanner, error) {
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return nil, fmt.Errorf("filename: %s,missing input", filename)
|
||||
return nil, fmt.Errorf("filename: %s, missing input", filename)
|
||||
}
|
||||
|
||||
var runeList []rune
|
||||
|
Loading…
Reference in New Issue
Block a user