diff --git a/python/mxnet/ndarray/sparse_ndarray.py b/python/mxnet/ndarray/sparse_ndarray.py index 9248a43e9627..6c8bb44b0ae3 100644 --- a/python/mxnet/ndarray/sparse_ndarray.py +++ b/python/mxnet/ndarray/sparse_ndarray.py @@ -545,7 +545,7 @@ def __setitem__(self, key, value): ---------- key : slice The indexing key. - value : NDArray or numpy.ndarray + value : scalar, NDArray or numpy.ndarray The value to set. Examples @@ -568,6 +568,12 @@ def __setitem__(self, key, value): array([[ 1., 1., 1.], [ 1., 1., 1.], [ 1., 1., 1.]], dtype=float32) + >>> # assign scalar to RowSparseNDArray + >>> x[:] = 7 + >>> x.asnumpy() + array([[ 7., 7., 7.], + [ 7., 7., 7.], + [ 7., 7., 7.]], dtype=float32) """ if not self.writable: raise ValueError('Failed to assign to a readonly RowSparseNDArray') @@ -580,8 +586,7 @@ def __setitem__(self, key, value): if value.handle is not self.handle: value.copyto(self) elif isinstance(value, numeric_types): - raise ValueError("Assigning numeric types to RowSparseNDArray " \ - "is not implemented yet.") + _internal._set_value(float(value), out=self) elif isinstance(value, (np.ndarray, np.generic)): warnings.warn('Assigning non-NDArray object to RowSparseNDArray is not efficient', RuntimeWarning) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 0d2968626d79..9f5b7bab820a 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -295,9 +295,21 @@ void SetValueOp(const real_t &rhs, NDArray *out) { switch (ret.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { - CHECK(ret.storage_type() == kDefaultStorage); - TBlob tmp = ret.data(); - ndarray::Eval(rhs, &tmp, ctx); + auto ret_stype = ret.storage_type(); + mshadow::Stream *s = ctx.get_stream(); + if (ret_stype == kRowSparseStorage) { + NDArray out = ret; + // indices + nnvm::dim_t nnr = ret.shape()[0]; + out.CheckAndAlloc({mshadow::Shape1(nnr)}); + op::PopulateFullIdxRspImpl(s, &out); + // data + TBlob tmp = out.data(); + ndarray::Eval(rhs, &tmp, ctx); + } else { + TBlob tmp = ret.data(); + ndarray::Eval(rhs, &tmp, ctx); + } }, ret.ctx(), {}, {ret.var()}, FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; @@ -305,8 +317,21 @@ void SetValueOp(const real_t &rhs, NDArray *out) { #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Eval(rhs, &tmp, ctx); + auto ret_stype = ret.storage_type(); + mshadow::Stream *s = ctx.get_stream(); + if (ret_stype == kRowSparseStorage) { + NDArray out = ret; + // indices + nnvm::dim_t nnr = ret.shape()[0]; + out.CheckAndAlloc({mshadow::Shape1(nnr)}); + op::PopulateFullIdxRspImpl(s, &out); + // data + TBlob tmp = out.data(); + ndarray::Eval(rhs, &tmp, ctx); + } else { + TBlob tmp = ret.data(); + ndarray::Eval(rhs, &tmp, ctx); + } // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, ret.ctx(), {}, {ret.var()}, diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index cbce588549a1..6c4fb01978d4 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -102,7 +102,8 @@ def test_sparse_nd_setitem(): def check_sparse_nd_setitem(stype, shape, dst): x = mx.nd.zeros(shape=shape, stype=stype) x[:] = dst - dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst + dst_nd = mx.nd.zeros(shape=shape) + dst_nd[:] = dst assert same(x.asnumpy(), dst_nd.asnumpy()) shape = rand_shape_2d() @@ -112,6 +113,10 @@ def check_sparse_nd_setitem(stype, shape, dst): check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, stype)) # numpy assignment check_sparse_nd_setitem(stype, shape, np.ones(shape)) + if stype == 'row_sparse': + # scalar assignment + check_sparse_nd_setitem(stype, shape, 0) + check_sparse_nd_setitem(stype, shape, 1) def test_sparse_nd_slice():