diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 773597a7f7eec..a6a7108d8d7b3 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -863,8 +863,7 @@ func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string withIdx := strings.Index(bindSQL, "WITH") restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreSpacesAroundBinaryOperation|format.RestoreStringWithoutCharset|format.RestoreNameBackQuotes, &withSb) restoreCtx.DefaultDB = defaultDB - err := n.With.Restore(restoreCtx) - if err != nil { + if err := n.With.Restore(restoreCtx); err != nil { logutil.BgLogger().Debug("[sql-bind] restore SQL failed", zap.Error(err)) return "" } diff --git a/cmd/explaintest/r/cte.result b/cmd/explaintest/r/cte.result index 077e62cf33cd4..6ca47324c6c4f 100644 --- a/cmd/explaintest/r/cte.result +++ b/cmd/explaintest/r/cte.result @@ -722,3 +722,78 @@ c1 c2 2 1 2 2 2 3 +use test; +drop table if exists t1, t2; +drop view if exists v1; +create table t1 (a int); +insert into t1 values (0), (1), (2), (3), (4); +create table t2 (a int); +insert into t2 values (1), (2), (3), (4), (5); +drop view if exists v1,v2; +create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create database if not exists test1; +use test1; +select * from test.v1; +a +5 +4 +3 +2 +select * from test.v2; +a +6 +5 +4 +3 +use test; +drop table if exists t ,t1, t2; +create table t(a int); +insert into t values (0); +create table t1 (b int); +insert into t1 values (0); +create table t2 (c int); +insert into t2 values (0); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +b c +0 0 +use test; +drop table if exists t11111; +create table t11111 (d int); +insert into t11111 values (123), (223), (323); +drop view if exists v1; +create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111; +use test1; +select * from test.v1; +d +123 +223 +323 +use test; +drop table if exists t1; +create table t1 (a int); +insert into t1 values (1); +drop view if exists v1; +create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1; +use test1; +select * from test.v1; +name_exp_1 +30 +use test; +drop table if exists t1,t2; +create table t1 (a int); +insert into t1 values (0), (1); +create table t2 (b int); +insert into t2 values (4), (5); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +a b +0 5 +0 4 +1 5 +1 4 diff --git a/cmd/explaintest/t/cte.test b/cmd/explaintest/t/cte.test index df9ee6c3b8d06..0c1623caa81dd 100644 --- a/cmd/explaintest/t/cte.test +++ b/cmd/explaintest/t/cte.test @@ -256,3 +256,61 @@ select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 un explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +# Some cases to Test Create View With CTE and checkout Database +# With name is the same as the table name +use test; +drop table if exists t1, t2; +drop view if exists v1; +create table t1 (a int); +insert into t1 values (0), (1), (2), (3), (4); +create table t2 (a int); +insert into t2 values (1), (2), (3), (4), (5); +drop view if exists v1,v2; +create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc; +create database if not exists test1; +use test1; +select * from test.v1; +select * from test.v2; +# case +use test; +drop table if exists t ,t1, t2; +create table t(a int); +insert into t values (0); +create table t1 (b int); +insert into t1 values (0); +create table t2 (c int); +insert into t2 values (0); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t11111; +create table t11111 (d int); +insert into t11111 values (123), (223), (323); +drop view if exists v1; +create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t1; +create table t1 (a int); +insert into t1 values (1); +drop view if exists v1; +create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1; +use test1; +select * from test.v1; +# case +use test; +drop table if exists t1,t2; +create table t1 (a int); +insert into t1 values (0), (1); +create table t2 (b int); +insert into t2 values (4), (5); +drop view if exists v1; +create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1; +use test1; +select * from test.v1; diff --git a/parser/ast/dml.go b/parser/ast/dml.go index cd29e293ba875..3f546fdc010af 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1046,6 +1046,51 @@ type CommonTableExpression struct { Name model.CIStr Query *SubqueryExpr ColNameList []model.CIStr + IsRecursive bool +} + +// Restore implements Node interface +func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error { + ctx.WriteName(c.Name.String()) + if c.IsRecursive { + // If the CTE is recursive, we should make it visible for the CTE's query. + // Otherwise, we should put it to stack after building the CTE's query. + ctx.RecordCTEName(c.Name.L) + } + if len(c.ColNameList) > 0 { + ctx.WritePlain(" (") + for j, name := range c.ColNameList { + if j != 0 { + ctx.WritePlain(", ") + } + ctx.WriteName(name.String()) + } + ctx.WritePlain(")") + } + ctx.WriteKeyWord(" AS ") + err := c.Query.Restore(ctx) + if err != nil { + return err + } + if !c.IsRecursive { + ctx.RecordCTEName(c.Name.L) + } + return nil +} + +// Accept implements Node interface +func (c *CommonTableExpression) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(c) + if skipChildren { + return v.Leave(newNode) + } + + node, ok := c.Query.Accept(v) + if !ok { + return c, false + } + c.Query = node.(*SubqueryExpr) + return v.Leave(c) } type WithClause struct { @@ -1115,6 +1160,7 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if i != 0 { ctx.WritePlain(", ") } +<<<<<<< HEAD ctx.WriteName(cte.Name.String()) if n.IsRecursive { // If the CTE is recursive, we should make it visible for the CTE's query. @@ -1139,6 +1185,11 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if !n.IsRecursive { ctx.CTENames = append(ctx.CTENames, cte.Name.L) } +======= + if err := cte.Restore(ctx); err != nil { + return err + } +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) } ctx.WritePlain(" ") return nil @@ -1151,11 +1202,9 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) { } for _, cte := range n.CTEs { - node, ok := cte.Query.Accept(v) - if !ok { + if _, ok := cte.Accept(v); !ok { return n, false } - cte.Query = node.(*SubqueryExpr) } return v.Leave(n) } @@ -1179,6 +1228,7 @@ func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error { }() } if !n.WithBeforeBraces && n.With != nil { + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err diff --git a/parser/parser.go b/parser/parser.go index a19f687391070..7f8e2be0c1984 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -16659,6 +16659,9 @@ yynewstate: { ws := yyS[yypt-0].item.(*ast.WithClause) ws.IsRecursive = true + for _, cte := range ws.CTEs { + cte.IsRecursive = true + } parser.yyVAL.item = ws } case 1493: diff --git a/parser/parser.y b/parser/parser.y index 0610450883ad1..a531b7139f8ff 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -8343,6 +8343,9 @@ WithClause: { ws := $3.(*ast.WithClause) ws.IsRecursive = true + for _, cte := range ws.CTEs { + cte.IsRecursive = true + } $$ = ws } diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 7eb2a3c041c52..1162631af6225 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -112,7 +112,16 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode { // Preprocess resolves table names of the node, and checks some statements' validation. // preprocessReturn used to extract the infoschema for the tableName and the timestamp from the asof clause. func Preprocess(ctx sessionctx.Context, node ast.Node, preprocessOpt ...PreprocessOpt) error { +<<<<<<< HEAD v := preprocessor{ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0), withName: make(map[string]interface{})} +======= + v := preprocessor{ + ctx: ctx, + tableAliasInJoin: make([]map[string]interface{}, 0), + preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)}, + staleReadProcessor: staleread.NewStaleReadProcessor(ctx), + } +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) for _, optFn := range preprocessOpt { optFn(&v) } @@ -170,6 +179,12 @@ type PreprocessExecuteISUpdate struct { Node ast.Node } +// preprocessWith is used to record info from WITH statements like CTE name. +type preprocessWith struct { + cteCanUsed []string + cteBeforeOffset []int +} + // preprocessor is an ast.Visitor that preprocess // ast Nodes parsed from parser. type preprocessor struct { @@ -181,7 +196,7 @@ type preprocessor struct { // tableAliasInJoin is a stack that keeps the table alias names for joins. // len(tableAliasInJoin) may bigger than 1 because the left/right child of join may be subquery that contains `JOIN` tableAliasInJoin []map[string]interface{} - withName map[string]interface{} + preprocessWith *preprocessWith // values that may be returned *PreprocessorReturn @@ -309,9 +324,12 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { } case *ast.GroupByClause: p.checkGroupBy(node) - case *ast.WithClause: - for _, cte := range node.CTEs { - p.withName[cte.Name.L] = struct{}{} + case *ast.CommonTableExpression, *ast.SubqueryExpr: + with := p.preprocessWith + beforeOffset := len(with.cteCanUsed) + with.cteBeforeOffset = append(with.cteBeforeOffset, beforeOffset) + if cteNode, exist := node.(*ast.CommonTableExpression); exist && cteNode.IsRecursive { + with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L) } case *ast.BeginStmt: // If the begin statement was like following: @@ -556,6 +574,19 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) { if x.Kind == ast.BRIEKindRestore { p.flag &= ^inCreateOrDropTable } + case *ast.CommonTableExpression, *ast.SubqueryExpr: + with := p.preprocessWith + lenWithCteBeforeOffset := len(with.cteBeforeOffset) + if lenWithCteBeforeOffset < 1 { + p.err = ErrInternal.GenWithStack("len(cteBeforeOffset) is less than one.Maybe it was deleted in somewhere.Should match in Enter and Leave") + break + } + beforeOffset := with.cteBeforeOffset[lenWithCteBeforeOffset-1] + with.cteBeforeOffset = with.cteBeforeOffset[:lenWithCteBeforeOffset-1] + with.cteCanUsed = with.cteCanUsed[:beforeOffset] + if cteNode, exist := x.(*ast.CommonTableExpression); exist { + with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L) + } } return in, p.err == nil @@ -1398,8 +1429,11 @@ func (p *preprocessor) stmtType() string { func (p *preprocessor) handleTableName(tn *ast.TableName) { if tn.Schema.L == "" { - if _, ok := p.withName[tn.Name.L]; ok { - return + + for _, cte := range p.preprocessWith.cteCanUsed { + if cte == tn.Name.L { + return + } } currentDB := p.ctx.GetSessionVars().CurrentDB @@ -1407,6 +1441,7 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) { p.err = errors.Trace(ErrNoDB) return } + tn.Schema = model.NewCIStr(currentDB) } diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index ff25091b4eeb8..5726fc0a8020c 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -15,7 +15,12 @@ package core_test import ( +<<<<<<< HEAD "context" +======= + "strings" + "testing" +>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133) . "github.com/pingcap/check" "github.com/pingcap/errors" @@ -26,6 +31,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -364,3 +370,63 @@ func (s *testValidatorSuite) TestDropGlobalTempTable(c *C) { s.runSQL(c, "drop global temporary table temp, ltemp1", false, core.ErrDropTableOnTemporaryTable) s.runSQL(c, "drop global temporary table test2.temp2, temp1", false, nil) } + +func TestPreprocessCTE(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t, t1, t2;") + tk.MustExec("create table t (c int);insert into t values (1), (2), (3), (4), (5);") + tk.MustExec("create table t1 (a int);insert into t1 values (0), (1), (2), (3), (4);") + tk.MustExec("create table t2 (b int);insert into t2 values (1), (2), (3), (4), (5);") + tk.MustExec("create table t11111 (d int);insert into t11111 values (1), (2), (3), (4), (5);") + tk.MustExec("drop table if exists tbl_1;\nCREATE TABLE `tbl_1` (\n `col_2` char(65) CHARACTER SET utf8 COLLATE utf8_bin DEFAULT NULL,\n `col_3` int(11) NOT NULL\n);") + testCases := []struct { + before string + after string + }{ + { + "create view v1 as WITH t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1;", + "CREATE ALGORITHM = UNDEFINED DEFINER = CURRENT_USER SQL SECURITY DEFINER VIEW `test`.`v1` AS WITH `t1` AS (SELECT `a` FROM `test`.`t2` WHERE `t2`.`a`=3 UNION SELECT `t2`.`a`+1 FROM (`test`.`t1`) JOIN `test`.`t2` WHERE `t1`.`a`=`t2`.`a`) SELECT * FROM `t1`", + }, + { + "WITH t1 AS ( SELECT(WITH t1 AS ( WITH qn AS ( SELECT 10 * a AS a FROM t1 ) SELECT 10 * a AS a FROM qn ) SELECT * FROM t1 LIMIT 1 ) FROM t2 WHERE t2.b = 3 UNION SELECT t2.b + 1 FROM t1, t2 WHERE t1.a = t2.b) SELECT * FROM t1", + "WITH `t1` AS (SELECT (WITH `t1` AS (WITH `qn` AS (SELECT 10*`a` AS `a` FROM `test`.`t1`) SELECT 10*`a` AS `a` FROM `qn`) SELECT * FROM `t1` LIMIT 1) FROM `test`.`t2` WHERE `t2`.`b`=3 UNION SELECT `t2`.`b`+1 FROM (`test`.`t1`) JOIN `test`.`t2` WHERE `t1`.`a`=`t2`.`b`) SELECT * FROM `t1`", + }, + { + "with recursive cte_8932 (col_34891,col_34892) AS ( with recursive cte_8932 (col_34893,col_34894,col_34895) AS ( with tbl_1 (col_34896,col_34897,col_34898,col_34899) AS ( select 1, \"2\",3,col_3 from tbl_1 ) select cte_as_8958.col_34896,cte_as_8958.col_34898,cte_as_8958.col_34899 from tbl_1 as cte_as_8958 UNION DISTINCT select col_34893 + 1,concat(col_34894, 1),col_34895 + 1 from cte_8932 where col_34893 < 5 ) select cte_as_8959.col_34893,cte_as_8959.col_34895 from cte_8932 as cte_as_8959 ) select * from cte_8932 as cte_as_8960 order by cte_as_8960.col_34891,cte_as_8960.col_34892;", + "WITH RECURSIVE `cte_8932` (`col_34891`, `col_34892`) AS (WITH RECURSIVE `cte_8932` (`col_34893`, `col_34894`, `col_34895`) AS (WITH `tbl_1` (`col_34896`, `col_34897`, `col_34898`, `col_34899`) AS (SELECT 1,_UTF8MB4'2',3,`col_3` FROM `test`.`tbl_1`) SELECT `cte_as_8958`.`col_34896`,`cte_as_8958`.`col_34898`,`cte_as_8958`.`col_34899` FROM `tbl_1` AS `cte_as_8958` UNION SELECT `col_34893`+1,CONCAT(`col_34894`, 1),`col_34895`+1 FROM `cte_8932` WHERE `col_34893`<5) SELECT `cte_as_8959`.`col_34893`,`cte_as_8959`.`col_34895` FROM `cte_8932` AS `cte_as_8959`) SELECT * FROM `cte_8932` AS `cte_as_8960` ORDER BY `cte_as_8960`.`col_34891`,`cte_as_8960`.`col_34892`", + }, + { + "with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1;", + "WITH `t1` AS (WITH `t11` AS (SELECT * FROM `test`.`t`) SELECT * FROM (`test`.`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "with t1 as (with t1 as (select * from t) select * from t1, t2) select * from t1;", + "WITH `t1` AS (WITH `t1` AS (SELECT * FROM `test`.`t`) SELECT * FROM (`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "WITH t1 AS ( WITH t1 AS ( SELECT * FROM t ) SELECT ( WITH t2 AS ( SELECT * FROM t ) SELECT * FROM t limit 1 ) FROM t1, t2 ) \n\nSELECT\n* \nFROM\n\tt1;", + "WITH `t1` AS (WITH `t1` AS (SELECT * FROM `test`.`t`) SELECT (WITH `t2` AS (SELECT * FROM `test`.`t`) SELECT * FROM `test`.`t` LIMIT 1) FROM (`t1`) JOIN `test`.`t2`) SELECT * FROM `t1`", + }, + { + "WITH t123 AS (WITH t11111 AS ( SELECT * FROM test.t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, test.t2 ) SELECT * FROM t11111;", + "WITH `t123` AS (WITH `t11111` AS (SELECT * FROM `test`.`t1`) SELECT (WITH `t2` AS (SELECT (WITH `t23` AS (SELECT * FROM `t11111`) SELECT * FROM `t23` LIMIT 1) FROM `t11111`) SELECT * FROM `t2` LIMIT 1) FROM (`t11111`) JOIN `test`.`t2`) SELECT * FROM `test`.`t11111`", + }, + } + for _, tc := range testCases { + stmts, warnings, err := parser.New().ParseSQL(tc.before) + require.Len(t, warnings, 0) + require.NoError(t, err) + require.Len(t, stmts, 1) + + err = core.Preprocess(tk.Session(), stmts[0]) + require.NoError(t, err) + + var rs strings.Builder + err = stmts[0].Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &rs)) + require.NoError(t, err) + require.Equal(t, tc.after, rs.String()) + } +}