Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
【PaddlePaddle Hackathon 77】Add squeeze op (#874)
Browse files Browse the repository at this point in the history
添加 squeeze 算子,后续可以添加 paddle 模型支持

详细内容参考为神经网络编译器 CINN 增加 squeeze 算子,按照上次直播的要求把主要内容放在了contrib文件夹中,并添加了单元测试
  • Loading branch information
zrr1999 authored Sep 1, 2022
1 parent af52b7a commit 9e1a1b6
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 1 deletion.
8 changes: 8 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) {
return instr.GetOutput(0);
}

Variable NetBuilder::Squeeze(const Variable& operand, const std::vector<int>& axes) {
Instruction instr("squeeze", {operand});
instr.SetAttr("axes", axes);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::Conv2d(const Variable& a,
const Variable& b,
const std::vector<int>& strides,
Expand Down
5 changes: 5 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class NetBuilder : public BaseBuilder {
*/
Variable Cast(const Variable& operand, const std::string& dtype);

/**
* Squeeze Variable x along the given axes.
*/
Variable Squeeze(const Variable& operand, const std::vector<int>& axes);

/**
* The convolution2D layer calculates the output based on the input, filter
* and strides, paddings, dilations, groups parameters.
Expand Down
160 changes: 160 additions & 0 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,5 +293,165 @@ TEST(net_build, program_execute_cast) {
}
}

TEST(net_build, program_execute_squeeze_case1) {
const int B = 4;
const int C = 1;
const int H = 7;
const int W = 1;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In");
Variable output = builder.Squeeze(input, {1});
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_shape.size(), 3UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H);
EXPECT_EQ(output_shape[2], W);

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int c = 0; c < C; ++c) {
VLOG(6) << "b = " << b << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + C * b));
float in_data = input_data[index];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(in_data, out_data);
}
VLOG(6) << line;
}
}
}
}

TEST(net_build, program_execute_squeeze_case2) {
const int B = 4;
const int C = 1;
const int H = 7;
const int W = 1;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In");
Variable output = builder.Squeeze(input, {1, 3});
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H);

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int c = 0; c < C; ++c) {
VLOG(6) << "b = " << b << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + C * b));
float in_data = input_data[index];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(in_data, out_data);
}
VLOG(6) << line;
}
}
}
}

TEST(net_build, program_execute_squeeze_case3) {
const int B = 4;
const int C = 1;
const int H = 7;
const int W = 1;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In");
Variable output = builder.Squeeze(input, {});
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H);

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int c = 0; c < C; ++c) {
VLOG(6) << "b = " << b << ", c = " << c;
for (int h = 0; h < H; ++h) {
std::string line;
for (int w = 0; w < W; ++w) {
int index = w + W * (h + H * (c + C * b));
float in_data = input_data[index];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(in_data, out_data);
}
VLOG(6) << line;
}
}
}
}

} // namespace frontend
} // namespace cinn
4 changes: 3 additions & 1 deletion cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ core_gather_headers()

gather_srcs(cinnapi_src SRCS
cast.cc
squeeze.cc
clip.cc
)
)

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
cc_test(test_squeeze SRCS squeeze_test.cc DEPS cinncore)
cc_test(test_clip SRCS clip_test.cc DEPS cinncore)
Loading

0 comments on commit 9e1a1b6

Please sign in to comment.