Skip to content

Commit

Permalink
[query] Let-Bind Input Collection Parameter In hl.nd.array.
Browse files Browse the repository at this point in the history
Fixes: hail-is#14559
`hl.nd.array`s constructed from stream pipelines can cause out of memory
exceptions owing to a limitation in the python CSE algorithm that
does not eliminate partially redundant expressions if if-expressions.
Explicitly `let`-binding the input collection prevents it from being
evaluated twice: once for the flattened data stream and once for the
original shape.
  • Loading branch information
ehigham committed Jun 5, 2024
1 parent d9d85d5 commit 0056b39
Showing 1 changed file with 47 additions and 55 deletions.
102 changes: 47 additions & 55 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
import functools
import itertools
import operator
from typing import Any, Callable, Iterable, Optional, TypeVar, Union

Expand Down Expand Up @@ -5172,64 +5173,55 @@ def list_shape(x):
else:
return []

def deep_flatten(es):
result = []
for e in es:
if isinstance(e, (list, builtins.tuple)):
result.extend(deep_flatten(e))
else:
result.append(e)

return result

def check_arrays_uniform(nested_arr, shape_list, ndim):
current_level_correct = hl.len(nested_arr) == shape_list[-ndim]
if ndim == 1:
return current_level_correct
else:
return current_level_correct & (
hl.all(lambda inner: check_arrays_uniform(inner, shape_list, ndim - 1), nested_arr)
)

if isinstance(collection, Expression):
if isinstance(collection, ArrayNumericExpression):
data_expr = collection
shape_expr = to_expr(tuple([hl.int64(hl.len(collection))]), ttuple(tint64))
ndim = 1
elif isinstance(collection, NumericExpression):
data_expr = array([collection])
shape_expr = hl.tuple([])
ndim = 0
elif isinstance(collection, ArrayExpression):
recursive_type = collection.dtype
ndim = 0
while isinstance(recursive_type, (tarray, tndarray)):
recursive_type = recursive_type._element_type
ndim += 1

data_expr = collection
for i in builtins.range(ndim - 1):
data_expr = hl.flatten(data_expr)

nested_collection = collection
shape_list = []
for i in builtins.range(ndim):
shape_list.append(hl.int64(hl.len(nested_collection)))
nested_collection = nested_collection[0]

shape_expr = (
hl.case()
.when(check_arrays_uniform(collection, shape_list, ndim), hl.tuple(shape_list))
.or_error("inner dimensions do not match")
)
def deep_flatten(xs: Iterable) -> Iterable:
return [y for x in xs for y in (deep_flatten(x) if isinstance(x, Iterable) else [x])]

if isinstance(collection, NumericExpression):
data_expr = array([collection])
shape_expr = hl.tuple([])
ndim = 0
elif isinstance(collection, ArrayExpression):
recursive_type = collection.dtype
ndim = 0

while isinstance(recursive_type, (tarray, tndarray)):
recursive_type = recursive_type._element_type
ndim += 1

def flatten_assert_shape(shape):
def go(dim):
return lambda xs: (
hl.case()
.when(
hl.len(xs) == shape[dim],
xs if dim == ndim - 1 else xs.flatmap(go(dim + 1)),
)
.or_error(f"dimension {dim} did not match")
)

else:
raise ValueError(f"{collection} cannot be converted into an ndarray")
return go(0)

shape_expr, data_expr = hl.bind(
lambda arr: hl.bind(
lambda shape: (shape, flatten_assert_shape(shape)(data)),
hl.tuple(
hl.int64(hl.len(dim))
for dim in itertools.accumulate(
builtins.range(ndim - 1),
lambda xs, _: xs[0],
initial=arr,
)
),
),
collection,
)

elif isinstance(collection, Expression):
raise ValueError(f"{collection} cannot be converted into an ndarray")
elif isinstance(collection, np.ndarray):
return hl.literal(collection)
else:
if isinstance(collection, np.ndarray):
return hl.literal(collection)
elif isinstance(collection, (list, builtins.tuple)):
if isinstance(collection, Iterable):
shape = list_shape(collection)
data = deep_flatten(collection)
else:
Expand Down

0 comments on commit 0056b39

Please sign in to comment.