Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule memory improvement #18924

Merged
merged 17 commits into from
Jan 16, 2024
Merged

ORTModule memory improvement #18924

merged 17 commits into from
Jan 16, 2024

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Dec 25, 2023

Dependency

#19007

ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with #8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly.

Design

Check the details from

https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

Convergence Validation:

image

differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2)

TODO

Consolidate the logic with Stage3's similar logic.

(cherry picked from commit 76640be)
(cherry picked from commit cd607d5)
(cherry picked from commit 333f235)
(cherry picked from commit be122e3)
@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Dec 25, 2023
@pengwa pengwa requested a review from a team as a code owner December 26, 2023 01:45
snnn
snnn previously approved these changes Jan 1, 2024
Copy link
Member

@snnn snnn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The yaml file change looks good to me. I didn't look the other parts.

@pengwa pengwa merged commit 1150b1f into main Jan 16, 2024
121 of 128 checks passed
@pengwa pengwa deleted the pengwa/mem_improvement branch January 16, 2024 00:57
@pengwa
Copy link
Contributor Author

pengwa commented Jan 16, 2024

Thank you @askhade, @snnn!

Comment on lines +44 to +49
Args:
exported_model (ModelProto): The exported model.
named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The full parameter map.

Returns:
tuple[bool, ModelProto]: A tuple of bool and ModelProto. The bool indicates whether the model is modified.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference: no need to include type information in docstrings as they are in the function signature already.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants