feat: provide a way to create Type Provider

This commit is contained in:
dapeng 2024-11-23 12:41:38 +08:00
parent 9bc57473d7
commit d3b63932d6
5 changed files with 229 additions and 132 deletions

View File

@ -7,6 +7,5 @@ import (
// Priest gorm的priest
func Priest(cemetery gone.Cemetery) error {
cemetery.Bury(NewLogger())
cemetery.Bury(NewGorm())
return nil
return ProviderPriest(cemetery)
}

View File

@ -4,119 +4,94 @@ import (
"github.com/gone-io/gone"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"reflect"
"time"
)
func NewGorm() (gone.Goner, gone.GonerOption, gone.GonerOption) {
return &iGorm{}, gone.GonerId("gorm"), gone.Provide(&gorm.DB{})
}
type iGorm struct {
gone.Flag
db *gorm.DB
dial gorm.Dialector `gone:"*"`
logger logger.Interface `gone:"*"`
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool `gone:"config,gorm.skip-default-transaction"`
// FullSaveAssociations full save associations
FullSaveAssociations bool `gone:"config,gorm.full-save-associations"`
// DryRun generate sql without execute
DryRun bool `gone:"config,dry-run"`
// PrepareStmt executes the given query in cached statement
PrepareStmt bool `gone:"config,gorm.prepare-stmt"`
// DisableAutomaticPing
DisableAutomaticPing bool `gone:"config,gorm.disable-automatic-ping"`
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool `gone:"config,gorm.disable-foreign-key-constraint-when-migrating"`
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool `gone:"config,gorm.ignore-relationships-when-migrating"`
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool `gone:"config,gorm.disable-nested-transaction"`
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool `gone:"config,gorm.allow-global-update"`
// QueryFields executes the SQL query with all fields of the table
QueryFields bool `gone:"config,gorm.query-fields"`
// CreateBatchSize default create batch size
CreateBatchSize int `gone:"config,gorm.create-batch-size"`
// TranslateError enabling error translation
TranslateError bool `gone:"config,gorm.translate-error"`
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool `gone:"config,gorm.propagate-unscoped"`
MaxIdle int `gone:"config,gorm.pool.max-idle"`
MaxOpen int `gone:"config,gorm.pool.max-open"`
ConnMaxLifetime *time.Duration `gone:"config,gorm.pool.conn-max-lifetime"`
}
func (s *iGorm) Start(gone.Cemetery) (err error) {
s.db, err = gorm.Open(s.dial, &gorm.Config{
SkipDefaultTransaction: s.SkipDefaultTransaction,
FullSaveAssociations: s.FullSaveAssociations,
Logger: s.logger,
DryRun: s.DryRun,
PrepareStmt: s.PrepareStmt,
DisableAutomaticPing: s.DisableAutomaticPing,
DisableForeignKeyConstraintWhenMigrating: s.DisableForeignKeyConstraintWhenMigrating,
IgnoreRelationshipsWhenMigrating: s.IgnoreRelationshipsWhenMigrating,
DisableNestedTransaction: s.DisableNestedTransaction,
AllowGlobalUpdate: s.AllowGlobalUpdate,
QueryFields: s.QueryFields,
CreateBatchSize: s.CreateBatchSize,
TranslateError: s.TranslateError,
PropagateUnscoped: s.PropagateUnscoped,
})
if err != nil {
return gone.ToError(err)
}
db, err := s.db.DB()
if err != nil {
return gone.ToError(err)
}
if s.MaxIdle > 0 {
db.SetMaxIdleConns(s.MaxIdle)
}
if s.MaxOpen > 0 {
db.SetMaxOpenConns(s.MaxOpen)
}
if s.ConnMaxLifetime != nil {
db.SetConnMaxLifetime(*s.ConnMaxLifetime)
}
return gone.ToError(db.Ping())
}
func (s *iGorm) Stop(gone.Cemetery) error {
db, err := s.db.DB()
if err != nil {
return gone.ToError(err)
}
return gone.ToError(db.Close())
}
var dbType = reflect.TypeOf(new(gorm.DB))
func (s *iGorm) Suck(conf string, v reflect.Value) gone.SuckError {
if v.Type() == dbType {
v.Set(reflect.ValueOf(s.db))
return nil
} else {
return gone.NewInnerError("only support *gorm.DB", gone.InjectError)
}
func ProviderPriest(cemetery gone.Cemetery) error {
return gone.NewProviderPriest(func(tagConf string, s struct {
dial gorm.Dialector `gone:"*"`
logger logger.Interface `gone:"*"`
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool `gone:"config,gorm.skip-default-transaction"`
// FullSaveAssociations full save associations
FullSaveAssociations bool `gone:"config,gorm.full-save-associations"`
// DryRun generate sql without execute
DryRun bool `gone:"config,dry-run"`
// PrepareStmt executes the given query in cached statement
PrepareStmt bool `gone:"config,gorm.prepare-stmt"`
// DisableAutomaticPing
DisableAutomaticPing bool `gone:"config,gorm.disable-automatic-ping"`
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool `gone:"config,gorm.disable-foreign-key-constraint-when-migrating"`
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool `gone:"config,gorm.ignore-relationships-when-migrating"`
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool `gone:"config,gorm.disable-nested-transaction"`
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool `gone:"config,gorm.allow-global-update"`
// QueryFields executes the SQL query with all fields of the table
QueryFields bool `gone:"config,gorm.query-fields"`
// CreateBatchSize default create batch size
CreateBatchSize int `gone:"config,gorm.create-batch-size"`
// TranslateError enabling error translation
TranslateError bool `gone:"config,gorm.translate-error"`
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool `gone:"config,gorm.propagate-unscoped"`
MaxIdle int `gone:"config,gorm.pool.max-idle"`
MaxOpen int `gone:"config,gorm.pool.max-open"`
ConnMaxLifetime *time.Duration `gone:"config,gorm.pool.conn-max-lifetime"`
}) (*gorm.DB, error) {
g, err := gorm.Open(s.dial, &gorm.Config{
SkipDefaultTransaction: s.SkipDefaultTransaction,
FullSaveAssociations: s.FullSaveAssociations,
Logger: s.logger,
DryRun: s.DryRun,
PrepareStmt: s.PrepareStmt,
DisableAutomaticPing: s.DisableAutomaticPing,
DisableForeignKeyConstraintWhenMigrating: s.DisableForeignKeyConstraintWhenMigrating,
IgnoreRelationshipsWhenMigrating: s.IgnoreRelationshipsWhenMigrating,
DisableNestedTransaction: s.DisableNestedTransaction,
AllowGlobalUpdate: s.AllowGlobalUpdate,
QueryFields: s.QueryFields,
CreateBatchSize: s.CreateBatchSize,
TranslateError: s.TranslateError,
PropagateUnscoped: s.PropagateUnscoped,
})
if err != nil {
return nil, gone.ToError(err)
}
db, err := g.DB()
if err != nil {
return nil, gone.ToError(err)
}
if s.MaxIdle > 0 {
db.SetMaxIdleConns(s.MaxIdle)
}
if s.MaxOpen > 0 {
db.SetMaxOpenConns(s.MaxOpen)
}
if s.ConnMaxLifetime != nil {
db.SetConnMaxLifetime(*s.ConnMaxLifetime)
}
return g, nil
})(cemetery)
}

View File

@ -5,7 +5,6 @@ import (
"github.com/gone-io/gone"
"reflect"
"strconv"
"strings"
"xorm.io/xorm"
)
@ -40,24 +39,8 @@ type provider struct {
var xormInterface = gone.GetInterfaceType(new(gone.XormEngine))
var xormInterfaceSlice = gone.GetInterfaceType(new([]gone.XormEngine))
func confMap(conf string) map[string]string {
conf = strings.TrimSpace(conf)
specs := strings.Split(conf, ",")
m := make(map[string]string)
for _, spec := range specs {
spec = strings.TrimSpace(spec)
pairs := strings.Split(spec, "=")
if len(pairs) == 1 {
m[pairs[0]] = ""
} else if len(pairs) > 1 {
m[pairs[0]] = pairs[1]
}
}
return m
}
func (p *provider) Suck(conf string, v reflect.Value) gone.SuckError {
m := confMap(conf)
m := gone.TagStringParse(conf)
clusterName := m[clusterKey]
if clusterName == "" {
clusterName = defaultCluster

77
help.go
View File

@ -280,3 +280,80 @@ example:
func (p *Preparer) Test(fn any) {
p.testKit().AfterStart(fn).Run()
}
type provider[T any] struct {
Flag
cemetery Cemetery `gone:"*"`
create func(tagConf string) (T, error)
}
func (p *provider[T]) Suck(conf string, v reflect.Value, field reflect.StructField) error {
obj, err := p.create(conf)
if err != nil {
return ToError(err)
}
v.Set(reflect.ValueOf(obj))
return nil
}
// TagStringParse parse tag string to map
// example: "a=1,b=2" -> map[string]string{"a":"1","b":"2"}
func TagStringParse(conf string) map[string]string {
conf = strings.TrimSpace(conf)
specs := strings.Split(conf, ",")
m := make(map[string]string)
for _, spec := range specs {
spec = strings.TrimSpace(spec)
pairs := strings.Split(spec, "=")
if len(pairs) == 1 {
m[pairs[0]] = ""
} else if len(pairs) > 1 {
m[pairs[0]] = pairs[1]
}
}
return m
}
// NewProviderPriest create a provider priest function for goner from a function like: `func(tagConf string, injectableStructParam struct{}) (provideType T, err error)`
// example:
// ```go
// type MyGoner struct {}
//
// func NewMyGoner(tagConf string, param struct{
// depGoner1 MyGoner1 `gone:"*"` // inject dep
// depGoner2 MyGoner2 `gone:"*"` // inject dep
// configStr string `gone:"config,my.config.str"` // inject config from config file
// }) (MyGoner, error) {
//
// // do something
// return MyGoner{}, nil
// }
//
// var priest = NewProviderPriest(NewMyGoner)
// ```
func NewProviderPriest[T any, P any](fn func(tagConf string, param P) (T, error)) Priest {
p := provider[T]{}
p.create = func(tagConf string) (T, error) {
args, err := p.cemetery.InjectFuncParameters(fn, func(pt reflect.Type, i int) any {
if i == 0 {
return tagConf
}
return nil
}, nil)
if err != nil {
return *new(T), err
}
results := reflect.ValueOf(fn).Call(args)
if results[1].IsNil() {
return results[0].Interface().(T), nil
}
return *new(T), ToError(results[1].Interface())
}
return func(cemetery Cemetery) error {
cemetery.Bury(&p, Provide(*new(T)))
return nil
}
}

View File

@ -39,3 +39,66 @@ func TestTestInjectByProvider(t *testing.T) {
return nil
})
}
type testBird struct {
Name string
}
func (b *testBird) Fly() {
println(b.Name + " flying")
}
type testCat struct {
Flag
Name string
}
func (*testCat) Meow() {
println("meow")
}
func TestNewProvider(t *testing.T) {
t.Run("provide struct", func(t *testing.T) {
RunTest(func(p struct {
blackBird testBird `gone:"*,black"`
grayBird testBird `gone:"*,gray"`
}) {
assert.Equal(t, p.blackBird.Name, "black")
assert.Equal(t, p.grayBird.Name, "gray")
}, func(cemetery Cemetery) error {
cemetery.Bury(&testCat{
Name: "cat",
})
priest := NewProviderPriest(func(tagConf string, in struct {
cat *testCat `gone:"*"`
}) (testBird, error) {
assert.Equal(t, in.cat.Name, "cat")
return testBird{Name: tagConf}, nil
})
return priest(cemetery)
})
})
t.Run("provide struct pointer", func(t *testing.T) {
RunTest(func(p struct {
blackBird *testBird `gone:"*,black"`
grayBird *testBird `gone:"*,gray"`
}) {
assert.Equal(t, p.blackBird.Name, "black")
assert.Equal(t, p.grayBird.Name, "gray")
}, func(cemetery Cemetery) error {
cemetery.Bury(&testCat{
Name: "cat",
})
priest := NewProviderPriest(func(tagConf string, in struct {
cat *testCat `gone:"*"`
}) (*testBird, error) {
assert.Equal(t, in.cat.Name, "cat")
return &testBird{Name: tagConf}, nil
})
return priest(cemetery)
})
})
}