From 962b36d7452d83928de31182f70cba88860c2092 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sun, 26 May 2024 12:52:05 +0800 Subject: [PATCH] fix: log concurrency problems after calling WithXXX methods (#4164) --- core/logx/richlogger.go | 36 ++++++++++++++++++------ core/logx/richlogger_test.go | 48 ++++++++++++++++++++++++++++++++ core/mapping/unmarshaler_test.go | 2 ++ 3 files changed, 78 insertions(+), 8 deletions(-) diff --git a/core/logx/richlogger.go b/core/logx/richlogger.go index 185a9699..482852e4 100644 --- a/core/logx/richlogger.go +++ b/core/logx/richlogger.go @@ -141,23 +141,43 @@ func (l *richLogger) WithCallerSkip(skip int) Logger { return l } - l.callerSkip = skip - return l + return &richLogger{ + ctx: l.ctx, + callerSkip: skip, + fields: l.fields, + } } func (l *richLogger) WithContext(ctx context.Context) Logger { - l.ctx = ctx - return l + return &richLogger{ + ctx: ctx, + callerSkip: l.callerSkip, + fields: l.fields, + } } func (l *richLogger) WithDuration(duration time.Duration) Logger { - l.fields = append(l.fields, Field(durationKey, timex.ReprOfDuration(duration))) - return l + fields := append(l.fields, Field(durationKey, timex.ReprOfDuration(duration))) + + return &richLogger{ + ctx: l.ctx, + callerSkip: l.callerSkip, + fields: fields, + } } func (l *richLogger) WithFields(fields ...LogField) Logger { - l.fields = append(l.fields, fields...) - return l + if len(fields) == 0 { + return l + } + + f := append(l.fields, fields...) + + return &richLogger{ + ctx: l.ctx, + callerSkip: l.callerSkip, + fields: f, + } } func (l *richLogger) buildFields(fields ...LogField) []LogField { diff --git a/core/logx/richlogger_test.go b/core/logx/richlogger_test.go index 11f7dfcb..52194d43 100644 --- a/core/logx/richlogger_test.go +++ b/core/logx/richlogger_test.go @@ -287,6 +287,54 @@ func TestLogWithCallerSkip(t *testing.T) { assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) } +func TestLogWithCallerSkipCopy(t *testing.T) { + log1 := WithCallerSkip(2) + log2 := log1.WithCallerSkip(3) + log3 := log2.WithCallerSkip(-1) + assert.Equal(t, 2, log1.(*richLogger).callerSkip) + assert.Equal(t, 3, log2.(*richLogger).callerSkip) + assert.Equal(t, 3, log3.(*richLogger).callerSkip) +} + +func TestLogWithContextCopy(t *testing.T) { + c1 := context.Background() + c2 := context.WithValue(context.Background(), "foo", "bar") + log1 := WithContext(c1) + log2 := log1.WithContext(c2) + assert.Equal(t, c1, log1.(*richLogger).ctx) + assert.Equal(t, c2, log2.(*richLogger).ctx) +} + +func TestLogWithDurationCopy(t *testing.T) { + log1 := WithContext(context.Background()) + log2 := log1.WithDuration(time.Second) + assert.Empty(t, log1.(*richLogger).fields) + assert.Equal(t, 1, len(log2.(*richLogger).fields)) + + var w mockWriter + old := writer.Swap(&w) + defer writer.Store(old) + log2.Info("hello") + assert.Contains(t, w.String(), `"duration":"1000.0ms"`) +} + +func TestLogWithFieldsCopy(t *testing.T) { + log1 := WithContext(context.Background()) + log2 := log1.WithFields(Field("foo", "bar")) + log3 := log1.WithFields() + assert.Empty(t, log1.(*richLogger).fields) + assert.Equal(t, 1, len(log2.(*richLogger).fields)) + assert.Equal(t, log1, log3) + assert.Empty(t, log3.(*richLogger).fields) + + var w mockWriter + old := writer.Swap(&w) + defer writer.Store(old) + + log2.Info("hello") + assert.Contains(t, w.String(), `"foo":"bar"`) +} + func TestLoggerWithFields(t *testing.T) { w := new(mockWriter) old := writer.Swap(w) diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 2e16df83..229d9c41 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -261,6 +261,7 @@ func TestUnmarshalInt(t *testing.T) { Int64FromStr int64 `key:"int64str,string"` DefaultInt int64 `key:"defaultint,default=11"` Optional int `key:"optional,optional"` + IntOptDef int `key:"intopt,optional,default=6"` } m := map[string]any{ "int": 1, @@ -289,6 +290,7 @@ func TestUnmarshalInt(t *testing.T) { ast.Equal(int64(9), in.Int64) ast.Equal(int64(10), in.Int64FromStr) ast.Equal(int64(11), in.DefaultInt) + ast.Equal(6, in.IntOptDef) } }