From 36384c82ede3f457dfee3b20c6768ddcfc1ecf97 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 16 Aug 2019 10:56:08 +0800 Subject: [PATCH] Add disable attr to subgraph property --- docs/tutorials/c++/subgraphAPI.md | 6 ++++-- src/c_api/c_api_symbolic.cc | 8 ++++++++ src/executor/graph_executor.cc | 9 +++++++-- src/operator/subgraph/mkldnn/mkldnn_conv_property.h | 7 +++---- src/operator/subgraph/mkldnn/mkldnn_fc_property.h | 7 +++---- src/operator/subgraph/subgraph_property.h | 11 +++++++---- 6 files changed, 32 insertions(+), 16 deletions(-) diff --git a/docs/tutorials/c++/subgraphAPI.md b/docs/tutorials/c++/subgraphAPI.md index 7403e2654423..91d24bd0319a 100644 --- a/docs/tutorials/c++/subgraphAPI.md +++ b/docs/tutorials/c++/subgraphAPI.md @@ -105,9 +105,11 @@ class SgProperty : public SubgraphProperty { }; ``` `SetAttr` is optional and developer can define their own attributes to control property behavior. -There're 2 built-in attributes that used by MXNet executor. +There're some built-in attributes that used by MXNet executor. -`property_name` : std::string, name of this property. +`property_name` : std::string, name of this property, used for diagnose. + +`disable` : bool, whther to disable this property. `inference_only` : bool, apply this property only for inference. Property will be skiped when need_grad=True. Default `false` if this attribute isn't defined. diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 020c0d17f0d1..f2f8e87009fe 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1044,6 +1044,14 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name, auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto property : subgraph_prop_list) { + if (property->HasAttr("disable") && property->GetAttr("disable") == true) { + auto full_name = property->HasAttr("property_name") + ? property->GetAttr("property_name") + : std::string(); + LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name + << " is disabled."; + continue; + } nnvm::Graph g = Symbol2Graph(*s); property->SetAttr("graph", g); g.attrs["subgraph_property"] = std::make_shared(property); diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index c5da0a0cf90d..7bdeac708003 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1659,10 +1659,15 @@ static bool SubgraphBackendCheck(const op::SubgraphBackendPtr& backend, static bool SubgraphPropertyCheck(const std::string& backend_name, const op::SubgraphPropertyPtr& prop, bool need_grad, bool verbose = false) { + auto full_name = + prop->HasAttr("property_name") ? prop->GetAttr("property_name") : std::string(); + if (prop->HasAttr("disable") && prop->GetAttr("disable") == true) { + LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name + << " is disabled."; + return false; + } if (prop->HasAttr("inference_only") && prop->GetAttr("inference_only") == true) { if (need_grad) { - auto full_name = prop->HasAttr("property_name") ? prop->GetAttr("property_name") - : std::string(); if (verbose) { LOG(INFO) << "skip partitioning graph with subgraph property " << full_name << " from backend " << backend_name << " as it requires `grad_req=null`."; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h index bf278ab75718..42ea9ea67fcf 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -177,13 +177,12 @@ class SgMKLDNNConvProperty : public SubgraphProperty { } static SubgraphPropertyPtr Create() { static const std::string &name = "MKLDNN convolution optimization pass"; - if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) { - LOG(INFO) << name << " is disabled."; - return nullptr; - } auto property = std::make_shared(); property->SetAttr("property_name", name); property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) { + property->SetAttr("disable", true); + } return property; } nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index 28350c2f0e99..caedecc417c0 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -132,13 +132,12 @@ class SgMKLDNNFCProperty : public SubgraphProperty { static SubgraphPropertyPtr Create() { static const std::string &name = "MKLDNN FullyConnected optimization pass"; - if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_OPT", 0)) { - LOG(INFO) << name << " is disabled."; - return nullptr; - } auto property = std::make_shared(); property->SetAttr("property_name", name); property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_OPT", 0)) { + property->SetAttr("disable", true); + } return property; } diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index e2c861ef512a..b8b125fdb5de 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -357,7 +357,7 @@ class SubgraphPropertyEntry { template SubgraphPropertyEntry set_attr(const std::string& name, const T value) const { - entry_->SetAttr(name, value); + if (entry_) entry_->SetAttr(name, value); return *this; } @@ -403,9 +403,12 @@ class SubgraphBackend { } } - SubgraphPropertyPtr& RegisterSubgraphProperty(const SubgraphPropertyPtr prop) { - prop_ptr_.push_back(prop); - return prop_ptr_.back(); + SubgraphPropertyPtr RegisterSubgraphProperty(SubgraphPropertyPtr prop) { + if (prop) { + prop_ptr_.push_back(prop); + return prop_ptr_.back(); + } + return prop; } const std::string& GetName() const { return name_; }