From aeda3828913e62d9360fd306035a6af073462de3 Mon Sep 17 00:00:00 2001 From: Daniel Jiang Date: Tue, 21 Feb 2023 12:33:45 -0800 Subject: [PATCH] approximate qPI using MVNXPB (#1684) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1684 This is work by jiayuewan during his internship. I am simply moving it to OSS. Reviewed By: Balandat Differential Revision: D43337388 fbshipit-source-id: ad36066ab16d322d9deeed99110b090a273065d2 --- botorch/acquisition/__init__.py | 2 + botorch/acquisition/analytic.py | 67 +++++++++++++ test/acquisition/test_analytic.py | 156 ++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+) diff --git a/botorch/acquisition/__init__.py b/botorch/acquisition/__init__.py index b110d09530..f7f23f0729 100644 --- a/botorch/acquisition/__init__.py +++ b/botorch/acquisition/__init__.py @@ -19,6 +19,7 @@ NoisyExpectedImprovement, PosteriorMean, ProbabilityOfImprovement, + qAnalyticProbabilityOfImprovement, UpperConfidenceBound, ) from botorch.acquisition.cost_aware import ( @@ -77,6 +78,7 @@ "ProbabilityOfImprovement", "ProximalAcquisitionFunction", "UpperConfidenceBound", + "qAnalyticProbabilityOfImprovement", "qExpectedImprovement", "qKnowledgeGradient", "MaxValueBase", diff --git a/botorch/acquisition/analytic.py b/botorch/acquisition/analytic.py index 041f83c1a7..74165978a1 100644 --- a/botorch/acquisition/analytic.py +++ b/botorch/acquisition/analytic.py @@ -17,6 +17,7 @@ from contextlib import nullcontext from copy import deepcopy + from typing import Dict, Optional, Tuple, Union import torch @@ -27,6 +28,7 @@ from botorch.models.gpytorch import GPyTorchModel from botorch.models.model import Model from botorch.utils.constants import get_constants_like +from botorch.utils.probability import MVNXPB from botorch.utils.probability.utils import ( log_ndtr as log_Phi, log_phi, @@ -37,6 +39,7 @@ from botorch.utils.safe_math import log1mexp, logmeanexp from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform from torch import Tensor +from torch.nn.functional import pad _sqrt_2pi = math.sqrt(2 * math.pi) # the following two numbers are needed for _log_ei_helper @@ -231,6 +234,70 @@ def forward(self, X: Tensor) -> Tensor: return Phi(u) +class qAnalyticProbabilityOfImprovement(AnalyticAcquisitionFunction): + r"""Approximate, single-outcome batch Probability of Improvement using MVNXPB. + + This implementation uses MVNXPB, a bivariate conditioning algorithm for + approximating P(a <= Y <= b) for multivariate normal Y. + See [Trinh2015bivariate]_. This (analytic) approximate q-PI is given by + `approx-qPI(X) = P(max Y >= best_f) = 1 - P(Y < best_f), Y ~ f(X), + X = (x_1,...,x_q)`, where `P(Y < best_f)` is estimated using MVNXPB. + """ + + def __init__( + self, + model: Model, + best_f: Union[float, Tensor], + posterior_transform: Optional[PosteriorTransform] = None, + maximize: bool = True, + **kwargs, + ) -> None: + """qPI using an analytic approximation. + + Args: + model: A fitted single-outcome model. + best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing + the best function value observed so far (assumed noiseless). + posterior_transform: A PosteriorTransform. If using a multi-output model, + a PosteriorTransform that transforms the multi-output posterior into a + single-output posterior is required. + maximize: If True, consider the problem a maximization problem. + """ + super().__init__(model=model, posterior_transform=posterior_transform, **kwargs) + self.maximize = maximize + if not torch.is_tensor(best_f): + best_f = torch.tensor(best_f) + self.register_buffer("best_f", best_f) + + @t_batch_mode_transform() + def forward(self, X: Tensor) -> Tensor: + """Evaluate approximate qPI on the candidate set X. + + Args: + X: A `batch_shape x q x d`-dim Tensor of t-batches with `q` `d`-dim design + points each + + Returns: + A `batch_shape`-dim Tensor of approximate Probability of Improvement values + at the given design points `X`, where `batch_shape'` is the broadcasted + batch shape of model and input `X`. + """ + self.best_f = self.best_f.to(X) + posterior = self.model.posterior( + X=X, posterior_transform=self.posterior_transform + ) + + covariance = posterior.distribution.covariance_matrix + bounds = pad( + (self.best_f.unsqueeze(-1) - posterior.distribution.mean).unsqueeze(-1), + pad=(1, 0) if self.maximize else (0, 1), + value=-float("inf") if self.maximize else float("inf"), + ) + # 1 - P(no improvement over best_f) + solver = MVNXPB(covariance_matrix=covariance, bounds=bounds) + return -solver.solve().expm1() + + class ExpectedImprovement(AnalyticAcquisitionFunction): r"""Single-outcome Expected Improvement (analytic). diff --git a/test/acquisition/test_analytic.py b/test/acquisition/test_analytic.py index 03bae7a5cb..d9e7926ba0 100644 --- a/test/acquisition/test_analytic.py +++ b/test/acquisition/test_analytic.py @@ -7,6 +7,7 @@ import math import torch +from botorch.acquisition import qAnalyticProbabilityOfImprovement from botorch.acquisition.analytic import ( _compute_log_prob_feas, _ei_helper, @@ -362,6 +363,161 @@ def test_probability_of_improvement_batch(self): LogProbabilityOfImprovement(model=mm2, best_f=0.0) +class TestqAnalyticProbabilityOfImprovement(BotorchTestCase): + def test_q_analytic_probability_of_improvement(self): + for dtype in (torch.float, torch.double): + mean = torch.zeros(1, device=self.device, dtype=dtype) + cov = torch.eye(n=1, device=self.device, dtype=dtype) + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + + # basic test + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96) + X = torch.rand(1, 2, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor(0.0250, device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # basic test, maximize + module = qAnalyticProbabilityOfImprovement( + model=mm, best_f=1.96, maximize=False + ) + X = torch.rand(1, 2, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor(0.9750, device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # basic test, posterior transform (single-output) + mean = torch.ones(1, device=self.device, dtype=dtype) + cov = torch.eye(n=1, device=self.device, dtype=dtype) + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + weights = torch.tensor([0.5], device=self.device, dtype=dtype) + transform = ScalarizedPosteriorTransform(weights) + module = qAnalyticProbabilityOfImprovement( + model=mm, best_f=0.0, posterior_transform=transform + ) + X = torch.rand(1, 2, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor(0.8413, device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # basic test, posterior transform (multi-output) + mean = torch.ones(1, 2, device=self.device, dtype=dtype) + cov = torch.eye(n=2, device=self.device, dtype=dtype).unsqueeze(0) + mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + weights = torch.ones(2, device=self.device, dtype=dtype) + transform = ScalarizedPosteriorTransform(weights) + module = qAnalyticProbabilityOfImprovement( + model=mm, best_f=0.0, posterior_transform=transform + ) + X = torch.rand(1, 1, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor(0.9214, device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # basic test, q = 2 + mean = torch.zeros(2, device=self.device, dtype=dtype) + cov = torch.eye(n=2, device=self.device, dtype=dtype) + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96) + X = torch.zeros(2, 2, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor(0.049375, device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + def test_batch_q_analytic_probability_of_improvement(self): + for dtype in (torch.float, torch.double): + # test batch mode + mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype) + cov = ( + torch.eye(n=1, device=self.device, dtype=dtype) + .unsqueeze(0) + .repeat(2, 1, 1) + ) + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=0) + X = torch.rand(2, 1, 1, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # test batched model and best_f values + mean = torch.zeros(2, 1, device=self.device, dtype=dtype) + cov = ( + torch.eye(n=1, device=self.device, dtype=dtype) + .unsqueeze(0) + .repeat(2, 1, 1) + ) + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + best_f = torch.tensor([0.0, -1.0], device=self.device, dtype=dtype) + module = qAnalyticProbabilityOfImprovement(model=mm, best_f=best_f) + X = torch.rand(2, 1, 1, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor([[0.5, 0.8413]], device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # test batched model, output transform (single output) + mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype) + cov = ( + torch.eye(n=1, device=self.device, dtype=dtype) + .unsqueeze(0) + .repeat(2, 1, 1) + ) + mvn = MultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + weights = torch.tensor([0.5], device=self.device, dtype=dtype) + transform = ScalarizedPosteriorTransform(weights) + module = qAnalyticProbabilityOfImprovement( + model=mm, best_f=0.0, posterior_transform=transform + ) + X = torch.rand(2, 1, 2, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # test batched model, output transform (multiple output) + mean = torch.tensor( + [[[1.0, 1.0]], [[0.0, 1.0]]], device=self.device, dtype=dtype + ) + cov = ( + torch.eye(n=2, device=self.device, dtype=dtype) + .unsqueeze(0) + .repeat(2, 1, 1) + ) + mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov) + posterior = GPyTorchPosterior(mvn) + mm = MockModel(posterior) + weights = torch.ones(2, device=self.device, dtype=dtype) + transform = ScalarizedPosteriorTransform(weights) + module = qAnalyticProbabilityOfImprovement( + model=mm, best_f=0.0, posterior_transform=transform + ) + X = torch.rand(2, 1, 2, device=self.device, dtype=dtype) + pi = module(X) + pi_expected = torch.tensor( + [0.9214, 0.7602], device=self.device, dtype=dtype + ) + self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4)) + + # test bad posterior transform class + with self.assertRaises(UnsupportedError): + qAnalyticProbabilityOfImprovement( + model=mm, best_f=0.0, posterior_transform=IdentityMCObjective() + ) + + class TestUpperConfidenceBound(BotorchTestCase): def test_upper_confidence_bound(self): for dtype in (torch.float, torch.double):