Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Workaround functools cache recursion limits for python 3.12+ #1257

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io
import typing

# Modified version of functools cache methods to workaround https://github.com/python/cpython/issues/112215
import bigframes_vendored.cpython.functools as vendored_functools
import bigframes_vendored.ibis.backends.bigquery as ibis_bigquery
import bigframes_vendored.ibis.expr.api as ibis_api
import bigframes_vendored.ibis.expr.types as ibis_types
Expand Down Expand Up @@ -112,8 +114,7 @@ def compile_ordered_ir(self, node: nodes.BigFrameNode) -> compiled.OrderedIR:
def compile_unordered_ir(self, node: nodes.BigFrameNode) -> compiled.UnorderedIR:
return typing.cast(compiled.UnorderedIR, self.compile_node(node, False))

# TODO: Remove cache when schema no longer requires compilation to derive schema (and therefor only compiles for execution)
@functools.lru_cache(maxsize=5000)
@vendored_functools.lru_cache(maxsize=5000)
def compile_node(
self, node: nodes.BigFrameNode, ordered: bool = True
) -> compiled.UnorderedIR | compiled.OrderedIR:
Expand Down
68 changes: 44 additions & 24 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import abc
import dataclasses
import datetime
import functools
import itertools
import typing
from typing import Callable, cast, Iterable, Mapping, Optional, Sequence, Tuple

# Modified version of functools cache methods to workaround https://github.com/python/cpython/issues/112215
import bigframes_vendored.cpython.functools as vendored_functools
import google.cloud.bigquery as bq

import bigframes.core.expression as ex
Expand Down Expand Up @@ -101,7 +102,8 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
"""The variables defined in this node (as opposed to by child nodes)."""
...

@functools.cached_property
@property
@vendored_functools.cache
def session(self):
sessions = []
for child in self.child_nodes:
Expand All @@ -118,7 +120,7 @@ def _validate(self):
"""Validate the local data in the node."""
return

@functools.cache
@vendored_functools.cache
def validate_tree(self) -> bool:
for child in self.child_nodes:
child.validate_tree()
Expand Down Expand Up @@ -147,9 +149,8 @@ def __eq__(self, other) -> bool:
return self._as_tuple() == other._as_tuple()

# BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
# Each subclass of BigFrameNode should use this property to implement __hash__
# The default dataclass-generated __hash__ method is not cached
@functools.cached_property
@vendored_functools.cached_property
def _cached_hash(self):
return hash(self._as_tuple())

Expand Down Expand Up @@ -209,23 +210,27 @@ def explicitly_ordered(self) -> bool:
"""
...

@functools.cached_property
@property
@vendored_functools.cache
def total_variables(self) -> int:
return self.variables_introduced + sum(
map(lambda x: x.total_variables, self.child_nodes)
)

@functools.cached_property
@property
@vendored_functools.cache
def total_relational_ops(self) -> int:
return self.relation_ops_created + sum(
map(lambda x: x.total_relational_ops, self.child_nodes)
)

@functools.cached_property
@property
@vendored_functools.cache
def total_joins(self) -> int:
return int(self.joins) + sum(map(lambda x: x.total_joins, self.child_nodes))

@functools.cached_property
@property
@vendored_functools.cache
def schema(self) -> schemata.ArraySchema:
# TODO: Make schema just a view on fields
return schemata.ArraySchema(
Expand Down Expand Up @@ -264,7 +269,8 @@ def defines_namespace(self) -> bool:
"""
return False

@functools.cached_property
@property
@vendored_functools.cache
def defined_variables(self) -> set[str]:
"""Full set of variables defined in the namespace, even if not selected."""
self_defined_variables = set(self.schema.names)
Expand All @@ -277,7 +283,8 @@ def defined_variables(self) -> set[str]:
def get_type(self, id: bfet_ids.ColumnId) -> bigframes.dtypes.Dtype:
return self._dtype_lookup[id]

@functools.cached_property
@property
@vendored_functools.cache
def _dtype_lookup(self):
return {field.id: field.dtype for field in self.fields}

Expand Down Expand Up @@ -419,7 +426,8 @@ def explicitly_ordered(self) -> bool:
def fields(self) -> Iterable[Field]:
return itertools.chain(self.left_child.fields, self.right_child.fields)

@functools.cached_property
@property
@vendored_functools.cache
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return OVERHEAD_VARIABLES
Expand Down Expand Up @@ -511,7 +519,8 @@ def fields(self) -> Iterable[Field]:
for id, field in zip(self.output_ids, self.children[0].fields)
)

@functools.cached_property
@property
@vendored_functools.cache
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return len(self.schema.items) + OVERHEAD_VARIABLES
Expand Down Expand Up @@ -579,11 +588,13 @@ def order_ambiguous(self) -> bool:
def explicitly_ordered(self) -> bool:
return True

@functools.cached_property
@property
@vendored_functools.cache
def fields(self) -> Iterable[Field]:
return (Field(self.output_id, next(iter(self.start.fields)).dtype),)

@functools.cached_property
@property
@vendored_functools.cache
def variables_introduced(self) -> int:
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
return len(self.schema.items) + OVERHEAD_VARIABLES
Expand Down Expand Up @@ -840,7 +851,8 @@ def order_ambiguous(self) -> bool:
def explicitly_ordered(self) -> bool:
return self.source.ordering is not None

@functools.cached_property
@property
@vendored_functools.cache
def variables_introduced(self) -> int:
return len(self.scan_list.items) + 1

Expand Down Expand Up @@ -902,7 +914,8 @@ def fields(self) -> Iterable[Field]:
def relation_ops_created(self) -> int:
return 2

@functools.cached_property
@property
@vendored_functools.cache
def variables_introduced(self) -> int:
return 1

Expand Down Expand Up @@ -1068,7 +1081,8 @@ def _validate(self):
if ref.id not in set(self.child.ids):
raise ValueError(f"Reference to column not in child: {ref.id}")

@functools.cached_property
@property
@vendored_functools.cache
def fields(self) -> Iterable[Field]:
return tuple(
Field(output, self.child.get_type(ref.id))
Expand Down Expand Up @@ -1139,7 +1153,8 @@ def _validate(self):
# Cannot assign to existing variables - append only!
assert all(name not in self.child.schema.names for _, name in self.assignments)

@functools.cached_property
@property
@vendored_functools.cache
def added_fields(self) -> Tuple[Field, ...]:
input_types = self.child._dtype_lookup
return tuple(
Expand Down Expand Up @@ -1252,7 +1267,8 @@ def row_preserving(self) -> bool:
def non_local(self) -> bool:
return True

@functools.cached_property
@property
@vendored_functools.cache
def fields(self) -> Iterable[Field]:
by_items = (
Field(ref.id, self.child.get_type(ref.id)) for ref in self.by_column_ids
Expand Down Expand Up @@ -1358,7 +1374,8 @@ def relation_ops_created(self) -> int:
def row_count(self) -> Optional[int]:
return self.child.row_count

@functools.cached_property
@property
@vendored_functools.cache
def added_field(self) -> Field:
input_type = self.child.get_type(self.column_name.id)
new_item_dtype = self.op.output_type(input_type)
Expand Down Expand Up @@ -1459,7 +1476,8 @@ def fields(self) -> Iterable[Field]:
def relation_ops_created(self) -> int:
return 3

@functools.cached_property
@property
@vendored_functools.cache
def variables_introduced(self) -> int:
return len(self.column_ids) + 1

Expand Down Expand Up @@ -1489,6 +1507,8 @@ def remap_refs(


# Tree operators


def top_down(
root: BigFrameNode,
transform: Callable[[BigFrameNode], BigFrameNode],
Expand All @@ -1507,7 +1527,7 @@ def top_down_internal(root: BigFrameNode) -> BigFrameNode:

if memoize:
# MUST reassign to the same name or caching won't work recursively
top_down_internal = functools.cache(top_down_internal)
top_down_internal = vendored_functools.cache(top_down_internal)

result = top_down_internal(root)
if validate:
Expand All @@ -1533,7 +1553,7 @@ def bottom_up_internal(root: BigFrameNode) -> BigFrameNode:

if memoize:
# MUST reassign to the same name or caching won't work recursively
bottom_up_internal = functools.cache(bottom_up_internal)
bottom_up_internal = vendored_functools.cache(bottom_up_internal)

result = bottom_up_internal(root)
if validate:
Expand Down
7 changes: 0 additions & 7 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4885,13 +4885,6 @@ def test_df_dot_operator_series(
)


# TODO(tswast): We may be able to re-enable this test after we break large
# queries up in https://github.com/googleapis/python-bigquery-dataframes/pull/427
@pytest.mark.skipif(
sys.version_info >= (3, 12),
# See: https://github.com/python/cpython/issues/112282
reason="setrecursionlimit has no effect on the Python C stack since Python 3.12.",
)
def test_recursion_limit(scalars_df_index):
scalars_df_index = scalars_df_index[["int64_too", "int64_col", "float64_col"]]
for i in range(400):
Expand Down
Loading
Loading