From ed3ed2dacc8a3c0c068ccc7510b742423b683ff0 Mon Sep 17 00:00:00 2001 From: Lu Qi <61354321+MarioLulab@users.noreply.github.com> Date: Thu, 9 Nov 2023 14:57:23 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.139=E3=80=91?= =?UTF-8?q?=20Migrate=20logsumexp=20into=20pir=20(#58843)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/tensor/math.py | 2 +- test/legacy_test/test_logsumexp.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 01e01d6b449cb..badf35446001b 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2626,7 +2626,7 @@ def logsumexp(x, axis=None, keepdim=False, name=None): """ reduce_all, axis = _get_reduce_axis(axis, x) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logsumexp(x, axis, keepdim, reduce_all) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_logsumexp.py b/test/legacy_test/test_logsumexp.py index 67cc953c5fdc5..90c40e08860fc 100644 --- a/test/legacy_test/test_logsumexp.py +++ b/test/legacy_test/test_logsumexp.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api def ref_logsumexp(x, axis=None, keepdim=False, reduce_all=False): @@ -87,7 +88,7 @@ def set_attrs_addition(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -95,6 +96,7 @@ def test_check_grad(self): ['Out'], user_defined_grads=self.user_defined_grads, user_defined_grad_outputs=self.user_defined_grad_outputs, + check_pir=True, ) def calc_grad(self): @@ -212,11 +214,11 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') + self.check_grad_with_place(place, ['X'], 'Out', check_pir=True) def set_attrs(self): pass @@ -258,6 +260,7 @@ def api_case(self, axis=None, keepdim=False): np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-05) paddle.enable_static() + @test_with_pir_api def test_api(self): self.api_case() self.api_case(2)