mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
chore: avoid nested WithCodeResponseWriter (#3406)
This commit is contained in:
parent
e8c1e6e09b
commit
13cdbdc98b
@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cw := &response.WithCodeResponseWriter{Writer: w}
|
cw := response.NewWithCodeResponseWriter(w)
|
||||||
defer func() {
|
defer func() {
|
||||||
if cw.Code < http.StatusInternalServerError {
|
if cw.Code < http.StatusInternalServerError {
|
||||||
promise.Accept()
|
promise.Accept()
|
||||||
|
@ -36,14 +36,11 @@ func LogHandler(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
timer := utils.NewElapsedTimer()
|
timer := utils.NewElapsedTimer()
|
||||||
logs := new(internal.LogCollector)
|
logs := new(internal.LogCollector)
|
||||||
lrw := response.WithCodeResponseWriter{
|
lrw := response.NewWithCodeResponseWriter(w)
|
||||||
Writer: w,
|
|
||||||
Code: http.StatusOK,
|
|
||||||
}
|
|
||||||
|
|
||||||
var dup io.ReadCloser
|
var dup io.ReadCloser
|
||||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||||
next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
|
next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
|
||||||
r.Body = dup
|
r.Body = dup
|
||||||
logBrief(r, lrw.Code, timer, logs)
|
logBrief(r, lrw.Code, timer, logs)
|
||||||
})
|
})
|
||||||
@ -54,7 +51,8 @@ type detailLoggedResponseWriter struct {
|
|||||||
buf *bytes.Buffer
|
buf *bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter {
|
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter,
|
||||||
|
buf *bytes.Buffer) *detailLoggedResponseWriter {
|
||||||
return &detailLoggedResponseWriter{
|
return &detailLoggedResponseWriter{
|
||||||
writer: writer,
|
writer: writer,
|
||||||
buf: buf,
|
buf: buf,
|
||||||
@ -93,10 +91,8 @@ func DetailedLogHandler(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
timer := utils.NewElapsedTimer()
|
timer := utils.NewElapsedTimer()
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
lrw := newDetailLoggedResponseWriter(&response.WithCodeResponseWriter{
|
rw := response.NewWithCodeResponseWriter(w)
|
||||||
Writer: w,
|
lrw := newDetailLoggedResponseWriter(rw, &buf)
|
||||||
Code: http.StatusOK,
|
|
||||||
}, &buf)
|
|
||||||
|
|
||||||
var dup io.ReadCloser
|
var dup io.ReadCloser
|
||||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||||
|
@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -88,18 +89,23 @@ func TestLogHandlerSlow(t *testing.T) {
|
|||||||
func TestDetailedLogHandler_Hijack(t *testing.T) {
|
func TestDetailedLogHandler_Hijack(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
writer := &detailLoggedResponseWriter{
|
writer := &detailLoggedResponseWriter{
|
||||||
writer: &response.WithCodeResponseWriter{
|
writer: response.NewWithCodeResponseWriter(resp),
|
||||||
Writer: resp,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
_, _, _ = writer.Hijack()
|
_, _, _ = writer.Hijack()
|
||||||
})
|
})
|
||||||
|
|
||||||
writer = &detailLoggedResponseWriter{
|
writer = &detailLoggedResponseWriter{
|
||||||
writer: &response.WithCodeResponseWriter{
|
writer: response.NewWithCodeResponseWriter(resp),
|
||||||
Writer: mockedHijackable{resp},
|
}
|
||||||
},
|
assert.NotPanics(t, func() {
|
||||||
|
_, _, _ = writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = &detailLoggedResponseWriter{
|
||||||
|
writer: response.NewWithCodeResponseWriter(mockedHijackable{
|
||||||
|
ResponseRecorder: resp,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
_, _, _ = writer.Hijack()
|
_, _, _ = writer.Hijack()
|
||||||
@ -133,6 +139,13 @@ func TestWrapStatusCodeWithColor(t *testing.T) {
|
|||||||
assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
|
assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDumpRequest(t *testing.T) {
|
||||||
|
const errMsg = "error"
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
|
r.Body = mockedReadCloser{errMsg: errMsg}
|
||||||
|
assert.Equal(t, errMsg, dumpRequest(r))
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkLogHandler(b *testing.B) {
|
func BenchmarkLogHandler(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
||||||
@ -146,3 +159,15 @@ func BenchmarkLogHandler(b *testing.B) {
|
|||||||
handler.ServeHTTP(resp, req)
|
handler.ServeHTTP(resp, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockedReadCloser struct {
|
||||||
|
errMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedReadCloser) Read(p []byte) (n int, err error) {
|
||||||
|
return 0, errors.New(m.errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedReadCloser) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -35,7 +35,7 @@ func PrometheusHandler(path, method string) func(http.Handler) http.Handler {
|
|||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
startTime := timex.Now()
|
startTime := timex.Now()
|
||||||
cw := &response.WithCodeResponseWriter{Writer: w}
|
cw := response.NewWithCodeResponseWriter(w)
|
||||||
defer func() {
|
defer func() {
|
||||||
metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method)
|
metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method)
|
||||||
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method)
|
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method)
|
||||||
|
@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cw := &response.WithCodeResponseWriter{Writer: w}
|
cw := response.NewWithCodeResponseWriter(w)
|
||||||
defer func() {
|
defer func() {
|
||||||
if cw.Code == http.StatusServiceUnavailable {
|
if cw.Code == http.StatusServiceUnavailable {
|
||||||
promise.Fail()
|
promise.Fail()
|
||||||
|
@ -67,9 +67,10 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
tw := &timeoutWriter{
|
tw := &timeoutWriter{
|
||||||
w: w,
|
w: w,
|
||||||
h: make(http.Header),
|
h: make(http.Header),
|
||||||
req: r,
|
req: r,
|
||||||
|
code: http.StatusOK,
|
||||||
}
|
}
|
||||||
panicChan := make(chan any, 1)
|
panicChan := make(chan any, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@ -91,10 +92,12 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
for k, vv := range tw.h {
|
for k, vv := range tw.h {
|
||||||
dst[k] = vv
|
dst[k] = vv
|
||||||
}
|
}
|
||||||
if !tw.wroteHeader {
|
|
||||||
tw.code = http.StatusOK
|
// We don't need to write header 200, because it's written by default.
|
||||||
|
// If we write it again, it will cause a warning: `http: superfluous response.WriteHeader call`.
|
||||||
|
if tw.code != http.StatusOK {
|
||||||
|
w.WriteHeader(tw.code)
|
||||||
}
|
}
|
||||||
w.WriteHeader(tw.code)
|
|
||||||
w.Write(tw.wbuf.Bytes())
|
w.Write(tw.wbuf.Bytes())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
tw.mu.Lock()
|
tw.mu.Lock()
|
||||||
|
@ -100,6 +100,18 @@ func TestWithinTimeout(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithinTimeoutBadCode(t *testing.T) {
|
||||||
|
timeoutHandler := TimeoutHandler(time.Second)
|
||||||
|
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(resp, req)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithTimeoutTimedout(t *testing.T) {
|
func TestWithTimeoutTimedout(t *testing.T) {
|
||||||
timeoutHandler := TimeoutHandler(time.Millisecond)
|
timeoutHandler := TimeoutHandler(time.Millisecond)
|
||||||
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -208,9 +220,7 @@ func TestTimeoutHijack(t *testing.T) {
|
|||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
|
|
||||||
writer := &timeoutWriter{
|
writer := &timeoutWriter{
|
||||||
w: &response.WithCodeResponseWriter{
|
w: response.NewWithCodeResponseWriter(resp),
|
||||||
Writer: resp,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
@ -218,9 +228,7 @@ func TestTimeoutHijack(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
writer = &timeoutWriter{
|
writer = &timeoutWriter{
|
||||||
w: &response.WithCodeResponseWriter{
|
w: response.NewWithCodeResponseWriter(mockedHijackable{resp}),
|
||||||
Writer: mockedHijackable{resp},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
@ -274,9 +282,7 @@ func TestTimeoutWriter_Hijack(t *testing.T) {
|
|||||||
func TestTimeoutWroteTwice(t *testing.T) {
|
func TestTimeoutWroteTwice(t *testing.T) {
|
||||||
c := logtest.NewCollector(t)
|
c := logtest.NewCollector(t)
|
||||||
writer := &timeoutWriter{
|
writer := &timeoutWriter{
|
||||||
w: &response.WithCodeResponseWriter{
|
w: response.NewWithCodeResponseWriter(httptest.NewRecorder()),
|
||||||
Writer: httptest.NewRecorder(),
|
|
||||||
},
|
|
||||||
h: make(http.Header),
|
h: make(http.Header),
|
||||||
req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
|
req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
|
||||||
}
|
}
|
||||||
|
@ -60,7 +60,7 @@ func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handl
|
|||||||
// convenient for tracking error messages
|
// convenient for tracking error messages
|
||||||
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
|
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
|
||||||
|
|
||||||
trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
|
trw := response.NewWithCodeResponseWriter(w)
|
||||||
next.ServeHTTP(trw, r.WithContext(spanCtx))
|
next.ServeHTTP(trw, r.WithContext(spanCtx))
|
||||||
|
|
||||||
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
|
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
|
||||||
|
@ -13,6 +13,20 @@ type WithCodeResponseWriter struct {
|
|||||||
Code int
|
Code int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewWithCodeResponseWriter returns a WithCodeResponseWriter.
|
||||||
|
// If writer is already a WithCodeResponseWriter, it returns writer directly.
|
||||||
|
func NewWithCodeResponseWriter(writer http.ResponseWriter) *WithCodeResponseWriter {
|
||||||
|
switch w := writer.(type) {
|
||||||
|
case *WithCodeResponseWriter:
|
||||||
|
return w
|
||||||
|
default:
|
||||||
|
return &WithCodeResponseWriter{
|
||||||
|
Writer: writer,
|
||||||
|
Code: http.StatusOK,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Flush flushes the response writer.
|
// Flush flushes the response writer.
|
||||||
func (w *WithCodeResponseWriter) Flush() {
|
func (w *WithCodeResponseWriter) Flush() {
|
||||||
if flusher, ok := w.Writer.(http.Flusher); ok {
|
if flusher, ok := w.Writer.(http.Flusher); ok {
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
func TestWithCodeResponseWriter(t *testing.T) {
|
func TestWithCodeResponseWriter(t *testing.T) {
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
cw := &WithCodeResponseWriter{Writer: w}
|
cw := NewWithCodeResponseWriter(w)
|
||||||
|
|
||||||
cw.Header().Set("X-Test", "test")
|
cw.Header().Set("X-Test", "test")
|
||||||
cw.WriteHeader(http.StatusServiceUnavailable)
|
cw.WriteHeader(http.StatusServiceUnavailable)
|
||||||
@ -34,9 +34,7 @@ func TestWithCodeResponseWriter(t *testing.T) {
|
|||||||
|
|
||||||
func TestWithCodeResponseWriter_Hijack(t *testing.T) {
|
func TestWithCodeResponseWriter_Hijack(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
writer := &WithCodeResponseWriter{
|
writer := NewWithCodeResponseWriter(NewWithCodeResponseWriter(resp))
|
||||||
Writer: resp,
|
|
||||||
}
|
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
writer.Hijack()
|
writer.Hijack()
|
||||||
})
|
})
|
||||||
|
Loading…
Reference in New Issue
Block a user