Skip to content

Commit

Permalink
ddl: fix the set's default value where create table (pingcap#12267)
Browse files Browse the repository at this point in the history
  • Loading branch information
zimulala committed Sep 29, 2019
1 parent 89b35b3 commit 586eacd
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 7 deletions.
55 changes: 54 additions & 1 deletion ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func assertErrorCode(c *C, tk *testkit.TestKit, sql string, errCode int) {
originErr := errors.Cause(err)
tErr, ok := originErr.(*terror.Error)
c.Assert(ok, IsTrue, Commentf("err: %T", originErr))
c.Assert(tErr.ToSQLError().Code, DeepEquals, uint16(errCode), Commentf("MySQL code:%v", tErr.ToSQLError()))
c.Assert(tErr.ToSQLError().Code, DeepEquals, uint16(errCode), Commentf("MySQL code:%v, err %v", tErr.ToSQLError().Code, tErr.ToSQLError()))
}

func (s *testDBSuite) testErrorCode(c *C, sql string, errCode int) {
Expand Down Expand Up @@ -2149,6 +2149,59 @@ func (s *testDBSuite) TestCreateTable(c *C) {
c.Assert(err, NotNil)
}

func (s *testDBSuite) TestCreateTableWithSetCol(c *C) {
s.tk = testkit.NewTestKitWithInit(c, s.store)
s.tk.MustExec("create table t_set (a int, b set('e') default '');")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` int(11) DEFAULT NULL,\n" +
" `b` set('e') DEFAULT ''\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('a', 'b', 'c', 'd') default 'a,C,c');")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('a','b','c','d') DEFAULT 'a,c'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

// It's for failure cases.
// The type of default value is string.
s.tk.MustExec("drop table t_set")
failedSQL := "create table t_set (a set('1', '4', '10') default '3');"
assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default '1,4,11');"
assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default '1 ,4');"
assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault)
// The type of default value is int.
failedSQL = "create table t_set (a set('1', '4', '10') default 0);"
assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault)
failedSQL = "create table t_set (a set('1', '4', '10') default 8);"
assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault)

// The type of default value is int.
// It's for successful cases
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 1);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 2);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '4'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 3);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1,4'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("drop table t_set")
s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 15);")
s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" +
" `a` set('1','4','10','21') DEFAULT '1,4,10,21'\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))
s.tk.MustExec("insert into t_set value()")
s.tk.MustQuery("select * from t_set").Check(testkit.Rows("1,4,10,21"))
}

func (s *testDBSuite) TestTableForeignKey(c *C) {
s.tk = testkit.NewTestKit(c, s.store)
s.tk.MustExec("use test")
Expand Down
3 changes: 3 additions & 0 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ var (
ErrConflictingDeclarations = terror.ClassDDL.New(codeConflictingDeclarations, "Conflicting declarations: 'CHARACTER SET %s' and 'CHARACTER SET %s'")
// ErrPrimaryCantHaveNull returns All parts of a PRIMARY KEY must be NOT NULL; if you need NULL in a key, use UNIQUE instead
ErrPrimaryCantHaveNull = terror.ClassDDL.New(codePrimaryCantHaveNull, mysql.MySQLErrName[mysql.ErrPrimaryCantHaveNull])
// ErrInvalidDefaultValue returns for invalid default value for columns.
ErrInvalidDefaultValue = terror.ClassDDL.New(codeInvalidDefaultValue, mysql.MySQLErrName[mysql.ErrInvalidDefault])

// ErrNotAllowedTypeInPartition returns not allowed type error when creating table partiton with unsupport expression type.
ErrNotAllowedTypeInPartition = terror.ClassDDL.New(codeErrFieldTypeNotAllowedAsPartitionField, mysql.MySQLErrName[mysql.ErrFieldTypeNotAllowedAsPartitionField])
Expand Down Expand Up @@ -711,6 +713,7 @@ func init() {
codeUnknownCollation: mysql.ErrUnknownCollation,
codeCollationCharsetMismatch: mysql.ErrCollationCharsetMismatch,
codeConflictingDeclarations: mysql.ErrConflictingDeclarations,
codeInvalidDefaultValue: mysql.ErrInvalidDefault,
}
terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes
}
65 changes: 60 additions & 5 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,8 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
return col, constraints, nil
}

func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, t *types.FieldType) (interface{}, error) {
tp, fsp := t.Tp, t.Decimal
func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, error) {
tp, fsp := col.FieldType.Tp, col.FieldType.Decimal
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
vd, err := expression.GetTimeValue(ctx, c.Expr, tp, fsp)
value := vd.GetValue()
Expand Down Expand Up @@ -620,7 +620,10 @@ func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, t *types.Field
return v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx)
}

if tp == mysql.TypeBit {
switch tp {
case mysql.TypeSet:
return setSetDefaultValue(v, col)
case mysql.TypeBit:
if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 {
// For BIT fields, convert int into BinaryLiteral.
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil
Expand All @@ -630,6 +633,58 @@ func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, t *types.Field
return v.ToString()
}

// setSetDefaultValue sets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) {
if v.Kind() == types.KindInt64 {
setCnt := len(col.Elems)
maxLimit := int64(1<<uint(setCnt) - 1)
val := v.GetInt64()
if val < 1 || val > maxLimit {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
setVal, err := types.ParseSetValue(col.Elems, uint64(val))
if err != nil {
return "", errors.Trace(err)
}
v.SetMysqlSet(setVal)
return v.ToString()
}

str, err := v.ToString()
if err != nil {
return "", errors.Trace(err)
}
if str == "" {
return str, nil
}

valMap := make(map[string]struct{}, len(col.Elems))
dVals := strings.Split(strings.ToLower(str), ",")
for _, dv := range dVals {
valMap[dv] = struct{}{}
}
var existCnt int
for dv := range valMap {
for i := range col.Elems {
e := strings.ToLower(col.Elems[i])
if e == dv {
existCnt++
break
}
}
}
if existCnt != len(valMap) {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
setVal, err := types.ParseSetName(col.Elems, str)
if err != nil {
return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
v.SetMysqlSet(setVal)

return v.ToString()
}

func removeOnUpdateNowFlag(c *table.Column) {
// For timestamp Col, if it is set null or default value,
// OnUpdateNowFlag should be removed.
Expand Down Expand Up @@ -1904,9 +1959,9 @@ func modifiable(origin *types.FieldType, to *types.FieldType) error {

func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) {
hasDefaultValue := false
value, err := getDefaultValue(ctx, option, &col.FieldType)
value, err := getDefaultValue(ctx, col, option)
if err != nil {
return hasDefaultValue, ErrColumnBadNull.GenWithStack("invalid default value - %s", err)
return hasDefaultValue, errors.Trace(err)
}

if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ func checkColumn(colDef *ast.ColumnDef) error {
if len(tp.Elems) > mysql.MaxTypeSetMembers {
return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colDef.Name.Name.O)
}
// Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html .
// Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
for _, str := range colDef.Tp.Elems {
if strings.Contains(str, ",") {
return types.ErrIllegalValueForType.GenWithStackByArgs(types.TypeStr(tp.Tp), str)
Expand Down

0 comments on commit 586eacd

Please sign in to comment.