Skip to content

Commit

Permalink
NEOS-1395 Improve the readability of TSQL query qualifier (#2607)
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi authored Sep 4, 2024
1 parent 870dfaa commit e591c93
Showing 1 changed file with 144 additions and 79 deletions.
223 changes: 144 additions & 79 deletions worker/pkg/query-builder2/tsql/query-qualifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,75 @@ import (
"strings"

"github.com/antlr4-go/antlr/v4"
parser "github.com/nucleuscloud/go-antlrv4-parser/tsql"
tsqlparser "github.com/nucleuscloud/go-antlrv4-parser/tsql"
)

/*
Updates columns names in where clause to be fully qualified
ex: SELECT * FROM users WHERE name = 'John' becomes SELECT * FROM users WHERE "users"."name" = 'John'
To view query tree use
tree.ToStringTree(parser.RuleNames, parser)
Example query tree for SELECT * FROM users WHERE name = 'John'
(tsql_file
(batch
(sql_clauses
(dml_clause
(select_statement_standalone
(select_statement
(query_expression
(query_specification
SELECT
(select_list
(select_list_elem
(asterisk *)))
FROM
(table_sources
(table_source
(table_source_item
(full_table_name
(id_ users)))))
WHERE
(search_condition
(predicate
(expression
(full_column_name
(id_
(keyword name))))
(comparison_operator =)
(expression
(primitive_expression
(primitive_constant 'John')))))))))))) <EOF>)
*/
func QualifyWhereCondition(sql string) (string, error) {
inputStream := antlr.NewInputStream(sql)

// create the lexer
lexer := tsqlparser.NewTSqlLexer(inputStream)
tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)

// create the parser
parser := tsqlparser.NewTSqlParser(tokens)
// add error listener
errorListener := newTSqlErrorListener()
parser.AddErrorListener(errorListener)

listener := newTSqlListener()
tree := parser.Tsql_file()
// walk tree and listen to events
antlr.ParseTreeWalkerDefault.Walk(listener, tree)

if len(errorListener.Errors) > 0 {
return "", fmt.Errorf("SQL parsing errors: %s", strings.Join(errorListener.Errors, "; "))
}
if len(listener.Errors) > 0 {
return "", fmt.Errorf("SQL building errors: %s", strings.Join(listener.Errors, "; "))
}

return listener.sqlString(), nil
}

type tSqlErrorListener struct {
*antlr.DefaultErrorListener
Errors []string
Expand All @@ -25,148 +91,147 @@ func (l *tSqlErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSy
}

type tsqlListener struct {
*parser.BaseTSqlParserListener
currentTable string
inSearchCondition bool
sqlStack []string
*tsqlparser.BaseTSqlParserListener
currentTable string // stores most recent table found in from clause
inSearchCondition bool // tracks when we enter where clause in the tree
sqlStack []string // rebuilds new sql string
Errors []string
}

func newTSqlListener() *tsqlListener {
return &tsqlListener{}
}

func (l *tsqlListener) SqlString() string {
// builds sql string
func (l *tsqlListener) sqlString() string {
return strings.TrimSpace(strings.Join(l.sqlStack, ""))
}

func (l *tsqlListener) Push(str string) {
// adds string to sql stack
func (l *tsqlListener) push(str string) {
l.sqlStack = append(l.sqlStack, str)
}

func (l *tsqlListener) Pop() string {
// removes last element in sql stack
func (l *tsqlListener) pop() {
if len(l.sqlStack) < 1 {
l.Errors = append(l.Errors, "stack is empty unable to pop")
return ""
return
}

result := l.sqlStack[len(l.sqlStack)-1]
l.sqlStack = l.sqlStack[:len(l.sqlStack)-1]
}

// creates new tree token with given text
func (l *tsqlListener) setToken(startToken, stopToken antlr.Token, text string) *antlr.CommonToken {
sourcePair := startToken.GetSource()
tokenType := startToken.GetTokenType()

return result
startIndex := startToken.GetStart()
stopIndex := stopToken.GetStop()
channel := startToken.GetChannel()

newToken := antlr.NewCommonToken(sourcePair, tokenType, channel, startIndex, stopIndex)
newToken.SetText(text)
return newToken
}

// EnterSearch_condition is called when production search_condition is entered.
func (l *tsqlListener) EnterSearch_condition(ctx *parser.Search_conditionContext) {
func (l *tsqlListener) addNodeText(node antlr.TerminalNode) {
if node.GetSymbol().GetTokenType() != antlr.TokenEOF {
text := node.GetText()
if text == "," {
// add space after commas
l.pop()
l.push(text)
l.push(" ")
} else if text == "." {
// remove space before periods
// should be table.column not table . column
l.pop()
l.push(text)
} else {
// add space after each node text
l.push(text)
l.push(" ")
}
}
}

func isTableTokenSet(ctx *tsqlparser.Full_column_nameContext) bool {
return ctx.Full_table_name() != nil && ctx.Full_table_name().GetText() != ""
}

/*
the following are parser events that are activated by walking the parsed tree
renaming these functions will break the parser
available listeners can be found at https://github.com/nucleuscloud/go-antlrv4-parser/blob/main/tsql/tsqlparser_base_listener.go
*/

// parser listener event when we enter where clause
func (l *tsqlListener) EnterSearch_condition(ctx *tsqlparser.Search_conditionContext) {
l.inSearchCondition = true
}

// ExitSearch_condition is called when production search_condition is exited.
func (l *tsqlListener) ExitSearch_condition(ctx *parser.Search_conditionContext) {
// parser listener event when we exit where clause
func (l *tsqlListener) ExitSearch_condition(ctx *tsqlparser.Search_conditionContext) {
l.inSearchCondition = false
}

// EnterSelect_statement is called when production select_statement is entered.
func (l *tsqlListener) EnterSelect_statement(ctx *parser.Select_statementContext) {
// parser listener event when we enter select statement
func (l *tsqlListener) EnterSelect_statement(ctx *tsqlparser.Select_statementContext) {
// important so we don't process select columns
l.inSearchCondition = false
}

// sets current table
func (l *tsqlListener) EnterTable_sources(ctx *parser.Table_sourcesContext) {
// sets current table found in from clause
func (l *tsqlListener) EnterTable_sources(ctx *tsqlparser.Table_sourcesContext) {
table := ctx.GetText()
l.currentTable = qualifyTableName(table)
}

// EnterTable_alias is called when production table_alias is entered.
func (l *tsqlListener) EnterTable_alias(ctx *parser.Table_aliasContext) {
// sets current table if alias found
func (l *tsqlListener) EnterTable_alias(ctx *tsqlparser.Table_aliasContext) {
l.currentTable = ctx.GetText()
}

// rebuilds sql string from tree
// adds terminal node text to sql stack and adds appropriate spacing
func (l *tsqlListener) VisitTerminal(node antlr.TerminalNode) {
if node.GetSymbol().GetTokenType() != antlr.TokenEOF {
text := node.GetText()
if text == "," {
l.Pop()
l.Push(text)
l.Push(" ")
} else if text == "." {
l.Pop()
l.Push(text)
} else {
l.Push(text)
l.Push(" ")
}
}
}

func (l *tsqlListener) SetToken(startToken, stopToken antlr.Token, text string) *antlr.CommonToken {
sourcePair := startToken.GetSource()
tokenType := startToken.GetTokenType()

startIndex := startToken.GetStart()
stopIndex := stopToken.GetStop()
channel := startToken.GetChannel()

newToken := antlr.NewCommonToken(sourcePair, tokenType, channel, startIndex, stopIndex)
newToken.SetText(text)
return newToken
l.addNodeText(node)
}

// update table name and add qualifiers
func (l *tsqlListener) EnterFull_table_name(ctx *parser.Full_table_nameContext) {
func (l *tsqlListener) EnterFull_table_name(ctx *tsqlparser.Full_table_nameContext) {
if !l.inSearchCondition {
// ignore any table names not in where clause
return
}
newToken := l.SetToken(ctx.GetStart(), ctx.GetStop(), ensureQuoted(l.currentTable))
// creates new token with table name
newToken := l.setToken(ctx.GetStart(), ctx.GetStop(), ensureQuoted(l.currentTable))
ctx.RemoveLastChild()
ctx.AddTokenNode(newToken)
}

// updates column name
// add table name if missing
func (l *tsqlListener) EnterFull_column_name(ctx *parser.Full_column_nameContext) {
func (l *tsqlListener) EnterFull_column_name(ctx *tsqlparser.Full_column_nameContext) {
if !l.inSearchCondition {
// ignore any table names not in where clause
return
}

var text string
if ctx.Full_table_name() == nil || ctx.Full_table_name().GetText() == "" {
if !isTableTokenSet(ctx) {
text = fmt.Sprintf("%s.%s", ensureQuoted(l.currentTable), parseColumnName(ctx.GetText()))
} else {
text = parseColumnName(ctx.GetText())
}

newToken := l.SetToken(ctx.GetStart(), ctx.GetStop(), text)
newToken := l.setToken(ctx.GetStart(), ctx.GetStop(), text)
ctx.RemoveLastChild()
ctx.AddTokenNode(newToken)
}

func QualifyWhereCondition(sql string) (string, error) {
inputStream := antlr.NewInputStream(sql)

// create the lexer
lexer := parser.NewTSqlLexer(inputStream)
tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)

// create the parser
p := parser.NewTSqlParser(tokens)
errorListener := newTSqlErrorListener()
p.AddErrorListener(errorListener)

listener := newTSqlListener()
tree := p.Tsql_file()
antlr.ParseTreeWalkerDefault.Walk(listener, tree)

if len(errorListener.Errors) > 0 {
return "", fmt.Errorf("SQL parsing errors: %s", strings.Join(errorListener.Errors, "; "))
}
if len(listener.Errors) > 0 {
return "", fmt.Errorf("SQL building errors: %s", strings.Join(listener.Errors, "; "))
}

return listener.SqlString(), nil
}

func parseColumnName(colText string) string {
split := strings.Split(colText, ".")
if len(split) == 1 {
Expand Down

0 comments on commit e591c93

Please sign in to comment.