From d439b638aa313f09c9efc6386381b7223c3f7ad2 Mon Sep 17 00:00:00 2001 From: Alisha Date: Wed, 4 Sep 2024 08:05:51 -0700 Subject: [PATCH 1/3] update mmsql sql parser with comments --- .../query-builder2/tsql/query-qualifier.go | 90 ++++++++++++------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/worker/pkg/query-builder2/tsql/query-qualifier.go b/worker/pkg/query-builder2/tsql/query-qualifier.go index 3727f6d2bc..235051a722 100644 --- a/worker/pkg/query-builder2/tsql/query-qualifier.go +++ b/worker/pkg/query-builder2/tsql/query-qualifier.go @@ -26,9 +26,9 @@ func (l *tSqlErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSy type tsqlListener struct { *parser.BaseTSqlParserListener - currentTable string - inSearchCondition bool - sqlStack []string + currentTable string // stores most recent table found in from clause + inSearchCondition bool // tracks when we enter where clause in the tree + sqlStack []string // rebuilds new sql string Errors []string } @@ -36,15 +36,18 @@ func newTSqlListener() *tsqlListener { return &tsqlListener{} } -func (l *tsqlListener) SqlString() string { +// builds sql string +func (l *tsqlListener) sqlString() string { return strings.TrimSpace(strings.Join(l.sqlStack, "")) } -func (l *tsqlListener) Push(str string) { +// adds string to sql stack +func (l *tsqlListener) push(str string) { l.sqlStack = append(l.sqlStack, str) } -func (l *tsqlListener) Pop() string { +// removes last element in sql stack +func (l *tsqlListener) pop() string { if len(l.sqlStack) < 1 { l.Errors = append(l.Errors, "stack is empty unable to pop") return "" @@ -56,68 +59,84 @@ func (l *tsqlListener) Pop() string { return result } -// EnterSearch_condition is called when production search_condition is entered. +// creates new tree token with given text +func (l *tsqlListener) setToken(startToken, stopToken antlr.Token, text string) *antlr.CommonToken { + sourcePair := startToken.GetSource() + tokenType := startToken.GetTokenType() + + startIndex := startToken.GetStart() + stopIndex := stopToken.GetStop() + channel := startToken.GetChannel() + + newToken := antlr.NewCommonToken(sourcePair, tokenType, channel, startIndex, stopIndex) + newToken.SetText(text) + return newToken +} + +/* + the following are parser events that are activated by walking the parsed tree + renaming these functions will break the parser + available listeners can be found at https://github.com/nucleuscloud/go-antlrv4-parser/blob/main/tsql/tsqlparser_base_listener.go +*/ + +// parser listener event when we enter where clause func (l *tsqlListener) EnterSearch_condition(ctx *parser.Search_conditionContext) { l.inSearchCondition = true } -// ExitSearch_condition is called when production search_condition is exited. +// parser listener event when we exit where clause func (l *tsqlListener) ExitSearch_condition(ctx *parser.Search_conditionContext) { l.inSearchCondition = false } -// EnterSelect_statement is called when production select_statement is entered. +// parser listener event when we enter select statement func (l *tsqlListener) EnterSelect_statement(ctx *parser.Select_statementContext) { + // important so we don't process select columns l.inSearchCondition = false } -// sets current table +// sets current table found in from clause func (l *tsqlListener) EnterTable_sources(ctx *parser.Table_sourcesContext) { table := ctx.GetText() l.currentTable = qualifyTableName(table) } -// EnterTable_alias is called when production table_alias is entered. +// sets current table if alias found func (l *tsqlListener) EnterTable_alias(ctx *parser.Table_aliasContext) { l.currentTable = ctx.GetText() } +// rebuilds sql string from tree +// adds terminal node text to sql stack and adds appropriate spacing func (l *tsqlListener) VisitTerminal(node antlr.TerminalNode) { if node.GetSymbol().GetTokenType() != antlr.TokenEOF { text := node.GetText() if text == "," { - l.Pop() - l.Push(text) - l.Push(" ") + // add space after commas + l.pop() + l.push(text) + l.push(" ") } else if text == "." { - l.Pop() - l.Push(text) + // remove space before periods + // should be table.column not table . column + l.pop() + l.push(text) } else { - l.Push(text) - l.Push(" ") + // add space after each node text + l.push(text) + l.push(" ") } } } -func (l *tsqlListener) SetToken(startToken, stopToken antlr.Token, text string) *antlr.CommonToken { - sourcePair := startToken.GetSource() - tokenType := startToken.GetTokenType() - - startIndex := startToken.GetStart() - stopIndex := stopToken.GetStop() - channel := startToken.GetChannel() - - newToken := antlr.NewCommonToken(sourcePair, tokenType, channel, startIndex, stopIndex) - newToken.SetText(text) - return newToken -} - // update table name and add qualifiers func (l *tsqlListener) EnterFull_table_name(ctx *parser.Full_table_nameContext) { if !l.inSearchCondition { + // ignore any table names not in where clause return } - newToken := l.SetToken(ctx.GetStart(), ctx.GetStop(), ensureQuoted(l.currentTable)) + // creates new token with table name + newToken := l.setToken(ctx.GetStart(), ctx.GetStop(), ensureQuoted(l.currentTable)) ctx.RemoveLastChild() ctx.AddTokenNode(newToken) } @@ -126,6 +145,7 @@ func (l *tsqlListener) EnterFull_table_name(ctx *parser.Full_table_nameContext) // add table name if missing func (l *tsqlListener) EnterFull_column_name(ctx *parser.Full_column_nameContext) { if !l.inSearchCondition { + // ignore any table names not in where clause return } @@ -136,7 +156,7 @@ func (l *tsqlListener) EnterFull_column_name(ctx *parser.Full_column_nameContext text = parseColumnName(ctx.GetText()) } - newToken := l.SetToken(ctx.GetStart(), ctx.GetStop(), text) + newToken := l.setToken(ctx.GetStart(), ctx.GetStop(), text) ctx.RemoveLastChild() ctx.AddTokenNode(newToken) } @@ -150,11 +170,13 @@ func QualifyWhereCondition(sql string) (string, error) { // create the parser p := parser.NewTSqlParser(tokens) + // add error listener errorListener := newTSqlErrorListener() p.AddErrorListener(errorListener) listener := newTSqlListener() tree := p.Tsql_file() + // walk tree and listen to events antlr.ParseTreeWalkerDefault.Walk(listener, tree) if len(errorListener.Errors) > 0 { @@ -164,7 +186,7 @@ func QualifyWhereCondition(sql string) (string, error) { return "", fmt.Errorf("SQL building errors: %s", strings.Join(listener.Errors, "; ")) } - return listener.SqlString(), nil + return listener.sqlString(), nil } func parseColumnName(colText string) string { From 8bd1c9811aabeb492ba328f2187798a62edf0833 Mon Sep 17 00:00:00 2001 From: Alisha Date: Wed, 4 Sep 2024 09:40:11 -0700 Subject: [PATCH 2/3] clean up tsql parser --- .../query-builder2/tsql/query-qualifier.go | 158 +++++++++++------- 1 file changed, 102 insertions(+), 56 deletions(-) diff --git a/worker/pkg/query-builder2/tsql/query-qualifier.go b/worker/pkg/query-builder2/tsql/query-qualifier.go index 235051a722..d70cad406d 100644 --- a/worker/pkg/query-builder2/tsql/query-qualifier.go +++ b/worker/pkg/query-builder2/tsql/query-qualifier.go @@ -5,9 +5,75 @@ import ( "strings" "github.com/antlr4-go/antlr/v4" - parser "github.com/nucleuscloud/go-antlrv4-parser/tsql" + tsqlparser "github.com/nucleuscloud/go-antlrv4-parser/tsql" ) +/* + Updates columns names in where clause to be fully qualified + ex: SELECT * FROM users WHERE name = 'John' becomes SELECT * FROM users WHERE "users"."name" = 'John' + + To view query tree use + tree.ToStringTree(parser.RuleNames, parser) + + Example query tree for SELECT * FROM users WHERE name = 'John' + (tsql_file + (batch + (sql_clauses + (dml_clause + (select_statement_standalone + (select_statement + (query_expression + (query_specification + SELECT + (select_list + (select_list_elem + (asterisk *))) + FROM + (table_sources + (table_source + (table_source_item + (full_table_name + (id_ users))))) + WHERE + (search_condition + (predicate + (expression + (full_column_name + (id_ + (keyword name)))) + (comparison_operator =) + (expression + (primitive_expression + (primitive_constant 'John')))))))))))) ) +*/ +func QualifyWhereCondition(sql string) (string, error) { + inputStream := antlr.NewInputStream(sql) + + // create the lexer + lexer := tsqlparser.NewTSqlLexer(inputStream) + tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) + + // create the parser + parser := tsqlparser.NewTSqlParser(tokens) + // add error listener + errorListener := newTSqlErrorListener() + parser.AddErrorListener(errorListener) + + listener := newTSqlListener() + tree := parser.Tsql_file() + // walk tree and listen to events + antlr.ParseTreeWalkerDefault.Walk(listener, tree) + + if len(errorListener.Errors) > 0 { + return "", fmt.Errorf("SQL parsing errors: %s", strings.Join(errorListener.Errors, "; ")) + } + if len(listener.Errors) > 0 { + return "", fmt.Errorf("SQL building errors: %s", strings.Join(listener.Errors, "; ")) + } + + return listener.sqlString(), nil +} + type tSqlErrorListener struct { *antlr.DefaultErrorListener Errors []string @@ -25,7 +91,7 @@ func (l *tSqlErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSy } type tsqlListener struct { - *parser.BaseTSqlParserListener + *tsqlparser.BaseTSqlParserListener currentTable string // stores most recent table found in from clause inSearchCondition bool // tracks when we enter where clause in the tree sqlStack []string // rebuilds new sql string @@ -73,6 +139,31 @@ func (l *tsqlListener) setToken(startToken, stopToken antlr.Token, text string) return newToken } +func (l *tsqlListener) addNodeText(node antlr.TerminalNode) { + if node.GetSymbol().GetTokenType() != antlr.TokenEOF { + text := node.GetText() + if text == "," { + // add space after commas + l.pop() + l.push(text) + l.push(" ") + } else if text == "." { + // remove space before periods + // should be table.column not table . column + l.pop() + l.push(text) + } else { + // add space after each node text + l.push(text) + l.push(" ") + } + } +} + +func isTableTokenSet(ctx *tsqlparser.Full_column_nameContext) bool { + return ctx.Full_table_name() != nil && ctx.Full_table_name().GetText() != "" +} + /* the following are parser events that are activated by walking the parsed tree renaming these functions will break the parser @@ -80,57 +171,40 @@ func (l *tsqlListener) setToken(startToken, stopToken antlr.Token, text string) */ // parser listener event when we enter where clause -func (l *tsqlListener) EnterSearch_condition(ctx *parser.Search_conditionContext) { +func (l *tsqlListener) EnterSearch_condition(ctx *tsqlparser.Search_conditionContext) { l.inSearchCondition = true } // parser listener event when we exit where clause -func (l *tsqlListener) ExitSearch_condition(ctx *parser.Search_conditionContext) { +func (l *tsqlListener) ExitSearch_condition(ctx *tsqlparser.Search_conditionContext) { l.inSearchCondition = false } // parser listener event when we enter select statement -func (l *tsqlListener) EnterSelect_statement(ctx *parser.Select_statementContext) { +func (l *tsqlListener) EnterSelect_statement(ctx *tsqlparser.Select_statementContext) { // important so we don't process select columns l.inSearchCondition = false } // sets current table found in from clause -func (l *tsqlListener) EnterTable_sources(ctx *parser.Table_sourcesContext) { +func (l *tsqlListener) EnterTable_sources(ctx *tsqlparser.Table_sourcesContext) { table := ctx.GetText() l.currentTable = qualifyTableName(table) } // sets current table if alias found -func (l *tsqlListener) EnterTable_alias(ctx *parser.Table_aliasContext) { +func (l *tsqlListener) EnterTable_alias(ctx *tsqlparser.Table_aliasContext) { l.currentTable = ctx.GetText() } // rebuilds sql string from tree // adds terminal node text to sql stack and adds appropriate spacing func (l *tsqlListener) VisitTerminal(node antlr.TerminalNode) { - if node.GetSymbol().GetTokenType() != antlr.TokenEOF { - text := node.GetText() - if text == "," { - // add space after commas - l.pop() - l.push(text) - l.push(" ") - } else if text == "." { - // remove space before periods - // should be table.column not table . column - l.pop() - l.push(text) - } else { - // add space after each node text - l.push(text) - l.push(" ") - } - } + l.addNodeText(node) } // update table name and add qualifiers -func (l *tsqlListener) EnterFull_table_name(ctx *parser.Full_table_nameContext) { +func (l *tsqlListener) EnterFull_table_name(ctx *tsqlparser.Full_table_nameContext) { if !l.inSearchCondition { // ignore any table names not in where clause return @@ -143,14 +217,14 @@ func (l *tsqlListener) EnterFull_table_name(ctx *parser.Full_table_nameContext) // updates column name // add table name if missing -func (l *tsqlListener) EnterFull_column_name(ctx *parser.Full_column_nameContext) { +func (l *tsqlListener) EnterFull_column_name(ctx *tsqlparser.Full_column_nameContext) { if !l.inSearchCondition { // ignore any table names not in where clause return } var text string - if ctx.Full_table_name() == nil || ctx.Full_table_name().GetText() == "" { + if !isTableTokenSet(ctx) { text = fmt.Sprintf("%s.%s", ensureQuoted(l.currentTable), parseColumnName(ctx.GetText())) } else { text = parseColumnName(ctx.GetText()) @@ -161,34 +235,6 @@ func (l *tsqlListener) EnterFull_column_name(ctx *parser.Full_column_nameContext ctx.AddTokenNode(newToken) } -func QualifyWhereCondition(sql string) (string, error) { - inputStream := antlr.NewInputStream(sql) - - // create the lexer - lexer := parser.NewTSqlLexer(inputStream) - tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) - - // create the parser - p := parser.NewTSqlParser(tokens) - // add error listener - errorListener := newTSqlErrorListener() - p.AddErrorListener(errorListener) - - listener := newTSqlListener() - tree := p.Tsql_file() - // walk tree and listen to events - antlr.ParseTreeWalkerDefault.Walk(listener, tree) - - if len(errorListener.Errors) > 0 { - return "", fmt.Errorf("SQL parsing errors: %s", strings.Join(errorListener.Errors, "; ")) - } - if len(listener.Errors) > 0 { - return "", fmt.Errorf("SQL building errors: %s", strings.Join(listener.Errors, "; ")) - } - - return listener.sqlString(), nil -} - func parseColumnName(colText string) string { split := strings.Split(colText, ".") if len(split) == 1 { From 388c4762d2e23b78c93738337dd8c4d974a9cb8c Mon Sep 17 00:00:00 2001 From: Alisha Date: Wed, 4 Sep 2024 09:52:52 -0700 Subject: [PATCH 3/3] fix lint --- worker/pkg/query-builder2/tsql/query-qualifier.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/worker/pkg/query-builder2/tsql/query-qualifier.go b/worker/pkg/query-builder2/tsql/query-qualifier.go index d70cad406d..956e0f6f7c 100644 --- a/worker/pkg/query-builder2/tsql/query-qualifier.go +++ b/worker/pkg/query-builder2/tsql/query-qualifier.go @@ -113,16 +113,13 @@ func (l *tsqlListener) push(str string) { } // removes last element in sql stack -func (l *tsqlListener) pop() string { +func (l *tsqlListener) pop() { if len(l.sqlStack) < 1 { l.Errors = append(l.Errors, "stack is empty unable to pop") - return "" + return } - result := l.sqlStack[len(l.sqlStack)-1] l.sqlStack = l.sqlStack[:len(l.sqlStack)-1] - - return result } // creates new tree token with given text