Skip to content

Commit

Permalink
cherry pick pingcap#34133 to release-5.4
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
  • Loading branch information
likzn authored and ti-srebot committed May 6, 2022
1 parent cd60925 commit 2276240
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 11 deletions.
3 changes: 1 addition & 2 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,7 @@ func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string
withIdx := strings.Index(bindSQL, "WITH")
restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreSpacesAroundBinaryOperation|format.RestoreStringWithoutCharset|format.RestoreNameBackQuotes, &withSb)
restoreCtx.DefaultDB = defaultDB
err := n.With.Restore(restoreCtx)
if err != nil {
if err := n.With.Restore(restoreCtx); err != nil {
logutil.BgLogger().Debug("[sql-bind] restore SQL failed", zap.Error(err))
return ""
}
Expand Down
75 changes: 75 additions & 0 deletions cmd/explaintest/r/cte.result
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,78 @@ c1 c2
2 1
2 2
2 3
use test;
drop table if exists t1, t2;
drop view if exists v1;
create table t1 (a int);
insert into t1 values (0), (1), (2), (3), (4);
create table t2 (a int);
insert into t2 values (1), (2), (3), (4), (5);
drop view if exists v1,v2;
create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create database if not exists test1;
use test1;
select * from test.v1;
a
5
4
3
2
select * from test.v2;
a
6
5
4
3
use test;
drop table if exists t ,t1, t2;
create table t(a int);
insert into t values (0);
create table t1 (b int);
insert into t1 values (0);
create table t2 (c int);
insert into t2 values (0);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
b c
0 0
use test;
drop table if exists t11111;
create table t11111 (d int);
insert into t11111 values (123), (223), (323);
drop view if exists v1;
create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111;
use test1;
select * from test.v1;
d
123
223
323
use test;
drop table if exists t1;
create table t1 (a int);
insert into t1 values (1);
drop view if exists v1;
create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1;
use test1;
select * from test.v1;
name_exp_1
30
use test;
drop table if exists t1,t2;
create table t1 (a int);
insert into t1 values (0), (1);
create table t2 (b int);
insert into t2 values (4), (5);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
a b
0 5
0 4
1 5
1 4
58 changes: 58 additions & 0 deletions cmd/explaintest/t/cte.test
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,61 @@ select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 un

explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
# Some cases to Test Create View With CTE and checkout Database
# With name is the same as the table name
use test;
drop table if exists t1, t2;
drop view if exists v1;
create table t1 (a int);
insert into t1 values (0), (1), (2), (3), (4);
create table t2 (a int);
insert into t2 values (1), (2), (3), (4), (5);
drop view if exists v1,v2;
create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create database if not exists test1;
use test1;
select * from test.v1;
select * from test.v2;
# case
use test;
drop table if exists t ,t1, t2;
create table t(a int);
insert into t values (0);
create table t1 (b int);
insert into t1 values (0);
create table t2 (c int);
insert into t2 values (0);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
# case
use test;
drop table if exists t11111;
create table t11111 (d int);
insert into t11111 values (123), (223), (323);
drop view if exists v1;
create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111;
use test1;
select * from test.v1;
# case
use test;
drop table if exists t1;
create table t1 (a int);
insert into t1 values (1);
drop view if exists v1;
create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1;
use test1;
select * from test.v1;
# case
use test;
drop table if exists t1,t2;
create table t1 (a int);
insert into t1 values (0), (1);
create table t2 (b int);
insert into t2 values (4), (5);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
56 changes: 53 additions & 3 deletions parser/ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,51 @@ type CommonTableExpression struct {
Name model.CIStr
Query *SubqueryExpr
ColNameList []model.CIStr
IsRecursive bool
}

// Restore implements Node interface
func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error {
ctx.WriteName(c.Name.String())
if c.IsRecursive {
// If the CTE is recursive, we should make it visible for the CTE's query.
// Otherwise, we should put it to stack after building the CTE's query.
ctx.RecordCTEName(c.Name.L)
}
if len(c.ColNameList) > 0 {
ctx.WritePlain(" (")
for j, name := range c.ColNameList {
if j != 0 {
ctx.WritePlain(", ")
}
ctx.WriteName(name.String())
}
ctx.WritePlain(")")
}
ctx.WriteKeyWord(" AS ")
err := c.Query.Restore(ctx)
if err != nil {
return err
}
if !c.IsRecursive {
ctx.RecordCTEName(c.Name.L)
}
return nil
}

// Accept implements Node interface
func (c *CommonTableExpression) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(c)
if skipChildren {
return v.Leave(newNode)
}

node, ok := c.Query.Accept(v)
if !ok {
return c, false
}
c.Query = node.(*SubqueryExpr)
return v.Leave(c)
}

type WithClause struct {
Expand Down Expand Up @@ -1115,6 +1160,7 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error {
if i != 0 {
ctx.WritePlain(", ")
}
<<<<<<< HEAD
ctx.WriteName(cte.Name.String())
if n.IsRecursive {
// If the CTE is recursive, we should make it visible for the CTE's query.
Expand All @@ -1139,6 +1185,11 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error {
if !n.IsRecursive {
ctx.CTENames = append(ctx.CTENames, cte.Name.L)
}
=======
if err := cte.Restore(ctx); err != nil {
return err
}
>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133)
}
ctx.WritePlain(" ")
return nil
Expand All @@ -1151,11 +1202,9 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) {
}

for _, cte := range n.CTEs {
node, ok := cte.Query.Accept(v)
if !ok {
if _, ok := cte.Accept(v); !ok {
return n, false
}
cte.Query = node.(*SubqueryExpr)
}
return v.Leave(n)
}
Expand All @@ -1179,6 +1228,7 @@ func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error {
}()
}
if !n.WithBeforeBraces && n.With != nil {
defer ctx.RestoreCTEFunc()()
err := n.With.Restore(ctx)
if err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -16659,6 +16659,9 @@ yynewstate:
{
ws := yyS[yypt-0].item.(*ast.WithClause)
ws.IsRecursive = true
for _, cte := range ws.CTEs {
cte.IsRecursive = true
}
parser.yyVAL.item = ws
}
case 1493:
Expand Down
3 changes: 3 additions & 0 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -8343,6 +8343,9 @@ WithClause:
{
ws := $3.(*ast.WithClause)
ws.IsRecursive = true
for _, cte := range ws.CTEs {
cte.IsRecursive = true
}
$$ = ws
}

Expand Down
47 changes: 41 additions & 6 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,16 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode {
// Preprocess resolves table names of the node, and checks some statements' validation.
// preprocessReturn used to extract the infoschema for the tableName and the timestamp from the asof clause.
func Preprocess(ctx sessionctx.Context, node ast.Node, preprocessOpt ...PreprocessOpt) error {
<<<<<<< HEAD
v := preprocessor{ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0), withName: make(map[string]interface{})}
=======
v := preprocessor{
ctx: ctx,
tableAliasInJoin: make([]map[string]interface{}, 0),
preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)},
staleReadProcessor: staleread.NewStaleReadProcessor(ctx),
}
>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133)
for _, optFn := range preprocessOpt {
optFn(&v)
}
Expand Down Expand Up @@ -170,6 +179,12 @@ type PreprocessExecuteISUpdate struct {
Node ast.Node
}

// preprocessWith is used to record info from WITH statements like CTE name.
type preprocessWith struct {
cteCanUsed []string
cteBeforeOffset []int
}

// preprocessor is an ast.Visitor that preprocess
// ast Nodes parsed from parser.
type preprocessor struct {
Expand All @@ -181,7 +196,7 @@ type preprocessor struct {
// tableAliasInJoin is a stack that keeps the table alias names for joins.
// len(tableAliasInJoin) may bigger than 1 because the left/right child of join may be subquery that contains `JOIN`
tableAliasInJoin []map[string]interface{}
withName map[string]interface{}
preprocessWith *preprocessWith

// values that may be returned
*PreprocessorReturn
Expand Down Expand Up @@ -309,9 +324,12 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
}
case *ast.GroupByClause:
p.checkGroupBy(node)
case *ast.WithClause:
for _, cte := range node.CTEs {
p.withName[cte.Name.L] = struct{}{}
case *ast.CommonTableExpression, *ast.SubqueryExpr:
with := p.preprocessWith
beforeOffset := len(with.cteCanUsed)
with.cteBeforeOffset = append(with.cteBeforeOffset, beforeOffset)
if cteNode, exist := node.(*ast.CommonTableExpression); exist && cteNode.IsRecursive {
with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L)
}
case *ast.BeginStmt:
// If the begin statement was like following:
Expand Down Expand Up @@ -556,6 +574,19 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) {
if x.Kind == ast.BRIEKindRestore {
p.flag &= ^inCreateOrDropTable
}
case *ast.CommonTableExpression, *ast.SubqueryExpr:
with := p.preprocessWith
lenWithCteBeforeOffset := len(with.cteBeforeOffset)
if lenWithCteBeforeOffset < 1 {
p.err = ErrInternal.GenWithStack("len(cteBeforeOffset) is less than one.Maybe it was deleted in somewhere.Should match in Enter and Leave")
break
}
beforeOffset := with.cteBeforeOffset[lenWithCteBeforeOffset-1]
with.cteBeforeOffset = with.cteBeforeOffset[:lenWithCteBeforeOffset-1]
with.cteCanUsed = with.cteCanUsed[:beforeOffset]
if cteNode, exist := x.(*ast.CommonTableExpression); exist {
with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L)
}
}

return in, p.err == nil
Expand Down Expand Up @@ -1398,15 +1429,19 @@ func (p *preprocessor) stmtType() string {

func (p *preprocessor) handleTableName(tn *ast.TableName) {
if tn.Schema.L == "" {
if _, ok := p.withName[tn.Name.L]; ok {
return

for _, cte := range p.preprocessWith.cteCanUsed {
if cte == tn.Name.L {
return
}
}

currentDB := p.ctx.GetSessionVars().CurrentDB
if currentDB == "" {
p.err = errors.Trace(ErrNoDB)
return
}

tn.Schema = model.NewCIStr(currentDB)
}

Expand Down
Loading

0 comments on commit 2276240

Please sign in to comment.