diff --git a/python/paddle/pir_utils.py b/python/paddle/pir_utils.py index 28d261b0155fc..f16d411262a22 100644 --- a/python/paddle/pir_utils.py +++ b/python/paddle/pir_utils.py @@ -13,6 +13,8 @@ # limitations under the License. +from functools import wraps + import paddle @@ -95,3 +97,13 @@ def _switch_to_old_ir(self): "IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \ please set FLAGS_enable_pir_api = false" ) + + +def test_with_pir_api(func): + @wraps(func) + def impl(*args, **kwargs): + func(*args, **kwargs) + with IrGuard(): + func(*args, **kwargs) + + return impl diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 5a60e6884b890..467c7f7ab88f1 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1226,7 +1226,7 @@ def maximum(x, y, name=None): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [5. , 3. , inf.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.maximum(x, y) else: return _elementwise_op(LayerHelper('elementwise_max', **locals())) diff --git a/test/legacy_test/test_maximum_op.py b/test/legacy_test/test_maximum_op.py index 818bdb65fee68..a0e660112bd03 100644 --- a/test/legacy_test/test_maximum_op.py +++ b/test/legacy_test/test_maximum_op.py @@ -18,6 +18,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class ApiMaximumTest(unittest.TestCase): @@ -39,6 +40,7 @@ def setUp(self): self.np_expected3 = np.maximum(self.input_a, self.input_c) self.np_expected4 = np.maximum(self.input_b, self.input_c) + @test_with_pir_api def test_static_api(self): paddle.enable_static() with paddle.static.program_guard( @@ -119,3 +121,7 @@ def test_dynamic_api(self): res = paddle.maximum(b, c) res = res.numpy() np.testing.assert_allclose(res, self.np_expected4, rtol=1e-05) + + +if __name__ == '__main__': + unittest.main()