Skip to content

Commit

Permalink
【PIR API adaptor No.222】Migrate paddle.nn.SyncBatchNorm into pir (#59077
Browse files Browse the repository at this point in the history
)
  • Loading branch information
xiaoyewww authored Nov 23, 2023
1 parent 91b829c commit e1f3e75
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ def forward(self, x):

# train mode: use mini-batch stats, eval mode: use global stats
# use_global_stats only support False in sync_batch_norm
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
sync_batch_norm_out, _, _, _, _, _ = _C_ops.sync_batch_norm_(
x,
self._mean,
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_convert_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_class_api(self):
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if in_dynamic_mode()", bn.forward.code)
self.assertIn("if in_dynamic_or_pir_mode()", bn.forward.code)

@test_ast_only
@test_legacy_and_pir_api
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
rank_attention,
shuffle_batch,
)
from paddle.pir_utils import test_with_pir_api
from paddle.tensor import random


Expand Down Expand Up @@ -275,6 +276,7 @@ def test_type():

self.assertRaises(TypeError, test_type)

@test_with_pir_api
def test_SyncBatchNorm(self):
if core.is_compiled_with_cuda():
with self.static_graph():
Expand Down
13 changes: 10 additions & 3 deletions test/legacy_test/test_sync_batch_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@

import paddle
from paddle import base, nn
from paddle.base import Program, core, program_guard
from paddle.base import core
from paddle.base.framework import in_dygraph_mode
from paddle.pir_utils import test_with_pir_api

_set_use_system_allocator(True)

Expand Down Expand Up @@ -364,7 +365,9 @@ def test_errors(self):
return

cleanup = enable_static()
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
my_sync_batch_norm = paddle.nn.SyncBatchNorm(10)
x1 = base.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], base.CUDAPlace(0)
Expand All @@ -382,11 +385,14 @@ def test_errors(self):


class TestConvertSyncBatchNorm(unittest.TestCase):
@test_with_pir_api
def test_convert(self):
if not core.is_compiled_with_cuda():
return

with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
compare_model = paddle.nn.Sequential(
paddle.nn.Conv2D(3, 5, 3),
paddle.nn.BatchNorm2D(5),
Expand All @@ -410,6 +416,7 @@ def test_convert(self):


class TestConvertSyncBatchNormCast1(unittest.TestCase):
@test_with_pir_api
def test_convert(self):
if not core.is_compiled_with_cuda():
return
Expand Down

0 comments on commit e1f3e75

Please sign in to comment.