diff --git a/zrpc/internal/rpcpubserver.go b/zrpc/internal/rpcpubserver.go index 3efde008..70b48132 100644 --- a/zrpc/internal/rpcpubserver.go +++ b/zrpc/internal/rpcpubserver.go @@ -14,7 +14,7 @@ const ( ) // NewRpcPubServer returns a Server. -func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, middlewares ServerMiddlewaresConf, +func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption) (Server, error) { registerEtcd := func() error { pubListenOn := figureOutListenOn(listenOn) @@ -34,7 +34,7 @@ func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, middlewares ServerMi } server := keepAliveServer{ registerEtcd: registerEtcd, - Server: NewRpcServer(listenOn, middlewares, opts...), + Server: NewRpcServer(listenOn, opts...), } return server, nil diff --git a/zrpc/internal/rpcpubserver_test.go b/zrpc/internal/rpcpubserver_test.go index 9c9f397b..cc36e465 100644 --- a/zrpc/internal/rpcpubserver_test.go +++ b/zrpc/internal/rpcpubserver_test.go @@ -13,7 +13,7 @@ func TestNewRpcPubServer(t *testing.T) { User: "user", Pass: "pass", ID: 10, - }, "", ServerMiddlewaresConf{}) + }, "") assert.NoError(t, err) assert.NotPanics(t, func() { s.Start(nil) diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index ed4aa235..c1302f03 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -5,9 +5,7 @@ import ( "net" "github.com/zeromicro/go-zero/core/proc" - "github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/internal/health" - "github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors" "google.golang.org/grpc" "google.golang.org/grpc/health/grpc_health_v1" ) @@ -19,38 +17,31 @@ type ( ServerOption func(options *rpcServerOptions) rpcServerOptions struct { - metrics *stat.Metrics - health bool + health bool } rpcServer struct { *baseRpcServer name string - middlewares ServerMiddlewaresConf healthManager health.Probe } ) // NewRpcServer returns a Server. -func NewRpcServer(addr string, middlewares ServerMiddlewaresConf, opts ...ServerOption) Server { +func NewRpcServer(addr string, opts ...ServerOption) Server { var options rpcServerOptions for _, opt := range opts { opt(&options) } - if options.metrics == nil { - options.metrics = stat.NewMetrics(addr) - } return &rpcServer{ baseRpcServer: newBaseRpcServer(addr, &options), - middlewares: middlewares, healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, addr)), } } func (s *rpcServer) SetName(name string) { s.name = name - s.baseRpcServer.SetName(name) } func (s *rpcServer) Start(register RegisterFn) error { @@ -59,8 +50,8 @@ func (s *rpcServer) Start(register RegisterFn) error { return err } - unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...) - streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...) + unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.unaryInterceptors...) + streamInterceptorOption := grpc.ChainStreamInterceptor(s.streamInterceptors...) options := append(s.options, unaryInterceptorOption, streamInterceptorOption) server := grpc.NewServer(options...) @@ -87,52 +78,6 @@ func (s *rpcServer) Start(register RegisterFn) error { return server.Serve(lis) } -func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor { - var interceptors []grpc.StreamServerInterceptor - - if s.middlewares.Trace { - interceptors = append(interceptors, serverinterceptors.StreamTracingInterceptor) - } - if s.middlewares.Recover { - interceptors = append(interceptors, serverinterceptors.StreamRecoverInterceptor) - } - if s.middlewares.Breaker { - interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor) - } - - return append(interceptors, s.streamInterceptors...) -} - -func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor { - var interceptors []grpc.UnaryServerInterceptor - - if s.middlewares.Trace { - interceptors = append(interceptors, serverinterceptors.UnaryTracingInterceptor) - } - if s.middlewares.Recover { - interceptors = append(interceptors, serverinterceptors.UnaryRecoverInterceptor) - } - if s.middlewares.Stat { - interceptors = append(interceptors, - serverinterceptors.UnaryStatInterceptor(s.metrics, s.middlewares.StatConf)) - } - if s.middlewares.Prometheus { - interceptors = append(interceptors, serverinterceptors.UnaryPrometheusInterceptor) - } - if s.middlewares.Breaker { - interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor) - } - - return append(interceptors, s.unaryInterceptors...) -} - -// WithMetrics returns a func that sets metrics to a Server. -func WithMetrics(metrics *stat.Metrics) ServerOption { - return func(options *rpcServerOptions) { - options.metrics = metrics - } -} - // WithRpcHealth returns a func that sets rpc health switch to a Server. func WithRpcHealth(health bool) ServerOption { return func(options *rpcServerOptions) { diff --git a/zrpc/internal/rpcserver_test.go b/zrpc/internal/rpcserver_test.go index 7af40110..696dae68 100644 --- a/zrpc/internal/rpcserver_test.go +++ b/zrpc/internal/rpcserver_test.go @@ -1,27 +1,18 @@ package internal import ( - "context" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/proc" - "github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/internal/mock" "google.golang.org/grpc" ) func TestRpcServer(t *testing.T) { - metrics := stat.NewMetrics("foo") - server := NewRpcServer("localhost:54321", ServerMiddlewaresConf{ - Trace: true, - Recover: true, - Stat: true, - Prometheus: true, - Breaker: true, - }, WithMetrics(metrics), WithRpcHealth(true)) + server := NewRpcServer("localhost:54321", WithRpcHealth(true)) server.SetName("mock") var wg, wgDone sync.WaitGroup var grpcServer *grpc.Server @@ -52,13 +43,7 @@ func TestRpcServer(t *testing.T) { } func TestRpcServer_WithBadAddress(t *testing.T) { - server := NewRpcServer("localhost:111111", ServerMiddlewaresConf{ - Trace: true, - Recover: true, - Stat: true, - Prometheus: true, - Breaker: true, - }, WithRpcHealth(true)) + server := NewRpcServer("localhost:111111", WithRpcHealth(true)) server.SetName("mock") err := server.Start(func(server *grpc.Server) { mock.RegisterDepositServiceServer(server, new(mock.DepositServer)) @@ -67,115 +52,3 @@ func TestRpcServer_WithBadAddress(t *testing.T) { proc.WrapUp() } - -func TestRpcServer_buildUnaryInterceptor(t *testing.T) { - tests := []struct { - name string - r *rpcServer - len int - }{ - { - name: "empty", - r: &rpcServer{ - baseRpcServer: &baseRpcServer{}, - }, - len: 0, - }, - { - name: "custom", - r: &rpcServer{ - baseRpcServer: &baseRpcServer{ - unaryInterceptors: []grpc.UnaryServerInterceptor{ - func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler) (interface{}, error) { - return nil, nil - }, - }, - }, - }, - len: 1, - }, - { - name: "middleware", - r: &rpcServer{ - baseRpcServer: &baseRpcServer{ - unaryInterceptors: []grpc.UnaryServerInterceptor{ - func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler) (interface{}, error) { - return nil, nil - }, - }, - }, - middlewares: ServerMiddlewaresConf{ - Trace: true, - Recover: true, - Stat: true, - Prometheus: true, - Breaker: true, - }, - }, - len: 6, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors())) - }) - } -} - -func TestRpcServer_buildStreamInterceptor(t *testing.T) { - tests := []struct { - name string - r *rpcServer - len int - }{ - { - name: "empty", - r: &rpcServer{ - baseRpcServer: &baseRpcServer{}, - }, - len: 0, - }, - { - name: "custom", - r: &rpcServer{ - baseRpcServer: &baseRpcServer{ - streamInterceptors: []grpc.StreamServerInterceptor{ - func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, - handler grpc.StreamHandler) error { - return nil - }, - }, - }, - }, - len: 1, - }, - { - name: "middleware", - r: &rpcServer{ - baseRpcServer: &baseRpcServer{ - streamInterceptors: []grpc.StreamServerInterceptor{ - func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, - handler grpc.StreamHandler) error { - return nil - }, - }, - }, - middlewares: ServerMiddlewaresConf{ - Trace: true, - Recover: true, - Breaker: true, - }, - }, - len: 4, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.len, len(test.r.buildStreamInterceptors())) - }) - } -} diff --git a/zrpc/internal/server.go b/zrpc/internal/server.go index ad405b34..fc9eea0c 100644 --- a/zrpc/internal/server.go +++ b/zrpc/internal/server.go @@ -3,7 +3,6 @@ package internal import ( "time" - "github.com/zeromicro/go-zero/core/stat" "google.golang.org/grpc" "google.golang.org/grpc/health" "google.golang.org/grpc/keepalive" @@ -27,7 +26,6 @@ type ( baseRpcServer struct { address string health *health.Server - metrics *stat.Metrics options []grpc.ServerOption streamInterceptors []grpc.StreamServerInterceptor unaryInterceptors []grpc.UnaryServerInterceptor @@ -42,7 +40,6 @@ func newBaseRpcServer(address string, rpcServerOpts *rpcServerOptions) *baseRpcS return &baseRpcServer{ address: address, health: h, - metrics: rpcServerOpts.metrics, options: []grpc.ServerOption{grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionIdle: defaultConnectionIdleDuration, })}, @@ -60,7 +57,3 @@ func (s *baseRpcServer) AddStreamInterceptors(interceptors ...grpc.StreamServerI func (s *baseRpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) { s.unaryInterceptors = append(s.unaryInterceptors, interceptors...) } - -func (s *baseRpcServer) SetName(name string) { - s.metrics.SetName(name) -} diff --git a/zrpc/internal/server_test.go b/zrpc/internal/server_test.go index fa48f7b4..a2fc926e 100644 --- a/zrpc/internal/server_test.go +++ b/zrpc/internal/server_test.go @@ -5,23 +5,18 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/stat" "google.golang.org/grpc" ) func TestBaseRpcServer_AddOptions(t *testing.T) { - metrics := stat.NewMetrics("foo") - server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics}) - server.SetName("bar") + server := newBaseRpcServer("foo", &rpcServerOptions{}) var opt grpc.EmptyServerOption server.AddOptions(opt) assert.Contains(t, server.options, opt) } func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) { - metrics := stat.NewMetrics("foo") - server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics}) - server.SetName("bar") + server := newBaseRpcServer("foo", &rpcServerOptions{}) var vals []int f := func(_ any, _ grpc.ServerStream, _ *grpc.StreamServerInfo, _ grpc.StreamHandler) error { vals = append(vals, 1) @@ -35,9 +30,7 @@ func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) { } func TestBaseRpcServer_AddUnaryInterceptors(t *testing.T) { - metrics := stat.NewMetrics("foo") - server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics}) - server.SetName("bar") + server := newBaseRpcServer("foo", &rpcServerOptions{}) var vals []int f := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( resp any, err error) { diff --git a/zrpc/server.go b/zrpc/server.go index d891c8e6..813fc358 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -36,21 +36,23 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error var server internal.Server metrics := stat.NewMetrics(c.ListenOn) serverOptions := []internal.ServerOption{ - internal.WithMetrics(metrics), internal.WithRpcHealth(c.Health), } if c.HasEtcd() { - server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, c.Middlewares, serverOptions...) + server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, serverOptions...) if err != nil { return nil, err } } else { - server = internal.NewRpcServer(c.ListenOn, c.Middlewares, serverOptions...) + server = internal.NewRpcServer(c.ListenOn, serverOptions...) } server.SetName(c.Name) - if err = setupInterceptors(server, c, metrics); err != nil { + metrics.SetName(c.Name) + setupStreamInterceptors(server, c) + setupUnaryInterceptors(server, c, metrics) + if err = setupAuthInterceptors(server, c); err != nil { return nil, err } @@ -108,6 +110,9 @@ func SetServerSlowThreshold(threshold time.Duration) { } func setupAuthInterceptors(svr internal.Server, c RpcServerConf) error { + if !c.Auth { + return nil + } rds, err := redis.NewRedis(c.Redis.RedisConf) if err != nil { return err @@ -124,22 +129,40 @@ func setupAuthInterceptors(svr internal.Server, c RpcServerConf) error { return nil } -func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metrics) error { +func setupStreamInterceptors(svr internal.Server, c RpcServerConf) { + if c.Middlewares.Trace { + svr.AddStreamInterceptors(serverinterceptors.StreamTracingInterceptor) + } + if c.Middlewares.Recover { + svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor) + } + if c.Middlewares.Breaker { + svr.AddStreamInterceptors(serverinterceptors.StreamBreakerInterceptor) + } +} + +func setupUnaryInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metrics) { + if c.Middlewares.Trace { + svr.AddUnaryInterceptors(serverinterceptors.UnaryTracingInterceptor) + } + if c.Middlewares.Recover { + svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor) + } + if c.Middlewares.Stat { + svr.AddUnaryInterceptors(serverinterceptors.UnaryStatInterceptor(metrics, c.Middlewares.StatConf)) + } + if c.Middlewares.Prometheus { + svr.AddUnaryInterceptors(serverinterceptors.UnaryPrometheusInterceptor) + } + if c.Middlewares.Breaker { + svr.AddUnaryInterceptors(serverinterceptors.UnaryBreakerInterceptor) + } if c.CpuThreshold > 0 { shedder := load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) svr.AddUnaryInterceptors(serverinterceptors.UnarySheddingInterceptor(shedder, metrics)) } - if c.Timeout > 0 { svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor( time.Duration(c.Timeout)*time.Millisecond, c.MethodTimeouts...)) } - - if c.Auth { - if err := setupAuthInterceptors(svr, c); err != nil { - return err - } - } - - return nil } diff --git a/zrpc/server_test.go b/zrpc/server_test.go index e99f224f..e42e379a 100644 --- a/zrpc/server_test.go +++ b/zrpc/server_test.go @@ -1,6 +1,7 @@ package zrpc import ( + "context" "testing" "time" @@ -16,47 +17,6 @@ import ( "google.golang.org/grpc" ) -func TestServer_setupInterceptors(t *testing.T) { - rds, err := miniredis.Run() - assert.NoError(t, err) - defer rds.Close() - - server := new(mockedServer) - conf := RpcServerConf{ - Auth: true, - Redis: redis.RedisKeyConf{ - RedisConf: redis.RedisConf{ - Host: rds.Addr(), - Type: redis.NodeType, - }, - Key: "foo", - }, - CpuThreshold: 10, - Timeout: 100, - Middlewares: ServerMiddlewaresConf{ - Trace: true, - Recover: true, - Stat: true, - Prometheus: true, - Breaker: true, - }, - MethodTimeouts: []MethodTimeoutConf{ - { - FullMethod: "/foo", - Timeout: 5 * time.Second, - }, - }, - } - err = setupInterceptors(server, conf, new(stat.Metrics)) - assert.Nil(t, err) - assert.Equal(t, 3, len(server.unaryInterceptors)) - assert.Equal(t, 1, len(server.streamInterceptors)) - - rds.SetError("mock error") - err = setupInterceptors(server, conf, new(stat.Metrics)) - assert.Error(t, err) -} - func TestServer(t *testing.T) { DontLogContentForMethod("foo") SetServerSlowThreshold(time.Second) @@ -198,3 +158,153 @@ func (m *mockedServer) SetName(_ string) { func (m *mockedServer) Start(_ internal.RegisterFn) error { return nil } + +func Test_setupUnaryInterceptors(t *testing.T) { + tests := []struct { + name string + r *mockedServer + conf RpcServerConf + len int + }{ + { + name: "empty", + r: &mockedServer{}, + len: 0, + }, + { + name: "custom", + r: &mockedServer{ + unaryInterceptors: []grpc.UnaryServerInterceptor{ + func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (interface{}, error) { + return nil, nil + }, + }, + }, + len: 1, + }, + { + name: "middleware", + r: &mockedServer{}, + conf: RpcServerConf{ + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, + }, + len: 5, + }, + { + name: "internal middleware", + r: &mockedServer{}, + conf: RpcServerConf{ + CpuThreshold: 900, + Timeout: 100, + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, + }, + len: 7, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + metrics := stat.NewMetrics("abc") + setupUnaryInterceptors(test.r, test.conf, metrics) + assert.Equal(t, test.len, len(test.r.unaryInterceptors)) + }) + } +} + +func Test_setupStreamInterceptors(t *testing.T) { + tests := []struct { + name string + r *mockedServer + conf RpcServerConf + len int + }{ + { + name: "empty", + r: &mockedServer{}, + len: 0, + }, + { + name: "custom", + r: &mockedServer{ + streamInterceptors: []grpc.StreamServerInterceptor{ + func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return handler(srv, ss) + }, + }, + }, + len: 1, + }, + { + name: "middleware", + r: &mockedServer{}, + conf: RpcServerConf{ + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, + }, + len: 3, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + setupStreamInterceptors(test.r, test.conf) + assert.Equal(t, test.len, len(test.r.streamInterceptors)) + }) + } +} + +func Test_setupAuthInterceptors(t *testing.T) { + t.Run("no need set auth", func(t *testing.T) { + s := &mockedServer{} + err := setupAuthInterceptors(s, RpcServerConf{ + Auth: false, + Redis: redis.RedisKeyConf{}, + }) + assert.NoError(t, err) + }) + + t.Run("redis error", func(t *testing.T) { + s := &mockedServer{} + err := setupAuthInterceptors(s, RpcServerConf{ + Auth: true, + Redis: redis.RedisKeyConf{}, + }) + assert.Error(t, err) + }) + + t.Run("works", func(t *testing.T) { + rds := miniredis.RunT(t) + s := &mockedServer{} + err := setupAuthInterceptors(s, RpcServerConf{ + Auth: true, + Redis: redis.RedisKeyConf{ + RedisConf: redis.RedisConf{ + Host: rds.Addr(), + Type: redis.NodeType, + }, + Key: "foo", + }, + }) + assert.NoError(t, err) + assert.Equal(t, 1, len(s.unaryInterceptors)) + assert.Equal(t, 1, len(s.streamInterceptors)) + }) +}