Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(parsedquery): resolve conflict when aliases and tables have the same name #21535

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 204 additions & 20 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any, cast, Iterator, List, Optional, Set, Tuple
from typing import Any, cast, Iterator, List, NamedTuple, Optional, Set, Tuple
from urllib import parse

import sqlparse
from sqlalchemy import and_
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
Parenthesis,
Expand Down Expand Up @@ -59,7 +60,6 @@

logger = logging.getLogger(__name__)


# TODO: Workaround for https://github.com/andialbrecht/sqlparse/issues/652.
sqlparse.keywords.SQL_REGEX.insert(
0,
Expand Down Expand Up @@ -187,6 +187,14 @@ def __eq__(self, __o: object) -> bool:


class ParsedQuery:
class TableTuple(NamedTuple):
table: Table
level: int

class AliasTuple(NamedTuple):
name: str
level: int
Comment on lines +190 to +196
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: while I'm having a hard time coming up with a better name for these, I'm not a big fan of Hungarian notation. An alternative could be TableWithLevel and AliasWithLevel to clarify the contents rather than the type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a fellow hater of Hungarian notation, I'd like to stress I'm not the original author :D I'll change these to something nicer.


def __init__(self, sql_statement: str, strip_comments: bool = False):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
Expand All @@ -203,13 +211,32 @@ def __init__(self, sql_statement: str, strip_comments: bool = False):

@property
def tables(self) -> Set[Table]:
def deepest_alias_level(alias_name: str) -> int:
alias_entries = [a for a in alias_name_tuples if str(a.name) == alias_name]
if len(alias_entries) > 0:
return max(alias_entries, key=lambda a: a.level).level
return -1

if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)

self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
result = self._extract_from_token(statement)
if result:
(table_tuples, alias_name_tuples) = result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuing my naming rant.. 😄 Here we coud probably just call then tables and aliases, and then later extract the table_names/table_levels from them etc.

alias_names: Set[str] = {
str(alias.name) for alias in alias_name_tuples
}
tables_not_in_alias: Set[Table] = {
t.table
for t in table_tuples
if (
t.table.table not in alias_names
or deepest_alias_level(t.table.table) < t.level
)
}
self._tables = self._tables.union(tables_not_in_alias)
self._alias_names = self._alias_names.union(
{a.name for a in alias_name_tuples}
)
return self._tables

@property
Expand Down Expand Up @@ -326,7 +353,26 @@ def get_table(tlist: TokenList) -> Optional[Table]:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))

def _process_tokenlist(self, token_list: TokenList) -> None:
@staticmethod
def _join_result(
result: Optional[Tuple[Set[TableTuple], Set[AliasTuple]]],
current_tables: Set[TableTuple],
current_alias_names: Set[AliasTuple],
) -> Tuple[Set[TableTuple], Set[AliasTuple]]:
if result:
(new_tables, new_alias_names) = result or (set(), set())
return (
current_tables.union(new_tables),
current_alias_names.union(new_alias_names),
)

def _process_tokenlist(
self,
token_list: TokenList,
level: int,
tables: Set[TableTuple],
alias_names: Set[AliasTuple],
) -> Optional[Tuple[Set[TableTuple], Set[AliasTuple]]]:
"""
Add table names to table set

Expand All @@ -336,17 +382,22 @@ def _process_tokenlist(self, token_list: TokenList) -> None:
if "(" not in str(token_list):
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
tables.add(self.TableTuple(table, level))
return tables, alias_names

# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())
alias = token_list.get_alias()
alias_names.add(self.AliasTuple(alias, level + 1))

# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)
alias_name = token_list.tokens[0].value

alias_names.add(self.AliasTuple(alias_name, level + 1))
return ParsedQuery._join_result(
self._extract_from_token(token_list), tables, alias_names
)

def as_create_table(
self,
Expand Down Expand Up @@ -374,7 +425,121 @@ def as_create_table(
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql

def _extract_from_token(self, token: Token) -> None:
@staticmethod
def _extract_with_identifiers(tokens: List[Token]) -> List[Token]:
non_empty_tokens = list(filter(lambda t: not t.is_whitespace, tokens))
with_identifiers = list()
i = 0
while i < len(non_empty_tokens):
# Foo(a, b) as (select a,b from bar)
if (
i + 2 <= len(non_empty_tokens)
and isinstance(non_empty_tokens[i], Function)
and non_empty_tokens[i + 1].value.lower() == "as"
and isinstance(non_empty_tokens[i + 2], Parenthesis)
):
with_identifiers.append(TokenList(non_empty_tokens[i : i + 3]))
i = i + 2
# q as (select * from foo)
elif isinstance(non_empty_tokens[i], Identifier):
with_identifiers.append(non_empty_tokens[i])
i = i + 1
return with_identifiers

def _extract_from_token_with_with_block( # pylint: disable=C0103
self,
token: Token,
level: int,
) -> Optional[Tuple[Set[TableTuple], Set[AliasTuple]]]:
select_token, with_identifiers = self._extract_with_and_select_tokens(token)

tables: Set[ParsedQuery.TableTuple] = set()
alias_names: Set[ParsedQuery.AliasTuple] = set()
for with_identifier in with_identifiers:
alias_names, tables = self._with_identifier_table_and_alias(
alias_names, level, tables, with_identifier
)

token_tables_and_alias = self._extract_from_token(select_token, level)

if token_tables_and_alias:
(select_tables, select_alias_names) = token_tables_and_alias
select_tables = self._tables_that_are_not_alias_names(
select_tables, alias_names
)
alias_names = {
a
for a in alias_names.union(select_alias_names)
if a.name not in {str(t.table) for t in tables}
}

tables = tables.union(
self._tables_that_are_not_alias_names(select_tables, alias_names)
)

return tables, alias_names

def _extract_with_and_select_tokens(
self, token: Token
) -> Tuple[Token, List[Token]]:
with_identifiers_token = token.token_next(0, skip_ws=True)[1]
with_token_position = token.token_index(with_identifiers_token) + 1
select_token = TokenList(token.tokens[with_token_position:])
with_identifiers = self._extract_with_identifiers(with_identifiers_token.tokens)
return select_token, with_identifiers

def _with_identifier_table_and_alias( # pylint: disable=invalid-name
self,
alias_names: Set[AliasTuple],
level: int,
tables: Set[TableTuple],
with_identifier: Token,
) -> Tuple[Set[AliasTuple], Set[TableTuple]]:
if with_identifier.is_group:
result = self._process_tokenlist(
TokenList(with_identifier.tokens), level, set(), set()
)
if result:
(with_tables, with_alias_names) = result
with_tables = self._tables_that_are_not_alias_names(
with_tables, alias_names
)
alias_names = self._alias_names_that_are_not_tables(
with_tables, alias_names
)
with_alias_names = self._alias_names_that_are_not_tables(
with_tables, with_alias_names
)
(tables, alias_names) = self._join_result(
(with_tables, with_alias_names), tables, alias_names
)
return alias_names, tables

@staticmethod
def _alias_names_that_are_not_tables( # pylint: disable=C0103
tables: Set[TableTuple], alias_names: Set[AliasTuple]
) -> Set[AliasTuple]:
def is_alias_not_a_table(alias: ParsedQuery.AliasTuple) -> bool:
return not any(
t
for t in tables
if alias.name == str(t.table) and alias.level > t.level
)

return {a for a in alias_names if is_alias_not_a_table(a)}

@staticmethod
def _tables_that_are_not_alias_names( # pylint: disable=C0103
tables: Set[TableTuple], alias_names: Set[AliasTuple]
) -> Set[TableTuple]:
def is_table_an_alias_name(table: ParsedQuery.TableTuple) -> bool:
return any(a for a in alias_names if str(table.table) == a.name)

return {t for t in tables if not is_table_an_alias_name(t)}

def _extract_from_token(
self, token: Token, level: int = 0
) -> Optional[Tuple[Set[TableTuple], Set[AliasTuple]]]:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
Expand All @@ -387,15 +552,23 @@ def _extract_from_token(self, token: Token) -> None:
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return
return None

tables: Set[ParsedQuery.TableTuple] = set()
alias_names: Set[ParsedQuery.AliasTuple] = set()

table_name_preceding_token = False

for item in token.tokens:
if token.tokens[0].value.lower() == "with":
return self._extract_from_token_with_with_block(token, level)

for item in token.tokens: # pylint: disable=too-many-nested-blocks
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)
[tables, alias_names] = ParsedQuery._join_result(
self._extract_from_token(item, level + 1), tables, alias_names
)

if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
Expand All @@ -409,14 +582,25 @@ def _extract_from_token(self, token: Token) -> None:
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
identifier_tables_and_alias = self._process_tokenlist(
item, level, tables, alias_names
)
if identifier_tables_and_alias:
(tables, alias_names) = identifier_tables_and_alias
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self._process_tokenlist(token2)
token_list_tables_and_alias = self._process_tokenlist(
token2, level, tables, alias_names
)
if token_list_tables_and_alias:
(tables, alias_names) = token_list_tables_and_alias
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(token2) for token2 in item.tokens):
self._extract_from_token(item)
[tables, alias_names] = ParsedQuery._join_result(
self._extract_from_token(item, level + 1), tables, alias_names
)
return tables, alias_names

def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit.
Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,39 @@ def test_extract_tables_subselect() -> None:
)


def test_extract_tables_subselect_with_table_names_as_aliases() -> None:
assert (
extract_tables(
"""SELECT t1.*
FROM (
SELECT *
FROM t1
WHERE day_of_week = 'Friday'
) t1
WHERE t1.resolution = 'NONE'
"""
)
== {Table("t1")}
)

assert (
extract_tables(
"""SELECT t1.*
FROM (
SELECT *
FROM t2
WHERE day_of_week = 'Friday'
) t1
UNION (
SELECT *
FROM t1) t2
WHERE t1.resolution = 'NONE'
"""
)
== {Table("t2"), Table("t1")}
)


def test_extract_tables_select_in_expression() -> None:
"""
Test that parser works with ``SELECT``s used as expressions.
Expand Down Expand Up @@ -525,6 +558,16 @@ def test_extract_tables_reusing_aliases() -> None:
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a
"""
)
== {Table("q2"), Table("src")}
)
Comment on lines +561 to +564
Copy link
Member

@villebro villebro May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we really sure that the q2 reference in the q1 CTE will in fact reference the table q2 in the database? I did the following test on my devenv which has examples data and the superset metadata in the same database, and

with q1 as ( select distinct gender from birth_names),
birth_names as ( select * from ab_role)
select * from (select gender from q1) a

produces the following error:

sqlite error: no such column: gender

This would indicate to me that it was looking for gender in ab_role, not the actual birth_names table. So to me it appears as if only src is in fact only referenced here, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried something similar but it's definitely referencing the database table, not the alias which came after it:

postgres=# CREATE TABLE hodor(age INT, name VARCHAR);
CREATE TABLE
postgres=# INSERT INTO hodor(age, name) VALUES (35, 'Hodor');
INSERT 0 1
postgres=# CREATE TABLE potatoes(count INT);
CREATE TABLE
postgres=# INSERT INTO potatoes(count) VALUES (100);
INSERT 0 1
postgres=# WITH
  result AS (SELECT * FROM hodor),
  hodor AS (SELECT * FROM potatoes)
SELECT * FROM result;
 age | name  
-----+-------
  35 | Hodor
(1 row)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the point of the extra CTE in the final SELECT was here, but it didn't change the result for me either:

postgres=# WITH
  result AS (SELECT * FROM hodor),
  hodor AS (SELECT * FROM potatoes)
SELECT * FROM (SELECT * FROM result) final;
 age | name  
-----+-------
  35 | Hodor
(1 row)

postgres=# WITH
  result AS (SELECT * FROM hodor),
  hodor AS (SELECT * FROM potatoes)
SELECT * FROM (SELECT age FROM result) final;
 age 
-----
  35
(1 row)

Copy link
Contributor

@giftig giftig May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, we're using different databases. It wouldn't surprise me at all if postgres and sqlite handled CTEs slightly differently. If that's the case this endeavour is slightly doomed though; we'd need a lot more logic here to handle various different databases.

Arguably it's impossible to maintain this without somehow asking the database for the table list instead of trying to work it out ourselves. Unfortunately for the db my team really cares about, trino, that's not readily possible either.

Copy link
Contributor

@giftig giftig May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried it with sqlite and indeed it behaves differently:

sqlite> CREATE TABLE hodor(age INT, name VARCHAR);
sqlite> INSERT INTO hodor(age, name) VALUES (35, 'Hodor');
sqlite> CREATE TABLE potatoes(count INT);
sqlite> INSERT INTO potatoes(count) VALUES (100);
sqlite> WITH
  result AS (SELECT * FROM hodor),
  hodor AS (SELECT * FROM potatoes)
SELECT * FROM result;   ...>    ...>    ...> 
100

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok good to know, this is likely one of those cases where different dbs behave differently. To support both use cases we could probably add a flag on BaseEngineSpec that specifies which variant the db in question supports. Something like cte_forward_alias_reference that we'd set to False for SQLite and True for Postgres (we should probably do some additional testing to see which is the more common flavor before setting that on BaseEngineSpec).

assert (
extract_tables(
"""
with q1 as ( select key from src where key = '5'),
src as ( select key from src where key = '5')
select * from (select key from q1) a
"""
)
== {Table("src")}
Expand Down
Loading