diff --git a/botorch/posteriors/fully_bayesian.py b/botorch/posteriors/fully_bayesian.py index 3084a743d2..9b36553f36 100644 --- a/botorch/posteriors/fully_bayesian.py +++ b/botorch/posteriors/fully_bayesian.py @@ -3,8 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from functools import lru_cache from typing import Callable, Optional, Tuple import torch @@ -54,6 +54,30 @@ def batched_bisect( return center +def _quantile(posterior: FullyBayesianPosterior, value: Tensor) -> Tensor: + r"""Compute the posterior quantiles for the mixture of models.""" + if value.numel() > 1: + return torch.stack( + [_quantile(posterior=posterior, value=v) for v in value], dim=0 + ) + if value <= 0 or value >= 1: + raise ValueError("value is expected to be in the range (0, 1).") + dist = torch.distributions.Normal( + loc=posterior.mean, scale=posterior.variance.sqrt() + ) + if posterior.mean.shape[MCMC_DIM] == 1: # Analytical solution + return dist.icdf(value).squeeze(MCMC_DIM) + icdf_val = dist.icdf(value) + low = icdf_val.min(dim=MCMC_DIM).values - TOL + high = icdf_val.max(dim=MCMC_DIM).values + TOL + bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0) + return batched_bisect( + f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM), + target=value.item(), + bounds=bounds, + ) + + class FullyBayesianPosterior(GPyTorchPosterior): r"""A posterior for a fully Bayesian model. @@ -82,30 +106,6 @@ def __init__(self, distribution: MultivariateNormal) -> None: self._mixture_mean: Optional[Tensor] = None self._mixture_variance: Optional[Tensor] = None - # using lru_cache on methods can cause memory leaks. See flake8 B019 - # So we define a function here instead, to be called by self.quantile - @lru_cache - def _quantile(value: Tensor) -> Tensor: - r"""Compute the posterior quantiles for the mixture of models.""" - if value.numel() > 1: - return torch.stack([self.quantile(v) for v in value], dim=0) - if value <= 0 or value >= 1: - raise ValueError("value is expected to be in the range (0, 1).") - dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt()) - if self.mean.shape[MCMC_DIM] == 1: # Analytical solution - return dist.icdf(value).squeeze(MCMC_DIM) - icdf_val = dist.icdf(value) - low = icdf_val.min(dim=MCMC_DIM).values - TOL - high = icdf_val.max(dim=MCMC_DIM).values + TOL - bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0) - return batched_bisect( - f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM), - target=value.item(), - bounds=bounds, - ) - - self._quantile = _quantile - @property def mixture_mean(self) -> Tensor: r"""The posterior mean for the mixture of models.""" @@ -126,7 +126,7 @@ def mixture_variance(self) -> Tensor: def quantile(self, value: Tensor) -> Tensor: r"""Compute the posterior quantiles for the mixture of models.""" - return self._quantile(value) + return _quantile(posterior=self, value=value) @property def batch_range(self) -> Tuple[int, int]: