Skip to content

Commit

Permalink
[Prim] support amp O1 in prim (#52598)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer authored Apr 9, 2023
1 parent b60f48c commit 58d5af0
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 2 deletions.
2 changes: 2 additions & 0 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def _create_amp_program(self, is_infer_mode=False):
amp_program, self._amp_list
)
if is_infer_mode:
if self._hooker:
amp_program = self._hooker.after_infer(amp_program)
return amp_program
else:
train_amp_program = self._append_backward_desc(amp_program)
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import warnings
import weakref

from paddle.amp.auto_cast import _in_amp_guard
from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import (
Expand Down Expand Up @@ -1228,7 +1227,7 @@ def _build_once(self, cache_key):
partial_program = partial_program_from(
concrete_program, cache_key.class_instance is not None
)
if core._is_fwd_prim_enabled() and not _in_amp_guard():
if core._is_fwd_prim_enabled():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program)
)
Expand Down
104 changes: 104 additions & 0 deletions test/prim/process/test_prim_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.fluid import core, framework
from paddle.nn import BatchNorm

np.random.seed(2023)


class PrimeNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.bn = BatchNorm(4, act="relu")

def forward(self, x):
y = self.conv(x)
out = self.bn(y)
res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0)
return res


class TestPrimAMPO1(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim v.s Dygraph in AMPO1.
"""

def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False

def train(self, use_prim):
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)

if use_prim:
net = paddle.jit.to_static(net, build_strategy=False)
with paddle.amp.auto_cast(level='O1'):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
return loss

def test_amp_01(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False)
actual = self.train(True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)

def test_amp_O1_infer(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
net = PrimeNet()
core._set_prim_all_enabled(False)
net.eval()
static_net = paddle.jit.to_static(net, build_strategy=False)
res = static_net(self.x)

# set prim all enabled
core._set_prim_all_enabled(True)
net.eval()
static_net = paddle.jit.to_static(net, build_strategy=False)
with paddle.amp.auto_cast(level='O1'):
res_amp = static_net(self.x)

np.testing.assert_allclose(
res,
res_amp,
rtol=1e-3,
atol=1e-3,
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 58d5af0

Please sign in to comment.