Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

approximate qPI using MVNXPB #1684

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
NoisyExpectedImprovement,
PosteriorMean,
ProbabilityOfImprovement,
qAnalyticProbabilityOfImprovement,
UpperConfidenceBound,
)
from botorch.acquisition.cost_aware import (
Expand Down Expand Up @@ -77,6 +78,7 @@
"ProbabilityOfImprovement",
"ProximalAcquisitionFunction",
"UpperConfidenceBound",
"qAnalyticProbabilityOfImprovement",
"qExpectedImprovement",
"qKnowledgeGradient",
"MaxValueBase",
Expand Down
67 changes: 67 additions & 0 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from contextlib import nullcontext
from copy import deepcopy

from typing import Dict, Optional, Tuple, Union

import torch
Expand All @@ -27,6 +28,7 @@
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model
from botorch.utils.constants import get_constants_like
from botorch.utils.probability import MVNXPB
from botorch.utils.probability.utils import (
log_ndtr as log_Phi,
log_phi,
Expand All @@ -37,6 +39,7 @@
from botorch.utils.safe_math import log1mexp, logmeanexp
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
from torch import Tensor
from torch.nn.functional import pad

_sqrt_2pi = math.sqrt(2 * math.pi)
# the following two numbers are needed for _log_ei_helper
Expand Down Expand Up @@ -231,6 +234,70 @@ def forward(self, X: Tensor) -> Tensor:
return Phi(u)


class qAnalyticProbabilityOfImprovement(AnalyticAcquisitionFunction):
r"""Approximate, single-outcome batch Probability of Improvement using MVNXPB.

This implementation uses MVNXPB, a bivariate conditioning algorithm for
approximating P(a <= Y <= b) for multivariate normal Y.
See [Trinh2015bivariate]_. This (analytic) approximate q-PI is given by
`approx-qPI(X) = P(max Y >= best_f) = 1 - P(Y < best_f), Y ~ f(X),
X = (x_1,...,x_q)`, where `P(Y < best_f)` is estimated using MVNXPB.
"""

def __init__(
self,
model: Model,
best_f: Union[float, Tensor],
posterior_transform: Optional[PosteriorTransform] = None,
maximize: bool = True,
**kwargs,
) -> None:
"""qPI using an analytic approximation.

Args:
model: A fitted single-outcome model.
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
the best function value observed so far (assumed noiseless).
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
"""
super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
self.maximize = maximize
if not torch.is_tensor(best_f):
best_f = torch.tensor(best_f)
self.register_buffer("best_f", best_f)

@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
"""Evaluate approximate qPI on the candidate set X.

Args:
X: A `batch_shape x q x d`-dim Tensor of t-batches with `q` `d`-dim design
points each

Returns:
A `batch_shape`-dim Tensor of approximate Probability of Improvement values
at the given design points `X`, where `batch_shape'` is the broadcasted
batch shape of model and input `X`.
"""
self.best_f = self.best_f.to(X)
posterior = self.model.posterior(
X=X, posterior_transform=self.posterior_transform
)

covariance = posterior.distribution.covariance_matrix
bounds = pad(
(self.best_f.unsqueeze(-1) - posterior.distribution.mean).unsqueeze(-1),
pad=(1, 0) if self.maximize else (0, 1),
value=-float("inf") if self.maximize else float("inf"),
)
# 1 - P(no improvement over best_f)
solver = MVNXPB(covariance_matrix=covariance, bounds=bounds)
return -solver.solve().expm1()


class ExpectedImprovement(AnalyticAcquisitionFunction):
r"""Single-outcome Expected Improvement (analytic).

Expand Down
156 changes: 156 additions & 0 deletions test/acquisition/test_analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math

import torch
from botorch.acquisition import qAnalyticProbabilityOfImprovement
from botorch.acquisition.analytic import (
_compute_log_prob_feas,
_ei_helper,
Expand Down Expand Up @@ -362,6 +363,161 @@ def test_probability_of_improvement_batch(self):
LogProbabilityOfImprovement(model=mm2, best_f=0.0)


class TestqAnalyticProbabilityOfImprovement(BotorchTestCase):
def test_q_analytic_probability_of_improvement(self):
for dtype in (torch.float, torch.double):
mean = torch.zeros(1, device=self.device, dtype=dtype)
cov = torch.eye(n=1, device=self.device, dtype=dtype)
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)

# basic test
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96)
X = torch.rand(1, 2, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor(0.0250, device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# basic test, maximize
module = qAnalyticProbabilityOfImprovement(
model=mm, best_f=1.96, maximize=False
)
X = torch.rand(1, 2, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor(0.9750, device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# basic test, posterior transform (single-output)
mean = torch.ones(1, device=self.device, dtype=dtype)
cov = torch.eye(n=1, device=self.device, dtype=dtype)
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)
weights = torch.tensor([0.5], device=self.device, dtype=dtype)
transform = ScalarizedPosteriorTransform(weights)
module = qAnalyticProbabilityOfImprovement(
model=mm, best_f=0.0, posterior_transform=transform
)
X = torch.rand(1, 2, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor(0.8413, device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# basic test, posterior transform (multi-output)
mean = torch.ones(1, 2, device=self.device, dtype=dtype)
cov = torch.eye(n=2, device=self.device, dtype=dtype).unsqueeze(0)
mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)
weights = torch.ones(2, device=self.device, dtype=dtype)
transform = ScalarizedPosteriorTransform(weights)
module = qAnalyticProbabilityOfImprovement(
model=mm, best_f=0.0, posterior_transform=transform
)
X = torch.rand(1, 1, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor(0.9214, device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# basic test, q = 2
mean = torch.zeros(2, device=self.device, dtype=dtype)
cov = torch.eye(n=2, device=self.device, dtype=dtype)
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96)
X = torch.zeros(2, 2, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor(0.049375, device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

def test_batch_q_analytic_probability_of_improvement(self):
for dtype in (torch.float, torch.double):
# test batch mode
mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype)
cov = (
torch.eye(n=1, device=self.device, dtype=dtype)
.unsqueeze(0)
.repeat(2, 1, 1)
)
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=0)
X = torch.rand(2, 1, 1, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# test batched model and best_f values
mean = torch.zeros(2, 1, device=self.device, dtype=dtype)
cov = (
torch.eye(n=1, device=self.device, dtype=dtype)
.unsqueeze(0)
.repeat(2, 1, 1)
)
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)
best_f = torch.tensor([0.0, -1.0], device=self.device, dtype=dtype)
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=best_f)
X = torch.rand(2, 1, 1, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor([[0.5, 0.8413]], device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# test batched model, output transform (single output)
mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype)
cov = (
torch.eye(n=1, device=self.device, dtype=dtype)
.unsqueeze(0)
.repeat(2, 1, 1)
)
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)
weights = torch.tensor([0.5], device=self.device, dtype=dtype)
transform = ScalarizedPosteriorTransform(weights)
module = qAnalyticProbabilityOfImprovement(
model=mm, best_f=0.0, posterior_transform=transform
)
X = torch.rand(2, 1, 2, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# test batched model, output transform (multiple output)
mean = torch.tensor(
[[[1.0, 1.0]], [[0.0, 1.0]]], device=self.device, dtype=dtype
)
cov = (
torch.eye(n=2, device=self.device, dtype=dtype)
.unsqueeze(0)
.repeat(2, 1, 1)
)
mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov)
posterior = GPyTorchPosterior(mvn)
mm = MockModel(posterior)
weights = torch.ones(2, device=self.device, dtype=dtype)
transform = ScalarizedPosteriorTransform(weights)
module = qAnalyticProbabilityOfImprovement(
model=mm, best_f=0.0, posterior_transform=transform
)
X = torch.rand(2, 1, 2, device=self.device, dtype=dtype)
pi = module(X)
pi_expected = torch.tensor(
[0.9214, 0.7602], device=self.device, dtype=dtype
)
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

# test bad posterior transform class
with self.assertRaises(UnsupportedError):
qAnalyticProbabilityOfImprovement(
model=mm, best_f=0.0, posterior_transform=IdentityMCObjective()
)


class TestUpperConfidenceBound(BotorchTestCase):
def test_upper_confidence_bound(self):
for dtype in (torch.float, torch.double):
Expand Down