Skip to content

Commit

Permalink
Enum types for model fields (volatiletech#1032)
Browse files Browse the repository at this point in the history
  • Loading branch information
optiman authored Jan 19, 2022
1 parent 95140c7 commit dd5f181
Show file tree
Hide file tree
Showing 21 changed files with 5,223 additions and 99 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ not to pass them through the command line or environment variables:
| debug | false |
| add-global-variants | false |
| add-panic-variants | false |
| add-enum-types | false |
| enum-null-prefix | "Null" |
| no-context | false |
| no-hooks | false |
| no-tests | false |
Expand All @@ -393,6 +395,7 @@ not to pass them through the command line or environment variables:
output = "my_models"
wipe = true
no-tests = true
add-enum-types = true

[psql]
dbname = "dbname"
Expand Down Expand Up @@ -441,6 +444,7 @@ Flags:
--add-panic-variants Enable generation for panic variants
--add-soft-deletes Enable soft deletion by updating deleted_at timestamp
--add-enum-types Enable generation of types for enums
--enum-null-prefix Name prefix of nullable enum types (default "Null")
-c, --config string Filename of config file to override default lookup
-d, --debug Debug mode prints stack traces on error
-h, --help help for sqlboiler
Expand Down
17 changes: 16 additions & 1 deletion boilingcore/boilingcore.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
"strings"

"github.com/friendsofgo/errors"
"github.com/volatiletech/strmangle"

"github.com/volatiletech/sqlboiler/v4/drivers"
"github.com/volatiletech/sqlboiler/v4/importers"
boiltemplates "github.com/volatiletech/sqlboiler/v4/templates"
"github.com/volatiletech/strmangle"
)

var (
Expand Down Expand Up @@ -90,6 +91,10 @@ func New(config *Config) (*State, error) {
return nil, errors.Wrap(err, "unable to merge imports from driver")
}

if s.Config.AddEnumTypes {
s.mergeEnumImports()
}

if !s.Config.NoContext {
s.Config.Imports.All.Standard = append(s.Config.Imports.All.Standard, `"context"`)
s.Config.Imports.Test.Standard = append(s.Config.Imports.Test.Standard, `"context"`)
Expand Down Expand Up @@ -134,6 +139,7 @@ func (s *State) Run() error {
AddPanic: s.Config.AddPanic,
AddSoftDeletes: s.Config.AddSoftDeletes,
AddEnumTypes: s.Config.AddEnumTypes,
EnumNullPrefix: s.Config.EnumNullPrefix,
NoContext: s.Config.NoContext,
NoHooks: s.Config.NoHooks,
NoAutoTimestamps: s.Config.NoAutoTimestamps,
Expand Down Expand Up @@ -424,6 +430,15 @@ func (s *State) mergeDriverImports() error {
return nil
}

// mergeEnumImports merges imports for nullable enum types
// into the current configuration's imports if tables returned
// from the driver have nullable enum columns.
func (s *State) mergeEnumImports() {
if drivers.TablesHaveNullableEnums(s.Tables) {
s.Config.Imports = importers.Merge(s.Config.Imports, importers.NullableEnumImports())
}
}

// processTypeReplacements checks the config for type replacements
// and performs them.
func (s *State) processTypeReplacements() error {
Expand Down
2 changes: 2 additions & 0 deletions boilingcore/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/spf13/cast"

"github.com/volatiletech/sqlboiler/v4/drivers"
"github.com/volatiletech/sqlboiler/v4/importers"
)
Expand All @@ -24,6 +25,7 @@ type Config struct {
AddPanic bool `toml:"add_panic,omitempty" json:"add_panic,omitempty"`
AddSoftDeletes bool `toml:"add_soft_deletes,omitempty" json:"add_soft_deletes,omitempty"`
AddEnumTypes bool `toml:"add_enum_types,omitempty" json:"add_enum_types,omitempty"`
EnumNullPrefix string `toml:"enum_null_prefix,omitempty" json:"enum_null_prefix,omitempty"`
NoContext bool `toml:"no_context,omitempty" json:"no_context,omitempty"`
NoTests bool `toml:"no_tests,omitempty" json:"no_tests,omitempty"`
NoHooks bool `toml:"no_hooks,omitempty" json:"no_hooks,omitempty"`
Expand Down
16 changes: 9 additions & 7 deletions boilingcore/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type templateData struct {
AddPanic bool
AddSoftDeletes bool
AddEnumTypes bool
EnumNullPrefix string
NoContext bool
NoHooks bool
NoAutoTimestamps bool
Expand Down Expand Up @@ -314,11 +315,12 @@ var templateFunctions = template.FuncMap{
},

// dbdrivers ops
"filterColumnsByAuto": drivers.FilterColumnsByAuto,
"filterColumnsByDefault": drivers.FilterColumnsByDefault,
"filterColumnsByEnum": drivers.FilterColumnsByEnum,
"sqlColDefinitions": drivers.SQLColDefinitions,
"columnNames": drivers.ColumnNames,
"columnDBTypes": drivers.ColumnDBTypes,
"getTable": drivers.GetTable,
"filterColumnsByAuto": drivers.FilterColumnsByAuto,
"filterColumnsByDefault": drivers.FilterColumnsByDefault,
"filterColumnsByEnum": drivers.FilterColumnsByEnum,
"sqlColDefinitions": drivers.SQLColDefinitions,
"columnNames": drivers.ColumnNames,
"columnDBTypes": drivers.ColumnDBTypes,
"getTable": drivers.GetTable,
"tablesHaveNullableEnums": drivers.TablesHaveNullableEnums,
}
24 changes: 19 additions & 5 deletions drivers/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ import (

// These constants are used in the config map passed into the driver
const (
ConfigBlacklist = "blacklist"
ConfigWhitelist = "whitelist"
ConfigSchema = "schema"
ConfigBlacklist = "blacklist"
ConfigWhitelist = "whitelist"
ConfigSchema = "schema"
ConfigAddEnumTypes = "add-enum-types"
ConfigEnumNullPrefix = "enum-null-prefix"

ConfigUser = "user"
ConfigPass = "pass"
Expand Down Expand Up @@ -77,6 +79,11 @@ type Constructor interface {
TranslateColumnType(Column) Column
}

type TableColumnTypeTranslator interface {
// TranslateTableColumnType takes a Database column type and table name and returns a go column type.
TranslateTableColumnType(c Column, tableName string) Column
}

// Tables returns the metadata for all tables, minus the tables
// specified in the blacklist.
func Tables(c Constructor, schema string, whitelist, blacklist []string) ([]Table, error) {
Expand All @@ -99,8 +106,15 @@ func Tables(c Constructor, schema string, whitelist, blacklist []string) ([]Tabl
return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name)
}

for i, col := range t.Columns {
t.Columns[i] = c.TranslateColumnType(col)
tr, ok := c.(TableColumnTypeTranslator)
if ok {
for i, col := range t.Columns {
t.Columns[i] = tr.TranslateTableColumnType(col, name)
}
} else {
for i, col := range t.Columns {
t.Columns[i] = c.TranslateColumnType(col)
}
}

if t.PKey, err = c.PrimaryKeyInfo(schema, name); err != nil {
Expand Down
34 changes: 27 additions & 7 deletions drivers/sqlboiler-mysql/driver/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/go-sql-driver/mysql"
"github.com/volatiletech/sqlboiler/v4/drivers"
"github.com/volatiletech/sqlboiler/v4/importers"
"github.com/volatiletech/strmangle"
)

//go:embed override
Expand All @@ -32,10 +33,11 @@ func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) {
// MySQLDriver holds the database connection string and a handle
// to the database connection.
type MySQLDriver struct {
connStr string
conn *sql.DB

tinyIntAsInt bool
connStr string
conn *sql.DB
addEnumTypes bool
enumNullPrefix string
tinyIntAsInt bool
}

// Templates that should be added/overridden
Expand Down Expand Up @@ -89,6 +91,8 @@ func (m *MySQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, e
}
}

m.addEnumTypes, _ = config[drivers.ConfigAddEnumTypes].(bool)
m.enumNullPrefix = strmangle.TitleCase(config.DefaultString(drivers.ConfigEnumNullPrefix, "Null"))
m.connStr = MySQLBuildQueryString(user, pass, dbname, host, port, sslmode)
m.conn, err = sql.Open("mysql", m.connStr)
if err != nil {
Expand Down Expand Up @@ -368,7 +372,15 @@ func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.Foreig
// TranslateColumnType converts mysql database types to Go types, for example
// "varchar" to "string" and "bigint" to "int64". It returns this parsed data
// as a Column object.
func (m *MySQLDriver) TranslateColumnType(c drivers.Column) drivers.Column {
// Deprecated: for MySQL enum types to be created properly TranslateTableColumnType method should be used instead.
func (m *MySQLDriver) TranslateColumnType(drivers.Column) drivers.Column {
panic("TranslateTableColumnType should be called")
}

// TranslateTableColumnType converts mysql database types to Go types, for example
// "varchar" to "string" and "bigint" to "int64". It returns this parsed data
// as a Column object.
func (m *MySQLDriver) TranslateTableColumnType(c drivers.Column, tableName string) drivers.Column {
unsigned := strings.Contains(c.FullDBType, "unsigned")
if c.Nullable {
switch c.DBType {
Expand Down Expand Up @@ -420,7 +432,11 @@ func (m *MySQLDriver) TranslateColumnType(c drivers.Column) drivers.Column {
case "json":
c.Type = "null.JSON"
default:
c.Type = "null.String"
if len(strmangle.ParseEnumVals(c.DBType)) > 0 && m.addEnumTypes {
c.Type = strmangle.TitleCase(tableName) + m.enumNullPrefix + strmangle.TitleCase(c.Name)
} else {
c.Type = "null.String"
}
}
} else {
switch c.DBType {
Expand Down Expand Up @@ -472,7 +488,11 @@ func (m *MySQLDriver) TranslateColumnType(c drivers.Column) drivers.Column {
case "json":
c.Type = "types.JSON"
default:
c.Type = "string"
if len(strmangle.ParseEnumVals(c.DBType)) > 0 && m.addEnumTypes {
c.Type = strmangle.TitleCase(tableName) + strmangle.TitleCase(c.Name)
} else {
c.Type = "string"
}
}
}

Expand Down
Loading

0 comments on commit dd5f181

Please sign in to comment.