-
Notifications
You must be signed in to change notification settings - Fork 14.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(parsedquery): resolve conflict when aliases and tables have the same name #21535
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,12 +18,13 @@ | |
import re | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Any, cast, Iterator, List, Optional, Set, Tuple | ||
from typing import Any, cast, Iterator, List, NamedTuple, Optional, Set, Tuple | ||
from urllib import parse | ||
|
||
import sqlparse | ||
from sqlalchemy import and_ | ||
from sqlparse.sql import ( | ||
Function, | ||
Identifier, | ||
IdentifierList, | ||
Parenthesis, | ||
|
@@ -59,7 +60,6 @@ | |
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# TODO: Workaround for https://github.com/andialbrecht/sqlparse/issues/652. | ||
sqlparse.keywords.SQL_REGEX.insert( | ||
0, | ||
|
@@ -187,6 +187,14 @@ def __eq__(self, __o: object) -> bool: | |
|
||
|
||
class ParsedQuery: | ||
class TableTuple(NamedTuple): | ||
table: Table | ||
level: int | ||
|
||
class AliasTuple(NamedTuple): | ||
name: str | ||
level: int | ||
|
||
def __init__(self, sql_statement: str, strip_comments: bool = False): | ||
if strip_comments: | ||
sql_statement = sqlparse.format(sql_statement, strip_comments=True) | ||
|
@@ -203,13 +211,32 @@ def __init__(self, sql_statement: str, strip_comments: bool = False): | |
|
||
@property | ||
def tables(self) -> Set[Table]: | ||
def deepest_alias_level(alias_name: str) -> int: | ||
alias_entries = [a for a in alias_name_tuples if str(a.name) == alias_name] | ||
if len(alias_entries) > 0: | ||
return max(alias_entries, key=lambda a: a.level).level | ||
return -1 | ||
|
||
if not self._tables: | ||
for statement in self._parsed: | ||
self._extract_from_token(statement) | ||
|
||
self._tables = { | ||
table for table in self._tables if str(table) not in self._alias_names | ||
} | ||
result = self._extract_from_token(statement) | ||
if result: | ||
(table_tuples, alias_name_tuples) = result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Continuing my naming rant.. 😄 Here we coud probably just call then |
||
alias_names: Set[str] = { | ||
str(alias.name) for alias in alias_name_tuples | ||
} | ||
tables_not_in_alias: Set[Table] = { | ||
t.table | ||
for t in table_tuples | ||
if ( | ||
t.table.table not in alias_names | ||
or deepest_alias_level(t.table.table) < t.level | ||
) | ||
} | ||
self._tables = self._tables.union(tables_not_in_alias) | ||
self._alias_names = self._alias_names.union( | ||
{a.name for a in alias_name_tuples} | ||
) | ||
return self._tables | ||
|
||
@property | ||
|
@@ -326,7 +353,26 @@ def get_table(tlist: TokenList) -> Optional[Table]: | |
def _is_identifier(token: Token) -> bool: | ||
return isinstance(token, (IdentifierList, Identifier)) | ||
|
||
def _process_tokenlist(self, token_list: TokenList) -> None: | ||
@staticmethod | ||
def _join_result( | ||
result: Optional[Tuple[Set[TableTuple], Set[AliasTuple]]], | ||
current_tables: Set[TableTuple], | ||
current_alias_names: Set[AliasTuple], | ||
) -> Tuple[Set[TableTuple], Set[AliasTuple]]: | ||
if result: | ||
(new_tables, new_alias_names) = result or (set(), set()) | ||
return ( | ||
current_tables.union(new_tables), | ||
current_alias_names.union(new_alias_names), | ||
) | ||
|
||
def _process_tokenlist( | ||
self, | ||
token_list: TokenList, | ||
level: int, | ||
tables: Set[TableTuple], | ||
alias_names: Set[AliasTuple], | ||
) -> Optional[Tuple[Set[TableTuple], Set[AliasTuple]]]: | ||
""" | ||
Add table names to table set | ||
|
||
|
@@ -336,17 +382,22 @@ def _process_tokenlist(self, token_list: TokenList) -> None: | |
if "(" not in str(token_list): | ||
table = self.get_table(token_list) | ||
if table and not table.table.startswith(CTE_PREFIX): | ||
self._tables.add(table) | ||
return | ||
tables.add(self.TableTuple(table, level)) | ||
return tables, alias_names | ||
|
||
# store aliases | ||
if token_list.has_alias(): | ||
self._alias_names.add(token_list.get_alias()) | ||
alias = token_list.get_alias() | ||
alias_names.add(self.AliasTuple(alias, level + 1)) | ||
|
||
# some aliases are not parsed properly | ||
if token_list.tokens[0].ttype == Name: | ||
self._alias_names.add(token_list.tokens[0].value) | ||
self._extract_from_token(token_list) | ||
alias_name = token_list.tokens[0].value | ||
|
||
alias_names.add(self.AliasTuple(alias_name, level + 1)) | ||
return ParsedQuery._join_result( | ||
self._extract_from_token(token_list), tables, alias_names | ||
) | ||
|
||
def as_create_table( | ||
self, | ||
|
@@ -374,7 +425,121 @@ def as_create_table( | |
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" | ||
return exec_sql | ||
|
||
def _extract_from_token(self, token: Token) -> None: | ||
@staticmethod | ||
def _extract_with_identifiers(tokens: List[Token]) -> List[Token]: | ||
non_empty_tokens = list(filter(lambda t: not t.is_whitespace, tokens)) | ||
with_identifiers = list() | ||
i = 0 | ||
while i < len(non_empty_tokens): | ||
# Foo(a, b) as (select a,b from bar) | ||
if ( | ||
i + 2 <= len(non_empty_tokens) | ||
and isinstance(non_empty_tokens[i], Function) | ||
and non_empty_tokens[i + 1].value.lower() == "as" | ||
and isinstance(non_empty_tokens[i + 2], Parenthesis) | ||
): | ||
with_identifiers.append(TokenList(non_empty_tokens[i : i + 3])) | ||
i = i + 2 | ||
# q as (select * from foo) | ||
elif isinstance(non_empty_tokens[i], Identifier): | ||
with_identifiers.append(non_empty_tokens[i]) | ||
i = i + 1 | ||
return with_identifiers | ||
|
||
def _extract_from_token_with_with_block( # pylint: disable=C0103 | ||
self, | ||
token: Token, | ||
level: int, | ||
) -> Optional[Tuple[Set[TableTuple], Set[AliasTuple]]]: | ||
select_token, with_identifiers = self._extract_with_and_select_tokens(token) | ||
|
||
tables: Set[ParsedQuery.TableTuple] = set() | ||
alias_names: Set[ParsedQuery.AliasTuple] = set() | ||
for with_identifier in with_identifiers: | ||
alias_names, tables = self._with_identifier_table_and_alias( | ||
alias_names, level, tables, with_identifier | ||
) | ||
|
||
token_tables_and_alias = self._extract_from_token(select_token, level) | ||
|
||
if token_tables_and_alias: | ||
(select_tables, select_alias_names) = token_tables_and_alias | ||
select_tables = self._tables_that_are_not_alias_names( | ||
select_tables, alias_names | ||
) | ||
alias_names = { | ||
a | ||
for a in alias_names.union(select_alias_names) | ||
if a.name not in {str(t.table) for t in tables} | ||
} | ||
|
||
tables = tables.union( | ||
self._tables_that_are_not_alias_names(select_tables, alias_names) | ||
) | ||
|
||
return tables, alias_names | ||
|
||
def _extract_with_and_select_tokens( | ||
self, token: Token | ||
) -> Tuple[Token, List[Token]]: | ||
with_identifiers_token = token.token_next(0, skip_ws=True)[1] | ||
with_token_position = token.token_index(with_identifiers_token) + 1 | ||
select_token = TokenList(token.tokens[with_token_position:]) | ||
with_identifiers = self._extract_with_identifiers(with_identifiers_token.tokens) | ||
return select_token, with_identifiers | ||
|
||
def _with_identifier_table_and_alias( # pylint: disable=invalid-name | ||
self, | ||
alias_names: Set[AliasTuple], | ||
level: int, | ||
tables: Set[TableTuple], | ||
with_identifier: Token, | ||
) -> Tuple[Set[AliasTuple], Set[TableTuple]]: | ||
if with_identifier.is_group: | ||
result = self._process_tokenlist( | ||
TokenList(with_identifier.tokens), level, set(), set() | ||
) | ||
if result: | ||
(with_tables, with_alias_names) = result | ||
with_tables = self._tables_that_are_not_alias_names( | ||
with_tables, alias_names | ||
) | ||
alias_names = self._alias_names_that_are_not_tables( | ||
with_tables, alias_names | ||
) | ||
with_alias_names = self._alias_names_that_are_not_tables( | ||
with_tables, with_alias_names | ||
) | ||
(tables, alias_names) = self._join_result( | ||
(with_tables, with_alias_names), tables, alias_names | ||
) | ||
return alias_names, tables | ||
|
||
@staticmethod | ||
def _alias_names_that_are_not_tables( # pylint: disable=C0103 | ||
tables: Set[TableTuple], alias_names: Set[AliasTuple] | ||
) -> Set[AliasTuple]: | ||
def is_alias_not_a_table(alias: ParsedQuery.AliasTuple) -> bool: | ||
return not any( | ||
t | ||
for t in tables | ||
if alias.name == str(t.table) and alias.level > t.level | ||
) | ||
|
||
return {a for a in alias_names if is_alias_not_a_table(a)} | ||
|
||
@staticmethod | ||
def _tables_that_are_not_alias_names( # pylint: disable=C0103 | ||
tables: Set[TableTuple], alias_names: Set[AliasTuple] | ||
) -> Set[TableTuple]: | ||
def is_table_an_alias_name(table: ParsedQuery.TableTuple) -> bool: | ||
return any(a for a in alias_names if str(table.table) == a.name) | ||
|
||
return {t for t in tables if not is_table_an_alias_name(t)} | ||
|
||
def _extract_from_token( | ||
self, token: Token, level: int = 0 | ||
) -> Optional[Tuple[Set[TableTuple], Set[AliasTuple]]]: | ||
""" | ||
<Identifier> store a list of subtokens and <IdentifierList> store lists of | ||
subtoken list. | ||
|
@@ -387,15 +552,23 @@ def _extract_from_token(self, token: Token) -> None: | |
:param token: instance of Token or child class, e.g. TokenList, to be processed | ||
""" | ||
if not hasattr(token, "tokens"): | ||
return | ||
return None | ||
|
||
tables: Set[ParsedQuery.TableTuple] = set() | ||
alias_names: Set[ParsedQuery.AliasTuple] = set() | ||
|
||
table_name_preceding_token = False | ||
|
||
for item in token.tokens: | ||
if token.tokens[0].value.lower() == "with": | ||
return self._extract_from_token_with_with_block(token, level) | ||
|
||
for item in token.tokens: # pylint: disable=too-many-nested-blocks | ||
if item.is_group and ( | ||
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis) | ||
): | ||
self._extract_from_token(item) | ||
[tables, alias_names] = ParsedQuery._join_result( | ||
self._extract_from_token(item, level + 1), tables, alias_names | ||
) | ||
|
||
if item.ttype in Keyword and ( | ||
item.normalized in PRECEDES_TABLE_NAME | ||
|
@@ -409,14 +582,25 @@ def _extract_from_token(self, token: Token) -> None: | |
continue | ||
if table_name_preceding_token: | ||
if isinstance(item, Identifier): | ||
self._process_tokenlist(item) | ||
identifier_tables_and_alias = self._process_tokenlist( | ||
item, level, tables, alias_names | ||
) | ||
if identifier_tables_and_alias: | ||
(tables, alias_names) = identifier_tables_and_alias | ||
elif isinstance(item, IdentifierList): | ||
for token2 in item.get_identifiers(): | ||
if isinstance(token2, TokenList): | ||
self._process_tokenlist(token2) | ||
token_list_tables_and_alias = self._process_tokenlist( | ||
token2, level, tables, alias_names | ||
) | ||
if token_list_tables_and_alias: | ||
(tables, alias_names) = token_list_tables_and_alias | ||
elif isinstance(item, IdentifierList): | ||
if any(not self._is_identifier(token2) for token2 in item.tokens): | ||
self._extract_from_token(item) | ||
[tables, alias_names] = ParsedQuery._join_result( | ||
self._extract_from_token(item, level + 1), tables, alias_names | ||
) | ||
return tables, alias_names | ||
|
||
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: | ||
"""Returns the query with the specified limit. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,6 +153,39 @@ def test_extract_tables_subselect() -> None: | |
) | ||
|
||
|
||
def test_extract_tables_subselect_with_table_names_as_aliases() -> None: | ||
assert ( | ||
extract_tables( | ||
"""SELECT t1.* | ||
FROM ( | ||
SELECT * | ||
FROM t1 | ||
WHERE day_of_week = 'Friday' | ||
) t1 | ||
WHERE t1.resolution = 'NONE' | ||
""" | ||
) | ||
== {Table("t1")} | ||
) | ||
|
||
assert ( | ||
extract_tables( | ||
"""SELECT t1.* | ||
FROM ( | ||
SELECT * | ||
FROM t2 | ||
WHERE day_of_week = 'Friday' | ||
) t1 | ||
UNION ( | ||
SELECT * | ||
FROM t1) t2 | ||
WHERE t1.resolution = 'NONE' | ||
""" | ||
) | ||
== {Table("t2"), Table("t1")} | ||
) | ||
|
||
|
||
def test_extract_tables_select_in_expression() -> None: | ||
""" | ||
Test that parser works with ``SELECT``s used as expressions. | ||
|
@@ -525,6 +558,16 @@ def test_extract_tables_reusing_aliases() -> None: | |
with q1 as ( select key from q2 where key = '5'), | ||
q2 as ( select key from src where key = '5') | ||
select * from (select key from q1) a | ||
""" | ||
) | ||
== {Table("q2"), Table("src")} | ||
) | ||
Comment on lines
+561
to
+564
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we really sure that the with q1 as ( select distinct gender from birth_names),
birth_names as ( select * from ab_role)
select * from (select gender from q1) a produces the following error:
This would indicate to me that it was looking for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just tried something similar but it's definitely referencing the database table, not the alias which came after it: postgres=# CREATE TABLE hodor(age INT, name VARCHAR);
CREATE TABLE
postgres=# INSERT INTO hodor(age, name) VALUES (35, 'Hodor');
INSERT 0 1
postgres=# CREATE TABLE potatoes(count INT);
CREATE TABLE
postgres=# INSERT INTO potatoes(count) VALUES (100);
INSERT 0 1
postgres=# WITH
result AS (SELECT * FROM hodor),
hodor AS (SELECT * FROM potatoes)
SELECT * FROM result;
age | name
-----+-------
35 | Hodor
(1 row) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what the point of the extra CTE in the final SELECT was here, but it didn't change the result for me either: postgres=# WITH
result AS (SELECT * FROM hodor),
hodor AS (SELECT * FROM potatoes)
SELECT * FROM (SELECT * FROM result) final;
age | name
-----+-------
35 | Hodor
(1 row)
postgres=# WITH
result AS (SELECT * FROM hodor),
hodor AS (SELECT * FROM potatoes)
SELECT * FROM (SELECT age FROM result) final;
age
-----
35
(1 row) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having said that, we're using different databases. It wouldn't surprise me at all if postgres and sqlite handled CTEs slightly differently. If that's the case this endeavour is slightly doomed though; we'd need a lot more logic here to handle various different databases. Arguably it's impossible to maintain this without somehow asking the database for the table list instead of trying to work it out ourselves. Unfortunately for the db my team really cares about, trino, that's not readily possible either. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just tried it with sqlite and indeed it behaves differently: sqlite> CREATE TABLE hodor(age INT, name VARCHAR);
sqlite> INSERT INTO hodor(age, name) VALUES (35, 'Hodor');
sqlite> CREATE TABLE potatoes(count INT);
sqlite> INSERT INTO potatoes(count) VALUES (100);
sqlite> WITH
result AS (SELECT * FROM hodor),
hodor AS (SELECT * FROM potatoes)
SELECT * FROM result; ...> ...> ...>
100 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok good to know, this is likely one of those cases where different dbs behave differently. To support both use cases we could probably add a flag on |
||
assert ( | ||
extract_tables( | ||
""" | ||
with q1 as ( select key from src where key = '5'), | ||
src as ( select key from src where key = '5') | ||
select * from (select key from q1) a | ||
""" | ||
) | ||
== {Table("src")} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: while I'm having a hard time coming up with a better name for these, I'm not a big fan of Hungarian notation. An alternative could be
TableWithLevel
andAliasWithLevel
to clarify the contents rather than the type.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a fellow hater of Hungarian notation, I'd like to stress I'm not the original author :D I'll change these to something nicer.