chore: add unit tests (#1615)

* test: add more tests

* test: add more tests
This commit is contained in:
Kevin Wan 2022-03-04 17:54:09 +08:00 committed by GitHub
parent 60760b52ab
commit 3b7ca86e4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 108 additions and 69 deletions

View File

@ -90,14 +90,14 @@ func TestParseFullMethod(t *testing.T) {
semconv.RPCMethodKey.String("theMethod"), semconv.RPCMethodKey.String("theMethod"),
}, },
}, { }, {
fullMethod: "/pkg.srv", fullMethod: "/pkg.svr",
name: "pkg.srv", name: "pkg.svr",
attr: []attribute.KeyValue(nil), attr: []attribute.KeyValue(nil),
}, { }, {
fullMethod: "/pkg.srv/", fullMethod: "/pkg.svr/",
name: "pkg.srv/", name: "pkg.svr/",
attr: []attribute.KeyValue{ attr: []attribute.KeyValue{
semconv.RPCServiceKey.String("pkg.srv"), semconv.RPCServiceKey.String("pkg.svr"),
}, },
}, },
} }

View File

@ -35,16 +35,16 @@ type engine struct {
} }
func newEngine(c RestConf) *engine { func newEngine(c RestConf) *engine {
srv := &engine{ svr := &engine{
conf: c, conf: c,
} }
if c.CpuThreshold > 0 { if c.CpuThreshold > 0 {
srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold( svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
(c.CpuThreshold + topCpuUsage) >> 1)) (c.CpuThreshold + topCpuUsage) >> 1))
} }
return srv return svr
} }
func (ng *engine) addRoutes(r featuredRoutes) { func (ng *engine) addRoutes(r featuredRoutes) {
@ -238,9 +238,9 @@ func (ng *engine) start(router httpx.Router) error {
} }
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile, return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
ng.conf.KeyFile, router, func(srv *http.Server) { ng.conf.KeyFile, router, func(svr *http.Server) {
if ng.tlsConfig != nil { if ng.tlsConfig != nil {
srv.TLSConfig = ng.tlsConfig svr.TLSConfig = ng.tlsConfig
} }
}) })
} }

View File

@ -36,3 +36,8 @@ func TestError(t *testing.T) {
assert.True(t, strings.Contains(val, "third")) assert.True(t, strings.Contains(val, "third"))
assert.True(t, strings.Contains(val, "\n")) assert.True(t, strings.Contains(val, "\n"))
} }
func TestContextKey_String(t *testing.T) {
val := contextKey("foo")
assert.True(t, strings.Contains(val.String(), "foo"))
}

View File

@ -151,6 +151,8 @@ func TestContentSecurity(t *testing.T) {
return return
} }
encrypted := test.mode != "0"
assert.Equal(t, encrypted, header.Encrypted())
assert.Equal(t, test.code, VerifySignature(r, header, time.Minute)) assert.Equal(t, test.code, VerifySignature(r, header, time.Minute))
}) })
} }

View File

@ -10,25 +10,25 @@ import (
) )
// StartOption defines the method to customize http.Server. // StartOption defines the method to customize http.Server.
type StartOption func(srv *http.Server) type StartOption func(svr *http.Server)
// StartHttp starts a http server. // StartHttp starts a http server.
func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error { func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error {
return start(host, port, handler, func(srv *http.Server) error { return start(host, port, handler, func(svr *http.Server) error {
return srv.ListenAndServe() return svr.ListenAndServe()
}, opts...) }, opts...)
} }
// StartHttps starts a https server. // StartHttps starts a https server.
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler, func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
opts ...StartOption) error { opts ...StartOption) error {
return start(host, port, handler, func(srv *http.Server) error { return start(host, port, handler, func(svr *http.Server) error {
// certFile and keyFile are set in buildHttpsServer // certFile and keyFile are set in buildHttpsServer
return srv.ListenAndServeTLS(certFile, keyFile) return svr.ListenAndServeTLS(certFile, keyFile)
}, opts...) }, opts...)
} }
func start(host string, port int, handler http.Handler, run func(srv *http.Server) error, func start(host string, port int, handler http.Handler, run func(svr *http.Server) error,
opts ...StartOption) (err error) { opts ...StartOption) (err error) {
server := &http.Server{ server := &http.Server{
Addr: fmt.Sprintf("%s:%d", host, port), Addr: fmt.Sprintf("%s:%d", host, port),

View File

@ -0,0 +1,33 @@
package internal
import (
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestStartHttp(t *testing.T) {
svr := httptest.NewUnstartedServer(http.NotFoundHandler())
fields := strings.Split(svr.Listener.Addr().String(), ":")
port, err := strconv.Atoi(fields[1])
assert.Nil(t, err)
err = StartHttp(fields[0], port, http.NotFoundHandler(), func(svr *http.Server) {
svr.IdleTimeout = 0
})
assert.NotNil(t, err)
}
func TestStartHttps(t *testing.T) {
svr := httptest.NewTLSServer(http.NotFoundHandler())
fields := strings.Split(svr.Listener.Addr().String(), ":")
port, err := strconv.Atoi(fields[1])
assert.Nil(t, err)
err = StartHttps(fields[0], port, "", "", http.NotFoundHandler(), func(svr *http.Server) {
svr.IdleTimeout = 0
})
assert.NotNil(t, err)
}

View File

@ -225,22 +225,22 @@ func WithTimeout(timeout time.Duration) RouteOption {
// WithTLSConfig returns a RunOption that with given tls config. // WithTLSConfig returns a RunOption that with given tls config.
func WithTLSConfig(cfg *tls.Config) RunOption { func WithTLSConfig(cfg *tls.Config) RunOption {
return func(srv *Server) { return func(svr *Server) {
srv.ngin.setTlsConfig(cfg) svr.ngin.setTlsConfig(cfg)
} }
} }
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(srv *Server) { return func(svr *Server) {
srv.ngin.setUnauthorizedCallback(callback) svr.ngin.setUnauthorizedCallback(callback)
} }
} }
// WithUnsignedCallback returns a RunOption that with given unsigned callback set. // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
return func(srv *Server) { return func(svr *Server) {
srv.ngin.setUnsignedCallback(callback) svr.ngin.setUnsignedCallback(callback)
} }
} }

View File

@ -56,22 +56,22 @@ Port: 54321
} }
for _, test := range tests { for _, test := range tests {
var srv *Server var svr *Server
var err error var err error
if test.fail { if test.fail {
_, err = NewServer(test.c, test.opts...) _, err = NewServer(test.c, test.opts...)
assert.NotNil(t, err) assert.NotNil(t, err)
continue continue
} else { } else {
srv = MustNewServer(test.c, test.opts...) svr = MustNewServer(test.c, test.opts...)
} }
srv.Use(ToMiddleware(func(next http.Handler) http.Handler { svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
})) }))
srv.AddRoute(Route{ svr.AddRoute(Route{
Method: http.MethodGet, Method: http.MethodGet,
Path: "/", Path: "/",
Handler: nil, Handler: nil,
@ -89,8 +89,8 @@ Port: 54321
} }
}() }()
srv.Start() svr.Start()
srv.Stop() svr.Stop()
}() }()
} }
} }
@ -290,9 +290,9 @@ Port: 54321
} }
for _, testCase := range testCases { for _, testCase := range testCases {
srv, err := NewServer(testCase.c, testCase.opts...) svr, err := NewServer(testCase.c, testCase.opts...)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, srv.ngin.tlsConfig, testCase.res) assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
} }
} }
@ -304,11 +304,11 @@ Port: 54321
var cnf RestConf var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
rt := router.NewRouter() rt := router.NewRouter()
srv, err := NewServer(cnf, WithRouter(rt)) svr, err := NewServer(cnf, WithRouter(rt))
assert.Nil(t, err) assert.Nil(t, err)
opt := WithCors("local") opt := WithCors("local")
opt(srv) opt(svr)
} }
func TestWithCustomCors(t *testing.T) { func TestWithCustomCors(t *testing.T) {
@ -319,7 +319,7 @@ Port: 54321
var cnf RestConf var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
rt := router.NewRouter() rt := router.NewRouter()
srv, err := NewServer(cnf, WithRouter(rt)) svr, err := NewServer(cnf, WithRouter(rt))
assert.Nil(t, err) assert.Nil(t, err)
opt := WithCustomCors(func(header http.Header) { opt := WithCustomCors(func(header http.Header) {
@ -327,5 +327,5 @@ Port: 54321
}, func(w http.ResponseWriter) { }, func(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}, "local") }, "local")
opt(srv) opt(svr)
} }

View File

@ -36,10 +36,10 @@ func main() {
var c config.Config var c config.Config
conf.MustLoad(*configFile, &c) conf.MustLoad(*configFile, &c)
ctx := svc.NewServiceContext(c) ctx := svc.NewServiceContext(c)
srv := server.New{{.serviceNew}}Server(ctx) svr := server.New{{.serviceNew}}Server(ctx)
s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) { s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
{{.pkg}}.Register{{.service}}Server(grpcServer, srv) {{.pkg}}.Register{{.service}}Server(grpcServer, svr)
if c.Mode == service.DevMode || c.Mode == service.TestMode { if c.Mode == service.DevMode || c.Mode == service.TestMode {
reflection.Register(grpcServer) reflection.Register(grpcServer)

View File

@ -23,7 +23,7 @@ func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics}) server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
server.SetName("bar") server.SetName("bar")
var vals []int var vals []int
f := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { f := func(_ interface{}, _ grpc.ServerStream, _ *grpc.StreamServerInfo, _ grpc.StreamHandler) error {
vals = append(vals, 1) vals = append(vals, 1)
return nil return nil
} }

View File

@ -9,13 +9,13 @@ import (
// StreamAuthorizeInterceptor returns a func that uses given authenticator in processing stream requests. // StreamAuthorizeInterceptor returns a func that uses given authenticator in processing stream requests.
func StreamAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.StreamServerInterceptor { func StreamAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, return func(svr interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error { handler grpc.StreamHandler) error {
if err := authenticator.Authenticate(stream.Context()); err != nil { if err := authenticator.Authenticate(stream.Context()); err != nil {
return err return err
} }
return handler(srv, stream) return handler(svr, stream)
} }
} }

View File

@ -65,7 +65,7 @@ func TestStreamAuthorizeInterceptor(t *testing.T) {
}) })
ctx := metadata.NewIncomingContext(context.Background(), md) ctx := metadata.NewIncomingContext(context.Background(), md)
stream := mockedStream{ctx: ctx} stream := mockedStream{ctx: ctx}
err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error { err = interceptor(nil, stream, nil, func(_ interface{}, _ grpc.ServerStream) error {
return nil return nil
}) })
if test.hasError { if test.hasError {

View File

@ -9,11 +9,11 @@ import (
) )
// StreamBreakerInterceptor is an interceptor that acts as a circuit breaker. // StreamBreakerInterceptor is an interceptor that acts as a circuit breaker.
func StreamBreakerInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, func StreamBreakerInterceptor(svr interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) (err error) { handler grpc.StreamHandler) (err error) {
breakerName := info.FullMethod breakerName := info.FullMethod
return breaker.DoWithAcceptable(breakerName, func() error { return breaker.DoWithAcceptable(breakerName, func() error {
return handler(srv, stream) return handler(svr, stream)
}, codes.Acceptable) }, codes.Acceptable)
} }

View File

@ -13,8 +13,7 @@ import (
func TestStreamBreakerInterceptor(t *testing.T) { func TestStreamBreakerInterceptor(t *testing.T) {
err := StreamBreakerInterceptor(nil, nil, &grpc.StreamServerInfo{ err := StreamBreakerInterceptor(nil, nil, &grpc.StreamServerInfo{
FullMethod: "any", FullMethod: "any",
}, func( }, func(_ interface{}, _ grpc.ServerStream) error {
srv interface{}, stream grpc.ServerStream) error {
return status.New(codes.DeadlineExceeded, "any").Err() return status.New(codes.DeadlineExceeded, "any").Err()
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
@ -23,7 +22,7 @@ func TestStreamBreakerInterceptor(t *testing.T) {
func TestUnaryBreakerInterceptor(t *testing.T) { func TestUnaryBreakerInterceptor(t *testing.T) {
_, err := UnaryBreakerInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{ _, err := UnaryBreakerInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "any", FullMethod: "any",
}, func(ctx context.Context, req interface{}) (interface{}, error) { }, func(_ context.Context, _ interface{}) (interface{}, error) {
return nil, status.New(codes.DeadlineExceeded, "any").Err() return nil, status.New(codes.DeadlineExceeded, "any").Err()
}) })
assert.NotNil(t, err) assert.NotNil(t, err)

View File

@ -11,17 +11,17 @@ import (
) )
// StreamCrashInterceptor catches panics in processing stream requests and recovers. // StreamCrashInterceptor catches panics in processing stream requests and recovers.
func StreamCrashInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
handler grpc.StreamHandler) (err error) { handler grpc.StreamHandler) (err error) {
defer handleCrash(func(r interface{}) { defer handleCrash(func(r interface{}) {
err = toPanicError(r) err = toPanicError(r)
}) })
return handler(srv, stream) return handler(svr, stream)
} }
// UnaryCrashInterceptor catches panics in processing unary requests and recovers. // UnaryCrashInterceptor catches panics in processing unary requests and recovers.
func UnaryCrashInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, func UnaryCrashInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) { handler grpc.UnaryHandler) (resp interface{}, err error) {
defer handleCrash(func(r interface{}) { defer handleCrash(func(r interface{}) {
err = toPanicError(r) err = toPanicError(r)

View File

@ -15,7 +15,7 @@ func init() {
func TestStreamCrashInterceptor(t *testing.T) { func TestStreamCrashInterceptor(t *testing.T) {
err := StreamCrashInterceptor(nil, nil, nil, func( err := StreamCrashInterceptor(nil, nil, nil, func(
srv interface{}, stream grpc.ServerStream) error { svr interface{}, stream grpc.ServerStream) error {
panic("mock panic") panic("mock panic")
}) })
assert.NotNil(t, err) assert.NotNil(t, err)

View File

@ -41,12 +41,12 @@ func UnaryTracingInterceptor(ctx context.Context, req interface{}, info *grpc.Un
} }
// StreamTracingInterceptor returns a grpc.StreamServerInterceptor for opentelemetry. // StreamTracingInterceptor returns a grpc.StreamServerInterceptor for opentelemetry.
func StreamTracingInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, func StreamTracingInterceptor(svr interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error { handler grpc.StreamHandler) error {
ctx, span := startSpan(ss.Context(), info.FullMethod) ctx, span := startSpan(ss.Context(), info.FullMethod)
defer span.End() defer span.End()
if err := handler(srv, wrapServerStream(ctx, ss)); err != nil { if err := handler(svr, wrapServerStream(ctx, ss)); err != nil {
s, ok := status.FromError(err) s, ok := status.FromError(err)
if ok { if ok {
span.SetStatus(codes.Error, s.Message()) span.SetStatus(codes.Error, s.Message())

View File

@ -101,7 +101,7 @@ func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
stream := mockedServerStream{ctx: ctx} stream := mockedServerStream{ctx: ctx}
err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
FullMethod: "/foo", FullMethod: "/foo",
}, func(srv interface{}, stream grpc.ServerStream) error { }, func(svr interface{}, stream grpc.ServerStream) error {
defer wg.Done() defer wg.Done()
atomic.AddInt32(&run, 1) atomic.AddInt32(&run, 1)
return nil return nil
@ -138,7 +138,7 @@ func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
stream := mockedServerStream{ctx: ctx} stream := mockedServerStream{ctx: ctx}
err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
FullMethod: "/foo", FullMethod: "/foo",
}, func(srv interface{}, stream grpc.ServerStream) error { }, func(svr interface{}, stream grpc.ServerStream) error {
defer wg.Done() defer wg.Done()
return test.err return test.err
}) })
@ -175,7 +175,7 @@ func TestStreamTracingInterceptor_WithError(t *testing.T) {
stream := mockedServerStream{ctx: ctx} stream := mockedServerStream{ctx: ctx}
err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
FullMethod: "/foo", FullMethod: "/foo",
}, func(srv interface{}, stream grpc.ServerStream) error { }, func(svr interface{}, stream grpc.ServerStream) error {
defer wg.Done() defer wg.Done()
return test.err return test.err
}) })

View File

@ -36,7 +36,7 @@ func TestServer_setupInterceptors(t *testing.T) {
func TestServer(t *testing.T) { func TestServer(t *testing.T) {
SetServerSlowThreshold(time.Second) SetServerSlowThreshold(time.Second)
srv := MustNewServer(RpcServerConf{ svr := MustNewServer(RpcServerConf{
ServiceConf: service.ServiceConf{ ServiceConf: service.ServiceConf{
Log: logx.LogConf{ Log: logx.LogConf{
ServiceName: "foo", ServiceName: "foo",
@ -52,11 +52,11 @@ func TestServer(t *testing.T) {
CpuThreshold: 0, CpuThreshold: 0,
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
srv.AddOptions(grpc.ConnectionTimeout(time.Hour)) svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
srv.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor) svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
srv.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor) svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
go srv.Start() go svr.Start()
srv.Stop() svr.Stop()
} }
func TestServerError(t *testing.T) { func TestServerError(t *testing.T) {
@ -79,7 +79,7 @@ func TestServerError(t *testing.T) {
} }
func TestServer_HasEtcd(t *testing.T) { func TestServer_HasEtcd(t *testing.T) {
srv := MustNewServer(RpcServerConf{ svr := MustNewServer(RpcServerConf{
ServiceConf: service.ServiceConf{ ServiceConf: service.ServiceConf{
Log: logx.LogConf{ Log: logx.LogConf{
ServiceName: "foo", ServiceName: "foo",
@ -94,15 +94,15 @@ func TestServer_HasEtcd(t *testing.T) {
Redis: redis.RedisKeyConf{}, Redis: redis.RedisKeyConf{},
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
srv.AddOptions(grpc.ConnectionTimeout(time.Hour)) svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
srv.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor) svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
srv.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor) svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
go srv.Start() go svr.Start()
srv.Stop() svr.Stop()
} }
func TestServer_StartFailed(t *testing.T) { func TestServer_StartFailed(t *testing.T) {
srv := MustNewServer(RpcServerConf{ svr := MustNewServer(RpcServerConf{
ServiceConf: service.ServiceConf{ ServiceConf: service.ServiceConf{
Log: logx.LogConf{ Log: logx.LogConf{
ServiceName: "foo", ServiceName: "foo",
@ -113,7 +113,7 @@ func TestServer_StartFailed(t *testing.T) {
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
assert.Panics(t, srv.Start) assert.Panics(t, svr.Start)
} }
type mockedServer struct { type mockedServer struct {