Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
batchnorm fall back.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jun 28, 2018
1 parent 1b69e4b commit 97b4cdd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 24 deletions.
28 changes: 4 additions & 24 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,37 +453,17 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 5);
CHECK_EQ(out_attrs->size(), 3);
DispatchMode wanted_mode;
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask)
wanted_mode = DispatchMode::kFComputeEx;
else
#endif
wanted_mode = DispatchMode::kFCompute;
for (int& v : *in_attrs) {
if (v == - 1) v = kDefaultStorage;
}
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
in_attrs, out_attrs);
}

static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode;
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask)
wanted_mode = DispatchMode::kFComputeEx;
else
#endif
wanted_mode = DispatchMode::kFCompute;
for (int& v : *in_attrs) {
if (v == - 1) v = kDefaultStorage;
}
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
in_attrs, out_attrs);
}

std::vector<nnvm::NodeEntry> BatchNormGrad(const nnvm::NodePtr& n,
Expand Down
7 changes: 7 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,13 @@ class OpCheck {
const std::vector<mxnet::NDArray> &outputs_);
};

bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
bool support_mkldnn,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);

#define MKLDNN_OPCHECK_INIT(backward, num_checks, inputs, outputs) \
static bool debug = dmlc::GetEnv("MXNET_MKLDNN_DEBUG", false); \
OpCheck check(backward, num_checks); \
Expand Down
29 changes: 29 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <atomic>
#include "./mkldnn_base-inl.h"
#include "./mkldnn_ops-inl.h"
#include "../../operator_common.h"

namespace mxnet {

Expand Down Expand Up @@ -506,6 +507,34 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs,
}
}

bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
bool support_mkldnn,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
for (int& v : *in_attrs)
if (v == - 1) v = kDefaultStorage;

DispatchMode wanted_mode;
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask && support_mkldnn)
wanted_mode = DispatchMode::kFComputeEx;
else
#endif
wanted_mode = DispatchMode::kFCompute;

bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched = op::storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
}
if (!dispatched) {
dispatched = op::dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}

} // namespace mxnet

#endif

0 comments on commit 97b4cdd

Please sign in to comment.