From 586eacd62f3b563e44ac9d337a38738335d48286 Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 26 Sep 2019 12:18:20 +0800 Subject: [PATCH] ddl: fix the set's default value where `create table` (#12267) --- ddl/db_test.go | 55 +++++++++++++++++++++++++++++++- ddl/ddl.go | 3 ++ ddl/ddl_api.go | 65 +++++++++++++++++++++++++++++++++++--- planner/core/preprocess.go | 2 +- 4 files changed, 118 insertions(+), 7 deletions(-) diff --git a/ddl/db_test.go b/ddl/db_test.go index d997dd49c52a4..90c49d6c163a5 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -125,7 +125,7 @@ func assertErrorCode(c *C, tk *testkit.TestKit, sql string, errCode int) { originErr := errors.Cause(err) tErr, ok := originErr.(*terror.Error) c.Assert(ok, IsTrue, Commentf("err: %T", originErr)) - c.Assert(tErr.ToSQLError().Code, DeepEquals, uint16(errCode), Commentf("MySQL code:%v", tErr.ToSQLError())) + c.Assert(tErr.ToSQLError().Code, DeepEquals, uint16(errCode), Commentf("MySQL code:%v, err %v", tErr.ToSQLError().Code, tErr.ToSQLError())) } func (s *testDBSuite) testErrorCode(c *C, sql string, errCode int) { @@ -2149,6 +2149,59 @@ func (s *testDBSuite) TestCreateTable(c *C) { c.Assert(err, NotNil) } +func (s *testDBSuite) TestCreateTableWithSetCol(c *C) { + s.tk = testkit.NewTestKitWithInit(c, s.store) + s.tk.MustExec("create table t_set (a int, b set('e') default '');") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` set('e') DEFAULT ''\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('a', 'b', 'c', 'd') default 'a,C,c');") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('a','b','c','d') DEFAULT 'a,c'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // It's for failure cases. + // The type of default value is string. + s.tk.MustExec("drop table t_set") + failedSQL := "create table t_set (a set('1', '4', '10') default '3');" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default '1,4,11');" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default '1 ,4');" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + // The type of default value is int. + failedSQL = "create table t_set (a set('1', '4', '10') default 0);" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default 8);" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + + // The type of default value is int. + // It's for successful cases + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 1);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 2);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '4'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 3);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1,4'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 15);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1,4,10,21'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("insert into t_set value()") + s.tk.MustQuery("select * from t_set").Check(testkit.Rows("1,4,10,21")) +} + func (s *testDBSuite) TestTableForeignKey(c *C) { s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") diff --git a/ddl/ddl.go b/ddl/ddl.go index db702067aa5a1..8af5e6c0b1840 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -176,6 +176,8 @@ var ( ErrConflictingDeclarations = terror.ClassDDL.New(codeConflictingDeclarations, "Conflicting declarations: 'CHARACTER SET %s' and 'CHARACTER SET %s'") // ErrPrimaryCantHaveNull returns All parts of a PRIMARY KEY must be NOT NULL; if you need NULL in a key, use UNIQUE instead ErrPrimaryCantHaveNull = terror.ClassDDL.New(codePrimaryCantHaveNull, mysql.MySQLErrName[mysql.ErrPrimaryCantHaveNull]) + // ErrInvalidDefaultValue returns for invalid default value for columns. + ErrInvalidDefaultValue = terror.ClassDDL.New(codeInvalidDefaultValue, mysql.MySQLErrName[mysql.ErrInvalidDefault]) // ErrNotAllowedTypeInPartition returns not allowed type error when creating table partiton with unsupport expression type. ErrNotAllowedTypeInPartition = terror.ClassDDL.New(codeErrFieldTypeNotAllowedAsPartitionField, mysql.MySQLErrName[mysql.ErrFieldTypeNotAllowedAsPartitionField]) @@ -711,6 +713,7 @@ func init() { codeUnknownCollation: mysql.ErrUnknownCollation, codeCollationCharsetMismatch: mysql.ErrCollationCharsetMismatch, codeConflictingDeclarations: mysql.ErrConflictingDeclarations, + codeInvalidDefaultValue: mysql.ErrInvalidDefault, } terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes } diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 6c55317e14b44..cede8e7b42703 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -577,8 +577,8 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o return col, constraints, nil } -func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, t *types.FieldType) (interface{}, error) { - tp, fsp := t.Tp, t.Decimal +func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, error) { + tp, fsp := col.FieldType.Tp, col.FieldType.Decimal if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { vd, err := expression.GetTimeValue(ctx, c.Expr, tp, fsp) value := vd.GetValue() @@ -620,7 +620,10 @@ func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, t *types.Field return v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx) } - if tp == mysql.TypeBit { + switch tp { + case mysql.TypeSet: + return setSetDefaultValue(v, col) + case mysql.TypeBit: if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 { // For BIT fields, convert int into BinaryLiteral. return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil @@ -630,6 +633,58 @@ func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, t *types.Field return v.ToString() } +// setSetDefaultValue sets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html. +func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) { + if v.Kind() == types.KindInt64 { + setCnt := len(col.Elems) + maxLimit := int64(1< maxLimit { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + setVal, err := types.ParseSetValue(col.Elems, uint64(val)) + if err != nil { + return "", errors.Trace(err) + } + v.SetMysqlSet(setVal) + return v.ToString() + } + + str, err := v.ToString() + if err != nil { + return "", errors.Trace(err) + } + if str == "" { + return str, nil + } + + valMap := make(map[string]struct{}, len(col.Elems)) + dVals := strings.Split(strings.ToLower(str), ",") + for _, dv := range dVals { + valMap[dv] = struct{}{} + } + var existCnt int + for dv := range valMap { + for i := range col.Elems { + e := strings.ToLower(col.Elems[i]) + if e == dv { + existCnt++ + break + } + } + } + if existCnt != len(valMap) { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + setVal, err := types.ParseSetName(col.Elems, str) + if err != nil { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + v.SetMysqlSet(setVal) + + return v.ToString() +} + func removeOnUpdateNowFlag(c *table.Column) { // For timestamp Col, if it is set null or default value, // OnUpdateNowFlag should be removed. @@ -1904,9 +1959,9 @@ func modifiable(origin *types.FieldType, to *types.FieldType) error { func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) { hasDefaultValue := false - value, err := getDefaultValue(ctx, option, &col.FieldType) + value, err := getDefaultValue(ctx, col, option) if err != nil { - return hasDefaultValue, ErrColumnBadNull.GenWithStack("invalid default value - %s", err) + return hasDefaultValue, errors.Trace(err) } if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil { diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index d157550d36ebe..370ae2be041cc 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -499,7 +499,7 @@ func checkColumn(colDef *ast.ColumnDef) error { if len(tp.Elems) > mysql.MaxTypeSetMembers { return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colDef.Name.Name.O) } - // Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html . + // Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html. for _, str := range colDef.Tp.Elems { if strings.Contains(str, ",") { return types.ErrIllegalValueForType.GenWithStackByArgs(types.TypeStr(tp.Tp), str)