diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 6cdcc2bd4aefdb..4a192fd48c84b6 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1623,7 +1623,7 @@ def forward(self, x): # train mode: use mini-batch stats, eval mode: use global stats # use_global_stats only support False in sync_batch_norm - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): sync_batch_norm_out, _, _, _, _, _ = _C_ops.sync_batch_norm_( x, self._mean, diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 357f1d8d592660..a66b39d9db7efb 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -293,7 +293,7 @@ def test_class_api(self): bn = paddle.nn.SyncBatchNorm(2) paddle.jit.to_static(bn) self.assertNotIn("_jst.IfElse", bn.forward.code) - self.assertIn("if in_dynamic_mode()", bn.forward.code) + self.assertIn("if in_dynamic_or_pir_mode()", bn.forward.code) @test_ast_only @test_legacy_and_pir_api diff --git a/test/legacy_test/test_layers.py b/test/legacy_test/test_layers.py index 8c8c52ed1abf2c..88c6243862a217 100644 --- a/test/legacy_test/test_layers.py +++ b/test/legacy_test/test_layers.py @@ -34,6 +34,7 @@ rank_attention, shuffle_batch, ) +from paddle.pir_utils import test_with_pir_api from paddle.tensor import random @@ -275,6 +276,7 @@ def test_type(): self.assertRaises(TypeError, test_type) + @test_with_pir_api def test_SyncBatchNorm(self): if core.is_compiled_with_cuda(): with self.static_graph(): diff --git a/test/legacy_test/test_sync_batch_norm_op.py b/test/legacy_test/test_sync_batch_norm_op.py index 0375ee7c527763..17daa24996b4ff 100644 --- a/test/legacy_test/test_sync_batch_norm_op.py +++ b/test/legacy_test/test_sync_batch_norm_op.py @@ -30,8 +30,9 @@ import paddle from paddle import base, nn -from paddle.base import Program, core, program_guard +from paddle.base import core from paddle.base.framework import in_dygraph_mode +from paddle.pir_utils import test_with_pir_api _set_use_system_allocator(True) @@ -364,7 +365,9 @@ def test_errors(self): return cleanup = enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): my_sync_batch_norm = paddle.nn.SyncBatchNorm(10) x1 = base.create_lod_tensor( np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], base.CUDAPlace(0) @@ -382,11 +385,14 @@ def test_errors(self): class TestConvertSyncBatchNorm(unittest.TestCase): + @test_with_pir_api def test_convert(self): if not core.is_compiled_with_cuda(): return - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): compare_model = paddle.nn.Sequential( paddle.nn.Conv2D(3, 5, 3), paddle.nn.BatchNorm2D(5), @@ -410,6 +416,7 @@ def test_convert(self): class TestConvertSyncBatchNormCast1(unittest.TestCase): + @test_with_pir_api def test_convert(self): if not core.is_compiled_with_cuda(): return