Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Add MKL-DNN int8 activation&pooling&flatten (#16425)
Browse files Browse the repository at this point in the history
* Add mkldnn quantized activation/pooling/flatten

* int8 flatten
  • Loading branch information
wuxun-zhang authored and pengzhao-intel committed Oct 11, 2019
1 parent 922b616 commit bc91f5b
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 10 deletions.
5 changes: 5 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ void MKLDNNPoolingFwd::Execute(const NDArray &in_data,

if (this->with_workspace_) {
auto engine = CpuEngine::Get()->get_engine();

if (workspace == nullptr) {
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
}

auto ws = std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
engine, workspace->GetMKLDNNData()->get_data_handle());
args[MKLDNN_ARG_WORKSPACE] = *ws;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* \brief MKLDNN(Quantized) Activation operator based on subgraph
* /author Zhiyuan Huang
*/
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include "../../nn/mkldnn/mkldnn_act-inl.h"
#include "../quantization_utils.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \brief
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../../nn/mkldnn/mkldnn_flatten-inl.h"
#include "../quantization_utils.h"

Expand Down
6 changes: 2 additions & 4 deletions src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \author Tao Lv, Xinyu Chen
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include "../../nn/mkldnn/mkldnn_pooling-inl.h"

Expand All @@ -38,9 +38,7 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op
|| in_data[0].dtype() == mshadow::kInt8)
<< "mkldnn_quantized_pooling op only supports uint8 and int8 as input type";
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
auto fwd = GetPoolingFwd(param, ctx.is_train, in_data[0], out_data[0]);
fwd.SetNewMem(in_data[0], out_data[0], req[0]);
fwd.Execute(out_data[0]);
MKLDNNPoolingCompute(ctx, param, in_data[0], req[0], out_data[0], nullptr);
out_data[1].data().dptr<float>()[0] = in_data[1].data().dptr<float>()[0];
out_data[2].data().dptr<float>()[0] = in_data[2].data().dptr<float>()[0];
}
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/quantized_activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ inline static bool QuantizedActivationStorageType(const nnvm::NodeAttrs &attrs,
CHECK_EQ(in_attrs->size(), 3);

*dispatch_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
const ActivationParam &param = nnvm::get<ActivationParam>(attrs.parsed);
if (dev_mask == mshadow::cpu::kDevMask && param.act_type == activation::kReLU) {
*dispatch_mode = DispatchMode::kFComputeEx;
Expand Down
6 changes: 3 additions & 3 deletions src/operator/quantization/quantized_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
#include <mxnet/op_attr_types.h>
#include "../nn/pooling-inl.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../nn/mkldnn/mkldnn_pooling-inl.h"
#endif

Expand Down Expand Up @@ -98,7 +98,7 @@ bool QuantizedPoolingType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_type->size(), 3U);
CHECK_EQ(out_type->size(), 3U);
if (param.pool_type == pool_enum::kMaxPooling || param.pool_type == pool_enum::kAvgPooling) {
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
TYPE_ASSIGN_CHECK(*out_type, 0, (*in_type)[0]);
#else
TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8);
Expand All @@ -122,7 +122,7 @@ inline static bool QuantizedPoolingStorageType(const nnvm::NodeAttrs &attrs,
CHECK_EQ(in_attrs->size(), 3);

*dispatch_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) {
*dispatch_mode = DispatchMode::kFComputeEx;
Expand Down

0 comments on commit bc91f5b

Please sign in to comment.