Skip to content

Commit 965a67a

Browse files
committed
chore: reorganize code structure initially
Signed-off-by: Bo Wang <wangbo1995ee@163.com>
1 parent cfc68ce commit 965a67a

File tree

7 files changed

+423
-371
lines changed

7 files changed

+423
-371
lines changed

core/partitioning/BUILD

+4
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@ config_setting(
1010
cc_library(
1111
name = "partitioning",
1212
hdrs = [
13+
"SegmentedBlock.h",
14+
"shape_analysis.h",
1315
"partitioning.h",
1416
],
1517
srcs = [
18+
"SegmentedBlock.cpp",
19+
"shape_analysis.cpp",
1620
"partitioning.cpp",
1721
],
1822
deps = [

core/partitioning/SegmentedBlock.cpp

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include "SegmentedBlock.h"
2+
3+
namespace trtorch {
4+
namespace core {
5+
namespace partitioning {
6+
7+
torch::jit::Value* getOrAddInputForValue(
8+
torch::jit::Value* old_value,
9+
std::shared_ptr<torch::jit::Graph>& graph,
10+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) {
11+
if (old_to_new.count(old_value) == 0) {
12+
auto node = old_value->node();
13+
14+
if (node->kind() == torch::jit::prim::Constant) {
15+
auto new_const = graph->createClone(node, {nullptr});
16+
graph->block()->prependNode(new_const);
17+
return new_const->output();
18+
}
19+
auto new_value = graph->block()->addInput();
20+
old_to_new[old_value] = new_value;
21+
new_value->copyMetadata(old_value);
22+
// mapping from new graph input Values to original graph values
23+
old_to_new[new_value] = old_value;
24+
return new_value;
25+
} else {
26+
return old_to_new[old_value];
27+
}
28+
}
29+
30+
torch::jit::Node* cloneNode(
31+
torch::jit::Node* node,
32+
std::shared_ptr<torch::jit::Graph>& graph,
33+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) {
34+
auto* block = graph->block();
35+
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v, graph, old_to_new); };
36+
37+
// create node for current graph by using the metadata in node and input Values in env
38+
auto new_node = block->appendNode(graph->createClone(node, env));
39+
for (size_t i = 0; i < node->outputs().size(); ++i) {
40+
auto oo = node->outputs()[i];
41+
auto no = new_node->outputs()[i];
42+
old_to_new[oo] = no;
43+
}
44+
return new_node;
45+
}
46+
47+
std::vector<SegmentedBlock> segment_graph(
48+
std::shared_ptr<torch::jit::Graph> g,
49+
const conversion::TorchFallback& fallback_info) {
50+
auto min_block_size = fallback_info.min_block_size;
51+
std::unordered_set<std::string> forced_fallback_operators(
52+
fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end());
53+
54+
auto nodes = g->block()->nodes();
55+
std::vector<SegmentedBlock> segmented_blocks;
56+
57+
// segment the nodes
58+
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
59+
for (const auto n : nodes) {
60+
if (n->kind() == torch::jit::prim::Constant)
61+
continue;
62+
63+
std::string node_string(n->kind().toQualString());
64+
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
65+
tensorrt_nodes.push_back(n);
66+
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
67+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
68+
pytorch_nodes.clear();
69+
}
70+
} else {
71+
if (tensorrt_nodes.size() >= min_block_size) {
72+
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
73+
} else {
74+
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
75+
}
76+
tensorrt_nodes.clear();
77+
pytorch_nodes.push_back(n);
78+
}
79+
}
80+
81+
// if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num <
82+
// min_block_size
83+
if (!pytorch_nodes.empty()) {
84+
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
85+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
86+
} else {
87+
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
88+
}
89+
90+
return std::move(segmented_blocks);
91+
}
92+
93+
} // namespace partitioning
94+
} // namespace core
95+
} // namespace trtorch

core/partitioning/SegmentedBlock.h

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include "core/conversion/conversion.h"
6+
#include "torch/csrc/jit/ir/ir.h"
7+
8+
namespace trtorch {
9+
namespace core {
10+
namespace partitioning {
11+
12+
torch::jit::Value* getOrAddInputForValue(
13+
torch::jit::Value* old_value,
14+
std::shared_ptr<torch::jit::Graph>& graph,
15+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new);
16+
17+
torch::jit::Node* cloneNode(
18+
torch::jit::Node* node,
19+
std::shared_ptr<torch::jit::Graph>& graph,
20+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new);
21+
22+
struct SegmentedBlock {
23+
public:
24+
enum SegmentedBlockTarget {
25+
kTorch,
26+
kTensorRT,
27+
};
28+
29+
SegmentedBlock() = default;
30+
31+
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
32+
33+
SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes)
34+
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
35+
for (auto& node : nodes) {
36+
nodes_.push_back(node);
37+
appendNode(node);
38+
}
39+
registerInputs();
40+
}
41+
42+
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
43+
44+
enum SegmentedBlockTarget target() {
45+
return target_;
46+
}
47+
48+
void appendNode(torch::jit::Node* n) {
49+
cloneNode(n, g_, old_to_new_);
50+
}
51+
52+
void registerInputs() {
53+
for (auto& value : g_->inputs()) {
54+
inputs_.push_back(old_to_new_[value]);
55+
}
56+
}
57+
58+
void registerOutput(torch::jit::Value* raw_output) {
59+
outputs_.push_back(raw_output);
60+
g_->registerOutput(old_to_new_[raw_output]);
61+
}
62+
63+
torch::jit::Block* block() {
64+
return g_->block();
65+
}
66+
67+
c10::ArrayRef<torch::jit::Value*> inputs() {
68+
return g_->inputs();
69+
}
70+
71+
void eraseInput(size_t i) {
72+
inputs_.erase(inputs_.begin() + i);
73+
g_->eraseInput(i);
74+
}
75+
76+
c10::ArrayRef<torch::jit::Value*> outputs() {
77+
return g_->outputs();
78+
}
79+
80+
void eraseOutput(size_t i) {
81+
outputs_.erase(outputs_.begin() + i);
82+
g_->eraseOutput(i);
83+
}
84+
85+
const std::vector<torch::jit::Value*>& raw_inputs() const {
86+
return inputs_;
87+
}
88+
89+
const std::vector<torch::jit::Value*>& raw_outputs() const {
90+
return outputs_;
91+
}
92+
93+
const std::vector<torch::jit::Node*>& raw_nodes() const {
94+
return nodes_;
95+
}
96+
97+
bool contain_raw_value(torch::jit::Value* input) {
98+
return old_to_new_.count(input);
99+
}
100+
101+
torch::jit::graph_node_list nodes() {
102+
return g_->nodes();
103+
}
104+
105+
void register_inshape(std::vector<nvinfer1::Dims>& in_shape) {
106+
in_shape_ = in_shape;
107+
}
108+
109+
const std::vector<nvinfer1::Dims>& in_shape() const {
110+
return in_shape_;
111+
}
112+
113+
std::shared_ptr<torch::jit::Graph>& g() {
114+
return g_;
115+
}
116+
117+
void update_graph(std::shared_ptr<torch::jit::Graph> new_g) {
118+
g_ = new_g;
119+
}
120+
121+
void update_target(SegmentedBlockTarget new_target) {
122+
target_ = new_target;
123+
}
124+
125+
private:
126+
SegmentedBlockTarget target_;
127+
std::vector<nvinfer1::Dims> in_shape_;
128+
std::vector<torch::jit::Value*> inputs_;
129+
std::vector<torch::jit::Value*> outputs_;
130+
std::vector<torch::jit::Node*> nodes_;
131+
std::shared_ptr<torch::jit::Graph> g_;
132+
std::string trt_engine;
133+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
134+
};
135+
136+
std::vector<SegmentedBlock> segment_graph(
137+
std::shared_ptr<torch::jit::Graph> g,
138+
const conversion::TorchFallback& fallback_info);
139+
140+
} // namespace partitioning
141+
} // namespace core
142+
} // namespace trtorch

0 commit comments

Comments
 (0)