Skip to content

Commit

Permalink
[PIR] No.46 Migrate paddle.nn.functional.pad into pir (#57348)
Browse files Browse the repository at this point in the history
  • Loading branch information
BeingGod authored Sep 26, 2023
1 parent 146489a commit 1f4d4cd
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
'stack',
'poisson',
'gumbel_softmax',
'pad',
'pad3d',
'squeeze',
'unsqueeze',
'tril',
Expand Down Expand Up @@ -104,6 +106,8 @@
'stack',
'poisson',
'gumbel_softmax',
'pad',
'pad3d',
'squeeze',
'unsqueeze',
'tril',
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):
paddings = pad
pad_value = value

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.pad(x, paddings, float(pad_value))
return out

Expand Down Expand Up @@ -1712,7 +1712,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):

unsqueezed_dim = []

if isinstance(pad, Variable):
if isinstance(pad, (Variable, pir.OpResult)):
if data_format in ["NCL", "NCHW", "NCDHW"]:
data_format = "NCDHW"
if x_dim == 3:
Expand Down Expand Up @@ -1756,7 +1756,7 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):
unsqueezed_dim = [1]
x = unsqueeze(x, axis=unsqueezed_dim)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if isinstance(pad, Variable):
pad = pad.tolist()
out = _C_ops.pad3d(x, pad, mode, value, data_format)
Expand Down
14 changes: 8 additions & 6 deletions test/legacy_test/test_pad3d_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def setUp(self):
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])

def test_check_output(self):
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_new_ir=True)

def get_dtype(self):
return np.float64
Expand Down Expand Up @@ -214,10 +214,12 @@ def get_dtype(self):
return np.float16

def test_check_output(self):
self.check_output(atol=1e-3)
self.check_output(atol=1e-3, check_new_ir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=1.5e-3)
self.check_grad(
['X'], 'Out', max_relative_error=1.5e-3, check_new_ir=True
)

cls_name = "{}_{}".format(parent.__name__, "FP16OP")
TestPad3dFp16.__name__ = cls_name
Expand Down Expand Up @@ -251,12 +253,12 @@ def get_dtype(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-2)
self.check_output_with_place(place, atol=1e-2, check_new_ir=True)

def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=1e-2
place, ['X'], 'Out', max_relative_error=1e-2, check_new_ir=True
)

cls_name = "{}_{}".format(parent.__name__, "BF16OP")
Expand Down
12 changes: 7 additions & 5 deletions test/legacy_test/test_pad_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def get_dtype(self):
return np.float64

def test_check_output(self):
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)

def initTestCase(self):
self.shape = (16, 16)
Expand Down Expand Up @@ -101,7 +101,7 @@ def get_dtype(self):
return np.float16

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)

cls_name = "{}_{}".format(parent.__name__, "Fp16")
TestPadFp16.__name__ = cls_name
Expand Down Expand Up @@ -253,11 +253,13 @@ def initTestCase(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_new_ir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_new_ir=True
)


if __name__ == '__main__':
Expand Down

0 comments on commit 1f4d4cd

Please sign in to comment.