Skip to content

Commit

Permalink
unittest change to fill_constant+elementwise_pow
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Feb 26, 2023
1 parent 8b2e4c7 commit 36ae90a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 21 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/prim/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_test_old(
operator
elementwise_mul_op
elementwise_sub_op
elementwise_pow_op
fill_constant_op
activation_op
phi_api
Expand Down
43 changes: 24 additions & 19 deletions paddle/fluid/prim/tests/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) {
target_block,
grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(4));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(5));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
Expand All @@ -204,36 +204,41 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");

ASSERT_EQ(grad_ops[0]->Type(), "pow");
ASSERT_EQ(grad_ops[0]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[0]->Inputs().at("X")[0], "b");
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[0]->GetAttr("factor")),
static_cast<float>(2.0));
ASSERT_EQ(grad_ops[0]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[0]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[1]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[1]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
ASSERT_EQ(grad_ops[1]->Type(), "elementwise_pow");
ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0], "b");
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[2]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[2]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[3]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[3]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y")[0],
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0],
grad_ops[2]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[4]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}

TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
Expand Down Expand Up @@ -376,7 +381,7 @@ TEST(StaticPrim, TestFlags) {
USE_OP_ITSELF(fill_constant);
USE_OP_ITSELF(tanh);
USE_OP_ITSELF(tanh_grad);
USE_OP_ITSELF(pow);
USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_sub);
USE_OP_ITSELF(elementwise_pow);
USE_OP_ITSELF(scale);
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
set(),
tuple(),
(
'pow',
'fill_constant',
'elementwise_pow',
'fill_constant',
'elementwise_sub',
'elementwise_mul',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def setUp(self):
self.grad_sub_block = tuple()
self.desired_ops = 'tanh_grad'
self.desired_ops_no_skip = (
'pow',
'fill_constant',
'elementwise_pow',
'fill_constant',
'elementwise_sub',
'elementwise_mul',
Expand Down

0 comments on commit 36ae90a

Please sign in to comment.