diff --git a/rest/handler/loghandler.go b/rest/handler/loghandler.go index 68552c4c..179a72f9 100644 --- a/rest/handler/loghandler.go +++ b/rest/handler/loghandler.go @@ -3,7 +3,6 @@ package handler import ( "bufio" "bytes" - "context" "errors" "fmt" "io" @@ -44,7 +43,7 @@ func LogHandler(next http.Handler) http.Handler { var dup io.ReadCloser r.Body, dup = iox.DupReadCloser(r.Body) - next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs))) + next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs))) r.Body = dup logBrief(r, lrw.Code, timer, logs) }) @@ -102,7 +101,7 @@ func DetailedLogHandler(next http.Handler) http.Handler { var dup io.ReadCloser r.Body, dup = iox.DupReadCloser(r.Body) logs := new(internal.LogCollector) - next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs))) + next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs))) r.Body = dup logDetails(r, lrw, timer, logs) }) diff --git a/rest/handler/loghandler_test.go b/rest/handler/loghandler_test.go index 2a562576..61156e76 100644 --- a/rest/handler/loghandler_test.go +++ b/rest/handler/loghandler_test.go @@ -22,7 +22,7 @@ func TestLogHandler(t *testing.T) { for _, logHandler := range handlers { req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything") + internal.LogCollectorFromContext(r.Context()).Append("anything") w.Header().Set("X-Test", "test") w.WriteHeader(http.StatusServiceUnavailable) _, err := w.Write([]byte("content")) @@ -49,7 +49,7 @@ func TestLogHandlerVeryLong(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf) handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything") + internal.LogCollectorFromContext(r.Context()).Append("anything") _, _ = io.Copy(io.Discard, r.Body) w.Header().Set("X-Test", "test") w.WriteHeader(http.StatusServiceUnavailable) diff --git a/rest/internal/log.go b/rest/internal/log.go index 0e921afb..59e66194 100644 --- a/rest/internal/log.go +++ b/rest/internal/log.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "fmt" "net/http" "sync" @@ -10,13 +11,32 @@ import ( "github.com/zeromicro/go-zero/rest/httpx" ) -// LogContext is a context key. -var LogContext = contextKey("request_logs") +// logContextKey is a context key. +var logContextKey = contextKey("request_logs") -// A LogCollector is used to collect logs. -type LogCollector struct { - Messages []string - lock sync.Mutex +type ( + // LogCollector is used to collect logs. + LogCollector struct { + Messages []string + lock sync.Mutex + } + + contextKey string +) + +// WithLogCollector returns a new context with LogCollector. +func WithLogCollector(ctx context.Context, lc *LogCollector) context.Context { + return context.WithValue(ctx, logContextKey, lc) +} + +// LogCollectorFromContext returns LogCollector from ctx. +func LogCollectorFromContext(ctx context.Context) *LogCollector { + val := ctx.Value(logContextKey) + if val == nil { + return nil + } + + return val.(*LogCollector) } // Append appends msg into log context. @@ -73,9 +93,9 @@ func Infof(r *http.Request, format string, v ...any) { } func appendLog(r *http.Request, message string) { - logs := r.Context().Value(LogContext) + logs := LogCollectorFromContext(r.Context()) if logs != nil { - logs.(*LogCollector).Append(message) + logs.Append(message) } } @@ -90,9 +110,3 @@ func formatf(r *http.Request, format string, v ...any) string { func formatWithReq(r *http.Request, v string) string { return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v) } - -type contextKey string - -func (c contextKey) String() string { - return "rest/internal context key " + string(c) -} diff --git a/rest/internal/log_test.go b/rest/internal/log_test.go index 79e08af6..7c46493f 100644 --- a/rest/internal/log_test.go +++ b/rest/internal/log_test.go @@ -14,7 +14,7 @@ import ( func TestInfo(t *testing.T) { collector := new(LogCollector) req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) - req = req.WithContext(context.WithValue(req.Context(), LogContext, collector)) + req = req.WithContext(WithLogCollector(req.Context(), collector)) Info(req, "first") Infof(req, "second %s", "third") val := collector.Flush() @@ -35,7 +35,10 @@ func TestError(t *testing.T) { assert.True(t, strings.Contains(val, "third")) } -func TestContextKey_String(t *testing.T) { - val := contextKey("foo") - assert.True(t, strings.Contains(val.String(), "foo")) +func TestLogCollectorContext(t *testing.T) { + ctx := context.Background() + assert.Nil(t, LogCollectorFromContext(ctx)) + collector := new(LogCollector) + ctx = WithLogCollector(ctx, collector) + assert.Equal(t, collector, LogCollectorFromContext(ctx)) }