Add strict flag (#2248)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
This commit is contained in:
anqiansong 2022-08-28 18:55:52 +08:00 committed by GitHub
parent a1466e1707
commit f70805ee60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 126 additions and 57 deletions

View File

@ -2,6 +2,7 @@ package model
import (
"github.com/spf13/cobra"
"github.com/zeromicro/go-zero/tools/goctl/model/mongo"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/command"
)
@ -77,6 +78,7 @@ func init() {
pgDatasourceCmd.Flags().StringVarP(&command.VarStringDir, "dir", "d", "", "The target dir")
pgDatasourceCmd.Flags().StringVar(&command.VarStringStyle, "style", "", "The file naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
pgDatasourceCmd.Flags().BoolVar(&command.VarBoolIdea, "idea", false, "For idea plugin [optional]")
pgDatasourceCmd.Flags().BoolVar(&command.VarBoolStrict, "strict", false, "Generate model in strict mode")
pgDatasourceCmd.Flags().StringVar(&command.VarStringHome, "home", "", "The goctl home path of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority")
pgDatasourceCmd.Flags().StringVar(&command.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\n\tThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure")
pgDatasourceCmd.Flags().StringVar(&command.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote")
@ -90,6 +92,8 @@ func init() {
mongoCmd.Flags().StringVar(&mongo.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\nThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure")
mongoCmd.Flags().StringVar(&mongo.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote")
mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict", false, "Generate model in strict mode")
mysqlCmd.AddCommand(datasourceCmd)
mysqlCmd.AddCommand(ddlCmd)
pgCmd.AddCommand(pgDatasourceCmd)

View File

@ -10,6 +10,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/postgres"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/command/migrationnotes"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/gen"
@ -47,6 +48,8 @@ var (
VarStringRemote string
// VarStringBranch describes the git branch of the repository.
VarStringBranch string
// VarBoolStrict describes whether the strict mode is enabled.
VarBoolStrict bool
)
var errNotMatched = errors.New("sql not matched")
@ -77,7 +80,16 @@ func MysqlDDL(_ *cobra.Command, _ []string) error {
return err
}
return fromDDL(src, dir, cfg, cache, idea, database)
arg := ddlArg{
src: src,
dir: dir,
cfg: cfg,
cache: cache,
idea: idea,
database: database,
strict: VarBoolStrict,
}
return fromDDL(arg)
}
// MySqlDataSource generates model code from datasource
@ -108,7 +120,16 @@ func MySqlDataSource(_ *cobra.Command, _ []string) error {
return err
}
return fromMysqlDataSource(url, dir, patterns, cfg, cache, idea)
arg := dataSourceArg{
url: url,
dir: dir,
tablePat: patterns,
cfg: cfg,
cache: cache,
idea: idea,
strict: VarBoolStrict,
}
return fromMysqlDataSource(arg)
}
type pattern map[string]struct{}
@ -180,12 +201,20 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error {
return err
}
return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea)
return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea, VarBoolStrict)
}
func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database string) error {
log := console.NewConsole(idea)
src = strings.TrimSpace(src)
type ddlArg struct {
src, dir string
cfg *config.Config
cache, idea bool
database string
strict bool
}
func fromDDL(arg ddlArg) error {
log := console.NewConsole(arg.idea)
src := strings.TrimSpace(arg.src)
if len(src) == 0 {
return errors.New("expected path or path globbing patterns, but nothing found")
}
@ -199,13 +228,13 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str
return errNotMatched
}
generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log))
if err != nil {
return err
}
for _, file := range files {
err = generator.StartFromDDL(file, cache, database)
err = generator.StartFromDDL(file, arg.cache, arg.strict, arg.database)
if err != nil {
return err
}
@ -214,25 +243,33 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str
return nil
}
func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, cache, idea bool) error {
log := console.NewConsole(idea)
if len(url) == 0 {
type dataSourceArg struct {
url, dir string
tablePat pattern
cfg *config.Config
cache, idea bool
strict bool
}
func fromMysqlDataSource(arg dataSourceArg) error {
log := console.NewConsole(arg.idea)
if len(arg.url) == 0 {
log.Error("%v", "expected data source of mysql, but nothing found")
return nil
}
if len(tablePat) == 0 {
if len(arg.tablePat) == 0 {
log.Error("%v", "expected table or table globbing patterns, but nothing found")
return nil
}
dsn, err := mysql.ParseDSN(url)
dsn, err := mysql.ParseDSN(arg.url)
if err != nil {
return err
}
logx.Disable()
databaseSource := strings.TrimSuffix(url, "/"+dsn.DBName) + "/information_schema"
databaseSource := strings.TrimSuffix(arg.url, "/"+dsn.DBName) + "/information_schema"
db := sqlx.NewMysql(databaseSource)
im := model.NewInformationSchemaModel(db)
@ -243,7 +280,7 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config,
matchTables := make(map[string]*model.Table)
for _, item := range tables {
if !tablePat.Match(item) {
if !arg.tablePat.Match(item) {
continue
}
@ -264,15 +301,15 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config,
return errors.New("no tables matched")
}
generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log))
if err != nil {
return err
}
return generator.StartFromInformationSchema(matchTables, cache)
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
}
func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea bool) error {
func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool) error {
log := console.NewConsole(idea)
if len(url) == 0 {
log.Error("%v", "expected data source of postgresql, but nothing found")
@ -324,5 +361,5 @@ func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Conf
return err
}
return generator.StartFromInformationSchema(matchTables, cache)
return generator.StartFromInformationSchema(matchTables, cache, strict)
}

View File

@ -10,6 +10,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/gen"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
@ -27,12 +28,25 @@ func TestFromDDl(t *testing.T) {
err := gen.Clean()
assert.Nil(t, err)
err = fromDDL("./user.sql", pathx.MustTempDir(), cfg, true, false, "go_zero")
err = fromDDL(ddlArg{
src: "./user.sql",
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go-zero",
strict: false,
})
assert.Equal(t, errNotMatched, err)
// case dir is not exists
unknownDir := filepath.Join(pathx.MustTempDir(), "test", "user.sql")
err = fromDDL(unknownDir, pathx.MustTempDir(), cfg, true, false, "go_zero")
err = fromDDL(ddlArg{
src: unknownDir,
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go_zero",
})
assert.True(t, func() bool {
switch err.(type) {
case *os.PathError:
@ -43,7 +57,12 @@ func TestFromDDl(t *testing.T) {
}())
// case empty src
err = fromDDL("", pathx.MustTempDir(), cfg, true, false, "go_zero")
err = fromDDL(ddlArg{
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go_zero",
})
if err != nil {
assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
}
@ -75,7 +94,13 @@ func TestFromDDl(t *testing.T) {
filename := filepath.Join(tempDir, "usermodel.go")
fromDDL := func(db string) {
err = fromDDL(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, db)
err = fromDDL(ddlArg{
src: filepath.Join(tempDir, "user*.sql"),
dir: tempDir,
cfg: cfg,
cache: true,
database: db,
})
assert.Nil(t, err)
_, err = os.Stat(filename)

View File

@ -132,28 +132,28 @@ var commonMysqlDataTypeMapString = map[string]string{
}
// ConvertDataType converts mysql column type into golang type
func ConvertDataType(dataBaseType int, isDefaultNull, unsigned bool) (string, error) {
func ConvertDataType(dataBaseType int, isDefaultNull, unsigned, strict bool) (string, error) {
tp, ok := commonMysqlDataTypeMapInt[dataBaseType]
if !ok {
return "", fmt.Errorf("unsupported database type: %v", dataBaseType)
}
return mayConvertNullType(tp, isDefaultNull, unsigned), nil
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil
}
// ConvertStringDataType converts mysql column type into golang type
func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned bool) (string, error) {
func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned, strict bool) (string, error) {
tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)]
if !ok {
return "", fmt.Errorf("unsupported database type: %s", dataBaseType)
}
return mayConvertNullType(tp, isDefaultNull, unsigned), nil
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil
}
func mayConvertNullType(goDataType string, isDefaultNull, unsigned bool) string {
func mayConvertNullType(goDataType string, isDefaultNull, unsigned, strict bool) string {
if !isDefaultNull {
if unsigned {
if unsigned && strict {
ret, ok := unsignedTypeMap[goDataType]
if ok {
return ret

View File

@ -8,23 +8,23 @@ import (
)
func TestConvertDataType(t *testing.T) {
v, err := ConvertDataType(parser.TinyInt, false, false)
v, err := ConvertDataType(parser.TinyInt, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "int64", v)
v, err = ConvertDataType(parser.TinyInt, false, true)
v, err = ConvertDataType(parser.TinyInt, false, true, true)
assert.Nil(t, err)
assert.Equal(t, "uint64", v)
v, err = ConvertDataType(parser.TinyInt, true, false)
v, err = ConvertDataType(parser.TinyInt, true, false, true)
assert.Nil(t, err)
assert.Equal(t, "sql.NullInt64", v)
v, err = ConvertDataType(parser.Timestamp, false, false)
v, err = ConvertDataType(parser.Timestamp, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "time.Time", v)
v, err = ConvertDataType(parser.Timestamp, true, false)
v, err = ConvertDataType(parser.Timestamp, true, false, true)
assert.Nil(t, err)
assert.Equal(t, "sql.NullTime", v)
}

View File

@ -102,8 +102,8 @@ func newDefaultOption() Option {
}
}
func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, database string) error {
modelList, err := g.genFromDDL(filename, withCache, database)
func (g *defaultGenerator) StartFromDDL(filename string, withCache, strict bool, database string) error {
modelList, err := g.genFromDDL(filename, withCache, strict, database)
if err != nil {
return err
}
@ -111,10 +111,10 @@ func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, databas
return g.createFile(modelList)
}
func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error {
func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache, strict bool) error {
m := make(map[string]*codeTuple)
for _, each := range tables {
table, err := parser.ConvertDataType(each)
table, err := parser.ConvertDataType(each, strict)
if err != nil {
return err
}
@ -201,11 +201,11 @@ func (g *defaultGenerator) createFile(modelList map[string]*codeTuple) error {
}
// ret1: key-table name,value-code
func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (
func (g *defaultGenerator) genFromDDL(filename string, withCache, strict bool, database string) (
map[string]*codeTuple, error,
) {
m := make(map[string]*codeTuple)
tables, err := parser.Parse(filename, database)
tables, err := parser.Parse(filename, database, strict)
if err != nil {
return nil, err
}

View File

@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
@ -40,7 +41,7 @@ func TestCacheModel(t *testing.T) {
})
assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, true, "go_zero")
err = g.StartFromDDL(sqlFile, true, false, "go_zero")
assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
@ -51,7 +52,7 @@ func TestCacheModel(t *testing.T) {
})
assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, false, "go_zero")
err = g.StartFromDDL(sqlFile, false, false, "go_zero")
assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
@ -78,7 +79,7 @@ func TestNamingModel(t *testing.T) {
})
assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, true, "go_zero")
err = g.StartFromDDL(sqlFile, true, false, "go_zero")
assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
@ -89,7 +90,7 @@ func TestNamingModel(t *testing.T) {
})
assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, true, "go_zero")
err = g.StartFromDDL(sqlFile, true, false, "go_zero")
assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
@ -186,7 +187,7 @@ func Test_genPublicModel(t *testing.T) {
})
require.NoError(t, err)
tables, err := parser.Parse(modelFilename, "")
tables, err := parser.Parse(modelFilename, "", false)
require.Equal(t, 1, len(tables))
code, err := g.genModelCustom(*tables[0], false)

View File

@ -8,6 +8,7 @@ import (
"github.com/zeromicro/ddl-parser/parser"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/converter"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
@ -61,7 +62,7 @@ func parseNameOriginal(ts []*parser.Table) (nameOriginals [][]string) {
}
// Parse parses ddl into golang structure
func Parse(filename, database string) ([]*Table, error) {
func Parse(filename, database string, strict bool) ([]*Table, error) {
p := parser.NewParser()
tables, err := p.From(filename)
if err != nil {
@ -124,7 +125,7 @@ func Parse(filename, database string) ([]*Table, error) {
return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
}
primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
primaryKey, fieldM, err := convertColumns(columns, primaryColumn, strict)
if err != nil {
return nil, err
}
@ -190,7 +191,7 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
}
}
func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
func convertColumns(columns []*parser.Column, primaryColumn string, strict bool) (Primary, map[string]*Field, error) {
var (
primaryKey Primary
fieldM = make(map[string]*Field)
@ -219,7 +220,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, ma
}
}
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned())
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
if err != nil {
return Primary{}, nil, err
}
@ -264,10 +265,10 @@ func (t *Table) ContainsTime() bool {
}
// ConvertDataType converts mysql data type into golang data type
func ConvertDataType(table *model.Table) (*Table, error) {
func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned")
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned)
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}
@ -292,7 +293,7 @@ func ConvertDataType(table *model.Table) (*Table, error) {
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
}
fieldM, err := getTableFields(table)
fieldM, err := getTableFields(table, strict)
if err != nil {
return nil, err
}
@ -342,12 +343,12 @@ func ConvertDataType(table *model.Table) (*Table, error) {
return &reply, nil
}
func getTableFields(table *model.Table) (map[string]*Field, error) {
func getTableFields(table *model.Table, strict bool) (map[string]*Field, error) {
fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned")
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned)
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}

View File

@ -7,6 +7,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
@ -17,7 +18,7 @@ func TestParsePlainText(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
assert.Nil(t, err)
_, err = Parse(sqlFile, "go_zero")
_, err = Parse(sqlFile, "go_zero", false)
assert.NotNil(t, err)
}
@ -26,7 +27,7 @@ func TestParseSelect(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
assert.Nil(t, err)
tables, err := Parse(sqlFile, "go_zero")
tables, err := Parse(sqlFile, "go_zero", false)
assert.Nil(t, err)
assert.Equal(t, 0, len(tables))
}
@ -39,7 +40,7 @@ func TestParseCreateTable(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte(user), 0o777)
assert.Nil(t, err)
tables, err := Parse(sqlFile, "go_zero")
tables, err := Parse(sqlFile, "go_zero", false)
assert.Equal(t, 1, len(tables))
table := tables[0]
assert.Nil(t, err)