Skip to content

Commit

Permalink
Fix (apache#16781)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Nov 15, 2019
1 parent 73f8a22 commit ade0682
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/executor/pointwise_fusion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,7 @@ namespace {
auto node = nnvm::Node::Create();
subgraph_sym.outputs = subgraph.outputs;
node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
std::ostringstream name_oss;
// the name of the new node will be the concatenation of all the node names in the subgraph
DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) {
if (n->op() != nullptr)
name_oss << n->op()->name << "_";
});
auto subgraph_name = name_oss.str();
subgraph_name.pop_back();
node->attrs.name = subgraph_name;
node->attrs.name = "FusedOp";
node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
node->attrs.op = Op::Get("_FusedOp");
Expand Down Expand Up @@ -152,16 +144,16 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& sub
auto it = node->control_deps.begin();
static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
std::vector<nnvm::NodePtr> new_control_deps;
while (it != node->control_deps.end()) {
// Use the first control dependency to get the inferattr helper
if (it != node->control_deps.end()) {
if (subgraph_set.count(it->get())) {
new_control_deps.push_back(*it);
} else {
if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
uint32_t node_id = subgraph_node->control_deps.size();
subgraph_node->control_deps.push_back(*it);
auto helper_node = op::MakeNode("_FusedOpOutHelper",
subgraph_node->attrs.name + "_"
+ node->attrs.name + "_outhelper",
"FusedOp_" + node->attrs.name + "_outhelper",
nullptr,
nullptr,
nullptr);
Expand All @@ -180,6 +172,17 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& sub
}
});

std::ostringstream name_oss;
// the name of the new node will be the concatenation of all the node names in the subgraph
DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) {
if (n->op() != nullptr) {
name_oss << n->op()->name << "_";
}
});
auto subgraph_name = name_oss.str();
subgraph_name.pop_back();
subgraph_node->attrs.name = subgraph_name;

const auto& index = subgraph.indexed_graph();
DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::NodePtr& node) {
for (auto &e : node->control_deps) {
Expand Down

0 comments on commit ade0682

Please sign in to comment.