diff --git a/executor/cte_test.go b/executor/cte_test.go index 3a9c3a5a7987f..236ea78e1d2ae 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -354,3 +354,23 @@ func TestCTEWithLimit(t *testing.T) { rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union all select c1 + 1 from cte1 limit 4 offset 4) select * from cte1;") rows.Check(testkit.Rows("3", "4", "3", "4")) } + +// https://github.com/pingcap/tidb/issues/33965. +func TestCTEsInView(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + + tk.MustExec("create database if not exists test1;") + tk.MustExec("create table test.t (a int);") + tk.MustExec("create table test1.t (a int);") + tk.MustExec("insert into test.t values (1);") + tk.MustExec("insert into test1.t values (2);") + + tk.MustExec("use test;") + tk.MustExec("create definer='root'@'localhost' view test.v as with tt as (select * from t) select * from tt;") + tk.MustQuery("select * from test.v;").Check(testkit.Rows("1")) + tk.MustExec("use test1;") + tk.MustQuery("select * from test.v;").Check(testkit.Rows("1")) +} diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 2349602da70ed..f57846e3acdaa 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -14,8 +14,6 @@ package ast import ( - "strings" - "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/format" @@ -286,14 +284,7 @@ func (n *TableName) restoreName(ctx *format.RestoreCtx) { ctx.WritePlain(".") } else if ctx.DefaultDB != "" { // Try CTE, for a CTE table name, we shouldn't write the database name. - ok := false - for _, name := range ctx.CTENames { - if strings.EqualFold(name, n.Name.String()) { - ok = true - break - } - } - if !ok { + if !ctx.IsCTETableName(n.Name.L) { ctx.WriteName(ctx.DefaultDB) ctx.WritePlain(".") } @@ -1117,7 +1108,7 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { if n.IsRecursive { // If the CTE is recursive, we should make it visible for the CTE's query. // Otherwise, we should put it to stack after building the CTE's query. - ctx.CTENames = append(ctx.CTENames, cte.Name.L) + ctx.RecordCTEName(cte.Name.L) } if len(cte.ColNameList) > 0 { ctx.WritePlain(" (") @@ -1135,7 +1126,7 @@ func (n *WithClause) Restore(ctx *format.RestoreCtx) error { return err } if !n.IsRecursive { - ctx.CTENames = append(ctx.CTENames, cte.Name.L) + ctx.RecordCTEName(cte.Name.L) } } ctx.WritePlain(" ") @@ -1161,10 +1152,7 @@ func (n *WithClause) Accept(v Visitor) (Node, bool) { // Restore implements Node interface. func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error { if n.WithBeforeBraces { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err @@ -1512,10 +1500,7 @@ type SetOprSelectList struct { // Restore implements Node interface. func (n *SetOprSelectList) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCTEFunc()() if err := n.With.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred while restore SetOprSelectList.With") } @@ -1616,10 +1601,7 @@ func (*SetOprStmt) resultSet() {} // Restore implements Node interface. func (n *SetOprStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCTEFunc()() if err := n.With.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred while restore UnionStmt.With") } @@ -2201,10 +2183,7 @@ type DeleteStmt struct { // Restore implements Node interface. func (n *DeleteStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err @@ -2365,10 +2344,7 @@ type UpdateStmt struct { // Restore implements Node interface. func (n *UpdateStmt) Restore(ctx *format.RestoreCtx) error { if n.With != nil { - l := len(ctx.CTENames) - defer func() { - ctx.CTENames = ctx.CTENames[:l] - }() + defer ctx.RestoreCTEFunc()() err := n.With.Restore(ctx) if err != nil { return err diff --git a/parser/ast/expressions.go b/parser/ast/expressions.go index 6a46ab332c831..66f40eb205952 100644 --- a/parser/ast/expressions.go +++ b/parser/ast/expressions.go @@ -512,7 +512,7 @@ type ColumnName struct { // Restore implements Node interface. func (n *ColumnName) Restore(ctx *format.RestoreCtx) error { - if n.Schema.O != "" { + if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) { ctx.WriteName(n.Schema.O) ctx.WritePlain(".") } diff --git a/parser/format/format.go b/parser/format/format.go index ef003d6a78d6d..4141b0baf119a 100644 --- a/parser/format/format.go +++ b/parser/format/format.go @@ -305,12 +305,12 @@ type RestoreCtx struct { Flags RestoreFlags In io.Writer DefaultDB string - CTENames []string + CTERestorer } // NewRestoreCtx returns a new `RestoreCtx`. func NewRestoreCtx(flags RestoreFlags, in io.Writer) *RestoreCtx { - return &RestoreCtx{flags, in, "", make([]string, 0)} + return &RestoreCtx{Flags: flags, In: in, DefaultDB: ""} } // WriteKeyWord writes the `keyWord` into writer. @@ -387,3 +387,33 @@ func (ctx *RestoreCtx) WritePlain(plainText string) { func (ctx *RestoreCtx) WritePlainf(format string, a ...interface{}) { fmt.Fprintf(ctx.In, format, a...) } + +// CTERestorer is used by WithClause related nodes restore. +type CTERestorer struct { + CTENames []string +} + +// IsCTETableName returns true if the given tableName comes from CTE. +func (c *CTERestorer) IsCTETableName(nameL string) bool { + for _, n := range c.CTENames { + if n == nameL { + return true + } + } + return false +} + +func (c *CTERestorer) RecordCTEName(nameL string) { + c.CTENames = append(c.CTENames, nameL) +} + +func (c *CTERestorer) RestoreCTEFunc() func() { + l := len(c.CTENames) + return func() { + if l == 0 { + c.CTENames = nil + } else { + c.CTENames = c.CTENames[:l] + } + } +}