diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 54c37c7774037..cfba5f3bb32fd 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -540,12 +540,10 @@ func (b *builtinIfIntSig) evalInt(row chunk.Row) (ret int64, isNull bool, err er if err != nil { return 0, true, err } - arg1, isNull1, err := b.args[1].EvalInt(b.ctx, row) - if (!isNull0 && arg0 != 0) || err != nil { - return arg1, isNull1, err + if !isNull0 && arg0 != 0 { + return b.args[1].EvalInt(b.ctx, row) } - arg2, isNull2, err := b.args[2].EvalInt(b.ctx, row) - return arg2, isNull2, err + return b.args[2].EvalInt(b.ctx, row) } type builtinIfRealSig struct { @@ -563,12 +561,10 @@ func (b *builtinIfRealSig) evalReal(row chunk.Row) (ret float64, isNull bool, er if err != nil { return 0, true, err } - arg1, isNull1, err := b.args[1].EvalReal(b.ctx, row) - if (!isNull0 && arg0 != 0) || err != nil { - return arg1, isNull1, err + if !isNull0 && arg0 != 0 { + return b.args[1].EvalReal(b.ctx, row) } - arg2, isNull2, err := b.args[2].EvalReal(b.ctx, row) - return arg2, isNull2, err + return b.args[2].EvalReal(b.ctx, row) } type builtinIfDecimalSig struct { @@ -586,12 +582,10 @@ func (b *builtinIfDecimalSig) evalDecimal(row chunk.Row) (ret *types.MyDecimal, if err != nil { return nil, true, err } - arg1, isNull1, err := b.args[1].EvalDecimal(b.ctx, row) - if (!isNull0 && arg0 != 0) || err != nil { - return arg1, isNull1, err + if !isNull0 && arg0 != 0 { + return b.args[1].EvalDecimal(b.ctx, row) } - arg2, isNull2, err := b.args[2].EvalDecimal(b.ctx, row) - return arg2, isNull2, err + return b.args[2].EvalDecimal(b.ctx, row) } type builtinIfStringSig struct { @@ -609,12 +603,10 @@ func (b *builtinIfStringSig) evalString(row chunk.Row) (ret string, isNull bool, if err != nil { return "", true, err } - arg1, isNull1, err := b.args[1].EvalString(b.ctx, row) - if (!isNull0 && arg0 != 0) || err != nil { - return arg1, isNull1, err + if !isNull0 && arg0 != 0 { + return b.args[1].EvalString(b.ctx, row) } - arg2, isNull2, err := b.args[2].EvalString(b.ctx, row) - return arg2, isNull2, err + return b.args[2].EvalString(b.ctx, row) } type builtinIfTimeSig struct { @@ -632,12 +624,10 @@ func (b *builtinIfTimeSig) evalTime(row chunk.Row) (ret types.Time, isNull bool, if err != nil { return ret, true, err } - arg1, isNull1, err := b.args[1].EvalTime(b.ctx, row) - if (!isNull0 && arg0 != 0) || err != nil { - return arg1, isNull1, err + if !isNull0 && arg0 != 0 { + return b.args[1].EvalTime(b.ctx, row) } - arg2, isNull2, err := b.args[2].EvalTime(b.ctx, row) - return arg2, isNull2, err + return b.args[2].EvalTime(b.ctx, row) } type builtinIfDurationSig struct { @@ -655,12 +645,10 @@ func (b *builtinIfDurationSig) evalDuration(row chunk.Row) (ret types.Duration, if err != nil { return ret, true, err } - arg1, isNull1, err := b.args[1].EvalDuration(b.ctx, row) - if (!isNull0 && arg0 != 0) || err != nil { - return arg1, isNull1, err + if !isNull0 && arg0 != 0 { + return b.args[1].EvalDuration(b.ctx, row) } - arg2, isNull2, err := b.args[2].EvalDuration(b.ctx, row) - return arg2, isNull2, err + return b.args[2].EvalDuration(b.ctx, row) } type builtinIfJSONSig struct { @@ -678,21 +666,10 @@ func (b *builtinIfJSONSig) evalJSON(row chunk.Row) (ret json.BinaryJSON, isNull if err != nil { return ret, true, err } - arg1, isNull1, err := b.args[1].EvalJSON(b.ctx, row) - if err != nil { - return ret, true, err - } - arg2, isNull2, err := b.args[2].EvalJSON(b.ctx, row) - if err != nil { - return ret, true, err - } - switch { - case isNull0 || arg0 == 0: - ret, isNull = arg2, isNull2 - case arg0 != 0: - ret, isNull = arg1, isNull1 + if !isNull0 && arg0 != 0 { + return b.args[1].EvalJSON(b.ctx, row) } - return + return b.args[2].EvalJSON(b.ctx, row) } type ifNullFunctionClass struct { diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 83065bae9302f..8045838d594ef 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -56,12 +56,8 @@ func ifFoldHandler(expr *ScalarFunction) (Expression, bool) { } return foldConstant(args[2]) } - var isDeferred, isDeferredConst bool - expr.GetArgs()[1], isDeferred = foldConstant(args[1]) - isDeferredConst = isDeferredConst || isDeferred - expr.GetArgs()[2], isDeferred = foldConstant(args[2]) - isDeferredConst = isDeferredConst || isDeferred - return expr, isDeferredConst + // if the condition is not const, which branch is unknown to run, so directly return. + return expr, false } func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { @@ -76,18 +72,17 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { } return constArg, isDeferred } - var isDeferredConst bool - expr.GetArgs()[1], isDeferredConst = foldConstant(args[1]) - return expr, isDeferredConst + // if the condition is not const, which branch is unknown to run, so directly return. + return expr, false } func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { args, l := expr.GetArgs(), len(expr.GetArgs()) - var isDeferred, isDeferredConst, hasNonConstCondition bool + var isDeferred, isDeferredConst bool for i := 0; i < l-1; i += 2 { expr.GetArgs()[i], isDeferred = foldConstant(args[i]) isDeferredConst = isDeferredConst || isDeferred - if _, isConst := expr.GetArgs()[i].(*Constant); isConst && !hasNonConstCondition { + if _, isConst := expr.GetArgs()[i].(*Constant); isConst { // If the condition is const and true, and the previous conditions // has no expr, then the folded execution body is returned, otherwise // the arguments of the casewhen are folded and replaced. @@ -105,20 +100,14 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst } } else { - hasNonConstCondition = true + // for no-const, here should return directly, because the following branches are unknown to be run or not + return expr, false } - expr.GetArgs()[i+1], isDeferred = foldConstant(args[i+1]) - isDeferredConst = isDeferredConst || isDeferred - } - - if l%2 == 0 { - return expr, isDeferredConst } - // If the number of arguments in casewhen is odd, and the previous conditions - // is const and false, then the folded else execution body is returned. otherwise + // is false, then the folded else execution body is returned. otherwise // the execution body of the else are folded and replaced. - if !hasNonConstCondition { + if l%2 == 1 { foldedExpr, isDeferred := foldConstant(args[l-1]) isDeferredConst = isDeferredConst || isDeferred if _, isConst := foldedExpr.(*Constant); isConst { @@ -127,10 +116,6 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { } return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst } - - expr.GetArgs()[l-1], isDeferred = foldConstant(args[l-1]) - isDeferredConst = isDeferredConst || isDeferred - return expr, isDeferredConst } diff --git a/expression/function_traits.go b/expression/function_traits.go index d64b762b752ee..0da50400e9647 100644 --- a/expression/function_traits.go +++ b/expression/function_traits.go @@ -56,6 +56,15 @@ var DisableFoldFunctions = map[string]struct{}{ ast.Benchmark: {}, } +// TryFoldFunctions stores functions which try to fold constant in child scope functions if without errors/warnings, +// otherwise, the child functions do not fold constant. +// Note: the function itself should fold constant. +var TryFoldFunctions = map[string]struct{}{ + ast.If: {}, + ast.Ifnull: {}, + ast.Case: {}, +} + // IllegalFunctions4GeneratedColumns stores functions that is illegal for generated columns. // See https://github.com/mysql/mysql-server/blob/5.7/mysql-test/suite/gcol/inc/gcol_blocked_sql_funcs_main.inc for details var IllegalFunctions4GeneratedColumns = map[string]struct{}{ diff --git a/expression/integration_test.go b/expression/integration_test.go index 20664a1abd034..2d7219bae49c1 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2808,6 +2808,24 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { tk.MustQuery("select ifnull(b, b/0) from t") tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select case when 1 then 1 else 1/0 end") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery(" select if(1,1,1/0)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select ifnull(1, 1/0)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + + tk.MustExec("delete from t") + tk.MustExec("insert t values ('str2', 0)") + tk.MustQuery("select case when b < 1 then 1 else 1/0 end from t") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select case when b < 1 then 1 when 1/0 then b else 1/0 end from t") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select if(b < 1 , 1, 1/0) from t") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select ifnull(b, 1/0) from t") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select case 2.0 when 2.0 then 3.0 when 3.0 then 2.0 end").Check(testkit.Rows("3.0")) tk.MustQuery("select case 2.0 when 3.0 then 2.0 when 4.0 then 3.0 else 5.0 end").Check(testkit.Rows("5.0")) tk.MustQuery("select case cast('2011-01-01' as date) when cast('2011-01-01' as date) then cast('2011-02-02' as date) end").Check(testkit.Rows("2011-02-02")) diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 1ae33d80c941b..80cbdbf706d43 100755 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -174,7 +174,9 @@ func typeInferForNull(args []Expression) { } // newFunctionImpl creates a new scalar function or constant. -func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { +// fold: 1 means folding constants, while 0 means not, +// -1 means try to fold constants if without errors/warnings, otherwise not. +func newFunctionImpl(ctx sessionctx.Context, fold int, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { if retType == nil { return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.") } @@ -210,20 +212,36 @@ func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType RetType: retType, Function: f, } - if fold { + if fold == 1 { return FoldConstant(sf), nil + } else if fold == -1 { + // try to fold constants, and return the original function if errors/warnings occur + sc := ctx.GetSessionVars().StmtCtx + beforeWarns := sc.WarningCount() + newSf := FoldConstant(sf) + afterWarns := sc.WarningCount() + if afterWarns > beforeWarns { + sc.TruncateWarnings(int(beforeWarns)) + return sf, nil + } + return newSf, nil } return sf, nil } // NewFunction creates a new scalar function or constant via a constant folding. func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { - return newFunctionImpl(ctx, true, funcName, retType, args...) + return newFunctionImpl(ctx, 1, funcName, retType, args...) } // NewFunctionBase creates a new scalar function with no constant folding. func NewFunctionBase(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { - return newFunctionImpl(ctx, false, funcName, retType, args...) + return newFunctionImpl(ctx, 0, funcName, retType, args...) +} + +// NewFunctionTryFold creates a new scalar function with trying constant folding. +func NewFunctionTryFold(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) { + return newFunctionImpl(ctx, -1, funcName, retType, args...) } // NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally. diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 81be3c70351df..0bd6352ad37de 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -154,6 +154,7 @@ func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p LogicalPlan) rewriter.preprocess = nil rewriter.insertPlan = nil rewriter.disableFoldCounter = 0 + rewriter.tryFoldCounter = 0 rewriter.ctxStack = rewriter.ctxStack[:0] rewriter.ctxNameStk = rewriter.ctxNameStk[:0] rewriter.ctx = ctx @@ -226,6 +227,7 @@ type expressionRewriter struct { // leaving the scope(enable again), the counter will -1. // NOTE: This value can be changed during expression rewritten. disableFoldCounter int + tryFoldCounter int } func (er *expressionRewriter) ctxStackLen() int { @@ -401,6 +403,16 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok { er.disableFoldCounter++ } + if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok { + er.tryFoldCounter++ + } + case *ast.CaseExpr: + if _, ok := expression.DisableFoldFunctions["case"]; ok { + er.disableFoldCounter++ + } + if _, ok := expression.TryFoldFunctions["case"]; ok { + er.tryFoldCounter++ + } case *ast.SetCollationExpr: // Do nothing default: @@ -944,6 +956,9 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok case *ast.VariableExpr: er.rewriteVariable(v) case *ast.FuncCallExpr: + if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok { + er.tryFoldCounter-- + } er.funcCallToExpression(v) if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok { er.disableFoldCounter-- @@ -959,7 +974,13 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok case *ast.BetweenExpr: er.betweenToExpression(v) case *ast.CaseExpr: + if _, ok := expression.TryFoldFunctions["case"]; ok { + er.tryFoldCounter-- + } er.caseToExpression(v) + if _, ok := expression.DisableFoldFunctions["case"]; ok { + er.disableFoldCounter-- + } case *ast.FuncCastExpr: arg := er.ctxStack[len(er.ctxStack)-1] er.err = expression.CheckArgsNotMultiColumnRow(arg) @@ -1053,6 +1074,9 @@ func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldT if er.disableFoldCounter > 0 { return expression.NewFunctionBase(er.sctx, funcName, retType, args...) } + if er.tryFoldCounter > 0 { + return expression.NewFunctionTryFold(er.sctx, funcName, retType, args...) + } return expression.NewFunction(er.sctx, funcName, retType, args...) }