Skip to content

Commit

Permalink
pnnx support dynamic slice indexes (#5299)
Browse files Browse the repository at this point in the history
* pnnx handle two operands add/sub/rsub variant

* fuse dynamic slice indexes, wip

* pnnx sliceindexes

* reset device may change non-dtype input numeric 5 to 6

* print inf as float

* preserve dtype for generation op

* pnnx convert torch.masked_select

* test masked_select

* test negative slice
  • Loading branch information
nihui authored Jan 25, 2024
1 parent ff17c17 commit 40958d3
Show file tree
Hide file tree
Showing 43 changed files with 1,864 additions and 799 deletions.
9 changes: 5 additions & 4 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_lgamma.cpp
pass_level2/torch_logsumexp.cpp
pass_level2/torch_lt.cpp
pass_level2/torch_masked_select.cpp
pass_level2/torch_matmul.cpp
pass_level2/torch_max.cpp
pass_level2/torch_mean.cpp
Expand Down Expand Up @@ -386,6 +387,8 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/convert_torch_tensor_split.cpp
pass_ncnn/convert_torch_unbind.cpp
pass_ncnn/convert_Tensor_select.cpp
pass_ncnn/convert_Tensor_slice.cpp
pass_ncnn/convert_Tensor_slice_copy.cpp
pass_ncnn/eliminate_output.cpp
pass_ncnn/expand_expression.cpp
pass_ncnn/fuse_convert_shufflechannel_slice.cpp
Expand Down Expand Up @@ -535,8 +538,6 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/Tensor_contiguous.cpp
pass_ncnn/Tensor_reshape.cpp
pass_ncnn/Tensor_repeat.cpp
pass_ncnn/Tensor_slice.cpp
pass_ncnn/Tensor_slice_copy.cpp
pass_ncnn/Tensor_view.cpp
pass_ncnn/torch_addmm.cpp
pass_ncnn/torch_amax.cpp
Expand Down Expand Up @@ -659,8 +660,8 @@ else()
target_link_libraries(pnnx PRIVATE ${TORCH_LIBRARIES} pthread dl)
endif()

#set_target_properties(pnnx PROPERTIES COMPILE_FLAGS -fsanitize=address)
#set_target_properties(pnnx PROPERTIES LINK_FLAGS -fsanitize=address)
# set_target_properties(pnnx PROPERTIES COMPILE_FLAGS -fsanitize=address)
# set_target_properties(pnnx PROPERTIES LINK_FLAGS -fsanitize=address)

if(APPLE)
set_target_properties(pnnx PROPERTIES INSTALL_RPATH "@executable_path/")
Expand Down
234 changes: 196 additions & 38 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,43 @@ std::string Parameter::encode_to_string(const Parameter& param)
return std::string();
}

bool Operator::has_param(const std::string& key) const
{
return params.find(key) != params.end();
}

bool Operator::has_attr(const std::string& key) const
{
return attrs.find(key) != attrs.end();
}

bool Operator::has_input(const std::string& key) const
{
return std::find(inputnames.begin(), inputnames.end(), key) != inputnames.end();
}

Operand* Operator::named_input(const std::string& key)
{
for (size_t i = 0; i < inputnames.size(); i++)
{
if (inputnames[i] == key)
return inputs[i];
}

return 0;
}

const Operand* Operator::named_input(const std::string& key) const
{
for (size_t i = 0; i < inputnames.size(); i++)
{
if (inputnames[i] == key)
return inputs[i];
}

return 0;
}

Graph::Graph()
{
}
Expand Down Expand Up @@ -1339,7 +1376,7 @@ static std::string expand_expression(const Operator* op)
if (t == "floor") unaryop = "torch.floor";
if (t == "log") unaryop = "torch.log";
if (t == "log10") unaryop = "torch.log10";
if (t == "neg") unaryop = "torch.neg";
if (t == "neg") unaryop = "-";
if (t == "reciprocal") unaryop = "torch.reciprocal";
if (t == "round") unaryop = "torch.round";
if (t == "rsqrt") unaryop = "torch.rsqrt";
Expand Down Expand Up @@ -1472,13 +1509,13 @@ static std::string expand_expression(const Operator* op)

static std::string make_slice_expression(const Operator* op)
{
for (size_t j = 0; j < op->inputnames.size(); j++)
{
fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str());
}
// for (size_t j = 0; j < op->inputnames.size(); j++)
// {
// fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str());
// }

std::vector<int> dims;
if (op->params.find("dims") != op->params.end())
if (op->has_param("dims"))
{
dims = op->params.at("dims").ai;
}
Expand All @@ -1487,66 +1524,158 @@ static std::string make_slice_expression(const Operator* op)
dims.push_back(op->params.at("dim").i);
}

std::string r;
std::string pr;
std::string nr;

int last_dim = -1;
const int ndim = (int)dims.size();
for (int i = 0; i < ndim; i++)
{
int dim = dims[i];
std::string& r = dim < 0 ? nr : pr;

for (int j = last_dim + 1; j < dim; j++)
{
r += ":,";
}
last_dim = dim;

if (op->params.find("starts") != op->params.end())
bool is_select = false;
if (op->has_param("select"))
{
int select = op->params.at("select").i;
if (select != INT_MAX)
{
r += std::to_string(select);
is_select = true;
}
}
if (op->has_param("selects"))
{
std::vector<int> selects = op->params.at("selects").ai;
int select = selects[i];
if (select != INT_MAX)
{
r += std::to_string(select);
is_select = true;
}
}
if (op->has_input("select"))
{
r += std::string("v_") + sanitize_identifier(op->named_input("select")->name);
is_select = true;
}
if (op->has_input("selects"))
{
// must be pnnx.SliceIndexes
const Operator* op_sliceindexes = op->named_input("selects")->producer;
const std::string& index = op_sliceindexes->params.at("indexes").as[i];
if (index[0] == '@')
{
int selecti = std::stoi(index.substr(1));
r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[selecti]->name);
is_select = true;
}
else
{
int select = std::stoi(index);
if (select != INT_MAX)
{
r += std::to_string(select);
is_select = true;
}
}
}

if (is_select)
{
if (i + 1 != ndim)
r += ',';
continue;
}

if (op->has_param("start"))
{
int start = op->params.at("start").i;
if (start != 0)
r += std::to_string(start);
}
else if (op->has_param("starts"))
{
std::vector<int> starts = op->params.at("starts").ai;
int start = starts[i];

if (start != 0)
r += std::to_string(start);
}
else
else if (op->has_input("start"))
{
r += std::string("v_") + sanitize_identifier(op->named_input("start")->name);
}
else // if (op->has_input("starts"))
{
fprintf(stderr, "find start\n");
// find start
for (size_t j = 0; j < op->inputnames.size(); j++)
// must be pnnx.SliceIndexes
const Operator* op_sliceindexes = op->named_input("starts")->producer;
const std::string& index = op_sliceindexes->params.at("indexes").as[i];
if (index[0] == '@')
{
if (op->inputnames[j] == "start")
{
r += std::string("v_") + sanitize_identifier(op->inputs[j]->name);

fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str());
break;
}
int starti = std::stoi(index.substr(1));
r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[starti]->name);
}
else
{
int start = std::stoi(index);
if (start != 0)
r += std::to_string(start);
}
}

r += ':';

if (op->params.find("ends") != op->params.end())
if (op->has_param("end"))
{
int end = op->params.at("end").i;
if (end != INT_MAX)
r += std::to_string(end);
}
else if (op->has_param("ends"))
{
std::vector<int> ends = op->params.at("ends").ai;
int end = ends[i];
if (end != INT_MAX)
r += std::to_string(end);
}
else
else if (op->has_input("end"))
{
r += std::string("v_") + sanitize_identifier(op->named_input("end")->name);
}
else // if (op->has_input("ends"))
{
// find end
for (size_t j = 0; j < op->inputnames.size(); j++)
// must be pnnx.SliceIndexes
const Operator* op_sliceindexes = op->named_input("ends")->producer;
const std::string& index = op_sliceindexes->params.at("indexes").as[i];
if (index[0] == '@')
{
if (op->inputnames[j] == "end")
{
r += std::string("v_") + sanitize_identifier(op->inputs[j]->name);
break;
}
int endi = std::stoi(index.substr(1));
r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[endi]->name);
}
else
{
int end = std::stoi(index);
if (end != INT_MAX)
r += std::to_string(end);
}
}

if (op->params.find("steps") != op->params.end())
if (op->has_param("step"))
{
int step = op->params.at("step").i;
if (step != 1)
{
r += ':';
r += std::to_string(step);
}
}
else if (op->has_param("steps"))
{
std::vector<int> steps = op->params.at("steps").ai;
int step = steps[i];
Expand All @@ -1556,16 +1685,29 @@ static std::string make_slice_expression(const Operator* op)
r += std::to_string(step);
}
}
else
else if (op->has_input("step"))
{
r += ':';
r += std::string("v_") + sanitize_identifier(op->named_input("step")->name);
}
else // if (op->has_input("steps"))
{
// find step
for (size_t j = 0; j < op->inputnames.size(); j++)
// must be pnnx.SliceIndexes
const Operator* op_sliceindexes = op->named_input("steps")->producer;
const std::string& index = op_sliceindexes->params.at("indexes").as[i];
if (index[0] == '@')
{
if (op->inputnames[j] == "step")
int stepi = std::stoi(index.substr(1));
r += ':';
r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[stepi]->name);
}
else
{
int step = std::stoi(index);
if (step != 1)
{
r += ':';
r += std::string("v_") + sanitize_identifier(op->inputs[j]->name);
break;
r += std::to_string(step);
}
}
}
Expand All @@ -1574,7 +1716,13 @@ static std::string make_slice_expression(const Operator* op)
r += ',';
}

return r;
if (!pr.empty() && !nr.empty())
return pr + "...," + nr;

if (pr.empty() && !nr.empty())
return std::string("...,") + nr;

return pr + nr;
}

static std::string make_index_expression(const Operator* op)
Expand Down Expand Up @@ -1932,6 +2080,9 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
if (op->type == "pnnx.Input" || op->type == "pnnx.Output")
continue;

if (op->type == "pnnx.SliceIndexes")
continue;

fprintf(pyfp, " ");

if (op->type == "pnnx.Expression")
Expand Down Expand Up @@ -2415,7 +2566,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
}
else
{
fprintf(pyfp, "\'%s\'", param.s.c_str());
if (param.s == "inf" || param.s == "-inf")
{
fprintf(pyfp, "float(\'%s\')", param.s.c_str());
}
else
{
fprintf(pyfp, "\'%s\'", param.s.c_str());
}
}
}
if (param.type == 5)
Expand Down
6 changes: 6 additions & 0 deletions tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ class Operand
class Operator
{
public:
bool has_param(const std::string& key) const;
bool has_attr(const std::string& key) const;
bool has_input(const std::string& key) const;
Operand* named_input(const std::string& key);
const Operand* named_input(const std::string& key) const;

std::vector<Operand*> inputs;
std::vector<Operand*> outputs;

Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/src/pass_level0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Grap
{
inline_block(g, module_operators);

constant_unpooling(g);

reset_device(g, device);

flatten_input(g);

constant_unpooling(g);

if (!input_tensors.empty())
{
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants, foldable_constants_zippath);
Expand Down
Loading

0 comments on commit 40958d3

Please sign in to comment.