优化hgorm.dao方法

This commit is contained in:
孟帅
2023-06-17 17:51:47 +08:00
parent 9b89402bf6
commit aff0ff3af5
15 changed files with 69 additions and 134 deletions

View File

@@ -8,7 +8,6 @@ package hgorm
// dao.
import (
"context"
"errors"
"fmt"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gerror"
@@ -25,56 +24,41 @@ type daoInstance interface {
// Join 关联表属性
type Join struct {
Dao interface{} // 关联表dao实例
Dao daoInstance // 关联表dao实例
Alias string // 别名
fields map[string]*gdb.TableField // 表字段列表
}
// GenJoinOnRelation 生成关联表关联条件
func GenJoinOnRelation(masterTable, masterField, joinTable, alias, onField string) []string {
return []string{
joinTable,
alias,
fmt.Sprintf("`%s`.`%s` = `%s`.`%s`", alias, onField, masterTable, masterField),
}
relation := fmt.Sprintf("`%s`.`%s` = `%s`.`%s`", alias, onField, masterTable, masterField)
return []string{joinTable, alias, relation}
}
// GenJoinSelect 生成关联表select
// 这里会将实体中的字段驼峰转为下划线于数据库进行匹配,意味着数据库字段必须全部是小写字母+下划线的格式
func GenJoinSelect(ctx context.Context, entity interface{}, masterDao interface{}, joins []*Join) (allFields string, err error) {
func GenJoinSelect(ctx context.Context, entity interface{}, dao daoInstance, joins []*Join) (allFields string, err error) {
var tmpFields []string
md, ok := masterDao.(daoInstance)
if !ok {
err = errors.New("masterDao unimplemented interface format.daoInstance")
return
}
if len(joins) == 0 {
err = errors.New("JoinFields joins len = 0")
err = gerror.New("JoinFields joins len = 0")
return
}
for _, v := range joins {
jd, ok := v.Dao.(daoInstance)
if !ok {
err = errors.New("joins index unimplemented interface format.daoInstance")
return
}
v.fields, err = jd.Ctx(ctx).TableFields(jd.Table())
v.fields, err = v.Dao.Ctx(ctx).TableFields(v.Dao.Table())
if err != nil {
return
}
}
masterFields, err := md.Ctx(ctx).TableFields(md.Table())
masterFields, err := dao.Ctx(ctx).TableFields(dao.Table())
if err != nil {
return
}
entityFields, err := convert.GetEntityFieldTags(entity)
if err != nil {
return "", err
return
}
if len(entityFields) == 0 {
@@ -104,42 +88,7 @@ func GenJoinSelect(ctx context.Context, entity interface{}, masterDao interface{
// 主表
originalField := gstr.CaseSnakeFirstUpper(field)
if _, ok := masterFields[originalField]; ok {
tmpFields = append(tmpFields, fmt.Sprintf("`%s`.`%s`", md.Table(), originalField))
continue
}
}
return gstr.Implode(",", convert.UniqueSlice(tmpFields)), nil
}
// GenSelect 生成select
// 这里会将实体中的字段驼峰转为下划线于数据库进行匹配,意味着数据库字段必须全部是小写字母+下划线的格式
func GenSelect(ctx context.Context, entity interface{}, dao interface{}) (allFields string, err error) {
var tmpFields []string
md, ok := dao.(daoInstance)
if !ok {
err = errors.New("dao unimplemented interface format.daoInstance")
return
}
fields, err := md.Ctx(ctx).TableFields(md.Table())
if err != nil {
return
}
entityFields, err := convert.GetEntityFieldTags(entity)
if err != nil {
return "", err
}
if len(entityFields) == 0 {
return "*", nil
}
for _, field := range entityFields {
originalField := gstr.CaseSnakeFirstUpper(field)
if _, ok := fields[originalField]; ok {
tmpFields = append(tmpFields, fmt.Sprintf("`%s`", originalField))
tmpFields = append(tmpFields, fmt.Sprintf("`%s`.`%s`", dao.Table(), originalField))
continue
}
}
@@ -153,7 +102,7 @@ func GetPkField(ctx context.Context, dao daoInstance) (string, error) {
return "", err
}
if len(fields) == 0 {
return "", errors.New("field not found")
return "", gerror.New("field not found")
}
for _, field := range fields {
@@ -161,23 +110,18 @@ func GetPkField(ctx context.Context, dao daoInstance) (string, error) {
return field.Name, nil
}
}
return "", errors.New("no primary key")
return "", gerror.New("no primary key")
}
// IsUnique 是否唯一
func IsUnique(ctx context.Context, dao interface{}, where g.Map, message string, pkId ...interface{}) error {
d, ok := dao.(daoInstance)
if !ok {
return errors.New("IsUnique dao unimplemented interface format.daoInstance")
}
func IsUnique(ctx context.Context, dao daoInstance, where g.Map, message string, pkId ...interface{}) error {
if len(where) == 0 {
return errors.New("where condition cannot be empty")
return gerror.New("where condition cannot be empty")
}
m := d.Ctx(ctx).Where(where)
m := dao.Ctx(ctx).Where(where)
if len(pkId) > 0 {
field, err := GetPkField(ctx, d)
field, err := GetPkField(ctx, dao)
if err != nil {
return err
}
@@ -192,32 +136,28 @@ func IsUnique(ctx context.Context, dao interface{}, where g.Map, message string,
if count > 0 {
if message == "" {
for k := range where {
message = fmt.Sprintf("in the table%s, %v not uniqued", d.Table(), where[k])
message = fmt.Sprintf("in the table%s, %v not uniqued", dao.Table(), where[k])
break
}
}
return errors.New(message)
return gerror.New(message)
}
return nil
}
// GenSubTree 生成下级关系树
func GenSubTree(ctx context.Context, dao interface{}, oldPid int64) (newPid int64, newLevel int, subTree string, err error) {
func GenSubTree(ctx context.Context, dao daoInstance, oldPid int64) (newPid int64, newLevel int, subTree string, err error) {
// 顶级树
if oldPid <= 0 {
return 0, 1, "", nil
}
d, ok := dao.(daoInstance)
if !ok {
return 0, 0, "", errors.New("GenTree dao unimplemented interface format.daoInstance")
}
field, err := GetPkField(ctx, d)
field, err := GetPkField(ctx, dao)
if err != nil {
return 0, 0, "", err
}
models, err := d.Ctx(ctx).Where(field, oldPid).One()
models, err := dao.Ctx(ctx).Where(field, oldPid).One()
if err != nil {
return 0, 0, "", err
}