diff --git a/executor/builder.go b/executor/builder.go index 46eb46388330b..711af3c8318c6 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -5153,7 +5153,12 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor { iterInTbl = storages.IterInTbl producer = storages.Producer } else { + if v.SeedPlan == nil { + b.err = errors.New("cte.seedPlan cannot be nil") + return nil + } // Build seed part. + corCols := plannercore.ExtractOuterApplyCorrelatedCols(v.SeedPlan) seedExec := b.build(v.SeedPlan) if b.err != nil { return nil @@ -5174,10 +5179,15 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor { storageMap[v.CTE.IDForStorage] = &CTEStorages{ResTbl: resTbl, IterInTbl: iterInTbl} // Build recursive part. - recursiveExec := b.build(v.RecurPlan) - if b.err != nil { - return nil + var recursiveExec Executor + if v.RecurPlan != nil { + recursiveExec = b.build(v.RecurPlan) + if b.err != nil { + return nil + } + corCols = append(corCols, plannercore.ExtractOuterApplyCorrelatedCols(v.RecurPlan)...) } + var sel []int if v.CTE.IsDistinct { sel = make([]int, chkSize) @@ -5186,18 +5196,24 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor { } } + var corColHashCodes [][]byte + for _, corCol := range corCols { + corColHashCodes = append(corColHashCodes, getCorColHashCode(corCol)) + } + producer = &cteProducer{ - ctx: b.ctx, - seedExec: seedExec, - recursiveExec: recursiveExec, - resTbl: resTbl, - iterInTbl: iterInTbl, - isDistinct: v.CTE.IsDistinct, - sel: sel, - hasLimit: v.CTE.HasLimit, - limitBeg: v.CTE.LimitBeg, - limitEnd: v.CTE.LimitEnd, - isInApply: v.CTE.IsInApply, + ctx: b.ctx, + seedExec: seedExec, + recursiveExec: recursiveExec, + resTbl: resTbl, + iterInTbl: iterInTbl, + isDistinct: v.CTE.IsDistinct, + sel: sel, + hasLimit: v.CTE.HasLimit, + limitBeg: v.CTE.LimitBeg, + limitEnd: v.CTE.LimitEnd, + corCols: corCols, + corColHashCodes: corColHashCodes, } storageMap[v.CTE.IDForStorage].Producer = producer } diff --git a/executor/cte.go b/executor/cte.go index ebea5e553290a..7c3d4fe128567 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -15,10 +15,12 @@ package executor import ( + "bytes" "context" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" @@ -80,8 +82,11 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() - if e.producer.isInApply { + if e.producer.checkAndUpdateCorColHashCode() { e.producer.reset() + if err = e.producer.reopenTbls(); err != nil { + return err + } } if !e.producer.opened { if err = e.producer.openProducer(ctx, e); err != nil { @@ -107,6 +112,10 @@ func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { func (e *CTEExec) Close() (err error) { e.producer.resTbl.Lock() if !e.producer.closed { + // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. + // It means you can still read resTbl after call closeProducer(). + // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). + // Separating these three function calls is only to follow the abstraction of the volcano model. err = e.producer.closeProducer() } e.producer.resTbl.Unlock() @@ -154,10 +163,9 @@ type cteProducer struct { memTracker *memory.Tracker diskTracker *disk.Tracker - // isInApply indicates whether CTE is in inner side of Apply - // and should resTbl/iterInTbl be reset for each outer row of Apply. - // Because we reset them when SQL is finished instead of when CTEExec.Close() is called. - isInApply bool + // Correlated Column. + corCols []*expression.CorrelatedColumn + corColHashCodes [][]byte } func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err error) { @@ -223,11 +231,6 @@ func (p *cteProducer) closeProducer() (err error) { } } p.closed = true - if p.isInApply { - if err = p.reopenTbls(); err != nil { - return err - } - } return nil } @@ -656,3 +659,20 @@ func (p *cteProducer) checkHasDup(probeKey uint64, } return false, nil } + +func getCorColHashCode(corCol *expression.CorrelatedColumn) (res []byte) { + return codec.HashCode(res, *corCol.Data) +} + +// Return true if cor col has changed. +func (p *cteProducer) checkAndUpdateCorColHashCode() bool { + var changed bool + for i, corCol := range p.corCols { + newHashCode := getCorColHashCode(corCol) + if !bytes.Equal(newHashCode, p.corColHashCodes[i]) { + changed = true + p.corColHashCodes[i] = newHashCode + } + } + return changed +} diff --git a/executor/cte_test.go b/executor/cte_test.go index 0d34fcd5f177a..fd556e133b517 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -499,4 +499,9 @@ func TestCTEShareCorColumn(t *testing.T) { tk.MustQuery("with cte1 as (select t1.c1, (select t2.c2 from t2 where t2.c2 = str_to_date(t1.c2, '%Y-%m-%d')) from t1 inner join t2 on t1.c1 = t2.c1) select /*+ hash_join_build(alias1) */ * from cte1 alias1 inner join cte1 alias2 on alias1.c1 = alias2.c1;").Check(testkit.Rows("1 2020-10-10 1 2020-10-10")) tk.MustQuery("with cte1 as (select t1.c1, (select t2.c2 from t2 where t2.c2 = str_to_date(t1.c2, '%Y-%m-%d')) from t1 inner join t2 on t1.c1 = t2.c1) select /*+ hash_join_build(alias2) */ * from cte1 alias1 inner join cte1 alias2 on alias1.c1 = alias2.c1;").Check(testkit.Rows("1 2020-10-10 1 2020-10-10")) } + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(a int);") + tk.MustExec("insert into t1 values(1), (2);") + tk.MustQuery("SELECT * FROM t1 dt WHERE EXISTS( WITH RECURSIVE qn AS (SELECT a AS b UNION ALL SELECT b+1 FROM qn WHERE b=0 or b = 1) SELECT * FROM qn dtqn1 where exists (select /*+ NO_DECORRELATE() */ b from qn where dtqn1.b+1));").Check(testkit.Rows("1", "2")) } diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index a4df1f01b47b0..117090b897107 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -8374,3 +8374,32 @@ func TestIssue43645(t *testing.T) { rs := tk.MustQuery("WITH tmp AS (SELECT t2.* FROM t2) select (SELECT tmp.col1 FROM tmp WHERE tmp.id=t1.id ) col1, (SELECT tmp.col2 FROM tmp WHERE tmp.id=t1.id ) col2, (SELECT tmp.col3 FROM tmp WHERE tmp.id=t1.id ) col3 from t1;") rs.Sort().Check(testkit.Rows("a aa aaa", "b bb bbb", "c cc ccc")) } + +func TestIssue45033(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2, t3, t4;") + tk.MustExec("create table t1 (c1 int, c2 int, c3 int, primary key(c1, c2));") + tk.MustExec("create table t2 (c2 int, c1 int, primary key(c2, c1));") + tk.MustExec("create table t3 (c4 int, key(c4));") + tk.MustExec("create table t4 (c2 varchar(20) , test_col varchar(50), gen_col varchar(50) generated always as(concat(test_col,'')) virtual not null, unique key(gen_col));") + tk.MustQuery(`select count(1) + from (select ( case + when count(1) + over( + partition by a.c2) >= 50 then 1 + else 0 + end ) alias1, + b.c2 as alias_col1 + from t1 a + left join (select c2 + from t4 f) k + on k.c2 = a.c2 + inner join t2 b + on b.c1 = a.c3) alias2 + where exists (select 1 + from (select distinct alias3.c4 as c2 + from t3 alias3) alias4 + where alias4.c2 = alias2.alias_col1);`).Check(testkit.Rows("0")) +} diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index 5d5fcef35c3e2..2781ea92fa8f2 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -105,6 +105,77 @@ func ExtractCorrelatedCols4PhysicalPlan(p PhysicalPlan) []*expression.Correlated return corCols } +// ExtractOuterApplyCorrelatedCols only extract the correlated columns whose corresponding Apply operator is outside the plan. +// For Plan-1, ExtractOuterApplyCorrelatedCols(CTE-1) will return cor_col_1. +// Plan-1: +// +// Apply_1 +// |_ outerSide +// |_CTEExec(CTE-1) +// +// CTE-1 +// |_Selection(cor_col_1) +// +// For Plan-2, the result of ExtractOuterApplyCorrelatedCols(CTE-2) will not return cor_col_3. +// Because Apply_3 is inside CTE-2. +// Plan-2: +// +// Apply_2 +// |_ outerSide +// |_ Selection(cor_col_2) +// |_CTEExec(CTE-2) +// CTE-2 +// |_ Apply_3 +// |_ outerSide +// |_ innerSide(cor_col_3) +func ExtractOuterApplyCorrelatedCols(p PhysicalPlan) []*expression.CorrelatedColumn { + return extractOuterApplyCorrelatedColsHelper(p, []*expression.Schema{}) +} + +func extractOuterApplyCorrelatedColsHelper(p PhysicalPlan, outerSchemas []*expression.Schema) []*expression.CorrelatedColumn { + if p == nil { + return nil + } + curCorCols := p.ExtractCorrelatedCols() + newCorCols := make([]*expression.CorrelatedColumn, 0, len(curCorCols)) + + // If a corresponding Apply is found inside this PhysicalPlan, ignore it. + for _, corCol := range curCorCols { + var found bool + for _, outerSchema := range outerSchemas { + if outerSchema.ColumnIndex(&corCol.Column) != -1 { + found = true + break + } + } + if !found { + newCorCols = append(newCorCols, corCol) + } + } + + switch v := p.(type) { + case *PhysicalApply: + var outerPlan PhysicalPlan + if v.InnerChildIdx == 0 { + outerPlan = v.Children()[1] + } else { + outerPlan = v.Children()[0] + } + outerSchemas = append(outerSchemas, outerPlan.Schema()) + newCorCols = append(newCorCols, extractOuterApplyCorrelatedColsHelper(v.Children()[0], outerSchemas)...) + newCorCols = append(newCorCols, extractOuterApplyCorrelatedColsHelper(v.Children()[1], outerSchemas)...) + case *PhysicalCTE: + newCorCols = append(newCorCols, extractOuterApplyCorrelatedColsHelper(v.SeedPlan, outerSchemas)...) + newCorCols = append(newCorCols, extractOuterApplyCorrelatedColsHelper(v.RecurPlan, outerSchemas)...) + default: + for _, child := range p.Children() { + newCorCols = append(newCorCols, extractOuterApplyCorrelatedColsHelper(child, outerSchemas)...) + } + } + + return newCorCols +} + // decorrelateSolver tries to convert apply plan to join plan. type decorrelateSolver struct{}