Skip to content

Commit

Permalink
chore: reorganize code structure initially
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <wangbo1995ee@163.com>
  • Loading branch information
bowang007 committed Mar 30, 2021
1 parent cfc68ce commit 965a67a
Show file tree
Hide file tree
Showing 7 changed files with 423 additions and 371 deletions.
4 changes: 4 additions & 0 deletions core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ config_setting(
cc_library(
name = "partitioning",
hdrs = [
"SegmentedBlock.h",
"shape_analysis.h",
"partitioning.h",
],
srcs = [
"SegmentedBlock.cpp",
"shape_analysis.cpp",
"partitioning.cpp",
],
deps = [
Expand Down
95 changes: 95 additions & 0 deletions core/partitioning/SegmentedBlock.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "SegmentedBlock.h"

namespace trtorch {
namespace core {
namespace partitioning {

torch::jit::Value* getOrAddInputForValue(
torch::jit::Value* old_value,
std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) {
if (old_to_new.count(old_value) == 0) {
auto node = old_value->node();

if (node->kind() == torch::jit::prim::Constant) {
auto new_const = graph->createClone(node, {nullptr});
graph->block()->prependNode(new_const);
return new_const->output();
}
auto new_value = graph->block()->addInput();
old_to_new[old_value] = new_value;
new_value->copyMetadata(old_value);
// mapping from new graph input Values to original graph values
old_to_new[new_value] = old_value;
return new_value;
} else {
return old_to_new[old_value];
}
}

torch::jit::Node* cloneNode(
torch::jit::Node* node,
std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new) {
auto* block = graph->block();
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v, graph, old_to_new); };

// create node for current graph by using the metadata in node and input Values in env
auto new_node = block->appendNode(graph->createClone(node, env));
for (size_t i = 0; i < node->outputs().size(); ++i) {
auto oo = node->outputs()[i];
auto no = new_node->outputs()[i];
old_to_new[oo] = no;
}
return new_node;
}

std::vector<SegmentedBlock> segment_graph(
std::shared_ptr<torch::jit::Graph> g,
const conversion::TorchFallback& fallback_info) {
auto min_block_size = fallback_info.min_block_size;
std::unordered_set<std::string> forced_fallback_operators(
fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end());

auto nodes = g->block()->nodes();
std::vector<SegmentedBlock> segmented_blocks;

// segment the nodes
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant)
continue;

std::string node_string(n->kind().toQualString());
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
tensorrt_nodes.push_back(n);
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
pytorch_nodes.clear();
}
} else {
if (tensorrt_nodes.size() >= min_block_size) {
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
} else {
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
}
tensorrt_nodes.clear();
pytorch_nodes.push_back(n);
}
}

// if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num <
// min_block_size
if (!pytorch_nodes.empty()) {
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
} else {
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
}

return std::move(segmented_blocks);
}

} // namespace partitioning
} // namespace core
} // namespace trtorch
142 changes: 142 additions & 0 deletions core/partitioning/SegmentedBlock.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#pragma once

#include <vector>

#include "core/conversion/conversion.h"
#include "torch/csrc/jit/ir/ir.h"

namespace trtorch {
namespace core {
namespace partitioning {

torch::jit::Value* getOrAddInputForValue(
torch::jit::Value* old_value,
std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new);

torch::jit::Node* cloneNode(
torch::jit::Node* node,
std::shared_ptr<torch::jit::Graph>& graph,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new);

struct SegmentedBlock {
public:
enum SegmentedBlockTarget {
kTorch,
kTensorRT,
};

SegmentedBlock() = default;

SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}

SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes)
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
for (auto& node : nodes) {
nodes_.push_back(node);
appendNode(node);
}
registerInputs();
}

SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}

enum SegmentedBlockTarget target() {
return target_;
}

void appendNode(torch::jit::Node* n) {
cloneNode(n, g_, old_to_new_);
}

void registerInputs() {
for (auto& value : g_->inputs()) {
inputs_.push_back(old_to_new_[value]);
}
}

void registerOutput(torch::jit::Value* raw_output) {
outputs_.push_back(raw_output);
g_->registerOutput(old_to_new_[raw_output]);
}

torch::jit::Block* block() {
return g_->block();
}

c10::ArrayRef<torch::jit::Value*> inputs() {
return g_->inputs();
}

void eraseInput(size_t i) {
inputs_.erase(inputs_.begin() + i);
g_->eraseInput(i);
}

c10::ArrayRef<torch::jit::Value*> outputs() {
return g_->outputs();
}

void eraseOutput(size_t i) {
outputs_.erase(outputs_.begin() + i);
g_->eraseOutput(i);
}

const std::vector<torch::jit::Value*>& raw_inputs() const {
return inputs_;
}

const std::vector<torch::jit::Value*>& raw_outputs() const {
return outputs_;
}

const std::vector<torch::jit::Node*>& raw_nodes() const {
return nodes_;
}

bool contain_raw_value(torch::jit::Value* input) {
return old_to_new_.count(input);
}

torch::jit::graph_node_list nodes() {
return g_->nodes();
}

void register_inshape(std::vector<nvinfer1::Dims>& in_shape) {
in_shape_ = in_shape;
}

const std::vector<nvinfer1::Dims>& in_shape() const {
return in_shape_;
}

std::shared_ptr<torch::jit::Graph>& g() {
return g_;
}

void update_graph(std::shared_ptr<torch::jit::Graph> new_g) {
g_ = new_g;
}

void update_target(SegmentedBlockTarget new_target) {
target_ = new_target;
}

private:
SegmentedBlockTarget target_;
std::vector<nvinfer1::Dims> in_shape_;
std::vector<torch::jit::Value*> inputs_;
std::vector<torch::jit::Value*> outputs_;
std::vector<torch::jit::Node*> nodes_;
std::shared_ptr<torch::jit::Graph> g_;
std::string trt_engine;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
};

std::vector<SegmentedBlock> segment_graph(
std::shared_ptr<torch::jit::Graph> g,
const conversion::TorchFallback& fallback_info);

} // namespace partitioning
} // namespace core
} // namespace trtorch
Loading

0 comments on commit 965a67a

Please sign in to comment.