Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Memory not freed after a variational GP model is discarded #1788

Closed
songlei00 opened this issue Apr 8, 2023 · 9 comments
Closed

[Bug] Memory not freed after a variational GP model is discarded #1788

songlei00 opened this issue Apr 8, 2023 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@songlei00
Copy link

🐛 Bug

Iteratively creating variational GP SingleTaskVariationalGP will result in out of memory. I find a similar problem in #1585 which uses exact GP, i.e., SingleTaskGP. Use gc.collect() will solve the problem in #1585 but is useless for my problem.

I add torch.cuda.empty_cache() and gc.collect() in my code and the code only creates the SingleTaskVariationalGP model and doesn't do anything about forward. However, the memory still increases.

To reproduce

import torch
from botorch.models import SingleTaskVariationalGP
from botorch.models.utils.inducing_point_allocators import GreedyVarianceReduction 
from gpytorch.mlls import VariationalELBO
import time
import gc

def create_gp_forward(device):
    # create the training data
    train_X = torch.rand(4096, 5).to(device)
    train_Y = torch.sin(train_X).sum(dim=1, keepdim=True).to(device)

    # creating the gp model
    gp = SingleTaskVariationalGP(
        train_X, train_Y, 
        learn_inducing_points=False,
        inducing_point_allocator=GreedyVarianceReduction(),
    )
    mll = VariationalELBO(gp.likelihood, gp.model, num_data=len(train_X))

    # gp(train_X)

    print(torch.cuda.memory_summary(device=device).split("\n")[7])

if __name__ == "__main__":
    for _ in range(10000):
        create_gp_forward(device = "cuda")
        time.sleep(2)
        torch.cuda.empty_cache()
        gc.collect()

** Stack trace/error message **

| Allocated memory      |  447602 KB |  509638 KB |     902 MB |  476950 KB |
| Allocated memory      |     870 MB |     930 MB |    1805 MB |     935 MB |
| Allocated memory      |    1303 MB |    1363 MB |    2708 MB |    1405 MB |
| Allocated memory      |    1736 MB |    1796 MB |    3611 MB |    1875 MB |
| Allocated memory      |    2169 MB |    2229 MB |    4514 MB |    2345 MB |
| Allocated memory      |    2602 MB |    2662 MB |    5417 MB |    2815 MB |
| Allocated memory      |    3035 MB |    3095 MB |    6320 MB |    3285 MB |
| Allocated memory      |    3468 MB |    3528 MB |    7223 MB |    3755 MB |
| Allocated memory      |    3901 MB |    3961 MB |    8125 MB |    4224 MB |
| Allocated memory      |    4333 MB |    4394 MB |    9028 MB |    4694 MB |
| Allocated memory      |    4766 MB |    4827 MB |    9931 MB |    5164 MB |
...
CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 5.79 GiB total capacity; 4.78 GiB already allocated; 34.06 MiB free; 4.78 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Expected Behavior

The variables created in create_gp_forward are local variables, so when finishing the function, the memory should be released and the occupation should be constant during the loop.

System information

  • botorch 0.8.2
  • gpytorch 1.9.1
  • torch 1.12.0
  • ubuntu 18.04
@songlei00 songlei00 added the bug Something isn't working label Apr 8, 2023
@songlei00
Copy link
Author

I further find that the out of memory problem is due to allocate_inducing_points function in GreedyVarianceReduction class. If I pass a Tensor input for inducing_points argument for SingleTaskVariationalGP, the GreedyVarianceReduction will not be created and the out of memory problem disappears.

For example, if you use the following code, the out of memory problem disappears.

gp = SingleTaskVariationalGP(
    train_X, train_Y, 
    inducing_points=train_X[: 50], 
)

You can use the following code to reproduce the out of memory problem, which calls the allocate_inducing_points function directly.

import torch
from botorch.models import SingleTaskVariationalGP
from botorch.models.utils.inducing_point_allocators import GreedyVarianceReduction 
from gpytorch.mlls import VariationalELBO
from gpytorch.kernels import ScaleKernel, MaternKernel
import time
import gc

def create_gp_forward(device):
    # create the training data
    train_X = torch.rand(4096, 5).to(device)
    train_Y = torch.sin(train_X).sum(dim=1, keepdim=True).to(device)

    kernel = ScaleKernel(MaternKernel())
    allocator = GreedyVarianceReduction()
    allocator.allocate_inducing_points(train_X, kernel, 30, train_X.shape[: -2])

    print(torch.cuda.memory_summary(device=device).split("\n")[7])

if __name__ == "__main__":
    for _ in range(10000):
        create_gp_forward(device = "cuda")
        time.sleep(2)
        torch.cuda.empty_cache()
        gc.collect()

I checked the source code of GreedyVarianceReduction class, but have no idea about the reason of out of memory.

@Balandat
Copy link
Contributor

Balandat commented Apr 8, 2023

Thanks for the report and for the repro. This seems rather odd - @henrymoss, do you have any idea what could be going on here? Is there any graph being built in the background that shouldn't be?

@songlei00 what happens if you wrap the create_gp_forward method in a with torch.no_grad() context?

cc also @esantorella re memory leak issues.

@songlei00
Copy link
Author

Thanks for your response.

Warpping the create_gp_forward method with torch.no_grad() is helpful and the usage of memory becomes a small and constant value. So I guess the out of memory is due to some graphs are built but never released.

I further tried to add torch.no_grad() context to the source code of the allocate_inducing_points function in GreedyVarianceReduction class. I find if I add torch.no_grad() context to L77:

train_train_kernel = covar_module(inputs).evaluate_kernel()

I can fix the out of memory problem. I think it is because calling covar_module creates a graph which is never released.

However, if I use the following code to iteratively run covar_module(inputs).evaluate_kernel(), the memory still maintain a normal value.

for _ in range(10000):
    train_X = torch.rand(4096, 5).to(device)
    kernel = ScaleKernel(MaternKernel()).to(device)
    train = kernel(train_X).evaluate_kernel()

It seems that the graph creating by covar_module(inputs).evaluate_kernel() and the later code, i.e., _pivoted_cholesky_init method in L82 of botorch.models.utils.inducing_point_allocators.py, jointly result in the out of memory problem. It is very strange, because the code in _pivoted_cholesky_init function just does some normal torch operations to calculate the results. I am very curious about the root cause.

@songlei00
Copy link
Author

songlei00 commented Apr 8, 2023

I copied minimum source code from the allocate_inducing_points method in GreedyVarianceReduction class to reproduce the problem and show how to fix it. But I don't know why torch.no_grad works and what is the root cause of out of memory problem. The code are as follows:

import torch
from gpytorch.kernels import ScaleKernel, MaternKernel
from botorch.models.utils.inducing_point_allocators import _pivoted_cholesky_init, UnitQualityFunction
import time
import gc

def bug_version(device):
    train_X = torch.rand(4096, 5).to(device)

    kernel = ScaleKernel(MaternKernel()).to(device)
    train_train_kernel = kernel(train_X).evaluate_kernel()
    quality_function = UnitQualityFunction()
    quality_scores = quality_function(train_X)
    inducing_points = _pivoted_cholesky_init(
        train_inputs=train_X,
        kernel_matrix=train_train_kernel,
        max_length=50,
        quality_scores=quality_scores
    )

    print(torch.cuda.memory_summary(device=device).split("\n")[7])

def fixed_version(device):
    train_X = torch.rand(4096, 5).to(device)

    kernel = ScaleKernel(MaternKernel()).to(device)
    with torch.no_grad():
        train_train_kernel = kernel(train_X).evaluate_kernel()
    quality_function = UnitQualityFunction()
    quality_scores = quality_function(train_X)
    inducing_points = _pivoted_cholesky_init(
        train_inputs=train_X,
        kernel_matrix=train_train_kernel,
        max_length=50,
        quality_scores=quality_scores
    )

    print(torch.cuda.memory_summary(device=device).split("\n")[7])

if __name__ == "__main__":
    for _ in range(10000):
        bug_version(device = "cuda")
        # fixed_version(device = "cuda")
        time.sleep(2)
        torch.cuda.empty_cache()
        gc.collect()

When I iteratively call the bug_version function, the memory increases and I met the out of memory problem. The outputs of torch.cuda.memory_summary are as follows:

| Allocated memory      |  199085 KB |  262388 KB |  335013 KB |  135928 KB |
| Allocated memory      |  332535 KB |  395838 KB |  670026 KB |  337491 KB |
| Allocated memory      |  465985 KB |  529288 KB |     981 MB |  539054 KB |
| Allocated memory      |  599435 KB |  662738 KB |    1308 MB |  740617 KB |
| Allocated memory      |  732885 KB |     777 MB |    1635 MB |     920 MB |
| Allocated memory      |     846 MB |     907 MB |    1962 MB |    1116 MB |
| Allocated memory      |     976 MB |    1038 MB |    2290 MB |    1313 MB |
| Allocated memory      |    1106 MB |    1168 MB |    2617 MB |    1510 MB |
| Allocated memory      |    1236 MB |    1298 MB |    2944 MB |    1707 MB |
| Allocated memory      |    1367 MB |    1429 MB |    3271 MB |    1904 MB |
...

When I iteratively call the fixed_version function which uses torch.no_grad(), the memory remains a small and constant value. The outputs of torch.cuda.memory_summary are as follows:

| Allocated memory      |   65636 KB |  196852 KB |  269477 KB |  203841 KB |
| Allocated memory      |   65636 KB |  196852 KB |  538954 KB |  473318 KB |
| Allocated memory      |   65636 KB |  196852 KB |     789 MB |  742795 KB |
| Allocated memory      |   65636 KB |  196852 KB |    1052 MB |     988 MB |
| Allocated memory      |   65636 KB |  196852 KB |    1315 MB |    1251 MB |
| Allocated memory      |   65636 KB |  196852 KB |    1578 MB |    1514 MB |
| Allocated memory      |   65636 KB |  196852 KB |    1842 MB |    1778 MB |
...

@Balandat
Copy link
Contributor

Balandat commented Apr 9, 2023

Hmm this seems to be some kind of issue with the graph for the parameter of the gpytorch kernel not getting garbage collected. Potentially this could be a bug with the caching in gpytorch. cc @gpleiss, @jacobrgardner

@esantorella esantorella self-assigned this Apr 10, 2023
@gpleiss
Copy link
Contributor

gpleiss commented May 26, 2023

It's probably a caching issue somewhere... I could try to look at this but I probably won't have time until later this summer. @esantorella if you want to dive into the caching hell that would be much appreciated!

@esantorella
Copy link
Member

esantorella commented Jun 12, 2023

I haven't fully figured out what's going on here, but I have an even smaller repro. Thanks so much @songlei00 for doing so much debugging work here -- everything you said was spot on and made digging into this a lot easier!

Here's a repro:

import torch
from gpytorch.kernels import ScaleKernel, MaternKernel
import gc
from memory_profiler import profile


def f(device="cpu") -> None:
    gc.collect()
    n = 4096 * 2
    train_X = torch.rand(n, 5).to(device)
    # without `ScaleKernel`, no leak
    kernel = ScaleKernel(MaternKernel()).to(device)
    # if this is instead `kernel(train_X).evaluate_kernel().tensor`, no leak
    train_train_kernel = kernel(train_X).evaluate_kernel()

    cis = torch.zeros((2, n), device=device, dtype=train_train_kernel.dtype)
    cis[0, :] = train_train_kernel[0, :]
    # if `cis[0, :]` is replaced with `kernel_matrix[0, :]`, no leak
    result = train_train_kernel[0, 0] * cis[0, :]
    # without this line, no leak
    cis[1, :] = result


@profile
def main() -> None:
    f()
    f()
    f()
    f()
    f()
    f()


if __name__ == "__main__":
    main()

Memory profiler output, indicating that memory increases linearly with each step:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    22    183.2 MiB    183.2 MiB           1   @profile
    23                                         def main() -> None:
    24   1211.5 MiB   1028.4 MiB           1       f()
    25   1724.2 MiB    512.7 MiB           1       f()
    26   2237.1 MiB    512.9 MiB           1       f()
    27   2750.2 MiB    513.1 MiB           1       f()
    28   3262.3 MiB    512.0 MiB           1       f()
    29   3774.9 MiB    512.7 MiB           1       f()

Digging into gc.get_objects() to see what's still in memory after f exits, it appears that there's at least one 8192 x 8192 tensor that gc is still tracking. gc isn't aware of anything referring to that tensor, which is why it isn't able to be garbage-collected. That tensor is allocated during evaluate_kernel, and it does not have the same numerical values as train_train_kernel, or anything close.

As @songlei00 said, reproducing the leak requires both a LinearOperator and some of the logic in _pivoted_cholesky_init, which I've pared down to the most relevant three lines. It is indeed puzzling that this happens since these are just regular operations on torch tensors, except that we index into train_train_kernel. It also seems to require ScaleKernel (using MaternKernel alone gets rid of the leak).

@esantorella
Copy link
Member

This is an issue with tensors being inappropriately persisted that can be reproduced purely with torch, so it isn't a GPyTorch or LinearOperator issue. Many thanks to @SebastianAment for helping debug this! When kernel has gradient information attached, and it's passed to _pivoted_cholesky_init, memory will tend to leak. Here's a function that leaks memory:

import torch
from botorch.models.utils.inducing_point_allocators import _pivoted_cholesky_init

def f():
    n = 4096

    # Part of `ScaleKernel.forward`,
    # which gets called in `evaluate_kernel`
    orig_output = torch.rand((n, n))
    # `outputscales = torch.zeros((1, 1), requires_grad=True)` will produce the same behavior
    outputscales = torch.nn.Parameter(torch.zeros((1, 1)))
    kernel = orig_output.mul(outputscales)

    # leaks memory unless we pass in `kernel.detach()` rather than `kernel`
    _pivoted_cholesky_init(
        train_inputs=torch.zeros((n, 2)),
        kernel_matrix=kernel,
        max_length=3,
        quality_scores=torch.ones(n),
    )   

Here's another, with pure PyTorch, using only the most relevant parts of _pivoted_cholesky_init:

import torch

def g():
    n = 4096

    # Part of `ScaleKernel.forward`,
    # which gets called in `evaluate_kernel`
    orig_output = torch.zeros((n, n))
    # alternately outputscales = torch.zeros((1, 1), requires_grad=True)
    outputscales = torch.nn.Parameter(torch.zeros((1, 1)))
    kernel = orig_output.mul(outputscales)

    # Part of `_pivoted_cholesky_init`
    cis = torch.zeros((2, n))
    cis[0, :] = kernel[0, :]
    cis[1, :] = kernel[0, 0] * cis[0, :]

This has something to do with autograd ,and with the self-referential in-place operations happening in the last two lines of the code above. After running f or g, a 4096 x 4096 tensor with values equivalent to orig_output persists in memory (as revealed by gc.get_objects()). I think a fix for this would be to add a line kernel = kernel.detach() in _pivoted_cholesky_init.

@Balandat
Copy link
Contributor

Awesome, great investigating!

I think a fix for this would be to add a line kernel = kernel.detach()

I assume we'd never want to backprop through _pivoted_cholesky_init? If that's the case then calling kernel.detach() makes sense to me.

If we do want to backprop we can also rewrite the code to not do these kinds of in-place operations that cause the issue.

esantorella added a commit to esantorella/botorch that referenced this issue Jun 17, 2023
Summary:
Fixes pytorch#1788 .
## Motivation

`allocate_inducing_points` leaks memory when passed a `kernel_matrix` with `requires_grad=True`. The memory leak happens due to a specific pattern of in-place torch operations in `_pivoted_cholesky_init`;  see [this comment](pytorch#1788 (comment)) for more explanation. There is no need for `allocate_inducing_points` to support a `kernel_matrix` with `requires_grad=True`, because the output of `allocate_inducing_points` is not differentiable anyway (thanks to in-place operations).

[x] make `_pivoted_cholesky_init` raise an `UnsupportedError` when passed a `kernel_matrix` with `requires_grad=True`. That is mildly BC-breaking, but I think that is okay since the alternative is a memory leak.
[x] Evaluate kernels with `torch.no_grad()` where they are only used to be passed to `_pivoted_cholesky_init`

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1890

Test Plan:
[x] Unit test for memory leak
[x] Unit test for UnsupportedError

Reviewed By: saitcakmak, Balandat

Differential Revision: D46803080

Pulled By: esantorella

fbshipit-source-id: 1fb9c6500d4246a3740a9fce4bda290043f8ac3b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants