Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: preprocessor add CTE recursive check when handleTableName (#34133) #34415

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,8 +961,7 @@ func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string
withIdx := strings.Index(bindSQL, "WITH")
restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreSpacesAroundBinaryOperation|format.RestoreStringWithoutCharset|format.RestoreNameBackQuotes, &withSb)
restoreCtx.DefaultDB = defaultDB
err := n.With.Restore(restoreCtx)
if err != nil {
if err := n.With.Restore(restoreCtx); err != nil {
logutil.BgLogger().Debug("[sql-bind] restore SQL failed", zap.Error(err))
return ""
}
Expand Down
75 changes: 75 additions & 0 deletions cmd/explaintest/r/cte.result
Original file line number Diff line number Diff line change
Expand Up @@ -730,3 +730,78 @@ c1 c2
2 1
2 2
2 3
use test;
drop table if exists t1, t2;
drop view if exists v1;
create table t1 (a int);
insert into t1 values (0), (1), (2), (3), (4);
create table t2 (a int);
insert into t2 values (1), (2), (3), (4), (5);
drop view if exists v1,v2;
create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create database if not exists test1;
use test1;
select * from test.v1;
a
5
4
3
2
select * from test.v2;
a
6
5
4
3
use test;
drop table if exists t ,t1, t2;
create table t(a int);
insert into t values (0);
create table t1 (b int);
insert into t1 values (0);
create table t2 (c int);
insert into t2 values (0);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
b c
0 0
use test;
drop table if exists t11111;
create table t11111 (d int);
insert into t11111 values (123), (223), (323);
drop view if exists v1;
create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111;
use test1;
select * from test.v1;
d
123
223
323
use test;
drop table if exists t1;
create table t1 (a int);
insert into t1 values (1);
drop view if exists v1;
create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1;
use test1;
select * from test.v1;
name_exp_1
30
use test;
drop table if exists t1,t2;
create table t1 (a int);
insert into t1 values (0), (1);
create table t2 (b int);
insert into t2 values (4), (5);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
a b
0 5
0 4
1 5
1 4
58 changes: 58 additions & 0 deletions cmd/explaintest/t/cte.test
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,61 @@ select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 un

explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
# Some cases to Test Create View With CTE and checkout Database
# With name is the same as the table name
use test;
drop table if exists t1, t2;
drop view if exists v1;
create table t1 (a int);
insert into t1 values (0), (1), (2), (3), (4);
create table t2 (a int);
insert into t2 values (1), (2), (3), (4), (5);
drop view if exists v1,v2;
create view v1 as with t1 as (select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create view v2 as with recursive t1 as ( select a from t2 where t2.a=3 union select t2.a+1 from t1,t2 where t1.a=t2.a) select * from t1 order by a desc;
create database if not exists test1;
use test1;
select * from test.v1;
select * from test.v2;
# case
use test;
drop table if exists t ,t1, t2;
create table t(a int);
insert into t values (0);
create table t1 (b int);
insert into t1 values (0);
create table t2 (c int);
insert into t2 values (0);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
# case
use test;
drop table if exists t11111;
create table t11111 (d int);
insert into t11111 values (123), (223), (323);
drop view if exists v1;
create view v1 as WITH t123 AS (WITH t11111 AS ( SELECT * FROM t1 ) SELECT ( WITH t2 AS ( SELECT ( WITH t23 AS ( SELECT * FROM t11111 ) SELECT * FROM t23 LIMIT 1 ) FROM t11111 ) SELECT * FROM t2 LIMIT 1 ) FROM t11111, t2 ) SELECT * FROM t11111;
use test1;
select * from test.v1;
# case
use test;
drop table if exists t1;
create table t1 (a int);
insert into t1 values (1);
drop view if exists v1;
create view v1 as SELECT (WITH qn AS (SELECT 10*a as a FROM t1),qn2 AS (SELECT 3*a AS b FROM qn) SELECT * from qn2 LIMIT 1) FROM t1;
use test1;
select * from test.v1;
# case
use test;
drop table if exists t1,t2;
create table t1 (a int);
insert into t1 values (0), (1);
create table t2 (b int);
insert into t2 values (4), (5);
drop view if exists v1;
create view v1 as with t1 as (with t11 as (select * from t1) select * from t1, t2) select * from t1;
use test1;
select * from test.v1;
56 changes: 53 additions & 3 deletions parser/ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,51 @@ type CommonTableExpression struct {
Name model.CIStr
Query *SubqueryExpr
ColNameList []model.CIStr
IsRecursive bool
}

// Restore implements Node interface
func (c *CommonTableExpression) Restore(ctx *format.RestoreCtx) error {
ctx.WriteName(c.Name.String())
if c.IsRecursive {
// If the CTE is recursive, we should make it visible for the CTE's query.
// Otherwise, we should put it to stack after building the CTE's query.
ctx.RecordCTEName(c.Name.L)
}
if len(c.ColNameList) > 0 {
ctx.WritePlain(" (")
for j, name := range c.ColNameList {
if j != 0 {
ctx.WritePlain(", ")
}
ctx.WriteName(name.String())
}
ctx.WritePlain(")")
}
ctx.WriteKeyWord(" AS ")
err := c.Query.Restore(ctx)
if err != nil {
return err
}
if !c.IsRecursive {
ctx.RecordCTEName(c.Name.L)
}
return nil
}

// Accept implements Node interface
func (c *CommonTableExpression) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(c)
if skipChildren {
return v.Leave(newNode)
}

node, ok := c.Query.Accept(v)
if !ok {
return c, false
}
c.Query = node.(*SubqueryExpr)
return v.Leave(c)
}

type WithClause struct {
Expand Down Expand Up @@ -1117,6 +1162,7 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error {
if i != 0 {
ctx.WritePlain(", ")
}
<<<<<<< HEAD
ctx.WriteName(cte.Name.String())
if n.IsRecursive {
// If the CTE is recursive, we should make it visible for the CTE's query.
Expand All @@ -1141,6 +1187,11 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error {
if !n.IsRecursive {
ctx.CTENames = append(ctx.CTENames, cte.Name.L)
}
=======
if err := cte.Restore(ctx); err != nil {
return err
}
>>>>>>> fa5e19010... planner: `preprocessor` add CTE recursive check when `handleTableName` (#34133)
}
ctx.WritePlain(" ")
return nil
Expand All @@ -1153,11 +1204,9 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) {
}

for _, cte := range n.CTEs {
node, ok := cte.Query.Accept(v)
if !ok {
if _, ok := cte.Accept(v); !ok {
return n, false
}
cte.Query = node.(*SubqueryExpr)
}
return v.Leave(n)
}
Expand All @@ -1181,6 +1230,7 @@ func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error {
}()
}
if !n.WithBeforeBraces && n.With != nil {
defer ctx.RestoreCTEFunc()()
err := n.With.Restore(ctx)
if err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -16698,6 +16698,9 @@ yynewstate:
{
ws := yyS[yypt-0].item.(*ast.WithClause)
ws.IsRecursive = true
for _, cte := range ws.CTEs {
cte.IsRecursive = true
}
parser.yyVAL.item = ws
}
case 1497:
Expand Down
3 changes: 3 additions & 0 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -8363,6 +8363,9 @@ WithClause:
{
ws := $3.(*ast.WithClause)
ws.IsRecursive = true
for _, cte := range ws.CTEs {
cte.IsRecursive = true
}
$$ = ws
}

Expand Down
40 changes: 33 additions & 7 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func Preprocess(ctx sessionctx.Context, node ast.Node, preprocessOpt ...Preproce
v := preprocessor{
ctx: ctx,
tableAliasInJoin: make([]map[string]interface{}, 0),
withName: make(map[string]interface{}),
preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)},
staleReadProcessor: staleread.NewStaleReadProcessor(ctx),
}
for _, optFn := range preprocessOpt {
Expand Down Expand Up @@ -176,6 +176,12 @@ type PreprocessExecuteISUpdate struct {
Node ast.Node
}

// preprocessWith is used to record info from WITH statements like CTE name.
type preprocessWith struct {
cteCanUsed []string
cteBeforeOffset []int
}

// preprocessor is an ast.Visitor that preprocess
// ast Nodes parsed from parser.
type preprocessor struct {
Expand All @@ -187,7 +193,7 @@ type preprocessor struct {
// tableAliasInJoin is a stack that keeps the table alias names for joins.
// len(tableAliasInJoin) may bigger than 1 because the left/right child of join may be subquery that contains `JOIN`
tableAliasInJoin []map[string]interface{}
withName map[string]interface{}
preprocessWith *preprocessWith

staleReadProcessor staleread.Processor

Expand Down Expand Up @@ -317,9 +323,12 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
}
case *ast.GroupByClause:
p.checkGroupBy(node)
case *ast.WithClause:
for _, cte := range node.CTEs {
p.withName[cte.Name.L] = struct{}{}
case *ast.CommonTableExpression, *ast.SubqueryExpr:
with := p.preprocessWith
beforeOffset := len(with.cteCanUsed)
with.cteBeforeOffset = append(with.cteBeforeOffset, beforeOffset)
if cteNode, exist := node.(*ast.CommonTableExpression); exist && cteNode.IsRecursive {
with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L)
}
case *ast.BeginStmt:
// If the begin statement was like following:
Expand Down Expand Up @@ -564,6 +573,19 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) {
if x.Kind == ast.BRIEKindRestore {
p.flag &= ^inCreateOrDropTable
}
case *ast.CommonTableExpression, *ast.SubqueryExpr:
with := p.preprocessWith
lenWithCteBeforeOffset := len(with.cteBeforeOffset)
if lenWithCteBeforeOffset < 1 {
p.err = ErrInternal.GenWithStack("len(cteBeforeOffset) is less than one.Maybe it was deleted in somewhere.Should match in Enter and Leave")
break
}
beforeOffset := with.cteBeforeOffset[lenWithCteBeforeOffset-1]
with.cteBeforeOffset = with.cteBeforeOffset[:lenWithCteBeforeOffset-1]
with.cteCanUsed = with.cteCanUsed[:beforeOffset]
if cteNode, exist := x.(*ast.CommonTableExpression); exist {
with.cteCanUsed = append(with.cteCanUsed, cteNode.Name.L)
}
}

return in, p.err == nil
Expand Down Expand Up @@ -1422,15 +1444,19 @@ func (p *preprocessor) stmtType() string {

func (p *preprocessor) handleTableName(tn *ast.TableName) {
if tn.Schema.L == "" {
if _, ok := p.withName[tn.Name.L]; ok {
return

for _, cte := range p.preprocessWith.cteCanUsed {
if cte == tn.Name.L {
return
}
}

currentDB := p.ctx.GetSessionVars().CurrentDB
if currentDB == "" {
p.err = errors.Trace(ErrNoDB)
return
}

tn.Schema = model.NewCIStr(currentDB)
}

Expand Down
Loading