From e0cbc243021ce3c024a46ccc8f91fe3f31ce73b8 Mon Sep 17 00:00:00 2001 From: likzn <1020193211@qq.com> Date: Fri, 6 May 2022 15:24:56 +0800 Subject: [PATCH] cherry pick #34133 to release-6.0 Signed-off-by: ti-srebot --- bindinfo/handle.go | 3 +- cmd/explaintest/r/cte.result | 75 +++++++++++++++++++++++++++++++++ cmd/explaintest/t/cte.test | 58 +++++++++++++++++++++++++ parser/ast/dml.go | 56 ++++++++++++++++++++++-- parser/parser.go | 3 ++ parser/parser.y | 3 ++ planner/core/preprocess.go | 40 +++++++++++++++--- planner/core/preprocess_test.go | 62 +++++++++++++++++++++++++++ 8 files changed, 288 insertions(+), 12 deletions(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index da4034b853aa7..6cf836ff0e21f 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -961,8 +961,7 @@ func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string withIdx := strings.Index(bindSQL, "WITH") restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreSpacesAroundBinaryOperation|format.RestoreStringWithoutCharset|format.RestoreNameBackQuotes, &withSb) restoreCtx.DefaultDB = defaultDB - err := n.With.Restore(restoreCtx) - if err != nil { + if err := n.With.Restore(restoreCtx); err != nil { logutil.BgLogger().Debug("[sql-bind] restore SQL failed", zap.Error(err)) return "" } diff --git a/cmd/explaintest/r/cte.result b/cmd/explaintest/r/cte.result index 6c1da3d121e77..3c8737d7400b8 100644 --- a/cmd/explaintest/r/cte.result +++ b/cmd/explaintest/r/cte.result @@ -730,3 +730,78 @@ c1 c2 2 1 2 2 2 3 +use test; +drop table if exists t1, t2; +drop view if exists v1; +create table t1 (a int); +insert into t1 values (0), (1), (2), (3), (4); +create table t2 (a int); +insert into t2 values (1), (2), (3), (4), (5); +drop view if exists v1,v2; +create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create database if not exists test1; +use test1; +select * from test.v1; +a +5 +4 +3 +2 +select * from test.v2; +a +6 +5 +4 +3 +use test; +drop table if exists t ,t1, t2; +create table t(a int); +insert into t values (0); +create table t1 (b int); +insert into t1 values (0); +create table t2 (c int); +insert into t2 values (0); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +b c +0 0 +use test; +drop table if exists t11111; +create table t11111 (d int); +insert into t11111 values (123), (223), (323); +drop view if exists v1; +create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111; +use test1; +select * from test.v1; +d +123 +223 +323 +use test; +drop table if exists t1; +create table t1 (a int); +insert into t1 values (1); +drop view if exists v1; +create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1; +use test1; +select * from test.v1; +name_exp_1 +30 +use test; +drop table if exists t1,t2; +create table t1 (a int); +insert into t1 values (0), (1); +create table t2 (b int); +insert into t2 values (4), (5); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +a b +0 5 +0 4 +1 5 +1 4 diff --git a/cmd/explaintest/t/cte.test b/cmd/explaintest/t/cte.test index df9ee6c3b8d06..0c1623caa81dd 100644 --- a/cmd/explaintest/t/cte.test +++ b/cmd/explaintest/t/cte.test @@ -256,3 +256,61 @@ select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 un explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +# Some cases to Test Create View With CTE and checkout Database +# With name is the same as the table name +use test; +drop table if exists t1, t2; +drop view if exists v1; +create table t1 (a int); +insert into t1 values (0), (1), (2), (3), (4); +create table t2 (a int); +insert into t2 values (1), (2), (3), (4), (5); +drop view if exists v1,v2; +create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create database if not exists test1; +use test1; +select * from test.v1; +select * from test.v2; +# case +use test; +drop table if exists t ,t1, t2; +create table t(a int); +insert into t values (0); +create table t1 (b int); +insert into t1 values (0); +create table t2 (c int); +insert into t2 values (0); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t11111; +create table t11111 (d int); +insert into t11111 values (123), (223), (323); +drop view if exists v1; +create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t1; +create table t1 (a int); +insert into t1 values (1); +drop view if exists v1; +create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t1,t2; +create table t1 (a int); +insert into t1 values (0), (1); +create table t2 (b int); +insert into t2 values (4), (5); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 5a8041fb3e939..6f3384f42afd7 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1048,6 +1048,51 @@ type CommonTableExpression struct { Name model.CIStr Query *SubqueryExpr ColNameList []model.CIStr + IsRecursive bool +} + +// Restore implements Node interface +func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error { + ctx.WriteName(c.Name.String()) + if c.IsRecursive { + // If the CTE is recursive, we should make it visible for the CTE's query. + // Otherwise, we should put it to stack after building the CTE's query. + ctx.RecordCTEName(c.Name.L) + } + if len(c.ColNameList) > 0 { + ctx.WritePlain(" (") + for j, name := range c.ColNameList { + if j != 0 { + ctx.WritePlain(", ") + } + ctx.WriteName(name.String()) + } + ctx.WritePlain(")") + } + ctx.WriteKeyWord(" AS ") + err := c.Query.Restore(ctx) + if err != nil { + return err + } + if !c.IsRecursive { + ctx.RecordCTEName(c.Name.L) + } + return nil +} + +// Accept implements Node interface +func (c *CommonTableExpression) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(c) + if skipChildren { + return v.Leave(newNode) + } + + node, ok := c.Query.Accept(v) + if !ok { + return c, false + } + c.Query = node.(*SubqueryExpr) + return v.Leave(c) } type WithClause struct { @@ -1117,6 +1162,7 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if i != 0 { ctx.WritePlain(", ") } +<<<<<<< HEAD ctx.WriteName(cte.Name.String()) if n.IsRecursive { // If the CTE is recursive, we should make it visible for the CTE's query. @@ -1141,6 +1187,11 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if !n.IsRecursive { ctx.CTENames = append(ctx.CTENames, cte.Name.L) } +======= + if err := cte.Restore(ctx); err != nil { + return err + } +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) } ctx.WritePlain(" ") return nil @@ -1153,11 +1204,9 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) { } for _, cte := range n.CTEs { - node, ok := cte.Query.Accept(v) - if !ok { + if _, ok := cte.Accept(v); !ok { return n, false } - cte.Query = node.(*SubqueryExpr) } return v.Leave(n) } @@ -1181,6 +1230,7 @@ func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error { }() } if !n.WithBeforeBraces && n.With != nil { + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err diff --git a/parser/parser.go b/parser/parser.go index de79b35a51ed7..ddd8618209cf1 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -16698,6 +16698,9 @@ yynewstate: { ws := yyS[yypt-0].item.(*ast.WithClause) ws.IsRecursive = true + for _, cte := range ws.CTEs { + cte.IsRecursive = true + } parser.yyVAL.item = ws } case 1497: diff --git a/parser/parser.y b/parser/parser.y index 14304bf200bf7..983b2f55c65af 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -8363,6 +8363,9 @@ WithClause: { ws := $3.(*ast.WithClause) ws.IsRecursive = true + for _, cte := range ws.CTEs { + cte.IsRecursive = true + } $$ = ws } diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 6ffb6d142f608..7c297a9ebe2f2 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -116,7 +116,7 @@ func Preprocess(ctx sessionctx.Context, node ast.Node, preprocessOpt ...Preproce v := preprocessor{ ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0), - withName: make(map[string]interface{}), + preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)}, staleReadProcessor: staleread.NewStaleReadProcessor(ctx), } for _, optFn := range preprocessOpt { @@ -176,6 +176,12 @@ type PreprocessExecuteISUpdate struct { Node ast.Node } +// preprocessWith is used to record info from WITH statements like CTE name. +type preprocessWith struct { + cteCanUsed []string + cteBeforeOffset []int +} + // preprocessor is an ast.Visitor that preprocess // ast Nodes parsed from parser. type preprocessor struct { @@ -187,7 +193,7 @@ type preprocessor struct { // tableAliasInJoin is a stack that keeps the table alias names for joins. // len(tableAliasInJoin) may bigger than 1 because the left/right child of join may be subquery that contains `JOIN` tableAliasInJoin []map[string]interface{} - withName map[string]interface{} + preprocessWith *preprocessWith staleReadProcessor staleread.Processor @@ -317,9 +323,12 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { } case *ast.GroupByClause: p.checkGroupBy(node) - case *ast.WithClause: - for _, cte := range node.CTEs { - p.withName[cte.Name.L] = struct{}{} + case *ast.CommonTableExpression, *ast.SubqueryExpr: + with := p.preprocessWith + beforeOffset := len(with.cteCanUsed) + with.cteBeforeOffset = append(with.cteBeforeOffset, beforeOffset) + if cteNode, exist := node.(*ast.CommonTableExpression); exist && cteNode.IsRecursive { + with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L) } case *ast.BeginStmt: // If the begin statement was like following: @@ -564,6 +573,19 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) { if x.Kind == ast.BRIEKindRestore { p.flag &= ^inCreateOrDropTable } + case *ast.CommonTableExpression, *ast.SubqueryExpr: + with := p.preprocessWith + lenWithCteBeforeOffset := len(with.cteBeforeOffset) + if lenWithCteBeforeOffset < 1 { + p.err = ErrInternal.GenWithStack("len(cteBeforeOffset) is less than one.Maybe it was deleted in somewhere.Should match in Enter and Leave") + break + } + beforeOffset := with.cteBeforeOffset[lenWithCteBeforeOffset-1] + with.cteBeforeOffset = with.cteBeforeOffset[:lenWithCteBeforeOffset-1] + with.cteCanUsed = with.cteCanUsed[:beforeOffset] + if cteNode, exist := x.(*ast.CommonTableExpression); exist { + with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L) + } } return in, p.err == nil @@ -1422,8 +1444,11 @@ func (p *preprocessor) stmtType() string { func (p *preprocessor) handleTableName(tn *ast.TableName) { if tn.Schema.L == "" { - if _, ok := p.withName[tn.Name.L]; ok { - return + + for _, cte := range p.preprocessWith.cteCanUsed { + if cte == tn.Name.L { + return + } } currentDB := p.ctx.GetSessionVars().CurrentDB @@ -1431,6 +1456,7 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) { p.err = errors.Trace(ErrNoDB) return } + tn.Schema = model.NewCIStr(currentDB) } diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index 0a22f349bd9ed..a01a0df11ae56 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -15,6 +15,7 @@ package core_test import ( + "strings" "testing" "github.com/pingcap/errors" @@ -22,6 +23,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -364,3 +366,63 @@ func TestLargeVarcharAutoConv(t *testing.T) { require.True(t, terror.ErrorEqual(warns[i].Err, dbterror.ErrAutoConvert)) } } + +func TestPreprocessCTE(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t, t1, t2;") + tk.MustExec("create table t (c int);insert into t values (1), (2), (3), (4), (5);") + tk.MustExec("create table t1 (a int);insert into t1 values (0), (1), (2), (3), (4);") + tk.MustExec("create table t2 (b int);insert into t2 values (1), (2), (3), (4), (5);") + tk.MustExec("create table t11111 (d int);insert into t11111 values (1), (2), (3), (4), (5);") + tk.MustExec("drop table if exists tbl_1;\nCREATE TABLE `tbl_1` (\n `col_2` char(65) CHARACTER SET utf8 COLLATE utf8_bin DEFAULT NULL,\n `col_3` int(11) NOT NULL\n);") + testCases := []struct { + before string + after string + }{ + { + "create view v1 as WITH t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1;", + "CREATE ALGORITHM = UNDEFINED DEFINER = CURRENT_USER SQL SECURITY DEFINER VIEW `test`.`v1` AS WITH `t1` AS (SELECT `a` FROM `test`.`t2` WHERE `t2`.`a`=3 UNION SELECT `t2`.`a`+1 FROM (`test`.`t1`) JOIN `test`.`t2` WHERE `t1`.`a`=`t2`.`a`) SELECT * FROM `t1`", + }, + { + "WITH t1 AS ( SELECT(WITH t1 AS ( WITH qn AS ( SELECT 10 * a AS a FROM t1 ) SELECT 10 * a AS a FROM qn ) SELECT * FROM t1 LIMIT 1 ) FROM t2 WHERE t2.b = 3 UNION SELECT t2.b + 1 FROM t1, t2 WHERE t1.a = t2.b) SELECT * FROM t1", + "WITH `t1` AS (SELECT (WITH `t1` AS (WITH `qn` AS (SELECT 10*`a` AS `a` FROM `test`.`t1`) SELECT 10*`a` AS `a` FROM `qn`) SELECT * FROM `t1` LIMIT 1) FROM `test`.`t2` WHERE `t2`.`b`=3 UNION SELECT `t2`.`b`+1 FROM (`test`.`t1`) JOIN `test`.`t2` WHERE `t1`.`a`=`t2`.`b`) SELECT * FROM `t1`", + }, + { + "with recursive cte_8932 (col_34891,col_34892) AS ( with recursive cte_8932 (col_34893,col_34894,col_34895) AS ( with tbl_1 (col_34896,col_34897,col_34898,col_34899) AS ( select 1, \"2\",3,col_3 from tbl_1 ) select cte_as_8958.col_34896,cte_as_8958.col_34898,cte_as_8958.col_34899 from tbl_1 as cte_as_8958 UNION DISTINCT select col_34893 + 1,concat(col_34894, 1),col_34895 + 1 from cte_8932 where col_34893 < 5 ) select cte_as_8959.col_34893,cte_as_8959.col_34895 from cte_8932 as cte_as_8959 ) select * from cte_8932 as cte_as_8960 order by cte_as_8960.col_34891,cte_as_8960.col_34892;", + "WITH RECURSIVE `cte_8932` (`col_34891`, `col_34892`) AS (WITH RECURSIVE `cte_8932` (`col_34893`, `col_34894`, `col_34895`) AS (WITH `tbl_1` (`col_34896`, `col_34897`, `col_34898`, `col_34899`) AS (SELECT 1,_UTF8MB4'2',3,`col_3` FROM `test`.`tbl_1`) SELECT `cte_as_8958`.`col_34896`,`cte_as_8958`.`col_34898`,`cte_as_8958`.`col_34899` FROM `tbl_1` AS `cte_as_8958` UNION SELECT `col_34893`+1,CONCAT(`col_34894`, 1),`col_34895`+1 FROM `cte_8932` WHERE `col_34893`<5) SELECT `cte_as_8959`.`col_34893`,`cte_as_8959`.`col_34895` FROM `cte_8932` AS `cte_as_8959`) SELECT * FROM `cte_8932` AS `cte_as_8960` ORDER BY `cte_as_8960`.`col_34891`,`cte_as_8960`.`col_34892`", + }, + { + "with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1;", + "WITH `t1` AS (WITH `t11` AS (SELECT * FROM `test`.`t`) SELECT * FROM (`test`.`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "with t1 as (with t1 as (select * from t) select * from t1, t2) select * from t1;", + "WITH `t1` AS (WITH `t1` AS (SELECT * FROM `test`.`t`) SELECT * FROM (`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "WITH t1 AS ( WITH t1 AS ( SELECT * FROM t ) SELECT ( WITH t2 AS ( SELECT * FROM t ) SELECT * FROM t limit 1 ) FROM t1, t2 ) \n\nSELECT\n* \nFROM\n\tt1;", + "WITH `t1` AS (WITH `t1` AS (SELECT * FROM `test`.`t`) SELECT (WITH `t2` AS (SELECT * FROM `test`.`t`) SELECT * FROM `test`.`t` LIMIT 1) FROM (`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "WITH t123 AS (WITH t11111 AS ( SELECT * FROM test.t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, test.t2 ) SELECT * FROM t11111;", + "WITH `t123` AS (WITH `t11111` AS (SELECT * FROM `test`.`t1`) SELECT (WITH `t2` AS (SELECT (WITH `t23` AS (SELECT * FROM `t11111`) SELECT * FROM `t23` LIMIT 1) FROM `t11111`) SELECT * FROM `t2` LIMIT 1) FROM (`t11111`) JOIN `test`.`t2`) SELECT * FROM `test`.`t11111`", + }, + } + for _, tc := range testCases { + stmts, warnings, err := parser.New().ParseSQL(tc.before) + require.Len(t, warnings, 0) + require.NoError(t, err) + require.Len(t, stmts, 1) + + err = core.Preprocess(tk.Session(), stmts[0]) + require.NoError(t, err) + + var rs strings.Builder + err = stmts[0].Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &rs)) + require.NoError(t, err) + require.Equal(t, tc.after, rs.String()) + } +}