Skip to content

Commit

Permalink
ddl: fix the set's default value where create table (#12267)
Browse files Browse the repository at this point in the history
  • Loading branch information
zimulala authored Sep 26, 2019
1 parent dab72fb commit 20bdf44
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 13 deletions.
53 changes: 53 additions & 0 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,59 @@ func (s *testDBSuite1) TestCreateTable(c *C) {
c.Assert(err.Error(), Equals, "[types:1291]Column 'a' has duplicated value 'B' in ENUM")
}

func (s *testDBSuite2) TestCreateTableWithSetCol(c *C) {
s.tk = testkit.NewTestKitWithInit(c, s.store)
s.tk.MustExec("create table t_set (a int, b set('e') default '');")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` int(11) DEFAULT NULL,\n" +
" `b` set('e') DEFAULT ''\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('a', 'b', 'c', 'd') default 'a,C,c');")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('a','b','c','d') DEFAULT 'a,c'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

// It's for failure cases.
// The type of default value is string.
s.tk.MustExec("drop table t_set")
failedSQL := "create table t_set (a set('1', '4', '10') default '3');"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default '1,4,11');"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default '1 ,4');"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
// The type of default value is int.
failedSQL = "create table t_set (a set('1', '4', '10') default 0);"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default 8);"
s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault)

// The type of default value is int.
// It's for successful cases
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 1);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 2);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '4'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 3);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1,4'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 15);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1,4,10,21'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("insert into t_set value()")
s.tk.MustQuery("select * from t_set").Check(testkit.Rows("1,4,10,21"))
}

func (s *testDBSuite2) TestTableForeignKey(c *C) {
s.tk = testkit.NewTestKit(c, s.store)
s.tk.MustExec("use test")
Expand Down
74 changes: 63 additions & 11 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,8 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
return col, constraints, nil
}

func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption, t *types.FieldType) (interface{}, error) {
tp, fsp := t.Tp, t.Decimal
func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, error) {
tp, fsp := col.FieldType.Tp, col.FieldType.Decimal
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
switch x := c.Expr.(type) {
case *ast.FuncCallExpr:
Expand All @@ -633,14 +633,14 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption
}
}
if defaultFsp != fsp {
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName)
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
}
}
vd, err := expression.GetTimeValue(ctx, c.Expr, tp, int8(fsp))
value := vd.GetValue()
if err != nil {
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName)
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}

// Value is nil means `default null`.
Expand Down Expand Up @@ -681,14 +681,14 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption
return strconv.FormatUint(value, 10), nil
}

if tp == mysql.TypeDuration {
var err error
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, t); err != nil {
switch tp {
case mysql.TypeSet:
return setSetDefaultValue(v, col)
case mysql.TypeDuration:
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil {
return "", errors.Trace(err)
}
}

if tp == mysql.TypeBit {
case mysql.TypeBit:
if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 {
// For BIT fields, convert int into BinaryLiteral.
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil
Expand All @@ -698,6 +698,58 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption
return v.ToString()
}

// setSetDefaultValue sets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) {
if v.Kind() == types.KindInt64 {
setCnt := len(col.Elems)
maxLimit := int64(1<<uint(setCnt) - 1)
val := v.GetInt64()
if val < 1 || val > maxLimit {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
setVal, err := types.ParseSetValue(col.Elems, uint64(val))
if err != nil {
return "", errors.Trace(err)
}
v.SetMysqlSet(setVal)
return v.ToString()
}

str, err := v.ToString()
if err != nil {
return "", errors.Trace(err)
}
if str == "" {
return str, nil
}

valMap := make(map[string]struct{}, len(col.Elems))
dVals := strings.Split(strings.ToLower(str), ",")
for _, dv := range dVals {
valMap[dv] = struct{}{}
}
var existCnt int
for dv := range valMap {
for i := range col.Elems {
e := strings.ToLower(col.Elems[i])
if e == dv {
existCnt++
break
}
}
}
if existCnt != len(valMap) {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
setVal, err := types.ParseSetName(col.Elems, str)
if err != nil {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
v.SetMysqlSet(setVal)

return v.ToString()
}

func removeOnUpdateNowFlag(c *table.Column) {
// For timestamp Col, if it is set null or default value,
// OnUpdateNowFlag should be removed.
Expand Down Expand Up @@ -2491,7 +2543,7 @@ func modifiable(origin *types.FieldType, to *types.FieldType) error {

func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) {
hasDefaultValue := false
value, err := getDefaultValue(ctx, col.Name.L, option, &col.FieldType)
value, err := getDefaultValue(ctx, col, option)
if err != nil {
return hasDefaultValue, errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion executor/seqtest/seq_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ func (s *seqTestSuite) TestShow(c *C) {
"c4|varchar(6)|YES||1|",
"c5|varchar(6)|YES||'C6'|",
"c6|enum('s','m','l','xl')|YES||xl|",
"c7|set('a','b','c','d')|YES||a,c,c|",
"c7|set('a','b','c','d')|YES||a,c|",
"c8|datetime|YES||CURRENT_TIMESTAMP|DEFAULT_GENERATED on update CURRENT_TIMESTAMP",
"c9|year(4)|YES||2014|",
))
Expand Down
2 changes: 1 addition & 1 deletion planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ func checkColumn(colDef *ast.ColumnDef) error {
if len(tp.Elems) > mysql.MaxTypeSetMembers {
return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colDef.Name.Name.O)
}
// Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html .
// Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
for _, str := range colDef.Tp.Elems {
if strings.Contains(str, ",") {
return types.ErrIllegalValueForType.GenWithStackByArgs(types.TypeStr(tp.Tp), str)
Expand Down

0 comments on commit 20bdf44

Please sign in to comment.