mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
feat: rest.WithChain to replace builtin middlewares (#2033)
* feat: rest.WithChain to replace builtin middlewares * chore: add comments * chore: refine code
This commit is contained in:
parent
50f16e2892
commit
47c49de94e
109
rest/chain/chain.go
Normal file
109
rest/chain/chain.go
Normal file
@ -0,0 +1,109 @@
|
||||
package chain
|
||||
|
||||
// This is a modified version of https://github.com/justinas/alice
|
||||
// The original code is licensed under the MIT license.
|
||||
// It's modified for couple reasons:
|
||||
// - Added the Chain interface
|
||||
// - Added support for the Chain.Prepend(...) method
|
||||
|
||||
import "net/http"
|
||||
|
||||
type (
|
||||
// Chain defines a chain of middleware.
|
||||
Chain interface {
|
||||
Append(middlewares ...Middleware) Chain
|
||||
Prepend(middlewares ...Middleware) Chain
|
||||
Then(h http.Handler) http.Handler
|
||||
ThenFunc(fn http.HandlerFunc) http.Handler
|
||||
}
|
||||
|
||||
// Middleware is an HTTP middleware.
|
||||
Middleware func(http.Handler) http.Handler
|
||||
|
||||
// chain acts as a list of http.Handler middlewares.
|
||||
// chain is effectively immutable:
|
||||
// once created, it will always hold
|
||||
// the same set of middlewares in the same order.
|
||||
chain struct {
|
||||
middlewares []Middleware
|
||||
}
|
||||
)
|
||||
|
||||
// New creates a new Chain, memorizing the given list of middleware middlewares.
|
||||
// New serves no other function, middlewares are only called upon a call to Then() or ThenFunc().
|
||||
func New(middlewares ...Middleware) Chain {
|
||||
return chain{middlewares: append(([]Middleware)(nil), middlewares...)}
|
||||
}
|
||||
|
||||
// Append extends a chain, adding the specified middlewares as the last ones in the request flow.
|
||||
//
|
||||
// c := chain.New(m1, m2)
|
||||
// c.Append(m3, m4)
|
||||
// // requests in c go m1 -> m2 -> m3 -> m4
|
||||
func (c chain) Append(middlewares ...Middleware) Chain {
|
||||
return chain{middlewares: join(c.middlewares, middlewares)}
|
||||
}
|
||||
|
||||
// Prepend extends a chain by adding the specified chain as the first one in the request flow.
|
||||
//
|
||||
// c := chain.New(m3, m4)
|
||||
// c1 := chain.New(m1, m2)
|
||||
// c.Prepend(c1)
|
||||
// // requests in c go m1 -> m2 -> m3 -> m4
|
||||
func (c chain) Prepend(middlewares ...Middleware) Chain {
|
||||
return chain{middlewares: join(middlewares, c.middlewares)}
|
||||
}
|
||||
|
||||
// Then chains the middleware and returns the final http.Handler.
|
||||
// New(m1, m2, m3).Then(h)
|
||||
// is equivalent to:
|
||||
// m1(m2(m3(h)))
|
||||
// When the request comes in, it will be passed to m1, then m2, then m3
|
||||
// and finally, the given handler
|
||||
// (assuming every middleware calls the following one).
|
||||
//
|
||||
// A chain can be safely reused by calling Then() several times.
|
||||
// stdStack := chain.New(ratelimitHandler, csrfHandler)
|
||||
// indexPipe = stdStack.Then(indexHandler)
|
||||
// authPipe = stdStack.Then(authHandler)
|
||||
// Note that middlewares are called on every call to Then() or ThenFunc()
|
||||
// and thus several instances of the same middleware will be created
|
||||
// when a chain is reused in this way.
|
||||
// For proper middleware, this should cause no problems.
|
||||
//
|
||||
// Then() treats nil as http.DefaultServeMux.
|
||||
func (c chain) Then(h http.Handler) http.Handler {
|
||||
if h == nil {
|
||||
h = http.DefaultServeMux
|
||||
}
|
||||
|
||||
for i := range c.middlewares {
|
||||
h = c.middlewares[len(c.middlewares)-1-i](h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ThenFunc works identically to Then, but takes
|
||||
// a HandlerFunc instead of a Handler.
|
||||
//
|
||||
// The following two statements are equivalent:
|
||||
// c.Then(http.HandlerFunc(fn))
|
||||
// c.ThenFunc(fn)
|
||||
//
|
||||
// ThenFunc provides all the guarantees of Then.
|
||||
func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {
|
||||
// This nil check cannot be removed due to the "nil is not nil" common mistake in Go.
|
||||
// Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil
|
||||
if fn == nil {
|
||||
return c.Then(nil)
|
||||
}
|
||||
return c.Then(fn)
|
||||
}
|
||||
|
||||
func join(a, b []Middleware) []Middleware {
|
||||
mids := make([]Middleware, 0, len(a)+len(b))
|
||||
mids = append(mids, a...)
|
||||
mids = append(mids, b...)
|
||||
return mids
|
||||
}
|
126
rest/chain/chain_test.go
Normal file
126
rest/chain/chain_test.go
Normal file
@ -0,0 +1,126 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// A constructor for middleware
|
||||
// that writes its own "tag" into the RW and does nothing else.
|
||||
// Useful in checking if a chain is behaving in the right order.
|
||||
func tagMiddleware(tag string) Middleware {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(tag))
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
|
||||
// but the best we can do.
|
||||
func funcsEqual(f1, f2 interface{}) bool {
|
||||
val1 := reflect.ValueOf(f1)
|
||||
val2 := reflect.ValueOf(f2)
|
||||
return val1.Pointer() == val2.Pointer()
|
||||
}
|
||||
|
||||
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("app\n"))
|
||||
})
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
c1 := func(h http.Handler) http.Handler {
|
||||
return nil
|
||||
}
|
||||
|
||||
c2 := func(h http.Handler) http.Handler {
|
||||
return http.StripPrefix("potato", nil)
|
||||
}
|
||||
|
||||
slice := []Middleware{c1, c2}
|
||||
c := New(slice...)
|
||||
for k := range slice {
|
||||
assert.True(t, funcsEqual(c.(chain).middlewares[k], slice[k]),
|
||||
"New does not add constructors correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenWorksWithNoMiddleware(t *testing.T) {
|
||||
assert.True(t, funcsEqual(New().Then(testApp), testApp),
|
||||
"Then does not work with no middleware")
|
||||
}
|
||||
|
||||
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||
assert.Equal(t, http.DefaultServeMux, New().Then(nil),
|
||||
"Then does not treat nil as DefaultServeMux")
|
||||
}
|
||||
|
||||
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||
assert.Equal(t, http.DefaultServeMux, New().ThenFunc(nil),
|
||||
"ThenFunc does not treat nil as DefaultServeMux")
|
||||
}
|
||||
|
||||
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
chained := New().ThenFunc(fn)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
chained.ServeHTTP(rec, (*http.Request)(nil))
|
||||
|
||||
assert.Equal(t, reflect.TypeOf((http.HandlerFunc)(nil)), reflect.TypeOf(chained),
|
||||
"ThenFunc does not construct HandlerFunc")
|
||||
}
|
||||
|
||||
func TestThenOrdersHandlersCorrectly(t *testing.T) {
|
||||
t1 := tagMiddleware("t1\n")
|
||||
t2 := tagMiddleware("t2\n")
|
||||
t3 := tagMiddleware("t3\n")
|
||||
|
||||
chained := New(t1, t2, t3).Then(testApp)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
chained.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, "t1\nt2\nt3\napp\n", w.Body.String(),
|
||||
"Then does not order handlers correctly")
|
||||
}
|
||||
|
||||
func TestAppendAddsHandlersCorrectly(t *testing.T) {
|
||||
c := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||
c = c.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||
h := c.Then(testApp)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
|
||||
"Append does not add handlers correctly")
|
||||
}
|
||||
|
||||
func TestExtendAddsHandlersCorrectly(t *testing.T) {
|
||||
c := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||
c = c.Prepend(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||
h := c.Then(testApp)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
|
||||
"Extend does not add handlers in correctly")
|
||||
}
|
@ -8,10 +8,10 @@ import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/zeromicro/go-zero/core/codec"
|
||||
"github.com/zeromicro/go-zero/core/load"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/rest/chain"
|
||||
"github.com/zeromicro/go-zero/rest/handler"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"github.com/zeromicro/go-zero/rest/internal"
|
||||
@ -25,15 +25,15 @@ const topCpuUsage = 1000
|
||||
var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||
|
||||
type engine struct {
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
disableDefaultMiddlewares bool
|
||||
middlewares []Middleware
|
||||
shedder load.Shedder
|
||||
priorityShedder load.Shedder
|
||||
tlsConfig *tls.Config
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
chain chain.Chain
|
||||
middlewares []Middleware
|
||||
shedder load.Shedder
|
||||
priorityShedder load.Shedder
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func newEngine(c RestConf) *engine {
|
||||
@ -53,20 +53,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
ng.routes = append(ng.routes, r)
|
||||
}
|
||||
|
||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
||||
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||
verifier func(chain.Chain) chain.Chain) chain.Chain {
|
||||
if fr.jwt.enabled {
|
||||
if len(fr.jwt.prevSecret) == 0 {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
chn = chn.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
||||
} else {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
chn = chn.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithPrevSecret(fr.jwt.prevSecret),
|
||||
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
||||
}
|
||||
}
|
||||
|
||||
return verifier(chain)
|
||||
return verifier(chn)
|
||||
}
|
||||
|
||||
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||
@ -85,10 +85,10 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
|
||||
}
|
||||
|
||||
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
||||
var chain alice.Chain
|
||||
if !ng.disableDefaultMiddlewares {
|
||||
chain = alice.New(
|
||||
route Route, verifier func(chain.Chain) chain.Chain) error {
|
||||
chn := ng.chain
|
||||
if chn == nil {
|
||||
chn = chain.New(
|
||||
handler.TracingHandler(ng.conf.Name, route.Path),
|
||||
ng.getLogHandler(),
|
||||
handler.PrometheusHandler(route.Path),
|
||||
@ -103,11 +103,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
|
||||
)
|
||||
}
|
||||
|
||||
chn = ng.appendAuthHandler(fr, chn, verifier)
|
||||
|
||||
for _, middleware := range ng.middlewares {
|
||||
chain = chain.Append(convertMiddleware(middleware))
|
||||
chn = chn.Append(convertMiddleware(middleware))
|
||||
}
|
||||
chain = ng.appendAuthHandler(fr, chain, verifier)
|
||||
handle := chain.ThenFunc(route.Handler)
|
||||
handle := chn.ThenFunc(route.Handler)
|
||||
|
||||
return router.Handle(route.Method, route.Path, handle)
|
||||
}
|
||||
@ -171,16 +172,16 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
chain := alice.New(
|
||||
chn := chain.New(
|
||||
handler.TracingHandler(ng.conf.Name, ""),
|
||||
ng.getLogHandler(),
|
||||
)
|
||||
|
||||
var h http.Handler
|
||||
if next != nil {
|
||||
h = chain.Then(next)
|
||||
h = chn.Then(next)
|
||||
} else {
|
||||
h = chain.Then(http.NotFoundHandler())
|
||||
h = chn.Then(http.NotFoundHandler())
|
||||
}
|
||||
|
||||
cw := response.NewHeaderOnceResponseWriter(w)
|
||||
@ -218,10 +219,10 @@ func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
|
||||
ng.unsignedCallback = callback
|
||||
}
|
||||
|
||||
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
||||
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
|
||||
if !signature.enabled {
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return chain
|
||||
return func(chn chain.Chain) chain.Chain {
|
||||
return chn
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -230,8 +231,8 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
|
||||
return nil, ErrSignatureConfig
|
||||
}
|
||||
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return chain
|
||||
return func(chn chain.Chain) chain.Chain {
|
||||
return chn
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -247,14 +248,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
|
||||
decrypters[fingerprint] = decrypter
|
||||
}
|
||||
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return func(chn chain.Chain) chain.Chain {
|
||||
if ng.unsignedCallback != nil {
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
return chn.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
|
||||
}
|
||||
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict))
|
||||
return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -229,46 +229,6 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_checkedChain(t *testing.T) {
|
||||
var called int32
|
||||
middleware1 := func() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&called, 1)
|
||||
next.ServeHTTP(w, r)
|
||||
atomic.AddInt32(&called, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
middleware2 := func() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&called, 1)
|
||||
next.ServeHTTP(w, r)
|
||||
atomic.AddInt32(&called, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
server := MustNewServer(RestConf{}, DisableDefaultMiddlewares())
|
||||
server.Use(ToMiddleware(middleware1()))
|
||||
server.Use(ToMiddleware(middleware2()))
|
||||
server.router = chainRouter{}
|
||||
server.AddRoutes(
|
||||
[]Route{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(_ http.ResponseWriter, _ *http.Request) {
|
||||
atomic.AddInt32(&called, 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
server.ngin.bindRoutes(chainRouter{})
|
||||
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
|
||||
}
|
||||
|
||||
func TestEngine_notFoundHandler(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
@ -374,7 +334,7 @@ type mockedRouter struct{}
|
||||
func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
|
||||
func (m mockedRouter) Handle(_, _ string, _ http.Handler) error {
|
||||
func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
|
||||
return errors.New("foo")
|
||||
}
|
||||
|
||||
@ -383,19 +343,3 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
|
||||
|
||||
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
||||
}
|
||||
|
||||
type chainRouter struct{}
|
||||
|
||||
func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
|
||||
func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
|
||||
handler.ServeHTTP(nil, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
|
||||
}
|
||||
|
||||
func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/rest/chain"
|
||||
"github.com/zeromicro/go-zero/rest/handler"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"github.com/zeromicro/go-zero/rest/internal/cors"
|
||||
@ -95,13 +96,6 @@ func (s *Server) Use(middleware Middleware) {
|
||||
s.ngin.use(middleware)
|
||||
}
|
||||
|
||||
// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares.
|
||||
func DisableDefaultMiddlewares() RunOption {
|
||||
return func(svr *Server) {
|
||||
svr.ngin.disableDefaultMiddlewares = true
|
||||
}
|
||||
}
|
||||
|
||||
// ToMiddleware converts the given handler to a Middleware.
|
||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||
@ -109,6 +103,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||
}
|
||||
}
|
||||
|
||||
// WithChain returns a RunOption that uses the given chain to replace the default chain.
|
||||
// JWT auth middleware and the middlewares that added by svr.Use() will be appended.
|
||||
func WithChain(chn chain.Chain) RunOption {
|
||||
return func(svr *Server) {
|
||||
svr.ngin.chain = chn
|
||||
}
|
||||
}
|
||||
|
||||
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
|
||||
func WithCors(origin ...string) RunOption {
|
||||
return func(server *Server) {
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -16,6 +17,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/service"
|
||||
"github.com/zeromicro/go-zero/rest/chain"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"github.com/zeromicro/go-zero/rest/router"
|
||||
)
|
||||
@ -435,3 +437,44 @@ func TestValidateSecret(t *testing.T) {
|
||||
validateSecret("short")
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WithChain(t *testing.T) {
|
||||
var called int32
|
||||
middleware1 := func() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&called, 1)
|
||||
next.ServeHTTP(w, r)
|
||||
atomic.AddInt32(&called, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
middleware2 := func() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&called, 1)
|
||||
next.ServeHTTP(w, r)
|
||||
atomic.AddInt32(&called, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
|
||||
server.AddRoutes(
|
||||
[]Route{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(_ http.ResponseWriter, _ *http.Request) {
|
||||
atomic.AddInt32(&called, 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
rt := router.NewRouter()
|
||||
assert.Nil(t, server.ngin.bindRoutes(rt))
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
assert.Nil(t, err)
|
||||
rt.ServeHTTP(httptest.NewRecorder(), req)
|
||||
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user