-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP] support master_grad for amp training #52235
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ typedef SSIZE_T ssize_t; | |
#include "paddle/fluid/pybind/exception.h" | ||
#include "paddle/fluid/pybind/tensor_py.h" | ||
#include "paddle/phi/api/ext/op_meta_info.h" | ||
#include "paddle/phi/api/include/api.h" | ||
#include "paddle/phi/api/lib/utils/allocator.h" | ||
#include "paddle/phi/common/data_type.h" | ||
#include "paddle/phi/core/compat/convert_utils.h" | ||
|
@@ -1183,6 +1184,37 @@ static PyObject* eager_api__add_backward_final_hook(PyObject* self, | |
EAGER_CATCH_AND_THROW_RETURN_NULL | ||
} | ||
|
||
static PyObject* eager_api_set_master_grads(PyObject* self, | ||
PyObject* args, | ||
PyObject* kwargs) { | ||
EAGER_TRY | ||
// tensor_list is a list of model parameters. | ||
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已添加注释 |
||
for (auto& tensor : tensor_list) { | ||
VLOG(6) << "set master_grad for tensor: " << tensor.name(); | ||
PADDLE_ENFORCE_EQ( | ||
egr::egr_utils_api::IsLeafTensor(tensor), | ||
true, | ||
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad.")); | ||
paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor); | ||
PADDLE_ENFORCE_NE(grad, | ||
nullptr, | ||
paddle::platform::errors::Fatal( | ||
"Detected NULL grad" | ||
"Please check if you have manually cleared" | ||
"the grad inside autograd_meta")); | ||
auto dtype = (*grad).dtype(); | ||
if ((*grad).initialized() && | ||
(dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16)) { | ||
auto master_grad = | ||
paddle::experimental::cast(*grad, phi::DataType::FLOAT32); | ||
grad->set_impl(master_grad.impl()); | ||
} | ||
} | ||
RETURN_PY_NONE | ||
EAGER_CATCH_AND_THROW_RETURN_NULL | ||
} | ||
|
||
PyMethodDef variable_functions[] = { | ||
// TODO(jiabin): Remove scale when we have final state tests | ||
{"scale", | ||
|
@@ -1251,6 +1283,11 @@ PyMethodDef variable_functions[] = { | |
(PyCFunction)(void (*)(void))eager_api_reset_saved_tensors_hooks, | ||
METH_VARARGS | METH_KEYWORDS, | ||
NULL}, | ||
/**amp functions**/ | ||
{"set_master_grads", | ||
(PyCFunction)(void (*)(void))eager_api_set_master_grads, | ||
METH_VARARGS | METH_KEYWORDS, | ||
NULL}, | ||
/**sparse functions**/ | ||
#if defined(PADDLE_WITH_CUDA) | ||
{"async_read", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,6 +101,23 @@ def amp_state(): | |
return _g_amp_state_ | ||
|
||
|
||
class AMPGlobalState: | ||
def __init__(self): | ||
self.model_parameters = [] | ||
self.use_master_grad = False | ||
self.already_register_final_backward_hook = False | ||
|
||
def __setattr__(self, name, val): | ||
self.__dict__[name] = val | ||
|
||
|
||
_amp_global_state = AMPGlobalState() | ||
|
||
|
||
def amp_global_state(): | ||
return _amp_global_state | ||
|
||
|
||
# NOTE(zhiqiu): similar as paddle.static.amp.fp16_lists.AutoMixedPrecisionLists._update_list | ||
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. | ||
def _update_list( | ||
|
@@ -423,6 +440,21 @@ def amp_guard( | |
amp_level = AMP_LEVEL.O0 | ||
amp_dtype = "float32" | ||
|
||
# master_grad_hook will run at the end of backward. | ||
# Since backward_final_hook will be cleared once they have been | ||
# done, we should register the hook every step. | ||
if ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我有个疑问是,这种实现方式,那训练的每次迭代用的FP16 grad是重新申请的吗?FP16 grad和master_grad哪个的显存会一直保存? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (1)正常训练中每个迭代反向的梯度也是新申请的。像梯度累加需要本次的梯度和之前的梯度,所以不会只有一份梯度。以下是一个常规训练,即master_grad=False,没有梯度累加的场景下的单测详细日志。
(2)FP16的grad在set_master_grad后就不需要保留了,后续过程param.grad只存在1份fp32的梯度。因此clear_grad清楚的param.grad也是fp32梯度 |
||
amp_global_state().use_master_grad | ||
and not amp_global_state().already_register_final_backward_hook | ||
): | ||
|
||
def master_grad_hook(): | ||
core.eager.set_master_grads(amp_global_state().model_parameters) | ||
amp_global_state().already_register_final_backward_hook = False | ||
|
||
core.eager._add_backward_final_hook(master_grad_hook) | ||
amp_global_state().already_register_final_backward_hook = True | ||
|
||
if tracer: | ||
# enable auto_cast | ||
original_amp_level = tracer._amp_level | ||
|
@@ -491,6 +523,7 @@ def amp_decorate( | |
dtype='float16', | ||
master_weight=None, | ||
save_dtype=None, | ||
master_grad=False, | ||
): | ||
""" | ||
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. | ||
|
@@ -604,6 +637,14 @@ def amp_decorate( | |
for opt in optimizers: | ||
_set_multi_precision(opt, use_multi_precision) | ||
|
||
# support master_grad | ||
if master_grad: | ||
amp_global_state().use_master_grad = True | ||
for idx in range(len(models)): | ||
amp_global_state().model_parameters.extend( | ||
models[idx].parameters() | ||
) | ||
|
||
if save_dtype is not None: | ||
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']): | ||
raise ValueError( | ||
|
@@ -701,6 +742,7 @@ def decorate( | |
dtype='float16', | ||
master_weight=None, | ||
save_dtype=None, | ||
master_grad=False, | ||
): | ||
""" | ||
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. | ||
|
@@ -717,6 +759,8 @@ def decorate( | |
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. | ||
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. | ||
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. | ||
master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
gradients will be FP32 dtype after the backpropagation. Default is False. | ||
|
||
Examples: | ||
|
||
|
@@ -766,5 +810,5 @@ def decorate( | |
print(output.dtype) # FP16 | ||
""" | ||
return amp_decorate( | ||
models, optimizers, level, dtype, master_weight, save_dtype | ||
models, optimizers, level, dtype, master_weight, save_dtype, master_grad | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
from paddle.fluid import core | ||
|
||
|
||
class SimpleNet(paddle.nn.Layer): | ||
def __init__(self, input_size, output_size): | ||
super().__init__() | ||
self.linear = paddle.nn.Linear(input_size, output_size) | ||
|
||
def forward(self, x): | ||
x = self.linear(x) | ||
return x | ||
|
||
|
||
@unittest.skipIf( | ||
not core.is_compiled_with_cuda() | ||
or not core.is_float16_supported(core.CUDAPlace(0)), | ||
"core is not complied with CUDA and not support the float16", | ||
) | ||
class TestMasterGrad(unittest.TestCase): | ||
def check_results( | ||
self, fp32_grads, op_list, total_steps, accumulate_batchs_num | ||
): | ||
for grad in fp32_grads: | ||
self.assertEqual(grad.dtype, paddle.float32) | ||
# fp16 calls | ||
self.assertEqual(int(op_list['matmul_v2'].split(',')[0]), total_steps) | ||
self.assertEqual( | ||
int(op_list['adamw_'].split(',')[0]), | ||
2 * (total_steps / accumulate_batchs_num), | ||
) | ||
self.assertEqual( | ||
int(op_list['transfer_dtype'].split(',')[0]), | ||
total_steps + total_steps * 2, | ||
) | ||
|
||
def run_dygraph(self, total_steps, accumulate_batchs_num): | ||
model = SimpleNet(2, 4) | ||
opt = paddle.optimizer.AdamW(parameters=model.parameters()) | ||
model, opt = paddle.amp.decorate( | ||
model, optimizers=opt, level='O2', master_grad=True | ||
) | ||
scaler = paddle.amp.GradScaler() | ||
|
||
paddle.amp.debugging.enable_operator_stats_collection() | ||
for i in range(total_steps): | ||
x = np.random.random((2, 2)).astype('float32') | ||
label = np.random.random((2, 4)).astype('float32') | ||
|
||
with paddle.amp.auto_cast(level='O2'): | ||
out = model(paddle.to_tensor(x)) | ||
loss = paddle.nn.functional.l1_loss( | ||
out, paddle.to_tensor(label) | ||
) | ||
scaled = scaler.scale(loss) | ||
scaled.backward() | ||
fp32_grads = [model.linear.weight.grad, model.linear.bias.grad] | ||
if (i + 1) % accumulate_batchs_num == 0: | ||
scaler.step(opt) | ||
scaler.update() | ||
opt.clear_grad() | ||
paddle.amp.debugging.disable_operator_stats_collection() | ||
op_list = paddle.fluid.core.get_low_precision_op_list() | ||
return fp32_grads, op_list | ||
|
||
def test_master_grad(self): | ||
total_steps = 4 | ||
accumulate_batchs_num = 2 | ||
fp32_grads, op_list = self.run_dygraph( | ||
total_steps, accumulate_batchs_num | ||
) | ||
self.check_results( | ||
fp32_grads, op_list, total_steps, accumulate_batchs_num | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加个注释说明下,这里支持了不同数据类型的梯度累加吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加在下面200行了