Skip to content

Commit

Permalink
Don't use functools.lru_cache on methods (pytorch#1650)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1650

It can cause memory leaks. For explanation
 and fix, see https://rednafi.github.io/reflections/dont-wrap-instance-methods-with-functoolslru_cache-decorator-in-python.html

Reviewed By: Balandat

Differential Revision: D42980747

fbshipit-source-id: 225edb7e43c363d1f2b4580bec77488542b1a7c3
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 3, 2023
1 parent b3d3074 commit 7e1350b
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions botorch/posteriors/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


from functools import lru_cache
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple

import torch
from botorch.posteriors.gpytorch import GPyTorchPosterior
Expand Down Expand Up @@ -79,41 +79,54 @@ def __init__(self, distribution: MultivariateNormal) -> None:
else distribution.variance.unsqueeze(-1)
)

self._mixture_mean: Optional[Tensor] = None
self._mixture_variance: Optional[Tensor] = None

# using lru_cache on methods can cause memory leaks. See flake8 B019
# So we define a function here instead, to be called by self.quantile
@lru_cache
def _quantile(value: Tensor) -> Tensor:
r"""Compute the posterior quantiles for the mixture of models."""
if value.numel() > 1:
return torch.stack([self.quantile(v) for v in value], dim=0)
if value <= 0 or value >= 1:
raise ValueError("value is expected to be in the range (0, 1).")
dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt())
if self.mean.shape[MCMC_DIM] == 1: # Analytical solution
return dist.icdf(value).squeeze(MCMC_DIM)
icdf_val = dist.icdf(value)
low = icdf_val.min(dim=MCMC_DIM).values - TOL
high = icdf_val.max(dim=MCMC_DIM).values + TOL
bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0)
return batched_bisect(
f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM),
target=value.item(),
bounds=bounds,
)

self._quantile = _quantile

@property
@lru_cache(maxsize=None)
def mixture_mean(self) -> Tensor:
r"""The posterior mean for the mixture of models."""
return self._mean.mean(dim=MCMC_DIM)
if self._mixture_mean is None:
self._mixture_mean = self._mean.mean(dim=MCMC_DIM)
return self._mixture_mean

@property
@lru_cache(maxsize=None)
def mixture_variance(self) -> Tensor:
r"""The posterior variance for the mixture of models."""
num_mcmc_samples = self.mean.shape[MCMC_DIM]
t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples
t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples
t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2)
return t1 + t2 + t3
if self._mixture_variance is None:
num_mcmc_samples = self.mean.shape[MCMC_DIM]
t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples
t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples
t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2)
self._mixture_variance = t1 + t2 + t3
return self._mixture_variance

@lru_cache(maxsize=None)
def quantile(self, value: Tensor) -> Tensor:
r"""Compute the posterior quantiles for the mixture of models."""
if value.numel() > 1:
return torch.stack([self.quantile(v) for v in value], dim=0)
if value <= 0 or value >= 1:
raise ValueError("value is expected to be in the range (0, 1).")
dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt())
if self.mean.shape[MCMC_DIM] == 1: # Analytical solution
return dist.icdf(value).squeeze(MCMC_DIM)
icdf_val = dist.icdf(value)
low = icdf_val.min(dim=MCMC_DIM).values - TOL
high = icdf_val.max(dim=MCMC_DIM).values + TOL
bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0)
return batched_bisect(
f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM),
target=value.item(),
bounds=bounds,
)
return self._quantile(value)

@property
def batch_range(self) -> Tuple[int, int]:
Expand Down

0 comments on commit 7e1350b

Please sign in to comment.