-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle-ASP]Support sharding training for the Nvidia's ASP(2:4 sparsity) functionality #37725
Conversation
sync repo
✅ This PR's description meets the template requirements! |
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add UT
@@ -150,7 +155,8 @@ def prune_model(main_program=None, | |||
n=2, | |||
m=4, | |||
mask_algo='mask_1d', | |||
with_mask=True): | |||
with_mask=True, | |||
sharding=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以让用户直接传一个place么?那种方式更好理解?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
理解是好理解,但是我们需要给用户额外说明。
此外,即使说明了,也有经常place设置错误的风险,然后出bug。
我的理解是,place在prune_model里面,代码稳定性更高一些呢?
Sorry to inform you that 13083b1's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
added tests for optimizer compatibility and modified prune_model API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
请修改下PR标题,方便后续检索和管理自己的工作。 |
Done, thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for sharding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Bug fixes
PR changes
Others
Describe
Nvidia has implemented 2:4 sparsity code in PaddlePaddle, supporting fleet distributed training. But when we are trying to train with sharding strategy (the model parallel paradigm in PaddlePaddle), GPU:0 will always be OOM while other GPUs seems normal.
After fix, developers should pass in an argument: shading=True when calling sparsity.prune_model() with the sharding strategy. Otherwise no difference when using the APIs.