1
+ #include " SegmentedBlock.h"
2
+
3
+ namespace trtorch {
4
+ namespace core {
5
+ namespace partitioning {
6
+
7
+ torch::jit::Value* getOrAddInputForValue (
8
+ torch::jit::Value* old_value,
9
+ std::shared_ptr<torch::jit::Graph>& graph,
10
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) {
11
+ if (old_to_new.count (old_value) == 0 ) {
12
+ auto node = old_value->node ();
13
+
14
+ if (node->kind () == torch::jit::prim::Constant) {
15
+ auto new_const = graph->createClone (node, {nullptr });
16
+ graph->block ()->prependNode (new_const);
17
+ return new_const->output ();
18
+ }
19
+ auto new_value = graph->block ()->addInput ();
20
+ old_to_new[old_value] = new_value;
21
+ new_value->copyMetadata (old_value);
22
+ // mapping from new graph input Values to original graph values
23
+ old_to_new[new_value] = old_value;
24
+ return new_value;
25
+ } else {
26
+ return old_to_new[old_value];
27
+ }
28
+ }
29
+
30
+ torch::jit::Node* cloneNode (
31
+ torch::jit::Node* node,
32
+ std::shared_ptr<torch::jit::Graph>& graph,
33
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) {
34
+ auto * block = graph->block ();
35
+ auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue (v, graph, old_to_new); };
36
+
37
+ // create node for current graph by using the metadata in node and input Values in env
38
+ auto new_node = block->appendNode (graph->createClone (node, env));
39
+ for (size_t i = 0 ; i < node->outputs ().size (); ++i) {
40
+ auto oo = node->outputs ()[i];
41
+ auto no = new_node->outputs ()[i];
42
+ old_to_new[oo] = no;
43
+ }
44
+ return new_node;
45
+ }
46
+
47
+ std::vector<SegmentedBlock> segment_graph (
48
+ std::shared_ptr<torch::jit::Graph> g,
49
+ const conversion::TorchFallback& fallback_info) {
50
+ auto min_block_size = fallback_info.min_block_size ;
51
+ std::unordered_set<std::string> forced_fallback_operators (
52
+ fallback_info.forced_fallback_operators .begin (), fallback_info.forced_fallback_operators .end ());
53
+
54
+ auto nodes = g->block ()->nodes ();
55
+ std::vector<SegmentedBlock> segmented_blocks;
56
+
57
+ // segment the nodes
58
+ std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
59
+ for (const auto n : nodes) {
60
+ if (n->kind () == torch::jit::prim::Constant)
61
+ continue ;
62
+
63
+ std::string node_string (n->kind ().toQualString ());
64
+ if (conversion::OpSupported (n) && !forced_fallback_operators.count (node_string)) {
65
+ tensorrt_nodes.push_back (n);
66
+ if (tensorrt_nodes.size () >= min_block_size && !pytorch_nodes.empty ()) {
67
+ segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
68
+ pytorch_nodes.clear ();
69
+ }
70
+ } else {
71
+ if (tensorrt_nodes.size () >= min_block_size) {
72
+ segmented_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
73
+ } else {
74
+ pytorch_nodes.insert (pytorch_nodes.end (), tensorrt_nodes.begin (), tensorrt_nodes.end ());
75
+ }
76
+ tensorrt_nodes.clear ();
77
+ pytorch_nodes.push_back (n);
78
+ }
79
+ }
80
+
81
+ // if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num <
82
+ // min_block_size
83
+ if (!pytorch_nodes.empty ()) {
84
+ pytorch_nodes.insert (pytorch_nodes.end (), tensorrt_nodes.begin (), tensorrt_nodes.end ());
85
+ segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
86
+ } else {
87
+ segmented_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
88
+ }
89
+
90
+ return std::move (segmented_blocks);
91
+ }
92
+
93
+ } // namespace partitioning
94
+ } // namespace core
95
+ } // namespace trtorch
0 commit comments