mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
chore: add unit tests (#1615)
* test: add more tests * test: add more tests
This commit is contained in:
parent
60760b52ab
commit
3b7ca86e4f
@ -90,14 +90,14 @@ func TestParseFullMethod(t *testing.T) {
|
||||
semconv.RPCMethodKey.String("theMethod"),
|
||||
},
|
||||
}, {
|
||||
fullMethod: "/pkg.srv",
|
||||
name: "pkg.srv",
|
||||
fullMethod: "/pkg.svr",
|
||||
name: "pkg.svr",
|
||||
attr: []attribute.KeyValue(nil),
|
||||
}, {
|
||||
fullMethod: "/pkg.srv/",
|
||||
name: "pkg.srv/",
|
||||
fullMethod: "/pkg.svr/",
|
||||
name: "pkg.svr/",
|
||||
attr: []attribute.KeyValue{
|
||||
semconv.RPCServiceKey.String("pkg.srv"),
|
||||
semconv.RPCServiceKey.String("pkg.svr"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -35,16 +35,16 @@ type engine struct {
|
||||
}
|
||||
|
||||
func newEngine(c RestConf) *engine {
|
||||
srv := &engine{
|
||||
svr := &engine{
|
||||
conf: c,
|
||||
}
|
||||
if c.CpuThreshold > 0 {
|
||||
srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
|
||||
srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
|
||||
svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
|
||||
svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
|
||||
(c.CpuThreshold + topCpuUsage) >> 1))
|
||||
}
|
||||
|
||||
return srv
|
||||
return svr
|
||||
}
|
||||
|
||||
func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
@ -238,9 +238,9 @@ func (ng *engine) start(router httpx.Router) error {
|
||||
}
|
||||
|
||||
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
|
||||
ng.conf.KeyFile, router, func(srv *http.Server) {
|
||||
ng.conf.KeyFile, router, func(svr *http.Server) {
|
||||
if ng.tlsConfig != nil {
|
||||
srv.TLSConfig = ng.tlsConfig
|
||||
svr.TLSConfig = ng.tlsConfig
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -36,3 +36,8 @@ func TestError(t *testing.T) {
|
||||
assert.True(t, strings.Contains(val, "third"))
|
||||
assert.True(t, strings.Contains(val, "\n"))
|
||||
}
|
||||
|
||||
func TestContextKey_String(t *testing.T) {
|
||||
val := contextKey("foo")
|
||||
assert.True(t, strings.Contains(val.String(), "foo"))
|
||||
}
|
||||
|
@ -151,6 +151,8 @@ func TestContentSecurity(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
encrypted := test.mode != "0"
|
||||
assert.Equal(t, encrypted, header.Encrypted())
|
||||
assert.Equal(t, test.code, VerifySignature(r, header, time.Minute))
|
||||
})
|
||||
}
|
||||
|
@ -10,25 +10,25 @@ import (
|
||||
)
|
||||
|
||||
// StartOption defines the method to customize http.Server.
|
||||
type StartOption func(srv *http.Server)
|
||||
type StartOption func(svr *http.Server)
|
||||
|
||||
// StartHttp starts a http server.
|
||||
func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error {
|
||||
return start(host, port, handler, func(srv *http.Server) error {
|
||||
return srv.ListenAndServe()
|
||||
return start(host, port, handler, func(svr *http.Server) error {
|
||||
return svr.ListenAndServe()
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
// StartHttps starts a https server.
|
||||
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
|
||||
opts ...StartOption) error {
|
||||
return start(host, port, handler, func(srv *http.Server) error {
|
||||
return start(host, port, handler, func(svr *http.Server) error {
|
||||
// certFile and keyFile are set in buildHttpsServer
|
||||
return srv.ListenAndServeTLS(certFile, keyFile)
|
||||
return svr.ListenAndServeTLS(certFile, keyFile)
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
func start(host string, port int, handler http.Handler, run func(srv *http.Server) error,
|
||||
func start(host string, port int, handler http.Handler, run func(svr *http.Server) error,
|
||||
opts ...StartOption) (err error) {
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", host, port),
|
||||
|
33
rest/internal/starter_test.go
Normal file
33
rest/internal/starter_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStartHttp(t *testing.T) {
|
||||
svr := httptest.NewUnstartedServer(http.NotFoundHandler())
|
||||
fields := strings.Split(svr.Listener.Addr().String(), ":")
|
||||
port, err := strconv.Atoi(fields[1])
|
||||
assert.Nil(t, err)
|
||||
err = StartHttp(fields[0], port, http.NotFoundHandler(), func(svr *http.Server) {
|
||||
svr.IdleTimeout = 0
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestStartHttps(t *testing.T) {
|
||||
svr := httptest.NewTLSServer(http.NotFoundHandler())
|
||||
fields := strings.Split(svr.Listener.Addr().String(), ":")
|
||||
port, err := strconv.Atoi(fields[1])
|
||||
assert.Nil(t, err)
|
||||
err = StartHttps(fields[0], port, "", "", http.NotFoundHandler(), func(svr *http.Server) {
|
||||
svr.IdleTimeout = 0
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
@ -225,22 +225,22 @@ func WithTimeout(timeout time.Duration) RouteOption {
|
||||
|
||||
// WithTLSConfig returns a RunOption that with given tls config.
|
||||
func WithTLSConfig(cfg *tls.Config) RunOption {
|
||||
return func(srv *Server) {
|
||||
srv.ngin.setTlsConfig(cfg)
|
||||
return func(svr *Server) {
|
||||
svr.ngin.setTlsConfig(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||
return func(srv *Server) {
|
||||
srv.ngin.setUnauthorizedCallback(callback)
|
||||
return func(svr *Server) {
|
||||
svr.ngin.setUnauthorizedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnsignedCallback returns a RunOption that with given unsigned callback set.
|
||||
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
||||
return func(srv *Server) {
|
||||
srv.ngin.setUnsignedCallback(callback)
|
||||
return func(svr *Server) {
|
||||
svr.ngin.setUnsignedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,22 +56,22 @@ Port: 54321
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
var srv *Server
|
||||
var svr *Server
|
||||
var err error
|
||||
if test.fail {
|
||||
_, err = NewServer(test.c, test.opts...)
|
||||
assert.NotNil(t, err)
|
||||
continue
|
||||
} else {
|
||||
srv = MustNewServer(test.c, test.opts...)
|
||||
svr = MustNewServer(test.c, test.opts...)
|
||||
}
|
||||
|
||||
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
|
||||
svr.Use(ToMiddleware(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}))
|
||||
srv.AddRoute(Route{
|
||||
svr.AddRoute(Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: nil,
|
||||
@ -89,8 +89,8 @@ Port: 54321
|
||||
}
|
||||
}()
|
||||
|
||||
srv.Start()
|
||||
srv.Stop()
|
||||
svr.Start()
|
||||
svr.Stop()
|
||||
}()
|
||||
}
|
||||
}
|
||||
@ -290,9 +290,9 @@ Port: 54321
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
srv, err := NewServer(testCase.c, testCase.opts...)
|
||||
svr, err := NewServer(testCase.c, testCase.opts...)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
|
||||
assert.Equal(t, svr.ngin.tlsConfig, testCase.res)
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,11 +304,11 @@ Port: 54321
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||
rt := router.NewRouter()
|
||||
srv, err := NewServer(cnf, WithRouter(rt))
|
||||
svr, err := NewServer(cnf, WithRouter(rt))
|
||||
assert.Nil(t, err)
|
||||
|
||||
opt := WithCors("local")
|
||||
opt(srv)
|
||||
opt(svr)
|
||||
}
|
||||
|
||||
func TestWithCustomCors(t *testing.T) {
|
||||
@ -319,7 +319,7 @@ Port: 54321
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||
rt := router.NewRouter()
|
||||
srv, err := NewServer(cnf, WithRouter(rt))
|
||||
svr, err := NewServer(cnf, WithRouter(rt))
|
||||
assert.Nil(t, err)
|
||||
|
||||
opt := WithCustomCors(func(header http.Header) {
|
||||
@ -327,5 +327,5 @@ Port: 54321
|
||||
}, func(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, "local")
|
||||
opt(srv)
|
||||
opt(svr)
|
||||
}
|
||||
|
@ -36,10 +36,10 @@ func main() {
|
||||
var c config.Config
|
||||
conf.MustLoad(*configFile, &c)
|
||||
ctx := svc.NewServiceContext(c)
|
||||
srv := server.New{{.serviceNew}}Server(ctx)
|
||||
svr := server.New{{.serviceNew}}Server(ctx)
|
||||
|
||||
s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
|
||||
{{.pkg}}.Register{{.service}}Server(grpcServer, srv)
|
||||
{{.pkg}}.Register{{.service}}Server(grpcServer, svr)
|
||||
|
||||
if c.Mode == service.DevMode || c.Mode == service.TestMode {
|
||||
reflection.Register(grpcServer)
|
||||
|
@ -23,7 +23,7 @@ func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
|
||||
server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
|
||||
server.SetName("bar")
|
||||
var vals []int
|
||||
f := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
f := func(_ interface{}, _ grpc.ServerStream, _ *grpc.StreamServerInfo, _ grpc.StreamHandler) error {
|
||||
vals = append(vals, 1)
|
||||
return nil
|
||||
}
|
||||
|
@ -9,13 +9,13 @@ import (
|
||||
|
||||
// StreamAuthorizeInterceptor returns a func that uses given authenticator in processing stream requests.
|
||||
func StreamAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
return func(svr interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
if err := authenticator.Authenticate(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return handler(srv, stream)
|
||||
return handler(svr, stream)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ func TestStreamAuthorizeInterceptor(t *testing.T) {
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
stream := mockedStream{ctx: ctx}
|
||||
err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
err = interceptor(nil, stream, nil, func(_ interface{}, _ grpc.ServerStream) error {
|
||||
return nil
|
||||
})
|
||||
if test.hasError {
|
||||
|
@ -9,11 +9,11 @@ import (
|
||||
)
|
||||
|
||||
// StreamBreakerInterceptor is an interceptor that acts as a circuit breaker.
|
||||
func StreamBreakerInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
func StreamBreakerInterceptor(svr interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) (err error) {
|
||||
breakerName := info.FullMethod
|
||||
return breaker.DoWithAcceptable(breakerName, func() error {
|
||||
return handler(srv, stream)
|
||||
return handler(svr, stream)
|
||||
}, codes.Acceptable)
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,7 @@ import (
|
||||
func TestStreamBreakerInterceptor(t *testing.T) {
|
||||
err := StreamBreakerInterceptor(nil, nil, &grpc.StreamServerInfo{
|
||||
FullMethod: "any",
|
||||
}, func(
|
||||
srv interface{}, stream grpc.ServerStream) error {
|
||||
}, func(_ interface{}, _ grpc.ServerStream) error {
|
||||
return status.New(codes.DeadlineExceeded, "any").Err()
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
@ -23,7 +22,7 @@ func TestStreamBreakerInterceptor(t *testing.T) {
|
||||
func TestUnaryBreakerInterceptor(t *testing.T) {
|
||||
_, err := UnaryBreakerInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "any",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
}, func(_ context.Context, _ interface{}) (interface{}, error) {
|
||||
return nil, status.New(codes.DeadlineExceeded, "any").Err()
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
@ -11,17 +11,17 @@ import (
|
||||
)
|
||||
|
||||
// StreamCrashInterceptor catches panics in processing stream requests and recovers.
|
||||
func StreamCrashInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) (err error) {
|
||||
defer handleCrash(func(r interface{}) {
|
||||
err = toPanicError(r)
|
||||
})
|
||||
|
||||
return handler(srv, stream)
|
||||
return handler(svr, stream)
|
||||
}
|
||||
|
||||
// UnaryCrashInterceptor catches panics in processing unary requests and recovers.
|
||||
func UnaryCrashInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
func UnaryCrashInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
defer handleCrash(func(r interface{}) {
|
||||
err = toPanicError(r)
|
||||
|
@ -15,7 +15,7 @@ func init() {
|
||||
|
||||
func TestStreamCrashInterceptor(t *testing.T) {
|
||||
err := StreamCrashInterceptor(nil, nil, nil, func(
|
||||
srv interface{}, stream grpc.ServerStream) error {
|
||||
svr interface{}, stream grpc.ServerStream) error {
|
||||
panic("mock panic")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
@ -41,12 +41,12 @@ func UnaryTracingInterceptor(ctx context.Context, req interface{}, info *grpc.Un
|
||||
}
|
||||
|
||||
// StreamTracingInterceptor returns a grpc.StreamServerInterceptor for opentelemetry.
|
||||
func StreamTracingInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
func StreamTracingInterceptor(svr interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
ctx, span := startSpan(ss.Context(), info.FullMethod)
|
||||
defer span.End()
|
||||
|
||||
if err := handler(srv, wrapServerStream(ctx, ss)); err != nil {
|
||||
if err := handler(svr, wrapServerStream(ctx, ss)); err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if ok {
|
||||
span.SetStatus(codes.Error, s.Message())
|
||||
|
@ -101,7 +101,7 @@ func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
|
||||
stream := mockedServerStream{ctx: ctx}
|
||||
err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
|
||||
FullMethod: "/foo",
|
||||
}, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
}, func(svr interface{}, stream grpc.ServerStream) error {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&run, 1)
|
||||
return nil
|
||||
@ -138,7 +138,7 @@ func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
|
||||
stream := mockedServerStream{ctx: ctx}
|
||||
err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
|
||||
FullMethod: "/foo",
|
||||
}, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
}, func(svr interface{}, stream grpc.ServerStream) error {
|
||||
defer wg.Done()
|
||||
return test.err
|
||||
})
|
||||
@ -175,7 +175,7 @@ func TestStreamTracingInterceptor_WithError(t *testing.T) {
|
||||
stream := mockedServerStream{ctx: ctx}
|
||||
err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
|
||||
FullMethod: "/foo",
|
||||
}, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
}, func(svr interface{}, stream grpc.ServerStream) error {
|
||||
defer wg.Done()
|
||||
return test.err
|
||||
})
|
||||
|
@ -36,7 +36,7 @@ func TestServer_setupInterceptors(t *testing.T) {
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
SetServerSlowThreshold(time.Second)
|
||||
srv := MustNewServer(RpcServerConf{
|
||||
svr := MustNewServer(RpcServerConf{
|
||||
ServiceConf: service.ServiceConf{
|
||||
Log: logx.LogConf{
|
||||
ServiceName: "foo",
|
||||
@ -52,11 +52,11 @@ func TestServer(t *testing.T) {
|
||||
CpuThreshold: 0,
|
||||
}, func(server *grpc.Server) {
|
||||
})
|
||||
srv.AddOptions(grpc.ConnectionTimeout(time.Hour))
|
||||
srv.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
|
||||
srv.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
|
||||
go srv.Start()
|
||||
srv.Stop()
|
||||
svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
|
||||
svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
|
||||
svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
|
||||
go svr.Start()
|
||||
svr.Stop()
|
||||
}
|
||||
|
||||
func TestServerError(t *testing.T) {
|
||||
@ -79,7 +79,7 @@ func TestServerError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_HasEtcd(t *testing.T) {
|
||||
srv := MustNewServer(RpcServerConf{
|
||||
svr := MustNewServer(RpcServerConf{
|
||||
ServiceConf: service.ServiceConf{
|
||||
Log: logx.LogConf{
|
||||
ServiceName: "foo",
|
||||
@ -94,15 +94,15 @@ func TestServer_HasEtcd(t *testing.T) {
|
||||
Redis: redis.RedisKeyConf{},
|
||||
}, func(server *grpc.Server) {
|
||||
})
|
||||
srv.AddOptions(grpc.ConnectionTimeout(time.Hour))
|
||||
srv.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
|
||||
srv.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
|
||||
go srv.Start()
|
||||
srv.Stop()
|
||||
svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
|
||||
svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor)
|
||||
svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor)
|
||||
go svr.Start()
|
||||
svr.Stop()
|
||||
}
|
||||
|
||||
func TestServer_StartFailed(t *testing.T) {
|
||||
srv := MustNewServer(RpcServerConf{
|
||||
svr := MustNewServer(RpcServerConf{
|
||||
ServiceConf: service.ServiceConf{
|
||||
Log: logx.LogConf{
|
||||
ServiceName: "foo",
|
||||
@ -113,7 +113,7 @@ func TestServer_StartFailed(t *testing.T) {
|
||||
}, func(server *grpc.Server) {
|
||||
})
|
||||
|
||||
assert.Panics(t, srv.Start)
|
||||
assert.Panics(t, svr.Start)
|
||||
}
|
||||
|
||||
type mockedServer struct {
|
||||
|
Loading…
Reference in New Issue
Block a user