Skip to content

Commit

Permalink
[MKLDNN] NDArray reorder in C API and deconv (apache#16265)
Browse files Browse the repository at this point in the history
* layout conversion in MXNDArrayGetData

* deconv: layout conversion for grad

* fix lint complain
  • Loading branch information
TaoLv authored and drivanov committed Sep 26, 2019
1 parent d3ce70b commit 72c8b88
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,13 @@ int MXNDArrayGetData(NDArrayHandle handle,
void **out_pdata) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
#if MXNET_USE_MKLDNN == 1
NDArray temp = *arr;
if (arr->IsMKLDNNData()) {
temp = arr->Reorder2Default();
arr = &temp;
}
#endif
if (!arr->is_none()) {
*out_pdata = arr->data().dptr_;
} else {
Expand Down
11 changes: 7 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,10 +559,13 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
Stream<cpu> *s = ctx.get_stream<cpu>();
Tensor<cpu, 1, DType> gbias =
in_grad[deconv::kBias].data().get<cpu, 1, DType>(s);
// If there is bias, the out grad has already been converted to the default
// format, so this shouldn't cause any performance issues.
Tensor<cpu, 4, DType> grad =
inputs[deconv::kOut].data().get<cpu, 4, DType>(s);

NDArray temp = inputs[deconv::kOut];
if (temp.IsMKLDNNData()) {
temp = temp.Reorder2Default();
}

Tensor<cpu, 4, DType> grad = temp.data().get<cpu, 4, DType>(s);
Assign(gbias, req[deconv::kBias],
mshadow::expr::sumall_except_dim<1>(grad));
}
Expand Down

0 comments on commit 72c8b88

Please sign in to comment.