feat: support http->http in gateway (#4605)

This commit is contained in:
Kevin Wan 2025-01-27 20:00:58 +08:00 committed by GitHub
parent c71829c8de
commit d415ba39e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 372 additions and 76 deletions

View File

@ -12,14 +12,22 @@ type (
Upstreams []Upstream
}
// HttpClientConf is the configuration for an HTTP client.
HttpClientConf struct {
Target string
Prefix string `json:",optional"`
Timeout int64 `json:",default=3000"`
}
// RouteMapping is a mapping between a gateway route and an upstream rpc method.
RouteMapping struct {
// Method is the HTTP method, like GET, POST, PUT, DELETE.
Method string
// Path is the HTTP path.
Path string
// RpcPath is the gRPC rpc method, with format of package.service/method
RpcPath string
// RpcPath is the gRPC rpc method, with format of package.service/method, optional.
// If the mapping is for HTTP, it's not necessary.
RpcPath string `json:",optional"`
}
// Upstream is the configuration for an upstream.
@ -27,12 +35,14 @@ type (
// Name is the name of the upstream.
Name string `json:",optional"`
// Grpc is the target of the upstream.
Grpc zrpc.RpcClientConf
Grpc *zrpc.RpcClientConf `json:",optional"`
// Http is the target of the upstream.
Http *HttpClientConf `json:",optional=!grpc"`
// ProtoSets is the file list of proto set, like [hello.pb].
// if your proto file import another proto file, you need to write multi-file slice,
// like [hello.pb, common.pb].
ProtoSets []string `json:",optional"`
// Mappings is the mapping between gateway routes and Upstream rpc methods.
// Mappings is the mapping between gateway routes and Upstream methods.
// Keep it blank if annotations are added in rpc methods.
Mappings []RouteMapping `json:",optional"`
}

View File

@ -3,22 +3,29 @@ package gateway
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/fullstorydev/grpcurl"
"github.com/golang/protobuf/jsonpb"
"github.com/jhump/protoreflect/grpcreflect"
"github.com/zeromicro/go-zero/core/logc"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mr"
"github.com/zeromicro/go-zero/core/threading"
"github.com/zeromicro/go-zero/gateway/internal"
"github.com/zeromicro/go-zero/rest"
"github.com/zeromicro/go-zero/rest/httpc"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/zrpc"
"google.golang.org/grpc/codes"
)
const defaultHttpScheme = "http"
type (
// Server is a gateway server.
Server struct {
@ -83,52 +90,11 @@ func (s *Server) build() error {
source <- up
}
}, func(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
var cli zrpc.Client
if s.dialer != nil {
cli = s.dialer(up.Grpc)
} else {
cli = zrpc.MustNewClient(up.Grpc)
}
s.conns = append(s.conns, cli)
source, err := s.createDescriptorSource(cli, up)
if err != nil {
cancel(fmt.Errorf("%s: %w", up.Name, err))
return
}
methods, err := internal.GetMethods(source)
if err != nil {
cancel(fmt.Errorf("%s: %w", up.Name, err))
return
}
resolver := grpcurl.AnyResolverFromDescriptorSource(source)
for _, m := range methods {
if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 {
writer.Write(rest.Route{
Method: m.HttpMethod,
Path: m.HttpPath,
Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
})
}
}
methodSet := make(map[string]struct{})
for _, m := range methods {
methodSet[m.RpcPath] = struct{}{}
}
for _, m := range up.Mappings {
if _, ok := methodSet[m.RpcPath]; !ok {
cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath))
return
}
writer.Write(rest.Route{
Method: strings.ToUpper(m.Method),
Path: m.Path,
Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
})
// up.Grpc and up.Http are exclusive
if up.Grpc != nil {
s.buildGrpcRoute(up, writer, cancel)
} else if up.Http != nil {
s.buildHttpRoute(up, writer)
}
}, func(pipe <-chan rest.Route, cancel func(error)) {
for route := range pipe {
@ -137,7 +103,7 @@ func (s *Server) build() error {
})
}
func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
parser, err := internal.NewRequestParser(r, resolver)
@ -160,31 +126,119 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
}
}
func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
var source grpcurl.DescriptorSource
var err error
if len(up.ProtoSets) > 0 {
source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...)
if err != nil {
return nil, err
}
func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
var cli zrpc.Client
if s.dialer != nil {
cli = s.dialer(*up.Grpc)
} else {
client := grpcreflect.NewClientAuto(context.Background(), cli.Conn())
source = grpcurl.DescriptorSourceFromServer(context.Background(), client)
cli = zrpc.MustNewClient(*up.Grpc)
}
s.conns = append(s.conns, cli)
source, err := createDescriptorSource(cli, up)
if err != nil {
cancel(fmt.Errorf("%s: %w", up.Name, err))
return
}
return source, nil
methods, err := internal.GetMethods(source)
if err != nil {
cancel(fmt.Errorf("%s: %w", up.Name, err))
return
}
resolver := grpcurl.AnyResolverFromDescriptorSource(source)
for _, m := range methods {
if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 {
writer.Write(rest.Route{
Method: m.HttpMethod,
Path: m.HttpPath,
Handler: s.buildGrpcHandler(source, resolver, cli, m.RpcPath),
})
}
}
methodSet := make(map[string]struct{})
for _, m := range methods {
methodSet[m.RpcPath] = struct{}{}
}
for _, m := range up.Mappings {
if _, ok := methodSet[m.RpcPath]; !ok {
cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath))
return
}
writer.Write(rest.Route{
Method: strings.ToUpper(m.Method),
Path: m.Path,
Handler: s.buildGrpcHandler(source, resolver, cli, m.RpcPath),
})
}
}
func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(httpx.ContentType, httpx.JsonContentType)
req, err := buildRequestWithNewTarget(r, target)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
}
if target.Timeout > 0 {
timeout := time.Duration(target.Timeout) * time.Millisecond
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
req = req.WithContext(ctx)
}
resp, err := httpc.DoRequest(req)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
}
defer resp.Body.Close()
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
if _, err = io.Copy(w, resp.Body); err != nil {
// log the error with original request info
logc.Error(r.Context(), err)
}
}
}
func (s *Server) buildHttpRoute(up Upstream, writer mr.Writer[rest.Route]) {
for _, m := range up.Mappings {
writer.Write(rest.Route{
Method: strings.ToUpper(m.Method),
Path: m.Path,
Handler: s.buildHttpHandler(up.Http),
})
}
}
func (s *Server) ensureUpstreamNames() error {
for i := 0; i < len(s.upstreams); i++ {
target, err := s.upstreams[i].Grpc.BuildTarget()
if err != nil {
return err
if len(s.upstreams[i].Name) > 0 {
continue
}
s.upstreams[i].Name = target
if s.upstreams[i].Grpc != nil {
target, err := s.upstreams[i].Grpc.BuildTarget()
if err != nil {
return err
}
s.upstreams[i].Name = target
} else if s.upstreams[i].Http != nil {
s.upstreams[i].Name = s.upstreams[i].Http.Target
}
}
return nil
@ -207,6 +261,50 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server)
}
}
func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.Request, error) {
u := *r.URL
u.Host = target.Target
if len(u.Scheme) == 0 {
u.Scheme = defaultHttpScheme
}
if len(target.Prefix) > 0 {
var err error
u.Path, err = url.JoinPath(target.Prefix, u.Path)
if err != nil {
return nil, err
}
}
return &http.Request{
Method: r.Method,
URL: &u,
Header: r.Header.Clone(),
Proto: r.Proto,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
ContentLength: r.ContentLength,
Body: io.NopCloser(r.Body),
}, nil
}
func createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
var source grpcurl.DescriptorSource
var err error
if len(up.ProtoSets) > 0 {
source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...)
if err != nil {
return nil, err
}
} else {
client := grpcreflect.NewClientAuto(context.Background(), cli.Conn())
source = grpcurl.DescriptorSourceFromServer(context.Background(), client)
}
return source, nil
}
// withDialer sets a dialer to create a gRPC client.
func withDialer(dialer func(conf zrpc.RpcClientConf) zrpc.Client) func(*Server) {
return func(s *Server) {

View File

@ -2,9 +2,12 @@ package gateway
import (
"context"
"errors"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
@ -65,7 +68,7 @@ func TestMustNewServer(t *testing.T) {
RpcPath: "mock.DepositService/Deposit",
},
},
Grpc: zrpc.RpcClientConf{
Grpc: &zrpc.RpcClientConf{
Endpoints: []string{"foo"},
Timeout: 1000,
Middlewares: zrpc.ClientMiddlewaresConf{
@ -98,7 +101,7 @@ func TestServer_ensureUpstreamNames(t *testing.T) {
var s = Server{
upstreams: []Upstream{
{
Grpc: zrpc.RpcClientConf{
Grpc: &zrpc.RpcClientConf{
Target: "target",
},
},
@ -113,7 +116,7 @@ func TestServer_ensureUpstreamNames_badEtcd(t *testing.T) {
var s = Server{
upstreams: []Upstream{
{
Grpc: zrpc.RpcClientConf{
Grpc: &zrpc.RpcClientConf{
Etcd: discov.EtcdConf{},
},
},
@ -125,3 +128,193 @@ func TestServer_ensureUpstreamNames_badEtcd(t *testing.T) {
s.Start()
})
}
func TestHttpToHttp(t *testing.T) {
server := startTestServer(t)
defer server.Close()
var c GatewayConf
assert.NoError(t, conf.FillDefault(&c))
c.DevServer.Host = "localhost"
c.Host = "localhost"
c.Port = 18882
s := MustNewServer(c)
s.upstreams = []Upstream{
{
Name: "test",
Mappings: []RouteMapping{
{
Method: "get",
Path: "/api/ping",
},
},
Http: &HttpClientConf{
Target: "localhost:45678",
Timeout: 3000,
},
},
{
Mappings: []RouteMapping{
{
Method: "get",
Path: "/ping",
},
},
Http: &HttpClientConf{
Target: "localhost:45678",
Prefix: "/api",
},
},
}
go s.Start()
defer s.Stop()
time.Sleep(time.Millisecond * 200)
t.Run("/api/ping", func(t *testing.T) {
resp, err := httpc.Do(context.Background(), http.MethodGet,
"http://localhost:18882/api/ping", nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if assert.NoError(t, err) {
assert.Equal(t, "pong", string(body))
}
})
t.Run("/ping", func(t *testing.T) {
resp, err := httpc.Do(context.Background(), http.MethodGet,
"http://localhost:18882/ping", nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if assert.NoError(t, err) {
assert.Equal(t, "pong", string(body))
}
})
t.Run("no upstream", func(t *testing.T) {
resp, err := httpc.Do(context.Background(), http.MethodGet,
"http://localhost:18882/ping/bad", nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
})
}
func TestHttpToHttpBadUpstream(t *testing.T) {
var c GatewayConf
assert.NoError(t, conf.FillDefault(&c))
c.DevServer.Host = "localhost"
c.Host = "localhost"
c.Port = 18883
s := MustNewServer(c)
s.upstreams = []Upstream{
{
Mappings: []RouteMapping{
{
Method: "get",
Path: "/api/ping",
},
},
Http: &HttpClientConf{
Target: "localhost:45678",
Prefix: "\x7f/api",
},
},
}
go s.Start()
defer s.Stop()
time.Sleep(time.Millisecond * 200)
t.Run("/api/ping", func(t *testing.T) {
resp, err := httpc.Do(context.Background(), http.MethodGet,
"http://localhost:18883/api/ping", nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
func TestHttpToHttpBadWriter(t *testing.T) {
t.Run("bad url", func(t *testing.T) {
handler := new(Server).buildHttpHandler(&HttpClientConf{
Target: "http://example.com",
Timeout: 3000,
})
w := httptest.NewRecorder()
handler.ServeHTTP(&badResponseWriter{w},
httptest.NewRequest(http.MethodGet, "http://localhost:18884", nil))
assert.Equal(t, http.StatusBadRequest, w.Code)
})
t.Run("bad url", func(t *testing.T) {
var c GatewayConf
assert.NoError(t, conf.FillDefault(&c))
c.DevServer.Host = "localhost"
c.Host = "localhost"
c.Port = 18884
s := MustNewServer(c)
s.upstreams = []Upstream{
{
Mappings: []RouteMapping{
{
Method: "get",
Path: "/api/ping",
},
},
Http: &HttpClientConf{
Target: "localhost:45678",
Prefix: "\x7f/api",
},
},
}
go s.Start()
defer s.Stop()
handler := new(Server).buildHttpHandler(&HttpClientConf{
Target: "localhost:18884",
Timeout: 3000,
})
w := httptest.NewRecorder()
handler.ServeHTTP(&badResponseWriter{w},
httptest.NewRequest(http.MethodGet, "http://localhost:18884/api/ping", nil))
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
// Handler function for the root route
func pingHandler(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("pong"))
}
func startTestServer(t *testing.T) *http.Server {
http.HandleFunc("/api/ping", pingHandler)
server := &http.Server{
Addr: ":45678",
Handler: http.DefaultServeMux,
}
go func() {
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Errorf("failed to start server: %v", err)
}
}()
return server
}
type badResponseWriter struct {
http.ResponseWriter
}
func (w *badResponseWriter) Write([]byte) (int, error) {
return 0, errors.New("bad writer")
}

View File

@ -6,15 +6,11 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
)
func TestMetricsInterceptor(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
logx.Disable()
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -183,7 +183,6 @@ func request(r *http.Request, cli client) (*http.Response, error) {
for i := len(respHandlers) - 1; i >= 0; i-- {
respHandlers[i](resp, err)
}
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())

View File

@ -617,7 +617,7 @@ func NewScanner(filename string, src interface{}) (*Scanner, error) {
}
if len(data) == 0 {
return nil, fmt.Errorf("filename: %s,missing input", filename)
return nil, fmt.Errorf("filename: %s, missing input", filename)
}
var runeList []rune