Skip to content

Commit a9b5427

Browse files
committed
fix #883: fixed Array.values(), zip_array_values and zip_array_items when axes=()
1 parent 3be6872 commit a9b5427

File tree

3 files changed

+68
-12
lines changed

3 files changed

+68
-12
lines changed

doc/source/changes/version_0_33.rst.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ Miscellaneous improvements
6161
Fixes
6262
^^^^^
6363

64-
* fixed something (closes :issue:`1`).
64+
* fixed Array.values(), zip_array_values and zip_array_items when axes=() (closes :issue:`883`).

larray/core/array.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3548,6 +3548,9 @@ def values(self, axes=None, ascending=True):
35483548
combined = np.ravel(self.data)
35493549
# combined[::-1] *is* indexable
35503550
return combined if ascending else combined[::-1]
3551+
elif not axes:
3552+
# empty axes list
3553+
return [self]
35513554

35523555
if not isinstance(axes, (tuple, list, AxisCollection)):
35533556
axes = (axes,)
@@ -9722,9 +9725,11 @@ def zip_array_values(values, axes=None, ascending=True):
97229725
97239726
Parameters
97249727
----------
9728+
values : sequence of (scalar or Array)
9729+
Values to iterate on. Scalars are repeated as many times as necessary.
97259730
axes : int, str or Axis or tuple of them, optional
9726-
Axis or axes along which to iterate and in which order. Defaults to None (union of all axes present in
9727-
all arrays, in the order they are found).
9731+
Axis or axes along which to iterate and in which order. All those axes must be compatible (if present) between
9732+
the different values. Defaults to None (union of all axes present in all arrays, in the order they are found).
97289733
ascending : bool, optional
97299734
Whether or not to iterate the axes in ascending order (from start to end). Defaults to True.
97309735
@@ -9758,6 +9763,10 @@ def zip_array_values(values, axes=None, ascending=True):
97589763
2 3
97599764
c c1 c2
97609765
2 3
9766+
9767+
When the axis to iterate on (`c` in this case) is not present in one of the arrays (arr1), that array is repeated
9768+
for each label of that axis:
9769+
97619770
>>> for a1, a2 in zip_array_values((arr1, arr2), arr2.c):
97629771
... print("==")
97639772
... print(a1)
@@ -9774,8 +9783,11 @@ def zip_array_values(values, axes=None, ascending=True):
97749783
a1 2 3
97759784
a a0 a1
97769785
1 3
9786+
9787+
When no `axes` are given, it iterates on the union of all compatible axes (a, b, and c in this case):
9788+
97779789
>>> for a1, a2 in zip_array_values((arr1, arr2)):
9778-
... print("arr1: {}, arr2: {}".format(a1, a2))
9790+
... print(f"arr1: {a1}, arr2: {a2}")
97799791
arr1: 0, arr2: 0
97809792
arr1: 0, arr2: 1
97819793
arr1: 1, arr2: 0
@@ -9794,21 +9806,26 @@ def values_with_expand(value, axes, readonly=True, ascending=True):
97949806
size = axes.size if axes.ndim else 0
97959807
return Repeater(value, size)
97969808

9797-
all_axes = AxisCollection.union(*[get_axes(v) for v in values])
9809+
values_axes = [get_axes(v) for v in values]
9810+
97989811
if axes is None:
9799-
axes = all_axes
9812+
all_iter_axes = values_axes
98009813
else:
98019814
if not isinstance(axes, (tuple, list, AxisCollection)):
98029815
axes = (axes,)
9803-
# transform string axes definitions to objects
9816+
9817+
# transform string axes _definitions_ to objects
98049818
axes = [Axis(axis) if isinstance(axis, str) and '=' in axis else axis
98059819
for axis in axes]
9806-
# transform string axes references to objects
9807-
axes = AxisCollection([axis if isinstance(axis, Axis) else all_axes[axis]
9808-
for axis in axes])
9820+
9821+
# get iter axes for all values and transform string axes _references_ to objects
9822+
all_iter_axes = [AxisCollection([value_axes[axis] for axis in axes if axis in value_axes])
9823+
for value_axes in values_axes]
9824+
9825+
common_iter_axes = AxisCollection.union(*all_iter_axes)
98099826

98109827
# sequence of tuples (of scalar or arrays)
9811-
return SequenceZip([values_with_expand(v, axes, ascending=ascending) for v in values])
9828+
return SequenceZip([values_with_expand(v, common_iter_axes, ascending=ascending) for v in values])
98129829

98139830

98149831
def zip_array_items(values, axes=None, ascending=True):

larray/tests/test_array.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from larray import (Array, LArray, Axis, AxisCollection, LGroup, IGroup,
1818
union, zeros, zeros_like, ndtest, empty, ones, eye, diag, stack,
1919
clip, exp, where, X, mean, isnan, round, read_hdf, read_csv, read_eurostat, read_excel,
20-
from_lists, from_string, open_excel, from_frame, sequence, nan)
20+
from_lists, from_string, open_excel, from_frame, sequence, nan,
21+
zip_array_values, zip_array_items)
2122
from larray.inout.pandas import from_series
2223
from larray.core.axis import _to_ticks, _to_key
2324
from larray.util.misc import LHDFStore
@@ -329,6 +330,10 @@ def test_values():
329330
assert_larray_equal(values[0], arr['b1'])
330331
assert_larray_equal(values[-1], arr['b0'])
331332

333+
values = arr.values(())
334+
res = list(values)
335+
assert_larray_equal(res[0], arr)
336+
332337

333338
def test_items():
334339
arr = ndtest((2, 2))
@@ -5210,6 +5215,40 @@ def test_eq():
52105215
assert_array_equal(ao.eq(ao['c0'], nans_equal=True), a == a['c0'])
52115216

52125217

5218+
def test_zip_array_values():
5219+
arr1 = ndtest((2, 3))
5220+
# b axis intentionally not the same on both arrays
5221+
arr2 = ndtest((2, 2, 2))
5222+
5223+
# 1) no axes => return input arrays themselves
5224+
res = list(zip_array_values((arr1, arr2), ()))
5225+
assert len(res) == 1 and len(res[0]) == 2
5226+
r0_arr1, r0_arr2 = res[0]
5227+
assert_larray_equal(r0_arr1, arr1)
5228+
assert_larray_equal(r0_arr2, arr2)
5229+
5230+
# 2) iterate on an axis not present on one of the arrays => the other array is repeated
5231+
res = list(zip_array_values((arr1, arr2), arr2.c))
5232+
assert len(res) == 2 and all(len(r) == 2 for r in res)
5233+
r0_arr1, r0_arr2 = res[0]
5234+
r1_arr1, r1_arr2 = res[1]
5235+
assert_larray_equal(r0_arr1, arr1)
5236+
assert_larray_equal(r0_arr2, arr2['c0'])
5237+
assert_larray_equal(r1_arr1, arr1)
5238+
assert_larray_equal(r1_arr2, arr2['c1'])
5239+
5240+
5241+
def test_zip_array_items():
5242+
arr1 = ndtest('a=a0,a1;b=b0,b1')
5243+
arr2 = ndtest('a=a0,a1;c=c0,c1')
5244+
res = list(zip_array_items((arr1, arr2), axes=()))
5245+
assert len(res) == 1 and len(res[0]) == 2 and len(res[0][1]) == 2
5246+
r0_k, (r0_arr1, r0_arr2) = res[0]
5247+
assert r0_k == ()
5248+
assert_larray_equal(r0_arr1, arr1)
5249+
assert_larray_equal(r0_arr2, arr2)
5250+
5251+
52135252
if __name__ == "__main__":
52145253
# import doctest
52155254
# import unittest

0 commit comments

Comments
 (0)