go-zero/rest/handler/maxconnshandler_test.go

74 lines
1.7 KiB
Go
Raw Normal View History

2020-07-29 18:00:04 +08:00
package handler
2020-07-26 17:09:05 +08:00
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/lang"
2020-07-26 17:09:05 +08:00
)
const conns = 4
func TestMaxConnsHandler(t *testing.T) {
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConnsHandler(conns)
2020-07-26 17:09:05 +08:00
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
waitGroup.Done()
<-done
}))
for i := 0; i < conns; i++ {
go func() {
2022-10-17 06:30:58 +08:00
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
2020-07-26 17:09:05 +08:00
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
2022-10-17 06:30:58 +08:00
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
2020-07-26 17:09:05 +08:00
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestWithoutMaxConnsHandler(t *testing.T) {
const (
key = "block"
value = "1"
)
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConnsHandler(0)
2020-07-26 17:09:05 +08:00
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val := r.Header.Get(key)
if val == value {
waitGroup.Done()
<-done
}
}))
for i := 0; i < conns; i++ {
go func() {
2022-10-17 06:30:58 +08:00
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
2020-07-26 17:09:05 +08:00
req.Header.Set(key, value)
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
2022-10-17 06:30:58 +08:00
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
2020-07-26 17:09:05 +08:00
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}