From 97b4cdded5e999af4d5fbc18dc4921e9864f1455 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 28 Jun 2018 13:32:44 -0700 Subject: [PATCH] batchnorm fall back. --- src/operator/nn/batch_norm.cc | 28 ++++------------------- src/operator/nn/mkldnn/mkldnn_base-inl.h | 7 ++++++ src/operator/nn/mkldnn/mkldnn_base.cc | 29 ++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 457f536d7fa0..dcd2e39e23ee 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -453,18 +453,8 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 5); CHECK_EQ(out_attrs->size(), 3); - DispatchMode wanted_mode; -#if MXNET_USE_MKLDNN == 1 - if (dev_mask == mshadow::cpu::kDevMask) - wanted_mode = DispatchMode::kFComputeEx; - else -#endif - wanted_mode = DispatchMode::kFCompute; - for (int& v : *in_attrs) { - if (v == - 1) v = kDefaultStorage; - } - return storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, wanted_mode); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, + in_attrs, out_attrs); } static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, @@ -472,18 +462,8 @@ static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, DispatchMode *dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - DispatchMode wanted_mode; -#if MXNET_USE_MKLDNN == 1 - if (dev_mask == mshadow::cpu::kDevMask) - wanted_mode = DispatchMode::kFComputeEx; - else -#endif - wanted_mode = DispatchMode::kFCompute; - for (int& v : *in_attrs) { - if (v == - 1) v = kDefaultStorage; - } - return storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, wanted_mode); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, + in_attrs, out_attrs); } std::vector BatchNormGrad(const nnvm::NodePtr& n, diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index c6e7f9bdefdc..f77d113dd1d7 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -492,6 +492,13 @@ class OpCheck { const std::vector &outputs_); }; +bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + bool support_mkldnn, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs); + #define MKLDNN_OPCHECK_INIT(backward, num_checks, inputs, outputs) \ static bool debug = dmlc::GetEnv("MXNET_MKLDNN_DEBUG", false); \ OpCheck check(backward, num_checks); \ diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 858f8e3261f2..2c8dea895823 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -22,6 +22,7 @@ #include #include "./mkldnn_base-inl.h" #include "./mkldnn_ops-inl.h" +#include "../../operator_common.h" namespace mxnet { @@ -506,6 +507,34 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs, } } +bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + bool support_mkldnn, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + for (int& v : *in_attrs) + if (v == - 1) v = kDefaultStorage; + + DispatchMode wanted_mode; +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask && support_mkldnn) + wanted_mode = DispatchMode::kFComputeEx; + else +#endif + wanted_mode = DispatchMode::kFCompute; + + bool dispatched = false; + if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { + dispatched = op::storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, wanted_mode); + } + if (!dispatched) { + dispatched = op::dispatch_fallback(out_attrs, dispatch_mode); + } + return dispatched; +} + } // namespace mxnet #endif