Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0]set fc weight layout as mkldnn v0.2x did (#16593)
Browse files Browse the repository at this point in the history
* set fc weight layout as mkldnn v0.2x did

* fix lint
  • Loading branch information
rongzha1 authored and pengzhao-intel committed Oct 25, 2019
1 parent 29f2e32 commit c23568f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
14 changes: 14 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype = -1) {
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
// for batch 256 alexnet benchmark test
if (dims.size() == 2) {
format = mkldnn::memory::format_tag::ab;
}

return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
}

inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
int num_groups,
bool quantized = false) {
Expand Down
9 changes: 4 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
const NDArray &data, const NDArray &weight, const NDArray *bias,
const mkldnn::memory::desc &out_md) {
auto data_md = GetMemDesc(data);
auto weight_md = GetMemDesc(weight);
auto weight_md = GetFCWeightDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
auto propagation =
is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
Expand Down Expand Up @@ -101,7 +101,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray &data, const NDArray &weight, const NDArray &output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetMemDesc(weight);
auto weight_md = GetFCWeightDesc(weight);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
Expand All @@ -112,7 +112,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
const NDArray &data, const NDArray &weight, const NDArray *bias,
const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetMemDesc(weight);
auto weight_md = GetFCWeightDesc(weight);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {
Expand Down Expand Up @@ -208,8 +208,7 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
} else {
weight_mem = weight.GetMKLDNNData();
if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
// TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug
// weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
}
}
Expand Down

0 comments on commit c23568f

Please sign in to comment.