Skip to content

Commit

Permalink
Added multi-axis advanced indexing support (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeymezher authored May 15, 2020
1 parent 792f49f commit 38ed16f
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 15 deletions.
111 changes: 97 additions & 14 deletions sparse/_coo/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def getitem(x, index):
from .core import COO

# If string, this is an index into an np.void

# Custom dtype.
if isinstance(index, str):
data = x.data[index]
Expand Down Expand Up @@ -86,6 +87,7 @@ def getitem(x, index):
i = 0

sorted = adv_idx is None or adv_idx.pos == 0
adv_idx_added = False
for ind in index:
# Nothing is added to shape or coords if the index is an integer.
if isinstance(ind, Integral):
Expand All @@ -100,8 +102,10 @@ def getitem(x, index):
sorted = False
# Add the index and shape for the advanced index.
elif isinstance(ind, np.ndarray):
shape.append(adv_idx.length)
coords.append(adv_idx.idx)
if not adv_idx_added:
shape.append(adv_idx.length)
coords.append(adv_idx.idx)
adv_idx_added = True
i += 1
# Add a dimension for None.
elif ind is None:
Expand Down Expand Up @@ -141,20 +145,37 @@ def _mask(coords, indices, shape):

if len(adv_idx) != 0:
if len(adv_idx) != 1:
raise IndexError(
"Only indices with at most one iterable index are supported."

# Ensure if multiple advanced indices are passed, all are of the same length
# Also check each advanced index to ensure each is only a one-dimensional iterable
adv_ix_len = len(adv_idx[0])
for ai in adv_idx:
if len(ai) != adv_ix_len:
raise IndexError(
"shape mismatch: indexing arrays could not be broadcast together. Ensure all indexing arrays are of the same length."
)
if ai.ndim != 1:
raise IndexError("Only one-dimensional iterable indices supported.")

mask, aidxs = _compute_multi_axis_multi_mask(
coords,
_ind_ar_from_indices(indices),
np.array(adv_idx, dtype=np.intp),
np.array(adv_idx_pos, dtype=np.intp),
)
return mask, _AdvIdxInfo(aidxs, adv_idx_pos, adv_ix_len)

adv_idx = adv_idx[0]
adv_idx_pos = adv_idx_pos[0]
else:
adv_idx = adv_idx[0]
adv_idx_pos = adv_idx_pos[0]

if adv_idx.ndim != 1:
raise IndexError("Only one-dimensional iterable indices supported.")
if adv_idx.ndim != 1:
raise IndexError("Only one-dimensional iterable indices supported.")

mask, aidxs = _compute_multi_mask(
coords, _ind_ar_from_indices(indices), adv_idx, adv_idx_pos
)
return mask, _AdvIdxInfo(aidxs, adv_idx_pos, len(adv_idx))
mask, aidxs = _compute_multi_mask(
coords, _ind_ar_from_indices(indices), adv_idx, adv_idx_pos
)
return mask, _AdvIdxInfo(aidxs, adv_idx_pos, len(adv_idx))

mask, is_slice = _compute_mask(coords, _ind_ar_from_indices(indices))

Expand Down Expand Up @@ -276,6 +297,68 @@ def _separate_adv_indices(indices):
return new_idx, adv_idx, adv_idx_pos


@numba.jit(nopython=True, nogil=True)
def _compute_multi_axis_multi_mask(
coords, indices, adv_idx, adv_idx_pos
): # pragma: no cover
"""
Computes a mask with the advanced index, and also returns the advanced index
dimension.
Parameters
----------
coords : np.ndarray
Coordinates of the input array.
indices : np.ndarray
The indices in slice format.
adv_idx : np.ndarray
List of advanced indices.
adv_idx_pos : np.ndarray
The position of the advanced indices.
Returns
-------
mask : np.ndarray
The mask.
aidxs : np.ndarray
The advanced array index.
"""
n_adv_idx = len(adv_idx_pos)
mask = numba.typed.List.empty_list(numba.types.intp)
a_indices = numba.typed.List.empty_list(numba.types.intp)
full_idx = np.empty((len(indices) + len(adv_idx_pos), 3), dtype=np.intp)

# Get location of non-advanced indices
if len(indices) != 0:
ixx = 0
for ix in range(coords.shape[0]):
isin = False
for ax in adv_idx_pos:
if ix == ax:
isin = True
break
if not isin:
full_idx[ix] = indices[ixx]
ixx += 1

for i in range(len(adv_idx[0])):
for ii in range(n_adv_idx):
full_idx[adv_idx_pos[ii]] = [adv_idx[ii][i], adv_idx[ii][i] + 1, 1]

partial_mask, is_slice = _compute_mask(coords, full_idx)
if is_slice:
slice_mask = numba.typed.List.empty_list(numba.types.intp)
for j in range(partial_mask[0], partial_mask[1]):
slice_mask.append(j)
partial_mask = array_from_list_intp(slice_mask)

for j in range(len(partial_mask)):
mask.append(partial_mask[j])
a_indices.append(i)

return array_from_list_intp(mask), array_from_list_intp(a_indices)


@numba.jit(nopython=True, nogil=True)
def _compute_multi_mask(coords, indices, adv_idx, adv_idx_pos): # pragma: no cover
"""
Expand All @@ -288,9 +371,9 @@ def _compute_multi_mask(coords, indices, adv_idx, adv_idx_pos): # pragma: no co
Coordinates of the input array.
indices : np.ndarray
The indices in slice format.
adv_idx : int
adv_idx : list(int)
The advanced index.
adv_idx_pos : int
adv_idx_pos : list(int)
The position of the advanced index.
Returns
Expand Down
5 changes: 4 additions & 1 deletion sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,10 @@ def test_gt():
(1, Ellipsis, None),
(1, 1, 1, Ellipsis),
(Ellipsis, 1, None),
# With multi-axis advanced indexing
([0, 1],) * 2,
([0, 1], [0, 2]),
([0, 0, 0], [0, 1, 2], [1, 2, 1]),
# Pathological - Slices larger than array
(slice(None, 1000)),
(slice(None), slice(None, 1000)),
Expand Down Expand Up @@ -1355,7 +1359,6 @@ def test_custom_dtype_slicing():
0.5,
[0.5],
{"potato": "kartoffel"},
([0, 1],) * 2,
([[0, 1]],),
],
)
Expand Down

0 comments on commit 38ed16f

Please sign in to comment.