Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NEOS-1395 Improve the readability of TSQL query qualifier #2607

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 144 additions & 79 deletions worker/pkg/query-builder2/tsql/query-qualifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,75 @@
"strings"

"github.com/antlr4-go/antlr/v4"
parser "github.com/nucleuscloud/go-antlrv4-parser/tsql"
tsqlparser "github.com/nucleuscloud/go-antlrv4-parser/tsql"
)

/*
Updates columns names in where clause to be fully qualified
ex: SELECT * FROM users WHERE name = 'John' becomes SELECT * FROM users WHERE "users"."name" = 'John'

To view query tree use
tree.ToStringTree(parser.RuleNames, parser)

Example query tree for SELECT * FROM users WHERE name = 'John'
(tsql_file
(batch
(sql_clauses
(dml_clause
(select_statement_standalone
(select_statement
(query_expression
(query_specification
SELECT
(select_list
(select_list_elem
(asterisk *)))
FROM
(table_sources
(table_source
(table_source_item
(full_table_name
(id_ users)))))
WHERE
(search_condition
(predicate
(expression
(full_column_name
(id_
(keyword name))))
(comparison_operator =)
(expression
(primitive_expression
(primitive_constant 'John')))))))))))) <EOF>)
*/
func QualifyWhereCondition(sql string) (string, error) {
inputStream := antlr.NewInputStream(sql)

// create the lexer
lexer := tsqlparser.NewTSqlLexer(inputStream)
tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)

// create the parser
parser := tsqlparser.NewTSqlParser(tokens)
// add error listener
errorListener := newTSqlErrorListener()
parser.AddErrorListener(errorListener)

listener := newTSqlListener()
tree := parser.Tsql_file()
// walk tree and listen to events
antlr.ParseTreeWalkerDefault.Walk(listener, tree)

if len(errorListener.Errors) > 0 {
return "", fmt.Errorf("SQL parsing errors: %s", strings.Join(errorListener.Errors, "; "))
}
if len(listener.Errors) > 0 {
return "", fmt.Errorf("SQL building errors: %s", strings.Join(listener.Errors, "; "))

Check warning on line 71 in worker/pkg/query-builder2/tsql/query-qualifier.go

View check run for this annotation

Codecov / codecov/patch

worker/pkg/query-builder2/tsql/query-qualifier.go#L71

Added line #L71 was not covered by tests
}

return listener.sqlString(), nil
}

type tSqlErrorListener struct {
*antlr.DefaultErrorListener
Errors []string
Expand All @@ -25,148 +91,147 @@
}

type tsqlListener struct {
*parser.BaseTSqlParserListener
currentTable string
inSearchCondition bool
sqlStack []string
*tsqlparser.BaseTSqlParserListener
currentTable string // stores most recent table found in from clause
inSearchCondition bool // tracks when we enter where clause in the tree
sqlStack []string // rebuilds new sql string
Errors []string
}

func newTSqlListener() *tsqlListener {
return &tsqlListener{}
}

func (l *tsqlListener) SqlString() string {
// builds sql string
func (l *tsqlListener) sqlString() string {
return strings.TrimSpace(strings.Join(l.sqlStack, ""))
}

func (l *tsqlListener) Push(str string) {
// adds string to sql stack
func (l *tsqlListener) push(str string) {
l.sqlStack = append(l.sqlStack, str)
}

func (l *tsqlListener) Pop() string {
// removes last element in sql stack
func (l *tsqlListener) pop() {
if len(l.sqlStack) < 1 {
l.Errors = append(l.Errors, "stack is empty unable to pop")
return ""
return

Check warning on line 119 in worker/pkg/query-builder2/tsql/query-qualifier.go

View check run for this annotation

Codecov / codecov/patch

worker/pkg/query-builder2/tsql/query-qualifier.go#L119

Added line #L119 was not covered by tests
}

result := l.sqlStack[len(l.sqlStack)-1]
l.sqlStack = l.sqlStack[:len(l.sqlStack)-1]
}

// creates new tree token with given text
func (l *tsqlListener) setToken(startToken, stopToken antlr.Token, text string) *antlr.CommonToken {
sourcePair := startToken.GetSource()
tokenType := startToken.GetTokenType()

return result
startIndex := startToken.GetStart()
stopIndex := stopToken.GetStop()
channel := startToken.GetChannel()

newToken := antlr.NewCommonToken(sourcePair, tokenType, channel, startIndex, stopIndex)
newToken.SetText(text)
return newToken
}

// EnterSearch_condition is called when production search_condition is entered.
func (l *tsqlListener) EnterSearch_condition(ctx *parser.Search_conditionContext) {
func (l *tsqlListener) addNodeText(node antlr.TerminalNode) {
if node.GetSymbol().GetTokenType() != antlr.TokenEOF {
text := node.GetText()
if text == "," {
// add space after commas
l.pop()
l.push(text)
l.push(" ")
} else if text == "." {
// remove space before periods
// should be table.column not table . column
l.pop()
l.push(text)
} else {
// add space after each node text
l.push(text)
l.push(" ")
}
}
}

func isTableTokenSet(ctx *tsqlparser.Full_column_nameContext) bool {
return ctx.Full_table_name() != nil && ctx.Full_table_name().GetText() != ""
}

/*
the following are parser events that are activated by walking the parsed tree
renaming these functions will break the parser
available listeners can be found at https://github.com/nucleuscloud/go-antlrv4-parser/blob/main/tsql/tsqlparser_base_listener.go
*/

// parser listener event when we enter where clause
func (l *tsqlListener) EnterSearch_condition(ctx *tsqlparser.Search_conditionContext) {
l.inSearchCondition = true
}

// ExitSearch_condition is called when production search_condition is exited.
func (l *tsqlListener) ExitSearch_condition(ctx *parser.Search_conditionContext) {
// parser listener event when we exit where clause
func (l *tsqlListener) ExitSearch_condition(ctx *tsqlparser.Search_conditionContext) {
l.inSearchCondition = false
}

// EnterSelect_statement is called when production select_statement is entered.
func (l *tsqlListener) EnterSelect_statement(ctx *parser.Select_statementContext) {
// parser listener event when we enter select statement
func (l *tsqlListener) EnterSelect_statement(ctx *tsqlparser.Select_statementContext) {
// important so we don't process select columns
l.inSearchCondition = false
}

// sets current table
func (l *tsqlListener) EnterTable_sources(ctx *parser.Table_sourcesContext) {
// sets current table found in from clause
func (l *tsqlListener) EnterTable_sources(ctx *tsqlparser.Table_sourcesContext) {
table := ctx.GetText()
l.currentTable = qualifyTableName(table)
}

// EnterTable_alias is called when production table_alias is entered.
func (l *tsqlListener) EnterTable_alias(ctx *parser.Table_aliasContext) {
// sets current table if alias found
func (l *tsqlListener) EnterTable_alias(ctx *tsqlparser.Table_aliasContext) {

Check warning on line 193 in worker/pkg/query-builder2/tsql/query-qualifier.go

View check run for this annotation

Codecov / codecov/patch

worker/pkg/query-builder2/tsql/query-qualifier.go#L193

Added line #L193 was not covered by tests
l.currentTable = ctx.GetText()
}

// rebuilds sql string from tree
// adds terminal node text to sql stack and adds appropriate spacing
func (l *tsqlListener) VisitTerminal(node antlr.TerminalNode) {
if node.GetSymbol().GetTokenType() != antlr.TokenEOF {
text := node.GetText()
if text == "," {
l.Pop()
l.Push(text)
l.Push(" ")
} else if text == "." {
l.Pop()
l.Push(text)
} else {
l.Push(text)
l.Push(" ")
}
}
}

func (l *tsqlListener) SetToken(startToken, stopToken antlr.Token, text string) *antlr.CommonToken {
sourcePair := startToken.GetSource()
tokenType := startToken.GetTokenType()

startIndex := startToken.GetStart()
stopIndex := stopToken.GetStop()
channel := startToken.GetChannel()

newToken := antlr.NewCommonToken(sourcePair, tokenType, channel, startIndex, stopIndex)
newToken.SetText(text)
return newToken
l.addNodeText(node)
}

// update table name and add qualifiers
func (l *tsqlListener) EnterFull_table_name(ctx *parser.Full_table_nameContext) {
func (l *tsqlListener) EnterFull_table_name(ctx *tsqlparser.Full_table_nameContext) {
if !l.inSearchCondition {
// ignore any table names not in where clause
return
}
newToken := l.SetToken(ctx.GetStart(), ctx.GetStop(), ensureQuoted(l.currentTable))
// creates new token with table name
newToken := l.setToken(ctx.GetStart(), ctx.GetStop(), ensureQuoted(l.currentTable))
ctx.RemoveLastChild()
ctx.AddTokenNode(newToken)
}

// updates column name
// add table name if missing
func (l *tsqlListener) EnterFull_column_name(ctx *parser.Full_column_nameContext) {
func (l *tsqlListener) EnterFull_column_name(ctx *tsqlparser.Full_column_nameContext) {
if !l.inSearchCondition {
// ignore any table names not in where clause
return
}

var text string
if ctx.Full_table_name() == nil || ctx.Full_table_name().GetText() == "" {
if !isTableTokenSet(ctx) {
text = fmt.Sprintf("%s.%s", ensureQuoted(l.currentTable), parseColumnName(ctx.GetText()))
} else {
text = parseColumnName(ctx.GetText())
}

newToken := l.SetToken(ctx.GetStart(), ctx.GetStop(), text)
newToken := l.setToken(ctx.GetStart(), ctx.GetStop(), text)
ctx.RemoveLastChild()
ctx.AddTokenNode(newToken)
}

func QualifyWhereCondition(sql string) (string, error) {
inputStream := antlr.NewInputStream(sql)

// create the lexer
lexer := parser.NewTSqlLexer(inputStream)
tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)

// create the parser
p := parser.NewTSqlParser(tokens)
errorListener := newTSqlErrorListener()
p.AddErrorListener(errorListener)

listener := newTSqlListener()
tree := p.Tsql_file()
antlr.ParseTreeWalkerDefault.Walk(listener, tree)

if len(errorListener.Errors) > 0 {
return "", fmt.Errorf("SQL parsing errors: %s", strings.Join(errorListener.Errors, "; "))
}
if len(listener.Errors) > 0 {
return "", fmt.Errorf("SQL building errors: %s", strings.Join(listener.Errors, "; "))
}

return listener.SqlString(), nil
}

func parseColumnName(colText string) string {
split := strings.Split(colText, ".")
if len(split) == 1 {
Expand Down