Skip to content

Commit

Permalink
Merge pull request #817 from cyfwry/811
Browse files Browse the repository at this point in the history
Fix the bug that fallback does not support more than one output
  • Loading branch information
narendasan authored Jan 31, 2022
2 parents 68dd005 + a874e35 commit 726b031
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,21 @@ GraphAndMapping ConstructFallbackGraph(
}
}

for (auto& output : block->outputs()) {
if (old_to_new_g.count(output)) {
new_g->registerOutput(old_to_new_g[output]);
if (block->outputs().size() > 1) {
std::vector<torch::jit::Value*> fallback_graph_vector;
for (auto& output : block->outputs()) {
if (old_to_new_g.count(output)) {
fallback_graph_vector.push_back(old_to_new_g[output]);
}
}
torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector);
auto return_tuple_node = new_g->createTuple(fallback_graph_outputs);
new_g->block()->appendNode(return_tuple_node);
// Set the output as the produced tuple
new_g->registerOutput(return_tuple_node->outputs()[0]);
} else {
if (old_to_new_g.count(block->outputs()[0])) {
new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
}
}
return {new_g, old_to_new_g};
Expand Down

0 comments on commit 726b031

Please sign in to comment.