Skip to content

Commit

Permalink
ddl: fix the enum's default value where create table (#20849)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei authored Nov 11, 2020
1 parent 7ab3649 commit e608e4b
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 55 deletions.
29 changes: 29 additions & 0 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2944,6 +2944,35 @@ func (s *testDBSuite2) TestCreateTableWithSetCol(c *C) {
tk.MustQuery("select * from t_set").Check(testkit.Rows("1,4,10,21"))
}

func (s *testDBSuite2) TestCreateTableWithEnumCol(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
// It's for failure cases.
// The type of default value is string.
tk.MustExec("drop table if exists t_enum")
failedSQL := "create table t_enum (a enum('1', '4', '10') default '3');"
tk.MustGetErrCode(failedSQL, errno.ErrInvalidDefault)
failedSQL = "create table t_enum (a enum('1', '4', '10') default '');"
tk.MustGetErrCode(failedSQL, errno.ErrInvalidDefault)
// The type of default value is int.
failedSQL = "create table t_enum (a enum('1', '4', '10') default 0);"
tk.MustGetErrCode(failedSQL, errno.ErrInvalidDefault)
failedSQL = "create table t_enum (a enum('1', '4', '10') default 8);"
tk.MustGetErrCode(failedSQL, errno.ErrInvalidDefault)

// The type of default value is int.
// It's for successful cases
tk.MustExec("drop table if exists t_enum")
tk.MustExec("create table t_enum (a enum('2', '3', '4') default 2);")
ret := tk.MustQuery("show create table t_enum").Rows()[0][1]
c.Assert(strings.Contains(ret.(string), "`a` enum('2','3','4') DEFAULT '3'"), IsTrue)
tk.MustExec("drop table t_enum")
tk.MustExec("create table t_enum (a enum('a', 'c', 'd') default 2);")
ret = tk.MustQuery("show create table t_enum").Rows()[0][1]
c.Assert(strings.Contains(ret.(string), "`a` enum('a','c','d') DEFAULT 'c'"), IsTrue)
tk.MustExec("insert into t_enum value()")
tk.MustQuery("select * from t_enum").Check(testkit.Rows("c"))
}

func (s *testDBSuite2) TestTableForeignKey(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
53 changes: 32 additions & 21 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,10 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOpt

switch tp {
case mysql.TypeSet:
val, err := setSetDefaultValue(v, col)
val, err := getSetDefaultValue(v, col)
return val, false, err
case mysql.TypeEnum:
val, err := getEnumDefaultValue(v, col)
return val, false, err
case mysql.TypeDuration:
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil {
Expand Down Expand Up @@ -788,8 +791,8 @@ func tryToGetSequenceDefaultValue(c *ast.ColumnOption) (expr string, isExpr bool
return "", false, nil
}

// setSetDefaultValue sets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) {
// getSetDefaultValue gets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
func getSetDefaultValue(v types.Datum, col *table.Column) (string, error) {
if v.Kind() == types.KindInt64 {
setCnt := len(col.Elems)
maxLimit := int64(1<<uint(setCnt) - 1)
Expand All @@ -812,31 +815,39 @@ func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) {
if str == "" {
return str, nil
}

ctor := collate.GetCollator(col.Collate)
valMap := make(map[string]struct{}, len(col.Elems))
dVals := strings.Split(str, ",")
for _, dv := range dVals {
valMap[string(ctor.Key(dv))] = struct{}{}
setVal, err := types.ParseSetName(col.Elems, str, col.Collate)
if err != nil {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
var existCnt int
for dv := range valMap {
for i := range col.Elems {
e := string(ctor.Key(col.Elems[i]))
if e == dv {
existCnt++
break
}
v.SetMysqlSet(setVal, col.Collate)

return v.ToString()
}

// getEnumDefaultValue gets the default value for the enum type. See https://dev.mysql.com/doc/refman/5.7/en/enum.html.
func getEnumDefaultValue(v types.Datum, col *table.Column) (string, error) {
if v.Kind() == types.KindInt64 {
val := v.GetInt64()
if val < 1 || val > int64(len(col.Elems)) {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
enumVal, err := types.ParseEnumValue(col.Elems, uint64(val))
if err != nil {
return "", errors.Trace(err)
}
v.SetMysqlEnum(enumVal, col.Collate)
return v.ToString()
}
if existCnt != len(valMap) {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)

str, err := v.ToString()
if err != nil {
return "", errors.Trace(err)
}
setVal, err := types.ParseSetName(col.Elems, str, col.Collate)
enumVal, err := types.ParseEnumName(col.Elems, str, col.Collate)
if err != nil {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
v.SetMysqlSet(setVal, col.Collate)
v.SetMysqlEnum(enumVal, col.Collate)

return v.ToString()
}
Expand Down
8 changes: 4 additions & 4 deletions executor/aggfuncs/func_first_row_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ import (

func (s *testSuite) TestMergePartialResult4FirstRow(c *C) {
elems := []string{"a", "b", "c", "d", "e"}
enumA, _ := types.ParseEnumName(elems, "a", mysql.DefaultCollationName)
enumC, _ := types.ParseEnumName(elems, "c", mysql.DefaultCollationName)
enumA, _ := types.ParseEnum(elems, "a", mysql.DefaultCollationName)
enumC, _ := types.ParseEnum(elems, "c", mysql.DefaultCollationName)

setA, _ := types.ParseSetName(elems, "a", mysql.DefaultCollationName)
setAB, _ := types.ParseSetName(elems, "a,b", mysql.DefaultCollationName)
setA, _ := types.ParseSet(elems, "a", mysql.DefaultCollationName)
setAB, _ := types.ParseSet(elems, "a,b", mysql.DefaultCollationName)

tests := []aggTest{
buildAggTester(ast.AggFuncFirstRow, mysql.TypeLonglong, 5, 0, 2, 0),
Expand Down
12 changes: 6 additions & 6 deletions executor/aggfuncs/func_max_min_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ func minUpdateMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (memD

func (s *testSuite) TestMergePartialResult4MaxMin(c *C) {
elems := []string{"a", "b", "c", "d", "e"}
enumA, _ := types.ParseEnumName(elems, "a", mysql.DefaultCollationName)
enumC, _ := types.ParseEnumName(elems, "c", mysql.DefaultCollationName)
enumE, _ := types.ParseEnumName(elems, "e", mysql.DefaultCollationName)
enumA, _ := types.ParseEnum(elems, "a", mysql.DefaultCollationName)
enumC, _ := types.ParseEnum(elems, "c", mysql.DefaultCollationName)
enumE, _ := types.ParseEnum(elems, "e", mysql.DefaultCollationName)

setA, _ := types.ParseSetName(elems, "a", mysql.DefaultCollationName) // setA.Value == 1
setAB, _ := types.ParseSetName(elems, "a,b", mysql.DefaultCollationName) // setAB.Value == 3
setAC, _ := types.ParseSetName(elems, "a,c", mysql.DefaultCollationName) // setAC.Value == 5
setA, _ := types.ParseSet(elems, "a", mysql.DefaultCollationName) // setA.Value == 1
setAB, _ := types.ParseSet(elems, "a,b", mysql.DefaultCollationName) // setAB.Value == 3
setAC, _ := types.ParseSet(elems, "a,c", mysql.DefaultCollationName) // setAC.Value == 5

unsignedType := types.NewFieldType(mysql.TypeLonglong)
unsignedType.Flag |= mysql.UnsignedFlag
Expand Down
4 changes: 3 additions & 1 deletion executor/seqtest/seq_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,8 @@ func (s *seqTestSuite) TestShow(c *C) {
c6 enum('s', 'm', 'l', 'xl') default 'xl',
c7 set('a', 'b', 'c', 'd') default 'a,c,c',
c8 datetime default current_timestamp on update current_timestamp,
c9 year default '2014'
c9 year default '2014',
c10 enum('2', '3', '4') default 2
);`)
tk.MustQuery(`show columns from t`).Check(testutil.RowsWithSep("|",
"c0|int(11)|YES||1|",
Expand All @@ -629,6 +630,7 @@ func (s *seqTestSuite) TestShow(c *C) {
"c7|set('a','b','c','d')|YES||a,c|",
"c8|datetime|YES||CURRENT_TIMESTAMP|DEFAULT_GENERATED on update CURRENT_TIMESTAMP",
"c9|year(4)|YES||2014|",
"c10|enum('2','3','4')|YES||3|",
))

// Test if 'show [status|variables]' is sorted by Variable_name (#14542)
Expand Down
12 changes: 6 additions & 6 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -1444,11 +1444,11 @@ func (d *Datum) convertToMysqlEnum(sc *stmtctx.StatementContext, target *FieldTy
)
switch d.k {
case KindString, KindBytes:
e, err = ParseEnumName(target.Elems, d.GetString(), target.Collate)
e, err = ParseEnum(target.Elems, d.GetString(), target.Collate)
case KindMysqlEnum:
e, err = ParseEnumName(target.Elems, d.GetMysqlEnum().Name, target.Collate)
e, err = ParseEnum(target.Elems, d.GetMysqlEnum().Name, target.Collate)
case KindMysqlSet:
e, err = ParseEnumName(target.Elems, d.GetMysqlSet().Name, target.Collate)
e, err = ParseEnum(target.Elems, d.GetMysqlSet().Name, target.Collate)
default:
var uintDatum Datum
uintDatum, err = d.convertToUint(sc, target)
Expand All @@ -1471,11 +1471,11 @@ func (d *Datum) convertToMysqlSet(sc *stmtctx.StatementContext, target *FieldTyp
)
switch d.k {
case KindString, KindBytes:
s, err = ParseSetName(target.Elems, d.GetString(), target.Collate)
s, err = ParseSet(target.Elems, d.GetString(), target.Collate)
case KindMysqlEnum:
s, err = ParseSetName(target.Elems, d.GetMysqlEnum().Name, target.Collate)
s, err = ParseSet(target.Elems, d.GetMysqlEnum().Name, target.Collate)
case KindMysqlSet:
s, err = ParseSetName(target.Elems, d.GetMysqlSet().Name, target.Collate)
s, err = ParseSet(target.Elems, d.GetMysqlSet().Name, target.Collate)
default:
var uintDatum Datum
uintDatum, err = d.convertToUint(sc, target)
Expand Down
18 changes: 13 additions & 5 deletions types/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ func (e Enum) ToNumber() float64 {
return float64(e.Value)
}

// ParseEnum creates a Enum with item name or value.
func ParseEnum(elems []string, name string, collation string) (Enum, error) {
if enumName, err := ParseEnumName(elems, name, collation); err == nil {
return enumName, nil
}
// name doesn't exist, maybe an integer?
if num, err := strconv.ParseUint(name, 0, 64); err == nil {
return ParseEnumValue(elems, num)
}

return Enum{}, errors.Errorf("item %s is not in enum %v", name, elems)
}

// ParseEnumName creates a Enum with item name.
func ParseEnumName(elems []string, name string, collation string) (Enum, error) {
ctor := collate.GetCollator(collation)
Expand All @@ -54,11 +67,6 @@ func ParseEnumName(elems []string, name string, collation string) (Enum, error)
}
}

// name doesn't exist, maybe an integer?
if num, err := strconv.ParseUint(name, 0, 64); err == nil {
return ParseEnumValue(elems, num)
}

return Enum{}, errors.Errorf("item %s is not in enum %v", name, elems)
}

Expand Down
6 changes: 3 additions & 3 deletions types/enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (s *testEnumSuite) TestEnum(c *C) {
}

for _, t := range tbl {
e, err := ParseEnumName(t.Elems, t.Name, mysql.DefaultCollationName)
e, err := ParseEnum(t.Elems, t.Name, mysql.DefaultCollationName)
if t.Expected == 0 {
c.Assert(err, NotNil)
c.Assert(e.ToNumber(), Equals, float64(0))
Expand All @@ -65,7 +65,7 @@ func (s *testEnumSuite) TestEnum(c *C) {
}

for _, t := range tbl {
e, err := ParseEnumName(t.Elems, t.Name, "utf8_unicode_ci")
e, err := ParseEnum(t.Elems, t.Name, "utf8_unicode_ci")
if t.Expected == 0 {
c.Assert(err, NotNil)
c.Assert(e.ToNumber(), Equals, float64(0))
Expand All @@ -79,7 +79,7 @@ func (s *testEnumSuite) TestEnum(c *C) {
}

for _, t := range citbl {
e, err := ParseEnumName(t.Elems, t.Name, "utf8_general_ci")
e, err := ParseEnum(t.Elems, t.Name, "utf8_general_ci")
if t.Expected == 0 {
c.Assert(err, NotNil)
c.Assert(e.ToNumber(), Equals, float64(0))
Expand Down
18 changes: 13 additions & 5 deletions types/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ func (e Set) Copy() Set {
}
}

// ParseSet creates a Set with name or value.
func ParseSet(elems []string, name string, collation string) (Set, error) {
if setName, err := ParseSetName(elems, name, collation); err == nil {
return setName, nil
}
// name doesn't exist, maybe an integer?
if num, err := strconv.ParseUint(name, 0, 64); err == nil {
return ParseSetValue(elems, num)
}

return Set{}, errors.Errorf("item %s is not in Set %v", name, elems)
}

// ParseSetName creates a Set with name.
func ParseSetName(elems []string, name string, collation string) (Set, error) {
if len(name) == 0 {
Expand Down Expand Up @@ -77,11 +90,6 @@ func ParseSetName(elems []string, name string, collation string) (Set, error) {
return Set{Name: strings.Join(items, ","), Value: value}, nil
}

// name doesn't exist, maybe an integer?
if num, err := strconv.ParseUint(name, 0, 64); err == nil {
return ParseSetValue(elems, num)
}

return Set{}, errors.Errorf("item %s is not in Set %v", name, elems)
}

Expand Down
8 changes: 4 additions & 4 deletions types/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,21 @@ func (s *testSetSuite) TestSet(c *C) {
}

for _, t := range tbl {
e, err := ParseSetName(elems, t.Name, mysql.DefaultCollationName)
e, err := ParseSet(elems, t.Name, mysql.DefaultCollationName)
c.Assert(err, IsNil)
c.Assert(e.ToNumber(), Equals, float64(t.ExpectedValue))
c.Assert(e.String(), Equals, t.ExpectedName)
}

for _, t := range tbl {
e, err := ParseSetName(elems, t.Name, "utf8_unicode_ci")
e, err := ParseSet(elems, t.Name, "utf8_unicode_ci")
c.Assert(err, IsNil)
c.Assert(e.ToNumber(), Equals, float64(t.ExpectedValue))
c.Assert(e.String(), Equals, t.ExpectedName)
}

for _, t := range citbl {
e, err := ParseSetName(elems, t.Name, "utf8_general_ci")
e, err := ParseSet(elems, t.Name, "utf8_general_ci")
c.Assert(err, IsNil)
c.Assert(e.ToNumber(), Equals, float64(t.ExpectedValue))
c.Assert(e.String(), Equals, t.ExpectedName)
Expand Down Expand Up @@ -95,7 +95,7 @@ func (s *testSetSuite) TestSet(c *C) {
"e.f",
}
for _, t := range tblErr {
_, err := ParseSetName(elems, t, mysql.DefaultCollationName)
_, err := ParseSet(elems, t, mysql.DefaultCollationName)
c.Assert(err, NotNil)
}

Expand Down

0 comments on commit e608e4b

Please sign in to comment.