Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid parallel] Optimize pipeline memory #34230

Merged

Conversation

wangxicoding
Copy link
Contributor

@wangxicoding wangxicoding commented Jul 19, 2021

PR types

Function optimization

PR changes

Others

Describe

优化流水线并行显存占用。
优化前显存会随着global batch size增大而增大,优化后显存不会随gbs增大而增大,保持不变

原PR #34214 。原PR中存在一个bug,当反向过程使用到了前向send var时(如recompute),由于存在提前gc情况,会导致反向使用这个var时出现 tensor为null的情况。
为解决这个问题,本PR采用最开始的 #34086 nop_op来hold住前向send var,通过gc管理用完即时释放就好。当然也可在c++端加入拓扑依赖判断,但工程实现上麻烦些。

最终Pipeline的send变量显存管理方式:

  • Forward send var,通过nop_op,通过gc的机制,自动显存管理。
    详见PR [hybrid performance] Optimize pipeline send wait #34086
  • Backward send var,通过section_worker执行器,根据分析的拓扑依赖,手工显存管理。
    拓扑依赖如下图,当前stage的FB前向recv完成,那么前两个FB的反向send也一定完成了,这个时候可以将反向send的变量释放。(TODO:PR中为编码方便,将释放放到了Forward结束后,可优化为在Forward recv之后)
    image

测试

V100 32GB 单机8卡。
gpt2-medium-en 345MB模型,pipeline_stage=8, micro_batch=4,

gbs 卡号 develop(MB) PR(MB) 显存变化量(MB)
32 0 24402 24406
1 21376 21380
7 7830 7834
64 0 24660 24664
1 21634 21380 -254
7 7830 7834
256 0 24660 不变
1 22408 不变 -1028
7 8168 不变 -334
1024 0 24660 不变
1 25504 不变 -4124
7 11770 不变 -3936
2048 0 24600 不变
1 29632 不变 -8252
7 15710 不变 -7876
3072 0 OOM 不变
1 OOM 不变
7 OOM 不变

测试结果

PR极大降低了显存,显存不随global batch size增大而增大,理论可增大到无穷

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost LGTM except for one suggestion.

if (schedule_mode_ != 1) return;

for (auto &op : ops_) {
if (!op->HasAttr("pipeline_send_var")) continue;
Copy link
Collaborator

@sneaxiy sneaxiy Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have another way to distinguish whether the input of ops (send_v2 or partial_send op) is the variable to send? I mean it is discouraged to set variable name as attribute. I prefer to add a bool attribute to indicate this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Remove pipeline_send_var attr. Find the backward send var directly.

sandyhouse
sandyhouse previously approved these changes Jul 19, 2021
Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now.

@wangxicoding wangxicoding merged commit a74208c into PaddlePaddle:develop Jul 20, 2021
@wangxicoding wangxicoding deleted the optimize_pipeline_memory1 branch July 20, 2021 01:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants