From d3793f45c8e3cc85ba06e5fc5363e5022f8ecad5 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Wed, 1 Nov 2023 12:22:47 +0000 Subject: [PATCH] fix: `ak.num` should always return a useful (non-unknown length) type (#2785) * fix: `ak.num` should always return a useful (non-unknown length) type * test: add test * feat: always return ArrayLike for `ak.num` * test: update test to reflect ak.num behavior --- src/awkward/operations/ak_num.py | 5 +++-- tests/test_2785_ak_num_typetracer_axis_0.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 tests/test_2785_ak_num_typetracer_axis_0.py diff --git a/src/awkward/operations/ak_num.py b/src/awkward/operations/ak_num.py index 7352ecccd2..3aa5f8e2f6 100644 --- a/src/awkward/operations/ak_num.py +++ b/src/awkward/operations/ak_num.py @@ -87,10 +87,11 @@ def _impl(array, axis, highlevel, behavior): raise TypeError(f"'axis' must be an integer, not {axis!r}") if maybe_posaxis(layout, axis, 1) == 0: + index_nplike = layout.backend.index_nplike if isinstance(layout, ak.record.Record): - return 1 + return index_nplike.asarray(index_nplike.shape_item_as_index(1)) else: - return layout.length + return index_nplike.asarray(index_nplike.shape_item_as_index(layout.length)) def action(layout, depth, **kwargs): posaxis = maybe_posaxis(layout, axis, depth) diff --git a/tests/test_2785_ak_num_typetracer_axis_0.py b/tests/test_2785_ak_num_typetracer_axis_0.py new file mode 100644 index 0000000000..bbccb2df8e --- /dev/null +++ b/tests/test_2785_ak_num_typetracer_axis_0.py @@ -0,0 +1,16 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE + + +import awkward as ak +from awkward._nplikes.typetracer import TypeTracerArray + + +def test_unknown_length(): + array = ak.typetracer.typetracer_from_form(ak.forms.NumpyForm("int64")) + assert isinstance(ak.num(array, axis=0), TypeTracerArray) + + +def test_known_length(): + array = ak.Array([0, 1, 2, 3], backend="typetracer") + # This is now the new behavior - always return typetracers + assert isinstance(ak.num(array, axis=0), TypeTracerArray)