Skip to content

Commit

Permalink
Merge pull request #1225 from pytorch/fix_collection_partitioning
Browse files Browse the repository at this point in the history
fix: fix the error that collection input segmented into trt subgraph
  • Loading branch information
narendasan authored Aug 3, 2022
2 parents 9bce034 + 6d0b1d3 commit 253b3c7
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,26 @@ std::vector<torch::jit::Node*> getDependencyNodes(
return stk;
}

void find_nontensor_output_nodes(
// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related
// nodes
void fallback_graph_nontensor_in_out(
torch::jit::Block* block,
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
// fallback nodes that produce entire graph's nonTensor output
for (auto i : block->outputs()) {
if (!isTensor(i)) {
global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR});
}
}

// fallback nodes that consume entire graph's nonTensor input
for (auto i : block->inputs()) {
if (!isTensor(i)) {
for (auto use : i->uses()) {
global_fallback_nodes.insert({use.user, FallbackNodeType::kNON_TENSOR});
}
}
}
}

void find_all_fallback_nodes(
Expand Down Expand Up @@ -202,6 +214,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
}
}
}

std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) {
torch::jit::EliminateDeadCode(seg_block.g());
});
Expand Down Expand Up @@ -440,8 +453,9 @@ PartitionedGraph Partition(
const PartitionInfo& partition_info,
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
LOG_DEBUG(partition_info);
// if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output
find_nontensor_output_nodes(block, global_fallback_nodes);
// if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor
// output
fallback_graph_nontensor_in_out(block, global_fallback_nodes);

// segment lowering global graph into blocks
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
Expand Down

0 comments on commit 253b3c7

Please sign in to comment.