Skip to content

Commit

Permalink
support excluded_layers for amp.decorate (PaddlePaddle#52871)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 authored and jjyaoao committed Apr 19, 2023
1 parent c9d87e2 commit e7bbc04
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 38 deletions.
132 changes: 95 additions & 37 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,47 +199,95 @@ def _is_gpu_bfloat16_supported():
return prop[0] >= 8 and cuda_version_check


def need_keep_fp32(layer, dtype):
need_keep_fp32 = False
# Highest prority. Because all the layers except BN will use bfloat16 params in bfoat16 training,
# here we provide a option to keep fp32 param.
if not layer._cast_to_low_precison:
need_keep_fp32 = True
# The BN layers will keep fp32
elif isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.SyncBatchNorm,
),
):
need_keep_fp32 = True
# layer._dtype is used to set params dtype. BF16 will use bf16 params.
elif (layer._dtype == 'float16') or (
(dtype == 'float16')
and isinstance(
layer,
(
paddle.nn.LayerNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
),
)
):
need_keep_fp32 = True

return need_keep_fp32


def set_excluded_layers(models, excluded_layers):
excluded_layers_instances = []
excluded_layers_types = []
error_message = "excluded_layers must be either a nn.Layer instance/type or a list of nn.Layer instances/types."
if excluded_layers is None:
excluded_layers = []
elif isinstance(excluded_layers, paddle.nn.Layer):
excluded_layers_instances = [excluded_layers]
elif isinstance(excluded_layers, type) and issubclass(
excluded_layers, paddle.nn.Layer
):
excluded_layers_types = [excluded_layers]
elif isinstance(excluded_layers, list):
for item in excluded_layers:
if isinstance(item, paddle.nn.Layer):
excluded_layers_instances.append(item)
elif issubclass(item, paddle.nn.Layer):
excluded_layers_types.append(item)
else:
raise TypeError(error_message)
else:
raise TypeError(error_message)

for idx in range(len(excluded_layers_instances)):
for layer in excluded_layers_instances[idx].sublayers(
include_self=True
):
layer._cast_to_low_precison = False
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
if type(layer) in excluded_layers_types:
layer._cast_to_low_precison = False


@dygraph_only
def pure_fp16_initialize(models):
def amp_initialize(models, dtype, excluded_layers):
set_excluded_layers(models, excluded_layers)
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True
if (layer._dtype == 'float16') or isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm,
paddle.nn.SyncBatchNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
),
):
if need_keep_fp32(layer, dtype):
continue
if isinstance(
if dtype == "float16" and isinstance(
layer,
(
paddle.incubate.nn.FusedFeedForward,
paddle.incubate.nn.FusedMultiHeadAttention,
),
):
layer._amp_decorate(dtype='float16')
layer._amp_decorate(dtype=dtype)
continue
layer._to_impl(
dtype='float16', include_sublayers=False, floating_only=True
)
return models


@dygraph_only
def pure_bf16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._to_impl(
dtype='bfloat16', include_sublayers=False, floating_only=True
dtype=dtype, include_sublayers=False, floating_only=True
)
return models

Expand Down Expand Up @@ -522,6 +570,7 @@ def amp_decorate(
master_weight=None,
save_dtype=None,
master_grad=False,
excluded_layers=None,
):
"""
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
Expand Down Expand Up @@ -590,6 +639,8 @@ def amp_decorate(
raise ValueError(
"level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
)
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype only support float16 or bfloat16.")

if level == 'O1':
if optimizers is None:
Expand All @@ -609,12 +660,9 @@ def amp_decorate(
raise TypeError(
"models must be either a single model or a list of models."
)
if dtype == 'float16':
models = pure_fp16_initialize(models=models)
elif dtype == 'bfloat16':
models = pure_bf16_initialize(models=models)
else:
raise TypeError("dtype only support float16 or bfloat16.")

# initialize parameters of the model.
amp_initialize(models=models, dtype=dtype, excluded_layers=excluded_layers)

if optimizers is not None:
# check optimizers
Expand Down Expand Up @@ -741,6 +789,7 @@ def decorate(
master_weight=None,
save_dtype=None,
master_grad=False,
excluded_layers=None,
):
"""
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
Expand All @@ -757,8 +806,10 @@ def decorate(
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight
gradients will be FP32 dtype after the backpropagation. Default is False.
master_grad(bool, optional): For level='O2', whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight
gradients will be float32 dtype after the backpropagation. Default is False, there is only float16 weight gradients.
excluded_layers(Layer|list of Layer, optional): Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. `excluded_layers` can be specified as
an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16.
Examples:
Expand Down Expand Up @@ -808,5 +859,12 @@ def decorate(
print(output.dtype) # FP16
"""
return amp_decorate(
models, optimizers, level, dtype, master_weight, save_dtype, master_grad
models,
optimizers,
level,
dtype,
master_weight,
save_dtype,
master_grad,
excluded_layers,
)
3 changes: 2 additions & 1 deletion python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ def __init__(self, name_scope=None, dtype="float32"):
self._forward_pre_hooks = collections.OrderedDict()
self._forward_post_hooks = collections.OrderedDict()

self._casted_by_pure_fp16 = False
# only used in AMP Training
self._cast_to_low_precison = True

self._state_dict_hooks = collections.OrderedDict()
# Records orignal functions after @to_static to support to rollback
Expand Down
166 changes: 166 additions & 0 deletions test/amp/test_amp_decorate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
import paddle.nn.functional as F


class ConvBNLayer(paddle.nn.Layer):
def __init__(
self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
):
super().__init__()

self._conv = paddle.nn.Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=None,
)

self._batch_norm = paddle.nn.BatchNorm(num_filters, act=act)

def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)

return y


class Model(paddle.nn.Layer):
def __init__(
self, input_channel, hidden_size, fp16_conv=True, fp16_linear=True
):
super().__init__()
self.conv = ConvBNLayer(input_channel, 8, 3)
self.linear = paddle.nn.Linear(8, hidden_size)
self.layernorm = paddle.nn.Sequential(
paddle.nn.LayerNorm(hidden_size),
paddle.nn.LayerNorm(hidden_size),
)
self.fp16_conv = fp16_conv
self.fp16_linear = fp16_linear

def forward(self, inputs):
with paddle.amp.auto_cast(enable=self.fp16_conv):
if not self.fp16_conv:
inputs = inputs.astype('float32')
x = self.conv(inputs)
with paddle.amp.auto_cast(enable=self.fp16_linear):
if not self.fp16_linear:
x = x.astype('float32')
x = self.linear(x)
x = F.relu(x)
x = self.layernorm(x)
return x


class TestAMPDecorate(unittest.TestCase):
def check_results(self, fp32_layers=[], fp16_layers=[]):
for idx in range(len(fp32_layers)):
for layer in fp32_layers[idx].sublayers(include_self=False):
self.assertEqual(layer.weight.dtype, paddle.float32)
self.assertEqual(layer.bias.dtype, paddle.float32)

for idx in range(len(fp16_layers)):
for layer in fp16_layers[idx].sublayers(include_self=False):
self.assertEqual(layer.weight.dtype, paddle.float16)
self.assertEqual(layer.bias.dtype, paddle.float16)

def test_excluded_layers(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8, fp16_conv=False)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=model.conv,
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32'))
self.check_results(
fp32_layers=[model.conv, model.layernorm],
fp16_layers=[model.linear],
)

def test_excluded_layers_attr_list(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8, fp16_conv=False, fp16_linear=False)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=[model.conv, model.linear],
)

with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32'))

self.check_results(
fp32_layers=[model.conv, model.linear, model.layernorm]
)

def test_excluded_layers_attr_types(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=[paddle.nn.Conv2D, model.linear],
)

with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16'))

self.check_results(
fp32_layers=[model.conv, model.linear, model.layernorm]
)

def test_excluded_layers_attr_none(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=None,
)

with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16'))

self.check_results(
fp32_layers=[model.layernorm, model.conv._batch_norm],
fp16_layers=[model.conv._conv, model.linear],
)


if __name__ == '__main__':
unittest.main()

0 comments on commit e7bbc04

Please sign in to comment.