From ea0ef217156c861e403e0cd7a448f1afef25193d Mon Sep 17 00:00:00 2001 From: Chen Zhiyang <1792266893@qq.com> Date: Wed, 22 Nov 2023 19:48:26 +0800 Subject: [PATCH] fix KernelWithXShapeInferMeta type bug && test_nn_grad passed (#59212) --- paddle/fluid/operators/reshape_op.cc | 5 +++- paddle/phi/api/yaml/backward.yaml | 6 ++--- paddle/phi/api/yaml/legacy_backward.yaml | 2 +- paddle/phi/infermeta/backward.cc | 5 +++- paddle/phi/infermeta/backward.h | 4 +++- test/legacy_test/test_activation_nn_grad.py | 2 ++ test/legacy_test/test_nn_grad.py | 26 ++++++++++++++++++--- 7 files changed, 40 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 26db1962a4c56..3a57b6da5642a 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -627,9 +627,12 @@ class Reshape2GradOp : public framework::OperatorWithKernel { using CompatMetaTensor = framework::CompatMetaTensor; CompatMetaTensor xshape(ctx->GetInputVarPtrs("XShape")[0], ctx->IsRuntime()); + CompatMetaTensor out_grad( + ctx->GetInputVarPtrs(framework::GradVarName("Out"))[0], + ctx->IsRuntime()); CompatMetaTensor dx(ctx->GetOutputVarPtrs(framework::GradVarName("X"))[0], ctx->IsRuntime()); - phi::KernelWithXShapeInferMeta(xshape, &dx); + phi::KernelWithXShapeInferMeta(xshape, out_grad, &dx); } protected: diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 7ca259e7e0ea1..af49d76cfb2ff 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -852,7 +852,7 @@ output : Tensor(x_grad) infer_meta : func : KernelWithXShapeInferMeta - param : [xshape] + param : [xshape, out_grad] kernel : func : flatten_grad data_type : out_grad @@ -2277,7 +2277,7 @@ output : Tensor(x_grad) infer_meta : func : KernelWithXShapeInferMeta - param: [xshape] + param: [xshape, out_grad] kernel : func : squeeze_grad data_type : out_grad @@ -2522,7 +2522,7 @@ output : Tensor(x_grad) infer_meta : func : KernelWithXShapeInferMeta - param: [xshape] + param: [xshape, out_grad] kernel : func : unsqueeze_grad param : [xshape, out_grad] diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 067eb83f2c646..81f434ca6eb3e 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -577,7 +577,7 @@ output : Tensor(x_grad) infer_meta : func : KernelWithXShapeInferMeta - param : [xshape] + param : [xshape, out_grad] kernel : func : reshape_grad param : [out_grad] diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e7a4e16fb912c..a3eb7ce8c906b 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -584,10 +584,13 @@ void InverseGradInferMeta(const MetaTensor& out, } } -void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) { +void KernelWithXShapeInferMeta(const MetaTensor& xshape, + const MetaTensor& out, + MetaTensor* dx) { auto xshape_dims = xshape.dims(); auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); dx->set_dims(x_dims); + dx->set_dtype(out.dtype()); dx->share_lod(xshape); } diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 85d70286226a7..c1d79f2378926 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -266,7 +266,9 @@ void InverseGradInferMeta(const MetaTensor& out, const MetaTensor& dout, MetaTensor* dx); -void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx); +void KernelWithXShapeInferMeta(const MetaTensor& xshape, + const MetaTensor& out, + MetaTensor* dx); void LUGradInferMeta(const MetaTensor& x, const MetaTensor& out, diff --git a/test/legacy_test/test_activation_nn_grad.py b/test/legacy_test/test_activation_nn_grad.py index 8203206d1c77c..3f86c97e589a2 100644 --- a/test/legacy_test/test_activation_nn_grad.py +++ b/test/legacy_test/test_activation_nn_grad.py @@ -548,6 +548,7 @@ class TestPowDoubleGradCheck1(unittest.TestCase): def pow_wrapper(self, x): return paddle.pow(x[0], 2) + @test_with_pir_api @prog_scope() def func(self, place): shape = [2, 3, 7, 9] @@ -577,6 +578,7 @@ class TestPowDoubleGradCheck2(unittest.TestCase): def pow_wrapper(self, x): return paddle.pow(x[0], 1) + @test_with_pir_api @prog_scope() def func(self, place): shape = [2, 3, 7, 9] diff --git a/test/legacy_test/test_nn_grad.py b/test/legacy_test/test_nn_grad.py index 592f4d8c0c922..8554bf8e326a1 100644 --- a/test/legacy_test/test_nn_grad.py +++ b/test/legacy_test/test_nn_grad.py @@ -21,11 +21,13 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() class TestSliceOpDoubleGradCheck(unittest.TestCase): + @test_with_pir_api @prog_scope() def func(self, place): self.config() @@ -42,7 +44,7 @@ def config(self): self.ends = [3, 3, 6] self.axes = [0, 1, 2] self.x_arr = np.random.random([3, 4, 5, 2]).astype("float64") - self.inputs = paddle.create_parameter( + self.inputs = paddle.static.data( dtype="float64", shape=[3, 4, 5, 2], name='x' ) @@ -60,12 +62,13 @@ def config(self): self.ends = [3, 3, 3] self.axes = [0, 1, 2] self.x_arr = np.random.random([3, 3, 3]).astype("float64") - self.inputs = paddle.create_parameter( + self.inputs = paddle.static.data( dtype="float64", shape=[3, 3, 3], name='x3' ) class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase): + @test_with_pir_api @prog_scope() def func(self, place): shape = [7, 11] @@ -90,6 +93,7 @@ def test_grad(self): class TestReduceSumWithDimDoubleGradCheck(unittest.TestCase): + @test_with_pir_api @prog_scope() def func(self, place): shape = [7, 11] @@ -114,6 +118,7 @@ def test_grad(self): class TestReshapeDoubleGradCheck(unittest.TestCase): + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [3, 12] @@ -142,6 +147,7 @@ class TestTileDoubleGradCheck(unittest.TestCase): def tile_wrapper(self, x): return paddle.tile(x[0], [4, 9]) + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [3, 12] @@ -173,6 +179,7 @@ class TestExpandV2DoubleGradCheck(unittest.TestCase): def expand_wrapper(self, x): return paddle.expand(x[0], [4, 12]) + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [1, 12] @@ -205,6 +212,7 @@ def squeeze_wrapper(self, x): axes = [0, 2] return paddle.squeeze(x[0], axes) + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [1, 3, 1, 40] @@ -237,6 +245,7 @@ def unsqueeze_wrapper(self, x): axes = [1, 2] return paddle.unsqueeze(x[0], axes) + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [3, 40] @@ -268,6 +277,7 @@ class TestClipDoubleGradCheck(unittest.TestCase): def clip_wrapper(self, x): return paddle.clip(x[0], min=-1.0, max=1.0) + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [2, 4, 10] @@ -292,6 +302,7 @@ def test_grad(self): class TestTransposeDoubleGradCheck(unittest.TestCase): + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [3, 40] @@ -314,6 +325,7 @@ def test_grad(self): class TestTransposeDoubleGradCheckCase1(unittest.TestCase): + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [2, 3, 4, 5] @@ -340,6 +352,7 @@ def pad_wrapper(self, x): pad = [1, 1, 1, 1] return paddle.nn.functional.pad(x[0], pad) + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [2, 3, 4, 5] @@ -361,7 +374,8 @@ def func(self, place): ) def test_grad(self): - places = [base.CPUPlace()] + # places = [base.CPUPlace()] + places = [] if core.is_compiled_with_cuda(): places.append(base.CUDAPlace(0)) for p in places: @@ -369,6 +383,7 @@ def test_grad(self): class TestConstantPadDoubleGradCheckCase1(TestConstantPadDoubleGradCheck): + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [2, 3, 4, 5] @@ -387,6 +402,7 @@ class TestConcatDoubleGradCheck(unittest.TestCase): def concat_wrapper(self, x): return paddle.concat(x, axis=0) + @test_with_pir_api @prog_scope() def func(self, place): x_shape = [2, 3, 4, 5] @@ -421,6 +437,7 @@ def test_grad(self): class TestAvgPool2DDoubleGradCheckCase1(unittest.TestCase): + @test_with_pir_api @prog_scope() def func(self, place): input_NCHW = paddle.static.data( @@ -451,6 +468,7 @@ def pool2d_wrapper(self, x): x[0], kernel_size=2, data_format="NHWC" ) + @test_with_pir_api @prog_scope() def func(self, place): input_NHWC = paddle.static.data( @@ -487,6 +505,7 @@ def pool2d_wrapper(self, x): x[0], kernel_size=2, padding=[1, 1] ) + @test_with_pir_api @prog_scope() def func(self, place): input_NCHW = paddle.static.data( @@ -520,6 +539,7 @@ class TestAvgPool2DDoubleGradCheckCase4(unittest.TestCase): def pool2d_wrapper(self, x): return paddle.nn.functional.avg_pool2d(x[0], kernel_size=[4, 4]) + @test_with_pir_api @prog_scope() def func(self, place): input_NCHW = paddle.static.data(