diff --git a/rest/config.go b/rest/config.go index 0ac31eba..48738001 100644 --- a/rest/config.go +++ b/rest/config.go @@ -7,6 +7,21 @@ import ( ) type ( + // MiddlewaresConf is the config of middlewares. + MiddlewaresConf struct { + Trace bool `json:",default=true"` + Log bool `json:",default=true"` + Prometheus bool `json:",default=true"` + MaxConns bool `json:",default=true"` + Breaker bool `json:",default=true"` + Shedding bool `json:",default=true"` + Timeout bool `json:",default=true"` + Recover bool `json:",default=true"` + Metrics bool `json:",default=true"` + MaxBytes bool `json:",default=true"` + Gunzip bool `json:",default=true"` + } + // A PrivateKeyConf is a private key config. PrivateKeyConf struct { Fingerprint string @@ -40,5 +55,7 @@ type ( Timeout int64 `json:",default=3000"` CpuThreshold int64 `json:",default=900,range=[0:1000]"` Signature SignatureConf `json:",optional"` + // There are default values for all the items in Middlewares. + Middlewares MiddlewaresConf } ) diff --git a/rest/engine.go b/rest/engine.go index 18a2f618..8beb5ad0 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -88,19 +88,7 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta route Route, verifier func(chain.Chain) chain.Chain) error { chn := ng.chain if chn == nil { - chn = chain.New( - handler.TracingHandler(ng.conf.Name, route.Path), - ng.getLogHandler(), - handler.PrometheusHandler(route.Path), - handler.MaxConns(ng.conf.MaxConns), - handler.BreakerHandler(route.Method, route.Path, metrics), - handler.SheddingHandler(ng.getShedder(fr.priority), metrics), - handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), - handler.RecoverHandler, - handler.MetricHandler(metrics), - handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), - handler.GunzipHandler, - ) + chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics) } chn = ng.appendAuthHandler(fr, chn, verifier) @@ -125,6 +113,47 @@ func (ng *engine) bindRoutes(router httpx.Router) error { return nil } +func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route, + metrics *stat.Metrics) chain.Chain { + chn := chain.New() + + if ng.conf.Middlewares.Trace { + chn = chn.Append(handler.TracingHandler(ng.conf.Name, route.Path)) + } + if ng.conf.Middlewares.Log { + chn = chn.Append(ng.getLogHandler()) + } + if ng.conf.Middlewares.Prometheus { + chn = chn.Append(handler.PrometheusHandler(route.Path)) + } + if ng.conf.Middlewares.MaxConns { + chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns)) + } + if ng.conf.Middlewares.Breaker { + chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics)) + } + if ng.conf.Middlewares.Shedding { + chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics)) + } + if ng.conf.Middlewares.Timeout { + chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout))) + } + if ng.conf.Middlewares.Recover { + chn = chn.Append(handler.RecoverHandler) + } + if ng.conf.Middlewares.Metrics { + chn = chn.Append(handler.MetricHandler(metrics)) + } + if ng.conf.Middlewares.MaxBytes { + chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes))) + } + if ng.conf.Middlewares.Gunzip { + chn = chn.Append(handler.GunzipHandler) + } + + return chn +} + func (ng *engine) checkedMaxBytes(bytes int64) int64 { if bytes > 0 { return bytes diff --git a/rest/engine_test.go b/rest/engine_test.go index f80c434d..0ebed8fa 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -18,10 +18,14 @@ func TestNewEngine(t *testing.T) { yamls := []string{ `Name: foo Port: 54321 +Middlewares: + Log: false `, `Name: foo Port: 54321 CpuThreshold: 500 +Middlewares: + Log: false `, `Name: foo Port: 54321 diff --git a/rest/handler/maxconnshandler.go b/rest/handler/maxconnshandler.go index c062cb7c..4bbccadd 100644 --- a/rest/handler/maxconnshandler.go +++ b/rest/handler/maxconnshandler.go @@ -8,8 +8,8 @@ import ( "github.com/zeromicro/go-zero/rest/internal" ) -// MaxConns returns a middleware that limit the concurrent connections. -func MaxConns(n int) func(http.Handler) http.Handler { +// MaxConnsHandler returns a middleware that limit the concurrent connections. +func MaxConnsHandler(n int) func(http.Handler) http.Handler { if n <= 0 { return func(next http.Handler) http.Handler { return next diff --git a/rest/handler/maxconnshandler_test.go b/rest/handler/maxconnshandler_test.go index 7f857c0c..2e483436 100644 --- a/rest/handler/maxconnshandler_test.go +++ b/rest/handler/maxconnshandler_test.go @@ -24,7 +24,7 @@ func TestMaxConnsHandler(t *testing.T) { done := make(chan lang.PlaceholderType) defer close(done) - maxConns := MaxConns(conns) + maxConns := MaxConnsHandler(conns) handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { waitGroup.Done() <-done @@ -54,7 +54,7 @@ func TestWithoutMaxConnsHandler(t *testing.T) { done := make(chan lang.PlaceholderType) defer close(done) - maxConns := MaxConns(0) + maxConns := MaxConnsHandler(0) handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { val := r.Header.Get(key) if val == value {