(goctl)feature/model config (#4062)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
This commit is contained in:
kesonan 2024-04-10 23:01:59 +08:00 committed by GitHub
parent 682460c1c8
commit 2a7ada993b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 574 additions and 36 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/withfig/autocomplete-tools/integrations/cobra"
"github.com/zeromicro/go-zero/tools/goctl/api"
"github.com/zeromicro/go-zero/tools/goctl/bug"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/docker"
"github.com/zeromicro/go-zero/tools/goctl/env"
"github.com/zeromicro/go-zero/tools/goctl/gateway"
@ -113,7 +114,7 @@ func init() {
rootCmd.SetUsageTemplate(usageTpl)
rootCmd.AddCommand(api.Cmd, bug.Cmd, docker.Cmd, kube.Cmd, env.Cmd, gateway.Cmd, model.Cmd)
rootCmd.AddCommand(migrate.Cmd, quickstart.Cmd, rpc.Cmd, tpl.Cmd, upgrade.Cmd)
rootCmd.AddCommand(migrate.Cmd, quickstart.Cmd, rpc.Cmd, tpl.Cmd, upgrade.Cmd, config.Cmd)
rootCmd.Command.AddCommand(cobracompletefig.CreateCompletionSpecCommand())
rootCmd.MustInit()
}

59
tools/goctl/config/cmd.go Normal file
View File

@ -0,0 +1,59 @@
package config
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/zeromicro/go-zero/tools/goctl/internal/cobrax"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
var (
// Cmd describes a bug command.
Cmd = cobrax.NewCommand("config")
initCmd = cobrax.NewCommand("init", cobrax.WithRunE(runConfigInit))
cleanCmd = cobrax.NewCommand("clean", cobrax.WithRunE(runConfigClean))
)
func init() {
Cmd.AddCommand(initCmd, cleanCmd)
}
func runConfigInit(*cobra.Command, []string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
cfgFile, err := getConfigPath(wd)
if err != nil {
return err
}
if pathx.FileExists(cfgFile) {
fmt.Printf("%s already exists, path: %s\n", configFile, cfgFile)
return nil
}
err = os.WriteFile(cfgFile, defaultConfig, 0644)
if err != nil {
return err
}
fmt.Printf("%s generated in %s\n", configFile, cfgFile)
return nil
}
func runConfigClean(*cobra.Command, []string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
cfgFile, err := getConfigPath(wd)
if err != nil {
return err
}
return pathx.RemoveIfExist(cfgFile)
}

View File

@ -1,25 +1,71 @@
package config
import (
_ "embed"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/util/ctx"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
"gopkg.in/yaml.v2"
)
// DefaultFormat defines a default naming style
const DefaultFormat = "gozero"
const (
// DefaultFormat defines a default naming style
DefaultFormat = "gozero"
configFile = "goctl.yaml"
)
var (
//go:embed default.yaml
defaultConfig []byte
ExternalConfig *External
)
// Config defines the file naming style
type Config struct {
// NamingFormat is used to define the naming format of the generated file name.
// just like time formatting, you can specify the formatting style through the
// two format characters go, and zero. for example: snake format you can
// define as go_zero, camel case format you can it is defined as goZero,
// and even split characters can be specified, such as go#zero. in theory,
// any combination can be used, but the prerequisite must meet the naming conventions
// of each operating system file name.
// Note: NamingFormat is based on snake or camel string
NamingFormat string `yaml:"namingFormat"`
}
type (
Config struct {
// NamingFormat is used to define the naming format of the generated file name.
// just like time formatting, you can specify the formatting style through the
// two format characters go, and zero. for example: snake format you can
// define as go_zero, camel case format you can it is defined as goZero,
// and even split characters can be specified, such as go#zero. in theory,
// any combination can be used, but the prerequisite must meet the naming conventions
// of each operating system file name.
// Note: NamingFormat is based on snake or camel string
NamingFormat string `yaml:"namingFormat"`
}
External struct {
// Model is the configuration for the model code generation.
Model Model `yaml:"model,omitempty"`
}
// Model defines the configuration for the model code generation.
Model struct {
// TypesMap: custom Data Type Mapping Table.
TypesMap map[string]ModelTypeMapOption `yaml:"types_map,omitempty" `
}
// ModelTypeMapOption custom Type Options.
ModelTypeMapOption struct {
// Type: valid when not using UnsignedType and NullType.
Type string `yaml:"type"`
// UnsignedType: valid when not using NullType.
UnsignedType string `yaml:"unsigned_type,omitempty"`
// NullType: priority use.
NullType string `yaml:"null_type,omitempty"`
// Pkg defines the package of the custom type.
Pkg string `yaml:"pkg,omitempty"`
}
)
// NewConfig creates an instance for Config
func NewConfig(format string) (*Config, error) {
@ -31,6 +77,58 @@ func NewConfig(format string) (*Config, error) {
return cfg, err
}
func init() {
var cfg External
err := loadConfig(&cfg)
if err != nil {
fmt.Println(err.Error())
} else {
ExternalConfig = &cfg
}
}
func loadConfig(cfg *External) error {
wd, err := os.Getwd()
if err != nil {
return err
}
cfgFile, err := getConfigPath(wd)
if err != nil {
return err
}
var content []byte
if pathx.FileExists(cfgFile) {
content, err = os.ReadFile(cfgFile)
if err != nil {
return err
}
}
if len(content) == 0 {
content = append(content, defaultConfig...)
}
return yaml.Unmarshal(content, cfg)
}
// getConfigPath returns the configuration file path, but not create the file.
func getConfigPath(workDir string) (string, error) {
abs, err := filepath.Abs(workDir)
if err != nil {
return "", err
}
err = pathx.MkdirIfNotExist(abs)
if err != nil {
return "", err
}
projectCtx, err := ctx.Prepare(abs)
if err != nil {
return "", err
}
return filepath.Join(projectCtx.Dir, configFile), nil
}
func validate(cfg *Config) error {
if len(strings.TrimSpace(cfg.NamingFormat)) == 0 {
return errors.New("missing namingFormat")

View File

@ -0,0 +1,157 @@
model:
types_map:
bigint:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
dec:
null_type: sql.NullFloat64
type: float64
decimal:
null_type: sql.NullFloat64
type: float64
double:
null_type: sql.NullFloat64
type: float64
float:
null_type: sql.NullFloat64
type: float64
float4:
null_type: sql.NullFloat64
type: float64
float8:
null_type: sql.NullFloat64
type: float64
int:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
int1:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
int2:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
int3:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
int4:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
int8:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
integer:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
mediumint:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
middleint:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
smallint:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
tinyint:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
date:
null_type: sql.NullTime
type: time.Time
datetime:
null_type: sql.NullTime
type: time.Time
timestamp:
null_type: sql.NullTime
type: time.Time
time:
null_type: sql.NullString
type: string
year:
null_type: sql.NullInt64
type: int64
unsigned_type: uint64
bit:
null_type: sql.NullByte
type: byte
unsigned_type: byte
bool:
null_type: sql.NullBool
type: bool
boolean:
null_type: sql.NullBool
type: bool
char:
null_type: sql.NullString
type: string
varchar:
null_type: sql.NullString
type: string
nvarchar:
null_type: sql.NullString
type: string
nchar:
null_type: sql.NullString
type: string
character:
null_type: sql.NullString
type: string
longvarchar:
null_type: sql.NullString
type: string
linestring:
null_type: sql.NullString
type: string
multilinestring:
null_type: sql.NullString
type: string
binary:
null_type: sql.NullString
type: string
varbinary:
null_type: sql.NullString
type: string
tinytext:
null_type: sql.NullString
type: string
text:
null_type: sql.NullString
type: string
mediumtext:
null_type: sql.NullString
type: string
longtext:
null_type: sql.NullString
type: string
enum:
null_type: sql.NullString
type: string
set:
null_type: sql.NullString
type: string
json:
null_type: sql.NullString
type: string
blob:
null_type: sql.NullString
type: string
longblob:
null_type: sql.NullString
type: string
mediumblob:
null_type: sql.NullString
type: string
tinyblob:
null_type: sql.NullString
type: string

View File

@ -274,6 +274,14 @@
},
"upgrade": {
"short": "Upgrade goctl to latest version"
},
"config": {
"init": {
"short": "Initialize goctl config file"
},
"clean": {
"short": "Clean goctl config file"
}
}
},
"global": {

View File

@ -5,6 +5,8 @@ import (
"strings"
"github.com/zeromicro/ddl-parser/parser"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
)
var unsignedTypeMap = map[string]string{
@ -73,6 +75,63 @@ var commonMysqlDataTypeMapInt = map[int]string{
parser.Boolean: "bool",
}
var commonMysqlDataTypeMap = map[int]string{
// number
parser.Bit: "bit",
parser.TinyInt: "tinyint",
parser.SmallInt: "smallint",
parser.MediumInt: "mediumint",
parser.Int: "int",
parser.MiddleInt: "middleint",
parser.Int1: "int1",
parser.Int2: "int2",
parser.Int3: "int3",
parser.Int4: "int4",
parser.Int8: "int8",
parser.Integer: "integer",
parser.BigInt: "bigint",
parser.Float: "float",
parser.Float4: "float4",
parser.Float8: "float8",
parser.Double: "double",
parser.Decimal: "decimal",
parser.Dec: "dec",
parser.Fixed: "fixed",
parser.Numeric: "numeric",
parser.Real: "real",
// date&time
parser.Date: "date",
parser.DateTime: "datetime",
parser.Timestamp: "timestamp",
parser.Time: "time",
parser.Year: "year",
// string
parser.Char: "char",
parser.VarChar: "varchar",
parser.NVarChar: "nvarchar",
parser.NChar: "nchar",
parser.Character: "character",
parser.LongVarChar: "longvarchar",
parser.LineString: "linestring",
parser.MultiLineString: "multilinestring",
parser.Binary: "binary",
parser.VarBinary: "varbinary",
parser.TinyText: "tinytext",
parser.Text: "text",
parser.MediumText: "mediumtext",
parser.LongText: "longtext",
parser.Enum: "enum",
parser.Set: "set",
parser.Json: "json",
parser.Blob: "blob",
parser.LongBlob: "longblob",
parser.MediumBlob: "mediumblob",
parser.TinyBlob: "tinyblob",
// bool
parser.Bool: "bool",
parser.Boolean: "boolean",
}
var commonMysqlDataTypeMapString = map[string]string{
// For consistency, all integer types are converted to int64
// bool
@ -144,28 +203,79 @@ var commonMysqlDataTypeMapString = map[string]string{
}
// ConvertDataType converts mysql column type into golang type
func ConvertDataType(dataBaseType int, isDefaultNull, unsigned, strict bool) (string, error) {
tp, ok := commonMysqlDataTypeMapInt[dataBaseType]
if !ok {
return "", fmt.Errorf("unsupported database type: %v", dataBaseType)
func ConvertDataType(dataBaseType int, isDefaultNull, unsigned, strict bool) (string, string, error) {
if env.UseExperimental() {
tp, ok := commonMysqlDataTypeMap[dataBaseType]
if !ok {
return "", "", fmt.Errorf("unsupported database type: %v", dataBaseType)
}
goType, thirdPkg, _, err := ConvertStringDataType(tp, isDefaultNull, unsigned, strict)
return goType, thirdPkg, err
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil
// the following are the old version compatibility code.
tp, ok := commonMysqlDataTypeMapInt[dataBaseType]
if !ok {
return "", "", fmt.Errorf("unsupported database type: %v", dataBaseType)
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), "", nil
}
// ConvertStringDataType converts mysql column type into golang type
func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned, strict bool) (
goType string, isPQArray bool, err error) {
goType string, thirdPkg string, isPQArray bool, err error) {
if env.UseExperimental() {
customTp, thirdImport := convertDatatypeWithConfig(dataBaseType, isDefaultNull, unsigned)
if len(customTp) != 0 {
return customTp, thirdImport, false, nil
}
tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)]
if !ok {
return "", "", false, fmt.Errorf("unsupported database type: %s", dataBaseType)
}
if strings.HasPrefix(dataBaseType, "_") {
return tp, "", true, nil
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), "", false, nil
}
// the following are the old version compatibility code.
tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)]
if !ok {
return "", false, fmt.Errorf("unsupported database type: %s", dataBaseType)
return "", "", false, fmt.Errorf("unsupported database type: %s", dataBaseType)
}
if strings.HasPrefix(dataBaseType, "_") {
return tp, true, nil
return tp, "", true, nil
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), false, nil
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), "", false, nil
}
func convertDatatypeWithConfig(dataBaseType string, isDefaultNull, unsigned bool) (string, string) {
if config.ExternalConfig == nil {
return "", ""
}
opt, ok := config.ExternalConfig.Model.TypesMap[strings.ToLower(dataBaseType)]
if !ok || (len(opt.Type) == 0 && len(opt.UnsignedType) == 0 && len(opt.NullType) == 0) {
return "", ""
}
if isDefaultNull {
if len(opt.NullType) != 0 {
return opt.NullType, opt.Pkg
}
} else if unsigned {
if len(opt.UnsignedType) != 0 {
return opt.UnsignedType, opt.Pkg
}
}
return opt.Type, opt.Pkg
}
func mayConvertNullType(goDataType string, isDefaultNull, unsigned, strict bool) string {

View File

@ -8,23 +8,102 @@ import (
)
func TestConvertDataType(t *testing.T) {
v, err := ConvertDataType(parser.TinyInt, false, false, true)
v, _, err := ConvertDataType(parser.TinyInt, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "int64", v)
v, err = ConvertDataType(parser.TinyInt, false, true, true)
v, _, err = ConvertDataType(parser.TinyInt, false, true, true)
assert.Nil(t, err)
assert.Equal(t, "uint64", v)
v, err = ConvertDataType(parser.TinyInt, true, false, true)
v, _, err = ConvertDataType(parser.TinyInt, true, false, true)
assert.Nil(t, err)
assert.Equal(t, "sql.NullInt64", v)
v, err = ConvertDataType(parser.Timestamp, false, false, true)
v, _, err = ConvertDataType(parser.Timestamp, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "time.Time", v)
v, err = ConvertDataType(parser.Timestamp, true, false, true)
v, _, err = ConvertDataType(parser.Timestamp, true, false, true)
assert.Nil(t, err)
assert.Equal(t, "sql.NullTime", v)
v, _, err = ConvertDataType(parser.Decimal, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "float64", v)
}
func TestConvertStringDataType(t *testing.T) {
type (
input struct {
dataType string
isDefaultNull bool
unsigned bool
strict bool
}
result struct {
goType string
thirdPkg string
isPQArray bool
}
)
var testData = []struct {
input input
want result
}{
{
input: input{
dataType: "bigint",
isDefaultNull: false,
unsigned: false,
strict: false,
},
want: result{
goType: "int64",
},
},
{
input: input{
dataType: "bigint",
isDefaultNull: true,
unsigned: false,
strict: false,
},
want: result{
goType: "sql.NullInt64",
},
},
{
input: input{
dataType: "bigint",
isDefaultNull: false,
unsigned: true,
strict: false,
},
want: result{
goType: "uint64",
},
},
{
input: input{
dataType: "_int2",
isDefaultNull: false,
unsigned: false,
strict: false,
},
want: result{
goType: "pq.Int64Array",
isPQArray: true,
},
},
}
for _, data := range testData {
tp, thirdPkg, isPQArray, err := ConvertStringDataType(data.input.dataType, data.input.isDefaultNull, data.input.unsigned, data.input.strict)
assert.NoError(t, err)
assert.Equal(t, data.want, result{
goType: tp,
thirdPkg: thirdPkg,
isPQArray: isPQArray,
})
}
}

View File

@ -9,7 +9,7 @@ CREATE TABLE `user`
`mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',
`gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公\r开',
`nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',
`type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT '用户类型',
`type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT '用户类型',
`create_time` timestamp NULL,
`update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
@ -22,12 +22,13 @@ CREATE TABLE `user`
CREATE TABLE `student`
(
`type` bigint NOT NULL,
`class` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`name` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`age` tinyint DEFAULT NULL,
`type` bigint NOT NULL,
`class` varchar(255) NOT NULL DEFAULT '',
`name` varchar(255) NOT NULL DEFAULT '',
`age` tinyint DEFAULT NULL,
`score` float(10, 0
) DEFAULT NULL,
`amount` decimal DEFAULT NULL,
`create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` timestamp NULL DEFAULT NULL,
`delete_time` timestamp NULL DEFAULT NULL ON UPDATE CURRENT_TIMESTAMP,

View File

@ -1,12 +1,27 @@
package gen
import (
"fmt"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
func genImports(table Table, withCache, timeImport bool) (string, error) {
var thirdImports []string
var m = map[string]struct{}{}
for _, c := range table.Fields {
if len(c.ThirdPkg) > 0 {
if _, ok := m[c.ThirdPkg]; ok {
continue
}
m[c.ThirdPkg] = struct{}{}
thirdImports = append(thirdImports, fmt.Sprintf("%q", c.ThirdPkg))
}
}
if withCache {
text, err := pathx.LoadTemplate(category, importsTemplateFile, template.Imports)
if err != nil {
@ -17,6 +32,7 @@ func genImports(table Table, withCache, timeImport bool) (string, error) {
"time": timeImport,
"containsPQ": table.ContainsPQ,
"data": table,
"third": strings.Join(thirdImports, "\n"),
})
if err != nil {
return "", err
@ -34,6 +50,7 @@ func genImports(table Table, withCache, timeImport bool) (string, error) {
"time": timeImport,
"containsPQ": table.ContainsPQ,
"data": table,
"third": strings.Join(thirdImports, "\n"),
})
if err != nil {
return "", err

View File

@ -38,6 +38,7 @@ type (
Field struct {
NameOriginal string
Name stringx.String
ThirdPkg string
DataType string
Comment string
SeqInIndex int
@ -219,7 +220,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string, strict bool)
}
}
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
dataType, thirdPkg, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
if err != nil {
return Primary{}, nil, err
}
@ -236,6 +237,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string, strict bool)
var field Field
field.Name = stringx.From(column.Name)
field.ThirdPkg = thirdPkg
field.DataType = dataType
field.Comment = util.TrimNewLine(comment)
@ -267,7 +269,7 @@ func (t *Table) ContainsTime() bool {
func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned")
primaryDataType, containsPQ, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
primaryDataType, thirdPkg, containsPQ, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}
@ -285,6 +287,7 @@ func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
reply.PrimaryKey = Primary{
Field: Field{
Name: stringx.From(table.PrimaryKey.Name),
ThirdPkg: thirdPkg,
DataType: primaryDataType,
Comment: table.PrimaryKey.Comment,
SeqInIndex: seqInIndex,
@ -351,7 +354,7 @@ func getTableFields(table *model.Table, strict bool) (map[string]*Field, error)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned")
dt, containsPQ, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
dt, thirdPkg, containsPQ, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}
@ -363,6 +366,7 @@ func getTableFields(table *model.Table, strict bool) (map[string]*Field, error)
field := &Field{
NameOriginal: each.Name,
Name: stringx.From(each.Name),
ThirdPkg: thirdPkg,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,

View File

@ -9,4 +9,6 @@ import (
"github.com/zeromicro/go-zero/core/stores/builder"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/stringx"
{{.third}}
)

View File

@ -11,4 +11,6 @@ import (
"github.com/zeromicro/go-zero/core/stores/sqlc"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/stringx"
{{.third}}
)