Skip to content

Commit e848851

Browse files
SunRunAwaysre-bot
authored andcommitted
expression: fix incorrect result of logical operators (#12173) (#12813)
1 parent c72efd1 commit e848851

File tree

5 files changed

+180
-20
lines changed

5 files changed

+180
-20
lines changed

expression/builtin.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,8 @@ var funcs = map[string]functionClass{
540540
ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}},
541541
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
542542
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}},
543-
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
544-
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
543+
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false},
544+
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false},
545545
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}},
546546
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
547547
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},

expression/builtin_op.go

+60-12
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
6565
if err != nil {
6666
return nil, errors.Trace(err)
6767
}
68+
args[0], err = wrapWithIsTrue(ctx, true, args[0])
69+
if err != nil {
70+
return nil, errors.Trace(err)
71+
}
72+
args[1], err = wrapWithIsTrue(ctx, true, args[1])
73+
if err != nil {
74+
return nil, errors.Trace(err)
75+
}
76+
6877
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
6978
sig := &builtinLogicAndSig{bf}
7079
sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd)
@@ -106,6 +115,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
106115
if err != nil {
107116
return nil, errors.Trace(err)
108117
}
118+
args[0], err = wrapWithIsTrue(ctx, true, args[0])
119+
if err != nil {
120+
return nil, errors.Trace(err)
121+
}
122+
args[1], err = wrapWithIsTrue(ctx, true, args[1])
123+
if err != nil {
124+
return nil, errors.Trace(err)
125+
}
126+
109127
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
110128
bf.tp.Flen = 1
111129
sig := &builtinLogicOrSig{bf}
@@ -153,6 +171,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
153171
if err != nil {
154172
return nil, errors.Trace(err)
155173
}
174+
156175
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
157176
sig := &builtinLogicXorSig{bf}
158177
sig.setPbCode(tipb.ScalarFuncSig_LogicalXor)
@@ -376,6 +395,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) {
376395
type isTrueOrFalseFunctionClass struct {
377396
baseFunctionClass
378397
op opcode.Op
398+
399+
// keepNull indicates how this function treats a null input parameter.
400+
// If keepNull is true and the input parameter is null, the function will return null.
401+
// If keepNull is false, the null input parameter will be cast to 0.
402+
keepNull bool
379403
}
380404

381405
func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
@@ -396,25 +420,25 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
396420
case opcode.IsTruth:
397421
switch argTp {
398422
case types.ETReal:
399-
sig = &builtinRealIsTrueSig{bf}
423+
sig = &builtinRealIsTrueSig{bf, c.keepNull}
400424
sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue)
401425
case types.ETDecimal:
402-
sig = &builtinDecimalIsTrueSig{bf}
426+
sig = &builtinDecimalIsTrueSig{bf, c.keepNull}
403427
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue)
404428
case types.ETInt:
405-
sig = &builtinIntIsTrueSig{bf}
429+
sig = &builtinIntIsTrueSig{bf, c.keepNull}
406430
sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue)
407431
}
408432
case opcode.IsFalsity:
409433
switch argTp {
410434
case types.ETReal:
411-
sig = &builtinRealIsFalseSig{bf}
435+
sig = &builtinRealIsFalseSig{bf, c.keepNull}
412436
sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse)
413437
case types.ETDecimal:
414-
sig = &builtinDecimalIsFalseSig{bf}
438+
sig = &builtinDecimalIsFalseSig{bf, c.keepNull}
415439
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse)
416440
case types.ETInt:
417-
sig = &builtinIntIsFalseSig{bf}
441+
sig = &builtinIntIsFalseSig{bf, c.keepNull}
418442
sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse)
419443
}
420444
}
@@ -423,10 +447,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
423447

424448
type builtinRealIsTrueSig struct {
425449
baseBuiltinFunc
450+
keepNull bool
426451
}
427452

428453
func (b *builtinRealIsTrueSig) Clone() builtinFunc {
429-
newSig := &builtinRealIsTrueSig{}
454+
newSig := &builtinRealIsTrueSig{keepNull: b.keepNull}
430455
newSig.cloneFrom(&b.baseBuiltinFunc)
431456
return newSig
432457
}
@@ -436,6 +461,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
436461
if err != nil {
437462
return 0, true, errors.Trace(err)
438463
}
464+
if b.keepNull && isNull {
465+
return 0, true, nil
466+
}
439467
if isNull || input == 0 {
440468
return 0, false, nil
441469
}
@@ -444,10 +472,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
444472

445473
type builtinDecimalIsTrueSig struct {
446474
baseBuiltinFunc
475+
keepNull bool
447476
}
448477

449478
func (b *builtinDecimalIsTrueSig) Clone() builtinFunc {
450-
newSig := &builtinDecimalIsTrueSig{}
479+
newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull}
451480
newSig.cloneFrom(&b.baseBuiltinFunc)
452481
return newSig
453482
}
@@ -457,6 +486,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
457486
if err != nil {
458487
return 0, true, errors.Trace(err)
459488
}
489+
if b.keepNull && isNull {
490+
return 0, true, nil
491+
}
460492
if isNull || input.IsZero() {
461493
return 0, false, nil
462494
}
@@ -465,10 +497,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
465497

466498
type builtinIntIsTrueSig struct {
467499
baseBuiltinFunc
500+
keepNull bool
468501
}
469502

470503
func (b *builtinIntIsTrueSig) Clone() builtinFunc {
471-
newSig := &builtinIntIsTrueSig{}
504+
newSig := &builtinIntIsTrueSig{keepNull: b.keepNull}
472505
newSig.cloneFrom(&b.baseBuiltinFunc)
473506
return newSig
474507
}
@@ -478,6 +511,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
478511
if err != nil {
479512
return 0, true, errors.Trace(err)
480513
}
514+
if b.keepNull && isNull {
515+
return 0, true, nil
516+
}
481517
if isNull || input == 0 {
482518
return 0, false, nil
483519
}
@@ -486,10 +522,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
486522

487523
type builtinRealIsFalseSig struct {
488524
baseBuiltinFunc
525+
keepNull bool
489526
}
490527

491528
func (b *builtinRealIsFalseSig) Clone() builtinFunc {
492-
newSig := &builtinRealIsFalseSig{}
529+
newSig := &builtinRealIsFalseSig{keepNull: b.keepNull}
493530
newSig.cloneFrom(&b.baseBuiltinFunc)
494531
return newSig
495532
}
@@ -499,6 +536,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
499536
if err != nil {
500537
return 0, true, errors.Trace(err)
501538
}
539+
if b.keepNull && isNull {
540+
return 0, true, nil
541+
}
502542
if isNull || input != 0 {
503543
return 0, false, nil
504544
}
@@ -507,10 +547,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
507547

508548
type builtinDecimalIsFalseSig struct {
509549
baseBuiltinFunc
550+
keepNull bool
510551
}
511552

512553
func (b *builtinDecimalIsFalseSig) Clone() builtinFunc {
513-
newSig := &builtinDecimalIsFalseSig{}
554+
newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull}
514555
newSig.cloneFrom(&b.baseBuiltinFunc)
515556
return newSig
516557
}
@@ -520,6 +561,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
520561
if err != nil {
521562
return 0, true, errors.Trace(err)
522563
}
564+
if b.keepNull && isNull {
565+
return 0, true, nil
566+
}
523567
if isNull || !input.IsZero() {
524568
return 0, false, nil
525569
}
@@ -528,10 +572,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
528572

529573
type builtinIntIsFalseSig struct {
530574
baseBuiltinFunc
575+
keepNull bool
531576
}
532577

533578
func (b *builtinIntIsFalseSig) Clone() builtinFunc {
534-
newSig := &builtinIntIsFalseSig{}
579+
newSig := &builtinIntIsFalseSig{keepNull: b.keepNull}
535580
newSig.cloneFrom(&b.baseBuiltinFunc)
536581
return newSig
537582
}
@@ -541,6 +586,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
541586
if err != nil {
542587
return 0, true, errors.Trace(err)
543588
}
589+
if b.keepNull && isNull {
590+
return 0, true, nil
591+
}
544592
if isNull || input != 0 {
545593
return 0, false, nil
546594
}

expression/builtin_op_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) {
8686
{[]interface{}{0, 1}, 0, false, false},
8787
{[]interface{}{0, 0}, 0, false, false},
8888
{[]interface{}{2, -1}, 1, false, false},
89+
{[]interface{}{"a", "0"}, 0, false, false},
8990
{[]interface{}{"a", "1"}, 0, false, false},
91+
{[]interface{}{"1a", "0"}, 0, false, false},
9092
{[]interface{}{"1a", "1"}, 1, false, false},
9193
{[]interface{}{0, nil}, 0, false, false},
9294
{[]interface{}{nil, 0}, 0, false, false},
9395
{[]interface{}{nil, 1}, 0, true, false},
96+
{[]interface{}{0.001, 0}, 0, false, false},
97+
{[]interface{}{0.001, 1}, 1, false, false},
98+
{[]interface{}{nil, 0.000}, 0, false, false},
99+
{[]interface{}{nil, 0.001}, 0, true, false},
100+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false},
101+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
102+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false},
103+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},
94104

95105
{[]interface{}{errors.New("must error"), 1}, 0, false, true},
96106
}
@@ -300,11 +310,26 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) {
300310
{[]interface{}{0, 1}, 1, false, false},
301311
{[]interface{}{0, 0}, 0, false, false},
302312
{[]interface{}{2, -1}, 1, false, false},
313+
{[]interface{}{"a", "0"}, 0, false, false},
303314
{[]interface{}{"a", "1"}, 1, false, false},
315+
{[]interface{}{"1a", "0"}, 1, false, false},
304316
{[]interface{}{"1a", "1"}, 1, false, false},
317+
// casting string to real depends on #10498, which will not be cherry-picked.
318+
// {[]interface{}{"0.0a", 0}, 0, false, false},
319+
// {[]interface{}{"0.0001a", 0}, 1, false, false},
305320
{[]interface{}{1, nil}, 1, false, false},
306321
{[]interface{}{nil, 1}, 1, false, false},
307322
{[]interface{}{nil, 0}, 0, true, false},
323+
{[]interface{}{0.000, 0}, 0, false, false},
324+
{[]interface{}{0.001, 0}, 1, false, false},
325+
{[]interface{}{nil, 0.000}, 0, true, false},
326+
{[]interface{}{nil, 0.001}, 1, false, false},
327+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false},
328+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false},
329+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
330+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false},
331+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
332+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false},
308333

309334
{[]interface{}{errors.New("must error"), 1}, 0, false, true},
310335
}
@@ -541,3 +566,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
541566
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
542567
}
543568
}
569+
570+
func (s *testEvaluatorSuite) TestLogicXor(c *C) {
571+
defer testleak.AfterTest(c)()
572+
573+
sc := s.ctx.GetSessionVars().StmtCtx
574+
origin := sc.IgnoreTruncate
575+
defer func() {
576+
sc.IgnoreTruncate = origin
577+
}()
578+
sc.IgnoreTruncate = true
579+
580+
cases := []struct {
581+
args []interface{}
582+
expected int64
583+
isNil bool
584+
getErr bool
585+
}{
586+
{[]interface{}{1, 1}, 0, false, false},
587+
{[]interface{}{1, 0}, 1, false, false},
588+
{[]interface{}{0, 1}, 1, false, false},
589+
{[]interface{}{0, 0}, 0, false, false},
590+
{[]interface{}{2, -1}, 0, false, false},
591+
{[]interface{}{"a", "0"}, 0, false, false},
592+
{[]interface{}{"a", "1"}, 1, false, false},
593+
{[]interface{}{"1a", "0"}, 1, false, false},
594+
{[]interface{}{"1a", "1"}, 0, false, false},
595+
{[]interface{}{0, nil}, 0, true, false},
596+
{[]interface{}{nil, 0}, 0, true, false},
597+
{[]interface{}{nil, 1}, 0, true, false},
598+
{[]interface{}{0.5000, 0.4999}, 1, false, false},
599+
{[]interface{}{0.5000, 1.0}, 0, false, false},
600+
{[]interface{}{0.4999, 1.0}, 1, false, false},
601+
{[]interface{}{nil, 0.000}, 0, true, false},
602+
{[]interface{}{nil, 0.001}, 0, true, false},
603+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false},
604+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
605+
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
606+
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},
607+
608+
{[]interface{}{errors.New("must error"), 1}, 0, false, true},
609+
}
610+
611+
for _, t := range cases {
612+
f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...)
613+
c.Assert(err, IsNil)
614+
d, err := f.Eval(chunk.Row{})
615+
if t.getErr {
616+
c.Assert(err, NotNil)
617+
} else {
618+
c.Assert(err, IsNil)
619+
if t.isNil {
620+
c.Assert(d.Kind(), Equals, types.KindNull)
621+
} else {
622+
c.Assert(d.GetInt64(), Equals, t.expected)
623+
}
624+
}
625+
}
626+
627+
// Test incorrect parameter count.
628+
_, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero)
629+
c.Assert(err, NotNil)
630+
631+
_, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero})
632+
c.Assert(err, IsNil)
633+
}

expression/distsql_builtin.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -371,17 +371,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
371371
f = &builtinCaseWhenIntSig{base}
372372

373373
case tipb.ScalarFuncSig_IntIsFalse:
374-
f = &builtinIntIsFalseSig{base}
374+
f = &builtinIntIsFalseSig{base, false}
375375
case tipb.ScalarFuncSig_RealIsFalse:
376-
f = &builtinRealIsFalseSig{base}
376+
f = &builtinRealIsFalseSig{base, false}
377377
case tipb.ScalarFuncSig_DecimalIsFalse:
378-
f = &builtinDecimalIsFalseSig{base}
378+
f = &builtinDecimalIsFalseSig{base, false}
379379
case tipb.ScalarFuncSig_IntIsTrue:
380-
f = &builtinIntIsTrueSig{base}
380+
f = &builtinIntIsTrueSig{base, false}
381381
case tipb.ScalarFuncSig_RealIsTrue:
382-
f = &builtinRealIsTrueSig{base}
382+
f = &builtinRealIsTrueSig{base, false}
383383
case tipb.ScalarFuncSig_DecimalIsTrue:
384-
f = &builtinDecimalIsTrueSig{base}
384+
f = &builtinDecimalIsTrueSig{base, false}
385385

386386
case tipb.ScalarFuncSig_IfNullReal:
387387
f = &builtinIfNullRealSig{base}

0 commit comments

Comments
 (0)