-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: reorganize code structure initially
Signed-off-by: Bo Wang <wangbo1995ee@163.com>
- Loading branch information
Showing
7 changed files
with
423 additions
and
371 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#include "SegmentedBlock.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace partitioning { | ||
|
||
torch::jit::Value* getOrAddInputForValue( | ||
torch::jit::Value* old_value, | ||
std::shared_ptr<torch::jit::Graph>& graph, | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) { | ||
if (old_to_new.count(old_value) == 0) { | ||
auto node = old_value->node(); | ||
|
||
if (node->kind() == torch::jit::prim::Constant) { | ||
auto new_const = graph->createClone(node, {nullptr}); | ||
graph->block()->prependNode(new_const); | ||
return new_const->output(); | ||
} | ||
auto new_value = graph->block()->addInput(); | ||
old_to_new[old_value] = new_value; | ||
new_value->copyMetadata(old_value); | ||
// mapping from new graph input Values to original graph values | ||
old_to_new[new_value] = old_value; | ||
return new_value; | ||
} else { | ||
return old_to_new[old_value]; | ||
} | ||
} | ||
|
||
torch::jit::Node* cloneNode( | ||
torch::jit::Node* node, | ||
std::shared_ptr<torch::jit::Graph>& graph, | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) { | ||
auto* block = graph->block(); | ||
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v, graph, old_to_new); }; | ||
|
||
// create node for current graph by using the metadata in node and input Values in env | ||
auto new_node = block->appendNode(graph->createClone(node, env)); | ||
for (size_t i = 0; i < node->outputs().size(); ++i) { | ||
auto oo = node->outputs()[i]; | ||
auto no = new_node->outputs()[i]; | ||
old_to_new[oo] = no; | ||
} | ||
return new_node; | ||
} | ||
|
||
std::vector<SegmentedBlock> segment_graph( | ||
std::shared_ptr<torch::jit::Graph> g, | ||
const conversion::TorchFallback& fallback_info) { | ||
auto min_block_size = fallback_info.min_block_size; | ||
std::unordered_set<std::string> forced_fallback_operators( | ||
fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end()); | ||
|
||
auto nodes = g->block()->nodes(); | ||
std::vector<SegmentedBlock> segmented_blocks; | ||
|
||
// segment the nodes | ||
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes; | ||
for (const auto n : nodes) { | ||
if (n->kind() == torch::jit::prim::Constant) | ||
continue; | ||
|
||
std::string node_string(n->kind().toQualString()); | ||
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) { | ||
tensorrt_nodes.push_back(n); | ||
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) { | ||
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); | ||
pytorch_nodes.clear(); | ||
} | ||
} else { | ||
if (tensorrt_nodes.size() >= min_block_size) { | ||
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); | ||
} else { | ||
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); | ||
} | ||
tensorrt_nodes.clear(); | ||
pytorch_nodes.push_back(n); | ||
} | ||
} | ||
|
||
// if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num < | ||
// min_block_size | ||
if (!pytorch_nodes.empty()) { | ||
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); | ||
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); | ||
} else { | ||
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); | ||
} | ||
|
||
return std::move(segmented_blocks); | ||
} | ||
|
||
} // namespace partitioning | ||
} // namespace core | ||
} // namespace trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#pragma once | ||
|
||
#include <vector> | ||
|
||
#include "core/conversion/conversion.h" | ||
#include "torch/csrc/jit/ir/ir.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace partitioning { | ||
|
||
torch::jit::Value* getOrAddInputForValue( | ||
torch::jit::Value* old_value, | ||
std::shared_ptr<torch::jit::Graph>& graph, | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new); | ||
|
||
torch::jit::Node* cloneNode( | ||
torch::jit::Node* node, | ||
std::shared_ptr<torch::jit::Graph>& graph, | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new); | ||
|
||
struct SegmentedBlock { | ||
public: | ||
enum SegmentedBlockTarget { | ||
kTorch, | ||
kTensorRT, | ||
}; | ||
|
||
SegmentedBlock() = default; | ||
|
||
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {} | ||
|
||
SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes) | ||
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) { | ||
for (auto& node : nodes) { | ||
nodes_.push_back(node); | ||
appendNode(node); | ||
} | ||
registerInputs(); | ||
} | ||
|
||
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {} | ||
|
||
enum SegmentedBlockTarget target() { | ||
return target_; | ||
} | ||
|
||
void appendNode(torch::jit::Node* n) { | ||
cloneNode(n, g_, old_to_new_); | ||
} | ||
|
||
void registerInputs() { | ||
for (auto& value : g_->inputs()) { | ||
inputs_.push_back(old_to_new_[value]); | ||
} | ||
} | ||
|
||
void registerOutput(torch::jit::Value* raw_output) { | ||
outputs_.push_back(raw_output); | ||
g_->registerOutput(old_to_new_[raw_output]); | ||
} | ||
|
||
torch::jit::Block* block() { | ||
return g_->block(); | ||
} | ||
|
||
c10::ArrayRef<torch::jit::Value*> inputs() { | ||
return g_->inputs(); | ||
} | ||
|
||
void eraseInput(size_t i) { | ||
inputs_.erase(inputs_.begin() + i); | ||
g_->eraseInput(i); | ||
} | ||
|
||
c10::ArrayRef<torch::jit::Value*> outputs() { | ||
return g_->outputs(); | ||
} | ||
|
||
void eraseOutput(size_t i) { | ||
outputs_.erase(outputs_.begin() + i); | ||
g_->eraseOutput(i); | ||
} | ||
|
||
const std::vector<torch::jit::Value*>& raw_inputs() const { | ||
return inputs_; | ||
} | ||
|
||
const std::vector<torch::jit::Value*>& raw_outputs() const { | ||
return outputs_; | ||
} | ||
|
||
const std::vector<torch::jit::Node*>& raw_nodes() const { | ||
return nodes_; | ||
} | ||
|
||
bool contain_raw_value(torch::jit::Value* input) { | ||
return old_to_new_.count(input); | ||
} | ||
|
||
torch::jit::graph_node_list nodes() { | ||
return g_->nodes(); | ||
} | ||
|
||
void register_inshape(std::vector<nvinfer1::Dims>& in_shape) { | ||
in_shape_ = in_shape; | ||
} | ||
|
||
const std::vector<nvinfer1::Dims>& in_shape() const { | ||
return in_shape_; | ||
} | ||
|
||
std::shared_ptr<torch::jit::Graph>& g() { | ||
return g_; | ||
} | ||
|
||
void update_graph(std::shared_ptr<torch::jit::Graph> new_g) { | ||
g_ = new_g; | ||
} | ||
|
||
void update_target(SegmentedBlockTarget new_target) { | ||
target_ = new_target; | ||
} | ||
|
||
private: | ||
SegmentedBlockTarget target_; | ||
std::vector<nvinfer1::Dims> in_shape_; | ||
std::vector<torch::jit::Value*> inputs_; | ||
std::vector<torch::jit::Value*> outputs_; | ||
std::vector<torch::jit::Node*> nodes_; | ||
std::shared_ptr<torch::jit::Graph> g_; | ||
std::string trt_engine; | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_; | ||
}; | ||
|
||
std::vector<SegmentedBlock> segment_graph( | ||
std::shared_ptr<torch::jit::Graph> g, | ||
const conversion::TorchFallback& fallback_info); | ||
|
||
} // namespace partitioning | ||
} // namespace core | ||
} // namespace trtorch |
Oops, something went wrong.