Skip to content

Commit 93f8151

Browse files
SunRunAwayeurekaka
authored andcommitted
expression: fix incorrect result of logical operators (#12173) (#12811)
1 parent b11578f commit 93f8151

File tree

5 files changed

+181
-20
lines changed

5 files changed

+181
-20
lines changed

expression/builtin.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,8 @@ var funcs = map[string]functionClass{
540540
ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}},
541541
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
542542
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}},
543-
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
544-
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
543+
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false},
544+
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false},
545545
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}},
546546
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
547547
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},

expression/builtin_op.go

+61-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"fmt"
1818
"math"
1919

20+
"github.com/pingcap/errors"
2021
"github.com/pingcap/parser/mysql"
2122
"github.com/pingcap/parser/opcode"
2223
"github.com/pingcap/tidb/sessionctx"
@@ -64,6 +65,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
6465
if err != nil {
6566
return nil, err
6667
}
68+
args[0], err = wrapWithIsTrue(ctx, true, args[0])
69+
if err != nil {
70+
return nil, errors.Trace(err)
71+
}
72+
args[1], err = wrapWithIsTrue(ctx, true, args[1])
73+
if err != nil {
74+
return nil, errors.Trace(err)
75+
}
76+
6777
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
6878
sig := &builtinLogicAndSig{bf}
6979
sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd)
@@ -105,6 +115,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
105115
if err != nil {
106116
return nil, err
107117
}
118+
args[0], err = wrapWithIsTrue(ctx, true, args[0])
119+
if err != nil {
120+
return nil, errors.Trace(err)
121+
}
122+
args[1], err = wrapWithIsTrue(ctx, true, args[1])
123+
if err != nil {
124+
return nil, errors.Trace(err)
125+
}
126+
108127
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
109128
bf.tp.Flen = 1
110129
sig := &builtinLogicOrSig{bf}
@@ -152,6 +171,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
152171
if err != nil {
153172
return nil, err
154173
}
174+
155175
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
156176
sig := &builtinLogicXorSig{bf}
157177
sig.setPbCode(tipb.ScalarFuncSig_LogicalXor)
@@ -375,6 +395,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) {
375395
type isTrueOrFalseFunctionClass struct {
376396
baseFunctionClass
377397
op opcode.Op
398+
399+
// keepNull indicates how this function treats a null input parameter.
400+
// If keepNull is true and the input parameter is null, the function will return null.
401+
// If keepNull is false, the null input parameter will be cast to 0.
402+
keepNull bool
378403
}
379404

380405
func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
@@ -395,25 +420,25 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
395420
case opcode.IsTruth:
396421
switch argTp {
397422
case types.ETReal:
398-
sig = &builtinRealIsTrueSig{bf}
423+
sig = &builtinRealIsTrueSig{bf, c.keepNull}
399424
sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue)
400425
case types.ETDecimal:
401-
sig = &builtinDecimalIsTrueSig{bf}
426+
sig = &builtinDecimalIsTrueSig{bf, c.keepNull}
402427
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue)
403428
case types.ETInt:
404-
sig = &builtinIntIsTrueSig{bf}
429+
sig = &builtinIntIsTrueSig{bf, c.keepNull}
405430
sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue)
406431
}
407432
case opcode.IsFalsity:
408433
switch argTp {
409434
case types.ETReal:
410-
sig = &builtinRealIsFalseSig{bf}
435+
sig = &builtinRealIsFalseSig{bf, c.keepNull}
411436
sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse)
412437
case types.ETDecimal:
413-
sig = &builtinDecimalIsFalseSig{bf}
438+
sig = &builtinDecimalIsFalseSig{bf, c.keepNull}
414439
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse)
415440
case types.ETInt:
416-
sig = &builtinIntIsFalseSig{bf}
441+
sig = &builtinIntIsFalseSig{bf, c.keepNull}
417442
sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse)
418443
}
419444
}
@@ -422,10 +447,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
422447

423448
type builtinRealIsTrueSig struct {
424449
baseBuiltinFunc
450+
keepNull bool
425451
}
426452

427453
func (b *builtinRealIsTrueSig) Clone() builtinFunc {
428-
newSig := &builtinRealIsTrueSig{}
454+
newSig := &builtinRealIsTrueSig{keepNull: b.keepNull}
429455
newSig.cloneFrom(&b.baseBuiltinFunc)
430456
return newSig
431457
}
@@ -435,6 +461,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
435461
if err != nil {
436462
return 0, true, err
437463
}
464+
if b.keepNull && isNull {
465+
return 0, true, nil
466+
}
438467
if isNull || input == 0 {
439468
return 0, false, nil
440469
}
@@ -443,10 +472,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
443472

444473
type builtinDecimalIsTrueSig struct {
445474
baseBuiltinFunc
475+
keepNull bool
446476
}
447477

448478
func (b *builtinDecimalIsTrueSig) Clone() builtinFunc {
449-
newSig := &builtinDecimalIsTrueSig{}
479+
newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull}
450480
newSig.cloneFrom(&b.baseBuiltinFunc)
451481
return newSig
452482
}
@@ -456,6 +486,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
456486
if err != nil {
457487
return 0, true, err
458488
}
489+
if b.keepNull && isNull {
490+
return 0, true, nil
491+
}
459492
if isNull || input.IsZero() {
460493
return 0, false, nil
461494
}
@@ -464,10 +497,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
464497

465498
type builtinIntIsTrueSig struct {
466499
baseBuiltinFunc
500+
keepNull bool
467501
}
468502

469503
func (b *builtinIntIsTrueSig) Clone() builtinFunc {
470-
newSig := &builtinIntIsTrueSig{}
504+
newSig := &builtinIntIsTrueSig{keepNull: b.keepNull}
471505
newSig.cloneFrom(&b.baseBuiltinFunc)
472506
return newSig
473507
}
@@ -477,6 +511,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
477511
if err != nil {
478512
return 0, true, err
479513
}
514+
if b.keepNull && isNull {
515+
return 0, true, nil
516+
}
480517
if isNull || input == 0 {
481518
return 0, false, nil
482519
}
@@ -485,10 +522,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
485522

486523
type builtinRealIsFalseSig struct {
487524
baseBuiltinFunc
525+
keepNull bool
488526
}
489527

490528
func (b *builtinRealIsFalseSig) Clone() builtinFunc {
491-
newSig := &builtinRealIsFalseSig{}
529+
newSig := &builtinRealIsFalseSig{keepNull: b.keepNull}
492530
newSig.cloneFrom(&b.baseBuiltinFunc)
493531
return newSig
494532
}
@@ -498,6 +536,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
498536
if err != nil {
499537
return 0, true, err
500538
}
539+
if b.keepNull && isNull {
540+
return 0, true, nil
541+
}
501542
if isNull || input != 0 {
502543
return 0, false, nil
503544
}
@@ -506,10 +547,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
506547

507548
type builtinDecimalIsFalseSig struct {
508549
baseBuiltinFunc
550+
keepNull bool
509551
}
510552

511553
func (b *builtinDecimalIsFalseSig) Clone() builtinFunc {
512-
newSig := &builtinDecimalIsFalseSig{}
554+
newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull}
513555
newSig.cloneFrom(&b.baseBuiltinFunc)
514556
return newSig
515557
}
@@ -519,6 +561,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
519561
if err != nil {
520562
return 0, true, err
521563
}
564+
if b.keepNull && isNull {
565+
return 0, true, nil
566+
}
522567
if isNull || !input.IsZero() {
523568
return 0, false, nil
524569
}
@@ -527,10 +572,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
527572

528573
type builtinIntIsFalseSig struct {
529574
baseBuiltinFunc
575+
keepNull bool
530576
}
531577

532578
func (b *builtinIntIsFalseSig) Clone() builtinFunc {
533-
newSig := &builtinIntIsFalseSig{}
579+
newSig := &builtinIntIsFalseSig{keepNull: b.keepNull}
534580
newSig.cloneFrom(&b.baseBuiltinFunc)
535581
return newSig
536582
}
@@ -540,6 +586,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
540586
if err != nil {
541587
return 0, true, err
542588
}
589+
if b.keepNull && isNull {
590+
return 0, true, nil
591+
}
543592
if isNull || input != 0 {
544593
return 0, false, nil
545594
}

expression/builtin_op_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) {
8686
{[]interface{}{0, 1}, 0, false, false},
8787
{[]interface{}{0, 0}, 0, false, false},
8888
{[]interface{}{2, -1}, 1, false, false},
89+
{[]interface{}{"a", "0"}, 0, false, false},
8990
{[]interface{}{"a", "1"}, 0, false, false},
91+
{[]interface{}{"1a", "0"}, 0, false, false},
9092
{[]interface{}{"1a", "1"}, 1, false, false},
9193
{[]interface{}{0, nil}, 0, false, false},
9294
{[]interface{}{nil, 0}, 0, false, false},
9395
{[]interface{}{nil, 1}, 0, true, false},
96+
{[]interface{}{0.001, 0}, 0, false, false},
97+
{[]interface{}{0.001, 1}, 1, false, false},
98+
{[]interface{}{nil, 0.000}, 0, false, false},
99+
{[]interface{}{nil, 0.001}, 0, true, false},
100+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false},
101+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
102+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false},
103+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},
94104

95105
{[]interface{}{errors.New("must error"), 1}, 0, false, true},
96106
}
@@ -300,11 +310,26 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) {
300310
{[]interface{}{0, 1}, 1, false, false},
301311
{[]interface{}{0, 0}, 0, false, false},
302312
{[]interface{}{2, -1}, 1, false, false},
313+
{[]interface{}{"a", "0"}, 0, false, false},
303314
{[]interface{}{"a", "1"}, 1, false, false},
315+
{[]interface{}{"1a", "0"}, 1, false, false},
304316
{[]interface{}{"1a", "1"}, 1, false, false},
317+
// casting string to real depends on #10498, which will not be cherry-picked.
318+
// {[]interface{}{"0.0a", 0}, 0, false, false},
319+
// {[]interface{}{"0.0001a", 0}, 1, false, false},
305320
{[]interface{}{1, nil}, 1, false, false},
306321
{[]interface{}{nil, 1}, 1, false, false},
307322
{[]interface{}{nil, 0}, 0, true, false},
323+
{[]interface{}{0.000, 0}, 0, false, false},
324+
{[]interface{}{0.001, 0}, 1, false, false},
325+
{[]interface{}{nil, 0.000}, 0, true, false},
326+
{[]interface{}{nil, 0.001}, 1, false, false},
327+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false},
328+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false},
329+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
330+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false},
331+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
332+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false},
308333

309334
{[]interface{}{errors.New("must error"), 1}, 0, false, true},
310335
}
@@ -541,3 +566,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
541566
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
542567
}
543568
}
569+
570+
func (s *testEvaluatorSuite) TestLogicXor(c *C) {
571+
defer testleak.AfterTest(c)()
572+
573+
sc := s.ctx.GetSessionVars().StmtCtx
574+
origin := sc.IgnoreTruncate
575+
defer func() {
576+
sc.IgnoreTruncate = origin
577+
}()
578+
sc.IgnoreTruncate = true
579+
580+
cases := []struct {
581+
args []interface{}
582+
expected int64
583+
isNil bool
584+
getErr bool
585+
}{
586+
{[]interface{}{1, 1}, 0, false, false},
587+
{[]interface{}{1, 0}, 1, false, false},
588+
{[]interface{}{0, 1}, 1, false, false},
589+
{[]interface{}{0, 0}, 0, false, false},
590+
{[]interface{}{2, -1}, 0, false, false},
591+
{[]interface{}{"a", "0"}, 0, false, false},
592+
{[]interface{}{"a", "1"}, 1, false, false},
593+
{[]interface{}{"1a", "0"}, 1, false, false},
594+
{[]interface{}{"1a", "1"}, 0, false, false},
595+
{[]interface{}{0, nil}, 0, true, false},
596+
{[]interface{}{nil, 0}, 0, true, false},
597+
{[]interface{}{nil, 1}, 0, true, false},
598+
{[]interface{}{0.5000, 0.4999}, 1, false, false},
599+
{[]interface{}{0.5000, 1.0}, 0, false, false},
600+
{[]interface{}{0.4999, 1.0}, 1, false, false},
601+
{[]interface{}{nil, 0.000}, 0, true, false},
602+
{[]interface{}{nil, 0.001}, 0, true, false},
603+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false},
604+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
605+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
606+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},
607+
608+
{[]interface{}{errors.New("must error"), 1}, 0, false, true},
609+
}
610+
611+
for _, t := range cases {
612+
f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...)
613+
c.Assert(err, IsNil)
614+
d, err := f.Eval(chunk.Row{})
615+
if t.getErr {
616+
c.Assert(err, NotNil)
617+
} else {
618+
c.Assert(err, IsNil)
619+
if t.isNil {
620+
c.Assert(d.Kind(), Equals, types.KindNull)
621+
} else {
622+
c.Assert(d.GetInt64(), Equals, t.expected)
623+
}
624+
}
625+
}
626+
627+
// Test incorrect parameter count.
628+
_, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero)
629+
c.Assert(err, NotNil)
630+
631+
_, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero})
632+
c.Assert(err, IsNil)
633+
}

expression/distsql_builtin.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -371,17 +371,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
371371
f = &builtinCaseWhenIntSig{base}
372372

373373
case tipb.ScalarFuncSig_IntIsFalse:
374-
f = &builtinIntIsFalseSig{base}
374+
f = &builtinIntIsFalseSig{base, false}
375375
case tipb.ScalarFuncSig_RealIsFalse:
376-
f = &builtinRealIsFalseSig{base}
376+
f = &builtinRealIsFalseSig{base, false}
377377
case tipb.ScalarFuncSig_DecimalIsFalse:
378-
f = &builtinDecimalIsFalseSig{base}
378+
f = &builtinDecimalIsFalseSig{base, false}
379379
case tipb.ScalarFuncSig_IntIsTrue:
380-
f = &builtinIntIsTrueSig{base}
380+
f = &builtinIntIsTrueSig{base, false}
381381
case tipb.ScalarFuncSig_RealIsTrue:
382-
f = &builtinRealIsTrueSig{base}
382+
f = &builtinRealIsTrueSig{base, false}
383383
case tipb.ScalarFuncSig_DecimalIsTrue:
384-
f = &builtinDecimalIsTrueSig{base}
384+
f = &builtinDecimalIsTrueSig{base, false}
385385

386386
case tipb.ScalarFuncSig_IfNullReal:
387387
f = &builtinIfNullRealSig{base}

0 commit comments

Comments
 (0)