feat: add middlewares config for rest (#2765)

* feat: add middlewares config for rest

* chore: disable logs in tests

* chore: enable verbose in tests
This commit is contained in:
Kevin Wan 2023-01-08 16:41:53 +08:00 committed by GitHub
parent f4502171ea
commit ade6f9ee46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 17 deletions

View File

@ -7,6 +7,21 @@ import (
) )
type ( type (
// MiddlewaresConf is the config of middlewares.
MiddlewaresConf struct {
Trace bool `json:",default=true"`
Log bool `json:",default=true"`
Prometheus bool `json:",default=true"`
MaxConns bool `json:",default=true"`
Breaker bool `json:",default=true"`
Shedding bool `json:",default=true"`
Timeout bool `json:",default=true"`
Recover bool `json:",default=true"`
Metrics bool `json:",default=true"`
MaxBytes bool `json:",default=true"`
Gunzip bool `json:",default=true"`
}
// A PrivateKeyConf is a private key config. // A PrivateKeyConf is a private key config.
PrivateKeyConf struct { PrivateKeyConf struct {
Fingerprint string Fingerprint string
@ -40,5 +55,7 @@ type (
Timeout int64 `json:",default=3000"` Timeout int64 `json:",default=3000"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"` CpuThreshold int64 `json:",default=900,range=[0:1000]"`
Signature SignatureConf `json:",optional"` Signature SignatureConf `json:",optional"`
// There are default values for all the items in Middlewares.
Middlewares MiddlewaresConf
} }
) )

View File

@ -88,19 +88,7 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
route Route, verifier func(chain.Chain) chain.Chain) error { route Route, verifier func(chain.Chain) chain.Chain) error {
chn := ng.chain chn := ng.chain
if chn == nil { if chn == nil {
chn = chain.New( chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler,
)
} }
chn = ng.appendAuthHandler(fr, chn, verifier) chn = ng.appendAuthHandler(fr, chn, verifier)
@ -125,6 +113,47 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
return nil return nil
} }
func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
metrics *stat.Metrics) chain.Chain {
chn := chain.New()
if ng.conf.Middlewares.Trace {
chn = chn.Append(handler.TracingHandler(ng.conf.Name, route.Path))
}
if ng.conf.Middlewares.Log {
chn = chn.Append(ng.getLogHandler())
}
if ng.conf.Middlewares.Prometheus {
chn = chn.Append(handler.PrometheusHandler(route.Path))
}
if ng.conf.Middlewares.MaxConns {
chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns))
}
if ng.conf.Middlewares.Breaker {
chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics))
}
if ng.conf.Middlewares.Shedding {
chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
}
if ng.conf.Middlewares.Timeout {
chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
}
if ng.conf.Middlewares.Recover {
chn = chn.Append(handler.RecoverHandler)
}
if ng.conf.Middlewares.Metrics {
chn = chn.Append(handler.MetricHandler(metrics))
}
if ng.conf.Middlewares.MaxBytes {
chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)))
}
if ng.conf.Middlewares.Gunzip {
chn = chn.Append(handler.GunzipHandler)
}
return chn
}
func (ng *engine) checkedMaxBytes(bytes int64) int64 { func (ng *engine) checkedMaxBytes(bytes int64) int64 {
if bytes > 0 { if bytes > 0 {
return bytes return bytes

View File

@ -18,10 +18,14 @@ func TestNewEngine(t *testing.T) {
yamls := []string{ yamls := []string{
`Name: foo `Name: foo
Port: 54321 Port: 54321
Middlewares:
Log: false
`, `,
`Name: foo `Name: foo
Port: 54321 Port: 54321
CpuThreshold: 500 CpuThreshold: 500
Middlewares:
Log: false
`, `,
`Name: foo `Name: foo
Port: 54321 Port: 54321

View File

@ -8,8 +8,8 @@ import (
"github.com/zeromicro/go-zero/rest/internal" "github.com/zeromicro/go-zero/rest/internal"
) )
// MaxConns returns a middleware that limit the concurrent connections. // MaxConnsHandler returns a middleware that limit the concurrent connections.
func MaxConns(n int) func(http.Handler) http.Handler { func MaxConnsHandler(n int) func(http.Handler) http.Handler {
if n <= 0 { if n <= 0 {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return next return next

View File

@ -24,7 +24,7 @@ func TestMaxConnsHandler(t *testing.T) {
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
defer close(done) defer close(done)
maxConns := MaxConns(conns) maxConns := MaxConnsHandler(conns)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
waitGroup.Done() waitGroup.Done()
<-done <-done
@ -54,7 +54,7 @@ func TestWithoutMaxConnsHandler(t *testing.T) {
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
defer close(done) defer close(done)
maxConns := MaxConns(0) maxConns := MaxConnsHandler(0)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val := r.Header.Get(key) val := r.Header.Get(key)
if val == value { if val == value {