Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jul 20, 2023
1 parent 9f29a17 commit 8012e1d
Show file tree
Hide file tree
Showing 16 changed files with 508 additions and 8 deletions.
5 changes: 5 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ set(pnnx_pass_level2_SRCS
pass_level2/Tensor_copy.cpp
pass_level2/Tensor_expand.cpp
pass_level2/Tensor_expand_as.cpp
pass_level2/Tensor_fill.cpp
pass_level2/Tensor_index.cpp
pass_level2/Tensor_index_put.cpp
pass_level2/Tensor_masked_fill.cpp
pass_level2/Tensor_new_empty.cpp
pass_level2/Tensor_new_ones.cpp
Expand All @@ -189,6 +191,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/Tensor_reshape.cpp
pass_level2/Tensor_select.cpp
pass_level2/Tensor_slice.cpp
pass_level2/Tensor_type_as.cpp
pass_level2/Tensor_view.cpp
pass_level2/torch_addmm.cpp
pass_level2/torch_amax.cpp
Expand Down Expand Up @@ -252,6 +255,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_sum.cpp
pass_level2/torch_permute.cpp
pass_level2/torch_tensor_split.cpp
pass_level2/torch_topk.cpp
pass_level2/torch_transpose.cpp
pass_level2/torch_unbind.cpp
pass_level2/torch_unsqueeze.cpp
Expand Down Expand Up @@ -320,6 +324,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/eliminate_noop_slice.cpp
pass_level5/eliminate_noop_view_reshape.cpp
pass_level5/eliminate_reshape_shape_expression.cpp
pass_level5/eliminate_type_as.cpp
pass_level5/eval_expression.cpp
pass_level5/fold_constants.cpp
pass_level5/fuse_adjacent_reshape.cpp
Expand Down
29 changes: 26 additions & 3 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,7 @@ static std::string expand_expression(const Operator* op)
{
std::string binaryop;
if (t == "atan2") binaryop = "torch.atan2";
if (t == "fmod") binaryop = "torch.fmod";
if (t == "pow") binaryop = "torch.pow";

std::string a = exprstack.top();
Expand All @@ -1311,14 +1312,15 @@ static std::string expand_expression(const Operator* op)
std::string r = binaryop + "(" + a + ", " + b + ")";
exprstack.push(r);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "fmod" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
{
std::string binaryop;
if (t == "add") binaryop = "+";
if (t == "sub") binaryop = "-";
if (t == "mul") binaryop = "*";
if (t == "div") binaryop = "/";
if (t == "floor_divide") binaryop = "//";
if (t == "fmod") binaryop = "%";
if (t == "and") binaryop = "&";
if (t == "or") binaryop = "|";
if (t == "xor") binaryop = "^";
Expand Down Expand Up @@ -2154,9 +2156,30 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());

for (size_t i = 1; i < op->inputs.size(); i++)
if (op->inputnames.size() == op->inputs.size())
{
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
for (size_t i = 1; i < op->inputs.size(); i++)
{
if (!op->inputnames[i].empty())
continue;

fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
}

for (size_t i = 1; i < op->inputs.size(); i++)
{
if (op->inputnames[i].empty())
continue;

fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str());
}
}
else
{
for (size_t i = 1; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
}
}
}
else
Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/src/pass_level0/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ static bool value_link_input(const torch::jit::Value* v, const std::vector<torch
|| optype == "aten::empty_like"
|| optype == "aten::full_like"
|| optype == "aten::ones_like"
|| optype == "aten::zeros_like")
|| optype == "aten::zeros_like"
|| optype == "aten::_shape_as_tensor")
return false;
}

Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/src/pass_level1/nn_Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class Linear : public FuseModulePass

op->params["in_features"] = weight.size(1);
op->params["out_features"] = weight.size(0);
op->params["bias"] = mod.hasattr("bias");
op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor();

op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
if (mod.hasattr("bias") && mod.attr("bias").isTensor())
{
op->attrs["bias"] = mod.attr("bias").toTensor();
}
Expand Down
41 changes: 41 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_fill.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_fill : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 value
aten::fill op_0 2 1 input value out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.fill";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_fill, 20)

} // namespace pnnx
43 changes: 43 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_index_put.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_index_put : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 indices
pnnx.Input input_2 0 1 values
prim::Constant op_0 0 1 accumulate value=%accumulate
aten::index_put op_1 4 1 input indices values accumulate out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.index_put";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_index_put, 20)

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Tensor_masked_fill : public GraphRewriterPass
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 mask
pnnx.Input input_2 0 1 value
aten::masked_fill op_1 3 1 input mask value out
aten::masked_fill op_0 3 1 input mask value out
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand Down
41 changes: 41 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_type_as.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_type_as : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 other
aten::type_as op_0 2 1 input other out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.type_as";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_type_as, 20)

} // namespace pnnx
44 changes: 44 additions & 0 deletions tools/pnnx/src/pass_level2/torch_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_topk : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 7
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 k
pnnx.Input input_2 0 1 dim
pnnx.Input input_3 0 1 largest
pnnx.Input input_4 0 1 sorted
aten::topk op_0 5 2 input k dim largest sorted values indices
pnnx.Output output 2 0 values indices
)PNNXIR";
}

const char* type_str() const
{
return "torch.topk";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_topk, 20)

} // namespace pnnx
3 changes: 3 additions & 0 deletions tools/pnnx/src/pass_level3/fuse_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ static bool operand_maybe_tensor(const Operand* operand)
if (op->type == "aten::atan2"
|| op->type == "aten::div"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::mul"
|| op->type == "aten::pow"
|| op->type == "aten::remainder")
Expand Down Expand Up @@ -404,6 +405,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
else if (op->type == "aten::atan2"
|| op->type == "aten::div"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::mul"
|| op->type == "aten::pow"
|| op->type == "aten::remainder")
Expand Down Expand Up @@ -562,6 +564,7 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan
|| op->type == "aten::exp"
|| op->type == "aten::floor"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::log"
|| op->type == "aten::log10"
|| op->type == "aten::mul"
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/src/pass_level5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "pass_level5/eliminate_noop_slice.h"
#include "pass_level5/eliminate_noop_view_reshape.h"
#include "pass_level5/eliminate_reshape_shape_expression.h"
#include "pass_level5/eliminate_type_as.h"
#include "pass_level5/eval_expression.h"
#include "pass_level5/fuse_adjacent_reshape.h"
#include "pass_level5/fuse_channel_shuffle.h"
Expand Down Expand Up @@ -112,6 +113,7 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons
eliminate_noop_cat(g);

eliminate_dropout(g);
eliminate_type_as(g);

eliminate_noop_upsample(g);

Expand Down
Loading

0 comments on commit 8012e1d

Please sign in to comment.