@@ -112,32 +112,31 @@ struct SGDDnsRspKernel {
112
112
113
113
template <typename xpu>
114
114
inline void SGDUpdateDnsRspImpl (const SGDParam& param,
115
- const OpContext &ctx,
116
- const std::vector<NDArray> &inputs,
117
- const std::vector<OpReqType> &req,
118
- const std::vector<NDArray> &outputs) {
115
+ const OpContext &ctx,
116
+ const TBlob& weight,
117
+ const NDArray& grad,
118
+ const OpReqType& req,
119
+ TBlob *out) {
119
120
using namespace mshadow ;
120
121
using namespace mshadow ::expr;
121
122
using namespace mshadow_op ;
123
+ using namespace mxnet_op ;
122
124
Stream<xpu>* s = ctx.get_stream <xpu>();
123
- auto &weight = inputs[0 ];
124
- auto &grad = inputs[1 ];
125
- auto &out = outputs[0 ];
126
- CHECK_EQ (weight.storage_type (), kDefaultStorage );
127
125
CHECK_EQ (grad.storage_type (), kRowSparseStorage );
128
- if (!grad.storage_initialized ()) return ;
126
+ // if gradients are zeros, no weights are updated
127
+ if (!grad.storage_initialized () || req == kNullOp ) return ;
128
+ CHECK_GT (weight.shape_ .Size (), 0 );
129
129
130
- MSHADOW_REAL_TYPE_SWITCH (weight.dtype () , DType, {
130
+ MSHADOW_REAL_TYPE_SWITCH (weight.type_flag_ , DType, {
131
131
MSHADOW_INT_TYPE_SWITCH (grad.aux_type (rowsparse::kIdx ), IType, {
132
- MXNET_ASSIGN_REQ_SWITCH (req[0 ], req_type, {
133
- auto weight_data = weight.data ().FlatTo2D <xpu, DType>(s);
134
- auto grad_idx = grad.aux_data (rowsparse::kIdx ).FlatTo1D <xpu, IType>(s);
135
- auto grad_val = grad.data ().FlatTo2D <xpu, DType>(s);
136
- auto out_data = out.data ().FlatTo2D <xpu, DType>(s);
132
+ MXNET_ASSIGN_REQ_SWITCH (req, req_type, {
133
+ auto weight_data = weight.dptr <DType>();
134
+ auto grad_idx = grad.aux_data (rowsparse::kIdx ).dptr <IType>();
135
+ auto grad_val = grad.data ().dptr <DType>();
137
136
auto num_rows = grad.aux_shape (rowsparse::kIdx )[0 ];
138
- auto width = weight.shape () .ProdShape (1 , weight. shape () .ndim ());
139
- mxnet_op:: Kernel<SGDDnsRspKernel<req_type>, xpu>::Launch (s, num_rows, width,
140
- out_data. dptr_ , weight_data. dptr_ , grad_idx. dptr_ , grad_val. dptr_ ,
137
+ auto width = weight.shape_ .ProdShape (1 , weight.ndim ());
138
+ Kernel<SGDDnsRspKernel<req_type>, xpu>::Launch (s, num_rows, width,
139
+ out-> dptr <DType>() , weight_data, grad_idx, grad_val,
141
140
static_cast <DType>(param.clip_gradient ),
142
141
static_cast <DType>(param.lr ), static_cast <DType>(param.wd ),
143
142
static_cast <DType>(param.rescale_grad ));
@@ -146,6 +145,29 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
146
145
});
147
146
}
148
147
148
+ template <typename xpu>
149
+ inline void SGDUpdateRspRspImpl (const SGDParam& param,
150
+ const OpContext& ctx,
151
+ const NDArray& weight,
152
+ const NDArray& grad,
153
+ const OpReqType& req,
154
+ NDArray *out) {
155
+ if (weight.storage_shape ()[0 ] == weight.shape ()[0 ] &&
156
+ out->storage_shape ()[0 ] == out->shape ()[0 ]) {
157
+ // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
158
+ // feed in kWriteTo as req for all operators.
159
+ // For sgd we don't want to assign zeros to the output values when req == kWriteTo
160
+ auto out_req = req;
161
+ if (out_req == kWriteTo ) out_req = kWriteInplace ;
162
+ // reuse dns rsp implementation when storage_shape == shape
163
+ TBlob out_blob = out->data ();
164
+ SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data (), grad, out_req, &out_blob);
165
+ } else {
166
+ LOG (FATAL) << " SGDUpdate for RowSparse weights is only implemented when "
167
+ << " weights.values.shape == weights.shape" ;
168
+ }
169
+ }
170
+
149
171
template <typename xpu>
150
172
inline void SGDUpdateEx (const nnvm::NodeAttrs& attrs,
151
173
const OpContext &ctx,
@@ -159,7 +181,11 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
159
181
auto weight_stype = inputs[0 ].storage_type ();
160
182
auto grad_stype = inputs[1 ].storage_type ();
161
183
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage ) {
162
- SGDUpdateDnsRspImpl<xpu>(param, ctx, inputs, req, outputs);
184
+ TBlob out = outputs[0 ].data ();
185
+ SGDUpdateDnsRspImpl<xpu>(param, ctx, inputs[0 ].data (), inputs[1 ], req[0 ], &out);
186
+ } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage ) {
187
+ NDArray out = outputs[0 ];
188
+ SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0 ], inputs[1 ], req[0 ], &out);
163
189
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage ) {
164
190
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDUpdate<xpu>, " SGDUpdate" );
165
191
}
@@ -262,30 +288,31 @@ struct SGDMomDnsRspDnsKernel {
262
288
263
289
template <typename xpu>
264
290
inline void SGDMomUpdateDnsRspDnsImpl (const SGDMomParam& param,
265
- const OpContext &ctx,
266
- const std::vector<NDArray> &inputs,
267
- const std::vector<OpReqType> &req,
268
- const std::vector<NDArray> &outputs) {
291
+ const OpContext& ctx,
292
+ const TBlob& weight,
293
+ const NDArray& grad,
294
+ const TBlob& mom,
295
+ const OpReqType& req,
296
+ TBlob *out) {
269
297
using namespace mxnet_op ;
298
+ using namespace rowsparse ;
270
299
Stream<xpu>* s = ctx.get_stream <xpu>();
271
- auto &weight = inputs[0 ];
272
- auto &grad = inputs[1 ];
273
- auto &mom = inputs[2 ];
274
- auto &out = outputs[0 ];
275
- if (!grad.storage_initialized ()) return ;
300
+ if (!grad.storage_initialized () || req == kNullOp ) return ;
301
+ CHECK_GT (weight.shape_ .Size (), 0 );
302
+ CHECK_GT (mom.shape_ .Size (), 0 );
276
303
277
- MSHADOW_REAL_TYPE_SWITCH (weight.dtype () , DType, {
278
- MSHADOW_INT_TYPE_SWITCH (grad.aux_type (rowsparse:: kIdx ), IType, {
279
- MXNET_ASSIGN_REQ_SWITCH (req[ 0 ] , req_type, {
280
- auto weight_data = weight.data (). FlatTo2D <xpu, DType>(s );
281
- auto grad_idx = grad.aux_data (rowsparse:: kIdx ).FlatTo1D <xpu, IType>(s );
282
- auto grad_val = grad.data ().FlatTo2D <xpu, DType>(s );
283
- auto mom_data = mom.data (). FlatTo2D <xpu, DType>(s );
284
- auto out_data = out. data (). FlatTo2D <xpu, DType>(s );
285
- auto num_rows = grad.aux_shape (rowsparse:: kIdx )[0 ];
286
- auto width = weight.shape () .ProdShape (1 , weight. shape () .ndim ());
304
+ MSHADOW_REAL_TYPE_SWITCH (weight.type_flag_ , DType, {
305
+ MSHADOW_INT_TYPE_SWITCH (grad.aux_type (kIdx ), IType, {
306
+ MXNET_ASSIGN_REQ_SWITCH (req, req_type, {
307
+ auto weight_data = weight.dptr < DType>();
308
+ auto grad_idx = grad.aux_data (kIdx ).dptr < IType>();
309
+ auto grad_val = grad.data ().dptr < DType>();
310
+ auto mom_data = mom.dptr < DType>();
311
+ auto out_data = out-> dptr < DType>();
312
+ auto num_rows = grad.aux_shape (kIdx )[0 ];
313
+ auto width = weight.shape_ .ProdShape (1 , weight.ndim ());
287
314
Kernel<SGDMomDnsRspDnsKernel<req_type>, xpu>::Launch (s, num_rows, width,
288
- out_data. dptr_ , mom_data. dptr_ , weight_data. dptr_ , grad_idx. dptr_ , grad_val. dptr_ ,
315
+ out_data, mom_data, weight_data, grad_idx, grad_val,
289
316
static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.momentum ),
290
317
static_cast <DType>(param.lr ), static_cast <DType>(param.wd ),
291
318
static_cast <DType>(param.rescale_grad ));
@@ -294,6 +321,50 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
294
321
});
295
322
}
296
323
324
+ template <typename xpu>
325
+ inline void SGDMomUpdateRspRspRspImpl (const SGDMomParam& param,
326
+ const OpContext& ctx,
327
+ const NDArray& weight,
328
+ const NDArray& grad,
329
+ const NDArray& mom,
330
+ const OpReqType& req,
331
+ NDArray *out) {
332
+ using namespace mshadow ;
333
+ using namespace mshadow ::expr;
334
+ using namespace mxnet_op ;
335
+ using namespace rowsparse ;
336
+ if (weight.storage_shape ()[0 ] == weight.shape ()[0 ] &&
337
+ out->storage_shape ()[0 ] == out->shape ()[0 ]) {
338
+ Stream<xpu>* s = ctx.get_stream <xpu>();
339
+ // fill mom with zero values in order to reuse the sgd mom dns impl
340
+ if (!mom.storage_initialized ()) {
341
+ MSHADOW_REAL_TYPE_SWITCH (mom.dtype (), DType, {
342
+ MSHADOW_INT_TYPE_SWITCH (mom.aux_type (kIdx ), IType, {
343
+ auto num_rows = mom.shape ()[0 ];
344
+ mom.CheckAndAlloc ({Shape1 (num_rows)});
345
+ auto mom_idx = mom.aux_data (kIdx ).FlatTo1D <xpu, IType>(s);
346
+ auto mom_val = mom.data ();
347
+ // TODO(haibin) this is single-thread execution
348
+ Kernel<set_zero, xpu>::Launch (s, mom_val.Size (), mom_val.dptr <DType>());
349
+ ASSIGN_DISPATCH (mom_idx, kWriteTo , range<IType>(0 , num_rows, 1 , 1 ))
350
+ });
351
+ });
352
+ }
353
+ // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
354
+ // feed in kWriteTo as req for all operators.
355
+ // For sgd we don't want to assign zeros to the output values when req == kWriteTo
356
+ auto out_req = req;
357
+ if (out_req == kWriteTo ) out_req = kWriteInplace ;
358
+ TBlob out_blob = out->data ();
359
+ // reuse dns rsp implementation when storage_shape == shape
360
+ SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data (), grad,
361
+ mom.data (), out_req, &out_blob);
362
+ } else {
363
+ LOG (FATAL) << " SGDUpdate for RowSparse weights is only implemented when "
364
+ << " weights.values.shape == weights.shape" ;
365
+ }
366
+ }
367
+
297
368
template <typename xpu>
298
369
inline void SGDMomUpdateEx (const nnvm::NodeAttrs& attrs,
299
370
const OpContext &ctx,
@@ -305,10 +376,16 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
305
376
auto weight_stype = inputs[0 ].storage_type ();
306
377
auto grad_stype = inputs[1 ].storage_type ();
307
378
auto mom_stype = inputs[2 ].storage_type ();
308
-
309
379
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage &&
310
380
mom_stype == kDefaultStorage ) {
311
- SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs, req, outputs);
381
+ TBlob out = outputs[0 ].data ();
382
+ SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0 ].data (), inputs[1 ],
383
+ inputs[2 ].data (), req[0 ], &out);
384
+ } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
385
+ mom_stype == kRowSparseStorage ) {
386
+ NDArray out = outputs[0 ];
387
+ SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0 ], inputs[1 ],
388
+ inputs[2 ], req[0 ], &out);
312
389
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage &&
313
390
mom_stype == kDefaultStorage ) {
314
391
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs,
0 commit comments