Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][AMP] Support local enable the auto_cast #58977

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections
import copy
import functools
Expand All @@ -26,6 +28,7 @@
import warnings
from collections.abc import Iterable
from types import FunctionType, MethodType
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -39,6 +42,9 @@
from .variable_index import _getitem_static, _setitem_impl_, _setitem_static
from .wrapped_decorator import signature_safe_contextmanager, wrap_decorator

if TYPE_CHECKING:
from paddle.static.amp.fp16_utils import AmpOptions

__all__ = []

EMPTY_VAR_NAME = core.kEmptyVarName()
Expand Down Expand Up @@ -2935,8 +2941,11 @@ def __init__(
# attr for static graph mode cuda graph
self._cuda_graph_attr = _current_cuda_graph_mode

# attr for OP should cast in AMP mode
self._should_auto_cast: bool = True
# attr for OP AMP mode
# using dynamic import to avoid cyclic dependency
from paddle.static.amp.fp16_utils import DEFAULT_AMP_OPTIONS

self._amp_options: AmpOptions = DEFAULT_AMP_OPTIONS

op_maker = core.op_proto_and_checker_maker

Expand Down Expand Up @@ -3695,24 +3704,24 @@ def dist_attr(self, dist_attr):
"""
self.desc.dist_attr = dist_attr

def set_auto_cast(self, auto_cast):
def set_amp_options(self, amp_options):
"""
Set auto cast attribute of this Operator.

Args:
auto_cast(bool): True if this Operator should cast in AMP mode.
amp_options (AmpOptions): AmpOptions of this Operator.
"""
self._should_auto_cast = auto_cast
self._amp_options = amp_options

@property
def should_auto_cast(self):
def amp_options(self):
"""
Get auto cast attribute of this Operator.

Returns:
bool: True if this Operator should cast in AMP mode.
bool: AmpOptions of this Operator.
"""
return self._should_auto_cast
return self._amp_options


@signature_safe_contextmanager
Expand Down Expand Up @@ -6985,7 +6994,7 @@ def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
if other_var.stop_gradient:
var.stop_gradient = True

def _copy_operator_info_from(self, other: "Program"):
def _copy_operator_info_from(self, other: Program):
"""
Copy the information of Operator information from other program.

Expand All @@ -7001,7 +7010,7 @@ def _copy_operator_info_from(self, other: "Program"):
)
for dst_block, src_block in zip(self.blocks, other.blocks):
for dst_op, src_op in zip(dst_block.ops, src_block.ops):
dst_op.set_auto_cast(src_op.should_auto_cast)
dst_op.set_amp_options(src_op.amp_options)

def list_vars(self):
"""
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import re
import warnings
from contextlib import contextmanager

import paddle
Expand Down Expand Up @@ -823,8 +824,10 @@ def convert_auto_cast(
):
from .program_translator import ProgramTranslator

if enable:
raise NotImplementedError("Does not support local switching on amp now")
warnings.warn(
"paddle.amp.auto_cast is an experimental features in auto parallel."
+ "This will take no effect in normal dy2static."
)

amp_records = ProgramTranslator.get_instance()._amp_records
main_program = paddle.static.default_main_program()
Expand Down
29 changes: 21 additions & 8 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class AmpOptions:
use_promote: bool


DEFAULT_AMP_OPTIONS = AmpOptions(
enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1',
dtype='float16',
use_promote=True,
)


def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
Expand Down Expand Up @@ -609,28 +619,31 @@ def map_block(block, fn, parent_op=None):
map_block(sub_block, fn, op)


def prepare_op_should_auto_cast(
def prepare_op_amp_options(
program: paddle.static.Program,
amp_records: dict[int, list[tuple[AmpOptions, int, int]]],
global_amp_options: AmpOptions,
):
amp_enable_op_map: dict[paddle.static.Operator, bool] = {}
op_amp_options_map: dict[paddle.static.Operator, AmpOptions] = {}

def fill_amp_enable_op_map(block, parent_op):
block_idx = block.idx
ops = block.ops
for op in ops:
# The top level should be FP16
current_op_amp_options = amp_enable_op_map.get(parent_op, True)
# Set the default options to global_amp_options if the op has not parent op.
current_op_amp_options = op_amp_options_map.get(
parent_op, global_amp_options
)
if block_idx in amp_records:
for amp_options, start, end in amp_records[block_idx]:
if op.idx in range(start, end):
current_op_amp_options = amp_options.enable
current_op_amp_options = amp_options
break
amp_enable_op_map[op] = current_op_amp_options
op_amp_options_map[op] = current_op_amp_options

map_block(program.global_block(), fill_amp_enable_op_map)
for op, enable in amp_enable_op_map.items():
op.set_auto_cast(enable)
for op, enable in op_amp_options_map.items():
op.set_amp_options(enable)


def cast_model_to_fp16(
Expand Down
75 changes: 68 additions & 7 deletions test/dygraph_to_static/test_local_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@

import paddle
from paddle.jit.dy2static.program_translator import ProgramTranslator
from paddle.static.amp.fp16_utils import prepare_op_should_auto_cast
from paddle.static.amp.fp16_utils import (
DEFAULT_AMP_OPTIONS,
AmpOptions,
prepare_op_amp_options,
)

GLOBAL_ENABLE_AMP_OPTIONS = DEFAULT_AMP_OPTIONS
GLOBAL_DISABLE_AMP_OPTIONS = AmpOptions(
enable=False,
custom_black_list=DEFAULT_AMP_OPTIONS.custom_black_list,
custom_white_list=DEFAULT_AMP_OPTIONS.custom_white_list,
level=DEFAULT_AMP_OPTIONS.level,
dtype=DEFAULT_AMP_OPTIONS.dtype,
use_promote=DEFAULT_AMP_OPTIONS.use_promote,
)


class LocalAutoCastLayer1(paddle.nn.Layer):
Expand Down Expand Up @@ -62,24 +76,47 @@ def forward(self, x):
return x + 1


class LocalAutoCastLayer3(paddle.nn.Layer):
def __init__(self):
super().__init__()
self._fc = paddle.nn.Linear(10, 10)

@paddle.jit.to_static(full_graph=True)
def forward(self, x):
with paddle.amp.auto_cast(True):
x = x.astype("float32")
x = self._fc(x)
y = self._fc(x) * 2
if x[0][0] > 1:
x = x + y
else:
x = x - y
x = x * 2

return x + 1


class TestLocalCast(unittest.TestCase):
def get_auto_cast_ops_info_from_program(self, program):
auto_cast_ops_info = []
for block in program.blocks:
current_block_should_auto_cast = []
auto_cast_ops_info.append(current_block_should_auto_cast)
for op in block.ops:
current_block_should_auto_cast.append(op.should_auto_cast)
current_block_should_auto_cast.append(op.amp_options.enable)
return auto_cast_ops_info

def should_auto_cast_for_each_ops(self, layer, input):
def should_auto_cast_for_each_ops(self, layer, input, global_amp_options):
concrete_program, _ = layer.forward.get_concrete_program(input)
program = concrete_program.main_program
prepare_op_should_auto_cast(
program, ProgramTranslator.get_instance()._amp_records
prepare_op_amp_options(
program,
ProgramTranslator.get_instance()._amp_records,
global_amp_options,
)
auto_cast_ops_info = self.get_auto_cast_ops_info_from_program(program)
paddle.enable_static()
# Ensure the cloned program has the same auto_cast ops info
cloned_program = program.clone()
paddle.disable_static()
cloned_auto_cast_ops_info = self.get_auto_cast_ops_info_from_program(
Expand All @@ -103,7 +140,9 @@ def test_should_auto_cast_1(self):
# All else branch in auto_cast(False) block
[False, False, False],
] # fmt: skip
actual = self.should_auto_cast_for_each_ops(layer, input)
actual = self.should_auto_cast_for_each_ops(
layer, input, GLOBAL_ENABLE_AMP_OPTIONS
)
self.assertEqual(expected, actual)

def test_should_auto_cast_2(self):
Expand All @@ -120,7 +159,29 @@ def test_should_auto_cast_2(self):
# All else branch out of auto_cast(False) block
[True, True, True],
] # fmt: skip
actual = self.should_auto_cast_for_each_ops(layer, input)
actual = self.should_auto_cast_for_each_ops(
layer, input, GLOBAL_ENABLE_AMP_OPTIONS
)
self.assertEqual(expected, actual)

def test_should_auto_cast_3(self):
layer = LocalAutoCastLayer3()
input = paddle.randn([10, 10])
expected = [
# There are part of ops in auto_cast(True) block
[
True, True, True, True, True, True,
False, False, False, False, False, False, False, False, False, False,
],
# All if branch out of auto_cast(True) block
[False, False],
# All else branch out of auto_cast(True) block
[False, False, False],
] # fmt: skip
actual = self.should_auto_cast_for_each_ops(
layer, input, GLOBAL_DISABLE_AMP_OPTIONS
)

self.assertEqual(expected, actual)


Expand Down