Skip to content

Commit

Permalink
*: support default expression value for sequence (#14589)
Browse files Browse the repository at this point in the history
  • Loading branch information
AilinKid authored Feb 13, 2020
1 parent 114405e commit 007c0e6
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 23 deletions.
2 changes: 2 additions & 0 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ var (
ErrUnknownSequence = terror.ClassDDL.New(mysql.ErrUnknownSequence, mysql.MySQLErrName[mysql.ErrUnknownSequence])
// ErrSequenceUnsupportedTableOption returns when unsupported table option exists in sequence.
ErrSequenceUnsupportedTableOption = terror.ClassDDL.New(mysql.ErrSequenceUnsupportedTableOption, mysql.MySQLErrName[mysql.ErrSequenceUnsupportedTableOption])
// ErrColumnTypeUnsupportedNextValue is returned when sequence next value is assigned to unsupported column type.
ErrColumnTypeUnsupportedNextValue = terror.ClassDDL.New(mysql.ErrColumnTypeUnsupportedNextValue, mysql.MySQLErrName[mysql.ErrColumnTypeUnsupportedNextValue])
)

// DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache.
Expand Down
76 changes: 59 additions & 17 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,13 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value in
return hasDefaultValue, value, nil
}

func checkSequenceDefaultValue(col *table.Column) error {
if mysql.IsIntegerType(col.Tp) {
return nil
}
return ErrColumnTypeUnsupportedNextValue.GenWithStackByArgs(col.ColumnInfo.Name.O)
}

func convertTimestampDefaultValToUTC(ctx sessionctx.Context, defaultVal interface{}, col *table.Column) (interface{}, error) {
if defaultVal == nil || col.Tp != mysql.TypeTimestamp {
return defaultVal, nil
Expand Down Expand Up @@ -638,7 +645,10 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
return col, constraints, nil
}

func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, error) {
// getDefault value will get the default value for column.
// 1: get the expr restored string for the column which uses sequence next value as default value.
// 2: get specific default value for the other column.
func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, bool, error) {
tp, fsp := col.FieldType.Tp, col.FieldType.Decimal
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
switch x := c.Expr.(type) {
Expand All @@ -651,35 +661,45 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOpt
}
}
if defaultFsp != fsp {
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
return nil, false, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
}
}
vd, err := expression.GetTimeValue(ctx, c.Expr, tp, int8(fsp))
value := vd.GetValue()
if err != nil {
return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
return nil, false, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}

// Value is nil means `default null`.
if value == nil {
return nil, nil
return nil, false, nil
}

// If value is types.Time, convert it to string.
if vv, ok := value.(types.Time); ok {
return vv.String(), nil
return vv.String(), false, nil
}

return value, nil
return value, false, nil
}
// handle default next value of sequence. (keep the expr string)
str, isSeqExpr, err := tryToGetSequenceDefaultValue(c)
if err != nil {
return nil, false, errors.Trace(err)
}
if isSeqExpr {
return str, true, nil
}

// evaluate the non-sequence expr to a certain value.
v, err := expression.EvalAstExpr(ctx, c.Expr)
if err != nil {
return nil, errors.Trace(err)
return nil, false, errors.Trace(err)
}

if v.IsNull() {
return nil, nil
return nil, false, nil
}

if v.Kind() == types.KindBinaryLiteral || v.Kind() == types.KindMysqlBit {
Expand All @@ -689,31 +709,47 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOpt
tp == mysql.TypeJSON {
// For BinaryLiteral / string fields, when getting default value we cast the value into BinaryLiteral{}, thus we return
// its raw string content here.
return v.GetBinaryLiteral().ToString(), nil
return v.GetBinaryLiteral().ToString(), false, nil
}
// For other kind of fields (e.g. INT), we supply its integer as string value.
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx)
if err != nil {
return nil, err
return nil, false, err
}
return strconv.FormatUint(value, 10), nil
return strconv.FormatUint(value, 10), false, nil
}

switch tp {
case mysql.TypeSet:
return setSetDefaultValue(v, col)
val, err := setSetDefaultValue(v, col)
return val, false, err
case mysql.TypeDuration:
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil {
return "", errors.Trace(err)
return "", false, errors.Trace(err)
}
case mysql.TypeBit:
if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 {
// For BIT fields, convert int into BinaryLiteral.
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), false, nil
}
}

return v.ToString()
val, err := v.ToString()
return val, false, err
}

func tryToGetSequenceDefaultValue(c *ast.ColumnOption) (expr string, isExpr bool, err error) {
if f, ok := c.Expr.(*ast.FuncCallExpr); ok && f.FnName.L == ast.NextVal {
var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)
if err := c.Expr.Restore(restoreCtx); err != nil {
return "", true, err
}
return sb.String(), true, nil
}
return "", false, nil
}

// setSetDefaultValue sets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
Expand Down Expand Up @@ -816,7 +852,7 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
return nil
}

if c.GetDefaultValue() != nil {
if c.GetDefaultValue() != nil && !c.DefaultIsExpr {
if _, err := table.GetColDefaultValue(ctx, c.ToInfo()); err != nil {
return types.ErrInvalidDefault.GenWithStackByArgs(c.Name)
}
Expand Down Expand Up @@ -2602,10 +2638,16 @@ func modifiable(origin *types.FieldType, to *types.FieldType) error {

func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) {
hasDefaultValue := false
value, err := getDefaultValue(ctx, col, option)
value, isSeqExpr, err := getDefaultValue(ctx, col, option)
if err != nil {
return hasDefaultValue, errors.Trace(err)
}
if isSeqExpr {
if err := checkSequenceDefaultValue(col); err != nil {
return false, errors.Trace(err)
}
col.DefaultIsExpr = isSeqExpr
}

if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
return hasDefaultValue, errors.Trace(err)
Expand Down
28 changes: 28 additions & 0 deletions ddl/sequence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ func (s *testSequenceSuite) TestDropSequence(c *C) {
func (s *testSequenceSuite) TestShowCreateSequence(c *C) {
s.tk = testkit.NewTestKit(c, s.store)
s.tk.MustExec("use test")
s.tk.MustExec("drop table if exists t")
s.tk.MustExec("drop sequence if exists seq")
s.tk.MustExec("create table t(a int)")
s.tk.MustExec("create sequence seq")

Expand Down Expand Up @@ -235,3 +237,29 @@ func (s *testSequenceSuite) TestShowCreateSequence(c *C) {
s.tk.MustExec("drop sequence if exists seq")
s.tk.MustExec(showString)
}

func (s *testSequenceSuite) TestSequenceAsDefaultValue(c *C) {
s.tk = testkit.NewTestKit(c, s.store)
s.tk.MustExec("use test")
s.tk.MustExec("create sequence seq")

// test the use sequence's nextval as default.
s.tk.MustExec("create table t1 (a int default next value for seq)")
s.tk.MustGetErrMsg("create table t2 (a char(1) default next value for seq)", "[ddl:8228]Unsupported sequence default value for column type 'a'")

s.tk.MustExec("create table t3 (a int default nextval(seq))")

s.tk.MustExec("create table t4 (a int)")
s.tk.MustExec("alter table t4 alter column a set default (next value for seq)")
s.tk.MustExec("alter table t4 alter column a set default (nextval(seq))")

s.tk.MustExec("create table t5 (a char(1))")
s.tk.MustGetErrMsg("alter table t5 alter column a set default (next value for seq)", "[ddl:8228]Unsupported sequence default value for column type 'a'")

s.tk.MustGetErrMsg("alter table t5 alter column a set default (nextval(seq))", "[ddl:8228]Unsupported sequence default value for column type 'a'")

s.tk.MustGetErrMsg("alter table t5 add column b char(1) default next value for seq", "[ddl:8228]Unsupported sequence default value for column type 'b'")

s.tk.MustExec("alter table t5 add column b int default nextval(seq)")

}
5 changes: 2 additions & 3 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,15 +507,14 @@ func (e *InsertValues) getRowInPlace(ctx context.Context, vals []types.Datum, ro

// getColDefaultValue gets the column default value.
func (e *InsertValues) getColDefaultValue(idx int, col *table.Column) (d types.Datum, err error) {
if e.colDefaultVals != nil && e.colDefaultVals[idx].valid {
if !col.DefaultIsExpr && e.colDefaultVals != nil && e.colDefaultVals[idx].valid {
return e.colDefaultVals[idx].val, nil
}

defaultVal, err := table.GetColDefaultValue(e.ctx, col.ToInfo())
if err != nil {
return types.Datum{}, err
}
if initialized := e.lazilyInitColDefaultValBuf(); initialized {
if initialized := e.lazilyInitColDefaultValBuf(); initialized && !col.DefaultIsExpr {
e.colDefaultVals[idx].val = defaultVal
e.colDefaultVals[idx].valid = true
}
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd h1:CV3VsP3Z02MVtdpTMfE
github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8=
github.com/pingcap/parser v0.0.0-20200212063918-0829643f461c h1:QbFj6Ng/PvHeQNN7aPWpulXIzoo+j/J8odEM7ERUt7g=
github.com/pingcap/parser v0.0.0-20200212063918-0829643f461c/go.mod h1:9v0Edh8IbgjGYW2ArJr19E+bvL8zKahsFp+ixWeId+4=

github.com/pingcap/pd v1.1.0-beta.0.20200106144140-f5a7aa985497 h1:FzLErYtcXnSxtC469OuVDlgBbh0trJZzNxw0mNKzyls=
github.com/pingcap/pd v1.1.0-beta.0.20200106144140-f5a7aa985497/go.mod h1:cfT/xu4Zz+Tkq95QrLgEBZ9ikRcgzy4alHqqoaTftqI=
github.com/pingcap/sysutil v0.0.0-20191216090214-5f9620d22b3b h1:EEyo/SCRswLGuSk+7SB86Ak1p8bS6HL1Mi4Dhyuv6zg=
Expand Down
18 changes: 17 additions & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
}
switch v := inNode.(type) {
case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause,
*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr:
*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr, *ast.TableNameExpr:
case *driver.ValueExpr:
value := &expression.Constant{Value: v.Datum, RetType: &v.Type}
er.ctxStackAppend(value, types.EmptyName)
Expand All @@ -921,6 +921,8 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter--
}
case *ast.TableName:
er.toTable(v)
case *ast.ColumnName:
er.toColumn(v)
case *ast.UnaryOperationExpr:
Expand Down Expand Up @@ -1491,6 +1493,20 @@ func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
}
}

// Now TableName in expression only used by sequence function like nextval(seq).
// The function arg should be evaluated as a table name rather than normal column name like mysql does.
func (er *expressionRewriter) toTable(v *ast.TableName) {
fullName := v.Name.L
if len(v.Schema.L) != 0 {
fullName = v.Schema.L + "." + fullName
}
val := &expression.Constant{
Value: types.NewDatum(fullName),
RetType: types.NewFieldType(mysql.TypeString),
}
er.ctxStackAppend(val, types.EmptyName)
}

func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
idx, err := expression.FindFieldName(er.names, v)
if err != nil {
Expand Down
26 changes: 25 additions & 1 deletion table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"time"
"unicode/utf8"

"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/model"
Expand Down Expand Up @@ -371,7 +372,30 @@ func GetColOriginDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo) (ty

// GetColDefaultValue gets default value of the column.
func GetColDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo) (types.Datum, error) {
return getColDefaultValue(ctx, col, col.GetDefaultValue())
defaultValue := col.GetDefaultValue()
if !col.DefaultIsExpr {
return getColDefaultValue(ctx, col, defaultValue)
}
return getColDefaultExprValue(ctx, col, defaultValue.(string))
}

func getColDefaultExprValue(ctx sessionctx.Context, col *model.ColumnInfo, defaultValue string) (types.Datum, error) {
var defaultExpr ast.ExprNode
expr := fmt.Sprintf("select %s", defaultValue)
stmts, _, err := parser.New().Parse(expr, "", "")
if err == nil {
defaultExpr = stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr
}
d, err := expression.EvalAstExpr(ctx, defaultExpr)
if err != nil {
return types.Datum{}, err
}
// Check the evaluated data type by cast.
value, err := CastValue(ctx, types.NewDatum(d), col)
if err != nil {
return types.Datum{}, err
}
return value, nil
}

func getColDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo, defaultVal interface{}) (types.Datum, error) {
Expand Down

0 comments on commit 007c0e6

Please sign in to comment.