Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support min_block_size != 1 caused fallback nodes re-segmentation #1195

Merged
merged 4 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ GraphAndMapping ConstructFallbackGraph(
}

for (auto& seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n");
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

Expand Down Expand Up @@ -436,7 +436,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto graph_and_mapping =
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
new_g = graph_and_mapping.first;
LOG_INFO("Segmented Graph: " << *new_g);
LOG_INFO("Graph after Fallback: " << *new_g);

// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
// module
Expand Down
78 changes: 68 additions & 10 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,13 @@ std::vector<torch::jit::Node*> getDependencyNodes(
return stk;
}

void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
void find_all_fallback_nodes(
std::unordered_map<torch::jit::Node*, int>& initial_fallback_nodes,
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
// global_fallback_nodes are the fallback nodes that we maintain globally
std::queue<torch::jit::Node*> q;
for (auto& node : fallback_nodes) {
for (auto& node : initial_fallback_nodes) {
q.push(node.first);
}

Expand All @@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
// for every node that produces this fallback node's NonTensor input, they should fallback too
for (auto input : cur_node->inputs()) {
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
fallback_nodes.insert({input->node(), 4}).second) {
global_fallback_nodes.insert({input->node(), 4}).second) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider an enum rather than using raw values for fallback reason.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! you mean we should use enum rather than raw values like 4 here right? thanks!

q.push(input->node());
}
}
Expand All @@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
if (!isTensor(output)) {
for (auto use : output->uses()) {
auto node = use.user;
if (node->kind() != torch::jit::prim::Constant && fallback_nodes.insert({node, 4}).second) {
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, 4}).second) {
q.push(node);
}
}
Expand Down Expand Up @@ -231,6 +235,8 @@ bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::ji
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
} else if (fallback_nodes.at(n) == 2) {
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
} else if (fallback_nodes.at(n) == 3) {
LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n));
} else {
LOG_GRAPH(
"Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
Expand Down Expand Up @@ -284,22 +290,74 @@ void get_fallback_nodes(
return;
}

std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
torch::jit::Block* block,
const std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
size_t min_block_size) {
auto nodes = block->nodes();
std::vector<torch::jit::Node*> cur_trt_nodes;
std::vector<torch::jit::Node*> min_block_fallback_nodes;
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant)
continue;

// check if current node fallback or not
if (!global_fallback_nodes.count(n)) {
// if this node is not in fallback nodes, then it's in trt segments
cur_trt_nodes.push_back(n);
} else {
if (cur_trt_nodes.size() < min_block_size) {
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
}
cur_trt_nodes.clear();
}
}
if (cur_trt_nodes.size() < min_block_size) {
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
}
return min_block_fallback_nodes;
}

void find_min_block_size_fallback_nodes(
torch::jit::Block* block,
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
size_t min_block_size) {
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
std::unordered_map<torch::jit::Node*, int> initial_fallback_nodes;

// keep fallback until all segments meet the min_block_size requirement
while (!min_block_fallback_nodes.empty()) {
for (const auto i : min_block_fallback_nodes) {
initial_fallback_nodes.insert({i, 3});
}
global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end());
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
find_all_fallback_nodes(initial_fallback_nodes, global_fallback_nodes);
// keep traverse the graph until there is no node fallback because of min_block_size
min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
}
}

PartitionedGraph segment_graph(
torch::jit::Block* block,
const PartitionInfo& partition_info,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
auto min_block_size = partition_info.min_block_size;
std::unordered_set<std::string> forced_fallback_ops(
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());

// get the initial fallback nodes (nodes that are unsupported or forced fallback)
get_fallback_nodes(block, forced_fallback_ops, fallback_nodes);
get_fallback_nodes(block, forced_fallback_ops, global_fallback_nodes);

// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
// input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
// that produces this input should also fallback
// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
find_all_fallback_nodes(fallback_nodes);
find_all_fallback_nodes(global_fallback_nodes, global_fallback_nodes);

// find all fallback nodes because of the min_block_size requirement
find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size);

auto nodes = block->nodes();

Expand All @@ -313,7 +371,7 @@ PartitionedGraph segment_graph(
continue;
}

if (check_node_fallback(n, fallback_nodes)) {
if (check_node_fallback(n, global_fallback_nodes)) {
in_prog_trt_blk_nodes.push_back(n);

// If there is an active PyTorch block and we have passed the threshold for a valid TRT
Expand Down Expand Up @@ -379,11 +437,11 @@ PartitionedGraph Partition(
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
const PartitionInfo& partition_info,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
LOG_DEBUG(partition_info);
// segment lowering global graph into blocks
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, fallback_nodes);
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);

// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks

Expand Down
40 changes: 40 additions & 0 deletions tests/core/partitioning/test_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,46 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
}

TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor):
%3 : int[] = prim::Constant[value=[-1, 5]]()
%4 : int[] = prim::Constant[value=[-1]]()
%5 : int = prim::Constant[value=2]()
%6 : int = prim::Constant[value=4]()
%7 : int = prim::Constant[value=5]()
%8 : int = prim::Constant[value=0]()
%9 : bool = prim::Constant[value=0]()
%10 : NoneType = prim::Constant()
%11 : int = prim::Constant[value=1]()
%12: Tensor = aten::reshape(%1, %4)
%13: Tensor = aten::reshape(%2, %3)
%14: Tensor = aten::reshape(%1, %3)
%15 : Tensor = aten::to(%12, %6, %9, %9, %10)
%16 : int = aten::size(%1, %8)
%17 : int[] = prim::ListConstruct(%16, %6, %5, %7)
%18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11)
%20 : Tensor = aten::reshape(%18, %17)
return (%20))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

torch_tensorrt::core::partitioning::PartitionInfo partition_info;
partition_info.enabled = true;
partition_info.min_block_size = 3;
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
ASSERT_TRUE(
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1));
ASSERT_TRUE(
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2, 3}, {4, 5, 6, 7}}));
}

TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down