From 8e8031b0db4750c4d7a821ccbcacc6549a9faa11 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 28 Mar 2024 00:52:37 +0000 Subject: [PATCH] refactor: define planning_complexity tree property --- bigframes/core/expression.py | 8 +++ bigframes/core/nodes.py | 133 +++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 8c3f52d22b..86731afe9c 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -108,6 +108,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: def is_bijective(self) -> bool: return False + @property + def is_raw_variable(self) -> bool: + return False + @dataclasses.dataclass(frozen=True) class ScalarConstantExpression(Expression): @@ -173,6 +177,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression: def is_bijective(self) -> bool: return True + @property + def is_raw_variable(self) -> bool: + return True + @dataclasses.dataclass(frozen=True) class OpExpression(Expression): diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 5ebd2a5997..8d984a0612 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -38,6 +38,9 @@ import bigframes.core.ordering as orderings import bigframes.session +# A fixed number of variable to assume for overhead on some operations +OVERHEAD_VARIABLES = 5 + @dataclass(frozen=True) class BigFrameNode: @@ -107,6 +110,38 @@ def roots(self) -> typing.Set[BigFrameNode]: def schema(self) -> schemata.ArraySchema: ... + @property + @abc.abstractmethod + def variables_introduced(self) -> int: + """ + Defines the number of variables generated by the current node. Used to estimate query planning complexity. + """ + ... + + @property + def relation_ops_created(self) -> int: + """ + Defines the number of relational ops generated by the current node. Used to estimate query planning complexity. + """ + return 1 + + @functools.cached_property + def total_variables(self) -> int: + return self.variables_introduced + sum( + map(lambda x: x.total_variables, self.child_nodes) + ) + + @functools.cached_property + def total_relational_ops(self) -> int: + return self.relation_ops_created + sum( + map(lambda x: x.total_relational_ops, self.child_nodes) + ) + + @property + def planning_complexity(self) -> int: + """Heuristic measure of planning complexity. Used to determine when to decompose overly complex computations.""" + return self.total_variables * self.total_relational_ops + @dataclass(frozen=True) class UnaryNode(BigFrameNode): @@ -165,6 +200,10 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping): ) return schemata.ArraySchema(items) + @functools.cached_property + def variables_introduced(self) -> int: + return OVERHEAD_VARIABLES + @dataclass(frozen=True) class ConcatNode(BigFrameNode): @@ -193,6 +232,11 @@ def schema(self) -> schemata.ArraySchema: ) return schemata.ArraySchema(items) + @functools.cached_property + def variables_introduced(self) -> int: + """Defines the number of variables generated by the current node. Used to estimate query planning complexity.""" + return OVERHEAD_VARIABLES + # Input Nodex @dataclass(frozen=True) @@ -216,6 +260,11 @@ def roots(self) -> typing.Set[BigFrameNode]: def schema(self) -> schemata.ArraySchema: return self.data_schema + @functools.cached_property + def variables_introduced(self) -> int: + """Defines the number of variables generated by the current node. Used to estimate query planning complexity.""" + return len(self.schema.items) + 1 + # TODO: Refactor to take raw gbq object reference @dataclass(frozen=True) @@ -252,6 +301,15 @@ def schema(self) -> schemata.ArraySchema: ) return schemata.ArraySchema(items) + @functools.cached_property + def variables_introduced(self) -> int: + return len(self.columns) + len(self.hidden_ordering_columns) + + @property + def relation_ops_created(self) -> int: + # Assume worst case, where readgbq actually has baked in analytic operation to generate index + return 2 + # Unary nodes @dataclass(frozen=True) @@ -275,6 +333,10 @@ def schema(self) -> schemata.ArraySchema: schemata.SchemaItem(self.col_id, bigframes.dtypes.INT_DTYPE) ) + @functools.cached_property + def variables_introduced(self) -> int: + return 1 + @dataclass(frozen=True) class FilterNode(UnaryNode): @@ -287,6 +349,10 @@ def row_preserving(self) -> bool: def __hash__(self): return self._node_hash + @property + def variables_introduced(self) -> int: + return 1 + @dataclass(frozen=True) class OrderByNode(UnaryNode): @@ -304,6 +370,15 @@ def __post_init__(self): def __hash__(self): return self._node_hash + @property + def variables_introduced(self) -> int: + return 0 + + @property + def relation_ops_created(self) -> int: + # Doesnt directly create any relational operations + return 0 + @dataclass(frozen=True) class ReversedNode(UnaryNode): @@ -313,6 +388,15 @@ class ReversedNode(UnaryNode): def __hash__(self): return self._node_hash + @property + def variables_introduced(self) -> int: + return 0 + + @property + def relation_ops_created(self) -> int: + # Doesnt directly create any relational operations + return 0 + @dataclass(frozen=True) class ProjectionNode(UnaryNode): @@ -332,6 +416,12 @@ def schema(self) -> schemata.ArraySchema: ) return schemata.ArraySchema(items) + @property + def variables_introduced(self) -> int: + # ignore passthrough expressions + new_vars = sum(1 for i in self.assignments if not i[0].is_raw_variable) + return new_vars + # TODO: Merge RowCount into Aggregate Node? # Row count can be compute from table metadata sometimes, so it is a bit special. @@ -351,6 +441,11 @@ def schema(self) -> schemata.ArraySchema: (schemata.SchemaItem("count", bigframes.dtypes.INT_DTYPE),) ) + @property + def variables_introduced(self) -> int: + # ignore passthrough expressions + return 1 + @dataclass(frozen=True) class AggregateNode(UnaryNode): @@ -388,6 +483,10 @@ def schema(self) -> schemata.ArraySchema: ) return schemata.ArraySchema(tuple([*by_items, *agg_items])) + @property + def variables_introduced(self) -> int: + return len(self.aggregations) + len(self.by_column_ids) + @dataclass(frozen=True) class WindowOpNode(UnaryNode): @@ -421,12 +520,31 @@ def schema(self) -> schemata.ArraySchema: schemata.SchemaItem(self.output_name, new_item_dtype) ) + @property + def variables_introduced(self) -> int: + return 1 + + @property + def relation_ops_created(self) -> int: + # Assume that if not reprojecting, that there is a sequence of window operations sharing the same window + return 0 if self.skip_reproject_unsafe else 2 + +# TODO: Remove this op @dataclass(frozen=True) class ReprojectOpNode(UnaryNode): def __hash__(self): return self._node_hash + @property + def variables_introduced(self) -> int: + return 0 + + @property + def relation_ops_created(self) -> int: + # This op is not a real transformation, just a hint to the sql generator + return 0 + @dataclass(frozen=True) class UnpivotNode(UnaryNode): @@ -498,6 +616,17 @@ def infer_dtype( ] return schemata.ArraySchema((*index_items, *value_items, *passthrough_items)) + @property + def variables_introduced(self) -> int: + return ( + len(self.schema.items) - len(self.passthrough_columns) + OVERHEAD_VARIABLES + ) + + @property + def relation_ops_created(self) -> int: + # Unpivot is essentially a cross join and a projection. + return 2 + @dataclass(frozen=True) class RandomSampleNode(UnaryNode): @@ -513,3 +642,7 @@ def row_preserving(self) -> bool: def __hash__(self): return self._node_hash + + @property + def variables_introduced(self) -> int: + return 1