-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Memory not freed after a variational GP model is discarded #1788
Comments
I further find that the out of memory problem is due to For example, if you use the following code, the out of memory problem disappears. gp = SingleTaskVariationalGP(
train_X, train_Y,
inducing_points=train_X[: 50],
) You can use the following code to reproduce the out of memory problem, which calls the import torch
from botorch.models import SingleTaskVariationalGP
from botorch.models.utils.inducing_point_allocators import GreedyVarianceReduction
from gpytorch.mlls import VariationalELBO
from gpytorch.kernels import ScaleKernel, MaternKernel
import time
import gc
def create_gp_forward(device):
# create the training data
train_X = torch.rand(4096, 5).to(device)
train_Y = torch.sin(train_X).sum(dim=1, keepdim=True).to(device)
kernel = ScaleKernel(MaternKernel())
allocator = GreedyVarianceReduction()
allocator.allocate_inducing_points(train_X, kernel, 30, train_X.shape[: -2])
print(torch.cuda.memory_summary(device=device).split("\n")[7])
if __name__ == "__main__":
for _ in range(10000):
create_gp_forward(device = "cuda")
time.sleep(2)
torch.cuda.empty_cache()
gc.collect() I checked the source code of |
Thanks for the report and for the repro. This seems rather odd - @henrymoss, do you have any idea what could be going on here? Is there any graph being built in the background that shouldn't be? @songlei00 what happens if you wrap the cc also @esantorella re memory leak issues. |
Thanks for your response. Warpping the I further tried to add
I can fix the out of memory problem. I think it is because calling However, if I use the following code to iteratively run for _ in range(10000):
train_X = torch.rand(4096, 5).to(device)
kernel = ScaleKernel(MaternKernel()).to(device)
train = kernel(train_X).evaluate_kernel() It seems that the graph creating by |
I copied minimum source code from the import torch
from gpytorch.kernels import ScaleKernel, MaternKernel
from botorch.models.utils.inducing_point_allocators import _pivoted_cholesky_init, UnitQualityFunction
import time
import gc
def bug_version(device):
train_X = torch.rand(4096, 5).to(device)
kernel = ScaleKernel(MaternKernel()).to(device)
train_train_kernel = kernel(train_X).evaluate_kernel()
quality_function = UnitQualityFunction()
quality_scores = quality_function(train_X)
inducing_points = _pivoted_cholesky_init(
train_inputs=train_X,
kernel_matrix=train_train_kernel,
max_length=50,
quality_scores=quality_scores
)
print(torch.cuda.memory_summary(device=device).split("\n")[7])
def fixed_version(device):
train_X = torch.rand(4096, 5).to(device)
kernel = ScaleKernel(MaternKernel()).to(device)
with torch.no_grad():
train_train_kernel = kernel(train_X).evaluate_kernel()
quality_function = UnitQualityFunction()
quality_scores = quality_function(train_X)
inducing_points = _pivoted_cholesky_init(
train_inputs=train_X,
kernel_matrix=train_train_kernel,
max_length=50,
quality_scores=quality_scores
)
print(torch.cuda.memory_summary(device=device).split("\n")[7])
if __name__ == "__main__":
for _ in range(10000):
bug_version(device = "cuda")
# fixed_version(device = "cuda")
time.sleep(2)
torch.cuda.empty_cache()
gc.collect() When I iteratively call the
When I iteratively call the
|
Hmm this seems to be some kind of issue with the graph for the parameter of the gpytorch kernel not getting garbage collected. Potentially this could be a bug with the caching in gpytorch. cc @gpleiss, @jacobrgardner |
It's probably a caching issue somewhere... I could try to look at this but I probably won't have time until later this summer. @esantorella if you want to dive into the caching hell that would be much appreciated! |
I haven't fully figured out what's going on here, but I have an even smaller repro. Thanks so much @songlei00 for doing so much debugging work here -- everything you said was spot on and made digging into this a lot easier! Here's a repro: import torch
from gpytorch.kernels import ScaleKernel, MaternKernel
import gc
from memory_profiler import profile
def f(device="cpu") -> None:
gc.collect()
n = 4096 * 2
train_X = torch.rand(n, 5).to(device)
# without `ScaleKernel`, no leak
kernel = ScaleKernel(MaternKernel()).to(device)
# if this is instead `kernel(train_X).evaluate_kernel().tensor`, no leak
train_train_kernel = kernel(train_X).evaluate_kernel()
cis = torch.zeros((2, n), device=device, dtype=train_train_kernel.dtype)
cis[0, :] = train_train_kernel[0, :]
# if `cis[0, :]` is replaced with `kernel_matrix[0, :]`, no leak
result = train_train_kernel[0, 0] * cis[0, :]
# without this line, no leak
cis[1, :] = result
@profile
def main() -> None:
f()
f()
f()
f()
f()
f()
if __name__ == "__main__":
main() Memory profiler output, indicating that memory increases linearly with each step:
Digging into As @songlei00 said, reproducing the leak requires both a |
This is an issue with tensors being inappropriately persisted that can be reproduced purely with torch, so it isn't a GPyTorch or LinearOperator issue. Many thanks to @SebastianAment for helping debug this! When import torch
from botorch.models.utils.inducing_point_allocators import _pivoted_cholesky_init
def f():
n = 4096
# Part of `ScaleKernel.forward`,
# which gets called in `evaluate_kernel`
orig_output = torch.rand((n, n))
# `outputscales = torch.zeros((1, 1), requires_grad=True)` will produce the same behavior
outputscales = torch.nn.Parameter(torch.zeros((1, 1)))
kernel = orig_output.mul(outputscales)
# leaks memory unless we pass in `kernel.detach()` rather than `kernel`
_pivoted_cholesky_init(
train_inputs=torch.zeros((n, 2)),
kernel_matrix=kernel,
max_length=3,
quality_scores=torch.ones(n),
) Here's another, with pure PyTorch, using only the most relevant parts of import torch
def g():
n = 4096
# Part of `ScaleKernel.forward`,
# which gets called in `evaluate_kernel`
orig_output = torch.zeros((n, n))
# alternately outputscales = torch.zeros((1, 1), requires_grad=True)
outputscales = torch.nn.Parameter(torch.zeros((1, 1)))
kernel = orig_output.mul(outputscales)
# Part of `_pivoted_cholesky_init`
cis = torch.zeros((2, n))
cis[0, :] = kernel[0, :]
cis[1, :] = kernel[0, 0] * cis[0, :] This has something to do with autograd ,and with the self-referential in-place operations happening in the last two lines of the code above. After running |
Awesome, great investigating!
I assume we'd never want to backprop through If we do want to backprop we can also rewrite the code to not do these kinds of in-place operations that cause the issue. |
Summary: Fixes pytorch#1788 . ## Motivation `allocate_inducing_points` leaks memory when passed a `kernel_matrix` with `requires_grad=True`. The memory leak happens due to a specific pattern of in-place torch operations in `_pivoted_cholesky_init`; see [this comment](pytorch#1788 (comment)) for more explanation. There is no need for `allocate_inducing_points` to support a `kernel_matrix` with `requires_grad=True`, because the output of `allocate_inducing_points` is not differentiable anyway (thanks to in-place operations). [x] make `_pivoted_cholesky_init` raise an `UnsupportedError` when passed a `kernel_matrix` with `requires_grad=True`. That is mildly BC-breaking, but I think that is okay since the alternative is a memory leak. [x] Evaluate kernels with `torch.no_grad()` where they are only used to be passed to `_pivoted_cholesky_init` ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1890 Test Plan: [x] Unit test for memory leak [x] Unit test for UnsupportedError Reviewed By: saitcakmak, Balandat Differential Revision: D46803080 Pulled By: esantorella fbshipit-source-id: 1fb9c6500d4246a3740a9fce4bda290043f8ac3b
🐛 Bug
Iteratively creating variational GP
SingleTaskVariationalGP
will result in out of memory. I find a similar problem in #1585 which uses exact GP, i.e.,SingleTaskGP
. Usegc.collect()
will solve the problem in #1585 but is useless for my problem.I add
torch.cuda.empty_cache()
andgc.collect()
in my code and the code only creates theSingleTaskVariationalGP
model and doesn't do anything about forward. However, the memory still increases.To reproduce
** Stack trace/error message **
Expected Behavior
The variables created in
create_gp_forward
are local variables, so when finishing the function, the memory should be released and the occupation should be constant during the loop.System information
The text was updated successfully, but these errors were encountered: