go-zero/tools/goctl/model/sql/parser/parser.go

394 lines
9.6 KiB
Go
Raw Normal View History

package parser
import (
"fmt"
2021-03-01 17:29:07 +08:00
"sort"
"strings"
2021-03-01 17:29:07 +08:00
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
2021-05-12 12:28:23 +08:00
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
2021-03-01 17:29:07 +08:00
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
"github.com/xwb1989/sqlparser"
)
2020-09-03 14:00:09 +08:00
const timeImport = "time.Time"
type (
// Table describes a mysql table
Table struct {
2021-03-01 17:29:07 +08:00
Name stringx.String
PrimaryKey Primary
UniqueIndex map[string][]*Field
NormalIndex map[string][]*Field
Fields []*Field
}
2020-09-03 14:00:09 +08:00
// Primary describes a primary key
Primary struct {
Field
AutoIncrement bool
}
2020-09-03 14:00:09 +08:00
// Field describes a table field
Field struct {
2021-03-01 17:29:07 +08:00
Name stringx.String
DataBaseType string
DataType string
Comment string
SeqInIndex int
OrdinalPosition int
}
2020-09-03 14:00:09 +08:00
// KeyType types alias of int
KeyType int
)
// Parse parses ddl into golang structure
func Parse(ddl string) (*Table, error) {
stmt, err := sqlparser.ParseStrictDDL(ddl)
if err != nil {
return nil, err
}
2020-08-19 16:10:43 +08:00
ddlStmt, ok := stmt.(*sqlparser.DDL)
if !ok {
2021-02-09 13:50:21 +08:00
return nil, errUnsupportDDL
}
2020-08-19 16:10:43 +08:00
action := ddlStmt.Action
if action != sqlparser.CreateStr {
return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action)
}
2020-08-19 16:10:43 +08:00
tableName := ddlStmt.NewName.Name.String()
tableSpec := ddlStmt.TableSpec
if tableSpec == nil {
2021-02-09 13:50:21 +08:00
return nil, errTableBodyNotFound
}
columns := tableSpec.Columns
indexes := tableSpec.Indexes
2021-03-01 17:29:07 +08:00
primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes)
2021-02-20 19:50:03 +08:00
if err != nil {
return nil, err
}
2020-08-19 16:10:43 +08:00
2021-03-01 21:14:07 +08:00
primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
2021-02-20 19:50:03 +08:00
if err != nil {
return nil, err
}
2021-03-01 21:14:07 +08:00
var fields []*Field
for _, e := range fieldM {
fields = append(fields, e)
}
2021-03-01 17:29:07 +08:00
var (
uniqueIndex = make(map[string][]*Field)
normalIndex = make(map[string][]*Field)
)
2021-03-01 21:14:07 +08:00
2021-03-01 17:29:07 +08:00
for indexName, each := range uniqueKeyMap {
for _, columnName := range each {
uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
}
}
for indexName, each := range normalKeyMap {
for _, columnName := range each {
normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
}
}
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
}
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
2021-03-01 21:14:07 +08:00
log := console.NewColorConsole()
uniqueSet := collection.NewSet()
for k, i := range uniqueIndex {
var list []string
for _, e := range i {
list = append(list, e.Name.Source())
}
joinRet := strings.Join(list, ",")
if uniqueSet.Contains(joinRet) {
log.Warning("table %s: duplicate unique index %s", tableName, joinRet)
delete(uniqueIndex, k)
continue
}
uniqueSet.AddStr(joinRet)
}
normalIndexSet := collection.NewSet()
for k, i := range normalIndex {
var list []string
for _, e := range i {
list = append(list, e.Name.Source())
}
joinRet := strings.Join(list, ",")
if normalIndexSet.Contains(joinRet) {
log.Warning("table %s: duplicate index %s", tableName, joinRet)
delete(normalIndex, k)
continue
}
normalIndexSet.Add(joinRet)
}
2021-02-20 19:50:03 +08:00
}
2021-03-01 21:14:07 +08:00
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
2021-03-01 17:29:07 +08:00
var (
primaryKey Primary
fieldM = make(map[string]*Field)
)
for _, column := range columns {
if column == nil {
continue
}
2021-03-01 17:29:07 +08:00
var comment string
if column.Type.Comment != nil {
comment = string(column.Type.Comment.Val)
}
2021-03-01 17:29:07 +08:00
2021-04-15 19:49:17 +08:00
isDefaultNull := true
if column.Type.NotNull {
isDefaultNull = false
} else {
if column.Type.Default == nil {
isDefaultNull = false
} else if string(column.Type.Default.Val) != "null" {
isDefaultNull = false
}
}
2021-03-01 17:29:07 +08:00
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
if err != nil {
2021-03-01 21:14:07 +08:00
return Primary{}, nil, err
}
2020-08-19 16:10:43 +08:00
var field Field
field.Name = stringx.From(column.Name.String())
field.DataBaseType = column.Type.Type
field.DataType = dataType
2021-05-12 12:28:23 +08:00
field.Comment = util.TrimNewLine(comment)
2021-03-01 17:29:07 +08:00
if field.Name.Source() == primaryColumn {
primaryKey = Primary{
Field: field,
AutoIncrement: bool(column.Type.Autoincrement),
}
}
2021-03-01 17:29:07 +08:00
fieldM[field.Name.Source()] = &field
}
2021-03-01 21:14:07 +08:00
return primaryKey, fieldM, nil
2021-02-20 19:50:03 +08:00
}
2020-08-19 16:10:43 +08:00
2021-03-01 17:29:07 +08:00
func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) {
var primaryColumn string
uniqueKeyMap := make(map[string][]string)
normalKeyMap := make(map[string][]string)
isCreateTimeOrUpdateTime := func(name string) bool {
camelColumnName := stringx.From(name).ToCamel()
// by default, createTime|updateTime findOne is not used.
return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime"
}
2021-02-20 19:50:03 +08:00
for _, index := range indexes {
info := index.Info
if info == nil {
continue
}
2021-03-01 17:29:07 +08:00
indexName := index.Info.Name.String()
2021-02-20 19:50:03 +08:00
if info.Primary {
if len(index.Columns) > 1 {
2021-03-01 17:29:07 +08:00
return "", nil, nil, errPrimaryKey
}
columnName := index.Columns[0].Column.String()
if isCreateTimeOrUpdateTime(columnName) {
continue
2021-02-20 19:50:03 +08:00
}
2021-03-01 17:29:07 +08:00
primaryColumn = columnName
2021-02-20 19:50:03 +08:00
continue
2021-03-01 17:29:07 +08:00
} else if info.Unique {
for _, each := range index.Columns {
columnName := each.Column.String()
if isCreateTimeOrUpdateTime(columnName) {
break
}
uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName)
}
2021-02-20 19:50:03 +08:00
} else if info.Spatial {
2021-03-01 17:29:07 +08:00
// do nothing
2021-02-20 19:50:03 +08:00
} else {
2021-03-01 17:29:07 +08:00
for _, each := range index.Columns {
columnName := each.Column.String()
if isCreateTimeOrUpdateTime(columnName) {
break
}
normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String())
}
2021-02-20 19:50:03 +08:00
}
}
2021-03-01 17:29:07 +08:00
return primaryColumn, uniqueKeyMap, normalKeyMap, nil
}
2021-03-01 17:29:07 +08:00
// ContainsTime returns true if contains golang type time.Time
func (t *Table) ContainsTime() bool {
for _, item := range t.Fields {
if item.DataType == timeImport {
return true
}
}
return false
}
2021-03-01 17:29:07 +08:00
// ConvertDataType converts mysql data type into golang data type
func ConvertDataType(table *model.Table) (*Table, error) {
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
if err != nil {
return nil, err
}
2021-03-01 17:29:07 +08:00
var reply Table
reply.UniqueIndex = map[string][]*Field{}
reply.NormalIndex = map[string][]*Field{}
reply.Name = stringx.From(table.Table)
seqInIndex := 0
if table.PrimaryKey.Index != nil {
seqInIndex = table.PrimaryKey.Index.SeqInIndex
}
2021-03-01 17:29:07 +08:00
reply.PrimaryKey = Primary{
Field: Field{
Name: stringx.From(table.PrimaryKey.Name),
DataBaseType: table.PrimaryKey.DataType,
DataType: primaryDataType,
Comment: table.PrimaryKey.Comment,
SeqInIndex: seqInIndex,
OrdinalPosition: table.PrimaryKey.OrdinalPosition,
},
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
}
fieldM, err := getTableFields(table)
if err != nil {
return nil, err
}
2021-03-01 17:29:07 +08:00
for _, each := range fieldM {
reply.Fields = append(reply.Fields, each)
}
2021-03-01 17:29:07 +08:00
sort.Slice(reply.Fields, func(i, j int) bool {
return reply.Fields[i].OrdinalPosition < reply.Fields[j].OrdinalPosition
})
2021-03-01 17:29:07 +08:00
uniqueIndexSet := collection.NewSet()
log := console.NewColorConsole()
for indexName, each := range table.UniqueIndex {
sort.Slice(each, func(i, j int) bool {
if each[i].Index != nil {
return each[i].Index.SeqInIndex < each[j].Index.SeqInIndex
}
2021-03-01 17:29:07 +08:00
return false
})
if len(each) == 1 {
one := each[0]
if one.Name == table.PrimaryKey.Name {
2021-03-01 21:14:07 +08:00
log.Warning("table %s: duplicate unique index with primary key, %s", table.Table, one.Name)
2021-03-01 17:29:07 +08:00
continue
}
}
2021-03-01 17:29:07 +08:00
var list []*Field
var uniqueJoin []string
for _, c := range each {
list = append(list, fieldM[c.Name])
uniqueJoin = append(uniqueJoin, c.Name)
}
uniqueKey := strings.Join(uniqueJoin, ",")
if uniqueIndexSet.Contains(uniqueKey) {
2021-03-01 21:14:07 +08:00
log.Warning("table %s: duplicate unique index, %s", table.Table, uniqueKey)
2021-03-01 17:29:07 +08:00
continue
}
2021-03-01 21:14:07 +08:00
uniqueIndexSet.AddStr(uniqueKey)
2021-03-01 17:29:07 +08:00
reply.UniqueIndex[indexName] = list
}
2021-03-01 21:14:07 +08:00
normalIndexSet := collection.NewSet()
2021-03-01 17:29:07 +08:00
for indexName, each := range table.NormalIndex {
var list []*Field
2021-03-01 21:14:07 +08:00
var normalJoin []string
2021-03-01 17:29:07 +08:00
for _, c := range each {
list = append(list, fieldM[c.Name])
2021-03-01 21:14:07 +08:00
normalJoin = append(normalJoin, c.Name)
}
normalKey := strings.Join(normalJoin, ",")
if normalIndexSet.Contains(normalKey) {
log.Warning("table %s: duplicate index, %s", table.Table, normalKey)
continue
2021-03-01 17:29:07 +08:00
}
2021-03-01 21:14:07 +08:00
normalIndexSet.AddStr(normalKey)
2021-03-01 17:29:07 +08:00
sort.Slice(list, func(i, j int) bool {
return list[i].SeqInIndex < list[j].SeqInIndex
})
reply.NormalIndex[indexName] = list
}
return &reply, nil
}
func getTableFields(table *model.Table) (map[string]*Field, error) {
fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
if err != nil {
return nil, err
}
columnSeqInIndex := 0
if each.Index != nil {
columnSeqInIndex = each.Index.SeqInIndex
}
field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,
OrdinalPosition: each.OrdinalPosition,
}
fieldM[each.Name] = field
}
return fieldM, nil
}