Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix incorrect result of logical operators when isTrue/ isFalse function is pushed down #15926

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5384,6 +5384,12 @@ func (s *testSuite1) TestIssue15718(c *C) {
tk.MustExec("create table tt(a decimal(10, 0), b varchar(1), c time);")
tk.MustExec("insert into tt values(0, '2', null), (7, null, '1122'), (NULL, 'w', null), (NULL, '2', '3344'), (NULL, NULL, '0'), (7, 'f', '33');")
tk.MustQuery("select a and b as d, a or c as e from tt;").Check(testkit.Rows("0 <nil>", "<nil> 1", "0 <nil>", "<nil> 1", "<nil> <nil>", "0 1"))

tk.MustExec("drop table if exists tt;")
tk.MustExec("create table tt(a decimal(10, 0), b varchar(1), c time);")
tk.MustExec("insert into tt values(0, '2', '123'), (7, null, '1122'), (null, 'w', null);")
tk.MustQuery("select a and b as d, a, b from tt order by d limit 1;").Check(testkit.Rows("<nil> 7 <nil>"))
tk.MustQuery("select b or c as d, b, c from tt order by d limit 1;").Check(testkit.Rows("<nil> w <nil>"))
}

func (s *testSuite1) TestIssue15767(c *C) {
Expand Down
30 changes: 30 additions & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,36 @@ func newBaseBuiltinCastFunc(builtinFunc baseBuiltinFunc, inUnion bool) baseBuilt
}
}

// baseBuiltinIsTrueOrFalseFunc will be contained in every struct that implement isTrue/ isFalse builtinFunc.
type baseBuiltinIsTrueOrFalseFunc struct {
baseBuiltinFunc

// keepNull indicates how this function treats a null input parameter.
// If keepNull is true and the input parameter is null, the function will return null.
// If keepNull is false, the null input parameter will be cast to 0.
keepNull bool
}

// metadata returns the metadata of cast functions
func (b *baseBuiltinIsTrueOrFalseFunc) metadata() proto.Message {
args := &tipb.IsTrueOrFalseMetadata{
KeepNull: b.keepNull,
}
return args
}

func (b *baseBuiltinIsTrueOrFalseFunc) cloneFrom(from *baseBuiltinIsTrueOrFalseFunc) {
b.baseBuiltinFunc.cloneFrom(&from.baseBuiltinFunc)
b.keepNull = from.keepNull
}

func newBaseBuiltinIsTrueOrFalseFunc(builtinFunc baseBuiltinFunc, keepNull bool) baseBuiltinIsTrueOrFalseFunc {
return baseBuiltinIsTrueOrFalseFunc{
baseBuiltinFunc: builtinFunc,
keepNull: keepNull,
}
}

// vecBuiltinFunc contains all vectorized methods for a builtin function.
type vecBuiltinFunc interface {
// vectorized returns if this builtin function itself supports vectorized evaluation.
Expand Down
56 changes: 25 additions & 31 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,35 +418,35 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
argTp = types.ETReal
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTp)
bf := newBaseBuiltinIsTrueOrFalseFunc(newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTp), c.keepNull)
bf.tp.Flen = 1

var sig builtinFunc
switch c.op {
case opcode.IsTruth:
switch argTp {
case types.ETReal:
sig = &builtinRealIsTrueSig{bf, c.keepNull}
sig = &builtinRealIsTrueSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue)
case types.ETDecimal:
sig = &builtinDecimalIsTrueSig{bf, c.keepNull}
sig = &builtinDecimalIsTrueSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue)
case types.ETInt:
sig = &builtinIntIsTrueSig{bf, c.keepNull}
sig = &builtinIntIsTrueSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue)
default:
return nil, errors.Errorf("unexpected types.EvalType %v", argTp)
}
case opcode.IsFalsity:
switch argTp {
case types.ETReal:
sig = &builtinRealIsFalseSig{bf, c.keepNull}
sig = &builtinRealIsFalseSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse)
case types.ETDecimal:
sig = &builtinDecimalIsFalseSig{bf, c.keepNull}
sig = &builtinDecimalIsFalseSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse)
case types.ETInt:
sig = &builtinIntIsFalseSig{bf, c.keepNull}
sig = &builtinIntIsFalseSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse)
default:
return nil, errors.Errorf("unexpected types.EvalType %v", argTp)
Expand All @@ -456,13 +456,12 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
}

type builtinRealIsTrueSig struct {
baseBuiltinFunc
keepNull bool
baseBuiltinIsTrueOrFalseFunc
}

func (b *builtinRealIsTrueSig) Clone() builtinFunc {
newSig := &builtinRealIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig := &builtinRealIsTrueSig{}
newSig.cloneFrom(&b.baseBuiltinIsTrueOrFalseFunc)
return newSig
}

Expand All @@ -481,13 +480,12 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
}

type builtinDecimalIsTrueSig struct {
baseBuiltinFunc
keepNull bool
baseBuiltinIsTrueOrFalseFunc
}

func (b *builtinDecimalIsTrueSig) Clone() builtinFunc {
newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig := &builtinDecimalIsTrueSig{}
newSig.cloneFrom(&b.baseBuiltinIsTrueOrFalseFunc)
return newSig
}

Expand All @@ -506,13 +504,12 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
}

type builtinIntIsTrueSig struct {
baseBuiltinFunc
keepNull bool
baseBuiltinIsTrueOrFalseFunc
}

func (b *builtinIntIsTrueSig) Clone() builtinFunc {
newSig := &builtinIntIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig := &builtinIntIsTrueSig{}
newSig.cloneFrom(&b.baseBuiltinIsTrueOrFalseFunc)
return newSig
}

Expand All @@ -531,13 +528,12 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
}

type builtinRealIsFalseSig struct {
baseBuiltinFunc
keepNull bool
baseBuiltinIsTrueOrFalseFunc
}

func (b *builtinRealIsFalseSig) Clone() builtinFunc {
newSig := &builtinRealIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig := &builtinRealIsFalseSig{}
newSig.cloneFrom(&b.baseBuiltinIsTrueOrFalseFunc)
return newSig
}

Expand All @@ -556,13 +552,12 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
}

type builtinDecimalIsFalseSig struct {
baseBuiltinFunc
keepNull bool
baseBuiltinIsTrueOrFalseFunc
}

func (b *builtinDecimalIsFalseSig) Clone() builtinFunc {
newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig := &builtinDecimalIsFalseSig{}
newSig.cloneFrom(&b.baseBuiltinIsTrueOrFalseFunc)
return newSig
}

Expand All @@ -581,13 +576,12 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
}

type builtinIntIsFalseSig struct {
baseBuiltinFunc
keepNull bool
baseBuiltinIsTrueOrFalseFunc
}

func (b *builtinIntIsFalseSig) Clone() builtinFunc {
newSig := &builtinIntIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig := &builtinIntIsFalseSig{}
newSig.cloneFrom(&b.baseBuiltinIsTrueOrFalseFunc)
return newSig
}

Expand Down
Loading