Skip to content

Commit

Permalink
【PIR API adaptor No.139】 Migrate logsumexp into pir (#58843)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioLulab authored Nov 9, 2023
1 parent d97d215 commit ed3ed2d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2626,7 +2626,7 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
"""
reduce_all, axis = _get_reduce_axis(axis, x)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.logsumexp(x, axis, keepdim, reduce_all)
else:
check_variable_and_dtype(
Expand Down
9 changes: 6 additions & 3 deletions test/legacy_test/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def ref_logsumexp(x, axis=None, keepdim=False, reduce_all=False):
Expand Down Expand Up @@ -87,14 +88,15 @@ def set_attrs_addition(self):
pass

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['X'],
['Out'],
user_defined_grads=self.user_defined_grads,
user_defined_grad_outputs=self.user_defined_grad_outputs,
check_pir=True,
)

def calc_grad(self):
Expand Down Expand Up @@ -212,11 +214,11 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)

def set_attrs(self):
pass
Expand Down Expand Up @@ -258,6 +260,7 @@ def api_case(self, axis=None, keepdim=False):
np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
def test_api(self):
self.api_case()
self.api_case(2)
Expand Down

0 comments on commit ed3ed2d

Please sign in to comment.