From 2efc9c48a7379010389558829cd847acf1581e1d Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 10 Sep 2021 12:42:39 +0800 Subject: [PATCH] planner: fix expression rewrite makes between expr infers wrong collation. (#27254) (#27851) --- expression/builtin.go | 9 ++++---- planner/core/expression_rewriter.go | 22 +++++++++++++----- planner/core/expression_rewriter_test.go | 29 ++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/expression/builtin.go b/expression/builtin.go index 0b75d57caea78..fb441a06dedc0 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -90,7 +90,7 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi if ctx == nil { return baseBuiltinFunc{}, errors.New("unexpected nil session ctx") } - if err := checkIllegalMixCollation(funcName, args, retType); err != nil { + if err := CheckIllegalMixCollation(funcName, args, retType); err != nil { return baseBuiltinFunc{}, err } derivedCharset, derivedCollate := DeriveCollationFromExprs(ctx, args...) @@ -113,7 +113,8 @@ var ( coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"} ) -func checkIllegalMixCollation(funcName string, args []Expression, evalType types.EvalType) error { +// CheckIllegalMixCollation check the if the aggregate expression is legal. +func CheckIllegalMixCollation(funcName string, args []Expression, evalType types.EvalType) error { if len(args) < 2 { return nil } @@ -132,7 +133,7 @@ func illegalMixCollationErr(funcName string, args []Expression) error { case 2: return collate.ErrIllegalMix2Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], funcName) case 3: - return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[0].GetType().Collate, coerString[args[2].Coercibility()], funcName) + return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[2].GetType().Collate, coerString[args[2].Coercibility()], funcName) default: return collate.ErrIllegalMixCollation.GenWithStackByArgs(funcName) } @@ -170,7 +171,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex } } - if err = checkIllegalMixCollation(funcName, args, retType); err != nil { + if err = CheckIllegalMixCollation(funcName, args, retType); err != nil { return } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index a481e4864c780..55dbb3f3305d3 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1577,17 +1577,27 @@ func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) { expr, lexp, rexp := er.wrapExpWithCast() - var op string + er.err = expression.CheckIllegalMixCollation("between", []expression.Expression{expr, lexp, rexp}, types.ETInt) + if er.err != nil { + return + } + + dstCharset, dstCollation := expression.DeriveCollationFromExprs(er.sctx, expr, lexp, rexp) + var l, r expression.Expression - l, er.err = er.newFunction(ast.GE, &v.Type, expr, lexp) - if er.err == nil { - r, er.err = er.newFunction(ast.LE, &v.Type, expr, rexp) + l, er.err = expression.NewFunctionBase(er.sctx, ast.GE, &v.Type, expr, lexp) + if er.err != nil { + return } - op = ast.LogicAnd + r, er.err = expression.NewFunctionBase(er.sctx, ast.LE, &v.Type, expr, rexp) if er.err != nil { return } - function, err := er.newFunction(op, &v.Type, l, r) + l.SetCharsetAndCollation(dstCharset, dstCollation) + r.SetCharsetAndCollation(dstCharset, dstCollation) + l = expression.FoldConstant(l) + r = expression.FoldConstant(r) + function, err := er.newFunction(ast.LogicAnd, &v.Type, l, r) if err != nil { er.err = err return diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index 66a038342305b..0ad599e80a092 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -17,16 +17,21 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/testutil" ) var _ = Suite(&testExpressionRewriterSuite{}) +var _ = SerialSuites(&testExpressionRewriterSuiteSerial{}) type testExpressionRewriterSuite struct { } +type testExpressionRewriterSuiteSerial struct { +} + func (s *testExpressionRewriterSuite) TestIfNullEliminateColName(c *C) { defer testleak.AfterTest(c)() store, dom, err := newStoreWithBootstrap() @@ -405,3 +410,27 @@ func (s *testExpressionRewriterSuite) TestIssue22818(c *C) { tk.MustQuery("select * from t where a between \"23:22:22\" and \"23:22:22\"").Check( testkit.Rows("23:22:22")) } + +func (s *testExpressionRewriterSuiteSerial) TestBetweenExprCollation(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(a char(10) charset latin1 collate latin1_bin, c char(10) collate utf8mb4_general_ci);") + tk.MustExec("insert into t1 values ('a', 'B');") + tk.MustExec("insert into t1 values ('c', 'D');") + tk.MustQuery("select * from t1 where a between 'B' and c;").Check(testkit.Rows("c D")) + tk.MustQuery("explain select * from t1 where 'a' between 'g' and 'f';").Check(testkit.Rows("TableDual_6 0.00 root rows:0")) + + tk.MustGetErrMsg("select * from t1 where a between 'B' collate utf8mb4_general_ci and c collate utf8mb4_unicode_ci;", "[expression:1270]Illegal mix of collations (latin1_bin,IMPLICIT), (utf8mb4_general_ci,EXPLICIT), (utf8mb4_unicode_ci,EXPLICIT) for operation 'between'") +}