Skip to content

Commit

Permalink
Merge pull request #1195 from pytorch/support_min_block_size
Browse files Browse the repository at this point in the history
feat: support min_block_size != 1 caused fallback nodes re-segmentation
  • Loading branch information
peri044 authored Jul 25, 2022
2 parents e07687d + 52abece commit 2f896b3
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 32 deletions.
17 changes: 1 addition & 16 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ GraphAndMapping ConstructFallbackGraph(
}

for (auto& seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n");
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

Expand Down Expand Up @@ -372,15 +372,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());

// // GPU default WS size : 1 GB
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
// auto device_spec = cfg.convert_info.engine_settings.device;
// auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
// if (workspace_size == 0) {
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
// }

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
Expand All @@ -391,14 +382,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");

// // GPU default WS size : 1 GB
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
auto device_spec = cfg.convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
// if (workspace_size == 0) {
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
// }

for (const torch::jit::Method& method : mod.get_methods()) {
if (method.name().compare("forward") == 0) {
Expand Down
90 changes: 74 additions & 16 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,13 @@ std::vector<torch::jit::Node*> getDependencyNodes(
return stk;
}

void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
void find_all_fallback_nodes(
std::unordered_map<torch::jit::Node*, int>& initial_fallback_nodes,
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
// global_fallback_nodes are the fallback nodes that we maintain globally
std::queue<torch::jit::Node*> q;
for (auto& node : fallback_nodes) {
for (auto& node : initial_fallback_nodes) {
q.push(node.first);
}

Expand All @@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
// for every node that produces this fallback node's NonTensor input, they should fallback too
for (auto input : cur_node->inputs()) {
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
fallback_nodes.insert({input->node(), 4}).second) {
global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) {
q.push(input->node());
}
}
Expand All @@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
if (!isTensor(output)) {
for (auto use : output->uses()) {
auto node = use.user;
if (node->kind() != torch::jit::prim::Constant && fallback_nodes.insert({node, 4}).second) {
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
q.push(node);
}
}
Expand Down Expand Up @@ -225,12 +229,14 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {

bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
if (fallback_nodes.count(n)) {
if (fallback_nodes.at(n) == 0) {
if (fallback_nodes.at(n) == FallbackNodeType::kUNSUPPORTED) {
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
} else if (fallback_nodes.at(n) == 1) {
} else if (fallback_nodes.at(n) == FallbackNodeType::kOPERATOR_FALLBACK) {
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
} else if (fallback_nodes.at(n) == 2) {
} else if (fallback_nodes.at(n) == FallbackNodeType::kMODULE_FALLBACK) {
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
} else if (fallback_nodes.at(n) == FallbackNodeType::kMIN_BLOCK_FALLBACK) {
LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n));
} else {
LOG_GRAPH(
"Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
Expand Down Expand Up @@ -267,39 +273,91 @@ void get_fallback_nodes(

// If the op is not supported by the conversion phase it should run in PyTorch
if (!conversion::OpSupported(n)) {
fallback_nodes.insert({n, 0});
fallback_nodes.insert({n, FallbackNodeType::kUNSUPPORTED});
}

// If the user specifies the op to run in Torch it should run in PyTorch
if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) {
fallback_nodes.insert({n, 1});
fallback_nodes.insert({n, FallbackNodeType::kOPERATOR_FALLBACK});
}

// If the user specifies the module containing this op to run in torch it should run in PyTorch
const auto to_compile_sym = c10::Symbol::attr("to_compile");
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
fallback_nodes.insert({n, 2});
fallback_nodes.insert({n, FallbackNodeType::kMODULE_FALLBACK});
}
}
return;
}

std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
torch::jit::Block* block,
const std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
size_t min_block_size) {
auto nodes = block->nodes();
std::vector<torch::jit::Node*> cur_trt_nodes;
std::vector<torch::jit::Node*> min_block_fallback_nodes;
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant)
continue;

// check if current node fallback or not
if (!global_fallback_nodes.count(n)) {
// if this node is not in fallback nodes, then it's in trt segments
cur_trt_nodes.push_back(n);
} else {
if (cur_trt_nodes.size() < min_block_size) {
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
}
cur_trt_nodes.clear();
}
}
if (cur_trt_nodes.size() < min_block_size) {
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
}
return min_block_fallback_nodes;
}

void find_min_block_size_fallback_nodes(
torch::jit::Block* block,
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
size_t min_block_size) {
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
std::unordered_map<torch::jit::Node*, int> initial_fallback_nodes;

// keep fallback until all segments meet the min_block_size requirement
while (!min_block_fallback_nodes.empty()) {
for (const auto i : min_block_fallback_nodes) {
initial_fallback_nodes.insert({i, FallbackNodeType::kMIN_BLOCK_FALLBACK});
}
global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end());
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
find_all_fallback_nodes(initial_fallback_nodes, global_fallback_nodes);
// keep traverse the graph until there is no node fallback because of min_block_size
min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
}
}

PartitionedGraph segment_graph(
torch::jit::Block* block,
const PartitionInfo& partition_info,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
auto min_block_size = partition_info.min_block_size;
std::unordered_set<std::string> forced_fallback_ops(
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());

// get the initial fallback nodes (nodes that are unsupported or forced fallback)
get_fallback_nodes(block, forced_fallback_ops, fallback_nodes);
get_fallback_nodes(block, forced_fallback_ops, global_fallback_nodes);

// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
// input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
// that produces this input should also fallback
// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
find_all_fallback_nodes(fallback_nodes);
find_all_fallback_nodes(global_fallback_nodes, global_fallback_nodes);

// find all fallback nodes because of the min_block_size requirement
find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size);

auto nodes = block->nodes();

Expand All @@ -313,7 +371,7 @@ PartitionedGraph segment_graph(
continue;
}

if (check_node_fallback(n, fallback_nodes)) {
if (check_node_fallback(n, global_fallback_nodes)) {
in_prog_trt_blk_nodes.push_back(n);

// If there is an active PyTorch block and we have passed the threshold for a valid TRT
Expand Down Expand Up @@ -379,11 +437,11 @@ PartitionedGraph Partition(
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
const PartitionInfo& partition_info,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
LOG_DEBUG(partition_info);
// segment lowering global graph into blocks
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, fallback_nodes);
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);

// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks

Expand Down
14 changes: 14 additions & 0 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ namespace partitioning {

typedef std::vector<SegmentedBlock> PartitionedGraph;

enum FallbackNodeType {
/// Node is not supported by TensorRT
kUNSUPPORTED,
/// Node is explicitly forced to fallback to Pytorch due to operator fallback
kOPERATOR_FALLBACK,
/// Node is explicitly forced to fallback to Pytorch due to module fallback
kMODULE_FALLBACK,
/// This node is in a TRT segment which does not satisfy min_block_size
/// and hence is forced to fallback.
kMIN_BLOCK_FALLBACK,
/// This node produces/consumes non-tensor inputs
kNON_TENSOR,
};

PartitionedGraph segment_graph(
torch::jit::Block* block,
const PartitionInfo& partition_info,
Expand Down
40 changes: 40 additions & 0 deletions tests/core/partitioning/test_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,46 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
}

TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor):
%3 : int[] = prim::Constant[value=[-1, 5]]()
%4 : int[] = prim::Constant[value=[-1]]()
%5 : int = prim::Constant[value=2]()
%6 : int = prim::Constant[value=4]()
%7 : int = prim::Constant[value=5]()
%8 : int = prim::Constant[value=0]()
%9 : bool = prim::Constant[value=0]()
%10 : NoneType = prim::Constant()
%11 : int = prim::Constant[value=1]()
%12: Tensor = aten::reshape(%1, %4)
%13: Tensor = aten::reshape(%2, %3)
%14: Tensor = aten::reshape(%1, %3)
%15 : Tensor = aten::to(%12, %6, %9, %9, %10)
%16 : int = aten::size(%1, %8)
%17 : int[] = prim::ListConstruct(%16, %6, %5, %7)
%18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11)
%20 : Tensor = aten::reshape(%18, %17)
return (%20))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

torch_tensorrt::core::partitioning::PartitionInfo partition_info;
partition_info.enabled = true;
partition_info.min_block_size = 3;
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
ASSERT_TRUE(
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1));
ASSERT_TRUE(
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2, 3}, {4, 5, 6, 7}}));
}

TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down

0 comments on commit 2f896b3

Please sign in to comment.