From e78dac5754dad5e2ada0c4ecb84668b56f8d750d Mon Sep 17 00:00:00 2001 From: likzn <1020193211@qq.com> Date: Fri, 6 May 2022 15:24:56 +0800 Subject: [PATCH 1/6] cherry pick #34133 to release-5.3 Signed-off-by: ti-srebot --- bindinfo/handle.go | 3 +- cmd/explaintest/r/cte.result | 193 ++++++++++++++++++++++++++++++++ cmd/explaintest/t/cte.test | 91 +++++++++++++++ parser/ast/dml.go | 56 ++++++++- parser/parser.go | 3 + parser/parser.y | 3 + planner/core/preprocess.go | 47 +++++++- planner/core/preprocess_test.go | 66 +++++++++++ 8 files changed, 451 insertions(+), 11 deletions(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 8d22b85156f06..836e42a445002 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -863,8 +863,7 @@ func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string withIdx := strings.Index(bindSQL, "WITH") restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreSpacesAroundBinaryOperation|format.RestoreStringWithoutCharset|format.RestoreNameBackQuotes, &withSb) restoreCtx.DefaultDB = defaultDB - err := n.With.Restore(restoreCtx) - if err != nil { + if err := n.With.Restore(restoreCtx); err != nil { logutil.BgLogger().Debug("[sql-bind] restore SQL failed", zap.Error(err)) return "" } diff --git a/cmd/explaintest/r/cte.result b/cmd/explaintest/r/cte.result index 4f3d979c45001..e902c0d9317c3 100644 --- a/cmd/explaintest/r/cte.result +++ b/cmd/explaintest/r/cte.result @@ -607,3 +607,196 @@ c1 c1 c1 1 1 1 2 2 2 3 3 3 +<<<<<<< HEAD +======= +// Test CTE as inner side of Apply +drop table if exists t1, t2; +create table t1(c1 int, c2 int); +insert into t1 values(2, 1); +insert into t1 values(2, 2); +create table t2(c1 int, c2 int); +insert into t2 values(1, 1); +insert into t2 values(3, 2); +explain select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +id estRows task access object operator info +Projection_18 10000.00 root test.t1.c1, test.t1.c2 +└─Apply_20 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#8), if(ne(Column#9, 0), NULL, 1)), or(eq(Column#10, 0), if(isnull(test.t1.c1), NULL, 0))) + ├─TableReader_22(Build) 10000.00 root data:TableFullScan_21 + │ └─TableFullScan_21 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─HashAgg_23(Probe) 1.00 root funcs:max(Column#13)->Column#8, funcs:sum(Column#14)->Column#9, funcs:count(1)->Column#10 + └─Projection_27 10.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#14 + └─CTEFullScan_25 10.00 root CTE:cte1 data:CTE_0 +CTE_0 10.00 root Non-Recursive CTE +└─Projection_13(Seed Part) 10.00 root test.t2.c1 + └─TableReader_16 10.00 root data:Selection_15 + └─Selection_15 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) + └─TableFullScan_14 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +c1 c2 +2 1 +// Test semi apply. +insert into t1 values(2, 3); +explain select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +id estRows task access object operator info +Apply_17 10000.00 root CARTESIAN semi join +├─TableReader_19(Build) 10000.00 root data:TableFullScan_18 +│ └─TableFullScan_18 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─CTEFullScan_20(Probe) 10.00 root CTE:cte1 data:CTE_0 +CTE_0 10.00 root Non-Recursive CTE +└─Projection_11(Seed Part) 10.00 root test.t2.c1 + └─TableReader_14 10.00 root data:Selection_13 + └─Selection_13 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) + └─TableFullScan_12 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +c1 c2 +2 1 +2 2 +// Same as above, but test recursive cte. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); +id estRows task access object operator info +Projection_26 10000.00 root test.t1.c1, test.t1.c2 +└─Apply_28 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#14), if(ne(Column#15, 0), NULL, 1)), or(eq(Column#16, 0), if(isnull(test.t1.c1), NULL, 0))) + ├─TableReader_30(Build) 10000.00 root data:TableFullScan_29 + │ └─TableFullScan_29 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─HashAgg_31(Probe) 1.00 root funcs:max(Column#19)->Column#14, funcs:sum(Column#20)->Column#15, funcs:count(1)->Column#16 + └─Projection_35 20.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#20 + └─CTEFullScan_33 20.00 root CTE:cte1 data:CTE_0 +CTE_0 20.00 root Recursive CTE, limit(offset:0, count:1) +├─Projection_19(Seed Part) 10.00 root test.t2.c1 +│ └─TableReader_22 10.00 root data:Selection_21 +│ └─Selection_21 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) +│ └─TableFullScan_20 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_23(Recursive Part) 10.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1 + └─CTETable_24 10.00 root Scan on CTE_0 +select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); +c1 c2 +2 1 +2 3 +explain select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); +id estRows task access object operator info +Apply_25 10000.00 root CARTESIAN semi join +├─TableReader_27(Build) 10000.00 root data:TableFullScan_26 +│ └─TableFullScan_26 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─CTEFullScan_28(Probe) 20.00 root CTE:cte1 data:CTE_0 +CTE_0 20.00 root Recursive CTE, limit(offset:0, count:10) +├─Projection_17(Seed Part) 10.00 root test.t2.c1 +│ └─TableReader_20 10.00 root data:Selection_19 +│ └─Selection_19 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) +│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_21(Recursive Part) 10.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1 + └─CTETable_22 10.00 root Scan on CTE_0 +select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); +c1 c2 +2 1 +2 2 +// Test correlated col is in recursive part. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +id estRows task access object operator info +Projection_24 10000.00 root test.t1.c1, test.t1.c2 +└─Apply_26 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#18), if(ne(Column#19, 0), NULL, 1)), or(eq(Column#20, 0), if(isnull(test.t1.c1), NULL, 0))) + ├─TableReader_28(Build) 10000.00 root data:TableFullScan_27 + │ └─TableFullScan_27 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─HashAgg_29(Probe) 1.00 root funcs:max(Column#23)->Column#18, funcs:sum(Column#24)->Column#19, funcs:count(1)->Column#20 + └─Projection_33 18000.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#24 + └─CTEFullScan_31 18000.00 root CTE:cte1 data:CTE_0 +CTE_0 18000.00 root Recursive CTE +├─TableReader_19(Seed Part) 10000.00 root data:TableFullScan_18 +│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_20(Recursive Part) 8000.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1, cast(plus(test.t2.c2, 1), int(11))->test.t2.c2 + └─Selection_21 8000.00 root eq(test.t2.c2, test.t1.c2) + └─CTETable_22 10000.00 root Scan on CTE_0 +select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +c1 c2 +explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +id estRows task access object operator info +Apply_23 10000.00 root CARTESIAN semi join +├─TableReader_25(Build) 10000.00 root data:TableFullScan_24 +│ └─TableFullScan_24 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─CTEFullScan_26(Probe) 18000.00 root CTE:cte1 data:CTE_0 +CTE_0 18000.00 root Recursive CTE +├─TableReader_17(Seed Part) 10000.00 root data:TableFullScan_16 +│ └─TableFullScan_16 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_18(Recursive Part) 8000.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1, cast(plus(test.t2.c2, 1), int(11))->test.t2.c2 + └─Selection_19 8000.00 root eq(test.t2.c2, test.t1.c2) + └─CTETable_20 10000.00 root Scan on CTE_0 +select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +c1 c2 +2 1 +2 2 +2 3 +use test; +drop table if exists t1, t2; +drop view if exists v1; +create table t1 (a int); +insert into t1 values (0), (1), (2), (3), (4); +create table t2 (a int); +insert into t2 values (1), (2), (3), (4), (5); +drop view if exists v1,v2; +create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create database if not exists test1; +use test1; +select * from test.v1; +a +5 +4 +3 +2 +select * from test.v2; +a +6 +5 +4 +3 +use test; +drop table if exists t ,t1, t2; +create table t(a int); +insert into t values (0); +create table t1 (b int); +insert into t1 values (0); +create table t2 (c int); +insert into t2 values (0); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +b c +0 0 +use test; +drop table if exists t11111; +create table t11111 (d int); +insert into t11111 values (123), (223), (323); +drop view if exists v1; +create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111; +use test1; +select * from test.v1; +d +123 +223 +323 +use test; +drop table if exists t1; +create table t1 (a int); +insert into t1 values (1); +drop view if exists v1; +create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1; +use test1; +select * from test.v1; +name_exp_1 +30 +use test; +drop table if exists t1,t2; +create table t1 (a int); +insert into t1 values (0), (1); +create table t2 (b int); +insert into t2 values (4), (5); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +a b +0 5 +0 4 +1 5 +1 4 +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) diff --git a/cmd/explaintest/t/cte.test b/cmd/explaintest/t/cte.test index b5fda97071cc8..4c62c44ca4438 100644 --- a/cmd/explaintest/t/cte.test +++ b/cmd/explaintest/t/cte.test @@ -226,3 +226,94 @@ create table tpk1(c1 int primary key); insert into tpk1 values(1), (2), (3); explain with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1; with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1; +<<<<<<< HEAD +======= +#case 34 +--echo // Test CTE as inner side of Apply +drop table if exists t1, t2; +create table t1(c1 int, c2 int); +insert into t1 values(2, 1); +insert into t1 values(2, 2); +create table t2(c1 int, c2 int); +insert into t2 values(1, 1); +insert into t2 values(3, 2); +explain select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); + +--echo // Test semi apply. +insert into t1 values(2, 3); +explain select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); + +--echo // Same as above, but test recursive cte. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); +select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); + +explain select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); +select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); + +--echo // Test correlated col is in recursive part. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); + +explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +# Some cases to Test Create View With CTE and checkout Database +# With name is the same as the table name +use test; +drop table if exists t1, t2; +drop view if exists v1; +create table t1 (a int); +insert into t1 values (0), (1), (2), (3), (4); +create table t2 (a int); +insert into t2 values (1), (2), (3), (4), (5); +drop view if exists v1,v2; +create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create database if not exists test1; +use test1; +select * from test.v1; +select * from test.v2; +# case +use test; +drop table if exists t ,t1, t2; +create table t(a int); +insert into t values (0); +create table t1 (b int); +insert into t1 values (0); +create table t2 (c int); +insert into t2 values (0); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t11111; +create table t11111 (d int); +insert into t11111 values (123), (223), (323); +drop view if exists v1; +create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t1; +create table t1 (a int); +insert into t1 values (1); +drop view if exists v1; +create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t1,t2; +create table t1 (a int); +insert into t1 values (0), (1); +create table t2 (b int); +insert into t2 values (4), (5); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 2349602da70ed..1b7a29b01b0be 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1046,6 +1046,51 @@ type CommonTableExpression struct { Name model.CIStr Query *SubqueryExpr ColNameList []model.CIStr + IsRecursive bool +} + +// Restore implements Node interface +func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error { + ctx.WriteName(c.Name.String()) + if c.IsRecursive { + // If the CTE is recursive, we should make it visible for the CTE's query. + // Otherwise, we should put it to stack after building the CTE's query. + ctx.RecordCTEName(c.Name.L) + } + if len(c.ColNameList) > 0 { + ctx.WritePlain(" (") + for j, name := range c.ColNameList { + if j != 0 { + ctx.WritePlain(", ") + } + ctx.WriteName(name.String()) + } + ctx.WritePlain(")") + } + ctx.WriteKeyWord(" AS ") + err := c.Query.Restore(ctx) + if err != nil { + return err + } + if !c.IsRecursive { + ctx.RecordCTEName(c.Name.L) + } + return nil +} + +// Accept implements Node interface +func (c *CommonTableExpression) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(c) + if skipChildren { + return v.Leave(newNode) + } + + node, ok := c.Query.Accept(v) + if !ok { + return c, false + } + c.Query = node.(*SubqueryExpr) + return v.Leave(c) } type WithClause struct { @@ -1113,6 +1158,7 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if i != 0 { ctx.WritePlain(", ") } +<<<<<<< HEAD ctx.WriteName(cte.Name.String()) if n.IsRecursive { // If the CTE is recursive, we should make it visible for the CTE's query. @@ -1137,6 +1183,11 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if !n.IsRecursive { ctx.CTENames = append(ctx.CTENames, cte.Name.L) } +======= + if err := cte.Restore(ctx); err != nil { + return err + } +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) } ctx.WritePlain(" ") return nil @@ -1149,11 +1200,9 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) { } for _, cte := range n.CTEs { - node, ok := cte.Query.Accept(v) - if !ok { + if _, ok := cte.Accept(v); !ok { return n, false } - cte.Query = node.(*SubqueryExpr) } return v.Leave(n) } @@ -1177,6 +1226,7 @@ func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error { }() } if !n.WithBeforeBraces && n.With != nil { + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err diff --git a/parser/parser.go b/parser/parser.go index 82442bf2a19d3..dfc478f0065fb 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -16776,6 +16776,9 @@ yynewstate: { ws := yyS[yypt-0].item.(*ast.WithClause) ws.IsRecursive = true + for _, cte := range ws.CTEs { + cte.IsRecursive = true + } parser.yyVAL.item = ws } case 1506: diff --git a/parser/parser.y b/parser/parser.y index da9359a3e73ad..78894f442c81f 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -8441,6 +8441,9 @@ WithClause: { ws := $3.(*ast.WithClause) ws.IsRecursive = true + for _, cte := range ws.CTEs { + cte.IsRecursive = true + } $$ = ws } diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 440c96bf915e8..a42663fbb6c8c 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -106,7 +106,16 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode { // Preprocess resolves table names of the node, and checks some statements validation. // preprocessReturn used to extract the infoschema for the tableName and the timestamp from the asof clause. func Preprocess(ctx sessionctx.Context, node ast.Node, preprocessOpt ...PreprocessOpt) error { +<<<<<<< HEAD v := preprocessor{ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0), withName: make(map[string]interface{})} +======= + v := preprocessor{ + ctx: ctx, + tableAliasInJoin: make([]map[string]interface{}, 0), + preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)}, + staleReadProcessor: staleread.NewStaleReadProcessor(ctx), + } +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) for _, optFn := range preprocessOpt { optFn(&v) } @@ -159,6 +168,12 @@ type PreprocessExecuteISUpdate struct { Node ast.Node } +// preprocessWith is used to record info from WITH statements like CTE name. +type preprocessWith struct { + cteCanUsed []string + cteBeforeOffset []int +} + // preprocessor is an ast.Visitor that preprocess // ast Nodes parsed from parser. type preprocessor struct { @@ -170,7 +185,7 @@ type preprocessor struct { // tableAliasInJoin is a stack that keeps the table alias names for joins. // len(tableAliasInJoin) may bigger than 1 because the left/right child of join may be subquery that contains `JOIN` tableAliasInJoin []map[string]interface{} - withName map[string]interface{} + preprocessWith *preprocessWith // values that may be returned *PreprocessorReturn @@ -295,9 +310,12 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { } case *ast.GroupByClause: p.checkGroupBy(node) - case *ast.WithClause: - for _, cte := range node.CTEs { - p.withName[cte.Name.L] = struct{}{} + case *ast.CommonTableExpression, *ast.SubqueryExpr: + with := p.preprocessWith + beforeOffset := len(with.cteCanUsed) + with.cteBeforeOffset = append(with.cteBeforeOffset, beforeOffset) + if cteNode, exist := node.(*ast.CommonTableExpression); exist && cteNode.IsRecursive { + with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L) } default: p.flag &= ^parentIsJoin @@ -525,6 +543,19 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) { if x.Kind == ast.BRIEKindRestore { p.flag &= ^inCreateOrDropTable } + case *ast.CommonTableExpression, *ast.SubqueryExpr: + with := p.preprocessWith + lenWithCteBeforeOffset := len(with.cteBeforeOffset) + if lenWithCteBeforeOffset < 1 { + p.err = ErrInternal.GenWithStack("len(cteBeforeOffset) is less than one.Maybe it was deleted in somewhere.Should match in Enter and Leave") + break + } + beforeOffset := with.cteBeforeOffset[lenWithCteBeforeOffset-1] + with.cteBeforeOffset = with.cteBeforeOffset[:lenWithCteBeforeOffset-1] + with.cteCanUsed = with.cteCanUsed[:beforeOffset] + if cteNode, exist := x.(*ast.CommonTableExpression); exist { + with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L) + } } return in, p.err == nil @@ -1367,8 +1398,11 @@ func (p *preprocessor) stmtType() string { func (p *preprocessor) handleTableName(tn *ast.TableName) { if tn.Schema.L == "" { - if _, ok := p.withName[tn.Name.L]; ok { - return + + for _, cte := range p.preprocessWith.cteCanUsed { + if cte == tn.Name.L { + return + } } currentDB := p.ctx.GetSessionVars().CurrentDB @@ -1376,6 +1410,7 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) { p.err = errors.Trace(ErrNoDB) return } + tn.Schema = model.NewCIStr(currentDB) } diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index ff25091b4eeb8..5726fc0a8020c 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -15,7 +15,12 @@ package core_test import ( +<<<<<<< HEAD "context" +======= + "strings" + "testing" +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) . "github.com/pingcap/check" "github.com/pingcap/errors" @@ -26,6 +31,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -364,3 +370,63 @@ func (s *testValidatorSuite) TestDropGlobalTempTable(c *C) { s.runSQL(c, "drop global temporary table temp, ltemp1", false, core.ErrDropTableOnTemporaryTable) s.runSQL(c, "drop global temporary table test2.temp2, temp1", false, nil) } + +func TestPreprocessCTE(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t, t1, t2;") + tk.MustExec("create table t (c int);insert into t values (1), (2), (3), (4), (5);") + tk.MustExec("create table t1 (a int);insert into t1 values (0), (1), (2), (3), (4);") + tk.MustExec("create table t2 (b int);insert into t2 values (1), (2), (3), (4), (5);") + tk.MustExec("create table t11111 (d int);insert into t11111 values (1), (2), (3), (4), (5);") + tk.MustExec("drop table if exists tbl_1;\nCREATE TABLE `tbl_1` (\n `col_2` char(65) CHARACTER SET utf8 COLLATE utf8_bin DEFAULT NULL,\n `col_3` int(11) NOT NULL\n);") + testCases := []struct { + before string + after string + }{ + { + "create view v1 as WITH t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1;", + "CREATE ALGORITHM = UNDEFINED DEFINER = CURRENT_USER SQL SECURITY DEFINER VIEW `test`.`v1` AS WITH `t1` AS (SELECT `a` FROM `test`.`t2` WHERE `t2`.`a`=3 UNION SELECT `t2`.`a`+1 FROM (`test`.`t1`) JOIN `test`.`t2` WHERE `t1`.`a`=`t2`.`a`) SELECT * FROM `t1`", + }, + { + "WITH t1 AS ( SELECT(WITH t1 AS ( WITH qn AS ( SELECT 10 * a AS a FROM t1 ) SELECT 10 * a AS a FROM qn ) SELECT * FROM t1 LIMIT 1 ) FROM t2 WHERE t2.b = 3 UNION SELECT t2.b + 1 FROM t1, t2 WHERE t1.a = t2.b) SELECT * FROM t1", + "WITH `t1` AS (SELECT (WITH `t1` AS (WITH `qn` AS (SELECT 10*`a` AS `a` FROM `test`.`t1`) SELECT 10*`a` AS `a` FROM `qn`) SELECT * FROM `t1` LIMIT 1) FROM `test`.`t2` WHERE `t2`.`b`=3 UNION SELECT `t2`.`b`+1 FROM (`test`.`t1`) JOIN `test`.`t2` WHERE `t1`.`a`=`t2`.`b`) SELECT * FROM `t1`", + }, + { + "with recursive cte_8932 (col_34891,col_34892) AS ( with recursive cte_8932 (col_34893,col_34894,col_34895) AS ( with tbl_1 (col_34896,col_34897,col_34898,col_34899) AS ( select 1, \"2\",3,col_3 from tbl_1 ) select cte_as_8958.col_34896,cte_as_8958.col_34898,cte_as_8958.col_34899 from tbl_1 as cte_as_8958 UNION DISTINCT select col_34893 + 1,concat(col_34894, 1),col_34895 + 1 from cte_8932 where col_34893 < 5 ) select cte_as_8959.col_34893,cte_as_8959.col_34895 from cte_8932 as cte_as_8959 ) select * from cte_8932 as cte_as_8960 order by cte_as_8960.col_34891,cte_as_8960.col_34892;", + "WITH RECURSIVE `cte_8932` (`col_34891`, `col_34892`) AS (WITH RECURSIVE `cte_8932` (`col_34893`, `col_34894`, `col_34895`) AS (WITH `tbl_1` (`col_34896`, `col_34897`, `col_34898`, `col_34899`) AS (SELECT 1,_UTF8MB4'2',3,`col_3` FROM `test`.`tbl_1`) SELECT `cte_as_8958`.`col_34896`,`cte_as_8958`.`col_34898`,`cte_as_8958`.`col_34899` FROM `tbl_1` AS `cte_as_8958` UNION SELECT `col_34893`+1,CONCAT(`col_34894`, 1),`col_34895`+1 FROM `cte_8932` WHERE `col_34893`<5) SELECT `cte_as_8959`.`col_34893`,`cte_as_8959`.`col_34895` FROM `cte_8932` AS `cte_as_8959`) SELECT * FROM `cte_8932` AS `cte_as_8960` ORDER BY `cte_as_8960`.`col_34891`,`cte_as_8960`.`col_34892`", + }, + { + "with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1;", + "WITH `t1` AS (WITH `t11` AS (SELECT * FROM `test`.`t`) SELECT * FROM (`test`.`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "with t1 as (with t1 as (select * from t) select * from t1, t2) select * from t1;", + "WITH `t1` AS (WITH `t1` AS (SELECT * FROM `test`.`t`) SELECT * FROM (`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "WITH t1 AS ( WITH t1 AS ( SELECT * FROM t ) SELECT ( WITH t2 AS ( SELECT * FROM t ) SELECT * FROM t limit 1 ) FROM t1, t2 ) \n\nSELECT\n* \nFROM\n\tt1;", + "WITH `t1` AS (WITH `t1` AS (SELECT * FROM `test`.`t`) SELECT (WITH `t2` AS (SELECT * FROM `test`.`t`) SELECT * FROM `test`.`t` LIMIT 1) FROM (`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "WITH t123 AS (WITH t11111 AS ( SELECT * FROM test.t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, test.t2 ) SELECT * FROM t11111;", + "WITH `t123` AS (WITH `t11111` AS (SELECT * FROM `test`.`t1`) SELECT (WITH `t2` AS (SELECT (WITH `t23` AS (SELECT * FROM `t11111`) SELECT * FROM `t23` LIMIT 1) FROM `t11111`) SELECT * FROM `t2` LIMIT 1) FROM (`t11111`) JOIN `test`.`t2`) SELECT * FROM `test`.`t11111`", + }, + } + for _, tc := range testCases { + stmts, warnings, err := parser.New().ParseSQL(tc.before) + require.Len(t, warnings, 0) + require.NoError(t, err) + require.Len(t, stmts, 1) + + err = core.Preprocess(tk.Session(), stmts[0]) + require.NoError(t, err) + + var rs strings.Builder + err = stmts[0].Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &rs)) + require.NoError(t, err) + require.Equal(t, tc.after, rs.String()) + } +} From 1fae9149294a3c1af1b5b2e66568e56341f55928 Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 10 Aug 2022 17:16:19 +0800 Subject: [PATCH 2/6] *: fix CI Signed-off-by: xhe --- cmd/explaintest/r/cte.result | 3 --- cmd/explaintest/t/cte.test | 3 --- parser/ast/dml.go | 27 --------------------------- planner/core/preprocess.go | 11 +++-------- planner/core/preprocess_test.go | 3 --- 5 files changed, 3 insertions(+), 44 deletions(-) diff --git a/cmd/explaintest/r/cte.result b/cmd/explaintest/r/cte.result index e902c0d9317c3..1ac88226b7cec 100644 --- a/cmd/explaintest/r/cte.result +++ b/cmd/explaintest/r/cte.result @@ -607,8 +607,6 @@ c1 c1 c1 1 1 1 2 2 2 3 3 3 -<<<<<<< HEAD -======= // Test CTE as inner side of Apply drop table if exists t1, t2; create table t1(c1 int, c2 int); @@ -799,4 +797,3 @@ a b 0 4 1 5 1 4 ->>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) diff --git a/cmd/explaintest/t/cte.test b/cmd/explaintest/t/cte.test index 4c62c44ca4438..0c1623caa81dd 100644 --- a/cmd/explaintest/t/cte.test +++ b/cmd/explaintest/t/cte.test @@ -226,8 +226,6 @@ create table tpk1(c1 int primary key); insert into tpk1 values(1), (2), (3); explain with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1; with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1; -<<<<<<< HEAD -======= #case 34 --echo // Test CTE as inner side of Apply drop table if exists t1, t2; @@ -316,4 +314,3 @@ drop view if exists v1; create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1; use test1; select * from test.v1; ->>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 1b7a29b01b0be..1c237aa903716 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1158,36 +1158,9 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if i != 0 { ctx.WritePlain(", ") } -<<<<<<< HEAD - ctx.WriteName(cte.Name.String()) - if n.IsRecursive { - // If the CTE is recursive, we should make it visible for the CTE's query. - // Otherwise, we should put it to stack after building the CTE's query. - ctx.CTENames = append(ctx.CTENames, cte.Name.L) - } - if len(cte.ColNameList) > 0 { - ctx.WritePlain(" (") - for j, name := range cte.ColNameList { - if j != 0 { - ctx.WritePlain(", ") - } - ctx.WriteName(name.String()) - } - ctx.WritePlain(")") - } - ctx.WriteKeyWord(" AS ") - err := cte.Query.Restore(ctx) - if err != nil { - return err - } - if !n.IsRecursive { - ctx.CTENames = append(ctx.CTENames, cte.Name.L) - } -======= if err := cte.Restore(ctx); err != nil { return err } ->>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) } ctx.WritePlain(" ") return nil diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index a42663fbb6c8c..140d912e903ea 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -106,16 +106,11 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode { // Preprocess resolves table names of the node, and checks some statements validation. // preprocessReturn used to extract the infoschema for the tableName and the timestamp from the asof clause. func Preprocess(ctx sessionctx.Context, node ast.Node, preprocessOpt ...PreprocessOpt) error { -<<<<<<< HEAD - v := preprocessor{ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0), withName: make(map[string]interface{})} -======= v := preprocessor{ - ctx: ctx, - tableAliasInJoin: make([]map[string]interface{}, 0), - preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)}, - staleReadProcessor: staleread.NewStaleReadProcessor(ctx), + ctx: ctx, + tableAliasInJoin: make([]map[string]interface{}, 0), + preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)}, } ->>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) for _, optFn := range preprocessOpt { optFn(&v) } diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index 5726fc0a8020c..7ad9e94617f09 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -15,12 +15,9 @@ package core_test import ( -<<<<<<< HEAD "context" -======= "strings" "testing" ->>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) . "github.com/pingcap/check" "github.com/pingcap/errors" From 4718e991723ec45e8ea099b7c533507d1b889d00 Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 10 Aug 2022 17:45:53 +0800 Subject: [PATCH 3/6] *: fix CI Signed-off-by: xhe --- parser/ast/dml.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 1c237aa903716..2378a925c2d5b 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1055,7 +1055,7 @@ func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error { if c.IsRecursive { // If the CTE is recursive, we should make it visible for the CTE's query. // Otherwise, we should put it to stack after building the CTE's query. - ctx.RecordCTEName(c.Name.L) + ctx.CTENames = append(ctx.CTENames, c.Name.L) } if len(c.ColNameList) > 0 { ctx.WritePlain(" (") @@ -1073,7 +1073,7 @@ func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error { return err } if !c.IsRecursive { - ctx.RecordCTEName(c.Name.L) + ctx.CTENames = append(ctx.CTENames, c.Name.L) } return nil } From 2a3f21062d48692499ee44ed17466418743a48f3 Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 10 Aug 2022 17:54:16 +0800 Subject: [PATCH 4/6] *: fix CI Signed-off-by: xhe --- parser/ast/dml.go | 29 +++++++---------------------- parser/format/format.go | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 2378a925c2d5b..bd8e1a34554ef 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1055,7 +1055,7 @@ func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error { if c.IsRecursive { // If the CTE is recursive, we should make it visible for the CTE's query. // Otherwise, we should put it to stack after building the CTE's query. - ctx.CTENames = append(ctx.CTENames, c.Name.L) + ctx.RecordCTEName(c.Name.L) } if len(c.ColNameList) > 0 { ctx.WritePlain(" (") @@ -1073,7 +1073,7 @@ func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error { return err } if !c.IsRecursive { - ctx.CTENames = append(ctx.CTENames, c.Name.L) + ctx.RecordCTEName(c.Name.L) } return nil } @@ -1183,10 +1183,7 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) { // Restore implements Node interface. func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error { if n.WithBeforeBraces { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCtx()() err := n.With.Restore(ctx) if err != nil { return err @@ -1535,10 +1532,7 @@ type SetOprSelectList struct { // Restore implements Node interface. func (n *SetOprSelectList) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCtx()() if err := n.With.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred while restore SetOprSelectList.With") } @@ -1639,10 +1633,7 @@ func (*SetOprStmt) resultSet() {} // Restore implements Node interface. func (n *SetOprStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCtx()() if err := n.With.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred while restore UnionStmt.With") } @@ -2224,10 +2215,7 @@ type DeleteStmt struct { // Restore implements Node interface. func (n *DeleteStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCtx()() err := n.With.Restore(ctx) if err != nil { return err @@ -2388,10 +2376,7 @@ type UpdateStmt struct { // Restore implements Node interface. func (n *UpdateStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCtx()() err := n.With.Restore(ctx) if err != nil { return err diff --git a/parser/format/format.go b/parser/format/format.go index ef003d6a78d6d..e763d6a8946d6 100644 --- a/parser/format/format.go +++ b/parser/format/format.go @@ -387,3 +387,20 @@ func (ctx *RestoreCtx) WritePlain(plainText string) { func (ctx *RestoreCtx) WritePlainf(format string, a ...interface{}) { fmt.Fprintf(ctx.In, format, a...) } + +// RecordCTEName records the CTE name. +func (c *RestoreCtx) RecordCTEName(nameL string) { + c.CTENames = append(c.CTENames, nameL) +} + +// RestoreCTEFunc is used to restore CTE. +func (c *RestoreCtx) RestoreCTEFunc() func() { + l := len(c.CTENames) + return func() { + if l == 0 { + c.CTENames = nil + } else { + c.CTENames = c.CTENames[:l] + } + } +} From a8967f3742b5695ddcc47736a01ca5e3807ff89a Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 10 Aug 2022 17:57:28 +0800 Subject: [PATCH 5/6] *: fix CI Signed-off-by: xhe --- parser/ast/dml.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/parser/ast/dml.go b/parser/ast/dml.go index bd8e1a34554ef..1df92c0d54661 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1183,7 +1183,7 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) { // Restore implements Node interface. func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error { if n.WithBeforeBraces { - defer ctx.RestoreCtx()() + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err @@ -1532,7 +1532,7 @@ type SetOprSelectList struct { // Restore implements Node interface. func (n *SetOprSelectList) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - defer ctx.RestoreCtx()() + defer ctx.RestoreCTEFunc()() if err := n.With.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred while restore SetOprSelectList.With") } @@ -1633,7 +1633,7 @@ func (*SetOprStmt) resultSet() {} // Restore implements Node interface. func (n *SetOprStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - defer ctx.RestoreCtx()() + defer ctx.RestoreCTEFunc()() if err := n.With.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred while restore UnionStmt.With") } @@ -2215,7 +2215,7 @@ type DeleteStmt struct { // Restore implements Node interface. func (n *DeleteStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - defer ctx.RestoreCtx()() + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err @@ -2376,7 +2376,7 @@ type UpdateStmt struct { // Restore implements Node interface. func (n *UpdateStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - defer ctx.RestoreCtx()() + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err From eba9ce9914166ed221718b1270909c506a83bb25 Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 10 Aug 2022 17:59:17 +0800 Subject: [PATCH 6/6] *: fix CI Signed-off-by: xhe --- planner/core/preprocess_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index 7ad9e94617f09..6d162c733372d 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -35,8 +35,10 @@ import ( "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/testleak" + "github.com/stretchr/testify/require" ) var _ = Suite(&testValidatorSuite{})