Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: fix unexpected NotNullFlag in case when expr ret type #23102

Merged
merged 7 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}

fieldTp := types.AggFieldType(fieldTps)
// Here we turn off NotNullFlag. Because if all when-clauses are false,
// the result of case-when expr is NULL.
types.SetTypeFlag(&fieldTp.Flag, mysql.NotNullFlag, false)
tp := fieldTp.EvalType()

if tp == types.ETInt {
Expand Down
10 changes: 10 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2961,6 +2961,16 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) {
result.Check(testkit.Rows("<nil> 4"))
result = tk.MustQuery("select * from t where b = case when a is null then 4 when a = 'str5' then 7 else 9 end")
result.Check(testkit.Rows("<nil> 4"))

// return type of case when expr should not include NotNullFlag. issue-23036
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(c1 int not null)")
tk.MustExec("insert into t1 values(1)")
result = tk.MustQuery("select (case when null then c1 end) is null from t1")
result.Check(testkit.Rows("1"))
result = tk.MustQuery("select (case when null then c1 end) is not null from t1")
result.Check(testkit.Rows("0"))

// test warnings
tk.MustQuery("select case when b=0 then 1 else 1/b end from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
Expand Down
7 changes: 4 additions & 3 deletions types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ func AggregateEvalType(fts []*FieldType, flag *uint) EvalType {
}
lft = rft
}
setTypeFlag(flag, mysql.UnsignedFlag, unsigned)
setTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString)
SetTypeFlag(flag, mysql.UnsignedFlag, unsigned)
SetTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString)
return aggregatedEvalType
}

Expand All @@ -160,7 +160,8 @@ func mergeEvalType(lhs, rhs EvalType, lft, rft *FieldType, isLHSUnsigned, isRHSU
return ETInt
}

func setTypeFlag(flag *uint, flagItem uint, on bool) {
// SetTypeFlag turns the flagItem on or off.
func SetTypeFlag(flag *uint, flagItem uint, on bool) {
if on {
*flag |= flagItem
} else {
Expand Down