From bd4107b35e3698f70c04e50b559e7b51facc7251 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 28 Mar 2023 10:18:02 +0000 Subject: [PATCH] element_wise_add_fp16_test --- .../unittests/test_elementwise_add_op.py | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 10c7df64d6104..a10c6d186205a 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -143,15 +143,29 @@ def init_dtype(self): def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - if core.is_float16_supported(place): - self.check_output_with_place( - place, - atol=1e-3, - check_dygraph=self.check_dygraph(), - check_prim=self.check_prim, - ) + place = core.CUDAPlace(0) + self.check_output_with_place( + place, + atol=1e-3, + check_dygraph=self.check_dygraph(), + check_prim=self.check_prim, + ) + + def test_check_grad_normal(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X', 'Y'], 'Out', check_prim=True) + + def test_check_grad_ingore_x(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['Y'], 'Out', no_grad_set=set("X"), check_prim=True + ) + + def test_check_grad_ingore_y(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', no_grad_set=set('Y'), check_prim=True + ) @unittest.skipIf( @@ -487,11 +501,14 @@ def if_enable_cinn(self): self.enable_cinn = False +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast." +) class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp): def init_input_output(self): self.x = np.random.rand(100, 1).astype(self.dtype) self.y = np.random.rand(1).astype(self.dtype) - self.out = self.x + self.y.reshape(1, 1) + self.out = self.x + self.y class TestElementwiseAddOp_channelwise_add(TestElementwiseAddOp):