mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-03 00:38:40 +08:00
refactor(rest): keep rest log collector context key private (#3407)
This commit is contained in:
parent
b71453985c
commit
61e562d0c7
@ -3,7 +3,6 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -44,7 +43,7 @@ func LogHandler(next http.Handler) http.Handler {
|
|||||||
|
|
||||||
var dup io.ReadCloser
|
var dup io.ReadCloser
|
||||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||||
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
|
next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
|
||||||
r.Body = dup
|
r.Body = dup
|
||||||
logBrief(r, lrw.Code, timer, logs)
|
logBrief(r, lrw.Code, timer, logs)
|
||||||
})
|
})
|
||||||
@ -102,7 +101,7 @@ func DetailedLogHandler(next http.Handler) http.Handler {
|
|||||||
var dup io.ReadCloser
|
var dup io.ReadCloser
|
||||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||||
logs := new(internal.LogCollector)
|
logs := new(internal.LogCollector)
|
||||||
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
|
next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
|
||||||
r.Body = dup
|
r.Body = dup
|
||||||
logDetails(r, lrw, timer, logs)
|
logDetails(r, lrw, timer, logs)
|
||||||
})
|
})
|
||||||
|
@ -22,7 +22,7 @@ func TestLogHandler(t *testing.T) {
|
|||||||
for _, logHandler := range handlers {
|
for _, logHandler := range handlers {
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
|
internal.LogCollectorFromContext(r.Context()).Append("anything")
|
||||||
w.Header().Set("X-Test", "test")
|
w.Header().Set("X-Test", "test")
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
_, err := w.Write([]byte("content"))
|
_, err := w.Write([]byte("content"))
|
||||||
@ -49,7 +49,7 @@ func TestLogHandlerVeryLong(t *testing.T) {
|
|||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf)
|
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf)
|
||||||
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
|
internal.LogCollectorFromContext(r.Context()).Append("anything")
|
||||||
_, _ = io.Copy(io.Discard, r.Body)
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
w.Header().Set("X-Test", "test")
|
w.Header().Set("X-Test", "test")
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
@ -2,6 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
@ -10,15 +11,34 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LogContext is a context key.
|
// logContextKey is a context key.
|
||||||
var LogContext = contextKey("request_logs")
|
var logContextKey = contextKey("request_logs")
|
||||||
|
|
||||||
// A LogCollector is used to collect logs.
|
type (
|
||||||
type LogCollector struct {
|
// LogCollector is used to collect logs.
|
||||||
|
LogCollector struct {
|
||||||
Messages []string
|
Messages []string
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
contextKey string
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithLogCollector returns a new context with LogCollector.
|
||||||
|
func WithLogCollector(ctx context.Context, lc *LogCollector) context.Context {
|
||||||
|
return context.WithValue(ctx, logContextKey, lc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogCollectorFromContext returns LogCollector from ctx.
|
||||||
|
func LogCollectorFromContext(ctx context.Context) *LogCollector {
|
||||||
|
val := ctx.Value(logContextKey)
|
||||||
|
if val == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return val.(*LogCollector)
|
||||||
|
}
|
||||||
|
|
||||||
// Append appends msg into log context.
|
// Append appends msg into log context.
|
||||||
func (lc *LogCollector) Append(msg string) {
|
func (lc *LogCollector) Append(msg string) {
|
||||||
lc.lock.Lock()
|
lc.lock.Lock()
|
||||||
@ -73,9 +93,9 @@ func Infof(r *http.Request, format string, v ...any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func appendLog(r *http.Request, message string) {
|
func appendLog(r *http.Request, message string) {
|
||||||
logs := r.Context().Value(LogContext)
|
logs := LogCollectorFromContext(r.Context())
|
||||||
if logs != nil {
|
if logs != nil {
|
||||||
logs.(*LogCollector).Append(message)
|
logs.Append(message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,9 +110,3 @@ func formatf(r *http.Request, format string, v ...any) string {
|
|||||||
func formatWithReq(r *http.Request, v string) string {
|
func formatWithReq(r *http.Request, v string) string {
|
||||||
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
|
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
|
||||||
}
|
}
|
||||||
|
|
||||||
type contextKey string
|
|
||||||
|
|
||||||
func (c contextKey) String() string {
|
|
||||||
return "rest/internal context key " + string(c)
|
|
||||||
}
|
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
func TestInfo(t *testing.T) {
|
func TestInfo(t *testing.T) {
|
||||||
collector := new(LogCollector)
|
collector := new(LogCollector)
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
req = req.WithContext(context.WithValue(req.Context(), LogContext, collector))
|
req = req.WithContext(WithLogCollector(req.Context(), collector))
|
||||||
Info(req, "first")
|
Info(req, "first")
|
||||||
Infof(req, "second %s", "third")
|
Infof(req, "second %s", "third")
|
||||||
val := collector.Flush()
|
val := collector.Flush()
|
||||||
@ -35,7 +35,10 @@ func TestError(t *testing.T) {
|
|||||||
assert.True(t, strings.Contains(val, "third"))
|
assert.True(t, strings.Contains(val, "third"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextKey_String(t *testing.T) {
|
func TestLogCollectorContext(t *testing.T) {
|
||||||
val := contextKey("foo")
|
ctx := context.Background()
|
||||||
assert.True(t, strings.Contains(val.String(), "foo"))
|
assert.Nil(t, LogCollectorFromContext(ctx))
|
||||||
|
collector := new(LogCollector)
|
||||||
|
ctx = WithLogCollector(ctx, collector)
|
||||||
|
assert.Equal(t, collector, LogCollectorFromContext(ctx))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user