Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add disable attr to subgraph property #15926

Merged
merged 1 commit into from
Aug 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/tutorials/c++/subgraphAPI.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ class SgProperty : public SubgraphProperty {
};
```
`SetAttr` is optional and developer can define their own attributes to control property behavior.
There're 2 built-in attributes that used by MXNet executor.
There're some built-in attributes that used by MXNet executor.

`property_name` : std::string, name of this property.
`property_name` : std::string, name of this property, used for diagnose.

`disable` : bool, whther to disable this property.

`inference_only` : bool, apply this property only for inference. Property will be skiped when need_grad=True. Default `false` if this attribute isn't defined.

Expand Down
8 changes: 8 additions & 0 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,14 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name,
auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
if (property->HasAttr("disable") && property->GetAttr<bool>("disable") == true) {
auto full_name = property->HasAttr("property_name")
? property->GetAttr<std::string>("property_name")
: std::string();
LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
<< " is disabled.";
continue;
}
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
Expand Down
9 changes: 7 additions & 2 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1659,10 +1659,15 @@ static bool SubgraphBackendCheck(const op::SubgraphBackendPtr& backend,
static bool SubgraphPropertyCheck(const std::string& backend_name,
const op::SubgraphPropertyPtr& prop, bool need_grad,
bool verbose = false) {
auto full_name =
prop->HasAttr("property_name") ? prop->GetAttr<std::string>("property_name") : std::string();
if (prop->HasAttr("disable") && prop->GetAttr<bool>("disable") == true) {
LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
<< " is disabled.";
return false;
}
if (prop->HasAttr("inference_only") && prop->GetAttr<bool>("inference_only") == true) {
if (need_grad) {
auto full_name = prop->HasAttr("property_name") ? prop->GetAttr<std::string>("property_name")
: std::string();
if (verbose) {
LOG(INFO) << "skip partitioning graph with subgraph property " << full_name
<< " from backend " << backend_name << " as it requires `grad_req=null`.";
Expand Down
7 changes: 3 additions & 4 deletions src/operator/subgraph/mkldnn/mkldnn_conv_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,12 @@ class SgMKLDNNConvProperty : public SubgraphProperty {
}
static SubgraphPropertyPtr Create() {
static const std::string &name = "MKLDNN convolution optimization pass";
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) {
LOG(INFO) << name << " is disabled.";
return nullptr;
}
auto property = std::make_shared<SgMKLDNNConvProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle other MXNET_DISABLE_MKLDNN_XXX environments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tested all env variables listed here, and no crash is found.

property->SetAttr<bool>("disable", true);
}
return property;
}
nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
Expand Down
7 changes: 3 additions & 4 deletions src/operator/subgraph/mkldnn/mkldnn_fc_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,12 @@ class SgMKLDNNFCProperty : public SubgraphProperty {

static SubgraphPropertyPtr Create() {
static const std::string &name = "MKLDNN FullyConnected optimization pass";
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_OPT", 0)) {
LOG(INFO) << name << " is disabled.";
return nullptr;
}
auto property = std::make_shared<SgMKLDNNFCProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_OPT", 0)) {
property->SetAttr<bool>("disable", true);
}
return property;
}

Expand Down
11 changes: 7 additions & 4 deletions src/operator/subgraph/subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class SubgraphPropertyEntry {

template<typename T>
SubgraphPropertyEntry set_attr(const std::string& name, const T value) const {
entry_->SetAttr<T>(name, value);
if (entry_) entry_->SetAttr<T>(name, value);
return *this;
}

Expand Down Expand Up @@ -403,9 +403,12 @@ class SubgraphBackend {
}
}

SubgraphPropertyPtr& RegisterSubgraphProperty(const SubgraphPropertyPtr prop) {
prop_ptr_.push_back(prop);
return prop_ptr_.back();
SubgraphPropertyPtr RegisterSubgraphProperty(SubgraphPropertyPtr prop) {
if (prop) {
prop_ptr_.push_back(prop);
return prop_ptr_.back();
}
return prop;
}

const std::string& GetName() const { return name_; }
Expand Down