Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] support master_grad for amp training #52235

Merged
merged 5 commits into from
Apr 10, 2023

Conversation

zhangting2020
Copy link
Contributor

@zhangting2020 zhangting2020 commented Mar 28, 2023

PR types

New features

PR changes

APIs

Describe

support master_grad for amp training

背景:AMP-O2模式下训练,部分模型会因梯度极大或极小导致计算溢出,本PR提供master_grad机制来提供这种场景下的一种解决方案。

实现要点
(1)默认不开启master_grad,仅在O2 level且用户设置为True后生效。

(2)用户通过amp.decorate接口设置master_grad,框架会同时记录被装饰的模型的参数,保存在AMP的全局状态AMPGlobalState

(3)在auto_cast中向backward_final_hook中注册master_grad_hook。

  • backward_final_hook将在backward整个过程执行的最后完成一些额外处理,因此master_grad_hook将在该阶段被触发。
  • master_grad_hook中会遍历上面被装饰的模型的参数,将低精度的参数梯度,替换为FP32类型,原始的FP16梯度将被释放。在该hook结束,将重置already_register_backward_final_hook的状态为False。原因是backward_final_hook完成后,所有hook将被清除。为了下一次迭代能继续触发master_grad_hook,必须在每一个迭代重新注册。

(4)如decorate API文档所述:一旦启用,在backward结束后,后续获取到的权重梯度将会是FP32类型。该功能是为了保证在用户自定义一些计算时通过param.grad拿到的是这份FP32梯度,以保障训练精度。

其他的考虑
(1)为什么不在backward API里添加master_grad的处理?
常规的FP32训练并不需要该过程,所以为了尽量不去使常规的训练过程多出不相关的代码,将这部分功能保留在了AMP模块中。

(2)master_grad是否会影响性能?
在未启用的场景下不会有性能影响。如果启用了该功能,由于会增加所有参数的cast过程,会引入这部分额外的开销。另外后续相关的GradClip、WeithDecay等由于使用了FP32类型的梯度做运算,也会有一些时间代价。

image

@paddle-bot
Copy link

paddle-bot bot commented Mar 28, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhangting2020 zhangting2020 force-pushed the master_grad branch 2 times, most recently from de0ea78 to 3e8c205 Compare March 28, 2023 11:48
@zhangting2020 zhangting2020 changed the title support master_grad for amp training [AMP] support master_grad for amp training Mar 30, 2023
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
for (auto& tensor : tensor_list) {
VLOG(6) << "set master_grad for tensor: " << tensor.name();
PADDLE_ENFORCE(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE_EQ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

egr::egr_utils_api::IsLeafTensor(tensor),
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad."));
paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor);
PADDLE_ENFORCE(grad != nullptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE_NE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个tensor_list是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加注释

"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加个注释说明下,这里支持了不同数据类型的梯度累加吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加在下面200行了

# master_grad_hook will run at the end of backward.
# Since backward_final_hook will be cleared once they have been
# done, we should register the hook every step.
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我有个疑问是,这种实现方式,那训练的每次迭代用的FP16 grad是重新申请的吗?FP16 grad和master_grad哪个的显存会一直保存?clear_grad清除的又是哪个grad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1)正常训练中每个迭代反向的梯度也是新申请的。像梯度累加需要本次的梯度和之前的梯度,所以不会只有一份梯度。以下是一个常规训练,即master_grad=False,没有梯度累加的场景下的单测详细日志。

  • 第一个迭代:grad_y是linear.w.grad,Ptr: 0x7fed21d1a370
I0406 05:21:41.665776 34683 nodes.cc:25587] Finish AD API GRAD: matmul_grad
I0406 05:21:41.665818 34683 nodes.cc:25610] { Input: [ 
( grad_out , [{Name: None, Initialized: 1, Ptr: 0x7fed21d197e0 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 4 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ None ], StopGradient: [ 0 ] ]}]),  
( x , [{Name: None, Initialized: 1, Ptr: 0x7fed12549cd0 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 2 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ None ], StopGradient: [ 1 ] ]}]),  
( y , [{Name: None, Initialized: 1, Ptr: 0x7fed0f5cbca0 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 4 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ BackwardOutMeta: [  {SlotSize: [1]: SlotID: 0, StopGradients: 0, , Edges[ { NULL Edge } ]}  ], BackwardInMeta: [  {SlotSize: [SlotID: 0, StopGradients: 0, , Edges[ { NULL Edge } ]]:  ] ], StopGradient: [ 0 ] ]}]), ],  
 Output: [ 
 ( grad_x , [{Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]}]),  
 ( grad_y , [{Name: None, Initialized: 1, Ptr: 0x7fed21d1a370 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 4 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ None ], StopGradient: [ 0 ] ]}]), ] } 
  • 第2个迭代:grad_y是linear.w.grad,Ptr: 0x7fed16a79380
I0406 05:21:41.684028 34683 nodes.cc:25587] Finish AD API GRAD: matmul_grad
I0406 05:21:41.684104 34683 nodes.cc:25610] { Input: [ 
( grad_out , [{Name: None, Initialized: 1, Ptr: 0x7fed1252b6c0 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 4 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ None ], StopGradient: [ 0 ] ]}]),  
( x , [{Name: None, Initialized: 1, Ptr: 0x7fed21d36fe0 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 2 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ None ], StopGradient: [ 1 ] ]}]),  
( y , [{Name: None, Initialized: 1, Ptr: 0x7fed0f5cbca0 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 4 ], ADInfo:[ Grad: [ {Name: None, Initialized: 1, Ptr: 0x7fed21d1a370 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 4 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ None ], StopGradient: [ 0 ] ]} ],  GradNode: [ BackwardOutMeta: [  {SlotSize: [1]: SlotID: 0, StopGradients: 0, , Edges[ { NULL Edge } ]}  ], BackwardInMeta: [  {SlotSize: [SlotID: 0, StopGradients: 0, , Edges[ { NULL Edge } ]]:  ] ], StopGradient: [ 0 ] ]}]), ],  
 Output: [ 
 ( grad_x , [{Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]}]),  
 ( grad_y , [{Name: None, Initialized: 1, Ptr: 0x7fed16a79380 TensorInfo: [ Type: DenseTensor, Dtype: float16, Place: Place(gpu:0), Shape: 2, 4 ], ADInfo:[ Grad: [ {Name: None, Initialized: 0, Ptr: 0 TensorInfo: [ Unknown ], ADInfo:[ None ]} ],  GradNode: [ None ], StopGradient: [ 0 ] ]}]), ] } 

(2)FP16的grad在set_master_grad后就不需要保留了,后续过程param.grad只存在1份fp32的梯度。因此clear_grad清楚的param.grad也是fp32梯度

# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eager amp有已有的单测文件吗?建议,统一放在已有的动态图amp单测里面,或者建立专门测试master_grad的单测,即单测文件名改成test_amp_master_grad.py。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改目录和文件名

if (i + 1) % accumulate_batchs_num == 0:
scaler.step(opt)
scaler.update()
opt.clear_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也可以用新增的算子统计接口,看下算子的执行类型、数量是否符合预期?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@zhangting2020 zhangting2020 force-pushed the master_grad branch 3 times, most recently from d2794a3 to 0e1973c Compare April 6, 2023 04:47
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Great work ~

@@ -717,6 +759,8 @@ def decorate(
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is enabled 斟酌一下吧,最好写的直接一点,比如 If master_grad is xxx

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,细节后续再修

@zhangting2020 zhangting2020 merged commit 4970dd6 into PaddlePaddle:develop Apr 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants