diff --git a/core/stores/mon/options.go b/core/stores/mon/options.go index 4097328f..14a6666a 100644 --- a/core/stores/mon/options.go +++ b/core/stores/mon/options.go @@ -1,9 +1,12 @@ package mon import ( + "reflect" "time" "github.com/zeromicro/go-zero/core/syncx" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" mopt "go.mongodb.org/mongo-driver/mongo/options" ) @@ -20,6 +23,13 @@ type ( // Option defines the method to customize a mongo model. Option func(opts *options) + + // RegisterType A struct store With custom type and Encoder/Decoder + RegisterType struct { + ValueType reflect.Type + Encoder bsoncodec.ValueEncoder + Decoder bsoncodec.ValueDecoder + } ) // DisableLog disables logging of mongo commands, includes info and slow logs. @@ -50,3 +60,15 @@ func WithTimeout(timeout time.Duration) Option { opts.SetTimeout(timeout) } } + +// WithRegistry set the Registry to convert custom type to mongo primitive type more easily. +func WithRegistry(registerType ...RegisterType) Option { + return func(opts *options) { + registry := bson.NewRegistry() + for _, v := range registerType { + registry.RegisterTypeEncoder(v.ValueType, v.Encoder) + registry.RegisterTypeDecoder(v.ValueType, v.Decoder) + } + opts.SetRegistry(registry) + } +} diff --git a/core/stores/mon/options_test.go b/core/stores/mon/options_test.go index ee944f94..f5a858ad 100644 --- a/core/stores/mon/options_test.go +++ b/core/stores/mon/options_test.go @@ -1,10 +1,14 @@ package mon import ( + "fmt" + "reflect" "testing" "time" "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" mopt "go.mongodb.org/mongo-driver/mongo/options" ) @@ -51,3 +55,56 @@ func TestDisableInfoLog(t *testing.T) { assert.False(t, logMon.True()) assert.True(t, logSlowMon.True()) } + +func TestWithRegistryForTimestampRegisterType(t *testing.T) { + opts := mopt.Client() + + // mongoDateTimeEncoder allow user convert time.Time to primitive.DateTime. + var mongoDateTimeEncoder bsoncodec.ValueEncoderFunc = func(ect bsoncodec.EncodeContext, w bsonrw.ValueWriter, value reflect.Value) error { + // Use reflect, determine if it can be converted to time.Time. + dec, ok := value.Interface().(time.Time) + if !ok { + return fmt.Errorf("value %v to encode is not of type time.Time", value) + } + return w.WriteDateTime(dec.Unix()) + } + + // mongoDateTimeEncoder allow user convert primitive.DateTime to time.Time. + var mongoDateTimeDecoder bsoncodec.ValueDecoderFunc = func(ect bsoncodec.DecodeContext, r bsonrw.ValueReader, value reflect.Value) error { + primTime, err := r.ReadDateTime() + if err != nil { + return fmt.Errorf("error reading primitive.DateTime from ValueReader: %v", err) + } + value.Set(reflect.ValueOf(time.Unix(primTime, 0))) + return nil + } + + registerType := []RegisterType{ + { + ValueType: reflect.TypeOf(time.Time{}), + Encoder: mongoDateTimeEncoder, + Decoder: mongoDateTimeDecoder, + }, + } + WithRegistry(registerType...)(opts) + + for _, v := range registerType { + // Validate Encoder + enc, err := opts.Registry.LookupEncoder(v.ValueType) + if err != nil { + t.Fatal(err) + } + if assert.ObjectsAreEqual(v.Encoder, enc) { + t.Errorf("Encoder got from Registry: %v, but want: %v", enc, v.Encoder) + } + + // Validate Decoder + dec, err := opts.Registry.LookupDecoder(v.ValueType) + if err != nil { + t.Fatal(err) + } + if assert.ObjectsAreEqual(v.Decoder, dec) { + t.Errorf("Decoder got from Registry: %v, but want: %v", dec, v.Decoder) + } + } +}