Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For SqlColumnPrunerOptimizer - add support for CTEs in sub-queries #1613

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.cte_mapping_lookup_builder import SqlCteAliasMappingLookupBuilderVisitor
from metricflow.sql.optimizer.required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
Expand Down Expand Up @@ -119,8 +120,13 @@ def optimize(self, node: SqlPlanNode) -> SqlPlanNode: # noqa: D102
)
return node

cte_alias_mapping_builder = SqlCteAliasMappingLookupBuilderVisitor()
node.accept(cte_alias_mapping_builder)
cte_alias_mapping_lookup = cte_alias_mapping_builder.cte_alias_mapping_lookup

map_required_column_aliases_visitor = SqlMapRequiredColumnAliasesVisitor(
start_node=node,
cte_alias_mapping_lookup=cte_alias_mapping_lookup,
required_column_aliases_in_start_node=frozenset(
[select_column.column_alias for select_column in required_select_columns]
),
Expand Down
82 changes: 82 additions & 0 deletions metricflow/sql/optimizer/cte_alias_to_cte_node_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import logging
from typing import Dict

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlSelectStatementNode

logger = logging.getLogger(__name__)


class SqlCteAliasMappingLookup:
"""A mutable lookup that stores the CTE-alias mapping at a given node.

In cases with nested CTEs in a SELECT, it's possible that a CTE defined in an inner SELECT has an alias that is the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would they be defined with the same alias? Shouldn't we use something similar to _next_unique_table_alias() to build CTE aliases? Or is the problem that both the DataflowToSqlQueryPlanConverter and the optimizer won't have access to that function to ensure unique aliases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If something like _next_unique_table_alias is used to generate the CTE aliases, then there would be no issue. However, it would be something that you would have to remember to do in DataflowToSqlQueryPlanConverter so handling this case is more defensive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I definitely was planning to use something similar to _next_unique_table_alias() in DataflowToSqlQueryPlanConverter for CTE aliases. If you have a strong opinion about handling this defensively, that's fine and we can leave it here. Just might be nice to simplify the logic if it's not necessary!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My aim is to try to maintain the contract: if the optimizer is fed in valid SQL, then it should output valid SQL. Otherwise, it becomes harder to use / makes a trap.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh ok I think I was misunderstanding. I was imagining the CTEs would always be bubbled up to the top level, even if there were defined when visiting a subquery. I didn't realize that you could have a CTE that lives inside a subquery. In that case, your logic makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was thinking about your use case and thought CTE that lives inside a subquery was going to be a better fit since it was only locally needed? There's flexibility though, so let me merge this so that you can try it out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, either should be fine from an execution standpoint, but keeping the CTEs at the top level might be most readable!

same as a CTE defined in an outer SELECT. e.g.

# outer_cte
WITH cte_0 AS (
SELECT 1 AS col_0
)

# outer_select
SELECT col_0
FROM (

# inner_cte
WITH cte_0 AS (
SELECT 2 AS col_0
)
# inner_select
SELECT col_0 FROM cte_0
)
...

In this case, `outer_cte` and `inner_cte` both have the same alias `cte_0`. When `cte_0` is referenced from
`inner_select`, it is referring to the `inner_cte`. For column pruning, it is necessary to figure out which CTE
a given alias is referencing, so this class helps to keep track of that mapping.
"""

def __init__(self) -> None: # noqa: D107
self._select_node_to_cte_alias_mapping: Dict[SqlSelectStatementNode, SqlCteAliasMapping] = {}

def cte_alias_mapping_exists(self, select_node: SqlSelectStatementNode) -> bool:
"""Returns true if the CTE-alias mapping for the given node has been recorded."""
return select_node in self._select_node_to_cte_alias_mapping

def add_cte_alias_mapping(
self,
select_node: SqlSelectStatementNode,
cte_alias_mapping: SqlCteAliasMapping,
) -> None:
"""Associate the given CTE-alias mapping with the given node.

Raises an exception if a mapping already exists.
"""
if select_node in self._select_node_to_cte_alias_mapping:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any scenario where the select_node won't be unique in the dataflow plan at this point, and a valid plan will error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a call to cte_alias_mapping_exists before it's added, so it shouldn't throw an error. But a non-unique select_node also means that the generated SQL plan is no longer a tree?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looks like some caching in DataflowNodeToSqlCteVisitor might cause select_node won't be unique. It doesn't cause issues at the moment, but let me taking a look on how to avoid that.

raise RuntimeError(
str(
LazyFormat(
"`select_node` node has already been added,",
# child_select_node=child_select_node,
select_node=select_node,
current_mapping=self._select_node_to_cte_alias_mapping,
)
)
)

self._select_node_to_cte_alias_mapping[select_node] = cte_alias_mapping

def get_cte_alias_mapping(self, select_node: SqlSelectStatementNode) -> SqlCteAliasMapping:
"""Return the CTE-alias mapping for the given node.

Raises an exception if a mapping was not previously added for the given node.
"""
cte_alias_mapping = self._select_node_to_cte_alias_mapping.get(select_node)
if cte_alias_mapping is None:
raise RuntimeError(
str(LazyFormat("CTE alias mapping does not exist for the given `select_node`", select_node=select_node))
)
return cte_alias_mapping
94 changes: 94 additions & 0 deletions metricflow/sql/optimizer/cte_mapping_lookup_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations

import logging
from contextlib import contextmanager
from typing import Iterator

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)


class SqlCteAliasMappingLookupBuilderVisitor(SqlPlanNodeVisitor[None]):
"""Traverses the SQL plan and builds the associated `SqlCteAliasMappingLookup`.

Please see `SqlCteAliasMappingLookup` for more details.
"""

def __init__(self) -> None: # noqa: D107
self._current_cte_alias_mapping = SqlCteAliasMapping()
self._cte_alias_mapping_lookup = SqlCteAliasMappingLookup()

@contextmanager
def _save_current_cte_alias_mapping(self) -> Iterator[None]:
previous_cte_alias_mapping = self._current_cte_alias_mapping
yield
self._current_cte_alias_mapping = previous_cte_alias_mapping

def _default_handler(self, node: SqlPlanNode) -> None:
"""Default recursive handler to visit the parents of the given node."""
for parent_node in node.parent_nodes:
with self._save_current_cte_alias_mapping():
parent_node.accept(self)
return

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
return self._default_handler(node)

@override
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
"""Based on required column aliases for this SELECT, figure out required column aliases in parents."""
logger.debug(
LazyFormat(
"Starting visit of SELECT statement node with CTE alias mapping",
node=node,
current_cte_alias_mapping=self._current_cte_alias_mapping,
)
)

if self._cte_alias_mapping_lookup.cte_alias_mapping_exists(node):
return self._default_handler(node)

# Record that can see the CTEs defined in this SELECT node and outer SELECT statements.
# CTEs defined in this select node should override ones that were defined in the outer SELECT in case
# of CTE alias collisions.
self._current_cte_alias_mapping = self._current_cte_alias_mapping.merge(
SqlCteAliasMapping.create({cte_node.cte_alias: cte_node for cte_node in node.cte_sources})
)
self._cte_alias_mapping_lookup.add_cte_alias_mapping(
select_node=node,
cte_alias_mapping=self._current_cte_alias_mapping,
)

return self._default_handler(node)

@override
def visit_table_node(self, node: SqlTableNode) -> None:
self._default_handler(node)

@override
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
self._default_handler(node)

@override
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None: # noqa: D102
self._default_handler(node)

@property
def cte_alias_mapping_lookup(self) -> SqlCteAliasMappingLookup:
"""Returns the lookup created after traversal."""
return self._cte_alias_mapping_lookup
35 changes: 19 additions & 16 deletions metricflow/sql/optimizer/required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from metricflow_semantics.sql.sql_exprs import SqlExpressionTreeLineage
from typing_extensions import override

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
Expand Down Expand Up @@ -56,7 +57,12 @@ class SqlMapRequiredColumnAliasesVisitor(SqlPlanNodeVisitor[None]):
) source_0
"""

def __init__(self, start_node: SqlPlanNode, required_column_aliases_in_start_node: FrozenSet[str]) -> None:
def __init__(
self,
start_node: SqlPlanNode,
required_column_aliases_in_start_node: FrozenSet[str],
cte_alias_mapping_lookup: SqlCteAliasMappingLookup,
) -> None:
"""Initializer.

Args:
Expand All @@ -70,15 +76,7 @@ def __init__(self, start_node: SqlPlanNode, required_column_aliases_in_start_nod

# Helps lookup the CTE node associated with a given CTE alias. A member variable is needed as any node in the
# SQL DAG can reference a CTE.
start_node_as_select_node = start_node.as_select_node

self._current_cte_alias_mapping = SqlCteAliasMapping()
start_node_as_select_node = start_node.as_select_node

if start_node_as_select_node is not None:
self._current_cte_alias_mapping = SqlCteAliasMapping.create(
{cte_source.cte_alias: cte_source for cte_source in start_node_as_select_node.cte_sources}
)
self._cte_node_lookup = cte_alias_mapping_lookup

def _search_for_expressions(
self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
Expand Down Expand Up @@ -124,16 +122,21 @@ def _visit_parents(self, node: SqlPlanNode) -> None:
parent_node.accept(self)
return

def _tag_potential_cte_node(self, table_name: str, column_aliases: Set[str]) -> None:
def _tag_potential_cte_node(
self, cte_alias_mapping: SqlCteAliasMapping, table_name: str, column_aliases: Set[str]
) -> None:
"""A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs."""
cte_node = self._current_cte_alias_mapping.get_cte_node_for_alias(table_name)
cte_node = cte_alias_mapping.get_cte_node_for_alias(table_name)

if cte_node is not None:
self._current_required_column_alias_mapping.add_aliases(cte_node, column_aliases)
# `visit_cte_node` will handle propagating the required aliases to all CTEs that this CTE node depends on.
cte_node.accept(self)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
"""Based on required column aliases for this SELECT, figure out required column aliases in parents."""
cte_alias_mapping = self._cte_node_lookup.get_cte_alias_mapping(node)

initial_required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node)

# If this SELECT statement uses DISTINCT, all columns are required as removing them would change the meaning of
Expand Down Expand Up @@ -184,14 +187,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
nodes_to_retain_all_columns.append(join_desc.right_source)

for node_to_retain_all_columns in nodes_to_retain_all_columns:
nearest_select_columns = node_to_retain_all_columns.nearest_select_columns(
self._current_cte_alias_mapping
)
nearest_select_columns = node_to_retain_all_columns.nearest_select_columns(cte_alias_mapping)
for select_column in nearest_select_columns or ():
self._current_required_column_alias_mapping.add_alias(
node=node_to_retain_all_columns, column_alias=select_column.column_alias
)

# TODO: TBD - may be necessary to mark columns in all visible CTEs since a string can reference anything.
self._visit_parents(node)
return

Expand All @@ -216,6 +217,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
from_source_as_sql_table_node = node.from_source.as_sql_table_node
if from_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
cte_alias_mapping=cte_alias_mapping,
table_name=from_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)
Expand All @@ -228,6 +230,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node
if right_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
cte_alias_mapping=cte_alias_mapping,
table_name=right_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
test_name: test_common_cte_aliases_in_nested_query
test_filename: test_cte_column_pruner.py
docstring:
Test the case where a CTE defined in the top-level SELECT has the same name as a CTE in a sub-query .
expectation_description:
In the `from_sub_query`, there is a reference to `cte_source__col_0` in a CTE named `cte_source`. Since
`from_sub_query` redefines `cte_source`, the column pruner should retain that column in the CTE defined
in `from_sub_query` but remove the column from the CTE defined in `top_level_select`.
---
optimizer:
SqlColumnPrunerOptimizer

sql_before_optimizing:
-- top_level_select
WITH cte_source AS (
-- CTE source
SELECT
test_table_alias.col_0 AS cte_source__col_0
, test_table_alias.col_1 AS cte_source__col_1
FROM test_schema.test_table test_table_alias
)

SELECT
from_source_alias.from_source__col_0 AS top_level__col_0
, right_source_alias.right_source__col_1 AS top_level__col_1
FROM (
-- from_sub_query
WITH cte_source AS (
-- CTE source
SELECT
test_table_alias.col_0 AS cte_source__col_0
, test_table_alias.col_1 AS cte_source__col_1
FROM test_schema.test_table test_table_alias
)

SELECT
from_source_alias.cte_source__col_0 AS from_source__col_0
FROM cte_source from_source_alias
) from_source_alias
INNER JOIN (
-- joined_sub_query
SELECT
from_source_alias.cte_source__col_1 AS right_source__col_1
FROM cte_source from_source_alias
) right_source_alias
ON
from_source_alias.from_source__col_0 = right_source_alias.right_source__col_1

sql_after_optimizing:
-- top_level_select
WITH cte_source AS (
-- CTE source
SELECT
test_table_alias.col_1 AS cte_source__col_1
FROM test_schema.test_table test_table_alias
)

SELECT
from_source_alias.from_source__col_0 AS top_level__col_0
, right_source_alias.right_source__col_1 AS top_level__col_1
FROM (
-- from_sub_query
WITH cte_source AS (
-- CTE source
SELECT
test_table_alias.col_0 AS cte_source__col_0
FROM test_schema.test_table test_table_alias
)

SELECT
from_source_alias.cte_source__col_0 AS from_source__col_0
FROM cte_source from_source_alias
) from_source_alias
INNER JOIN (
-- joined_sub_query
SELECT
from_source_alias.cte_source__col_1 AS right_source__col_1
FROM cte_source from_source_alias
) right_source_alias
ON
from_source_alias.from_source__col_0 = right_source_alias.right_source__col_1
3 changes: 3 additions & 0 deletions tests_metricflow/sql/optimizer/check_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from typing import Optional

from _pytest.fixtures import FixtureRequest
from metricflow_semantics.mf_logging.formatting import indent
Expand All @@ -21,6 +22,7 @@ def assert_optimizer_result_snapshot_equal(
optimizer: SqlPlanOptimizer,
sql_plan_renderer: SqlPlanRenderer,
select_statement: SqlSelectStatementNode,
expectation_description: Optional[str] = None,
) -> None:
"""Helper to assert that the SQL snapshot of the optimizer result is the same as the stored one."""
sql_before_optimizing = sql_plan_renderer.render_sql_plan(SqlPlan(select_statement)).sql
Expand Down Expand Up @@ -58,4 +60,5 @@ def assert_optimizer_result_snapshot_equal(
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str=snapshot_str,
expectation_description=expectation_description,
)
Loading
Loading