-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training multiple models #7018
base: master
Are you sure you want to change the base?
Training multiple models #7018
Conversation
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
…/zero_multi_models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am missing context of what this PR is doing (other than it tries to do something wrt training multiple models).
But don't you need new tests?
@stas00, thanks for the feedback. I have updated the OP with some background from earlier discussions. I will work on converting the gist codes into UTs. |
That's much better after your OP expansion, Tunji. The gists look good, please ping me once they are tests and would be happy to review again. That's a very important feature for quite a few users. Thank you for working on it. |
Support training multiple models, such as in HF
Here is some update on supporting multiple DS engines with single loss.backward(). The main message is that I think we can support this. First, some context. Backward pass in ZeRO is complicated because the optimizations/features require special handling of gradients, such as:
So, we created engine.backward(loss) as a wrapper function to provide us fine-grained control over backward as below
As demonstrated by @muellerzr, this approach breaks down when loss originates from multiple DS engines. Our proposed solution is to use backward hooks on the module to launch backward_prologue() and backward_epilogue() . Specifically,
We plan for this solution to preserve BC, i.e., engine.backward() will remain correct for single engine scenarios.
The current status is that (1) is completed, while (2) is in progress. To unblock e2e testing for multi-engine scenarios, since there are probably other issues, we have a temporarily added engine._backward_prologue() . You can try this out via the following artifacts.