From 08a404a0b3ab23c511abb6200cc29d9871df9b45 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Thu, 16 Nov 2023 14:43:26 +0000 Subject: [PATCH] fix(new-ir): support SyncBatchNorm --- python/paddle/nn/layer/norm.py | 2 +- test/dygraph_to_static/test_convert_call.py | 1 + test/legacy_test/test_layers.py | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 6cdcc2bd4aefd..4a192fd48c84b 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1623,7 +1623,7 @@ def forward(self, x): # train mode: use mini-batch stats, eval mode: use global stats # use_global_stats only support False in sync_batch_norm - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): sync_batch_norm_out, _, _, _, _, _ = _C_ops.sync_batch_norm_( x, self._mean, diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 357f1d8d59266..5da4ad5ca398b 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -28,6 +28,7 @@ from paddle import base from paddle.jit.dy2static.convert_call_func import CONVERSION_OPTIONS from paddle.jit.dy2static.utils import func_to_source_code +from paddle.pir_utils import test_with_pir_api SEED = 2020 np.random.seed(SEED) diff --git a/test/legacy_test/test_layers.py b/test/legacy_test/test_layers.py index 8c8c52ed1abf2..88c6243862a21 100644 --- a/test/legacy_test/test_layers.py +++ b/test/legacy_test/test_layers.py @@ -34,6 +34,7 @@ rank_attention, shuffle_batch, ) +from paddle.pir_utils import test_with_pir_api from paddle.tensor import random @@ -275,6 +276,7 @@ def test_type(): self.assertRaises(TypeError, test_type) + @test_with_pir_api def test_SyncBatchNorm(self): if core.is_compiled_with_cuda(): with self.static_graph():