Skip to content

Commit

Permalink
fix(new-ir): support SyncBatchNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyewww committed Nov 16, 2023
1 parent 8022b63 commit 1d92238
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ def forward(self, x):

# train mode: use mini-batch stats, eval mode: use global stats
# use_global_stats only support False in sync_batch_norm
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
sync_batch_norm_out, _, _, _, _, _ = _C_ops.sync_batch_norm_(
x,
self._mean,
Expand Down
3 changes: 3 additions & 0 deletions test/dygraph_to_static/test_convert_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.jit.dy2static.convert_call_func import CONVERSION_OPTIONS
from paddle.jit.dy2static.utils import func_to_source_code
from paddle.pir_utils import test_with_pir_api

SEED = 2020
np.random.seed(SEED)
Expand Down Expand Up @@ -287,13 +288,15 @@ def test_functional_api(self):
self.assertIn("if in_dynamic_or_pir_mode()", func.code)

@test_ast_only
@test_with_pir_api
def test_class_api(self):
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if in_dynamic_mode()", bn.forward.code)

@test_ast_only
@test_with_pir_api
def test_class_patch_api(self):
paddle.nn.SyncBatchNorm.forward = forward
bn = paddle.nn.SyncBatchNorm(2)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
rank_attention,
shuffle_batch,
)
from paddle.pir_utils import test_with_pir_api
from paddle.tensor import random


Expand Down Expand Up @@ -275,6 +276,7 @@ def test_type():

self.assertRaises(TypeError, test_type)

@test_with_pir_api
def test_SyncBatchNorm(self):
if core.is_compiled_with_cuda():
with self.static_graph():
Expand Down

0 comments on commit 1d92238

Please sign in to comment.