From 3de13c1bcc956e481035360a93b5522d492c40f3 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Fri, 23 Aug 2019 14:51:18 +0800 Subject: [PATCH 1/3] ast: change the order of visiting select stmt --- ast/dml.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index 0fd5a5304..9970d80c6 100755 --- a/ast/dml.go +++ b/ast/dml.go @@ -764,6 +764,8 @@ type SelectStmt struct { IsAfterUnionDistinct bool // IsInBraces indicates whether it's a stmt in brace. IsInBraces bool + // QueryBlockOffset indicates the order of this SelectStmt if counted from left to right in the sql text. + QueryBlockOffset int } // Restore implements Node interface. @@ -906,6 +908,14 @@ func (n *SelectStmt) Accept(v Visitor) (Node, bool) { n.TableHints = newHints } + if n.Fields != nil { + node, ok := n.Fields.Accept(v) + if !ok { + return n, false + } + n.Fields = node.(*FieldList) + } + if n.From != nil { node, ok := n.From.Accept(v) if !ok { @@ -922,14 +932,6 @@ func (n *SelectStmt) Accept(v Visitor) (Node, bool) { n.Where = node.(ExprNode) } - if n.Fields != nil { - node, ok := n.Fields.Accept(v) - if !ok { - return n, false - } - n.Fields = node.(*FieldList) - } - if n.GroupBy != nil { node, ok := n.GroupBy.Accept(v) if !ok { From 52dcfa05765507d038877c2abda30199fec90a41 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Mon, 26 Aug 2019 18:49:09 +0800 Subject: [PATCH 2/3] address comments --- parser.go | 9 ++++++--- parser.y | 3 +++ parser_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ yy_parser.go | 17 ++++++++++++----- 4 files changed, 61 insertions(+), 8 deletions(-) diff --git a/parser.go b/parser.go index 2efe4ff36..02464aaa9 100644 --- a/parser.go +++ b/parser.go @@ -12539,9 +12539,10 @@ yynewstate: case 1080: { st := &ast.SelectStmt{ - SelectStmtOpts: yyS[yypt-1].item.(*ast.SelectStmtOpts), - Distinct: yyS[yypt-1].item.(*ast.SelectStmtOpts).Distinct, - Fields: yyS[yypt-0].item.(*ast.FieldList), + SelectStmtOpts: yyS[yypt-1].item.(*ast.SelectStmtOpts), + Distinct: yyS[yypt-1].item.(*ast.SelectStmtOpts).Distinct, + Fields: yyS[yypt-0].item.(*ast.FieldList), + QueryBlockOffset: parser.queryBlockOffset(), } if st.SelectStmtOpts.TableHints != nil { st.TableHints = st.SelectStmtOpts.TableHints @@ -14580,6 +14581,7 @@ yynewstate: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } case 1521: @@ -14590,6 +14592,7 @@ yynewstate: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } case 1522: diff --git a/parser.y b/parser.y index 3f1eac45c..dc06d87f1 100644 --- a/parser.y +++ b/parser.y @@ -5738,6 +5738,7 @@ SelectStmtBasic: SelectStmtOpts: $2.(*ast.SelectStmtOpts), Distinct: $2.(*ast.SelectStmtOpts).Distinct, Fields: $3.(*ast.FieldList), + QueryBlockOffset: parser.queryBlockOffset(), } if st.SelectStmtOpts.TableHints != nil { st.TableHints = st.SelectStmtOpts.TableHints @@ -8102,6 +8103,7 @@ StatementList: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } | StatementList ';' Statement @@ -8112,6 +8114,7 @@ StatementList: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) + parser.blockOffset = 0 } } diff --git a/parser_test.go b/parser_test.go index e39a8e370..7b5d0d9e2 100644 --- a/parser_test.go +++ b/parser_test.go @@ -4290,3 +4290,43 @@ func (checker *nodeTextCleaner) Enter(in ast.Node) (out ast.Node, skipChildren b func (checker *nodeTextCleaner) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } + +type queryBlockOffsetChecker struct { + curOffset int + mismatch bool +} + +func (checker *queryBlockOffsetChecker) Enter(in ast.Node) (ast.Node, bool) { + sel, ok := in.(*ast.SelectStmt) + if !ok { + return in, false + } + checker.curOffset++ + if sel.QueryBlockOffset != checker.curOffset { + checker.mismatch = true + } + return in, false +} + +func (checker *queryBlockOffsetChecker) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} + +func (s *testParserSuite) SelectStmtOffset(c *C) { + parser := parser.New() + sqls := []string{ + "select * from t; select * from t", + "select a, (select count(*) from t t1 where t1.b > t.a) from t where b > (select b from t t2 where t2.b = t.a limit 1)", + "select count(*) from t t1 where t1.a < (select count(*) from t t2 where t1.a > t2.a)", + } + checker := &queryBlockOffsetChecker{} + for _, sql := range sqls { + stmts, _, err := parser.Parse(sql, "", "") + c.Assert(err, IsNil) + for _, stmt := range stmts { + checker.curOffset = 0 + stmt.Accept(checker) + c.Assert(checker.mismatch, IsFalse) + } + } +} diff --git a/yy_parser.go b/yy_parser.go index 46e61000a..77bb7e6e3 100644 --- a/yy_parser.go +++ b/yy_parser.go @@ -91,11 +91,12 @@ func TrimComment(txt string) string { // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function. type Parser struct { - charset string - collation string - result []ast.StmtNode - src string - lexer Scanner + charset string + collation string + result []ast.StmtNode + src string + lexer Scanner + blockOffset int // the following fields are used by yyParse to reduce allocation. cache []yySymType @@ -134,6 +135,7 @@ func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode parser.collation = collation parser.src = sql parser.result = parser.result[:0] + parser.blockOffset = 0 var l yyLexer parser.lexer.reset(sql) @@ -217,6 +219,11 @@ func (parser *Parser) endOffset(v *yySymType) int { return offset } +func (parser *Parser) queryBlockOffset() int { + parser.blockOffset++ + return parser.blockOffset +} + func toInt(l yyLexer, lval *yySymType, str string) int { n, err := strconv.ParseUint(str, 10, 64) if err != nil { From e142743e52c07735c39a9dc3d61fefb784e44693 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Wed, 28 Aug 2019 12:04:59 +0800 Subject: [PATCH 3/3] Revert "address comments" This reverts commit 52dcfa05765507d038877c2abda30199fec90a41. --- parser.go | 9 +++------ parser.y | 3 --- parser_test.go | 40 ---------------------------------------- yy_parser.go | 17 +++++------------ 4 files changed, 8 insertions(+), 61 deletions(-) diff --git a/parser.go b/parser.go index 02464aaa9..2efe4ff36 100644 --- a/parser.go +++ b/parser.go @@ -12539,10 +12539,9 @@ yynewstate: case 1080: { st := &ast.SelectStmt{ - SelectStmtOpts: yyS[yypt-1].item.(*ast.SelectStmtOpts), - Distinct: yyS[yypt-1].item.(*ast.SelectStmtOpts).Distinct, - Fields: yyS[yypt-0].item.(*ast.FieldList), - QueryBlockOffset: parser.queryBlockOffset(), + SelectStmtOpts: yyS[yypt-1].item.(*ast.SelectStmtOpts), + Distinct: yyS[yypt-1].item.(*ast.SelectStmtOpts).Distinct, + Fields: yyS[yypt-0].item.(*ast.FieldList), } if st.SelectStmtOpts.TableHints != nil { st.TableHints = st.SelectStmtOpts.TableHints @@ -14581,7 +14580,6 @@ yynewstate: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) - parser.blockOffset = 0 } } case 1521: @@ -14592,7 +14590,6 @@ yynewstate: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) - parser.blockOffset = 0 } } case 1522: diff --git a/parser.y b/parser.y index dc06d87f1..3f1eac45c 100644 --- a/parser.y +++ b/parser.y @@ -5738,7 +5738,6 @@ SelectStmtBasic: SelectStmtOpts: $2.(*ast.SelectStmtOpts), Distinct: $2.(*ast.SelectStmtOpts).Distinct, Fields: $3.(*ast.FieldList), - QueryBlockOffset: parser.queryBlockOffset(), } if st.SelectStmtOpts.TableHints != nil { st.TableHints = st.SelectStmtOpts.TableHints @@ -8103,7 +8102,6 @@ StatementList: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) - parser.blockOffset = 0 } } | StatementList ';' Statement @@ -8114,7 +8112,6 @@ StatementList: s.SetText(lexer.stmtText()) } parser.result = append(parser.result, s) - parser.blockOffset = 0 } } diff --git a/parser_test.go b/parser_test.go index 7b5d0d9e2..e39a8e370 100644 --- a/parser_test.go +++ b/parser_test.go @@ -4290,43 +4290,3 @@ func (checker *nodeTextCleaner) Enter(in ast.Node) (out ast.Node, skipChildren b func (checker *nodeTextCleaner) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } - -type queryBlockOffsetChecker struct { - curOffset int - mismatch bool -} - -func (checker *queryBlockOffsetChecker) Enter(in ast.Node) (ast.Node, bool) { - sel, ok := in.(*ast.SelectStmt) - if !ok { - return in, false - } - checker.curOffset++ - if sel.QueryBlockOffset != checker.curOffset { - checker.mismatch = true - } - return in, false -} - -func (checker *queryBlockOffsetChecker) Leave(in ast.Node) (out ast.Node, ok bool) { - return in, true -} - -func (s *testParserSuite) SelectStmtOffset(c *C) { - parser := parser.New() - sqls := []string{ - "select * from t; select * from t", - "select a, (select count(*) from t t1 where t1.b > t.a) from t where b > (select b from t t2 where t2.b = t.a limit 1)", - "select count(*) from t t1 where t1.a < (select count(*) from t t2 where t1.a > t2.a)", - } - checker := &queryBlockOffsetChecker{} - for _, sql := range sqls { - stmts, _, err := parser.Parse(sql, "", "") - c.Assert(err, IsNil) - for _, stmt := range stmts { - checker.curOffset = 0 - stmt.Accept(checker) - c.Assert(checker.mismatch, IsFalse) - } - } -} diff --git a/yy_parser.go b/yy_parser.go index 77bb7e6e3..46e61000a 100644 --- a/yy_parser.go +++ b/yy_parser.go @@ -91,12 +91,11 @@ func TrimComment(txt string) string { // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function. type Parser struct { - charset string - collation string - result []ast.StmtNode - src string - lexer Scanner - blockOffset int + charset string + collation string + result []ast.StmtNode + src string + lexer Scanner // the following fields are used by yyParse to reduce allocation. cache []yySymType @@ -135,7 +134,6 @@ func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode parser.collation = collation parser.src = sql parser.result = parser.result[:0] - parser.blockOffset = 0 var l yyLexer parser.lexer.reset(sql) @@ -219,11 +217,6 @@ func (parser *Parser) endOffset(v *yySymType) int { return offset } -func (parser *Parser) queryBlockOffset() int { - parser.blockOffset++ - return parser.blockOffset -} - func toInt(l yyLexer, lval *yySymType, str string) int { n, err := strconv.ParseUint(str, 10, 64) if err != nil {