From 4947cf1f596d82e6e987cc3895da0b352c7da1a6 Mon Sep 17 00:00:00 2001 From: sylzd Date: Thu, 2 Dec 2021 13:11:53 +0800 Subject: [PATCH] expression: fix wrong result of greatest/least(mixed unsigned/signed int) (#30121) --- expression/builtin_compare.go | 28 ++++++++++++++++++++++++++++ expression/builtin_compare_test.go | 10 ++++++++++ expression/integration_test.go | 9 +++++++++ expression/typeinfer_test.go | 7 +++++++ 4 files changed, 54 insertions(+) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index e4edf23ab719f..609ee26dc09e1 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -478,6 +478,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } switch tp { case types.ETInt: + // adjust unsigned flag + greastInitUnsignedFlag := false + if isEqualsInitUnsignedFlag(greastInitUnsignedFlag, args) { + bf.tp.Flag &= ^mysql.UnsignedFlag + } else { + bf.tp.Flag |= mysql.UnsignedFlag + } + sig = &builtinGreatestIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_GreatestInt) case types.ETReal: @@ -745,6 +753,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi } switch tp { case types.ETInt: + // adjust unsigned flag + leastInitUnsignedFlag := true + if isEqualsInitUnsignedFlag(leastInitUnsignedFlag, args) { + bf.tp.Flag |= mysql.UnsignedFlag + } else { + bf.tp.Flag &= ^mysql.UnsignedFlag + } + sig = &builtinLeastIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LeastInt) case types.ETReal: @@ -2880,3 +2896,15 @@ func CompareJSON(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs } return int64(json.CompareBinary(arg0, arg1)), false, nil } + +// isEqualsInitUnsignedFlag can adjust unsigned flag for greatest/least function. +// For greatest, returns unsigned result if there is at least one argument is unsigned. +// For least, returns signed result if there is at least one argument is signed. +func isEqualsInitUnsignedFlag(initUnsigned bool, args []Expression) bool { + for _, arg := range args { + if initUnsigned != mysql.HasUnsignedFlag(arg.GetType().Flag) { + return false + } + } + return true +} diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index c2e8ecc9fd64e..c16f484c498dd 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -279,6 +279,8 @@ func TestGreatestLeastFunc(t *testing.T) { sc := ctx.GetSessionVars().StmtCtx originIgnoreTruncate := sc.IgnoreTruncate sc.IgnoreTruncate = true + decG := &types.MyDecimal{} + decL := &types.MyDecimal{} defer func() { sc.IgnoreTruncate = originIgnoreTruncate }() @@ -290,6 +292,14 @@ func TestGreatestLeastFunc(t *testing.T) { isNil bool getErr bool }{ + { + []interface{}{int64(-9223372036854775808), uint64(9223372036854775809)}, + decG.FromUint(9223372036854775809), decL.FromInt(-9223372036854775808), false, false, + }, + { + []interface{}{uint64(9223372036854775808), uint64(9223372036854775809)}, + uint64(9223372036854775809), uint64(9223372036854775808), false, false, + }, { []interface{}{1, 2, 3, 4}, int64(4), int64(1), false, false, diff --git a/expression/integration_test.go b/expression/integration_test.go index 29dbc7e91bcc9..d3c561e3340b4 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -10648,3 +10648,12 @@ func (s *testIntegrationSuite) TestIssue29513(c *C) { tk.MustQuery("select '123' union select cast(a as char) from t;").Sort().Check(testkit.Rows("123", "45678")) tk.MustQuery("select '123' union select cast(a as char(2)) from t;").Sort().Check(testkit.Rows("123", "45")) } + +func (s *testIntegrationSuite) TestIssue30101(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 bigint unsigned, c2 bigint unsigned);") + tk.MustExec("insert into t1 values(9223372036854775808, 9223372036854775809);") + tk.MustQuery("select greatest(c1, c2) from t1;").Sort().Check(testkit.Rows("9223372036854775809")) +} diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 9c526712419ec..4c91489da289e 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -1057,6 +1057,13 @@ func (s *InferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase { {"interval(c_int_d, c_int_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"interval(c_int_d, c_float_d, c_double_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + + {"greatest(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 11, 0}, + {"least(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"least(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"least(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, } }