Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Unify implicit alignment #1181

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
19 changes: 1 addition & 18 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import bigframes.core.expression as ex
import bigframes.core.guid
import bigframes.core.identifiers as ids
import bigframes.core.join_def as join_def
import bigframes.core.local_data as local_data
import bigframes.core.nodes as nodes
from bigframes.core.ordering import OrderingExpression
Expand Down Expand Up @@ -446,7 +445,7 @@ def try_row_join(
other_node, r_mapping = self.prepare_join_names(other)
import bigframes.core.rewrite

result_node = bigframes.core.rewrite.try_join_as_projection(
result_node = bigframes.core.rewrite.try_row_join(
self.node, other_node, conditions
)
if result_node is None:
Expand Down Expand Up @@ -480,22 +479,6 @@ def prepare_join_names(
else:
return other.node, {id: id for id in other.column_ids}

def try_legacy_row_join(
self,
other: ArrayValue,
join_type: join_def.JoinType,
join_keys: typing.Tuple[join_def.CoalescedColumnMapping, ...],
mappings: typing.Tuple[join_def.JoinColumnMapping, ...],
) -> typing.Optional[ArrayValue]:
import bigframes.core.rewrite

result = bigframes.core.rewrite.legacy_join_as_projection(
self.node, other.node, join_keys, mappings, join_type
)
if result is not None:
return ArrayValue(result)
return None

def explode(self, column_ids: typing.Sequence[str]) -> ArrayValue:
assert len(column_ids) > 0
for column_id in column_ids:
Expand Down
78 changes: 11 additions & 67 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,9 +2341,7 @@ def join(
# Handle null index, which only supports row join
# This is the canonical way of aligning on null index, so always allow (ignore block_identity_join)
if self.index.nlevels == other.index.nlevels == 0:
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
self, other
)
result = try_row_join(self, other)
if result is not None:
return result
raise bigframes.exceptions.NullIndexError(
Expand All @@ -2356,9 +2354,7 @@ def join(
and (self.index.nlevels == other.index.nlevels)
and (self.index.dtypes == other.index.dtypes)
):
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
self, other
)
result = try_row_join(self, other)
if result is not None:
return result

Expand Down Expand Up @@ -2697,9 +2693,11 @@ def is_uniquely_named(self: BlockIndexProperties):
return len(set(self.names)) == len(self.names)


def try_new_row_join(
left: Block, right: Block
def try_row_join(
left: Block,
right: Block,
) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]:

join_keys = tuple(
(left_id, right_id)
for left_id, right_id in zip(left.index_columns, right.index_columns)
Expand All @@ -2708,11 +2706,13 @@ def try_new_row_join(
if join_result is None: # did not succeed
return None
combined_expr, (get_column_left, get_column_right) = join_result
# Keep the left index column, and drop the matching right column
index_cols_post_join = [get_column_left[id] for id in left.index_columns]

# Can use either side's index columns, as they match exactly
index_cols_post_join = [get_column_right[id] for id in right.index_columns]
combined_expr = combined_expr.drop_columns(
[get_column_right[id] for id in right.index_columns]
[get_column_left[id] for id in left.index_columns]
)

block = Block(
combined_expr,
index_columns=index_cols_post_join,
Expand All @@ -2725,62 +2725,6 @@ def try_new_row_join(
)


def try_legacy_row_join(
left: Block,
right: Block,
*,
how="left",
) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]:
"""Joins two blocks that have a common root expression by merging the projections."""
left_expr = left.expr
right_expr = right.expr
# Create a new array value, mapping from both, then left, and then right
join_keys = tuple(
join_defs.CoalescedColumnMapping(
left_source_id=left_id,
right_source_id=right_id,
destination_id=guid.generate_guid(),
)
for left_id, right_id in zip(left.index_columns, right.index_columns)
)
left_mappings = [
join_defs.JoinColumnMapping(
source_table=join_defs.JoinSide.LEFT,
source_id=id,
destination_id=guid.generate_guid(),
)
for id in left.value_columns
]
right_mappings = [
join_defs.JoinColumnMapping(
source_table=join_defs.JoinSide.RIGHT,
source_id=id,
destination_id=guid.generate_guid(),
)
for id in right.value_columns
]
combined_expr = left_expr.try_legacy_row_join(
right_expr,
join_type=how,
join_keys=join_keys,
mappings=(*left_mappings, *right_mappings),
)
if combined_expr is None:
return None
get_column_left = {m.source_id: m.destination_id for m in left_mappings}
get_column_right = {m.source_id: m.destination_id for m in right_mappings}
block = Block(
combined_expr,
column_labels=[*left.column_labels, *right.column_labels],
index_columns=(key.destination_id for key in join_keys),
index_labels=left.index.names,
)
return (
block,
(get_column_left, get_column_right),
)


def join_with_single_row(
left: Block,
single_row_block: Block,
Expand Down
9 changes: 9 additions & 0 deletions bigframes/core/compile/aggregate_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ def _(
return _apply_window_if_present(ibis_api.rank(), window) + 1


@compile_nullary_agg.register
def _(
op: agg_ops.RowNumberOp,
column: ibis_types.Column,
window=None,
) -> ibis_types.IntegerValue:
return _apply_window_if_present(ibis_api.row_number(), window)


@compile_unary_agg.register
def _(
op: agg_ops.DenseRankOp,
Expand Down
33 changes: 1 addition & 32 deletions bigframes/core/join_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,9 @@
# limitations under the License.
from __future__ import annotations

import dataclasses
import enum
from typing import Literal, NamedTuple


class JoinSide(enum.Enum):
LEFT = 0
RIGHT = 1

def inverse(self) -> JoinSide:
if self == JoinSide.LEFT:
return JoinSide.RIGHT
return JoinSide.LEFT


JoinType = Literal["inner", "outer", "left", "right", "cross"]
from typing import NamedTuple


class JoinCondition(NamedTuple):
left_id: str
right_id: str


@dataclasses.dataclass(frozen=True)
class JoinColumnMapping:
source_table: JoinSide
source_id: str
destination_id: str


@dataclasses.dataclass(frozen=True)
class CoalescedColumnMapping:
"""Special column mapping used only by implicit joiner only"""

left_source_id: str
right_source_id: str
destination_id: str
63 changes: 46 additions & 17 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

COLUMN_SET = frozenset[bfet_ids.ColumnId]

IMPLICIT_JOINER_MASKING = True


@dataclasses.dataclass(frozen=True)
class Field:
Expand Down Expand Up @@ -83,8 +85,12 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
"""Direct children of this node"""
return tuple([])

@property
@functools.cached_property
def projection_base(self) -> BigFrameNode:
import bigframes.core.rewrite.implicit_align

if isinstance(self, bigframes.core.rewrite.implicit_align.ALIGNABLE_NODES):
return self.child.projection_base
return self

@property
Expand Down Expand Up @@ -918,10 +924,6 @@ def row_count(self) -> Optional[int]:
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
return (self.col_id,)

@property
def projection_base(self) -> BigFrameNode:
return self.child.projection_base

@property
def added_fields(self) -> Tuple[Field, ...]:
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE),)
Expand Down Expand Up @@ -1095,10 +1097,6 @@ def variables_introduced(self) -> int:
def defines_namespace(self) -> bool:
return True

@property
def projection_base(self) -> BigFrameNode:
return self.child.projection_base

@property
def row_count(self) -> Optional[int]:
return self.child.row_count
Expand Down Expand Up @@ -1173,10 +1171,6 @@ def variables_introduced(self) -> int:
def row_count(self) -> Optional[int]:
return self.child.row_count

@property
def projection_base(self) -> BigFrameNode:
return self.child.projection_base

@property
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
return tuple(id for _, id in self.assignments)
Expand Down Expand Up @@ -1361,10 +1355,6 @@ def fields(self) -> Iterable[Field]:
def variables_introduced(self) -> int:
return 1

@property
def projection_base(self) -> BigFrameNode:
return self.child.projection_base

@property
def added_fields(self) -> Tuple[Field, ...]:
return (self.added_field,)
Expand Down Expand Up @@ -1506,3 +1496,42 @@ def remap_refs(
) -> BigFrameNode:
new_ids = tuple(id.remap_column_refs(mappings) for id in self.column_ids)
return dataclasses.replace(self, column_ids=new_ids) # type: ignore


def top_down(
root: BigFrameNode,
transform: Callable[[BigFrameNode], BigFrameNode],
*,
memoize=False,
validate=False,
):
def top_down_internal(root: BigFrameNode) -> BigFrameNode:
return transform(root).transform_children(transform)

if memoize:
# MUST reassign to the same name or caching won't work recursively
top_down_internal = functools.cache(top_down_internal)
result = top_down_internal(root)
if validate:
result.validate_tree()
return result


def bottom_up(
root: BigFrameNode,
transform: Callable[[BigFrameNode], BigFrameNode],
*,
memoize=False,
validate=False,
):
def bottom_up_internal(root: BigFrameNode) -> BigFrameNode:
return transform(root.transform_children(bottom_up_internal))

if memoize:
# MUST reassign to the same name or caching won't work recursively
bottom_up_internal = functools.cache(bottom_up_internal)

result = bottom_up_internal(root)
if validate:
result.validate_tree()
return result
7 changes: 3 additions & 4 deletions bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.

from bigframes.core.rewrite.identifiers import remap_variables
from bigframes.core.rewrite.implicit_align import try_join_as_projection
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
from bigframes.core.rewrite.implicit_align import convert_relational_join, try_row_join
from bigframes.core.rewrite.slices import pullup_limit_from_slice, replace_slice_ops

__all__ = [
"legacy_join_as_projection",
"try_join_as_projection",
"try_row_join",
"replace_slice_ops",
"pullup_limit_from_slice",
"remap_variables",
"convert_relational_join",
]
Loading
Loading