Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ORTModule memory refinement #8979

Closed
wants to merge 14 commits into from
Closed

[WIP] ORTModule memory refinement #8979

wants to merge 14 commits into from

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Sep 7, 2021

Description: memory refinement

On a simplified models constructed by stacking autograd.Function instances. So most of computations runs on PyTorch kernels, this is a good baseline to compare PyTorch with ORT. From memory profiling, we see ORT takes 16% more memory than PyTorch runs, obviously there are some bugs.

ORTModuleFunction holds all the calculated gradients until backward completed, BUT Pytorch will accumulate the calculated gradients into param.grad immediately once gradient computations comes to any of the leaf node (AccumulateGrad function).

The commit (29eebc2) in this PR, tries to 1). use similar idea of #8993 (cut off the connection between ORTModuleFunction and its inputs' AccumulateGrad gradient function. then PyTorch will not not accumulate the ORTModuleFunction backward outputs into param.grad. 2). we do the gradient in-place update (into param.grad) on the ONNX graph.

image

With the changes, for some cases the memory consumptions are in parity between ORT and PyTorch. More detailed benchmarks come later.

TODO: I might missed requirements of DDP onto torch grad accumulator 's post hook. Not sure whether change 2 benefits real models before investing more. So currently multiple GPU run might be failed using this branch.

Benchmark

command: python bench.py --batch 1024 --hidden 8194 --layer 12 --tag test --ort

PT:

MEM_STAT - ====== 98 before forward pass ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%
MEM_STAT - ====== 98 after forward pass ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%
MEM_STAT - ====== 98 after loss ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%
MEM_STAT - ====== 98 after backward pass ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%
MEM_STAT - ====== 99 before forward pass ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%
MEM_STAT - ====== 99 after forward pass ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%
MEM_STAT - ====== 99 after loss ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%
MEM_STAT - ====== 99 after backward pass ====== MA 9348.5493 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 22.81 GB, percent = 2.6%

ORT (master):

MEM_STAT - ====== 98 before forward pass ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%
MEM_STAT - ====== 98 after forward pass ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%
MEM_STAT - ====== 98 after loss ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%
MEM_STAT - ====== 98 after backward pass ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%
MEM_STAT - ====== 99 before forward pass ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%
MEM_STAT - ====== 99 after forward pass ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%
MEM_STAT - ====== 99 after loss ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%
MEM_STAT - ====== 99 after backward pass ====== MA 9348.5493 MB Max_MA 12550.0869 MB CPU Virtual Memory: used = 23.13 GB, percent = 2.6%

ORT (this PR):

MEM_STAT - ====== 98 before forward pass ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%
MEM_STAT - ====== 98 after forward pass ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%
MEM_STAT - ====== 98 after loss ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%
MEM_STAT - ====== 98 after backward pass ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%
MEM_STAT - ====== 99 before forward pass ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%
MEM_STAT - ====== 99 after forward pass ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%
MEM_STAT - ====== 99 after loss ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%
MEM_STAT - ====== 99 after backward pass ====== MA 9316.5415 MB Max_MA 9700.6987 MB CPU Virtual Memory: used = 23.14 GB, percent = 2.6%

Conclusion:

With this PR, the memory allocation (MA) and max memory allocation (MAX_MA) are aligned with PyTorch runs.

Motivation and Context

  • Why is this change required? What problem does it solve?
  • If it fixes an open issue, please link to the issue here.

@pengwa pengwa added component:ortmodule training issues related to ONNX Runtime training; typically submitted using template labels Sep 7, 2021
@pengwa pengwa changed the title custom autograd func memory refinement [WIP] custom autograd func memory refinement Sep 7, 2021
@pengwa
Copy link
Contributor Author

pengwa commented Sep 7, 2021

I might missed requirements of DDP onto torch grad accumulator 's post hook. Not sure whether change 2 benefits real models before investing more.

@pengwa pengwa marked this pull request as draft September 8, 2021 00:53
@pengwa pengwa changed the title [WIP] custom autograd func memory refinement [WIP] ORTModule memory refinement Sep 8, 2021
@pengwa pengwa closed this Sep 13, 2021
@pengwa pengwa reopened this Nov 3, 2021
@garymm garymm removed the request for review from a team February 11, 2022 01:33
@stale
Copy link

stale bot commented Apr 16, 2022

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@stale stale bot added the stale issues that have not been addressed in a while; categorized by a bot label Apr 16, 2022
@stale stale bot removed the stale issues that have not been addressed in a while; categorized by a bot label Aug 12, 2022
@pengwa pengwa closed this Dec 30, 2022
@pengwa pengwa deleted the pengwa/custom_fnc_mem branch April 11, 2023 11:37
pengwa added a commit that referenced this pull request Jan 16, 2024
## Dependency

#19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
#8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants