Skip to content

Commit

Permalink
refactor: define planning_complexity tree property
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorBergeron committed Mar 28, 2024
1 parent 56cefff commit 8e8031b
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
8 changes: 8 additions & 0 deletions bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def is_bijective(self) -> bool:
return False

@property
def is_raw_variable(self) -> bool:
return False


@dataclasses.dataclass(frozen=True)
class ScalarConstantExpression(Expression):
Expand Down Expand Up @@ -173,6 +177,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def is_bijective(self) -> bool:
return True

@property
def is_raw_variable(self) -> bool:
return True


@dataclasses.dataclass(frozen=True)
class OpExpression(Expression):
Expand Down
133 changes: 133 additions & 0 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
import bigframes.core.ordering as orderings
import bigframes.session

# A fixed number of variable to assume for overhead on some operations
OVERHEAD_VARIABLES = 5


@dataclass(frozen=True)
class BigFrameNode:
Expand Down Expand Up @@ -107,6 +110,38 @@ def roots(self) -> typing.Set[BigFrameNode]:
def schema(self) -> schemata.ArraySchema:
...

@property
@abc.abstractmethod
def variables_introduced(self) -> int:
"""
Defines the number of variables generated by the current node. Used to estimate query planning complexity.
"""
...

@property
def relation_ops_created(self) -> int:
"""
Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
"""
return 1

@functools.cached_property
def total_variables(self) -> int:
return self.variables_introduced + sum(
map(lambda x: x.total_variables, self.child_nodes)
)

@functools.cached_property
def total_relational_ops(self) -> int:
return self.relation_ops_created + sum(
map(lambda x: x.total_relational_ops, self.child_nodes)
)

@property
def planning_complexity(self) -> int:
"""Heuristic measure of planning complexity. Used to determine when to decompose overly complex computations."""
return self.total_variables * self.total_relational_ops


@dataclass(frozen=True)
class UnaryNode(BigFrameNode):
Expand Down Expand Up @@ -165,6 +200,10 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
)
return schemata.ArraySchema(items)

@functools.cached_property
def variables_introduced(self) -> int:
return OVERHEAD_VARIABLES


@dataclass(frozen=True)
class ConcatNode(BigFrameNode):
Expand Down Expand Up @@ -193,6 +232,11 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(items)

@functools.cached_property
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return OVERHEAD_VARIABLES


# Input Nodex
@dataclass(frozen=True)
Expand All @@ -216,6 +260,11 @@ def roots(self) -> typing.Set[BigFrameNode]:
def schema(self) -> schemata.ArraySchema:
return self.data_schema

@functools.cached_property
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return len(self.schema.items) + 1


# TODO: Refactor to take raw gbq object reference
@dataclass(frozen=True)
Expand Down Expand Up @@ -252,6 +301,15 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(items)

@functools.cached_property
def variables_introduced(self) -> int:
return len(self.columns) + len(self.hidden_ordering_columns)

@property
def relation_ops_created(self) -> int:
# Assume worst case, where readgbq actually has baked in analytic operation to generate index
return 2


# Unary nodes
@dataclass(frozen=True)
Expand All @@ -275,6 +333,10 @@ def schema(self) -> schemata.ArraySchema:
schemata.SchemaItem(self.col_id, bigframes.dtypes.INT_DTYPE)
)

@functools.cached_property
def variables_introduced(self) -> int:
return 1


@dataclass(frozen=True)
class FilterNode(UnaryNode):
Expand All @@ -287,6 +349,10 @@ def row_preserving(self) -> bool:
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 1


@dataclass(frozen=True)
class OrderByNode(UnaryNode):
Expand All @@ -304,6 +370,15 @@ def __post_init__(self):
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0

@property
def relation_ops_created(self) -> int:
# Doesnt directly create any relational operations
return 0


@dataclass(frozen=True)
class ReversedNode(UnaryNode):
Expand All @@ -313,6 +388,15 @@ class ReversedNode(UnaryNode):
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0

@property
def relation_ops_created(self) -> int:
# Doesnt directly create any relational operations
return 0


@dataclass(frozen=True)
class ProjectionNode(UnaryNode):
Expand All @@ -332,6 +416,12 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(items)

@property
def variables_introduced(self) -> int:
# ignore passthrough expressions
new_vars = sum(1 for i in self.assignments if not i[0].is_raw_variable)
return new_vars


# TODO: Merge RowCount into Aggregate Node?
# Row count can be compute from table metadata sometimes, so it is a bit special.
Expand All @@ -351,6 +441,11 @@ def schema(self) -> schemata.ArraySchema:
(schemata.SchemaItem("count", bigframes.dtypes.INT_DTYPE),)
)

@property
def variables_introduced(self) -> int:
# ignore passthrough expressions
return 1


@dataclass(frozen=True)
class AggregateNode(UnaryNode):
Expand Down Expand Up @@ -388,6 +483,10 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(tuple([*by_items, *agg_items]))

@property
def variables_introduced(self) -> int:
return len(self.aggregations) + len(self.by_column_ids)


@dataclass(frozen=True)
class WindowOpNode(UnaryNode):
Expand Down Expand Up @@ -421,12 +520,31 @@ def schema(self) -> schemata.ArraySchema:
schemata.SchemaItem(self.output_name, new_item_dtype)
)

@property
def variables_introduced(self) -> int:
return 1

@property
def relation_ops_created(self) -> int:
# Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
return 0 if self.skip_reproject_unsafe else 2


# TODO: Remove this op
@dataclass(frozen=True)
class ReprojectOpNode(UnaryNode):
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0

@property
def relation_ops_created(self) -> int:
# This op is not a real transformation, just a hint to the sql generator
return 0


@dataclass(frozen=True)
class UnpivotNode(UnaryNode):
Expand Down Expand Up @@ -498,6 +616,17 @@ def infer_dtype(
]
return schemata.ArraySchema((*index_items, *value_items, *passthrough_items))

@property
def variables_introduced(self) -> int:
return (
len(self.schema.items) - len(self.passthrough_columns) + OVERHEAD_VARIABLES
)

@property
def relation_ops_created(self) -> int:
# Unpivot is essentially a cross join and a projection.
return 2


@dataclass(frozen=True)
class RandomSampleNode(UnaryNode):
Expand All @@ -513,3 +642,7 @@ def row_preserving(self) -> bool:

def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 1

0 comments on commit 8e8031b

Please sign in to comment.