From 6e7a3469521a11ea00b2e348addffc603582ceca Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Fri, 21 Oct 2022 12:27:55 -0700 Subject: [PATCH] Adjust PairwiseGP ScaleKernel prior (#1460) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1460 Updating the prior of PairwiseGP's output scale prior. Additionally, also make sure it must be used, better initialization of the inferred utility values, and replaced `_batch_chol_inv` with `torch.cholesky_inverse`. TLDR is that we were previously using an significantly restrictive prior on the output scale theta (note theta = 1/sigma^2 where sigma is the probit noise on the function value), this prevent us from accommodating comparison errors outside range of the green line. Reviewed By: Balandat Differential Revision: D40136741 fbshipit-source-id: 7f48b162ef529416f5acbb6045d92b0e28b62255 --- botorch/models/pairwise_gp.py | 65 +++++++++++++++++++-------------- test/models/test_pairwise_gp.py | 11 +++++- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/botorch/models/pairwise_gp.py b/botorch/models/pairwise_gp.py index 04b0c8fc1d..5d701b5a47 100644 --- a/botorch/models/pairwise_gp.py +++ b/botorch/models/pairwise_gp.py @@ -37,7 +37,7 @@ from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior import Posterior from gpytorch import settings -from gpytorch.constraints import GreaterThan +from gpytorch.constraints import GreaterThan, Interval from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels.rbf_kernel import RBFKernel from gpytorch.kernels.scale_kernel import ScaleKernel @@ -147,7 +147,7 @@ def __init__( # Set optional parameters # Explicitly set jitter for numerical stability in psd_safe_cholesky - self._jitter = kwargs.get("jitter", 1e-5) + self._jitter = kwargs.get("jitter", 1e-6) # Stopping creteria in scipy.optimize.fsolve used to find f_map in _update() # If None, set to 1e-6 by default in _update self._xtol = kwargs.get("xtol") @@ -170,6 +170,7 @@ def __init__( # estimates away from scale value that would make Phi(f(x)) saturate # at 0 or 1 if covar_module is None: + os_lb, os_ub = 1e-2, 1e2 ls_prior = GammaPrior(1.2, 0.5) ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate covar_module = ScaleKernel( @@ -181,9 +182,16 @@ def __init__( lower_bound=1e-4, transform=None, initial_value=ls_prior_mode ), ), - outputscale_prior=SmoothedBoxPrior(a=1, b=4), + outputscale_prior=SmoothedBoxPrior(a=os_lb, b=os_ub), + # make sure we won't get extreme values for the output scale + outputscale_constraint=Interval( + lower_bound=os_lb * 0.5, + upper_bound=os_ub * 2.0, + initial_value=1.0, + ), ) - + if not isinstance(covar_module, ScaleKernel): + raise UnsupportedError("PairwiseGP must be used with a ScaleKernel.") self.covar_module = covar_module self._x0 = None # will store temporary results for warm-starting @@ -225,6 +233,16 @@ def __deepcopy__(self, memo) -> PairwiseGP: self.__deepcopy__ = dcp return new_model + def _scaled_psd_safe_cholesky( + self, M: Tensor, jitter: Optional[float] = None + ) -> Tensor: + r"""scale M by 1/outputscale before cholesky for better numerical stability""" + scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1) + M = M / scale + chol = psd_safe_cholesky(M, jitter=jitter) + chol = chol * scale.sqrt() + return chol + def _has_no_data(self): r"""Return true if the model does not have both datapoints and comparisons""" return ( @@ -238,24 +256,6 @@ def _calc_covar(self, X1: Tensor, X2: Tensor) -> Union[Tensor, LinearOperator]: covar = self.covar_module(X1, X2) return covar.to_dense() - def _batch_chol_inv(self, mat_chol: Tensor) -> Tensor: - r"""Wrapper to perform (batched) cholesky inverse""" - # TODO: get rid of this once cholesky_inverse supports batch mode - batch_eye = torch.eye( - mat_chol.shape[-1], - dtype=self.datapoints.dtype, - device=self.datapoints.device, - ) - - if len(mat_chol.shape) == 2: - mat_inv = torch.cholesky_inverse(mat_chol) - elif len(mat_chol.shape) > 2 and (mat_chol.shape[-1] == mat_chol.shape[-2]): - batch_eye = batch_eye.repeat(*(mat_chol.shape[:-2]), 1, 1) - chol_inv = torch.linalg.solve_triangular(mat_chol, batch_eye, upper=False) - mat_inv = chol_inv.transpose(-1, -2) @ chol_inv - - return mat_inv - def _update_covar(self, datapoints: Tensor) -> None: r"""Update values derived from the data and hyperparameters @@ -265,8 +265,10 @@ def _update_covar(self, datapoints: Tensor) -> None: datapoints: (Transformed) datapoints for finding f_max """ self.covar = self._calc_covar(datapoints, datapoints) - self.covar_chol = psd_safe_cholesky(self.covar, jitter=self._jitter) - self.covar_inv = self._batch_chol_inv(self.covar_chol) + self.covar_chol = self._scaled_psd_safe_cholesky( + self.covar, jitter=self._jitter + ) + self.covar_inv = torch.cholesky_inverse(self.covar_chol) def _prior_mean(self, X: Tensor) -> Union[Tensor, LinearOperator]: r"""Return point prediction using prior only @@ -417,7 +419,17 @@ def _update(self, datapoints: Tensor, **kwargs) -> None: # warm start init_x0_size = self.batch_shape + torch.Size([self.n]) if self._x0 is None or torch.Size(self._x0.shape) != init_x0_size: - x0 = np.random.rand(*init_x0_size) + sqrt_scale = ( + self.covar_module.outputscale.sqrt() + .unsqueeze(-1) + .detach() + .cpu() + .numpy() + ) + # initialize x0 using std normal but clip by 3 std to keep it bounded + x0 = np.random.standard_normal(init_x0_size).clip(min=-3, max=3) + # scale x0 to be on roughly the right scale + x0 = x0 * sqrt_scale else: x0 = self._x0 @@ -755,7 +767,6 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal: 2. Prior predictions (prior mode) 3. Predictive posterior (eval mode) """ - # Training mode: optimizing if self.training: if self._has_no_data(): @@ -839,7 +850,7 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal: # output_covar is sometimes non-PSD # perform a cholesky decomposition to check and amend covariance_matrix=RootLinearOperator( - psd_safe_cholesky(output_covar, jitter=self._jitter) + self._scaled_psd_safe_cholesky(output_covar, jitter=self._jitter) ), ) return post diff --git a/test/models/test_pairwise_gp.py b/test/models/test_pairwise_gp.py index 86407cd0f5..eb2da00289 100644 --- a/test/models/test_pairwise_gp.py +++ b/test/models/test_pairwise_gp.py @@ -105,9 +105,16 @@ def test_pairwise_gp(self): self.assertEqual(model.num_outputs, 1) self.assertEqual(model.batch_shape, batch_shape) + # test not using a ScaleKernel + with self.assertRaisesRegex(RuntimeError, "used with a ScaleKernel"): + PairwiseGP(**model_kwargs, covar_module=LinearKernel()) + # test custom models - custom_m = PairwiseGP(**model_kwargs, covar_module=LinearKernel()) - self.assertIsInstance(custom_m.covar_module, LinearKernel) + custom_m = PairwiseGP( + **model_kwargs, covar_module=ScaleKernel(LinearKernel()) + ) + self.assertIsInstance(custom_m.covar_module, ScaleKernel) + self.assertIsInstance(custom_m.covar_module.base_kernel, LinearKernel) # prior prediction prior_m = PairwiseGP(None, None).to(**tkwargs)