Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: define planning_complexity tree property #538

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def is_bijective(self) -> bool:
return False

@property
def is_raw_variable(self) -> bool:
return False


@dataclasses.dataclass(frozen=True)
class ScalarConstantExpression(Expression):
Expand Down Expand Up @@ -173,6 +177,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
def is_bijective(self) -> bool:
return True

@property
def is_raw_variable(self) -> bool:
return True


@dataclasses.dataclass(frozen=True)
class OpExpression(Expression):
Expand Down
133 changes: 133 additions & 0 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
import bigframes.core.ordering as orderings
import bigframes.session

# A fixed number of variable to assume for overhead on some operations
OVERHEAD_VARIABLES = 5


@dataclass(frozen=True)
class BigFrameNode:
Expand Down Expand Up @@ -107,6 +110,38 @@ def roots(self) -> typing.Set[BigFrameNode]:
def schema(self) -> schemata.ArraySchema:
...

@property
@abc.abstractmethod
def variables_introduced(self) -> int:
"""
Defines the number of variables generated by the current node. Used to estimate query planning complexity.
"""
...

@property
def relation_ops_created(self) -> int:
"""
Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
"""
return 1

@functools.cached_property
def total_variables(self) -> int:
return self.variables_introduced + sum(
map(lambda x: x.total_variables, self.child_nodes)
)

@functools.cached_property
def total_relational_ops(self) -> int:
return self.relation_ops_created + sum(
map(lambda x: x.total_relational_ops, self.child_nodes)
)

@property
def planning_complexity(self) -> int:
"""Heuristic measure of planning complexity. Used to determine when to decompose overly complex computations."""
return self.total_variables * self.total_relational_ops


@dataclass(frozen=True)
class UnaryNode(BigFrameNode):
Expand Down Expand Up @@ -165,6 +200,10 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
)
return schemata.ArraySchema(items)

@functools.cached_property
def variables_introduced(self) -> int:
return OVERHEAD_VARIABLES


@dataclass(frozen=True)
class ConcatNode(BigFrameNode):
Expand Down Expand Up @@ -193,6 +232,11 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(items)

@functools.cached_property
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return OVERHEAD_VARIABLES


# Input Nodex
@dataclass(frozen=True)
Expand All @@ -216,6 +260,11 @@ def roots(self) -> typing.Set[BigFrameNode]:
def schema(self) -> schemata.ArraySchema:
return self.data_schema

@functools.cached_property
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return len(self.schema.items) + 1


# TODO: Refactor to take raw gbq object reference
@dataclass(frozen=True)
Expand Down Expand Up @@ -252,6 +301,15 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(items)

@functools.cached_property
def variables_introduced(self) -> int:
return len(self.columns) + len(self.hidden_ordering_columns)

@property
def relation_ops_created(self) -> int:
# Assume worst case, where readgbq actually has baked in analytic operation to generate index
return 2


# Unary nodes
@dataclass(frozen=True)
Expand All @@ -275,6 +333,10 @@ def schema(self) -> schemata.ArraySchema:
schemata.SchemaItem(self.col_id, bigframes.dtypes.INT_DTYPE)
)

@functools.cached_property
def variables_introduced(self) -> int:
return 1


@dataclass(frozen=True)
class FilterNode(UnaryNode):
Expand All @@ -287,6 +349,10 @@ def row_preserving(self) -> bool:
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 1


@dataclass(frozen=True)
class OrderByNode(UnaryNode):
Expand All @@ -304,6 +370,15 @@ def __post_init__(self):
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0

@property
def relation_ops_created(self) -> int:
# Doesnt directly create any relational operations
return 0


@dataclass(frozen=True)
class ReversedNode(UnaryNode):
Expand All @@ -313,6 +388,15 @@ class ReversedNode(UnaryNode):
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0

@property
def relation_ops_created(self) -> int:
# Doesnt directly create any relational operations
return 0


@dataclass(frozen=True)
class ProjectionNode(UnaryNode):
Expand All @@ -332,6 +416,12 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(items)

@property
def variables_introduced(self) -> int:
# ignore passthrough expressions
new_vars = sum(1 for i in self.assignments if not i[0].is_raw_variable)
return new_vars


# TODO: Merge RowCount into Aggregate Node?
# Row count can be compute from table metadata sometimes, so it is a bit special.
Expand All @@ -351,6 +441,11 @@ def schema(self) -> schemata.ArraySchema:
(schemata.SchemaItem("count", bigframes.dtypes.INT_DTYPE),)
)

@property
def variables_introduced(self) -> int:
# ignore passthrough expressions
return 1


@dataclass(frozen=True)
class AggregateNode(UnaryNode):
Expand Down Expand Up @@ -388,6 +483,10 @@ def schema(self) -> schemata.ArraySchema:
)
return schemata.ArraySchema(tuple([*by_items, *agg_items]))

@property
def variables_introduced(self) -> int:
return len(self.aggregations) + len(self.by_column_ids)


@dataclass(frozen=True)
class WindowOpNode(UnaryNode):
Expand Down Expand Up @@ -421,12 +520,31 @@ def schema(self) -> schemata.ArraySchema:
schemata.SchemaItem(self.output_name, new_item_dtype)
)

@property
def variables_introduced(self) -> int:
return 1

@property
def relation_ops_created(self) -> int:
# Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
return 0 if self.skip_reproject_unsafe else 2


# TODO: Remove this op
@dataclass(frozen=True)
class ReprojectOpNode(UnaryNode):
def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 0

@property
def relation_ops_created(self) -> int:
# This op is not a real transformation, just a hint to the sql generator
return 0


@dataclass(frozen=True)
class UnpivotNode(UnaryNode):
Expand Down Expand Up @@ -498,6 +616,17 @@ def infer_dtype(
]
return schemata.ArraySchema((*index_items, *value_items, *passthrough_items))

@property
def variables_introduced(self) -> int:
return (
len(self.schema.items) - len(self.passthrough_columns) + OVERHEAD_VARIABLES
)

@property
def relation_ops_created(self) -> int:
# Unpivot is essentially a cross join and a projection.
return 2


@dataclass(frozen=True)
class RandomSampleNode(UnaryNode):
Expand All @@ -513,3 +642,7 @@ def row_preserving(self) -> bool:

def __hash__(self):
return self._node_hash

@property
def variables_introduced(self) -> int:
return 1