Skip to content

Commit

Permalink
【PIR API adaptor No.50、57、67】Migrate some ops into pir (PaddlePaddle#…
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 authored Oct 31, 2023
1 parent 7c7cb19 commit 8937e68
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4271,7 +4271,7 @@ def cumprod(x, dim=None, dtype=None, name=None):
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.cumprod(x, dim)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -4981,7 +4981,7 @@ def digamma(x, name=None):
[ nan , 5.32286835]])
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.digamma(x)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -5337,7 +5337,7 @@ def erfinv(x, name=None):
[ 0. , 0.47693631, -inf. ])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.erfinv(x)
else:
check_variable_and_dtype(
Expand Down
5 changes: 3 additions & 2 deletions test/legacy_test/test_cumprod_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_check_output(self):
for dim in range(-len(self.shape), len(self.shape)):
for zero_num in self.zero_nums:
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.check_output()
self.check_output(check_pir=True)

# test backward.
def test_check_grad(self):
Expand All @@ -133,13 +133,14 @@ def test_check_grad(self):
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.init_grad_input_output(dim)
if self.dtype == np.float64:
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)
else:
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
check_pir=True,
)


Expand Down
12 changes: 7 additions & 5 deletions test/legacy_test/test_digamma_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ def init_dtype_type(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestDigammaOpFp32(TestDigammaOp):
def init_dtype_type(self):
self.dtype = np.float32

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestDigammaFP16Op(TestDigammaOp):
Expand Down Expand Up @@ -87,10 +87,12 @@ def init_dtype_type(self):

def test_check_output(self):
# bfloat16 needs to set the parameter place
self.check_output_with_place(core.CUDAPlace(0))
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)

def test_check_grad_normal(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Out', check_pir=True
)


class TestDigammaAPI(unittest.TestCase):
Expand Down
11 changes: 4 additions & 7 deletions test/legacy_test/test_erfinv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def init_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.gradient],
user_defined_grad_outputs=self.grad_out,
check_pir=True,
)


Expand Down Expand Up @@ -143,15 +144,11 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
)
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


if __name__ == "__main__":
Expand Down

0 comments on commit 8937e68

Please sign in to comment.