mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
chore: avoid nested WithCodeResponseWriter (#3406)
This commit is contained in:
parent
e8c1e6e09b
commit
13cdbdc98b
@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
|
||||
return
|
||||
}
|
||||
|
||||
cw := &response.WithCodeResponseWriter{Writer: w}
|
||||
cw := response.NewWithCodeResponseWriter(w)
|
||||
defer func() {
|
||||
if cw.Code < http.StatusInternalServerError {
|
||||
promise.Accept()
|
||||
|
@ -36,14 +36,11 @@ func LogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
logs := new(internal.LogCollector)
|
||||
lrw := response.WithCodeResponseWriter{
|
||||
Writer: w,
|
||||
Code: http.StatusOK,
|
||||
}
|
||||
lrw := response.NewWithCodeResponseWriter(w)
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
|
||||
next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
|
||||
r.Body = dup
|
||||
logBrief(r, lrw.Code, timer, logs)
|
||||
})
|
||||
@ -54,7 +51,8 @@ type detailLoggedResponseWriter struct {
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter {
|
||||
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter,
|
||||
buf *bytes.Buffer) *detailLoggedResponseWriter {
|
||||
return &detailLoggedResponseWriter{
|
||||
writer: writer,
|
||||
buf: buf,
|
||||
@ -93,10 +91,8 @@ func DetailedLogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
var buf bytes.Buffer
|
||||
lrw := newDetailLoggedResponseWriter(&response.WithCodeResponseWriter{
|
||||
Writer: w,
|
||||
Code: http.StatusOK,
|
||||
}, &buf)
|
||||
rw := response.NewWithCodeResponseWriter(w)
|
||||
lrw := newDetailLoggedResponseWriter(rw, &buf)
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
|
@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -88,18 +89,23 @@ func TestLogHandlerSlow(t *testing.T) {
|
||||
func TestDetailedLogHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := &detailLoggedResponseWriter{
|
||||
writer: &response.WithCodeResponseWriter{
|
||||
Writer: resp,
|
||||
},
|
||||
writer: response.NewWithCodeResponseWriter(resp),
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
_, _, _ = writer.Hijack()
|
||||
})
|
||||
|
||||
writer = &detailLoggedResponseWriter{
|
||||
writer: &response.WithCodeResponseWriter{
|
||||
Writer: mockedHijackable{resp},
|
||||
},
|
||||
writer: response.NewWithCodeResponseWriter(resp),
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
_, _, _ = writer.Hijack()
|
||||
})
|
||||
|
||||
writer = &detailLoggedResponseWriter{
|
||||
writer: response.NewWithCodeResponseWriter(mockedHijackable{
|
||||
ResponseRecorder: resp,
|
||||
}),
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
_, _, _ = writer.Hijack()
|
||||
@ -133,6 +139,13 @@ func TestWrapStatusCodeWithColor(t *testing.T) {
|
||||
assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
|
||||
}
|
||||
|
||||
func TestDumpRequest(t *testing.T) {
|
||||
const errMsg = "error"
|
||||
r := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||
r.Body = mockedReadCloser{errMsg: errMsg}
|
||||
assert.Equal(t, errMsg, dumpRequest(r))
|
||||
}
|
||||
|
||||
func BenchmarkLogHandler(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
@ -146,3 +159,15 @@ func BenchmarkLogHandler(b *testing.B) {
|
||||
handler.ServeHTTP(resp, req)
|
||||
}
|
||||
}
|
||||
|
||||
type mockedReadCloser struct {
|
||||
errMsg string
|
||||
}
|
||||
|
||||
func (m mockedReadCloser) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New(m.errMsg)
|
||||
}
|
||||
|
||||
func (m mockedReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ func PrometheusHandler(path, method string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := timex.Now()
|
||||
cw := &response.WithCodeResponseWriter{Writer: w}
|
||||
cw := response.NewWithCodeResponseWriter(w)
|
||||
defer func() {
|
||||
metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method)
|
||||
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method)
|
||||
|
@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
|
||||
return
|
||||
}
|
||||
|
||||
cw := &response.WithCodeResponseWriter{Writer: w}
|
||||
cw := response.NewWithCodeResponseWriter(w)
|
||||
defer func() {
|
||||
if cw.Code == http.StatusServiceUnavailable {
|
||||
promise.Fail()
|
||||
|
@ -70,6 +70,7 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w: w,
|
||||
h: make(http.Header),
|
||||
req: r,
|
||||
code: http.StatusOK,
|
||||
}
|
||||
panicChan := make(chan any, 1)
|
||||
go func() {
|
||||
@ -91,10 +92,12 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
for k, vv := range tw.h {
|
||||
dst[k] = vv
|
||||
}
|
||||
if !tw.wroteHeader {
|
||||
tw.code = http.StatusOK
|
||||
}
|
||||
|
||||
// We don't need to write header 200, because it's written by default.
|
||||
// If we write it again, it will cause a warning: `http: superfluous response.WriteHeader call`.
|
||||
if tw.code != http.StatusOK {
|
||||
w.WriteHeader(tw.code)
|
||||
}
|
||||
w.Write(tw.wbuf.Bytes())
|
||||
case <-ctx.Done():
|
||||
tw.mu.Lock()
|
||||
|
@ -100,6 +100,18 @@ func TestWithinTimeout(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithinTimeoutBadCode(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(time.Second)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.Code)
|
||||
}
|
||||
|
||||
func TestWithTimeoutTimedout(t *testing.T) {
|
||||
timeoutHandler := TimeoutHandler(time.Millisecond)
|
||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -208,9 +220,7 @@ func TestTimeoutHijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
writer := &timeoutWriter{
|
||||
w: &response.WithCodeResponseWriter{
|
||||
Writer: resp,
|
||||
},
|
||||
w: response.NewWithCodeResponseWriter(resp),
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
@ -218,9 +228,7 @@ func TestTimeoutHijack(t *testing.T) {
|
||||
})
|
||||
|
||||
writer = &timeoutWriter{
|
||||
w: &response.WithCodeResponseWriter{
|
||||
Writer: mockedHijackable{resp},
|
||||
},
|
||||
w: response.NewWithCodeResponseWriter(mockedHijackable{resp}),
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
@ -274,9 +282,7 @@ func TestTimeoutWriter_Hijack(t *testing.T) {
|
||||
func TestTimeoutWroteTwice(t *testing.T) {
|
||||
c := logtest.NewCollector(t)
|
||||
writer := &timeoutWriter{
|
||||
w: &response.WithCodeResponseWriter{
|
||||
Writer: httptest.NewRecorder(),
|
||||
},
|
||||
w: response.NewWithCodeResponseWriter(httptest.NewRecorder()),
|
||||
h: make(http.Header),
|
||||
req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handl
|
||||
// convenient for tracking error messages
|
||||
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
|
||||
|
||||
trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
|
||||
trw := response.NewWithCodeResponseWriter(w)
|
||||
next.ServeHTTP(trw, r.WithContext(spanCtx))
|
||||
|
||||
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
|
||||
|
@ -13,6 +13,20 @@ type WithCodeResponseWriter struct {
|
||||
Code int
|
||||
}
|
||||
|
||||
// NewWithCodeResponseWriter returns a WithCodeResponseWriter.
|
||||
// If writer is already a WithCodeResponseWriter, it returns writer directly.
|
||||
func NewWithCodeResponseWriter(writer http.ResponseWriter) *WithCodeResponseWriter {
|
||||
switch w := writer.(type) {
|
||||
case *WithCodeResponseWriter:
|
||||
return w
|
||||
default:
|
||||
return &WithCodeResponseWriter{
|
||||
Writer: writer,
|
||||
Code: http.StatusOK,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush flushes the response writer.
|
||||
func (w *WithCodeResponseWriter) Flush() {
|
||||
if flusher, ok := w.Writer.(http.Flusher); ok {
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
func TestWithCodeResponseWriter(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cw := &WithCodeResponseWriter{Writer: w}
|
||||
cw := NewWithCodeResponseWriter(w)
|
||||
|
||||
cw.Header().Set("X-Test", "test")
|
||||
cw.WriteHeader(http.StatusServiceUnavailable)
|
||||
@ -34,9 +34,7 @@ func TestWithCodeResponseWriter(t *testing.T) {
|
||||
|
||||
func TestWithCodeResponseWriter_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := &WithCodeResponseWriter{
|
||||
Writer: resp,
|
||||
}
|
||||
writer := NewWithCodeResponseWriter(NewWithCodeResponseWriter(resp))
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user