diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 7d8cc524c817..a7ad8e6c6c98 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -847,26 +847,32 @@ def _basic_indexing_slice_is_contiguous(slc_key, shape): """Whether indexing with the given key results in a contiguous array. The rule is: From right to left, if in an axis, a slice produces a - proper subset, no later axis can produce a proper subset or use - a step different from 1. + proper subset, the later slice must have <=1 elements. The ``slc_key`` sequence must have the same length as ``shape`` and only contain `slice` objects. """ assert len(slc_key) == len(shape) - subset = False + is_subset = False + total_sliced_elements = np.prod([_get_slice_len(slc, n) + for slc, n in zip(slc_key, shape)]) + if total_sliced_elements in (0, 1): + return True for idx, n in zip(reversed(slc_key), reversed(shape)): - start, stop, step = idx.indices(n) - if step > 0: - num = int(np.ceil(max(stop - start, 0) / step)) - else: - num = int(np.ceil(min(stop - start, 0) / step)) - - if num != 1 and (subset or step != 1): + _, _, step = idx.indices(n) + num_elements = _get_slice_len(idx, n) + if num_elements == 0: + return True + elif num_elements > 1 and (step > 1 or step < 0): + # We do not support the case of reverse slicing of multiple elements and + # forward slicing of #elements > 1 and step > 1 return False - if num != n: - subset = True - + elif is_subset: + if num_elements > 1: + return False + else: + if num_elements < n: + is_subset = True return True # pylint: enable=invalid-name @@ -875,14 +881,9 @@ def _basic_indexing_sliced_shape(slc_key, shape): """Return the shape after slicing with the given key.""" assert len(slc_key) == len(shape) sliced_shape = [] - for idx, n in zip(slc_key, shape): - start, stop, step = idx.indices(n) - if step > 0: - num = int(np.ceil(max(stop - start, 0) / step)) - else: - num = int(np.ceil(min(stop - start, 0) / step)) - sliced_shape.append(num) - + for slc, n in zip(slc_key, shape): + num_elements = _get_slice_len(slc, n) + sliced_shape.append(num_elements) return tuple(sliced_shape) # pylint: disable=invalid-name @@ -890,15 +891,17 @@ def _basic_indexing_sliced_shape(slc_key, shape): def _basic_indexing_contiguous_flat_begin_end(slc_key, shape): """Return the flat indices of begin and end for contiguous slicing.""" assert len(slc_key) == len(shape) - begin, end, _ = slc_key[0].indices(shape[0]) - flat_begin, flat_end = begin, end - 1 - for idx, n in zip(slc_key[1:], shape[1:]): + flat_begin, flat_end = 0, 0 + for slc, n in zip(slc_key, shape): flat_begin *= n flat_end *= n - begin, end, _ = idx.indices(n) - flat_begin += begin - flat_end += end - 1 - + begin, _, _ = slc.indices(n) + num_elements = _get_slice_len(slc, n) + if num_elements == 0: + return 0, 0 + else: + flat_begin += begin + flat_end += begin + num_elements - 1 return flat_begin, flat_end + 1 # pylint: enable=invalid-name @@ -1062,7 +1065,7 @@ def _get_nd_basic_indexing(self, key): for ax in new_axes: # pylint: disable=invalid-name final_shape.insert(ax, 1) - if final_shape == []: + if len(final_shape) == 0: # Override for single element indexing final_shape = [1] return sliced.reshape(final_shape) @@ -3108,6 +3111,26 @@ def _get_dim_size(start, stop, step): return dim_size +def _get_slice_len(slc, seq_length): + """Given a python slice object and the length of the sequence, calculate the number of elements + in the slice. + + Parameters + ---------- + slc : py_slice + The slice object + seq_length : int + The length of the object you are going to apply the slice on + + Returns + ------- + ret : int + Total number of elements in the slice + """ + start, stop, step = slc.indices(seq_length) + return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step) + + def _get_broadcast_shape(shape1, shape2): """Given two shapes that are not identical, find the shape that both input shapes can broadcast to.""" diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 6dc6bafa7288..9375bed5a79b 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -293,10 +293,9 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { NDArray NDArray::SliceWithRecord(index_t begin, index_t end) { NDArray ret = this->Slice(begin, end); if (!Imperative::Get()->is_recording()) return ret; - // fake a slice_axis op + // fake a slice op nnvm::NodeAttrs attrs; - attrs.op = nnvm::Op::Get("slice_axis"); - attrs.dict.insert({"axis", "0"}); + attrs.op = nnvm::Op::Get("slice"); attrs.dict.insert({"begin", std::to_string(begin)}); attrs.dict.insert({"end", std::to_string(end)}); attrs.op->attr_parser(&attrs); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 0f371d174e40..9bfc20cd7bb5 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -125,6 +125,10 @@ static inline bool SupportStorageMKLDNN(int stype) { static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) { int ndim = shape.ndim(); + if (ndim == 0 || shape.Size() == 0) { + // MKLDNN currently does not support 0-dim Tensor and 0-size Tensor + return false; + } return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4); } diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 8e46f03e79bc..9f4e62cac50c 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -642,13 +642,18 @@ def test_getitem(np_array, index): ) np_indexed_array = np_array[np_index] mx_np_array = np.array(np_array, dtype=np_array.dtype) - try: - mx_indexed_array = mx_np_array[index] - except Exception as e: - print('Failed with index = {}'.format(index)) - raise e - mx_indexed_array = mx_indexed_array.asnumpy() - assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index) + for autograd in [True, False]: + try: + if autograd: + with mx.autograd.record(): + mx_indexed_array = mx_np_array[index] + else: + mx_indexed_array = mx_np_array[index] + except Exception as e: + print('Failed with index = {}'.format(index)) + raise e + mx_indexed_array = mx_indexed_array.asnumpy() + assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index) def test_setitem(np_array, index): def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None): @@ -768,6 +773,7 @@ def test_setitem_autograd(np_array, index): np_int(slice(1, 5), np.int32), np_int(slice(1, 5), np.int64), slice(1, 5, 2), + slice(1, 2, 2), np_int(slice(1, 5, 2), np.int32), np_int(slice(1, 5, 2), np.int64), slice(7, 0, -1),