Skip to content

Commit

Permalink
csr slice bug fix (apache#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored and Olivier committed Jun 13, 2017
1 parent af1a06a commit 85d59ce
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def _slice(self, start, stop):
assert(stype == 'csr'), "_slice for " + str(stype) + " not implemented yet"
warnings.warn('slicing SparseNDArray is not efficient', RuntimeWarning)
shape = list(self.shape)
stop = shape[0] if stop is None else stop
start = 0 if start is None else start
shape[0] = stop - start
handle = _new_alloc_handle(self.storage_type, tuple(shape), self.context,
True, self.dtype, self.aux_types)
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def check_sparse_nd_csr_slice(shape):
start = rnd.randint(0, shape[0] - 1)
end = rnd.randint(start + 1, shape[0])
assert same(A[start:end].asnumpy(), A2[start:end])
assert same(A[start:].asnumpy(), A2[start:])
assert same(A[:end].asnumpy(), A2[:end])

shape = (rnd.randint(2, 10), rnd.randint(1, 10))
check_sparse_nd_csr_slice(shape)
Expand Down

0 comments on commit 85d59ce

Please sign in to comment.