diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 41ee359d71b..dea584456f7 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -1,5 +1,6 @@ import builtins import functools +import itertools import operator from typing import Any, Callable, Iterable, Optional, TypeVar, Union @@ -5172,64 +5173,55 @@ def list_shape(x): else: return [] - def deep_flatten(es): - result = [] - for e in es: - if isinstance(e, (list, builtins.tuple)): - result.extend(deep_flatten(e)) - else: - result.append(e) - - return result - - def check_arrays_uniform(nested_arr, shape_list, ndim): - current_level_correct = hl.len(nested_arr) == shape_list[-ndim] - if ndim == 1: - return current_level_correct - else: - return current_level_correct & ( - hl.all(lambda inner: check_arrays_uniform(inner, shape_list, ndim - 1), nested_arr) - ) - - if isinstance(collection, Expression): - if isinstance(collection, ArrayNumericExpression): - data_expr = collection - shape_expr = to_expr(tuple([hl.int64(hl.len(collection))]), ttuple(tint64)) - ndim = 1 - elif isinstance(collection, NumericExpression): - data_expr = array([collection]) - shape_expr = hl.tuple([]) - ndim = 0 - elif isinstance(collection, ArrayExpression): - recursive_type = collection.dtype - ndim = 0 - while isinstance(recursive_type, (tarray, tndarray)): - recursive_type = recursive_type._element_type - ndim += 1 - - data_expr = collection - for i in builtins.range(ndim - 1): - data_expr = hl.flatten(data_expr) - - nested_collection = collection - shape_list = [] - for i in builtins.range(ndim): - shape_list.append(hl.int64(hl.len(nested_collection))) - nested_collection = nested_collection[0] - - shape_expr = ( - hl.case() - .when(check_arrays_uniform(collection, shape_list, ndim), hl.tuple(shape_list)) - .or_error("inner dimensions do not match") - ) + def deep_flatten(xs: Iterable) -> Iterable: + return [y for x in xs for y in (deep_flatten(x) if isinstance(x, Iterable) else [x])] + + if isinstance(collection, NumericExpression): + data_expr = array([collection]) + shape_expr = hl.tuple([]) + ndim = 0 + elif isinstance(collection, ArrayExpression): + recursive_type = collection.dtype + ndim = 0 + + while isinstance(recursive_type, (tarray, tndarray)): + recursive_type = recursive_type._element_type + ndim += 1 + + def flatten_assert_shape(shape): + def go(dim): + return lambda xs: ( + hl.case() + .when( + hl.len(xs) == shape[dim], + xs if dim == ndim - 1 else xs.flatmap(go(dim + 1)), + ) + .or_error(f"dimension {dim} did not match") + ) - else: - raise ValueError(f"{collection} cannot be converted into an ndarray") + return go(0) + + shape_expr, data_expr = hl.bind( + lambda arr: hl.bind( + lambda shape: (shape, flatten_assert_shape(shape)(data)), + hl.tuple( + hl.int64(hl.len(dim)) + for dim in itertools.accumulate( + builtins.range(ndim - 1), + lambda xs, _: xs[0], + initial=arr, + ) + ), + ), + collection, + ) + elif isinstance(collection, Expression): + raise ValueError(f"{collection} cannot be converted into an ndarray") + elif isinstance(collection, np.ndarray): + return hl.literal(collection) else: - if isinstance(collection, np.ndarray): - return hl.literal(collection) - elif isinstance(collection, (list, builtins.tuple)): + if isinstance(collection, Iterable): shape = list_shape(collection) data = deep_flatten(collection) else: