Skip to content

Commit

Permalink
Support nested CTEs in the column pruner.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jan 21, 2025
1 parent 4294701 commit 5f94aee
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 4 deletions.
82 changes: 82 additions & 0 deletions metricflow/sql/optimizer/cte_alias_to_cte_node_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import logging
from typing import Dict

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlSelectStatementNode

logger = logging.getLogger(__name__)


class SqlCteAliasMappingLookup:
"""A mutable lookup that stores the CTE-alias mapping at a given node.
In cases with nested CTEs in a SELECT, it's possible that a CTE defined in an inner SELECT has an alias that is the
same as a CTE defined in an outer SELECT. e.g.
# outer_cte
WITH cte_0 AS (
SELECT 1 AS col_0
)
# outer_select
SELECT col_0
FROM (
# inner_cte
WITH cte_0 AS (
SELECT 2 AS col_0
)
# inner_select
SELECT col_0 FROM cte_0
)
...
In this case, `outer_cte` and `inner_cte` both have the same alias `cte_0`. When `cte_0` is referenced from
`inner_select`, it is referring to the `inner_cte`. For column pruning, it is necessary to figure out which CTE
a given alias is referencing, so this class helps to keep track of that mapping.
"""

def __init__(self) -> None: # noqa: D107
self._select_node_to_cte_alias_mapping: Dict[SqlSelectStatementNode, SqlCteAliasMapping] = {}

def cte_alias_mapping_exists(self, select_node: SqlSelectStatementNode) -> bool:
"""Returns true if the CTE-alias mapping for the given node has been recorded."""
return select_node in self._select_node_to_cte_alias_mapping

def add_cte_alias_mapping(
self,
select_node: SqlSelectStatementNode,
cte_alias_mapping: SqlCteAliasMapping,
) -> None:
"""Associate the given CTE-alias mapping with the given node.
Raises an exception if a mapping already exists.
"""
if select_node in self._select_node_to_cte_alias_mapping:
raise RuntimeError(
str(
LazyFormat(
"`select_node` node has already been added,",
# child_select_node=child_select_node,
select_node=select_node,
current_mapping=self._select_node_to_cte_alias_mapping,
)
)
)

self._select_node_to_cte_alias_mapping[select_node] = cte_alias_mapping

def get_cte_alias_mapping(self, select_node: SqlSelectStatementNode) -> SqlCteAliasMapping:
"""Return the CTE-alias mapping for the given node.
Raises an exception if a mapping was not previously added for the given node.
"""
cte_alias_mapping = self._select_node_to_cte_alias_mapping.get(select_node)
if cte_alias_mapping is None:
raise RuntimeError(
str(LazyFormat("CTE alias mapping does not exist for the given `select_node`", select_node=select_node))
)
return cte_alias_mapping
33 changes: 30 additions & 3 deletions metricflow/sql/optimizer/required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from metricflow_semantics.sql.sql_exprs import SqlExpressionTreeLineage
from typing_extensions import override

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
Expand Down Expand Up @@ -70,15 +71,18 @@ def __init__(self, start_node: SqlPlanNode, required_column_aliases_in_start_nod

# Helps lookup the CTE node associated with a given CTE alias. A member variable is needed as any node in the
# SQL DAG can reference a CTE.
self._cte_node_lookup = SqlCteAliasMappingLookup()
start_node_as_select_node = start_node.as_select_node

self._current_cte_alias_mapping = SqlCteAliasMapping()
start_node_as_select_node = start_node.as_select_node

if start_node_as_select_node is not None:
self._current_cte_alias_mapping = SqlCteAliasMapping.create(
{cte_source.cte_alias: cte_source for cte_source in start_node_as_select_node.cte_sources}
)
self._cte_node_lookup.add_cte_alias_mapping(
select_node=start_node_as_select_node,
cte_alias_mapping=self._current_cte_alias_mapping,
)

def _search_for_expressions(
self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
Expand Down Expand Up @@ -127,13 +131,34 @@ def _visit_parents(self, node: SqlPlanNode) -> None:
def _tag_potential_cte_node(self, table_name: str, column_aliases: Set[str]) -> None:
"""A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs."""
cte_node = self._current_cte_alias_mapping.get_cte_node_for_alias(table_name)

if cte_node is not None:
self._current_required_column_alias_mapping.add_aliases(cte_node, column_aliases)
# `visit_cte_node` will handle propagating the required aliases to all CTEs that this CTE node depends on.
cte_node.accept(self)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
"""Based on required column aliases for this SELECT, figure out required column aliases in parents."""
# If this SELECT node defines any CTEs, it should override ones that were defined in the outer SELECT in case
# of CTE alias collisions.
if not self._cte_node_lookup.cte_alias_mapping_exists(node):
self._cte_node_lookup.add_cte_alias_mapping(
select_node=node,
cte_alias_mapping=self._current_cte_alias_mapping.merge(
SqlCteAliasMapping.create({cte_node.cte_alias: cte_node for cte_node in node.cte_sources})
),
)

previous_cte_alias_mapping = self._current_cte_alias_mapping
self._current_cte_alias_mapping = self._cte_node_lookup.get_cte_alias_mapping(node)
logger.debug(
LazyFormat(
"Starting visit of SELECT statement node with CTE alias mapping",
node=node,
current_cte_alias_mapping=self._current_cte_alias_mapping,
)
)

initial_required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node)

# If this SELECT statement uses DISTINCT, all columns are required as removing them would change the meaning of
Expand Down Expand Up @@ -191,8 +216,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
self._current_required_column_alias_mapping.add_alias(
node=node_to_retain_all_columns, column_alias=select_column.column_alias
)

# TODO: TBD - may be necessary to mark columns in all visible CTEs since a string can reference anything.
self._visit_parents(node)
self._current_cte_alias_mapping = previous_cte_alias_mapping
return

# Create a mapping from the source alias to the column aliases needed from the corresponding source.
Expand Down Expand Up @@ -255,6 +281,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:

# Visit recursively.
self._visit_parents(node)
self._current_cte_alias_mapping = previous_cte_alias_mapping
return

def visit_table_node(self, node: SqlTableNode) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests_metricflow/sql/optimizer/check_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ def assert_optimizer_result_snapshot_equal(
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str=snapshot_str,
expectation_description=expectation_description
expectation_description=expectation_description,
)
137 changes: 137 additions & 0 deletions tests_metricflow/sql/optimizer/test_cte_column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.formatting.formatting_helpers import mf_dedent
from metricflow_semantics.sql.sql_exprs import (
SqlColumnReference,
SqlColumnReferenceExpression,
Expand Down Expand Up @@ -328,3 +329,139 @@ def test_multi_child_pruning(
sql_plan_renderer=sql_plan_renderer,
select_statement=select_statement,
)


def test_common_cte_aliases_in_nested_query(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_pruner: SqlColumnPrunerOptimizer,
sql_plan_renderer: DefaultSqlPlanRenderer,
) -> None:
"""Test the case where a CTE defined in the top-level SELECT has the same name as a CTE in a sub-query ."""
top_level_select_ctes = (
SqlCteNode.create(
cte_alias="cte_source",
select_statement=SqlSelectStatementNode.create(
description="CTE source",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1")
),
column_alias="cte_source__col_1",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name="test_schema", table_name="test_table")),
from_source_alias="test_table_alias",
),
),
)
from_sub_query_ctes = (
SqlCteNode.create(
cte_alias="cte_source",
select_statement=SqlSelectStatementNode.create(
description="CTE source",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1")
),
column_alias="cte_source__col_1",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name="test_schema", table_name="test_table")),
from_source_alias="test_table_alias",
),
),
)

top_level_select = SqlSelectStatementNode.create(
description="top_level_select",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="from_source__col_0")
),
column_alias="top_level__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1")
),
column_alias="top_level__col_1",
),
),
from_source=SqlSelectStatementNode.create(
description="from_sub_query",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="cte_source__col_0")
),
column_alias="from_source__col_0",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source")),
from_source_alias="from_source_alias",
cte_sources=from_sub_query_ctes,
),
from_source_alias="from_source_alias",
join_descs=(
SqlJoinDescription(
right_source=SqlSelectStatementNode.create(
description="joined_sub_query",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(
table_alias="from_source_alias", column_name="cte_source__col_1"
)
),
column_alias="right_source__col_1",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source")),
from_source_alias="from_source_alias",
),
right_source_alias="right_source_alias",
on_condition=SqlComparisonExpression.create(
left_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="from_source__col_0")
),
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1")
),
),
join_type=SqlJoinType.INNER,
),
),
cte_sources=top_level_select_ctes,
)

assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=column_pruner,
sql_plan_renderer=sql_plan_renderer,
select_statement=top_level_select,
expectation_description=mf_dedent(
"""
In the `from_sub_query`, there is a reference to `cte_source__col_0` in a CTE named `cte_source`. Since
`from_sub_query` redefines `cte_source`, the column pruner should retain that column in the CTE defined
in `from_sub_query` but remove the column from the CTE defined in `top_level_select`.
"""
),
)

0 comments on commit 5f94aee

Please sign in to comment.