diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 22e9abd156a3..08d91af6fbb3 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -89,6 +89,14 @@ class MKLDNNPoolingBwd { const mkldnn::pooling_backward::primitive_desc &GetPd(); }; +inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) { + if ((x + padl + padr - k) % s != 0) { + return (padr + s - ((x + padl + padr - k) % s)); + } else { + return padr; + } +} + inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { return param.kernel.ndim() == 2 && (param.pool_type == pool_enum::kMaxPooling || @@ -105,7 +113,17 @@ inline bool SupportMKLDNNPooling(const PoolingParam ¶m, if (param.pooling_convention == pool_enum::kValid) { return true; } else { - // currently, only max-pooling is supported for full convention + if (param.pool_type == pool_enum::kAvgPooling) { + CHECK_EQ(dshape.ndim(), 4); + // mkldnn works differently when padding is asymmetric, so let's skip this case. + if (param.pad[0] == GetPaddingSizeFull(dshape[2], param.pad[0], param.pad[0], param.kernel[0], + param.stride[0]) && + param.pad[1] == GetPaddingSizeFull(dshape[3], param.pad[1], param.pad[1], param.kernel[1], + param.stride[1])) { + return true; + } + return false; + } return param.pool_type == pool_enum::kMaxPooling; } } diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 6eda2aa33b34..d2f79700051a 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -127,13 +127,6 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { } } -static inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) { - if ((x + padl + padr - k) % s != 0) { - return (padr + s - ((x + padl + padr - k) % s)); - } else { - return padr; - } -} mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( const PoolingParam ¶m, const bool is_train, const mkldnn::memory::desc &data_md, diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 485fc1345dfd..f998c33d16a8 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -278,9 +278,8 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, return; } - - if (SupportMKLDNN(inputs[0]) && - SupportMKLDNNPooling(param, inputs[0].shape())) { + if (SupportMKLDNN(inputs[0]) + && SupportMKLDNNPooling(param, inputs[0].shape())) { if (MKLDNNRequireWorkspace(param)) { CHECK_GT(outputs.size(), 1U); workspace = &outputs[1];