Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training multiple models #7018

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

Training multiple models #7018

wants to merge 13 commits into from

Conversation

tjruwase
Copy link
Contributor

@tjruwase tjruwase commented Feb 8, 2025

Support training multiple models, such as in HF

Here is some update on supporting multiple DS engines with single loss.backward(). The main message is that I think we can support this. First, some context. Backward pass in ZeRO is complicated because the optimizations/features require special handling of gradients, such as:

  1. Gradient partitioning
  2. Overlapping backward and reduction
  3. Upcasting for fp32 grad accumulation

So, we created engine.backward(loss) as a wrapper function to provide us fine-grained control over backward as below

def backward(loss):
 backward_prologue() # setup logic for special gradient handling
 loss.backward()
 backward_epilogue() # cleanup/teardown logic

As demonstrated by @muellerzr, this approach breaks down when loss originates from multiple DS engines. Our proposed solution is to use backward hooks on the module to launch backward_prologue() and backward_epilogue() . Specifically,

  1. backward pre hook on engine.module to launch backward_prologue() before any module gradient is created.
  2. backward post hook on engine.module to launch backward_epilogue() after all module gradients are created.

We plan for this solution to preserve BC, i.e., engine.backward() will remain correct for single engine scenarios.
The current status is that (1) is completed, while (2) is in progress. To unblock e2e testing for multi-engine scenarios, since there are probably other issues, we have a temporarily added engine._backward_prologue() . You can try this out via the following artifacts.

  1. Simple multi-engine test code: https://gist.github.com/tjruwase/f1adccf087b8fa269ffce2ab91c4f1c6#file-multi_engine-py
  2. DS branch: https://github.com/microsoft/DeepSpeed/tree/olruwase/zero_multi_models

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
@tjruwase tjruwase requested a review from tohtana as a code owner February 8, 2025 15:08
@tjruwase tjruwase requested a review from stas00 February 8, 2025 15:08
Copy link
Collaborator

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am missing context of what this PR is doing (other than it tries to do something wrt training multiple models).

But don't you need new tests?

@tjruwase
Copy link
Contributor Author

tjruwase commented Feb 8, 2025

I am missing context of what this PR is doing (other than it tries to do something wrt training multiple models).

But don't you need new tests?

@stas00, thanks for the feedback. I have updated the OP with some background from earlier discussions.

I will work on converting the gist codes into UTs.

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
@stas00
Copy link
Collaborator

stas00 commented Feb 10, 2025

That's much better after your OP expansion, Tunji.

The gists look good, please ping me once they are tests and would be happy to review again.

That's a very important feature for quite a few users. Thank you for working on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants