Skip to content

Commit

Permalink
fix KernelWithXShapeInferMeta type bug && test_nn_grad passed (#59212)
Browse files Browse the repository at this point in the history
  • Loading branch information
changeyoung98 authored Nov 22, 2023
1 parent 4a25662 commit ea0ef21
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 10 deletions.
5 changes: 4 additions & 1 deletion paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,12 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor xshape(ctx->GetInputVarPtrs("XShape")[0],
ctx->IsRuntime());
CompatMetaTensor out_grad(
ctx->GetInputVarPtrs(framework::GradVarName("Out"))[0],
ctx->IsRuntime());
CompatMetaTensor dx(ctx->GetOutputVarPtrs(framework::GradVarName("X"))[0],
ctx->IsRuntime());
phi::KernelWithXShapeInferMeta(xshape, &dx);
phi::KernelWithXShapeInferMeta(xshape, out_grad, &dx);
}

protected:
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
param : [xshape, out_grad]
kernel :
func : flatten_grad
data_type : out_grad
Expand Down Expand Up @@ -2277,7 +2277,7 @@
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
param: [xshape, out_grad]
kernel :
func : squeeze_grad
data_type : out_grad
Expand Down Expand Up @@ -2522,7 +2522,7 @@
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
param: [xshape, out_grad]
kernel :
func : unsqueeze_grad
param : [xshape, out_grad]
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
param : [xshape, out_grad]
kernel :
func : reshape_grad
param : [out_grad]
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,13 @@ void InverseGradInferMeta(const MetaTensor& out,
}
}

void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) {
void KernelWithXShapeInferMeta(const MetaTensor& xshape,
const MetaTensor& out,
MetaTensor* dx) {
auto xshape_dims = xshape.dims();
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
dx->set_dims(x_dims);
dx->set_dtype(out.dtype());
dx->share_lod(xshape);
}

Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ void InverseGradInferMeta(const MetaTensor& out,
const MetaTensor& dout,
MetaTensor* dx);

void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx);
void KernelWithXShapeInferMeta(const MetaTensor& xshape,
const MetaTensor& out,
MetaTensor* dx);

void LUGradInferMeta(const MetaTensor& x,
const MetaTensor& out,
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_activation_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ class TestPowDoubleGradCheck1(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 2)

@test_with_pir_api
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
Expand Down Expand Up @@ -577,6 +578,7 @@ class TestPowDoubleGradCheck2(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 1)

@test_with_pir_api
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
Expand Down
26 changes: 23 additions & 3 deletions test/legacy_test/test_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()


class TestSliceOpDoubleGradCheck(unittest.TestCase):
@test_with_pir_api
@prog_scope()
def func(self, place):
self.config()
Expand All @@ -42,7 +44,7 @@ def config(self):
self.ends = [3, 3, 6]
self.axes = [0, 1, 2]
self.x_arr = np.random.random([3, 4, 5, 2]).astype("float64")
self.inputs = paddle.create_parameter(
self.inputs = paddle.static.data(
dtype="float64", shape=[3, 4, 5, 2], name='x'
)

Expand All @@ -60,12 +62,13 @@ def config(self):
self.ends = [3, 3, 3]
self.axes = [0, 1, 2]
self.x_arr = np.random.random([3, 3, 3]).astype("float64")
self.inputs = paddle.create_parameter(
self.inputs = paddle.static.data(
dtype="float64", shape=[3, 3, 3], name='x3'
)


class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
@test_with_pir_api
@prog_scope()
def func(self, place):
shape = [7, 11]
Expand All @@ -90,6 +93,7 @@ def test_grad(self):


class TestReduceSumWithDimDoubleGradCheck(unittest.TestCase):
@test_with_pir_api
@prog_scope()
def func(self, place):
shape = [7, 11]
Expand All @@ -114,6 +118,7 @@ def test_grad(self):


class TestReshapeDoubleGradCheck(unittest.TestCase):
@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [3, 12]
Expand Down Expand Up @@ -142,6 +147,7 @@ class TestTileDoubleGradCheck(unittest.TestCase):
def tile_wrapper(self, x):
return paddle.tile(x[0], [4, 9])

@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [3, 12]
Expand Down Expand Up @@ -173,6 +179,7 @@ class TestExpandV2DoubleGradCheck(unittest.TestCase):
def expand_wrapper(self, x):
return paddle.expand(x[0], [4, 12])

@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [1, 12]
Expand Down Expand Up @@ -205,6 +212,7 @@ def squeeze_wrapper(self, x):
axes = [0, 2]
return paddle.squeeze(x[0], axes)

@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [1, 3, 1, 40]
Expand Down Expand Up @@ -237,6 +245,7 @@ def unsqueeze_wrapper(self, x):
axes = [1, 2]
return paddle.unsqueeze(x[0], axes)

@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [3, 40]
Expand Down Expand Up @@ -268,6 +277,7 @@ class TestClipDoubleGradCheck(unittest.TestCase):
def clip_wrapper(self, x):
return paddle.clip(x[0], min=-1.0, max=1.0)

@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [2, 4, 10]
Expand All @@ -292,6 +302,7 @@ def test_grad(self):


class TestTransposeDoubleGradCheck(unittest.TestCase):
@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [3, 40]
Expand All @@ -314,6 +325,7 @@ def test_grad(self):


class TestTransposeDoubleGradCheckCase1(unittest.TestCase):
@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
Expand All @@ -340,6 +352,7 @@ def pad_wrapper(self, x):
pad = [1, 1, 1, 1]
return paddle.nn.functional.pad(x[0], pad)

@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
Expand All @@ -361,14 +374,16 @@ def func(self, place):
)

def test_grad(self):
places = [base.CPUPlace()]
# places = [base.CPUPlace()]
places = []
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for p in places:
self.func(p)


class TestConstantPadDoubleGradCheckCase1(TestConstantPadDoubleGradCheck):
@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
Expand All @@ -387,6 +402,7 @@ class TestConcatDoubleGradCheck(unittest.TestCase):
def concat_wrapper(self, x):
return paddle.concat(x, axis=0)

@test_with_pir_api
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
Expand Down Expand Up @@ -421,6 +437,7 @@ def test_grad(self):


class TestAvgPool2DDoubleGradCheckCase1(unittest.TestCase):
@test_with_pir_api
@prog_scope()
def func(self, place):
input_NCHW = paddle.static.data(
Expand Down Expand Up @@ -451,6 +468,7 @@ def pool2d_wrapper(self, x):
x[0], kernel_size=2, data_format="NHWC"
)

@test_with_pir_api
@prog_scope()
def func(self, place):
input_NHWC = paddle.static.data(
Expand Down Expand Up @@ -487,6 +505,7 @@ def pool2d_wrapper(self, x):
x[0], kernel_size=2, padding=[1, 1]
)

@test_with_pir_api
@prog_scope()
def func(self, place):
input_NCHW = paddle.static.data(
Expand Down Expand Up @@ -520,6 +539,7 @@ class TestAvgPool2DDoubleGradCheckCase4(unittest.TestCase):
def pool2d_wrapper(self, x):
return paddle.nn.functional.avg_pool2d(x[0], kernel_size=[4, 4])

@test_with_pir_api
@prog_scope()
def func(self, place):
input_NCHW = paddle.static.data(
Expand Down

0 comments on commit ea0ef21

Please sign in to comment.