go-zero/rest/handler/maxconnshandler_test.go

80 lines
1.7 KiB
Go

package handler
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/lang"
)
const conns = 4
func init() {
log.SetOutput(ioutil.Discard)
}
func TestMaxConnsHandler(t *testing.T) {
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(conns)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
waitGroup.Done()
<-done
}))
for i := 0; i < conns; i++ {
go func() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestWithoutMaxConnsHandler(t *testing.T) {
const (
key = "block"
value = "1"
)
var waitGroup sync.WaitGroup
waitGroup.Add(conns)
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(0)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val := r.Header.Get(key)
if val == value {
waitGroup.Done()
<-done
}
}))
for i := 0; i < conns; i++ {
go func() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Set(key, value)
handler.ServeHTTP(httptest.NewRecorder(), req)
}()
}
waitGroup.Wait()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}