Skip to content

Commit a880bc7

Browse files
draft for sgd rsp rsp (#75)
support sgd(rsp, rsp) support dot(csr, rsp) when rsp is full add ref to const ndarray params support sparse embedding with rsp weight' fix lint modify embedding backward to produce dense grad remove invalid_rid for rsp->dns remove previous embedding op changes pass sparse embedding test add STORAGE_TYPE_ASSIGN_CHECK remove backward storage infer
1 parent f98912b commit a880bc7

11 files changed

+365
-224
lines changed

python/mxnet/optimizer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from .ndarray import NDArray, zeros, clip, sqrt, sign
66
from .ndarray import sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update
7+
from .sparse_ndarray import zeros as sparse_zeros
78
from .random import normal
89

910

@@ -332,7 +333,8 @@ def create_state(self, index, weight):
332333
if self.momentum == 0.0:
333334
return None
334335
else:
335-
return zeros(weight.shape, weight.context, dtype=weight.dtype)
336+
return sparse_zeros(weight.storage_type, weight.shape,
337+
weight.context, dtype=weight.dtype)
336338

337339
def update(self, index, weight, grad, state):
338340
assert(isinstance(weight, NDArray))

python/mxnet/sparse_ndarray.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def to_dense(source):
571571
"""
572572
return ndarray.cast_storage(source, storage_type='default')
573573

574-
def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None):
574+
def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs):
575575
"""Return a new array of given shape and type, filled with zeros.
576576
577577
Parameters
@@ -599,6 +599,8 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None):
599599
>>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy()
600600
array([[ 0., 0.]], dtype=float16)
601601
"""
602+
if storage_type == 'default':
603+
return ndarray.zeros(shape, ctx, dtype, **kwargs)
602604
if ctx is None:
603605
ctx = Context.default_ctx
604606
dtype = mx_real_t if dtype is None else dtype
@@ -609,7 +611,7 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None):
609611
raise Exception("unknown storage type")
610612
assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type]))
611613
out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types))
612-
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out)
614+
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs)
613615

614616
def _ndarray_cls(handle, writable=True):
615617
stype = _storage_type(handle)

src/operator/operator_common.h

+31
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,19 @@ inline std::string type_string(const int& x) {
110110
return "unknown";
111111
}
112112

113+
/*! \brief get string representation of storage_type */
114+
inline std::string stype_string(const int& x) {
115+
switch (x) {
116+
case kDefaultStorage:
117+
return "default";
118+
case kCSRStorage:
119+
return "csr";
120+
case kRowSparseStorage:
121+
return "row_sparse";
122+
}
123+
return "unknown";
124+
}
125+
113126
/*!
114127
* \brief Assign x to y. Checks for compatiblity when y is not empty.
115128
* Allow missing dim in both x and y (as 0).
@@ -186,6 +199,24 @@ inline bool type_assign(int *y, const int& x) {
186199
} \
187200
}
188201

202+
/*!
203+
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
204+
* Use macro so we can see the error file more clearly
205+
* \param type_array the storage type array to store the result
206+
* \param index the index of in the array
207+
* \param type the inferred storage type
208+
*/
209+
#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \
210+
{ \
211+
if (!type_assign(&(type_array)[index], type)) { \
212+
std::ostringstream os; \
213+
os << "Storage type inconsistent, Provided=" \
214+
<< stype_string((type_array)[index]) << ',' \
215+
<< " inferred storage type=" << stype_string(type); \
216+
throw ::mxnet::op::InferTypeError(os.str(), index); \
217+
} \
218+
}
219+
189220
// helper macro to implement bind dispatch
190221
#if MXNET_USE_CUDA
191222
#define DO_BIND_DISPATCH(Method, ...) \

src/operator/optimizer_op-inl.h

+118-41
Original file line numberDiff line numberDiff line change
@@ -112,32 +112,31 @@ struct SGDDnsRspKernel {
112112

113113
template<typename xpu>
114114
inline void SGDUpdateDnsRspImpl(const SGDParam& param,
115-
const OpContext &ctx,
116-
const std::vector<NDArray> &inputs,
117-
const std::vector<OpReqType> &req,
118-
const std::vector<NDArray> &outputs) {
115+
const OpContext &ctx,
116+
const TBlob& weight,
117+
const NDArray& grad,
118+
const OpReqType& req,
119+
TBlob *out) {
119120
using namespace mshadow;
120121
using namespace mshadow::expr;
121122
using namespace mshadow_op;
123+
using namespace mxnet_op;
122124
Stream<xpu>* s = ctx.get_stream<xpu>();
123-
auto &weight = inputs[0];
124-
auto &grad = inputs[1];
125-
auto &out = outputs[0];
126-
CHECK_EQ(weight.storage_type(), kDefaultStorage);
127125
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
128-
if (!grad.storage_initialized()) return;
126+
// if gradients are zeros, no weights are updated
127+
if (!grad.storage_initialized() || req == kNullOp) return;
128+
CHECK_GT(weight.shape_.Size(), 0);
129129

130-
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
130+
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
131131
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
132-
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
133-
auto weight_data = weight.data().FlatTo2D<xpu, DType>(s);
134-
auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D<xpu, IType>(s);
135-
auto grad_val = grad.data().FlatTo2D<xpu, DType>(s);
136-
auto out_data = out.data().FlatTo2D<xpu, DType>(s);
132+
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
133+
auto weight_data = weight.dptr<DType>();
134+
auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
135+
auto grad_val = grad.data().dptr<DType>();
137136
auto num_rows = grad.aux_shape(rowsparse::kIdx)[0];
138-
auto width = weight.shape().ProdShape(1, weight.shape().ndim());
139-
mxnet_op::Kernel<SGDDnsRspKernel<req_type>, xpu>::Launch(s, num_rows, width,
140-
out_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_,
137+
auto width = weight.shape_.ProdShape(1, weight.ndim());
138+
Kernel<SGDDnsRspKernel<req_type>, xpu>::Launch(s, num_rows, width,
139+
out->dptr<DType>(), weight_data, grad_idx, grad_val,
141140
static_cast<DType>(param.clip_gradient),
142141
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
143142
static_cast<DType>(param.rescale_grad));
@@ -146,6 +145,29 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
146145
});
147146
}
148147

148+
template<typename xpu>
149+
inline void SGDUpdateRspRspImpl(const SGDParam& param,
150+
const OpContext& ctx,
151+
const NDArray& weight,
152+
const NDArray& grad,
153+
const OpReqType& req,
154+
NDArray *out) {
155+
if (weight.storage_shape()[0] == weight.shape()[0] &&
156+
out->storage_shape()[0] == out->shape()[0]) {
157+
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
158+
// feed in kWriteTo as req for all operators.
159+
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
160+
auto out_req = req;
161+
if (out_req == kWriteTo) out_req = kWriteInplace;
162+
// reuse dns rsp implementation when storage_shape == shape
163+
TBlob out_blob = out->data();
164+
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
165+
} else {
166+
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when "
167+
<< "weights.values.shape == weights.shape";
168+
}
169+
}
170+
149171
template<typename xpu>
150172
inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
151173
const OpContext &ctx,
@@ -159,7 +181,11 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
159181
auto weight_stype = inputs[0].storage_type();
160182
auto grad_stype = inputs[1].storage_type();
161183
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) {
162-
SGDUpdateDnsRspImpl<xpu>(param, ctx, inputs, req, outputs);
184+
TBlob out = outputs[0].data();
185+
SGDUpdateDnsRspImpl<xpu>(param, ctx, inputs[0].data(), inputs[1], req[0], &out);
186+
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) {
187+
NDArray out = outputs[0];
188+
SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
163189
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) {
164190
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDUpdate<xpu>, "SGDUpdate");
165191
}
@@ -262,30 +288,31 @@ struct SGDMomDnsRspDnsKernel {
262288

263289
template<typename xpu>
264290
inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
265-
const OpContext &ctx,
266-
const std::vector<NDArray> &inputs,
267-
const std::vector<OpReqType> &req,
268-
const std::vector<NDArray> &outputs) {
291+
const OpContext& ctx,
292+
const TBlob& weight,
293+
const NDArray& grad,
294+
const TBlob& mom,
295+
const OpReqType& req,
296+
TBlob *out) {
269297
using namespace mxnet_op;
298+
using namespace rowsparse;
270299
Stream<xpu>* s = ctx.get_stream<xpu>();
271-
auto &weight = inputs[0];
272-
auto &grad = inputs[1];
273-
auto &mom = inputs[2];
274-
auto &out = outputs[0];
275-
if (!grad.storage_initialized()) return;
300+
if (!grad.storage_initialized() || req == kNullOp) return;
301+
CHECK_GT(weight.shape_.Size(), 0);
302+
CHECK_GT(mom.shape_.Size(), 0);
276303

277-
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
278-
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
279-
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
280-
auto weight_data = weight.data().FlatTo2D<xpu, DType>(s);
281-
auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D<xpu, IType>(s);
282-
auto grad_val = grad.data().FlatTo2D<xpu, DType>(s);
283-
auto mom_data = mom.data().FlatTo2D<xpu, DType>(s);
284-
auto out_data = out.data().FlatTo2D<xpu, DType>(s);
285-
auto num_rows = grad.aux_shape(rowsparse::kIdx)[0];
286-
auto width = weight.shape().ProdShape(1, weight.shape().ndim());
304+
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
305+
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
306+
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
307+
auto weight_data = weight.dptr<DType>();
308+
auto grad_idx = grad.aux_data(kIdx).dptr<IType>();
309+
auto grad_val = grad.data().dptr<DType>();
310+
auto mom_data = mom.dptr<DType>();
311+
auto out_data = out->dptr<DType>();
312+
auto num_rows = grad.aux_shape(kIdx)[0];
313+
auto width = weight.shape_.ProdShape(1, weight.ndim());
287314
Kernel<SGDMomDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, width,
288-
out_data.dptr_, mom_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_,
315+
out_data, mom_data, weight_data, grad_idx, grad_val,
289316
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
290317
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
291318
static_cast<DType>(param.rescale_grad));
@@ -294,6 +321,50 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
294321
});
295322
}
296323

324+
template<typename xpu>
325+
inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
326+
const OpContext& ctx,
327+
const NDArray& weight,
328+
const NDArray& grad,
329+
const NDArray& mom,
330+
const OpReqType& req,
331+
NDArray *out) {
332+
using namespace mshadow;
333+
using namespace mshadow::expr;
334+
using namespace mxnet_op;
335+
using namespace rowsparse;
336+
if (weight.storage_shape()[0] == weight.shape()[0] &&
337+
out->storage_shape()[0] == out->shape()[0]) {
338+
Stream<xpu>* s = ctx.get_stream<xpu>();
339+
// fill mom with zero values in order to reuse the sgd mom dns impl
340+
if (!mom.storage_initialized()) {
341+
MSHADOW_REAL_TYPE_SWITCH(mom.dtype(), DType, {
342+
MSHADOW_INT_TYPE_SWITCH(mom.aux_type(kIdx), IType, {
343+
auto num_rows = mom.shape()[0];
344+
mom.CheckAndAlloc({Shape1(num_rows)});
345+
auto mom_idx = mom.aux_data(kIdx).FlatTo1D<xpu, IType>(s);
346+
auto mom_val = mom.data();
347+
// TODO(haibin) this is single-thread execution
348+
Kernel<set_zero, xpu>::Launch(s, mom_val.Size(), mom_val.dptr<DType>());
349+
ASSIGN_DISPATCH(mom_idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
350+
});
351+
});
352+
}
353+
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
354+
// feed in kWriteTo as req for all operators.
355+
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
356+
auto out_req = req;
357+
if (out_req == kWriteTo) out_req = kWriteInplace;
358+
TBlob out_blob = out->data();
359+
// reuse dns rsp implementation when storage_shape == shape
360+
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
361+
mom.data(), out_req, &out_blob);
362+
} else {
363+
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when "
364+
<< "weights.values.shape == weights.shape";
365+
}
366+
}
367+
297368
template<typename xpu>
298369
inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
299370
const OpContext &ctx,
@@ -305,10 +376,16 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
305376
auto weight_stype = inputs[0].storage_type();
306377
auto grad_stype = inputs[1].storage_type();
307378
auto mom_stype = inputs[2].storage_type();
308-
309379
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage &&
310380
mom_stype == kDefaultStorage) {
311-
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs, req, outputs);
381+
TBlob out = outputs[0].data();
382+
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
383+
inputs[2].data(), req[0], &out);
384+
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
385+
mom_stype == kRowSparseStorage) {
386+
NDArray out = outputs[0];
387+
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1],
388+
inputs[2], req[0], &out);
312389
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage &&
313390
mom_stype == kDefaultStorage) {
314391
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs,

src/operator/tensor/elemwise_unary_op.h

+3-6
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,8 @@ inline void CastStorageDnsRspImpl(mshadow::Stream<cpu>* s, const TBlob& dns, NDA
324324
struct CastStorageRspDnsKernel {
325325
template<typename DType, typename IType>
326326
MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data,
327-
DType* dns, const index_t invalid_rid) {
327+
DType* dns) {
328328
auto rid = idx[i];
329-
// skip invalid rows
330-
if (rid == invalid_rid) return;
331329
auto dns_offset = rid * width;
332330
auto rsp_offset = i * width;
333331
for (size_t col = 0; col < width; col++) {
@@ -356,10 +354,9 @@ void CastStorageRspDnsImpl(mshadow::Stream<xpu>* s, const NDArray& rsp, TBlob* d
356354
auto out_data = dns->FlatTo2D<xpu, DType>(s).dptr_;
357355
auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size();
358356
auto rsp_shape = rsp.shape();
359-
auto invalid_rid = rsp_shape[0];
360357
auto width = rsp_shape.ProdShape(1, rsp_shape.ndim());
361-
mxnet_op::Kernel<CastStorageRspDnsKernel, xpu>::Launch(s, num_rows, width, in_idx, in_data,
362-
out_data, invalid_rid);
358+
mxnet_op::Kernel<CastStorageRspDnsKernel, xpu>::Launch(s, num_rows, width, in_idx,
359+
in_data, out_data);
363360
}
364361
});
365362
});

src/operator/tensor/indexing_op.cc

+17-9
Original file line numberDiff line numberDiff line change
@@ -87,39 +87,47 @@ NNVM_REGISTER_OP(_backward_Embedding)
8787
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>);
8888

8989
NNVM_REGISTER_OP(SparseEmbedding)
90-
.describe(R"code(Maps integer indices to vector representations (embeddings) with sparse weight update
91-
)code" ADD_FILELINE)
90+
.describe(R"doc(Represents words or other sparse inputs by dense continuous vectors.
91+
It assumes that the input is in one-hot form. E.g., for a vocabulary size of 10,000,
92+
each input vector is expected to have dimension 10,000.
93+
The index of the non-zero entry is the index of the word or item it represents.
94+
95+
The corresponding embedding vectors are stored as rows of a matrix.
96+
Hence, mapping an input word to its embedding is implemented as a matrix product.
97+
98+
The gradient of an embedding matrix has the form of gradient vectors that are only
99+
non-zero for words seen in a minibatch.
100+
)doc" ADD_FILELINE)
92101
.set_num_inputs(2)
93102
.set_num_outputs(1)
94103
.set_attr_parser(ParamParser<EmbeddingParam>)
95104
.set_attr<nnvm::FListInputNames>("FListInputNames",
96105
[](const NodeAttrs& attrs) {
97106
return std::vector<std::string>{"data", "weight"};
98107
})
99-
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape)
108+
.set_attr<nnvm::FInferShape>("FInferShape", SparseEmbeddingShape)
100109
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType)
110+
.set_attr<nnvm::FInferStorageType>("FInferStorageType", SparseEmbeddingForwardStorageType)
101111
.set_attr<FResourceRequest>("FResourceRequest",
102112
[](const NodeAttrs& attrs) {
103113
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
104114
})
105-
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpForward<cpu>)
115+
.set_attr<FComputeEx>(FCOMP_EX_CPU, SparseEmbeddingForwardEx<cpu>)
106116
.set_attr<nnvm::FGradient>("FGradient",
107117
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
108118
return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds,
109119
{n->inputs[0]}, n->attrs.dict);
110120
})
111-
.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.")
121+
.add_argument("data", "NDArray-or-Symbol",
122+
"The input array to the sparse embedding operator.")
112123
.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.")
113124
.add_arguments(EmbeddingParam::__FIELDS__());
114125

115126
NNVM_REGISTER_OP(_backward_SparseEmbedding)
116127
.set_num_inputs(2)
117128
.set_num_outputs(2)
118129
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
119-
.set_attr<nnvm::FInferStorageType>("FInferStorageType", SparseEmbeddingBackwardStorageType)
120-
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingOpBackwardEx<cpu>);
121-
// TODO(haibin) handle dense case
122-
// .set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>);
130+
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingBackwardEx<cpu>);
123131

124132
NNVM_REGISTER_OP(take)
125133
.describe(R"code(Takes elements from an input array along the given axis.

0 commit comments

Comments
 (0)