Skip to content

Commit

Permalink
Assorted PairwiseGP stability improvements (#1755)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1755

Main changes include

- Prior update: updated prior for better model fit and better numerical stability
- Utility heuristic initialization: previously we initialize the latent utility (i.e., the latent function value) randomly, which may lead to extreme likelihood values and unnecessarily longer optimization time. We now use comparison-winning-count-based heuristics to initialize the utility weights.
- Ensuring covariance is PSD: despite the numerical instability of working on logit/probit scale, at the minimum, the covariance between training datapoints should be PSD by definition (e.g., when using a scaled RBF kernel). If this assumption is not hold, the accumulation of error is going to lead to many other undesirable consequences downstream. To resolve this, check and add jitter to guarantee the PSD-ness of covariance matrices.

Differential Revision: D44137937

fbshipit-source-id: a556e07aca80e4ba6ce67250fdcd744c40eae2a2
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Mar 22, 2023
1 parent e8d3f85 commit adf4ba0
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 8 deletions.
63 changes: 56 additions & 7 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators import LinearOperator, RootLinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.errors import NotPSDError
from scipy import optimize
from torch import float32, float64, Tensor
from torch.nn.modules.module import _IncompatibleKeys
Expand Down Expand Up @@ -86,6 +87,34 @@ def _scaled_psd_safe_cholesky(
return chol


def _ensure_psd_with_jitter(
M: Tensor,
scale: Union[float, Tensor] = 1.0,
jitter: float = 1e-8,
max_tries: int = 3,
) -> Tensor:
scaled_M = M / scale
new_jitter = 0
for i in range(max_tries):
scaled_M = scaled_M + new_jitter * torch.diag_embed(
torch.ones(
scaled_M.shape[:-1], device=scaled_M.device, dtype=scaled_M.dtype
)
)
_, info = torch.linalg.cholesky_ex(scaled_M)
is_psd = (info == 0).all()
if is_psd:
break
else:
new_jitter = jitter * (10**i) - new_jitter
if not is_psd:
raise NotPSDError(
"Matrix not positive definite after repeatedly adding jitter "
f"up to {jitter * (10**i):.1e}."
)
return scaled_M * scale


# Why we subclass GP even though it provides no functionality:
# if this subclassing is removed, we get the following GPyTorch error:
# "RuntimeError: All MarginalLogLikelihood objects must be given a GP object as
Expand Down Expand Up @@ -218,7 +247,7 @@ def __init__(
# at 0 or 1
if covar_module is None:
os_lb, os_ub = 1e-2, 1e2
ls_prior = GammaPrior(1.2, 0.5)
ls_prior = GammaPrior(concentration=2.4, rate=2.7)
ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate
covar_module = ScaleKernel(
RBFKernel(
Expand All @@ -228,6 +257,7 @@ def __init__(
lengthscale_constraint=GreaterThan(
lower_bound=1e-4, transform=None, initial_value=ls_prior_mode
),
dtype=torch.float64,
),
outputscale_prior=SmoothedBoxPrior(a=os_lb, b=os_ub),
# make sure we won't get extreme values for the output scale
Expand All @@ -236,6 +266,7 @@ def __init__(
upper_bound=os_ub * 2.0,
initial_value=1.0,
),
dtype=torch.float64,
)
if not isinstance(covar_module, ScaleKernel):
raise UnsupportedError("PairwiseGP must be used with a ScaleKernel.")
Expand Down Expand Up @@ -294,8 +325,16 @@ def _has_no_data(self):

def _calc_covar(self, X1: Tensor, X2: Tensor) -> Union[Tensor, LinearOperator]:
r"""Calculate the covariance matrix given two sets of datapoints"""
covar = self.covar_module(X1, X2)
return covar.to_dense()
covar = self.covar_module(X1, X2).to_dense()
# making sure covar is PSD when it's a covariance matrix
if X1 is X2:
scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1).detach()
covar = _ensure_psd_with_jitter(
M=covar,
scale=scale,
jitter=self._jitter,
)
return covar

def _update_covar(self, datapoints: Tensor) -> None:
r"""Update values derived from the data and hyperparameters
Expand All @@ -306,9 +345,10 @@ def _update_covar(self, datapoints: Tensor) -> None:
datapoints: (Transformed) datapoints for finding f_max
"""
self.covar = self._calc_covar(datapoints, datapoints)
scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1).detach()
self.covar_chol = _scaled_psd_safe_cholesky(
M=self.covar,
scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1),
scale=scale,
jitter=self._jitter,
)
self.covar_inv = torch.cholesky_inverse(self.covar_chol)
Expand Down Expand Up @@ -469,8 +509,16 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
.cpu()
.numpy()
)
# initialize x0 using std normal but clip by 3 std to keep it bounded
x0 = np.random.standard_normal(init_x0_size).clip(min=-3, max=3)
# Heuristic intialization using winning count with perturbation
# to avoid extreme or unprobable likelihood values
win_count = self.D.sum(dim=-2).detach().cpu().numpy()
wc_mean, wc_std = (
win_count.mean(axis=-1, keepdims=True),
win_count.std(axis=-1, keepdims=True).clip(min=1e-6),
)
x0 = (win_count - wc_mean) / wc_std
# adding random perturbation to in case get stuck at strange init values
x0 = x0 + 0.05 * np.random.standard_normal(init_x0_size)
# scale x0 to be on roughly the right scale
x0 = x0 * sqrt_scale
else:
Expand Down Expand Up @@ -928,14 +976,15 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:

output_mean, output_covar = pred_mean, pred_covar

scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1).detach()
post = MultivariateNormal(
mean=output_mean,
# output_covar is sometimes non-PSD
# perform a cholesky decomposition to check and amend
covariance_matrix=RootLinearOperator(
_scaled_psd_safe_cholesky(
output_covar,
scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1),
scale=scale,
jitter=self._jitter,
)
),
Expand Down
25 changes: 24 additions & 1 deletion test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
PairwiseLogitLikelihood,
PairwiseProbitLikelihood,
)
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.pairwise_gp import (
_ensure_psd_with_jitter,
PairwiseGP,
PairwiseLaplaceMarginalLogLikelihood,
)
from botorch.models.transforms.input import Normalize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.pairwise_samplers import PairwiseSobolQMCNormalSampler
Expand All @@ -24,6 +28,7 @@
from gpytorch.kernels.linear_kernel import LinearKernel
from gpytorch.means import ConstantMean
from gpytorch.priors import GammaPrior, SmoothedBoxPrior
from linear_operator.utils.errors import NotPSDError


class TestPairwiseGP(BotorchTestCase):
Expand Down Expand Up @@ -346,3 +351,21 @@ def test_load_state_dict(self):
_ = model.load_state_dict(sd)
for buffer_name in model._buffer_names:
self.assertIsNone(model.get_buffer(buffer_name))

def test_helper_functions(self):
for batch_shape, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
):
tkwargs = {"device": self.device, "dtype": dtype}
# M is borderline PSD
M = torch.ones((*batch_shape, 2, 2), **tkwargs)
with self.assertRaises(torch._C._LinAlgError):
torch.cholesky(M)
# This should work fine
_ensure_psd_with_jitter(M)

bad_M = torch.tensor([[1.0, 2.0], [2.0, 1.0]], **tkwargs).expand(
(*batch_shape, 2, 2)
)
with self.assertRaises(NotPSDError):
_ensure_psd_with_jitter(bad_M)

0 comments on commit adf4ba0

Please sign in to comment.