-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ORTModule memory improvement #18924
ORTModule memory improvement #18924
Conversation
…pengwa/mem_improvement
…pengwa/mem_improvement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The yaml file change looks good to me. I didn't look the other parts.
orttraining/orttraining/python/training/ortmodule/_inference_manager.py
Outdated
Show resolved
Hide resolved
…pengwa/mem_improvement
Args: | ||
exported_model (ModelProto): The exported model. | ||
named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map. | ||
|
||
Returns: | ||
tuple[bool, ModelProto]: A tuple of bool and ModelProto. The bool indicates whether the model is modified. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For future reference: no need to include type information in docstrings as they are in the function signature already.
Dependency
#19007
ORTModule memory efficient gradient management
Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with #8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter.
This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly.
Design
Check the details from
https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ
Convergence Validation:
differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2)
TODO
Consolidate the logic with Stage3's similar logic.