From 965a67a2b819cb53143f08452f5bd0146c70884a Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 30 Mar 2021 18:30:16 -0500 Subject: [PATCH] chore: reorganize code structure initially Signed-off-by: Bo Wang --- core/partitioning/BUILD | 4 + core/partitioning/SegmentedBlock.cpp | 95 +++++++++ core/partitioning/SegmentedBlock.h | 142 +++++++++++++ core/partitioning/partitioning.cpp | 299 +++++---------------------- core/partitioning/partitioning.h | 127 +----------- core/partitioning/shape_analysis.cpp | 112 ++++++++++ core/partitioning/shape_analysis.h | 15 ++ 7 files changed, 423 insertions(+), 371 deletions(-) create mode 100644 core/partitioning/SegmentedBlock.cpp create mode 100644 core/partitioning/SegmentedBlock.h create mode 100644 core/partitioning/shape_analysis.cpp create mode 100644 core/partitioning/shape_analysis.h diff --git a/core/partitioning/BUILD b/core/partitioning/BUILD index 0d8b2006a7..0f21667718 100644 --- a/core/partitioning/BUILD +++ b/core/partitioning/BUILD @@ -10,9 +10,13 @@ config_setting( cc_library( name = "partitioning", hdrs = [ + "SegmentedBlock.h", + "shape_analysis.h", "partitioning.h", ], srcs = [ + "SegmentedBlock.cpp", + "shape_analysis.cpp", "partitioning.cpp", ], deps = [ diff --git a/core/partitioning/SegmentedBlock.cpp b/core/partitioning/SegmentedBlock.cpp new file mode 100644 index 0000000000..015baf1e23 --- /dev/null +++ b/core/partitioning/SegmentedBlock.cpp @@ -0,0 +1,95 @@ +#include "SegmentedBlock.h" + +namespace trtorch { +namespace core { +namespace partitioning { + +torch::jit::Value* getOrAddInputForValue( + torch::jit::Value* old_value, + std::shared_ptr& graph, + std::unordered_map& old_to_new) { + if (old_to_new.count(old_value) == 0) { + auto node = old_value->node(); + + if (node->kind() == torch::jit::prim::Constant) { + auto new_const = graph->createClone(node, {nullptr}); + graph->block()->prependNode(new_const); + return new_const->output(); + } + auto new_value = graph->block()->addInput(); + old_to_new[old_value] = new_value; + new_value->copyMetadata(old_value); + // mapping from new graph input Values to original graph values + old_to_new[new_value] = old_value; + return new_value; + } else { + return old_to_new[old_value]; + } +} + +torch::jit::Node* cloneNode( + torch::jit::Node* node, + std::shared_ptr& graph, + std::unordered_map& old_to_new) { + auto* block = graph->block(); + auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v, graph, old_to_new); }; + + // create node for current graph by using the metadata in node and input Values in env + auto new_node = block->appendNode(graph->createClone(node, env)); + for (size_t i = 0; i < node->outputs().size(); ++i) { + auto oo = node->outputs()[i]; + auto no = new_node->outputs()[i]; + old_to_new[oo] = no; + } + return new_node; +} + +std::vector segment_graph( + std::shared_ptr g, + const conversion::TorchFallback& fallback_info) { + auto min_block_size = fallback_info.min_block_size; + std::unordered_set forced_fallback_operators( + fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end()); + + auto nodes = g->block()->nodes(); + std::vector segmented_blocks; + + // segment the nodes + std::vector tensorrt_nodes, pytorch_nodes; + for (const auto n : nodes) { + if (n->kind() == torch::jit::prim::Constant) + continue; + + std::string node_string(n->kind().toQualString()); + if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) { + tensorrt_nodes.push_back(n); + if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) { + segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); + pytorch_nodes.clear(); + } + } else { + if (tensorrt_nodes.size() >= min_block_size) { + segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); + } else { + pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); + } + tensorrt_nodes.clear(); + pytorch_nodes.push_back(n); + } + } + + // if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num < + // min_block_size + if (!pytorch_nodes.empty()) { + pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); + segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); + } else { + segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); + } + + return std::move(segmented_blocks); +} + +} // namespace partitioning +} // namespace core +} // namespace trtorch \ No newline at end of file diff --git a/core/partitioning/SegmentedBlock.h b/core/partitioning/SegmentedBlock.h new file mode 100644 index 0000000000..7c25e39acd --- /dev/null +++ b/core/partitioning/SegmentedBlock.h @@ -0,0 +1,142 @@ +#pragma once + +#include + +#include "core/conversion/conversion.h" +#include "torch/csrc/jit/ir/ir.h" + +namespace trtorch { +namespace core { +namespace partitioning { + +torch::jit::Value* getOrAddInputForValue( + torch::jit::Value* old_value, + std::shared_ptr& graph, + std::unordered_map& old_to_new); + +torch::jit::Node* cloneNode( + torch::jit::Node* node, + std::shared_ptr& graph, + std::unordered_map& old_to_new); + +struct SegmentedBlock { + public: + enum SegmentedBlockTarget { + kTorch, + kTensorRT, + }; + + SegmentedBlock() = default; + + SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared()) {} + + SegmentedBlock(SegmentedBlockTarget blk_target, std::vector& nodes) + : target_(blk_target), g_(std::make_shared()) { + for (auto& node : nodes) { + nodes_.push_back(node); + appendNode(node); + } + registerInputs(); + } + + SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr g) : target_(blk_target), g_(g) {} + + enum SegmentedBlockTarget target() { + return target_; + } + + void appendNode(torch::jit::Node* n) { + cloneNode(n, g_, old_to_new_); + } + + void registerInputs() { + for (auto& value : g_->inputs()) { + inputs_.push_back(old_to_new_[value]); + } + } + + void registerOutput(torch::jit::Value* raw_output) { + outputs_.push_back(raw_output); + g_->registerOutput(old_to_new_[raw_output]); + } + + torch::jit::Block* block() { + return g_->block(); + } + + c10::ArrayRef inputs() { + return g_->inputs(); + } + + void eraseInput(size_t i) { + inputs_.erase(inputs_.begin() + i); + g_->eraseInput(i); + } + + c10::ArrayRef outputs() { + return g_->outputs(); + } + + void eraseOutput(size_t i) { + outputs_.erase(outputs_.begin() + i); + g_->eraseOutput(i); + } + + const std::vector& raw_inputs() const { + return inputs_; + } + + const std::vector& raw_outputs() const { + return outputs_; + } + + const std::vector& raw_nodes() const { + return nodes_; + } + + bool contain_raw_value(torch::jit::Value* input) { + return old_to_new_.count(input); + } + + torch::jit::graph_node_list nodes() { + return g_->nodes(); + } + + void register_inshape(std::vector& in_shape) { + in_shape_ = in_shape; + } + + const std::vector& in_shape() const { + return in_shape_; + } + + std::shared_ptr& g() { + return g_; + } + + void update_graph(std::shared_ptr new_g) { + g_ = new_g; + } + + void update_target(SegmentedBlockTarget new_target) { + target_ = new_target; + } + + private: + SegmentedBlockTarget target_; + std::vector in_shape_; + std::vector inputs_; + std::vector outputs_; + std::vector nodes_; + std::shared_ptr g_; + std::string trt_engine; + std::unordered_map old_to_new_; +}; + +std::vector segment_graph( + std::shared_ptr g, + const conversion::TorchFallback& fallback_info); + +} // namespace partitioning +} // namespace core +} // namespace trtorch \ No newline at end of file diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 9d2c143ce1..3de69c2464 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -1,9 +1,7 @@ #include "partitioning.h" + #include -#include "core/conversion/evaluators/eval_util.h" -#include "core/lowering/passes/passes.h" -#include "core/util/prelude.h" -#include "torch/csrc/jit/api/module.h" +#include "shape_analysis.h" #include "torch/csrc/jit/passes/constant_pooling.h" namespace trtorch { @@ -21,204 +19,6 @@ struct usage_info { std::vector tensorrt_use_id; }; -torch::jit::Value* getOrAddInputForValue( - torch::jit::Value* old_value, - std::shared_ptr& graph, - std::unordered_map& old_to_new) { - if (old_to_new.count(old_value) == 0) { - auto node = old_value->node(); - - if (node->kind() == torch::jit::prim::Constant) { - auto new_const = graph->createClone(node, {nullptr}); - graph->block()->prependNode(new_const); - return new_const->output(); - } - auto new_value = graph->block()->addInput(); - old_to_new[old_value] = new_value; - new_value->copyMetadata(old_value); - // mapping from new graph input Values to original graph values - old_to_new[new_value] = old_value; - return new_value; - } else { - return old_to_new[old_value]; - } -} - -torch::jit::Node* cloneNode( - torch::jit::Node* node, - std::shared_ptr& graph, - std::unordered_map& old_to_new) { - auto* block = graph->block(); - auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v, graph, old_to_new); }; - - // create node for current graph by using the metadata in node and input Values in env - auto new_node = block->appendNode(graph->createClone(node, env)); - for (size_t i = 0; i < node->outputs().size(); ++i) { - auto oo = node->outputs()[i]; - auto no = new_node->outputs()[i]; - old_to_new[oo] = no; - } - return new_node; -} - -c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr& g) { - std::vector args; - for (auto in : g->inputs()) { - args.push_back(c10::Argument(in->debugName(), in->type())); - } - - std::vector returns; - for (auto out : g->outputs()) { - returns.push_back(c10::Argument(out->debugName(), out->type())); - } - - return c10::FunctionSchema(method_name, method_name, args, returns); -} - -void registerSegmentInOutIValues( - SegmentedBlock& seg_block, - std::unordered_map& ivalues_maps) { - // create a module to run the graph - auto g = seg_block.g(); - auto copy_g = g->copy(); - - // create tuple for multiple outputs - if (seg_block.raw_outputs().size() > 1) { - auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs())); - for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) { - copy_g->eraseOutput(idx); - } - copy_g->registerOutput(new_output_node->outputs()[0]); - } - - torch::jit::script::Module cur_mod(c10::QualifiedName("module")); - - auto self = copy_g->insertInput(0, "self_1"); - self->setType(cur_mod.type()); - - auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), copy_g); - auto schema = getFunctionSchema(cur_method->name(), copy_g); - cur_mod.type()->addMethod(cur_method); - cur_method->setSchema(schema); - - std::vector jit_inputs_ivalues; - - // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments - for (auto& input : seg_block.raw_inputs()) { - TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName()); - if (input->node()->kind() == torch::jit::prim::Param) { - jit_inputs_ivalues.push_back(ivalues_maps[input]); - } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor()); - } else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toInt()); - } else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toBool()); - } else if (input->type()->kind() == torch::jit::TypeKind::ListType) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toList()); - } else { - TRTORCH_CHECK(input->type()->kind() == torch::jit::TypeKind::TupleType, "Input for mini graph is not TupleType."); - jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple()); - } - } - - // run segments to get outputs for later segments input shape, and other arguments such as Int - std::vector jit_results; - torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues); - - if (jit_results_ivalues.isTuple()) { - auto results = jit_results_ivalues.toTuple()->elements(); - for (auto r : results) { - jit_results.push_back(r); - } - } else { - jit_results.push_back(jit_results_ivalues); - } - - size_t idx = 0; - for (auto& output : seg_block.raw_outputs()) { - ivalues_maps[output] = jit_results[idx++]; - } - - // set input shape for each segmented block so we wil use it in conversion process - std::vector input_shape; - for (auto& i : seg_block.raw_inputs()) { - if (ivalues_maps[i].isTensor()) { - input_shape.push_back(util::toDims(ivalues_maps[i].toTensor().sizes())); - } - } - - seg_block.register_inshape(input_shape); -} - -std::vector generateRandomInputs(std::vector& input_ranges) { - // generate random inputs for running pytorch segments - std::vector random_inputs; - for (auto& input_range : input_ranges) { - auto cur_shape = input_range.input_shape; - std::vector shape; - shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); - auto in = at::randint(5, shape, {at::kCUDA}); - random_inputs.push_back(in.clone()); - } - return random_inputs; -} - -void registerSegmentsOutputs(std::vector& segmented_blocks, std::shared_ptr g) { - // find the corresponding raw values in original global graph for this segmented block's inputs/outputs - std::set input_values; - for (auto& seg_block : segmented_blocks) { - for (auto& input : seg_block.raw_inputs()) { - input_values.insert(input); - } - } - - for (auto& graph_output : g->outputs()) { - input_values.insert(graph_output); - } - - // should be careful here because some in-place operations don't return any values, there is no output for this kind - // of segment identify the output for each mini-graph by checking if any value in this graph is used later we - // shouldn't register nonTensor output for TensorRT segments - for (auto& seg_block : segmented_blocks) { - for (auto& mini_graph_input : input_values) { - if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) == - seg_block.raw_inputs().end() && - seg_block.contain_raw_value(mini_graph_input)) { - if (!isTensorOrTensorList(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT) - continue; - seg_block.registerOutput(mini_graph_input); - } - } - // if no output, then register the last node's output as current graph's output - if (seg_block.raw_outputs().empty()) { - // for Torch segments, register input as output - if (seg_block.target() == SegmentedBlock::kTorch) { - seg_block.registerOutput(seg_block.raw_inputs()[0]); - } else { - // for TensorRT segments, register last nonInput Tensor outputs - for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) { - for (auto node_output : seg_block.raw_nodes()[i]->outputs()) { - if (isTensorOrTensorList(node_output)) - seg_block.registerOutput(node_output); - } - if (!seg_block.raw_outputs().empty()) - break; - } - } - } - } - // erase segments which still have no output - segmented_blocks.erase( - std::remove_if( - segmented_blocks.begin(), - segmented_blocks.end(), - [](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }), - segmented_blocks.end()); - - return; -} - std::vector getDependencyNodes(std::vector& vals) { // using bfs to get the DAG dependency nodes for input value std::queue> q( @@ -252,7 +52,7 @@ SegmentedBlock injectNodesForNonTensorInputs(SegmentedBlock& seg_block) { } std::vector new_block_nodes = getDependencyNodes(nontensor_inputs); new_block_nodes.insert(new_block_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - return SegmentedBlock(seg_block.target(), new_block_nodes); + return std::move(SegmentedBlock(seg_block.target(), new_block_nodes)); } void resolveNonTensorInputs(std::vector& segmented_blocks, std::shared_ptr g) { @@ -297,53 +97,59 @@ void resolveNonTensorInputs(std::vector& segmented_blocks, std:: return; } -void construct_segments( - std::vector& pytorch_nodes, - std::vector& tensorrt_nodes, - std::vector& segmented_blocks, - size_t min_block_size) { - // construct segmented blocks according to min_block_size and consecutive nodes - if (!tensorrt_nodes.empty()) { - if (tensorrt_nodes.size() < min_block_size) { - pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end()); - } else { - if (!pytorch_nodes.empty()) - segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); - segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes); - pytorch_nodes.clear(); +void registerSegmentsOutputs(std::vector& segmented_blocks, std::shared_ptr g) { + // find the corresponding raw values in original global graph for this segmented block's inputs/outputs + std::set input_values; + for (auto& seg_block : segmented_blocks) { + for (auto& input : seg_block.raw_inputs()) { + input_values.insert(input); } - tensorrt_nodes.clear(); } -} -void segment_graph( - std::shared_ptr g, - const conversion::TorchFallback& fallback_info, - std::vector& segmented_blocks) { - auto min_block_size = fallback_info.min_block_size; - std::unordered_set forced_fallback_operators( - fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end()); - - auto nodes = g->block()->nodes(); - - // segment the nodes - std::vector tensorrt_nodes, pytorch_nodes; - for (const auto n : nodes) { - if (n->kind() == torch::jit::prim::Constant) - continue; + for (auto& graph_output : g->outputs()) { + input_values.insert(graph_output); + } - std::string node_string(n->kind().toQualString()); - if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) { - tensorrt_nodes.push_back(n); - } else { - construct_segments(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size); - pytorch_nodes.push_back(n); + // should be careful here because some in-place operations don't return any values, there is no output for this kind + // of segment identify the output for each mini-graph by checking if any value in this graph is used later we + // shouldn't register nonTensor output for TensorRT segments + for (auto& seg_block : segmented_blocks) { + for (auto& mini_graph_input : input_values) { + if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) == + seg_block.raw_inputs().end() && + seg_block.contain_raw_value(mini_graph_input)) { + if (!isTensorOrTensorList(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT) + continue; + seg_block.registerOutput(mini_graph_input); + } + } + // if no output, then register the last node's output as current graph's output + if (seg_block.raw_outputs().empty()) { + // for Torch segments, register input as output + if (seg_block.target() == SegmentedBlock::kTorch) { + seg_block.registerOutput(seg_block.raw_inputs()[0]); + } else { + // for TensorRT segments, register last nonInput Tensor outputs + for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) { + for (auto node_output : seg_block.raw_nodes()[i]->outputs()) { + if (isTensorOrTensorList(node_output)) + seg_block.registerOutput(node_output); + } + if (!seg_block.raw_outputs().empty()) + break; + } + } } } - construct_segments(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size); - if (!pytorch_nodes.empty()) { - segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); - } + // erase segments which still have no output + segmented_blocks.erase( + std::remove_if( + segmented_blocks.begin(), + segmented_blocks.end(), + [](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }), + segmented_blocks.end()); + + return; } std::vector Partition( @@ -351,8 +157,7 @@ std::vector Partition( std::vector& input_ranges, const conversion::TorchFallback& fallback_info) { // segment lowering global graph into blocks - std::vector segmented_blocks; - segment_graph(g, fallback_info, segmented_blocks); + std::vector segmented_blocks = segment_graph(g, fallback_info); // resolve nonTensor inputs/outputs resolveNonTensorInputs(segmented_blocks, g); @@ -370,7 +175,7 @@ std::vector Partition( // register every segment's input shape, and it's running output IValues for (auto& seg_block : segmented_blocks) { torch::jit::ConstantPooling(seg_block.g()); - registerSegmentInOutIValues(seg_block, ivalues_maps); + getSegmentsOutputByRunning(seg_block, ivalues_maps); } return segmented_blocks; diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 62717cbbd1..fe43240ada 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -3,136 +3,15 @@ #include #include "core/conversion/conversion.h" +#include "core/conversion/evaluators/eval_util.h" +#include "core/partitioning/SegmentedBlock.h" +#include "core/util/prelude.h" #include "torch/csrc/jit/ir/ir.h" namespace trtorch { namespace core { namespace partitioning { -torch::jit::Value* getOrAddInputForValue( - torch::jit::Value* old_value, - std::shared_ptr& graph, - std::unordered_map& old_to_new); - -torch::jit::Node* cloneNode( - torch::jit::Node* node, - std::shared_ptr& graph, - std::unordered_map& old_to_new); - -struct SegmentedBlock { - public: - enum SegmentedBlockTarget { - kTorch, - kTensorRT, - }; - - SegmentedBlock() = default; - - SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared()) {} - - SegmentedBlock(SegmentedBlockTarget blk_target, std::vector& nodes) - : target_(blk_target), g_(std::make_shared()) { - for (auto& node : nodes) { - nodes_.push_back(node); - appendNode(node); - } - registerInputs(); - } - - SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr g) : target_(blk_target), g_(g) {} - - enum SegmentedBlockTarget target() { - return target_; - } - - void appendNode(torch::jit::Node* n) { - cloneNode(n, g_, old_to_new_); - } - - void registerInputs() { - for (auto& value : g_->inputs()) { - inputs_.push_back(old_to_new_[value]); - } - } - - void registerOutput(torch::jit::Value* raw_output) { - outputs_.push_back(raw_output); - g_->registerOutput(old_to_new_[raw_output]); - } - - torch::jit::Block* block() { - return g_->block(); - } - - c10::ArrayRef inputs() { - return g_->inputs(); - } - - void eraseInput(size_t i) { - inputs_.erase(inputs_.begin() + i); - g_->eraseInput(i); - } - - c10::ArrayRef outputs() { - return g_->outputs(); - } - - void eraseOutput(size_t i) { - outputs_.erase(outputs_.begin() + i); - g_->eraseOutput(i); - } - - const std::vector& raw_inputs() const { - return inputs_; - } - - const std::vector& raw_outputs() const { - return outputs_; - } - - const std::vector& raw_nodes() const { - return nodes_; - } - - bool contain_raw_value(torch::jit::Value* input) { - return old_to_new_.count(input); - } - - torch::jit::graph_node_list nodes() { - return g_->nodes(); - } - - void register_inshape(std::vector& in_shape) { - in_shape_ = in_shape; - } - - const std::vector& in_shape() const { - return in_shape_; - } - - std::shared_ptr& g() { - return g_; - } - - void update_graph(std::shared_ptr new_g) { - g_ = new_g; - } - - void update_target(SegmentedBlockTarget new_target) { - target_ = new_target; - } - - private: - SegmentedBlockTarget target_; - std::vector in_shape_; - std::vector inputs_; - std::vector outputs_; - std::vector nodes_; - std::shared_ptr g_; - std::string trt_engine; - std::unordered_map old_to_new_; -}; - std::vector Partition( std::shared_ptr g, std::vector& input_ranges, diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp new file mode 100644 index 0000000000..efb9fd6b6b --- /dev/null +++ b/core/partitioning/shape_analysis.cpp @@ -0,0 +1,112 @@ +#include "shape_analysis.h" +#include "torch/csrc/jit/api/module.h" + +namespace trtorch { +namespace core { +namespace partitioning { + +std::vector generateRandomInputs(std::vector& input_ranges) { + // generate random inputs for running pytorch segments + std::vector random_inputs; + for (auto& input_range : input_ranges) { + auto cur_shape = input_range.input_shape; + std::vector shape; + shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); + auto in = at::randint(5, shape, {at::kCUDA}); + random_inputs.push_back(in.clone()); + } + return random_inputs; +} + +c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr& g) { + std::vector args; + for (auto in : g->inputs()) { + args.push_back(c10::Argument(in->debugName(), in->type())); + } + + std::vector returns; + for (auto out : g->outputs()) { + returns.push_back(c10::Argument(out->debugName(), out->type())); + } + + return c10::FunctionSchema(method_name, method_name, args, returns); +} + +void getSegmentsOutputByRunning( + SegmentedBlock& seg_block, + std::unordered_map& ivalues_maps) { + // create a module to run the graph + auto g = seg_block.g(); + auto copy_g = g->copy(); + + // create tuple for multiple outputs + if (seg_block.raw_outputs().size() > 1) { + auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs())); + for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) { + copy_g->eraseOutput(idx); + } + + copy_g->registerOutput(new_output_node->outputs()[0]); + } + + torch::jit::script::Module cur_mod(c10::QualifiedName("module")); + + auto self = copy_g->insertInput(0, "self_1"); + self->setType(cur_mod.type()); + + auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), copy_g); + auto schema = getFunctionSchema(cur_method->name(), copy_g); + cur_mod.type()->addMethod(cur_method); + cur_method->setSchema(schema); + + std::vector jit_inputs_ivalues; + + // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments + for (auto& input : seg_block.raw_inputs()) { + TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName()); + if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor()); + } else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toInt()); + } else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toBool()); + } else if (input->type()->kind() == torch::jit::TypeKind::ListType) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toList()); + } else { + TRTORCH_CHECK(input->type()->kind() == torch::jit::TypeKind::TupleType, "Input for mini graph is not TupleType."); + jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple()); + } + } + + // run segments to get outputs for later segments input shape, and other arguments such as Int + std::vector jit_results; + torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues); + + if (jit_results_ivalues.isTuple()) { + auto results = jit_results_ivalues.toTuple()->elements(); + for (auto r : results) { + jit_results.push_back(r); + } + } else { + jit_results.push_back(jit_results_ivalues); + } + + size_t idx = 0; + for (auto& output : seg_block.raw_outputs()) { + ivalues_maps[output] = jit_results[idx++]; + } + + // set input shape for each segmented block so we wil use it in conversion process + std::vector input_shape; + for (auto& i : seg_block.raw_inputs()) { + if (ivalues_maps[i].isTensor()) { + input_shape.push_back(util::toDims(ivalues_maps[i].toTensor().sizes())); + } + } + + seg_block.register_inshape(input_shape); +} + +} // namespace partitioning +} // namespace core +} // namespace trtorch diff --git a/core/partitioning/shape_analysis.h b/core/partitioning/shape_analysis.h new file mode 100644 index 0000000000..8252573430 --- /dev/null +++ b/core/partitioning/shape_analysis.h @@ -0,0 +1,15 @@ +#include "SegmentedBlock.h" + +namespace trtorch { +namespace core { +namespace partitioning { + +std::vector generateRandomInputs(std::vector& input_ranges); + +void getSegmentsOutputByRunning( + SegmentedBlock& seg_block, + std::unordered_map& ivalues_maps); + +} // namespace partitioning +} // namespace core +} // namespace trtorch \ No newline at end of file