Skip to content

Commit

Permalink
fix(new-ir): support SyncBatchNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyewww committed Nov 17, 2023
1 parent 7161b06 commit 08a404a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ def forward(self, x):

# train mode: use mini-batch stats, eval mode: use global stats
# use_global_stats only support False in sync_batch_norm
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
sync_batch_norm_out, _, _, _, _, _ = _C_ops.sync_batch_norm_(
x,
self._mean,
Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_convert_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle import base
from paddle.jit.dy2static.convert_call_func import CONVERSION_OPTIONS
from paddle.jit.dy2static.utils import func_to_source_code
from paddle.pir_utils import test_with_pir_api

SEED = 2020
np.random.seed(SEED)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
rank_attention,
shuffle_batch,
)
from paddle.pir_utils import test_with_pir_api
from paddle.tensor import random


Expand Down Expand Up @@ -275,6 +276,7 @@ def test_type():

self.assertRaises(TypeError, test_type)

@test_with_pir_api
def test_SyncBatchNorm(self):
if core.is_compiled_with_cuda():
with self.static_graph():
Expand Down

0 comments on commit 08a404a

Please sign in to comment.