-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSpeed is slower than FSDP #5047
Comments
This issue turned out to be cascaded, but was resolved by a combination of leaf_nodes and the stage3_param_persistence_threshold. |
Thank you, @halilakin, for sharing your experience! We apologize for not being able to respond to you earlier. I am curious about your performance evaluation for future improvements, but the link to the FSDP results seems to be broken. Could you please update the link and let us know which model you used and the configurations? |
Thanks for offering help here @tohtana. I am back to this task (renamed it accordingly) and still trying to figure out why I can't match Pytorch FSDP's performance with ZeRO3. There are several issues but here is the top one: ZeRO3 launches too many all_gathers Here is the forward pass of a 6 layer transformer model with FSDP And here is the same forward with ZeRO3 You can see that ZeRO3 is slower and it's probably due to launching too many all_gathers instead of launching 1 all_gather per layer like FSDP does. At least, that was my theory. I tried to limit the number of all_gathers by setting a transformerlayer as a leaf node However, it doesn't seem to have made things faster. I am also not sure if this is the right solution since this solution is a very recent introduction to ZeRO3. How did people solve this issue before? Another issue is that I tried increasing allgather_bucket_size but that doesn't seem to be doing anything. I also tried increasing prefetch_bucket_size hoping that it would prefetch layers ahead of time but that also didn't help. The fact that every transformer layer takes equal time with ZeRO3 in the last picture tells me that we are making the CPU wait instead of running ahead but I am not sure what's causing that. cc: @stas00 |
@halilakin Thank you for sharing your detailed investigation! We appreciate your kind help. It seems that synchronization is not working as expected. We had anticipated that the leaf module of ZeRO-3 would enhance performance in certain scenarios, but it appears insufficient to address the inefficiency. We will work on resolving this issue and provide you with updates as we make progress. |
Thanks for the quick response @tohtana. I've stopped using leaf module for now until I fully understand all the knobs and finish reading the code. I will update the thread with more information but it's quite possible that I am not correctly setting all the parameters. |
In my experience and setup, I found FSDP to be approximately 10% faster than ZeRO. This appears to be because FSDP consolidates parameters from an entire layer into a single communication all_gather, rather than initiating multiple all_gathers for components such as MHA and FFN, as ZeRO does, which decreases communication throughput. I have attempted to reduce the number of all_gathers ZeRO launches using two methods:
I also tried adjusting the allgather_bucket_size parameter, but it doesn't seem to work. If there are other ways to make ZeRO aggregate many nn.Modules and initiate a single all_gather for all their parameters, please let me know. |
Hi @halilakin, |
Thanks @tohtana. Let me test the leaf module thoroughly in combination with other flags today and update the thread. |
I was able to resolve the fragmented all_gathers by using the leaf module feature, by setting it before the initialization of DeepSpeed so that the hooks only attach to the transformer blocks. Previously, I was setting it after the initialization following the original example. I can attribute a significant portion of the remaining performance gap to the post-step all-gathering of persistent parameters here, as you can also see in the trace. What's the purpose of this expensive all_gather? I thought persistent parameters are not partitioned. What's the reason to all_gather them at every step? |
Hi @halilakin, I think the idle time between layers was caused by the overhead associated with handling hooks that launch allgather/reduce_scatter, and some other operations. Unfortunately, I'm unable to view the details of the profiled timeline (for some reason, downloading or zooming in was not possible). If you could provide more details, I might have better insight into the backward pass. We are actively working to enhance the performance of ZeRO. Your feedback is valuable to our efforts. |
I was able to get ZeRO to 95% of the MFU I got with FSDP for my network, primarily by using leaf nodes and setting the zero3 parameters better. The remaining gap appears to be due to ZeRO partitioning everything in my model. There are parts of my model where there shouldn't be any partitioning since they are computation heavy but memory light, but it seems there is no straightforward way to exclude parts of a network from being partitioned with ZeRO. |
Hi @halilakin , really nice chatting with you yesterday. Given we are clear about all the collective communication calls, and now performance-wise is ok. I will close this issue for now. I already noted the feature you requested. Thx a ton. |
Thanks for the help @GuanhuaWang! |
Describe the bug
I am still familiarizing with DeepSpeed so here is a n00b question. I wrapped my model with DeepSpeed and seeing good ZeRO2 performance. However, when I switch to ZeRO3, the all gathers are not overlapping and they are very fragmented even though the default params look good. How can I learn more about why all gathers are so fragmented and how to make them less granular?
The text was updated successfully, but these errors were encountered: