Skip to content

Commit

Permalink
fixup! Add new dataflow plan nodes for custom offset windows
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jan 22, 2025
1 parent 54edeb1 commit cd61ca8
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import DUNDER
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
Expand Down Expand Up @@ -2099,26 +2098,30 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit

# Build columns that get start and end of the custom grain period.
# Ex: FIRST_VALUE(ds) OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__fiscal_quarter__first_value
new_select_columns: Tuple[SqlSelectColumn, ...] = tuple()
new_select_columns: Tuple[SqlSelectColumn, ...] = ()
bounds_columns: Tuple[SqlSelectColumn, ...] = ()
bounds_instances: Tuple[TimeDimensionInstance, ...] = ()
custom_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=custom_grain_column_name
)
base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=base_grain_column_name
)
for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE):
bounds_instance = custom_grain_instance.with_new_spec(
new_spec=custom_grain_instance.spec.with_window_functions((window_func,)),
column_association_resolver=self._column_association_resolver,
)
select_column = SqlSelectColumn(
expr=SqlWindowFunctionExpression.create(
sql_function=window_func,
sql_function_args=(base_column_expr,),
partition_by_args=(custom_column_expr,),
order_by_args=(SqlWindowOrderByArgument(base_column_expr),),
),
column_alias=self._column_association_resolver.resolve_spec(
custom_grain_instance.spec.with_window_function(window_func)
).column_name,
column_alias=bounds_instance.associated_column.column_name,
)
bounds_instances += (bounds_instance,)
bounds_columns += (select_column,)
new_select_columns += (select_column,)

Expand All @@ -2132,7 +2135,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
order_by_args=(SqlWindowOrderByArgument(base_column_expr),),
),
column_alias=self._column_association_resolver.resolve_spec(
base_grain_instance.spec.with_window_function(SqlWindowFunction.ROW_NUMBER)
base_grain_instance.spec.with_window_functions((SqlWindowFunction.ROW_NUMBER,))
).column_name,
)
new_select_columns += (row_number_column,)
Expand Down Expand Up @@ -2169,20 +2172,30 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
custom_grain_column = SqlSelectColumn.from_table_and_column_names(
column_name=custom_grain_column_name, table_alias=unique_rows_alias
)
first_value_offset_column, last_value_offset_column = tuple(
SqlSelectColumn(
expr=SqlWindowFunctionExpression.create(
sql_function=SqlWindowFunction.LEAD,
sql_function_args=(
bounds_column.ref_with_new_table_alias(unique_rows_alias),
SqlIntegerExpression.create(node.offset_window.count),
offset_bounds_columns: Tuple[SqlSelectColumn, ...] = ()
for i in range(len(bounds_columns)):
bounds_instance = bounds_instances[i]
bounds_column = bounds_columns[i]
offset_bounds_instance = bounds_instance.with_new_spec(
bounds_instance.spec.with_window_functions(
(bounds_instance.spec.window_functions + (SqlWindowFunction.LEAD,))
),
column_association_resolver=self._column_association_resolver,
)
offset_bounds_columns += (
SqlSelectColumn(
expr=SqlWindowFunctionExpression.create(
sql_function=SqlWindowFunction.LEAD,
sql_function_args=(
bounds_column.ref_with_new_table_alias(unique_rows_alias),
SqlIntegerExpression.create(node.offset_window.count),
),
order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),),
),
order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),),
column_alias=offset_bounds_instance.associated_column.column_name,
),
column_alias=f"{bounds_column.column_alias}{DUNDER}offset",
)
for bounds_column in bounds_columns
)
first_value_offset_column, last_value_offset_column = offset_bounds_columns
offset_bounds_subquery_alias = self._next_unique_table_alias()
offset_bounds_subquery = SqlSelectStatementNode.create(
description="Offset Custom Granularity Bounds",
Expand Down Expand Up @@ -2213,7 +2226,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
)
# LEAD isn't quite accurate here, but this will differentiate the offset instance (and column) from the original one.
offset_base_column_name = self._column_association_resolver.resolve_spec(
base_grain_instance.spec.with_window_function(SqlWindowFunction.LEAD)
base_grain_instance.spec.with_window_functions((SqlWindowFunction.LEAD,))
).column_name
offset_base_column = SqlSelectColumn(
expr=SqlCaseExpression.create(
Expand Down

0 comments on commit cd61ca8

Please sign in to comment.