Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
yury-intel committed Jul 30, 2021
1 parent efbe4a4 commit e6af0f8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
49 changes: 29 additions & 20 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,12 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

const int simd_w = mayiuse(cpu::x64::avx512_common) ? 16 : 8;
if (group != 1 && (((getParentEdgeAt(0)->getDims()[1] / group) % simd_w != 0)
|| ((getChildEdgeAt(0)->getDims()[1] / group) % simd_w != 0))) {
enforceRef = true;
}

size_t inputsNumber = getOriginalInputsNumber();
InferenceEngine::LayerConfig config;
config.dynBatchSupport = false;
Expand All @@ -986,19 +992,20 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
config.outConfs[0].inPlace = -1;

impl_desc_type impl_type;
// if (mayiuse(cpu::x64::avx512_common)) {
// impl_type = impl_desc_type::jit_avx512;
// } else if (mayiuse(cpu::x64::avx2)) {
// impl_type = impl_desc_type::jit_avx2;
// } else if (mayiuse(cpu::x64::sse41)) {
// impl_type = impl_desc_type::jit_sse42;
// } else {
// impl_type = impl_desc_type::ref;
// }
impl_type = impl_desc_type::ref;

if (false && mayiuse(cpu::x64::sse41)) {
// optimzed implementation
if (enforceRef) {
impl_type = impl_desc_type::ref;
} else if (mayiuse(cpu::x64::avx512_common)) {
impl_type = impl_desc_type::jit_avx512;
} else if (mayiuse(cpu::x64::avx2)) {
impl_type = impl_desc_type::jit_avx2;
} else if (mayiuse(cpu::x64::sse41)) {
impl_type = impl_desc_type::jit_sse42;
} else {
impl_type = impl_desc_type::ref;
}

if (!enforceRef && mayiuse(cpu::x64::sse41)) {
// optimized implementation
auto dataFormat = memory::format_tag::nhwc;
auto offFormat = memory::format_tag::nchw;
auto weiFormat = group > 1 ? mayiuse(avx512_common) ? memory::format_tag::gOIhw16i16o : memory::format_tag::gOIhw8i8o
Expand Down Expand Up @@ -1097,13 +1104,15 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {

jcp.nthr = dnnl_get_max_threads();

// if (mayiuse(cpu::x64::avx512_common)) {
// def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
// } else if (mayiuse(cpu::x64::avx2)) {
// def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
// } else if (mayiuse(cpu::x64::sse41)) {
// def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
// }
if (enforceRef) {
return;
} else if (mayiuse(cpu::x64::avx512_common)) {
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
} else if (mayiuse(cpu::x64::avx2)) {
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
} else if (mayiuse(cpu::x64::sse41)) {
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
}

if (def_conv_kernel)
def_conv_kernel->create_ker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class MKLDNNDeformableConvolutionNode : public MKLDNNNode {
bool canBeInPlace() const override {
return false;
}
bool enforceRef = false;

InferenceEngine::Precision getRuntimePrecision() const override;

Expand Down

0 comments on commit e6af0f8

Please sign in to comment.