mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-03 00:38:40 +08:00
feat: rest.WithChain to replace builtin middlewares (#2033)
* feat: rest.WithChain to replace builtin middlewares * chore: add comments * chore: refine code
This commit is contained in:
parent
50f16e2892
commit
47c49de94e
109
rest/chain/chain.go
Normal file
109
rest/chain/chain.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package chain
|
||||||
|
|
||||||
|
// This is a modified version of https://github.com/justinas/alice
|
||||||
|
// The original code is licensed under the MIT license.
|
||||||
|
// It's modified for couple reasons:
|
||||||
|
// - Added the Chain interface
|
||||||
|
// - Added support for the Chain.Prepend(...) method
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type (
|
||||||
|
// Chain defines a chain of middleware.
|
||||||
|
Chain interface {
|
||||||
|
Append(middlewares ...Middleware) Chain
|
||||||
|
Prepend(middlewares ...Middleware) Chain
|
||||||
|
Then(h http.Handler) http.Handler
|
||||||
|
ThenFunc(fn http.HandlerFunc) http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware is an HTTP middleware.
|
||||||
|
Middleware func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
// chain acts as a list of http.Handler middlewares.
|
||||||
|
// chain is effectively immutable:
|
||||||
|
// once created, it will always hold
|
||||||
|
// the same set of middlewares in the same order.
|
||||||
|
chain struct {
|
||||||
|
middlewares []Middleware
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// New creates a new Chain, memorizing the given list of middleware middlewares.
|
||||||
|
// New serves no other function, middlewares are only called upon a call to Then() or ThenFunc().
|
||||||
|
func New(middlewares ...Middleware) Chain {
|
||||||
|
return chain{middlewares: append(([]Middleware)(nil), middlewares...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append extends a chain, adding the specified middlewares as the last ones in the request flow.
|
||||||
|
//
|
||||||
|
// c := chain.New(m1, m2)
|
||||||
|
// c.Append(m3, m4)
|
||||||
|
// // requests in c go m1 -> m2 -> m3 -> m4
|
||||||
|
func (c chain) Append(middlewares ...Middleware) Chain {
|
||||||
|
return chain{middlewares: join(c.middlewares, middlewares)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepend extends a chain by adding the specified chain as the first one in the request flow.
|
||||||
|
//
|
||||||
|
// c := chain.New(m3, m4)
|
||||||
|
// c1 := chain.New(m1, m2)
|
||||||
|
// c.Prepend(c1)
|
||||||
|
// // requests in c go m1 -> m2 -> m3 -> m4
|
||||||
|
func (c chain) Prepend(middlewares ...Middleware) Chain {
|
||||||
|
return chain{middlewares: join(middlewares, c.middlewares)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then chains the middleware and returns the final http.Handler.
|
||||||
|
// New(m1, m2, m3).Then(h)
|
||||||
|
// is equivalent to:
|
||||||
|
// m1(m2(m3(h)))
|
||||||
|
// When the request comes in, it will be passed to m1, then m2, then m3
|
||||||
|
// and finally, the given handler
|
||||||
|
// (assuming every middleware calls the following one).
|
||||||
|
//
|
||||||
|
// A chain can be safely reused by calling Then() several times.
|
||||||
|
// stdStack := chain.New(ratelimitHandler, csrfHandler)
|
||||||
|
// indexPipe = stdStack.Then(indexHandler)
|
||||||
|
// authPipe = stdStack.Then(authHandler)
|
||||||
|
// Note that middlewares are called on every call to Then() or ThenFunc()
|
||||||
|
// and thus several instances of the same middleware will be created
|
||||||
|
// when a chain is reused in this way.
|
||||||
|
// For proper middleware, this should cause no problems.
|
||||||
|
//
|
||||||
|
// Then() treats nil as http.DefaultServeMux.
|
||||||
|
func (c chain) Then(h http.Handler) http.Handler {
|
||||||
|
if h == nil {
|
||||||
|
h = http.DefaultServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range c.middlewares {
|
||||||
|
h = c.middlewares[len(c.middlewares)-1-i](h)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThenFunc works identically to Then, but takes
|
||||||
|
// a HandlerFunc instead of a Handler.
|
||||||
|
//
|
||||||
|
// The following two statements are equivalent:
|
||||||
|
// c.Then(http.HandlerFunc(fn))
|
||||||
|
// c.ThenFunc(fn)
|
||||||
|
//
|
||||||
|
// ThenFunc provides all the guarantees of Then.
|
||||||
|
func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {
|
||||||
|
// This nil check cannot be removed due to the "nil is not nil" common mistake in Go.
|
||||||
|
// Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil
|
||||||
|
if fn == nil {
|
||||||
|
return c.Then(nil)
|
||||||
|
}
|
||||||
|
return c.Then(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func join(a, b []Middleware) []Middleware {
|
||||||
|
mids := make([]Middleware, 0, len(a)+len(b))
|
||||||
|
mids = append(mids, a...)
|
||||||
|
mids = append(mids, b...)
|
||||||
|
return mids
|
||||||
|
}
|
126
rest/chain/chain_test.go
Normal file
126
rest/chain/chain_test.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package chain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A constructor for middleware
|
||||||
|
// that writes its own "tag" into the RW and does nothing else.
|
||||||
|
// Useful in checking if a chain is behaving in the right order.
|
||||||
|
func tagMiddleware(tag string) Middleware {
|
||||||
|
return func(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(tag))
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
|
||||||
|
// but the best we can do.
|
||||||
|
func funcsEqual(f1, f2 interface{}) bool {
|
||||||
|
val1 := reflect.ValueOf(f1)
|
||||||
|
val2 := reflect.ValueOf(f2)
|
||||||
|
return val1.Pointer() == val2.Pointer()
|
||||||
|
}
|
||||||
|
|
||||||
|
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("app\n"))
|
||||||
|
})
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
c1 := func(h http.Handler) http.Handler {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c2 := func(h http.Handler) http.Handler {
|
||||||
|
return http.StripPrefix("potato", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
slice := []Middleware{c1, c2}
|
||||||
|
c := New(slice...)
|
||||||
|
for k := range slice {
|
||||||
|
assert.True(t, funcsEqual(c.(chain).middlewares[k], slice[k]),
|
||||||
|
"New does not add constructors correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenWorksWithNoMiddleware(t *testing.T) {
|
||||||
|
assert.True(t, funcsEqual(New().Then(testApp), testApp),
|
||||||
|
"Then does not work with no middleware")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||||
|
assert.Equal(t, http.DefaultServeMux, New().Then(nil),
|
||||||
|
"Then does not treat nil as DefaultServeMux")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||||
|
assert.Equal(t, http.DefaultServeMux, New().ThenFunc(nil),
|
||||||
|
"ThenFunc does not treat nil as DefaultServeMux")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
|
||||||
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
chained := New().ThenFunc(fn)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
chained.ServeHTTP(rec, (*http.Request)(nil))
|
||||||
|
|
||||||
|
assert.Equal(t, reflect.TypeOf((http.HandlerFunc)(nil)), reflect.TypeOf(chained),
|
||||||
|
"ThenFunc does not construct HandlerFunc")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenOrdersHandlersCorrectly(t *testing.T) {
|
||||||
|
t1 := tagMiddleware("t1\n")
|
||||||
|
t2 := tagMiddleware("t2\n")
|
||||||
|
t3 := tagMiddleware("t3\n")
|
||||||
|
|
||||||
|
chained := New(t1, t2, t3).Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chained.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, "t1\nt2\nt3\napp\n", w.Body.String(),
|
||||||
|
"Then does not order handlers correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendAddsHandlersCorrectly(t *testing.T) {
|
||||||
|
c := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||||
|
c = c.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||||
|
h := c.Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
|
||||||
|
"Append does not add handlers correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtendAddsHandlersCorrectly(t *testing.T) {
|
||||||
|
c := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||||
|
c = c.Prepend(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||||
|
h := c.Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
|
||||||
|
"Extend does not add handlers in correctly")
|
||||||
|
}
|
@ -8,10 +8,10 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
|
||||||
"github.com/zeromicro/go-zero/core/codec"
|
"github.com/zeromicro/go-zero/core/codec"
|
||||||
"github.com/zeromicro/go-zero/core/load"
|
"github.com/zeromicro/go-zero/core/load"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
|
"github.com/zeromicro/go-zero/rest/chain"
|
||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal"
|
"github.com/zeromicro/go-zero/rest/internal"
|
||||||
@ -29,7 +29,7 @@ type engine struct {
|
|||||||
routes []featuredRoutes
|
routes []featuredRoutes
|
||||||
unauthorizedCallback handler.UnauthorizedCallback
|
unauthorizedCallback handler.UnauthorizedCallback
|
||||||
unsignedCallback handler.UnsignedCallback
|
unsignedCallback handler.UnsignedCallback
|
||||||
disableDefaultMiddlewares bool
|
chain chain.Chain
|
||||||
middlewares []Middleware
|
middlewares []Middleware
|
||||||
shedder load.Shedder
|
shedder load.Shedder
|
||||||
priorityShedder load.Shedder
|
priorityShedder load.Shedder
|
||||||
@ -53,20 +53,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
|||||||
ng.routes = append(ng.routes, r)
|
ng.routes = append(ng.routes, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||||
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
verifier func(chain.Chain) chain.Chain) chain.Chain {
|
||||||
if fr.jwt.enabled {
|
if fr.jwt.enabled {
|
||||||
if len(fr.jwt.prevSecret) == 0 {
|
if len(fr.jwt.prevSecret) == 0 {
|
||||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
chn = chn.Append(handler.Authorize(fr.jwt.secret,
|
||||||
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
||||||
} else {
|
} else {
|
||||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
chn = chn.Append(handler.Authorize(fr.jwt.secret,
|
||||||
handler.WithPrevSecret(fr.jwt.prevSecret),
|
handler.WithPrevSecret(fr.jwt.prevSecret),
|
||||||
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return verifier(chain)
|
return verifier(chn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||||
@ -85,10 +85,10 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
route Route, verifier func(chain.Chain) chain.Chain) error {
|
||||||
var chain alice.Chain
|
chn := ng.chain
|
||||||
if !ng.disableDefaultMiddlewares {
|
if chn == nil {
|
||||||
chain = alice.New(
|
chn = chain.New(
|
||||||
handler.TracingHandler(ng.conf.Name, route.Path),
|
handler.TracingHandler(ng.conf.Name, route.Path),
|
||||||
ng.getLogHandler(),
|
ng.getLogHandler(),
|
||||||
handler.PrometheusHandler(route.Path),
|
handler.PrometheusHandler(route.Path),
|
||||||
@ -103,11 +103,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chn = ng.appendAuthHandler(fr, chn, verifier)
|
||||||
|
|
||||||
for _, middleware := range ng.middlewares {
|
for _, middleware := range ng.middlewares {
|
||||||
chain = chain.Append(convertMiddleware(middleware))
|
chn = chn.Append(convertMiddleware(middleware))
|
||||||
}
|
}
|
||||||
chain = ng.appendAuthHandler(fr, chain, verifier)
|
handle := chn.ThenFunc(route.Handler)
|
||||||
handle := chain.ThenFunc(route.Handler)
|
|
||||||
|
|
||||||
return router.Handle(route.Method, route.Path, handle)
|
return router.Handle(route.Method, route.Path, handle)
|
||||||
}
|
}
|
||||||
@ -171,16 +172,16 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
|||||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
chain := alice.New(
|
chn := chain.New(
|
||||||
handler.TracingHandler(ng.conf.Name, ""),
|
handler.TracingHandler(ng.conf.Name, ""),
|
||||||
ng.getLogHandler(),
|
ng.getLogHandler(),
|
||||||
)
|
)
|
||||||
|
|
||||||
var h http.Handler
|
var h http.Handler
|
||||||
if next != nil {
|
if next != nil {
|
||||||
h = chain.Then(next)
|
h = chn.Then(next)
|
||||||
} else {
|
} else {
|
||||||
h = chain.Then(http.NotFoundHandler())
|
h = chn.Then(http.NotFoundHandler())
|
||||||
}
|
}
|
||||||
|
|
||||||
cw := response.NewHeaderOnceResponseWriter(w)
|
cw := response.NewHeaderOnceResponseWriter(w)
|
||||||
@ -218,10 +219,10 @@ func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
|
|||||||
ng.unsignedCallback = callback
|
ng.unsignedCallback = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
|
||||||
if !signature.enabled {
|
if !signature.enabled {
|
||||||
return func(chain alice.Chain) alice.Chain {
|
return func(chn chain.Chain) chain.Chain {
|
||||||
return chain
|
return chn
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,8 +231,8 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
|
|||||||
return nil, ErrSignatureConfig
|
return nil, ErrSignatureConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(chain alice.Chain) alice.Chain {
|
return func(chn chain.Chain) chain.Chain {
|
||||||
return chain
|
return chn
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,14 +248,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
|
|||||||
decrypters[fingerprint] = decrypter
|
decrypters[fingerprint] = decrypter
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(chain alice.Chain) alice.Chain {
|
return func(chn chain.Chain) chain.Chain {
|
||||||
if ng.unsignedCallback != nil {
|
if ng.unsignedCallback != nil {
|
||||||
return chain.Append(handler.ContentSecurityHandler(
|
return chn.Append(handler.ContentSecurityHandler(
|
||||||
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
|
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
|
||||||
}
|
}
|
||||||
|
|
||||||
return chain.Append(handler.ContentSecurityHandler(
|
return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
|
||||||
decrypters, signature.Expiry, signature.Strict))
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,46 +229,6 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_checkedChain(t *testing.T) {
|
|
||||||
var called int32
|
|
||||||
middleware1 := func() func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
atomic.AddInt32(&called, 1)
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
atomic.AddInt32(&called, 1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
middleware2 := func() func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
atomic.AddInt32(&called, 1)
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
atomic.AddInt32(&called, 1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
server := MustNewServer(RestConf{}, DisableDefaultMiddlewares())
|
|
||||||
server.Use(ToMiddleware(middleware1()))
|
|
||||||
server.Use(ToMiddleware(middleware2()))
|
|
||||||
server.router = chainRouter{}
|
|
||||||
server.AddRoutes(
|
|
||||||
[]Route{
|
|
||||||
{
|
|
||||||
Method: http.MethodGet,
|
|
||||||
Path: "/",
|
|
||||||
Handler: func(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
atomic.AddInt32(&called, 1)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
server.ngin.bindRoutes(chainRouter{})
|
|
||||||
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEngine_notFoundHandler(t *testing.T) {
|
func TestEngine_notFoundHandler(t *testing.T) {
|
||||||
logx.Disable()
|
logx.Disable()
|
||||||
|
|
||||||
@ -374,7 +334,7 @@ type mockedRouter struct{}
|
|||||||
func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
|
func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mockedRouter) Handle(_, _ string, _ http.Handler) error {
|
func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
|
||||||
return errors.New("foo")
|
return errors.New("foo")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,19 +343,3 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
|
|||||||
|
|
||||||
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type chainRouter struct{}
|
|
||||||
|
|
||||||
func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
|
|
||||||
handler.ServeHTTP(nil, nil)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
|
|
||||||
}
|
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/rest/chain"
|
||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/cors"
|
"github.com/zeromicro/go-zero/rest/internal/cors"
|
||||||
@ -95,13 +96,6 @@ func (s *Server) Use(middleware Middleware) {
|
|||||||
s.ngin.use(middleware)
|
s.ngin.use(middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares.
|
|
||||||
func DisableDefaultMiddlewares() RunOption {
|
|
||||||
return func(svr *Server) {
|
|
||||||
svr.ngin.disableDefaultMiddlewares = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToMiddleware converts the given handler to a Middleware.
|
// ToMiddleware converts the given handler to a Middleware.
|
||||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||||
@ -109,6 +103,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithChain returns a RunOption that uses the given chain to replace the default chain.
|
||||||
|
// JWT auth middleware and the middlewares that added by svr.Use() will be appended.
|
||||||
|
func WithChain(chn chain.Chain) RunOption {
|
||||||
|
return func(svr *Server) {
|
||||||
|
svr.ngin.chain = chn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
|
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
|
||||||
func WithCors(origin ...string) RunOption {
|
func WithCors(origin ...string) RunOption {
|
||||||
return func(server *Server) {
|
return func(server *Server) {
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/conf"
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/service"
|
"github.com/zeromicro/go-zero/core/service"
|
||||||
|
"github.com/zeromicro/go-zero/rest/chain"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/router"
|
"github.com/zeromicro/go-zero/rest/router"
|
||||||
)
|
)
|
||||||
@ -435,3 +437,44 @@ func TestValidateSecret(t *testing.T) {
|
|||||||
validateSecret("short")
|
validateSecret("short")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_WithChain(t *testing.T) {
|
||||||
|
var called int32
|
||||||
|
middleware1 := func() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
middleware2 := func() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
|
||||||
|
server.AddRoutes(
|
||||||
|
[]Route{
|
||||||
|
{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rt := router.NewRouter()
|
||||||
|
assert.Nil(t, server.ngin.bindRoutes(rt))
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
rt.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user