diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 12272a5f6b2d4..1bb9b263cf063 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import collections import copy import functools @@ -26,6 +28,7 @@ import warnings from collections.abc import Iterable from types import FunctionType, MethodType +from typing import TYPE_CHECKING import numpy as np @@ -39,6 +42,9 @@ from .variable_index import _getitem_static, _setitem_impl_, _setitem_static from .wrapped_decorator import signature_safe_contextmanager, wrap_decorator +if TYPE_CHECKING: + from paddle.static.amp.fp16_utils import AmpOptions + __all__ = [] EMPTY_VAR_NAME = core.kEmptyVarName() @@ -2935,8 +2941,11 @@ def __init__( # attr for static graph mode cuda graph self._cuda_graph_attr = _current_cuda_graph_mode - # attr for OP should cast in AMP mode - self._should_auto_cast: bool = True + # attr for OP AMP mode + # using dynamic import to avoid cyclic dependency + from paddle.static.amp.fp16_utils import DEFAULT_AMP_OPTIONS + + self._amp_options: AmpOptions = DEFAULT_AMP_OPTIONS op_maker = core.op_proto_and_checker_maker @@ -3695,24 +3704,24 @@ def dist_attr(self, dist_attr): """ self.desc.dist_attr = dist_attr - def set_auto_cast(self, auto_cast): + def set_amp_options(self, amp_options): """ Set auto cast attribute of this Operator. Args: - auto_cast(bool): True if this Operator should cast in AMP mode. + amp_options (AmpOptions): AmpOptions of this Operator. """ - self._should_auto_cast = auto_cast + self._amp_options = amp_options @property - def should_auto_cast(self): + def amp_options(self): """ Get auto cast attribute of this Operator. Returns: - bool: True if this Operator should cast in AMP mode. + bool: AmpOptions of this Operator. """ - return self._should_auto_cast + return self._amp_options @signature_safe_contextmanager @@ -6985,7 +6994,7 @@ def _copy_data_info_from(self, other, pruned_origin_block_id_map=None): if other_var.stop_gradient: var.stop_gradient = True - def _copy_operator_info_from(self, other: "Program"): + def _copy_operator_info_from(self, other: Program): """ Copy the information of Operator information from other program. @@ -7001,7 +7010,7 @@ def _copy_operator_info_from(self, other: "Program"): ) for dst_block, src_block in zip(self.blocks, other.blocks): for dst_op, src_op in zip(dst_block.ops, src_block.ops): - dst_op.set_auto_cast(src_op.should_auto_cast) + dst_op.set_amp_options(src_op.amp_options) def list_vars(self): """ diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 86f863d02089f..d98c9c81df714 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -15,6 +15,7 @@ from __future__ import annotations import re +import warnings from contextlib import contextmanager import paddle @@ -823,8 +824,10 @@ def convert_auto_cast( ): from .program_translator import ProgramTranslator - if enable: - raise NotImplementedError("Does not support local switching on amp now") + warnings.warn( + "paddle.amp.auto_cast is an experimental features in auto parallel." + + "This will take no effect in normal dy2static." + ) amp_records = ProgramTranslator.get_instance()._amp_records main_program = paddle.static.default_main_program() diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index a0deec7314398..c81be584de6c5 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -53,6 +53,16 @@ class AmpOptions: use_promote: bool +DEFAULT_AMP_OPTIONS = AmpOptions( + enable=True, + custom_white_list=None, + custom_black_list=None, + level='O1', + dtype='float16', + use_promote=True, +) + + def _rename_arg(op, old_name, new_name): """ If an op has old_name input and output, rename these input @@ -609,28 +619,31 @@ def map_block(block, fn, parent_op=None): map_block(sub_block, fn, op) -def prepare_op_should_auto_cast( +def prepare_op_amp_options( program: paddle.static.Program, amp_records: dict[int, list[tuple[AmpOptions, int, int]]], + global_amp_options: AmpOptions, ): - amp_enable_op_map: dict[paddle.static.Operator, bool] = {} + op_amp_options_map: dict[paddle.static.Operator, AmpOptions] = {} def fill_amp_enable_op_map(block, parent_op): block_idx = block.idx ops = block.ops for op in ops: - # The top level should be FP16 - current_op_amp_options = amp_enable_op_map.get(parent_op, True) + # Set the default options to global_amp_options if the op has not parent op. + current_op_amp_options = op_amp_options_map.get( + parent_op, global_amp_options + ) if block_idx in amp_records: for amp_options, start, end in amp_records[block_idx]: if op.idx in range(start, end): - current_op_amp_options = amp_options.enable + current_op_amp_options = amp_options break - amp_enable_op_map[op] = current_op_amp_options + op_amp_options_map[op] = current_op_amp_options map_block(program.global_block(), fill_amp_enable_op_map) - for op, enable in amp_enable_op_map.items(): - op.set_auto_cast(enable) + for op, enable in op_amp_options_map.items(): + op.set_amp_options(enable) def cast_model_to_fp16( diff --git a/test/dygraph_to_static/test_local_cast.py b/test/dygraph_to_static/test_local_cast.py index 5b388d36af0d0..1fc774d8b7b40 100644 --- a/test/dygraph_to_static/test_local_cast.py +++ b/test/dygraph_to_static/test_local_cast.py @@ -18,7 +18,21 @@ import paddle from paddle.jit.dy2static.program_translator import ProgramTranslator -from paddle.static.amp.fp16_utils import prepare_op_should_auto_cast +from paddle.static.amp.fp16_utils import ( + DEFAULT_AMP_OPTIONS, + AmpOptions, + prepare_op_amp_options, +) + +GLOBAL_ENABLE_AMP_OPTIONS = DEFAULT_AMP_OPTIONS +GLOBAL_DISABLE_AMP_OPTIONS = AmpOptions( + enable=False, + custom_black_list=DEFAULT_AMP_OPTIONS.custom_black_list, + custom_white_list=DEFAULT_AMP_OPTIONS.custom_white_list, + level=DEFAULT_AMP_OPTIONS.level, + dtype=DEFAULT_AMP_OPTIONS.dtype, + use_promote=DEFAULT_AMP_OPTIONS.use_promote, +) class LocalAutoCastLayer1(paddle.nn.Layer): @@ -62,6 +76,26 @@ def forward(self, x): return x + 1 +class LocalAutoCastLayer3(paddle.nn.Layer): + def __init__(self): + super().__init__() + self._fc = paddle.nn.Linear(10, 10) + + @paddle.jit.to_static(full_graph=True) + def forward(self, x): + with paddle.amp.auto_cast(True): + x = x.astype("float32") + x = self._fc(x) + y = self._fc(x) * 2 + if x[0][0] > 1: + x = x + y + else: + x = x - y + x = x * 2 + + return x + 1 + + class TestLocalCast(unittest.TestCase): def get_auto_cast_ops_info_from_program(self, program): auto_cast_ops_info = [] @@ -69,17 +103,20 @@ def get_auto_cast_ops_info_from_program(self, program): current_block_should_auto_cast = [] auto_cast_ops_info.append(current_block_should_auto_cast) for op in block.ops: - current_block_should_auto_cast.append(op.should_auto_cast) + current_block_should_auto_cast.append(op.amp_options.enable) return auto_cast_ops_info - def should_auto_cast_for_each_ops(self, layer, input): + def should_auto_cast_for_each_ops(self, layer, input, global_amp_options): concrete_program, _ = layer.forward.get_concrete_program(input) program = concrete_program.main_program - prepare_op_should_auto_cast( - program, ProgramTranslator.get_instance()._amp_records + prepare_op_amp_options( + program, + ProgramTranslator.get_instance()._amp_records, + global_amp_options, ) auto_cast_ops_info = self.get_auto_cast_ops_info_from_program(program) paddle.enable_static() + # Ensure the cloned program has the same auto_cast ops info cloned_program = program.clone() paddle.disable_static() cloned_auto_cast_ops_info = self.get_auto_cast_ops_info_from_program( @@ -103,7 +140,9 @@ def test_should_auto_cast_1(self): # All else branch in auto_cast(False) block [False, False, False], ] # fmt: skip - actual = self.should_auto_cast_for_each_ops(layer, input) + actual = self.should_auto_cast_for_each_ops( + layer, input, GLOBAL_ENABLE_AMP_OPTIONS + ) self.assertEqual(expected, actual) def test_should_auto_cast_2(self): @@ -120,7 +159,29 @@ def test_should_auto_cast_2(self): # All else branch out of auto_cast(False) block [True, True, True], ] # fmt: skip - actual = self.should_auto_cast_for_each_ops(layer, input) + actual = self.should_auto_cast_for_each_ops( + layer, input, GLOBAL_ENABLE_AMP_OPTIONS + ) + self.assertEqual(expected, actual) + + def test_should_auto_cast_3(self): + layer = LocalAutoCastLayer3() + input = paddle.randn([10, 10]) + expected = [ + # There are part of ops in auto_cast(True) block + [ + True, True, True, True, True, True, + False, False, False, False, False, False, False, False, False, False, + ], + # All if branch out of auto_cast(True) block + [False, False], + # All else branch out of auto_cast(True) block + [False, False, False], + ] # fmt: skip + actual = self.should_auto_cast_for_each_ops( + layer, input, GLOBAL_DISABLE_AMP_OPTIONS + ) + self.assertEqual(expected, actual)