Skip to content

Commit

Permalink
Support If/Loop block
Browse files Browse the repository at this point in the history
  • Loading branch information
sen.li committed Jun 27, 2024
1 parent 4f310c4 commit 2841063
Show file tree
Hide file tree
Showing 21 changed files with 1,503 additions and 1,064 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,7 @@ dev.1.0.21.20240619
1. Support export sub_model

dev.1.0.22.20240620
1. Support load input tensor to export
1. Support load input tensor to export

dev.1.0.23.20240627
1. Support If/Loop block
12 changes: 8 additions & 4 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,11 @@ set(pnnx_pass_level6_SRCS
pass_level6/trans_Stack2Unsqueeze.cpp
pass_level6/trans_ReshapeAs2Reshape.cpp
pass_level6/trans_TensorTypeAs2TensorTo.cpp
pass_level6/fold_Loop.cpp
)

set(pnnx_pass_sub_model_SRCS
pass_sub_model/fold_Loop.cpp
pass_sub_model/fold_If.cpp
)

# set(pnnx_pass_ncnn_SRCS
Expand Down Expand Up @@ -603,7 +607,7 @@ set(torch2pnnx_SRCS

pass_level0.cpp
pass_level1.cpp

# pass_level1_class.cpp
${pnnx_pass_level0_SRCS}
${pnnx_pass_level1_SRCS}

Expand Down Expand Up @@ -691,13 +695,13 @@ set(pnnx_SRCS
pass_level4.cpp
pass_level5.cpp
pass_level6.cpp

pass_sub_model.cpp
${pnnx_pass_level2_SRCS}
${pnnx_pass_level3_SRCS}
${pnnx_pass_level4_SRCS}
${pnnx_pass_level5_SRCS}
${pnnx_pass_level6_SRCS}

${pnnx_pass_sub_model_SRCS}
# pass_ncnn.cpp
# save_ncnn.cpp
# ${pnnx_pass_ncnn_SRCS}
Expand Down
241 changes: 233 additions & 8 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2891,16 +2891,19 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
fprintf(pyfp, " self.infer_flag = infer_flag\n");
for (const Operator* op : ops)
{
if(op->type == "pnnx.Loop")
if(op->type == "pnnx.Loop" || op->type == "pnnx.If")
{
std::string op_name = op->name;

std::string subModelBinPath = save_dir + "/" + op_name + ".pnnx.bin";
std::string subModelInferPath = save_dir + "/" + op_name + "_pnnx_infer.py";
fprintf(pyfp, " %s = load_module('%s')\n", (op_name + "_Mod").c_str(), subModelInferPath.c_str());
fprintf(pyfp, " %s = getattr(%s, 'Model')\n", (op_name + "_Cls").c_str(), (op_name + "_Mod").c_str());
fprintf(pyfp, " %s = %s('%s', True)\n", ("self." + op_name + "_Obj").c_str(), (op_name + "_Cls").c_str(), subModelBinPath.c_str());
fprintf(pyfp, " %s.eval()\n", ("self." + op_name + "_Obj").c_str());
std::vector<std::string> block_names = op->params.at("block_names").as;
for(auto op_name: block_names)
{
std::string subModelBinPath = save_dir + "/" + op_name + ".pnnx.bin";
std::string subModelInferPath = save_dir + "/" + op_name + "_pnnx_infer.py";
fprintf(pyfp, " %s = load_module('%s')\n", (op_name + "_Mod").c_str(), subModelInferPath.c_str());
fprintf(pyfp, " %s = getattr(%s, 'Model')\n", (op_name + "_Cls").c_str(), (op_name + "_Mod").c_str());
fprintf(pyfp, " %s = %s('%s', True)\n", ("self." + op_name + "_Obj").c_str(), (op_name + "_Cls").c_str(), subModelBinPath.c_str());
fprintf(pyfp, " %s.eval()\n", ("self." + op_name + "_Obj").c_str());
}
continue;
}
if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.")
Expand Down Expand Up @@ -3270,6 +3273,7 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
{
std::string condition_expr = op->params.at("condition").s;
int iter_num = op->params.at("iter_num").i;
// std::vector<std::string> block_names = op->params.at("block_names").as;
std::string op_name = op->name;
std::vector<Operand*> inputs = op->inputs;
std::vector<Operand*> outputs = op->outputs;
Expand Down Expand Up @@ -3306,6 +3310,64 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
continue;
}

if(op->type == "pnnx.If")
{
std::string op_name = op->name;
std::vector<Operand*> inputs = op->inputs;
std::vector<Operand*> outputs = op->outputs;
std::vector<std::string> block_names = op->params.at("block_names").as;
std::unordered_map<std::string, std::string> block_input_indexes_map;
for(auto block_name: block_names)
{
std::string block_input_indexes_name = block_name + "_input_indexes";
std::vector<int> block_input_indexes = op->params.at(block_input_indexes_name).ai;
std::string input_list = "";
int index = 0;
int index_num = block_input_indexes.size();
for(auto input_index: block_input_indexes)
{
std::string cur_input_name = sanitize_identifier(op->inputs[input_index]->name);
input_list = input_list + "v_" + cur_input_name;
if (index + 1 != index_num)
{
input_list = input_list + ", ";
}
}
block_input_indexes_map[block_name] = input_list;
}

std::string output_list = "";
std::string real_input_list = "";
for(int index = 0; index < outputs.size(); index++)
{
std::string cur_output_name = sanitize_identifier(outputs[index]->name);
std::string cur_real_input_name = sanitize_identifier(inputs[index + 1]->name);
output_list = output_list + "v_" + cur_output_name;
real_input_list = real_input_list + "v_" + cur_real_input_name;
if (index + 1 != outputs.size())
{
output_list = output_list + ", ";
real_input_list = real_input_list + ", ";
}

}
std::string condition = "v_" + sanitize_identifier(op->inputs[0]->name);
fprintf(pyfp, "if(%s):\n", condition.c_str());
fprintf(pyfp, " %s = %s(%s)\n", output_list.c_str(), ("self." + block_names[0] + "_Obj").c_str(), block_input_indexes_map[block_names[0]].c_str());
if(block_names.size() == 2)
{
fprintf(pyfp, " else:\n");
fprintf(pyfp, " %s = %s(%s)\n", output_list.c_str(), ("self." + block_names[1] + "_Obj").c_str(), block_input_indexes_map[block_names[1]].c_str());
}
else
{
fprintf(pyfp, " else:\n");
fprintf(pyfp, " %s = %s\n", output_list.c_str(), real_input_list.c_str());
}
continue;
}


if (op->type == "pnnx.Expression")
{
// expr
Expand Down Expand Up @@ -4444,6 +4506,34 @@ Operator* Graph::new_operator(const std::string& type, const std::string& name)
return op;
}

Operator* Graph::new_constant_operator(const std::string& type, const std::string& name)
{
// get last input index
int last_input_index = -1;
for(auto op: this->ops)
{
if(op->type == "pnnx.Input")
{
last_input_index++;
}
}

if(last_input_index == -1)
{
Operator* op = new Operator;
op->type = type;
op->name = name;
ops.push_back(op);
return op;
}
else
{
std::string last_input_op_name = "pnnx_input_" + std::to_string(last_input_index);
Operator* last_input_op = this->get_operator(last_input_op_name);
return this->new_operator_after(type, name, last_input_op);
}
}

Operator* Graph::new_operator_before(const std::string& type, const std::string& name, const Operator* cur)
{
Operator* op = new Operator;
Expand Down Expand Up @@ -4781,4 +4871,139 @@ int Graph::extract_sub_graph(const std::vector<std::string>& start_nodes, const
return 1;
}


std::string MainGraph::get_pnnx_graph_name()
{
return name;
}
void MainGraph::create_main_graph(std::string& name)
{
this->name = name;
this->main_graph = std::make_shared<Graph>();
}

std::shared_ptr<pnnx::Graph> MainGraph::get_main_graph()
{
return this->main_graph;
}

void MainGraph::insert_sub_graph(std::string& name, std::shared_ptr<pnnx::MainGraph>& sub_graph, Operator* op, int init_input_num)
{
this->sub_graph_map[name] = sub_graph;
std::vector<int> init_index = {};
for(auto i = 0; i < init_input_num; i++)
{
init_index.push_back(i);
}
std::unordered_map<std::string, std::vector<int>> op_graph_input_indexes = {{name, init_index}};
op_2_graph[op->name] = op_graph_input_indexes;
}

void MainGraph::set_base_graph(std::shared_ptr<pnnx::MainGraph>& base_graph)
{
this->base_graph = base_graph;
}

std::shared_ptr<pnnx::MainGraph> MainGraph::get_base_graph()
{
return this->base_graph;
}

std::shared_ptr<pnnx::MainGraph> MainGraph::get_sub_graph(std::string& name)
{
return this->sub_graph_map[name];
}
Operator* MainGraph::set_sub_graph_new_input(const std::string& sub_graph_name, const std::string& operand_name, Operand* r1)
{
auto sub_graph = this->sub_graph_map[sub_graph_name];
auto sub_main_graph = sub_graph->get_main_graph();
// create new input
int input_index = 0;
for(auto op: sub_main_graph->ops)
{
if(op->type == "pnnx.Input")
{
input_index++;
}
}
Operator* new_input_op;
std::string new_input_op_name = "pnnx_input_" + std::to_string(input_index);
if(input_index == 0)
{
new_input_op = sub_main_graph->new_operator("pnnx.Input", new_input_op_name);
}
else
{
std::string last_input_op_name = "pnnx_input_" + std::to_string(input_index-1);
Operator* last_input_op = sub_main_graph->get_operator(last_input_op_name);
new_input_op = sub_main_graph->new_operator_after("pnnx.Input", new_input_op_name, last_input_op);

}
Operand* r2 = sub_main_graph->new_operand(operand_name);
r2->producer = new_input_op;
r2->params = r1->params;
r2->type = r1->type;
r2->shape = r1->shape;
new_input_op->outputs.push_back(r2);
return new_input_op;
}

void MainGraph::set_op_new_input(std::string& sub_graph_name2, Operator* new_input_op)
{
for (auto& it = op_2_graph.begin(); it != op_2_graph.end(); ++it)
{
auto& op_2_graph_input_list = it->second;
for (auto& it2 = op_2_graph_input_list.begin(); it2 != op_2_graph_input_list.end(); ++it2)
{
if(it2->first == sub_graph_name2)
{
// get op

Operator* src_op = main_graph->get_operator(it->first);

new_input_op->outputs[0]->consumers.push_back(src_op);
src_op->inputs.push_back(new_input_op->outputs[0]);
int last_src_input_index = src_op->inputs.size() -1;
it2->second.push_back(last_src_input_index);
}
}
}
}

Operator* MainGraph::get_base_op(std::string& sub_op_block_name)
{
auto base_main_graph = base_graph->get_main_graph();
for(auto op: base_main_graph->ops)
{
if(op->name == sub_op_block_name)
{
return op;
}
}
return 0;
}

MainGraph::MainGraph()
{
}


MainGraph::~MainGraph()
{

}

MainGraph::MainGraph(const MainGraph& /*rhs*/)
{
}

MainGraph& MainGraph::operator=(const MainGraph& /*rhs*/)
{
return *this;
}





} // namespace pnnx
38 changes: 36 additions & 2 deletions tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#ifndef PNNX_IR_H
#define PNNX_IR_H

#include <memory>
#include <limits.h>
#include <complex>
#include <initializer_list>
Expand All @@ -23,7 +23,8 @@
#include <set>
#include <string>
#include <vector>

#include <unordered_map>
#include <queue>
#if BUILD_TORCH2PNNX
namespace torch {
namespace jit {
Expand Down Expand Up @@ -331,6 +332,8 @@ class Graph

Operator* new_operator(const std::string& type, const std::string& name);

Operator* new_constant_operator(const std::string& type, const std::string& name);

Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur);

Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur);
Expand Down Expand Up @@ -361,6 +364,37 @@ class Graph
Graph& operator=(const Graph& rhs);
};

class MainGraph
{
public:
MainGraph();
~MainGraph();
std::string get_pnnx_graph_name();
void create_main_graph(std::string& name);
std::shared_ptr<pnnx::Graph> get_main_graph();
void insert_sub_graph(std::string& name, std::shared_ptr<pnnx::MainGraph>& sub_graph, Operator* op, int init_input_num = 0);
void set_base_graph(std::shared_ptr<pnnx::MainGraph>& base_graph);
std::shared_ptr<pnnx::MainGraph> get_base_graph();
std::shared_ptr<pnnx::MainGraph> get_sub_graph(std::string& name);
Operator* set_sub_graph_new_input(const std::string& sub_graph_name, const std::string& operand_name, Operand* r1);
void set_op_new_input(std::string& sub_graph_name2, Operator* new_input_op);
Operator* get_base_op(std::string& sub_op_block_name);
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<int>>> op_2_graph;
std::unordered_map<std::string, std::shared_ptr<pnnx::MainGraph>> sub_graph_map;
std::string name;
std::vector<std::string> effective_sub_model_name;
private:

std::shared_ptr<pnnx::Graph> main_graph;


std::shared_ptr<pnnx::MainGraph> base_graph;

MainGraph(const MainGraph& rhs);
MainGraph& operator=(const MainGraph& rhs);
};


} // namespace pnnx

#endif // PNNX_IR_H
Loading

0 comments on commit 2841063

Please sign in to comment.