Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP]Master grad in static graph #53362

Merged
merged 29 commits into from
May 18, 2023

Conversation

shaojiewang
Copy link
Contributor

@shaojiewang shaojiewang commented Apr 26, 2023

PR types

New features

PR changes

Others

Description

Pcard-70458
enable master grad on static graph mode

背景与功能和 #52235 一致,本PR在静态图做功能实现。

功能和效果

amp O2模式下训练时,bf16和fp16精度会出现梯度小于bf16/fp16精度或大于bf16/fp16表达范围,在静态图上,将梯度转为fp32后做check_finite_and_unscale、grad clip、regularization和optimizer,以确保训练精度。

使用

  1. 默认不开启master_grad,仅在O2 level下,用户手动设置时开启。
  2. 用户通过paddle.static.amp.decorate接口设置master_grad,master_grad=True。配置生效后,会在OptimizerWithMixedPrecision.apply_gradients接口中,创建master_grad tensor,并且在_check_finite_and_unscale之前,插入cast op,把bf16/fp16的grad转换成fp32的master_grad。check_finite_and_unscale、grad clip、regularization和optimizer都使用fp32的master grad计算。

影响

启用后,会在program中插入一些cast op,并且check_finite_and_unscale、grad clip、regularization和optimizer的gradients参数变为fp32,单个step速度会变慢。

@paddle-bot
Copy link

paddle-bot bot commented Apr 26, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot
Copy link

paddle-bot bot commented Apr 26, 2023

❌ The PR is not created using PR's template. You can refer to this Demo.
Please use PR's template, it helps save our maintainers' time so that more developers get helped.

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented May 4, 2023

Sorry to inform you that 6659cca's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.



@unittest.skipIf(
not core.supports_bfloat16(), "place does not support BF16 evaluation"
Copy link
Contributor

@ZzSean ZzSean May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个好像只会判断CPU的place是否支持bf16,如果是GPU的话需要用这个core.is_compiled_with_cuda()+core.is_bfloat16_supported(core.CUDAPlace(0))判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成了判断GPU的接口

@shaojiewang shaojiewang changed the title Master grad in static graph [AMP]Master grad in static graph May 8, 2023
@shaojiewang shaojiewang requested a review from ZzSean May 11, 2023 02:02
# master gradients
self._already_create_master_grad = set()
self._master_grads = {}
self._master_grad = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几行是不是不需要,我看在基类中已有设置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adamw.init()里面没有调用super.init(),是否是因为有某些考量所以没有调用?

@@ -277,6 +395,6 @@ def run_program(
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
# print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里注释要打开么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个测试的内容改成比较O1 O2的loss结果是否equal了,所以是否可以删掉这条打印?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看到有别的测试再用它,打开了

zhangting2020
zhangting2020 previously approved these changes May 12, 2023
# master gradients
self._already_create_master_grad = set()
self._master_grads = {}
self._master_grad = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

定义一个函数吧,create_master_grad_states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if grad.name in self._master_grads:
var = self._master_grads[grad.name]
else:
var_name = grad.name + "_fp32_master"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要判断一下grad的数据类型?或者加一个assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在调用这个函数的时候判断了grad的数据类型,这里是否也要再次判断下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加一个assert

Add ops to cast gradient to master gradient

Args:
param_grads(list(tuple(Tensor, Tensor))):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

虽然这个函数不自动生成文档,但这参数和功能描述的格式不太符合常规

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,请检查是否改正确了

assert isinstance(target_block, framework.Block)
# create
for p, g in param_grads:
if g.name not in self._already_create_master_grad:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用if g.name not in self._master_grads.keys()也能判断吧,没有必要另外存一个self._already_create_master_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用if g.name not in self._master_grads.keys()判断

@@ -1170,9 +1246,10 @@ def apply_gradients(self, params_grads):

# 'optimizer(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None:
# create master gradients
params_grads = self._append_cast_to_master_grad_op(params_grads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没有_grad_clipmaster_grad能生效吗?这里的paramsmaster_weight吗,即用于grad_clip计算的param是不是master_weight

我理解并不只是grad_clip里面使用master_grad,而是backward之后一切需要用到grad的地方都使用master_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没有_grad_clip,master_grad能生效吗?

不能生效。这里写的不对,应该挪到if self._grad_clip is not None外面去判断。随后修改

这里的params是master_weight吗,即用于grad_clip计算的param是不是master_weight?

params不是master_weightgrad_clip不使用params参数,是否需要改成传入master_weightmaster_grad的tuple?

@@ -791,6 +798,7 @@ def decorate(
use_dynamic_loss_scaling=None,
use_amp_guard=False,
use_promote=False,
use_master_grad=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加到L792行之后,参数形式为master_grad=False,并添加参数对应的文档

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -42,14 +72,18 @@ def _build_optimizer(
beta2=0.836,
epsilon=1e-4,
weight_decay=0.01,
multi_precision=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不要加multi_precision参数,decorate已经支持设置master_weight,并且O2训练会自动设置成True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉了

use_promote=use_promote,
master_weight=True,
init_loss_scaling=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_loss_scaling也没必要设置,bfloat16训练会自动设置成1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉了

f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.",
)

def test_amp_fp16_o1(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测这是为了测试master_grad功能的话,o1的检查感觉没有必要?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的。已删除

amp_dtype,
amp_level,
amp_lists,
True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要一个grad_clipFalse的单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加

…s unittest; 3.use a function to create master grad states
@shaojiewang shaojiewang requested a review from Xreki May 16, 2023 02:25
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. PR描述再加强一下吧,功能是什么、怎么做的、达到了什么效果

return losses

dtype = "float16"
max_iters = 25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能没必要跑这么多个iter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测测试了两个项目,1.O1和O2加master grad的loss相等,2.O1和O2不加master grad的loss不相等。两个条件同时满足出现在了第24 step,所以设置了25

seed = 0
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed不需要重复设置吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是想让两次startup program跑出同样的结果。这样的写法比较简单。

@shaojiewang
Copy link
Contributor Author

LGTM. PR描述再加强一下吧,功能是什么、怎么做的、达到了什么效果

补充了PR描述

@Xreki Xreki merged commit 972581d into PaddlePaddle:develop May 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants