Skip to content

Commit

Permalink
Adjust PairwiseGP ScaleKernel prior (#1460)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1460

Updating the prior of PairwiseGP's output scale prior.
Additionally, also make sure it must be used, better initialization of the inferred utility values, and replaced `_batch_chol_inv` with `torch.cholesky_inverse`.

TLDR is that we were previously using an significantly restrictive prior on the output scale theta (note theta = 1/sigma^2 where sigma is the probit noise on the function value), this prevent us from accommodating comparison errors outside range of the green line.

Reviewed By: Balandat

Differential Revision: D40136741

fbshipit-source-id: 7f48b162ef529416f5acbb6045d92b0e28b62255
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Oct 21, 2022
1 parent 2ea11a6 commit 6e7a346
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 29 deletions.
65 changes: 38 additions & 27 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from gpytorch import settings
from gpytorch.constraints import GreaterThan
from gpytorch.constraints import GreaterThan, Interval
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(

# Set optional parameters
# Explicitly set jitter for numerical stability in psd_safe_cholesky
self._jitter = kwargs.get("jitter", 1e-5)
self._jitter = kwargs.get("jitter", 1e-6)
# Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
# If None, set to 1e-6 by default in _update
self._xtol = kwargs.get("xtol")
Expand All @@ -170,6 +170,7 @@ def __init__(
# estimates away from scale value that would make Phi(f(x)) saturate
# at 0 or 1
if covar_module is None:
os_lb, os_ub = 1e-2, 1e2
ls_prior = GammaPrior(1.2, 0.5)
ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate
covar_module = ScaleKernel(
Expand All @@ -181,9 +182,16 @@ def __init__(
lower_bound=1e-4, transform=None, initial_value=ls_prior_mode
),
),
outputscale_prior=SmoothedBoxPrior(a=1, b=4),
outputscale_prior=SmoothedBoxPrior(a=os_lb, b=os_ub),
# make sure we won't get extreme values for the output scale
outputscale_constraint=Interval(
lower_bound=os_lb * 0.5,
upper_bound=os_ub * 2.0,
initial_value=1.0,
),
)

if not isinstance(covar_module, ScaleKernel):
raise UnsupportedError("PairwiseGP must be used with a ScaleKernel.")
self.covar_module = covar_module

self._x0 = None # will store temporary results for warm-starting
Expand Down Expand Up @@ -225,6 +233,16 @@ def __deepcopy__(self, memo) -> PairwiseGP:
self.__deepcopy__ = dcp
return new_model

def _scaled_psd_safe_cholesky(
self, M: Tensor, jitter: Optional[float] = None
) -> Tensor:
r"""scale M by 1/outputscale before cholesky for better numerical stability"""
scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1)
M = M / scale
chol = psd_safe_cholesky(M, jitter=jitter)
chol = chol * scale.sqrt()
return chol

def _has_no_data(self):
r"""Return true if the model does not have both datapoints and comparisons"""
return (
Expand All @@ -238,24 +256,6 @@ def _calc_covar(self, X1: Tensor, X2: Tensor) -> Union[Tensor, LinearOperator]:
covar = self.covar_module(X1, X2)
return covar.to_dense()

def _batch_chol_inv(self, mat_chol: Tensor) -> Tensor:
r"""Wrapper to perform (batched) cholesky inverse"""
# TODO: get rid of this once cholesky_inverse supports batch mode
batch_eye = torch.eye(
mat_chol.shape[-1],
dtype=self.datapoints.dtype,
device=self.datapoints.device,
)

if len(mat_chol.shape) == 2:
mat_inv = torch.cholesky_inverse(mat_chol)
elif len(mat_chol.shape) > 2 and (mat_chol.shape[-1] == mat_chol.shape[-2]):
batch_eye = batch_eye.repeat(*(mat_chol.shape[:-2]), 1, 1)
chol_inv = torch.linalg.solve_triangular(mat_chol, batch_eye, upper=False)
mat_inv = chol_inv.transpose(-1, -2) @ chol_inv

return mat_inv

def _update_covar(self, datapoints: Tensor) -> None:
r"""Update values derived from the data and hyperparameters
Expand All @@ -265,8 +265,10 @@ def _update_covar(self, datapoints: Tensor) -> None:
datapoints: (Transformed) datapoints for finding f_max
"""
self.covar = self._calc_covar(datapoints, datapoints)
self.covar_chol = psd_safe_cholesky(self.covar, jitter=self._jitter)
self.covar_inv = self._batch_chol_inv(self.covar_chol)
self.covar_chol = self._scaled_psd_safe_cholesky(
self.covar, jitter=self._jitter
)
self.covar_inv = torch.cholesky_inverse(self.covar_chol)

def _prior_mean(self, X: Tensor) -> Union[Tensor, LinearOperator]:
r"""Return point prediction using prior only
Expand Down Expand Up @@ -417,7 +419,17 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
# warm start
init_x0_size = self.batch_shape + torch.Size([self.n])
if self._x0 is None or torch.Size(self._x0.shape) != init_x0_size:
x0 = np.random.rand(*init_x0_size)
sqrt_scale = (
self.covar_module.outputscale.sqrt()
.unsqueeze(-1)
.detach()
.cpu()
.numpy()
)
# initialize x0 using std normal but clip by 3 std to keep it bounded
x0 = np.random.standard_normal(init_x0_size).clip(min=-3, max=3)
# scale x0 to be on roughly the right scale
x0 = x0 * sqrt_scale
else:
x0 = self._x0

Expand Down Expand Up @@ -755,7 +767,6 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
2. Prior predictions (prior mode)
3. Predictive posterior (eval mode)
"""

# Training mode: optimizing
if self.training:
if self._has_no_data():
Expand Down Expand Up @@ -839,7 +850,7 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
# output_covar is sometimes non-PSD
# perform a cholesky decomposition to check and amend
covariance_matrix=RootLinearOperator(
psd_safe_cholesky(output_covar, jitter=self._jitter)
self._scaled_psd_safe_cholesky(output_covar, jitter=self._jitter)
),
)
return post
Expand Down
11 changes: 9 additions & 2 deletions test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,16 @@ def test_pairwise_gp(self):
self.assertEqual(model.num_outputs, 1)
self.assertEqual(model.batch_shape, batch_shape)

# test not using a ScaleKernel
with self.assertRaisesRegex(RuntimeError, "used with a ScaleKernel"):
PairwiseGP(**model_kwargs, covar_module=LinearKernel())

# test custom models
custom_m = PairwiseGP(**model_kwargs, covar_module=LinearKernel())
self.assertIsInstance(custom_m.covar_module, LinearKernel)
custom_m = PairwiseGP(
**model_kwargs, covar_module=ScaleKernel(LinearKernel())
)
self.assertIsInstance(custom_m.covar_module, ScaleKernel)
self.assertIsInstance(custom_m.covar_module.base_kernel, LinearKernel)

# prior prediction
prior_m = PairwiseGP(None, None).to(**tkwargs)
Expand Down

0 comments on commit 6e7a346

Please sign in to comment.