From 0241cf950aa96eacfd12f542a54871231a1c75cb Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Wed, 23 Oct 2019 22:23:33 +0800 Subject: [PATCH 1/2] set fc weight layout as mkldnn v0.2x did --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 14 ++++++++++++++ src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 9 ++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 0fa99e1ed3f4..ef5ecc02dae4 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -274,6 +274,20 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1 return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any}; } +inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype = -1) { + int ndim = arr.shape().ndim(); + mkldnn::memory::dims dims(ndim); + dtype = (dtype == -1) ? arr.dtype() : dtype; + for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; + auto format = mkldnn::memory::format_tag::any; + // for batch 256 alexnet benchmark test + if (dims.size() == 2) { + format = mkldnn::memory::format_tag::ab; + } + + return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format}; +} + inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, int num_groups, bool quantized = false) { diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index e5d4a80bdfcf..1e7f879c5322 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -37,7 +37,7 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( const NDArray &data, const NDArray &weight, const NDArray *bias, const mkldnn::memory::desc &out_md) { auto data_md = GetMemDesc(data); - auto weight_md = GetMemDesc(weight); + auto weight_md = GetFCWeightDesc(weight); auto engine = CpuEngine::Get()->get_engine(); auto propagation = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; @@ -101,7 +101,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData( const NDArray &data, const NDArray &weight, const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetMemDesc(weight); + auto weight_md = GetFCWeightDesc(weight); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md); @@ -112,7 +112,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei const NDArray &data, const NDArray &weight, const NDArray *bias, const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetMemDesc(weight); + auto weight_md = GetFCWeightDesc(weight); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); if (bias) { @@ -208,8 +208,7 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param, } else { weight_mem = weight.GetMKLDNNData(); if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) { - // TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug - // weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc()); + weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc()); weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1); } } From 937d7b946d42a8f6b736ff2f18495c26c2874dbb Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Wed, 23 Oct 2019 22:46:05 +0800 Subject: [PATCH 2/2] fix lint --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index ef5ecc02dae4..0f371d174e40 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -280,7 +280,7 @@ inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype dtype = (dtype == -1) ? arr.dtype() : dtype; for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; auto format = mkldnn::memory::format_tag::any; - // for batch 256 alexnet benchmark test + // for batch 256 alexnet benchmark test if (dims.size() == 2) { format = mkldnn::memory::format_tag::ab; }