Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom autograd func memory refinement #8993

Merged
merged 11 commits into from
Sep 9, 2021
Merged

custom autograd func memory refinement #8993

merged 11 commits into from
Sep 9, 2021

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Sep 8, 2021

Description: custom autograd func memory refinement

In PythonOp glue code, we run the PyTorch code using torch tensor (constructed from the upstream ORTValue), this torch tensor is a leaf node when constructed with DLPack. So during forward function runs, there are edges connected to the leaf tensor, which will have a AccumulateGrad gradient function. Having AccumulateGrad gradient function reference the leaf variable , so that means, until AccumulateGrad gradient function is destroyed after PythonOpGrad completed and calls unregister_grad_fn, then the leaf variable will be released. This increase the life time of the variable a lot.

The changes in this PR is, after THPFunction_apply completed, we cut off the edge connection to the leaf variable, then the AccumulateGrad gradient function will be release immediately.

PT:

MEM_STAT - ====== 99 before forward pass ====== MA 9344.0005 MB Max_MA 9696.001 MB CPU Virtual Memory: used = 69.16 GB, percent = 15.7%
MEM_STAT - ====== 99 after forward pass ====== MA 9344.0005 MB Max_MA 9696.001 MB CPU Virtual Memory: used = 69.23 GB, percent = 15.7%
MEM_STAT - ====== 99 after loss ====== MA 9344.0005 MB Max_MA 9696.001 MB CPU Virtual Memory: used = 69.23 GB, percent = 15.7%
MEM_STAT - ====== 99 after backward pass ====== MA 9344.0005 MB Max_MA 9696.001 MB CPU Virtual Memory: used = 69.34 GB, percent = 15.7%

ORT (master):

MEM_STAT - ====== 99 before forward pass ====== MA 9344.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.94 GB, percent = 16.1%
MEM_STAT - ====== 99 after forward pass ====== MA 9696.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.66 GB, percent = 16.0%
MEM_STAT - ====== 99 after loss ====== MA 9696.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.66 GB, percent = 16.0%
MEM_STAT - ====== 99 after backward pass ====== MA 9344.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.71 GB, percent = 16.0%

ORT (This PR):

MEM_STAT - ====== 99 before forward pass ====== MA 9344.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.1 GB, percent = 15.9%
MEM_STAT - ====== 99 after forward pass ====== MA 9344.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.18 GB, percent = 15.9%
MEM_STAT - ====== 99 after loss ====== MA 9344.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.18 GB, percent = 15.9%
MEM_STAT - ====== 99 after backward pass ====== MA 9344.0005 MB Max_MA 12544.001 MB CPU Virtual Memory: used = 70.23 GB, percent = 15.9%

MA (Memory Allocated) running with ORT dropped from 9696 to 9344 (in parity with PyTorch now).
Max_MA (Max Memory Allocated), ORT still higher than PyTorch. This is because ORTModuleFunction have a coarse-grained gradient accumulation operation, making some of the earlier generated gradient live longer than PyTorch runs.

Motivation and Context

  • Why is this change required? What problem does it solve?
  • If it fixes an open issue, please link to the issue here.

@pengwa pengwa requested a review from wschin September 8, 2021 03:38
@pengwa pengwa added training issues related to ONNX Runtime training; typically submitted using template release:1.9 labels Sep 8, 2021
…tensions/torch_interop_utils/torch_interop_utils.cc

Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Copy link
Contributor

@wschin wschin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard working. Looks super great.

@pengwa pengwa merged commit d209fe2 into master Sep 9, 2021
@pengwa pengwa deleted the pengwa/pythonop_mem branch September 9, 2021 10:37
wangyems pushed a commit that referenced this pull request Sep 9, 2021
* Release torch tensor referenced by torch gradient graph (created in PythonOp)

* Update orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc

* refine with comments

Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
wangyems added a commit that referenced this pull request Sep 9, 2021
* fast reduction for reducemean (#8976)

* Adding preprocessor checks for torch version during torch cpp extensions compilation (#8989)

* custom autograd func memory refinement  (#8993)

* Release torch tensor referenced by torch gradient graph (created in PythonOp)

* Update orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/torch_interop_utils/torch_interop_utils.cc

* refine with comments

Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>

* Fix issues in TensorRT EP (#8996)

* fix big engine load issue and add cuda_cpu_alloc

* remove redundancy

* fix minor issues

* [js/web] fix karma launch with chrome headless (#8998)

* Update Nuget Packge Pipline to CUDA11.4 and TensorRT8 on Windows (#9000)

* Update to CUDA11.4 and TensorRT-8.0.3.4

* update trt pool, remove cudnn from setup_env_gpu.bat

* revert pool

* test gpu package pipeline on t4

* back out changes

* back out changes

Co-authored-by: George Wu <jywu@microsoft.com>

* Fix fuzz testing build blocking release. (#9008)

* add model local function support (#8540)

* updates for picking pnnx commit

* add tests filter to c# tests

* plus test fixes

* fix versioning for contrib ops

* fix tests

* test filter for optional ops

* more versioning related updates

* fix test

* fix layernorm spec

* more updates

* update docs

* add more test filters

* more filters

* update binary size threshold

* update docs

* draft - enable model local function

* enable model local functions in ORT

* update to latest rel onnx commit

* plus tests

* plus more updates

* plus updates

* test updates

* Fix for nested functions + shape inference

* plus bug fix and updates per review

* plus fixes per review

* plus test updates

* plus updates per review

* plus fixes

* fix a test

Co-authored-by: Vincent Wang <wangwchpku@outlook.com>
Co-authored-by: baijumeswani <bmeswani@microsoft.com>
Co-authored-by: pengwa <pengwa@microsoft.com>
Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Co-authored-by: stevenlix <38092805+stevenlix@users.noreply.github.com>
Co-authored-by: Yulong Wang <yulongw@microsoft.com>
Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com>
Co-authored-by: George Wu <jywu@microsoft.com>
Co-authored-by: Pranav Sharma <prs@microsoft.com>
Co-authored-by: Ashwini Khade <askhade@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants