Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MKLDNN] enable MaxPooling with full pooling convention (#16860)
Browse files Browse the repository at this point in the history
* [MKLDNN] enable MaxPooling for full pooling convention

* Run CI

* Fix UT

* Add comment

* Run CI
  • Loading branch information
ZhennanQin authored and pengzhao-intel committed Dec 16, 2019
1 parent 897f4fa commit 52c9a45
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
20 changes: 19 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class MKLDNNPoolingBwd {
const mkldnn::pooling_backward::primitive_desc &GetPd();
};

inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) {
if ((x + padl + padr - k) % s != 0) {
return (padr + s - ((x + padl + padr - k) % s));
} else {
return padr;
}
}

inline bool SupportMKLDNNPooling(const PoolingParam &param) {
return param.kernel.ndim() == 2 &&
(param.pool_type == pool_enum::kMaxPooling ||
Expand All @@ -105,7 +113,17 @@ inline bool SupportMKLDNNPooling(const PoolingParam &param,
if (param.pooling_convention == pool_enum::kValid) {
return true;
} else {
// currently, only max-pooling is supported for full convention
if (param.pool_type == pool_enum::kAvgPooling) {
CHECK_EQ(dshape.ndim(), 4);
// mkldnn works differently when padding is asymmetric, so let's skip this case.
if (param.pad[0] == GetPaddingSizeFull(dshape[2], param.pad[0], param.pad[0], param.kernel[0],
param.stride[0]) &&
param.pad[1] == GetPaddingSizeFull(dshape[3], param.pad[1], param.pad[1], param.kernel[1],
param.stride[1])) {
return true;
}
return false;
}
return param.pool_type == pool_enum::kMaxPooling;
}
}
Expand Down
7 changes: 0 additions & 7 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,6 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
}
}

static inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) {
if ((x + padl + padr - k) % s != 0) {
return (padr + s - ((x + padl + padr - k) % s));
} else {
return padr;
}
}

mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
const PoolingParam &param, const bool is_train, const mkldnn::memory::desc &data_md,
Expand Down
5 changes: 2 additions & 3 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,8 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
return;
}


if (SupportMKLDNN(inputs[0]) &&
SupportMKLDNNPooling(param, inputs[0].shape())) {
if (SupportMKLDNN(inputs[0])
&& SupportMKLDNNPooling(param, inputs[0].shape())) {
if (MKLDNNRequireWorkspace(param)) {
CHECK_GT(outputs.size(), 1U);
workspace = &outputs[1];
Expand Down

0 comments on commit 52c9a45

Please sign in to comment.