Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: fix unexpected NOT_NULL flags #19029

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/explaintest/r/explain.result
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ drop view if exists v;
create view v as select cast(replace(substring_index(substring_index("",',',1),':',-1),'"','') as CHAR(32)) as event_id;
desc v;
Field Type Null Key Default Extra
event_id varchar(32) YES NULL
event_id varchar(32) NO NULL
29 changes: 15 additions & 14 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -613,20 +613,21 @@ HashJoin_7 8002.00 root right outer join, equal:[eq(test.t.nb, test.t.nb)]
└─TableFullScan_12 10000.00 cop[tikv] table:tb keep order:false, stats:pseudo
explain select ifnull(t.a, 1) in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t;
id estRows task access object operator info
Projection_12 10000.00 root Column#14
└─Apply_14 10000.00 root CARTESIAN left outer semi join, other cond:eq(ifnull(test.t.a, 1), Column#13)
├─TableReader_16(Build) 10000.00 root data:TableFullScan_15
│ └─TableFullScan_15 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─HashAgg_19(Probe) 1.00 root funcs:count(Column#15)->Column#13
└─HashJoin_20 9.99 root inner join, equal:[eq(test.t.a, test.t.a)]
├─HashAgg_30(Build) 7.99 root group by:test.t.a, funcs:count(Column#16)->Column#15, funcs:firstrow(test.t.a)->test.t.a
│ └─TableReader_31 7.99 root data:HashAgg_25
│ └─HashAgg_25 7.99 cop[tikv] group by:test.t.a, funcs:count(1)->Column#16
│ └─Selection_29 9.99 cop[tikv] eq(test.t.a, test.t.a), not(isnull(test.t.a))
│ └─TableFullScan_28 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─TableReader_24(Probe) 9.99 root data:Selection_23
└─Selection_23 9.99 cop[tikv] eq(test.t.a, test.t.a), not(isnull(test.t.a))
└─TableFullScan_22 10000.00 cop[tikv] table:s keep order:false, stats:pseudo
Projection_13 10000.00 root Column#14
└─Apply_15 10000.00 root left outer semi join, equal:[eq(Column#15, Column#13)]
├─Projection_16(Build) 10000.00 root test.t.a, ifnull(test.t.a, 1)->Column#15
│ └─TableReader_18 10000.00 root data:TableFullScan_17
│ └─TableFullScan_17 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─HashAgg_21(Probe) 1.00 root funcs:count(Column#17)->Column#13
└─HashJoin_22 9.99 root inner join, equal:[eq(test.t.a, test.t.a)]
├─HashAgg_32(Build) 7.99 root group by:test.t.a, funcs:count(Column#18)->Column#17, funcs:firstrow(test.t.a)->test.t.a
│ └─TableReader_33 7.99 root data:HashAgg_27
│ └─HashAgg_27 7.99 cop[tikv] group by:test.t.a, funcs:count(1)->Column#18
│ └─Selection_31 9.99 cop[tikv] eq(test.t.a, test.t.a), not(isnull(test.t.a))
│ └─TableFullScan_30 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─TableReader_26(Probe) 9.99 root data:Selection_25
└─Selection_25 9.99 cop[tikv] eq(test.t.a, test.t.a), not(isnull(test.t.a))
└─TableFullScan_24 10000.00 cop[tikv] table:s keep order:false, stats:pseudo
drop table if exists t;
create table t(a int);
explain select * from t where _tidb_rowid = 0;
Expand Down
6 changes: 6 additions & 0 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
fieldTp.Flen, fieldTp.Decimal = 0, types.UnspecifiedLength
types.SetBinChsClnFlag(fieldTp)
}
// it's hard to distinguish not null flag, because
// 1. the args[0] may be null
// 2. no ELSE part
// and MySQL 8.0 also does not provide any not null flag in CASE WHEN
fieldTp.Flag &^= mysql.NotNullFlag

argTps := make([]types.EvalType, 0, l)
for i := 0; i < l-1; i += 2 {
if args[i], err = wrapWithIsTrue(ctx, true, args[i], false); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,8 @@ func (c *sysDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
return nil, err
}
bf.tp.Flen, bf.tp.Decimal = 19, 0
// Illegal parameters have been filtered out in the parser, so the result is always not null.
bf.tp.Flag |= mysql.NotNullFlag

var sig builtinFunc
if len(args) == 1 {
Expand Down
41 changes: 34 additions & 7 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package expression

import (
"fmt"
"sync"

"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
Expand Down Expand Up @@ -56,7 +57,9 @@ func NewNull() *Constant {

// Constant stands for a constant value.
type Constant struct {
Value types.Datum
Value types.Datum
// once protects the changes of RetType
lock sync.Mutex
RetType *types.FieldType
// DeferredExpr holds deferred function in PlanCache cached plan.
// it's only used to represent non-deterministic functions(see expression.DeferredFunctions)
Expand All @@ -70,6 +73,19 @@ type Constant struct {
collationInfo
}

// Clone implements Expression interface.
func (c *Constant) Clone() Expression {
con := &Constant{
Value: c.Value,
RetType: c.RetType,
DeferredExpr: c.DeferredExpr,
ParamMarker: c.ParamMarker,
hashcode: c.hashcode,
collationInfo: c.collationInfo,
}
return con
}

// ParamMarker indicates param provided by COM_STMT_EXECUTE.
type ParamMarker struct {
ctx sessionctx.Context
Expand Down Expand Up @@ -98,12 +114,6 @@ func (c *Constant) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%q", c)), nil
}

// Clone implements Expression interface.
func (c *Constant) Clone() Expression {
con := *c
return &con
}

// GetType implements Expression interface.
func (c *Constant) GetType() *types.FieldType {
if c.ParamMarker != nil {
Expand All @@ -114,6 +124,23 @@ func (c *Constant) GetType() *types.FieldType {
types.DefaultParamTypeForValue(dt.GetValue(), tp)
return tp
}
if !c.Value.IsNull() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check !c.Value.IsNull() && c.RetType.Flag&mysql.NotNullFlag == 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we can remove the sync.Once.

// c.once.Do(func() {
// c.RetType = c.RetType.Clone()
// c.RetType.Flag |= mysql.NotNullFlag
// })
// tp := c.RetType.Clone()
// tp.Flag |= mysql.NotNullFlag
// return tp

c.lock.Lock()
c.RetType.Flag |= mysql.NotNullFlag
c.lock.Unlock()
} else {
c.lock.Lock()
c.RetType.Flag &^= mysql.NotNullFlag
c.lock.Unlock()
}
return c.RetType
}

Expand Down
13 changes: 13 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2875,6 +2875,19 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) {
result.Check(testkit.Rows("<nil> 4"))
result = tk.MustQuery("select * from t where b = case when a is null then 4 when a = 'str5' then 7 else 9 end")
result.Check(testkit.Rows("<nil> 4"))
result = tk.MustQuery(`SELECT -Max(+23) * -+Cast(--10 AS SIGNED) * -CASE
WHEN 0 > 85 THEN NULL
WHEN NOT
CASE +55
WHEN +( +82 ) + -89 * -69 THEN +Count(-88)
WHEN +CASE 57
WHEN +89 THEN -89 * Count(*)
WHEN 17 THEN NULL
END THEN ( -10 )
END IS NULL THEN NULL
ELSE 83 + 48
END AS col0; `)
result.Check(testkit.Rows("-30130"))
// test warnings
tk.MustQuery("select case when b=0 then 1 else 1/b end from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
Expand Down
2 changes: 1 addition & 1 deletion expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func ParseSimpleExprCastWithTableInfo(ctx sessionctx.Context, exprStr string, ta
if err != nil {
return nil, err
}
e = BuildCastFunction(ctx, e, targetFt)
e = BuildCastFunction(ctx, e, targetFt.Clone())
return e, nil
}

Expand Down
Loading