Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] Fix imperative basic indexing in numpy (#16902) (#16919)
Browse files Browse the repository at this point in the history
* fix bug

add test case

fix

Update test_numpy_ndarray.py

* revise function name
  • Loading branch information
ptrendx authored Nov 27, 2019
1 parent 121739a commit a3b0aa4
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 39 deletions.
81 changes: 52 additions & 29 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,26 +847,32 @@ def _basic_indexing_slice_is_contiguous(slc_key, shape):
"""Whether indexing with the given key results in a contiguous array.
The rule is: From right to left, if in an axis, a slice produces a
proper subset, no later axis can produce a proper subset or use
a step different from 1.
proper subset, the later slice must have <=1 elements.
The ``slc_key`` sequence must have the same length as ``shape`` and
only contain `slice` objects.
"""
assert len(slc_key) == len(shape)
subset = False
is_subset = False
total_sliced_elements = np.prod([_get_slice_len(slc, n)
for slc, n in zip(slc_key, shape)])
if total_sliced_elements in (0, 1):
return True
for idx, n in zip(reversed(slc_key), reversed(shape)):
start, stop, step = idx.indices(n)
if step > 0:
num = int(np.ceil(max(stop - start, 0) / step))
else:
num = int(np.ceil(min(stop - start, 0) / step))

if num != 1 and (subset or step != 1):
_, _, step = idx.indices(n)
num_elements = _get_slice_len(idx, n)
if num_elements == 0:
return True
elif num_elements > 1 and (step > 1 or step < 0):
# We do not support the case of reverse slicing of multiple elements and
# forward slicing of #elements > 1 and step > 1
return False
if num != n:
subset = True

elif is_subset:
if num_elements > 1:
return False
else:
if num_elements < n:
is_subset = True
return True
# pylint: enable=invalid-name

Expand All @@ -875,30 +881,27 @@ def _basic_indexing_sliced_shape(slc_key, shape):
"""Return the shape after slicing with the given key."""
assert len(slc_key) == len(shape)
sliced_shape = []
for idx, n in zip(slc_key, shape):
start, stop, step = idx.indices(n)
if step > 0:
num = int(np.ceil(max(stop - start, 0) / step))
else:
num = int(np.ceil(min(stop - start, 0) / step))
sliced_shape.append(num)

for slc, n in zip(slc_key, shape):
num_elements = _get_slice_len(slc, n)
sliced_shape.append(num_elements)
return tuple(sliced_shape)

# pylint: disable=invalid-name
@staticmethod
def _basic_indexing_contiguous_flat_begin_end(slc_key, shape):
"""Return the flat indices of begin and end for contiguous slicing."""
assert len(slc_key) == len(shape)
begin, end, _ = slc_key[0].indices(shape[0])
flat_begin, flat_end = begin, end - 1
for idx, n in zip(slc_key[1:], shape[1:]):
flat_begin, flat_end = 0, 0
for slc, n in zip(slc_key, shape):
flat_begin *= n
flat_end *= n
begin, end, _ = idx.indices(n)
flat_begin += begin
flat_end += end - 1

begin, _, _ = slc.indices(n)
num_elements = _get_slice_len(slc, n)
if num_elements == 0:
return 0, 0
else:
flat_begin += begin
flat_end += begin + num_elements - 1
return flat_begin, flat_end + 1
# pylint: enable=invalid-name

Expand Down Expand Up @@ -1062,7 +1065,7 @@ def _get_nd_basic_indexing(self, key):
for ax in new_axes: # pylint: disable=invalid-name
final_shape.insert(ax, 1)

if final_shape == []:
if len(final_shape) == 0:
# Override for single element indexing
final_shape = [1]
return sliced.reshape(final_shape)
Expand Down Expand Up @@ -3108,6 +3111,26 @@ def _get_dim_size(start, stop, step):
return dim_size


def _get_slice_len(slc, seq_length):
"""Given a python slice object and the length of the sequence, calculate the number of elements
in the slice.
Parameters
----------
slc : py_slice
The slice object
seq_length : int
The length of the object you are going to apply the slice on
Returns
-------
ret : int
Total number of elements in the slice
"""
start, stop, step = slc.indices(seq_length)
return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)


def _get_broadcast_shape(shape1, shape2):
"""Given two shapes that are not identical, find the shape
that both input shapes can broadcast to."""
Expand Down
5 changes: 2 additions & 3 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,9 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
NDArray ret = this->Slice(begin, end);
if (!Imperative::Get()->is_recording()) return ret;
// fake a slice_axis op
// fake a slice op
nnvm::NodeAttrs attrs;
attrs.op = nnvm::Op::Get("slice_axis");
attrs.dict.insert({"axis", "0"});
attrs.op = nnvm::Op::Get("slice");
attrs.dict.insert({"begin", std::to_string(begin)});
attrs.dict.insert({"end", std::to_string(end)});
attrs.op->attr_parser(&attrs);
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ static inline bool SupportStorageMKLDNN(int stype) {

static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
int ndim = shape.ndim();
if (ndim == 0 || shape.Size() == 0) {
// MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
return false;
}
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
}

Expand Down
20 changes: 13 additions & 7 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,13 +642,18 @@ def test_getitem(np_array, index):
)
np_indexed_array = np_array[np_index]
mx_np_array = np.array(np_array, dtype=np_array.dtype)
try:
mx_indexed_array = mx_np_array[index]
except Exception as e:
print('Failed with index = {}'.format(index))
raise e
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)
for autograd in [True, False]:
try:
if autograd:
with mx.autograd.record():
mx_indexed_array = mx_np_array[index]
else:
mx_indexed_array = mx_np_array[index]
except Exception as e:
print('Failed with index = {}'.format(index))
raise e
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)

def test_setitem(np_array, index):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
Expand Down Expand Up @@ -768,6 +773,7 @@ def test_setitem_autograd(np_array, index):
np_int(slice(1, 5), np.int32),
np_int(slice(1, 5), np.int64),
slice(1, 5, 2),
slice(1, 2, 2),
np_int(slice(1, 5, 2), np.int32),
np_int(slice(1, 5, 2), np.int64),
slice(7, 0, -1),
Expand Down

0 comments on commit a3b0aa4

Please sign in to comment.