diff --git a/core/compiler.cpp b/core/compiler.cpp index b7ab5dd1c8..fc1cc66aee 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -240,7 +240,7 @@ GraphAndMapping ConstructFallbackGraph( } for (auto& seg_block : segmented_blocks) { - LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n"); + LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n"); std::ostringstream trt_engine_id; trt_engine_id << reinterpret_cast(&seg_block); @@ -372,15 +372,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: // Infer the type of an input from the weights of the calculation auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); - // // GPU default WS size : 1 GB - // // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X. - // auto workspace_size = cfg.convert_info.engine_settings.workspace_size; - // auto device_spec = cfg.convert_info.engine_settings.device; - // auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); - // if (workspace_size == 0) { - // cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device); - // } - MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); @@ -391,14 +382,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) { torch::jit::Module new_mod(mod._ivalue()->name() + "_trt"); - // // GPU default WS size : 1 GB - // // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X. - // auto workspace_size = cfg.convert_info.engine_settings.workspace_size; auto device_spec = cfg.convert_info.engine_settings.device; auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); - // if (workspace_size == 0) { - // cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device); - // } for (const torch::jit::Method& method : mod.get_methods()) { if (method.name().compare("forward") == 0) { diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 476d6fcfba..8fcd29f7a8 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -98,9 +98,13 @@ std::vector getDependencyNodes( return stk; } -void find_all_fallback_nodes(std::unordered_map& fallback_nodes) { +void find_all_fallback_nodes( + std::unordered_map& initial_fallback_nodes, + std::unordered_map& global_fallback_nodes) { + // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function + // global_fallback_nodes are the fallback nodes that we maintain globally std::queue q; - for (auto& node : fallback_nodes) { + for (auto& node : initial_fallback_nodes) { q.push(node.first); } @@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map& fallbac // for every node that produces this fallback node's NonTensor input, they should fallback too for (auto input : cur_node->inputs()) { if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && - fallback_nodes.insert({input->node(), 4}).second) { + global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) { q.push(input->node()); } } @@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map& fallbac if (!isTensor(output)) { for (auto use : output->uses()) { auto node = use.user; - if (node->kind() != torch::jit::prim::Constant && fallback_nodes.insert({node, 4}).second) { + if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) { q.push(node); } } @@ -225,12 +229,14 @@ bool checkLoopEvaluatable(torch::jit::Node* n) { bool check_node_fallback(torch::jit::Node* n, const std::unordered_map& fallback_nodes) { if (fallback_nodes.count(n)) { - if (fallback_nodes.at(n) == 0) { + if (fallback_nodes.at(n) == FallbackNodeType::kUNSUPPORTED) { LOG_GRAPH("Node not supported by conversion: " << util::node_info(n)); - } else if (fallback_nodes.at(n) == 1) { + } else if (fallback_nodes.at(n) == FallbackNodeType::kOPERATOR_FALLBACK) { LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n)); - } else if (fallback_nodes.at(n) == 2) { + } else if (fallback_nodes.at(n) == FallbackNodeType::kMODULE_FALLBACK) { LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n)); + } else if (fallback_nodes.at(n) == FallbackNodeType::kMIN_BLOCK_FALLBACK) { + LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n)); } else { LOG_GRAPH( "Node fallback to Torch because the NonTensor dependencies with other fallback nodes: " @@ -267,39 +273,91 @@ void get_fallback_nodes( // If the op is not supported by the conversion phase it should run in PyTorch if (!conversion::OpSupported(n)) { - fallback_nodes.insert({n, 0}); + fallback_nodes.insert({n, FallbackNodeType::kUNSUPPORTED}); } // If the user specifies the op to run in Torch it should run in PyTorch if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) { - fallback_nodes.insert({n, 1}); + fallback_nodes.insert({n, FallbackNodeType::kOPERATOR_FALLBACK}); } // If the user specifies the module containing this op to run in torch it should run in PyTorch const auto to_compile_sym = c10::Symbol::attr("to_compile"); if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { - fallback_nodes.insert({n, 2}); + fallback_nodes.insert({n, FallbackNodeType::kMODULE_FALLBACK}); } } return; } +std::vector traverse_nodes_for_min_block_size( + torch::jit::Block* block, + const std::unordered_map& global_fallback_nodes, + size_t min_block_size) { + auto nodes = block->nodes(); + std::vector cur_trt_nodes; + std::vector min_block_fallback_nodes; + for (const auto n : nodes) { + if (n->kind() == torch::jit::prim::Constant) + continue; + + // check if current node fallback or not + if (!global_fallback_nodes.count(n)) { + // if this node is not in fallback nodes, then it's in trt segments + cur_trt_nodes.push_back(n); + } else { + if (cur_trt_nodes.size() < min_block_size) { + min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); + } + cur_trt_nodes.clear(); + } + } + if (cur_trt_nodes.size() < min_block_size) { + min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); + } + return min_block_fallback_nodes; +} + +void find_min_block_size_fallback_nodes( + torch::jit::Block* block, + std::unordered_map& global_fallback_nodes, + size_t min_block_size) { + // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement + auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size); + std::unordered_map initial_fallback_nodes; + + // keep fallback until all segments meet the min_block_size requirement + while (!min_block_fallback_nodes.empty()) { + for (const auto i : min_block_fallback_nodes) { + initial_fallback_nodes.insert({i, FallbackNodeType::kMIN_BLOCK_FALLBACK}); + } + global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end()); + // find the fallback nodes because of dependency with min_block_size caused fallback nodes + find_all_fallback_nodes(initial_fallback_nodes, global_fallback_nodes); + // keep traverse the graph until there is no node fallback because of min_block_size + min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size); + } +} + PartitionedGraph segment_graph( torch::jit::Block* block, const PartitionInfo& partition_info, - std::unordered_map& fallback_nodes) { + std::unordered_map& global_fallback_nodes) { auto min_block_size = partition_info.min_block_size; std::unordered_set forced_fallback_ops( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); // get the initial fallback nodes (nodes that are unsupported or forced fallback) - get_fallback_nodes(block, forced_fallback_ops, fallback_nodes); + get_fallback_nodes(block, forced_fallback_ops, global_fallback_nodes); // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node // that produces this input should also fallback // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported - find_all_fallback_nodes(fallback_nodes); + find_all_fallback_nodes(global_fallback_nodes, global_fallback_nodes); + + // find all fallback nodes because of the min_block_size requirement + find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size); auto nodes = block->nodes(); @@ -313,7 +371,7 @@ PartitionedGraph segment_graph( continue; } - if (check_node_fallback(n, fallback_nodes)) { + if (check_node_fallback(n, global_fallback_nodes)) { in_prog_trt_blk_nodes.push_back(n); // If there is an active PyTorch block and we have passed the threshold for a valid TRT @@ -379,11 +437,11 @@ PartitionedGraph Partition( torch::jit::Block* block, std::unordered_map& example_tensor_map, const PartitionInfo& partition_info, - std::unordered_map& fallback_nodes) { + std::unordered_map& global_fallback_nodes) { LOG_DEBUG(partition_info); // segment lowering global graph into blocks LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); - PartitionedGraph segmented_blocks = segment_graph(block, partition_info, fallback_nodes); + PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes); // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index fce88134b7..df64f582a4 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -16,6 +16,20 @@ namespace partitioning { typedef std::vector PartitionedGraph; +enum FallbackNodeType { + /// Node is not supported by TensorRT + kUNSUPPORTED, + /// Node is explicitly forced to fallback to Pytorch due to operator fallback + kOPERATOR_FALLBACK, + /// Node is explicitly forced to fallback to Pytorch due to module fallback + kMODULE_FALLBACK, + /// This node is in a TRT segment which does not satisfy min_block_size + /// and hence is forced to fallback. + kMIN_BLOCK_FALLBACK, + /// This node produces/consumes non-tensor inputs + kNON_TENSOR, +}; + PartitionedGraph segment_graph( torch::jit::Block* block, const PartitionInfo& partition_info, diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index bf32bcf918..bb6aa086fb 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -120,6 +120,46 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}})); } +TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor, + %2 : Tensor): + %3 : int[] = prim::Constant[value=[-1, 5]]() + %4 : int[] = prim::Constant[value=[-1]]() + %5 : int = prim::Constant[value=2]() + %6 : int = prim::Constant[value=4]() + %7 : int = prim::Constant[value=5]() + %8 : int = prim::Constant[value=0]() + %9 : bool = prim::Constant[value=0]() + %10 : NoneType = prim::Constant() + %11 : int = prim::Constant[value=1]() + %12: Tensor = aten::reshape(%1, %4) + %13: Tensor = aten::reshape(%2, %3) + %14: Tensor = aten::reshape(%1, %3) + %15 : Tensor = aten::to(%12, %6, %9, %9, %10) + %16 : int = aten::size(%1, %8) + %17 : int[] = prim::ListConstruct(%16, %6, %5, %7) + %18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11) + %20 : Tensor = aten::reshape(%18, %17) + return (%20))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + torch_tensorrt::core::partitioning::PartitionInfo partition_info; + partition_info.enabled = true; + partition_info.min_block_size = 3; + std::unordered_map fallback_nodes; + std::vector segmented_blocks = + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); + ASSERT_TRUE( + checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE( + checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2, 3}, {4, 5, 6, 7}})); +} + TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { const auto graph = R"IR( graph(%0 : Tensor,