Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support range as advanced index for ndarrays #16047

Merged
merged 1 commit into from
Aug 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,17 +1053,15 @@ def _advanced_index_to_array(idx, ax_len, ctx):
if idx.dtype != idx_dtype:
idx = idx.astype(idx_dtype)
return idx.as_in_context(ctx)

elif isinstance(idx, (np.ndarray, list, tuple)):
return array(idx, ctx, idx_dtype)

elif isinstance(idx, integer_types):
return array([idx], ctx, idx_dtype)

elif isinstance(idx, py_slice):
start, stop, step = idx.indices(ax_len)
return arange(start, stop, step, ctx=ctx, dtype=idx_dtype)

elif sys.version_info[0] > 2 and isinstance(idx, range):
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
return arange(idx.start, idx.stop, idx.step, ctx=ctx, dtype=idx_dtype)
else:
raise RuntimeError('illegal index type {}'.format(type(idx)))

Expand Down Expand Up @@ -2888,6 +2886,7 @@ def _scatter_set_nd(self, value_nd, indices):
lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self
)


def indexing_key_expand_implicit_axes(key, shape):
"""Make implicit axes explicit by adding ``slice(None)``.
Examples
Expand Down Expand Up @@ -2984,6 +2983,8 @@ def _is_advanced_index(idx):
return True
elif isinstance(idx, py_slice) or idx is None:
return False
elif sys.version_info[0] > 2 and isinstance(idx, range):
return True
else:
raise RuntimeError('illegal index type {}'.format(type(idx)))

Expand All @@ -2995,7 +2996,8 @@ def get_indexing_dispatch_code(key):
for idx in key:
if isinstance(idx, (NDArray, np.ndarray, list, tuple)):
return _NDARRAY_ADVANCED_INDEXING

elif sys.version_info[0] > 2 and isinstance(idx, range):
return _NDARRAY_ADVANCED_INDEXING
elif not (isinstance(idx, (py_slice, integer_types)) or idx is None):
raise ValueError(
'NDArray does not support slicing with key {} of type {}.'
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def __getitem__(self, key):
key, shape[0]))
return self._at(key)
elif isinstance(key, py_slice):
if (key.step is None or key.step == 1):
if key.start is not None or key.stop is not None:
if key.step is None or key.step == 1:
if key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
return self
Expand Down
303 changes: 152 additions & 151 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,6 @@ def test_np_ndarray_copy():
@with_seed()
@use_np
def test_np_ndarray_indexing():
"""
Test all indexing.
"""
def np_int(index, int_type=np.int32):
"""
Helper function for testing indexing that converts slices to slices of ints or None, and tuples to
Expand Down Expand Up @@ -507,156 +504,160 @@ def test_setitem_autograd(np_array, index):

shape = (8, 16, 9, 9)
np_array = _np.arange(_np.prod(_np.array(shape)), dtype='int32').reshape(shape) # native np array

# Test sliced output being ndarray:
index_list = [
# Basic indexing
# Single int as index
0,
np.int32(0),
np.int64(0),
5,
np.int32(5),
np.int64(5),
-1,
np.int32(-1),
np.int64(-1),
# Slicing as index
slice(5),
np_int(slice(5), np.int32),
np_int(slice(5), np.int64),
slice(1, 5),
np_int(slice(1, 5), np.int32),
np_int(slice(1, 5), np.int64),
slice(1, 5, 2),
np_int(slice(1, 5, 2), np.int32),
np_int(slice(1, 5, 2), np.int64),
slice(7, 0, -1),
np_int(slice(7, 0, -1)),
np_int(slice(7, 0, -1), np.int64),
slice(None, 6),
np_int(slice(None, 6)),
np_int(slice(None, 6), np.int64),
slice(None, 6, 3),
np_int(slice(None, 6, 3)),
np_int(slice(None, 6, 3), np.int64),
slice(1, None),
np_int(slice(1, None)),
np_int(slice(1, None), np.int64),
slice(1, None, 3),
np_int(slice(1, None, 3)),
np_int(slice(1, None, 3), np.int64),
slice(None, None, 2),
np_int(slice(None, None, 2)),
np_int(slice(None, None, 2), np.int64),
slice(None, None, -1),
np_int(slice(None, None, -1)),
np_int(slice(None, None, -1), np.int64),
slice(None, None, -2),
np_int(slice(None, None, -2), np.int32),
np_int(slice(None, None, -2), np.int64),
# Multiple ints as indices
(1, 2, 3),
np_int((1, 2, 3)),
np_int((1, 2, 3), np.int64),
(-1, -2, -3),
np_int((-1, -2, -3)),
np_int((-1, -2, -3), np.int64),
(1, 2, 3, 4),
np_int((1, 2, 3, 4)),
np_int((1, 2, 3, 4), np.int64),
(-4, -3, -2, -1),
np_int((-4, -3, -2, -1)),
np_int((-4, -3, -2, -1), np.int64),
# slice(None) as indices
(slice(None), slice(None), 1, 8),
(slice(None), slice(None), -1, 8),
(slice(None), slice(None), 1, -8),
(slice(None), slice(None), -1, -8),
np_int((slice(None), slice(None), 1, 8)),
np_int((slice(None), slice(None), 1, 8), np.int64),
(slice(None), slice(None), 1, 8),
np_int((slice(None), slice(None), -1, -8)),
np_int((slice(None), slice(None), -1, -8), np.int64),
(slice(None), 2, slice(1, 5), 1),
np_int((slice(None), 2, slice(1, 5), 1)),
np_int((slice(None), 2, slice(1, 5), 1), np.int64),
# Mixture of ints and slices as indices
(slice(None, None, -1), 2, slice(1, 5), 1),
np_int((slice(None, None, -1), 2, slice(1, 5), 1)),
np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64),
(slice(None, None, -1), 2, slice(1, 7, 2), 1),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64),
(slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64),
(slice(1, 8, 2), 1, slice(3, 8), 2),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2)),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64),
# Test Ellipsis ('...')
(1, Ellipsis, -1),
(slice(2), Ellipsis, None, 0),
# Test newaxis
None,
(1, None, -2, 3, -4),
(1, slice(2, 5), None),
(slice(None), slice(1, 4), None, slice(2, 3)),
(slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None),
(slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)),
(None, slice(1, 2), 3, None),
(1, None, 2, 3, None, None, 4),
# Advanced indexing
([1, 2], slice(3, 5), None, None, [3, 4]),
(slice(None), slice(3, 5), None, None, [2, 3], [3, 4]),
(slice(None), slice(3, 5), None, [2, 3], None, [3, 4]),
(None, slice(None), slice(3, 5), [2, 3], None, [3, 4]),
[1],
[1, 2],
[2, 1, 3],
[7, 5, 0, 3, 6, 2, 1],
np.array([6, 3], dtype=np.int32),
np.array([[3, 4], [0, 6]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
np.array([[2], [0], [1]], dtype=np.int32),
np.array([[2], [0], [1]], dtype=np.int64),
np.array([4, 7], dtype=np.int32),
np.array([4, 7], dtype=np.int64),
np.array([[3, 6], [2, 1]], dtype=np.int32),
np.array([[3, 6], [2, 1]], dtype=np.int64),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int32)),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int64)),
(1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)),
(1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)),
(1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)),
(1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')),
[0],
[0, 1],
[1, 2, 3],
[2, 0, 5, 6],
([1, 1], [2, 3]),
([1], [4], [5]),
([1], [4], [5], [6]),
([[1]], [[2]]),
([[1]], [[2]], [[3]], [[4]]),
(slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)),
([[[[1]]]], [[1]], slice(0, 3), [1, 5]),
([[[[1]]]], 3, slice(0, 3), [1, 3]),
([[[[1]]]], 3, slice(0, 3), 0),
([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)),
([1, 2], slice(3, 5), [2, 3], [3, 4]),
([1, 2], slice(3, 5), (2, 3), [3, 4]),
(),
# Basic indexing
# Single int as index
0,
np.int32(0),
np.int64(0),
5,
np.int32(5),
np.int64(5),
-1,
np.int32(-1),
np.int64(-1),
# Slicing as index
slice(5),
np_int(slice(5), np.int32),
np_int(slice(5), np.int64),
slice(1, 5),
np_int(slice(1, 5), np.int32),
np_int(slice(1, 5), np.int64),
slice(1, 5, 2),
np_int(slice(1, 5, 2), np.int32),
np_int(slice(1, 5, 2), np.int64),
slice(7, 0, -1),
np_int(slice(7, 0, -1)),
np_int(slice(7, 0, -1), np.int64),
slice(None, 6),
np_int(slice(None, 6)),
np_int(slice(None, 6), np.int64),
slice(None, 6, 3),
np_int(slice(None, 6, 3)),
np_int(slice(None, 6, 3), np.int64),
slice(1, None),
np_int(slice(1, None)),
np_int(slice(1, None), np.int64),
slice(1, None, 3),
np_int(slice(1, None, 3)),
np_int(slice(1, None, 3), np.int64),
slice(None, None, 2),
np_int(slice(None, None, 2)),
np_int(slice(None, None, 2), np.int64),
slice(None, None, -1),
np_int(slice(None, None, -1)),
np_int(slice(None, None, -1), np.int64),
slice(None, None, -2),
np_int(slice(None, None, -2), np.int32),
np_int(slice(None, None, -2), np.int64),
# Multiple ints as indices
(1, 2, 3),
np_int((1, 2, 3)),
np_int((1, 2, 3), np.int64),
(-1, -2, -3),
np_int((-1, -2, -3)),
np_int((-1, -2, -3), np.int64),
(1, 2, 3, 4),
np_int((1, 2, 3, 4)),
np_int((1, 2, 3, 4), np.int64),
(-4, -3, -2, -1),
np_int((-4, -3, -2, -1)),
np_int((-4, -3, -2, -1), np.int64),
# slice(None) as indices
(slice(None), slice(None), 1, 8),
(slice(None), slice(None), -1, 8),
(slice(None), slice(None), 1, -8),
(slice(None), slice(None), -1, -8),
np_int((slice(None), slice(None), 1, 8)),
np_int((slice(None), slice(None), 1, 8), np.int64),
(slice(None), slice(None), 1, 8),
np_int((slice(None), slice(None), -1, -8)),
np_int((slice(None), slice(None), -1, -8), np.int64),
(slice(None), 2, slice(1, 5), 1),
np_int((slice(None), 2, slice(1, 5), 1)),
np_int((slice(None), 2, slice(1, 5), 1), np.int64),
# Mixture of ints and slices as indices
(slice(None, None, -1), 2, slice(1, 5), 1),
np_int((slice(None, None, -1), 2, slice(1, 5), 1)),
np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64),
(slice(None, None, -1), 2, slice(1, 7, 2), 1),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64),
(slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64),
(slice(1, 8, 2), 1, slice(3, 8), 2),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2)),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64),
# Test Ellipsis ('...')
(1, Ellipsis, -1),
(slice(2), Ellipsis, None, 0),
# Test newaxis
None,
(1, None, -2, 3, -4),
(1, slice(2, 5), None),
(slice(None), slice(1, 4), None, slice(2, 3)),
(slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None),
(slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)),
(None, slice(1, 2), 3, None),
(1, None, 2, 3, None, None, 4),
# Advanced indexing
([1, 2], slice(3, 5), None, None, [3, 4]),
(slice(None), slice(3, 5), None, None, [2, 3], [3, 4]),
(slice(None), slice(3, 5), None, [2, 3], None, [3, 4]),
(None, slice(None), slice(3, 5), [2, 3], None, [3, 4]),
[1],
[1, 2],
[2, 1, 3],
[7, 5, 0, 3, 6, 2, 1],
np.array([6, 3], dtype=np.int32),
np.array([[3, 4], [0, 6]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
np.array([[2], [0], [1]], dtype=np.int32),
np.array([[2], [0], [1]], dtype=np.int64),
np.array([4, 7], dtype=np.int32),
np.array([4, 7], dtype=np.int64),
np.array([[3, 6], [2, 1]], dtype=np.int32),
np.array([[3, 6], [2, 1]], dtype=np.int64),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int32)),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int64)),
(1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)),
(1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)),
(1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)),
(1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')),
[0],
[0, 1],
[1, 2, 3],
[2, 0, 5, 6],
([1, 1], [2, 3]),
([1], [4], [5]),
([1], [4], [5], [6]),
([[1]], [[2]]),
([[1]], [[2]], [[3]], [[4]]),
(slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)),
([[[[1]]]], [[1]], slice(0, 3), [1, 5]),
([[[[1]]]], 3, slice(0, 3), [1, 3]),
([[[[1]]]], 3, slice(0, 3), 0),
([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)),
([1, 2], slice(3, 5), [2, 3], [3, 4]),
([1, 2], slice(3, 5), (2, 3), [3, 4]),
range(4),
range(3, 0, -1),
(range(4,), [1]),
]
for index in index_list:
test_getitem(np_array, index)
Expand Down