chore: avoid nested WithCodeResponseWriter (#3406)

This commit is contained in:
Kevin Wan 2023-07-11 23:59:43 +08:00 committed by GitHub
parent e8c1e6e09b
commit 13cdbdc98b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 81 additions and 39 deletions

View File

@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
return
}
cw := &response.WithCodeResponseWriter{Writer: w}
cw := response.NewWithCodeResponseWriter(w)
defer func() {
if cw.Code < http.StatusInternalServerError {
promise.Accept()

View File

@ -36,14 +36,11 @@ func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
logs := new(internal.LogCollector)
lrw := response.WithCodeResponseWriter{
Writer: w,
Code: http.StatusOK,
}
lrw := response.NewWithCodeResponseWriter(w)
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
r.Body = dup
logBrief(r, lrw.Code, timer, logs)
})
@ -54,7 +51,8 @@ type detailLoggedResponseWriter struct {
buf *bytes.Buffer
}
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter {
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter,
buf *bytes.Buffer) *detailLoggedResponseWriter {
return &detailLoggedResponseWriter{
writer: writer,
buf: buf,
@ -93,10 +91,8 @@ func DetailedLogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
var buf bytes.Buffer
lrw := newDetailLoggedResponseWriter(&response.WithCodeResponseWriter{
Writer: w,
Code: http.StatusOK,
}, &buf)
rw := response.NewWithCodeResponseWriter(w)
lrw := newDetailLoggedResponseWriter(rw, &buf)
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)

View File

@ -2,6 +2,7 @@ package handler
import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
@ -88,18 +89,23 @@ func TestLogHandlerSlow(t *testing.T) {
func TestDetailedLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &detailLoggedResponseWriter{
writer: &response.WithCodeResponseWriter{
Writer: resp,
},
writer: response.NewWithCodeResponseWriter(resp),
}
assert.NotPanics(t, func() {
_, _, _ = writer.Hijack()
})
writer = &detailLoggedResponseWriter{
writer: &response.WithCodeResponseWriter{
Writer: mockedHijackable{resp},
},
writer: response.NewWithCodeResponseWriter(resp),
}
assert.NotPanics(t, func() {
_, _, _ = writer.Hijack()
})
writer = &detailLoggedResponseWriter{
writer: response.NewWithCodeResponseWriter(mockedHijackable{
ResponseRecorder: resp,
}),
}
assert.NotPanics(t, func() {
_, _, _ = writer.Hijack()
@ -133,6 +139,13 @@ func TestWrapStatusCodeWithColor(t *testing.T) {
assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
}
func TestDumpRequest(t *testing.T) {
const errMsg = "error"
r := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
r.Body = mockedReadCloser{errMsg: errMsg}
assert.Equal(t, errMsg, dumpRequest(r))
}
func BenchmarkLogHandler(b *testing.B) {
b.ReportAllocs()
@ -146,3 +159,15 @@ func BenchmarkLogHandler(b *testing.B) {
handler.ServeHTTP(resp, req)
}
}
type mockedReadCloser struct {
errMsg string
}
func (m mockedReadCloser) Read(p []byte) (n int, err error) {
return 0, errors.New(m.errMsg)
}
func (m mockedReadCloser) Close() error {
return nil
}

View File

@ -35,7 +35,7 @@ func PrometheusHandler(path, method string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now()
cw := &response.WithCodeResponseWriter{Writer: w}
cw := response.NewWithCodeResponseWriter(w)
defer func() {
metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method)
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method)

View File

@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
return
}
cw := &response.WithCodeResponseWriter{Writer: w}
cw := response.NewWithCodeResponseWriter(w)
defer func() {
if cw.Code == http.StatusServiceUnavailable {
promise.Fail()

View File

@ -67,9 +67,10 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(ctx)
done := make(chan struct{})
tw := &timeoutWriter{
w: w,
h: make(http.Header),
req: r,
w: w,
h: make(http.Header),
req: r,
code: http.StatusOK,
}
panicChan := make(chan any, 1)
go func() {
@ -91,10 +92,12 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for k, vv := range tw.h {
dst[k] = vv
}
if !tw.wroteHeader {
tw.code = http.StatusOK
// We don't need to write header 200, because it's written by default.
// If we write it again, it will cause a warning: `http: superfluous response.WriteHeader call`.
if tw.code != http.StatusOK {
w.WriteHeader(tw.code)
}
w.WriteHeader(tw.code)
w.Write(tw.wbuf.Bytes())
case <-ctx.Done():
tw.mu.Lock()

View File

@ -100,6 +100,18 @@ func TestWithinTimeout(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestWithinTimeoutBadCode(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Second)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusInternalServerError, resp.Code)
}
func TestWithTimeoutTimedout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Millisecond)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -208,9 +220,7 @@ func TestTimeoutHijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &timeoutWriter{
w: &response.WithCodeResponseWriter{
Writer: resp,
},
w: response.NewWithCodeResponseWriter(resp),
}
assert.NotPanics(t, func() {
@ -218,9 +228,7 @@ func TestTimeoutHijack(t *testing.T) {
})
writer = &timeoutWriter{
w: &response.WithCodeResponseWriter{
Writer: mockedHijackable{resp},
},
w: response.NewWithCodeResponseWriter(mockedHijackable{resp}),
}
assert.NotPanics(t, func() {
@ -274,9 +282,7 @@ func TestTimeoutWriter_Hijack(t *testing.T) {
func TestTimeoutWroteTwice(t *testing.T) {
c := logtest.NewCollector(t)
writer := &timeoutWriter{
w: &response.WithCodeResponseWriter{
Writer: httptest.NewRecorder(),
},
w: response.NewWithCodeResponseWriter(httptest.NewRecorder()),
h: make(http.Header),
req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
}

View File

@ -60,7 +60,7 @@ func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handl
// convenient for tracking error messages
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
trw := response.NewWithCodeResponseWriter(w)
next.ServeHTTP(trw, r.WithContext(spanCtx))
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)

View File

@ -13,6 +13,20 @@ type WithCodeResponseWriter struct {
Code int
}
// NewWithCodeResponseWriter returns a WithCodeResponseWriter.
// If writer is already a WithCodeResponseWriter, it returns writer directly.
func NewWithCodeResponseWriter(writer http.ResponseWriter) *WithCodeResponseWriter {
switch w := writer.(type) {
case *WithCodeResponseWriter:
return w
default:
return &WithCodeResponseWriter{
Writer: writer,
Code: http.StatusOK,
}
}
}
// Flush flushes the response writer.
func (w *WithCodeResponseWriter) Flush() {
if flusher, ok := w.Writer.(http.Flusher); ok {

View File

@ -11,7 +11,7 @@ import (
func TestWithCodeResponseWriter(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := &WithCodeResponseWriter{Writer: w}
cw := NewWithCodeResponseWriter(w)
cw.Header().Set("X-Test", "test")
cw.WriteHeader(http.StatusServiceUnavailable)
@ -34,9 +34,7 @@ func TestWithCodeResponseWriter(t *testing.T) {
func TestWithCodeResponseWriter_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &WithCodeResponseWriter{
Writer: resp,
}
writer := NewWithCodeResponseWriter(NewWithCodeResponseWriter(resp))
assert.NotPanics(t, func() {
writer.Hijack()
})