Skip to content

Commit

Permalink
Move consolidate_duplicates to BoTorch and consolidate duplicates in …
Browse files Browse the repository at this point in the history
…PairwiseGP

Summary:
# Context
One problem for GP models is that when evaluating points that are close, it is likely to trigger numerical issues resulted from non-PSD covariance matrix. The problem is particularly pronounced and hard to bypass when doing optimization (either BOPE or preferential BO) as we would need to repetitively compare points to the incumbent.

To improve preference learning stability, we can automatically consolidate the same (or numerically similar points) into the same point. For example, with training data `datapoints = [[1, 2], [3, 4], [1, 2], [5, 6]]` and `comparisons = [[0, 1], [2, 3]]` with be turned into the consolidated `datapoints = [[1, 2], [3, 4], [5, 6]]` and `comparisons = [[0, 1], [0, 2]]`. This shouldn't lead to any changes model fitting as the likelihood remains the same.

# Code changes
To implement this, following changes are made
- Upstreamed the `consolidate_duplicates` and related helper functions from `Ax` to `Botorch`.
- Implicitly replace `datapoint` and `comparisons` in `PairwiseGP` with the consolidated ones.
- Added `unconsolidated_datapoints`, `unconsolidated_comparisons`, and `unconsolidated_utility` in case the user would like to access the original data and the corresponding utility directly from the model.

Differential Revision: D44126864

fbshipit-source-id: cae0420824a072282fac2d31dda59a3a042e291f
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Mar 20, 2023
1 parent e4e2f8a commit f6e6fbc
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 58 deletions.
202 changes: 145 additions & 57 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from botorch.models.model import FantasizeMixin, Model
from botorch.models.transforms.input import InputTransform
from botorch.models.utils.assorted import consolidate_duplicates
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from gpytorch import settings
Expand All @@ -53,6 +54,36 @@
from torch.nn.modules.module import _IncompatibleKeys


# Helper functions
def _check_strict_input(inputs, t_inputs, target_or_inputs):
for input_, t_input in zip(inputs, t_inputs or (None,)):
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(t_input, attr, None)
found_attr = getattr(input_, attr, None)
if expected_attr != found_attr:
msg = (
"Cannot modify {attr} of {t_or_i} "
"(expected {e_attr}, found {f_attr})."
)
msg = msg.format(
attr=attr,
e_attr=expected_attr,
f_attr=found_attr,
t_or_i=target_or_inputs,
)
raise RuntimeError(msg)


def _scaled_psd_safe_cholesky(
M: Tensor, scale: Union[float, Tensor], jitter: Optional[float] = None
) -> Tensor:
r"""scale M by 1/outputscale before cholesky for better numerical stability"""
M = M / scale
chol = psd_safe_cholesky(M, jitter=jitter)
chol = chol * scale.sqrt()
return chol


# Why we subclass GP even though it provides no functionality:
# if this subclassing is removed, we get the following GPyTorch error:
# "RuntimeError: All MarginalLogLikelihood objects must be given a GP object as
Expand All @@ -77,7 +108,6 @@ class PairwiseGP(Model, GP, FantasizeMixin):
In the example below, the user/decision maker has stated that they prefer
the first item over the second item and the third item over the second item,
generating comparisons [0, 1] and [2, 1].
Example:
>>> from botorch.models import PairwiseGP
>>> import torch
Expand All @@ -87,8 +117,8 @@ class PairwiseGP(Model, GP, FantasizeMixin):
"""

_buffer_names = [
"datapoints",
"comparisons",
"consolidated_datapoints",
"consolidated_comparisons",
"D",
"DT",
"utility",
Expand All @@ -97,6 +127,9 @@ class PairwiseGP(Model, GP, FantasizeMixin):
"hlcov_eye",
"covar",
"covar_inv",
"unconsolidated_datapoints",
"unconsolidated_comparisons",
"consolidated_indices",
]

def __init__(
Expand All @@ -121,6 +154,20 @@ def __init__(
"""
super().__init__()

# Set optional parameters
# Explicitly set jitter for numerical stability in psd_safe_cholesky
self._jitter = kwargs.get("jitter", 1e-6)
# Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
# If None, set to 1e-6 by default in _update
self._xtol = kwargs.get("xtol")
# atol rtol for consolidate_duplicates
self._consolidate_rtol = kwargs.get("consolidate_rtol", 1e-4)
self._consolidate_atol = kwargs.get("consolidate_atol", 1e-6)
# The maximum number of calls to the function in scipy.optimize.fsolve
# If None, set to 100 by default in _update
# If zero, then 100*(N+1) is used by default by fsolve;
self._maxfev = kwargs.get("maxfev")

if input_transform is not None:
input_transform.to(datapoints)
# input transformation is applied in set_train_data
Expand All @@ -142,24 +189,16 @@ def __init__(

self.pred_cov_fac_need_update = True
self.dim = None
self.consolidated_datapoints = None
self.consolidated_comparisons = None
self.consolidated_indices = None

# See set_train_data for additional compatibility variables.
# Not that the datapoints here are not transformed even if input_transform
# is not None to avoid double transformation during model fitting.
# self.transform_inputs is called in `forward`
self.set_train_data(datapoints, comparisons, update_model=False)

# Set optional parameters
# Explicitly set jitter for numerical stability in psd_safe_cholesky
self._jitter = kwargs.get("jitter", 1e-6)
# Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
# If None, set to 1e-6 by default in _update
self._xtol = kwargs.get("xtol")
# The maximum number of calls to the function in scipy.optimize.fsolve
# If None, set to 100 by default in _update
# If zero, then 100*(N+1) is used by default by fsolve;
self._maxfev = kwargs.get("maxfev")

# Set hyperparameters
# Do not set the batch_shape explicitly so mean_module can operate in both mode
# once fsolve used in _update can run in batch mode, we should explicitly set
Expand Down Expand Up @@ -202,22 +241,26 @@ def __init__(
if self.datapoints is not None and self.comparisons is not None:
self.to(dtype=self.datapoints.dtype, device=self.datapoints.device)
# Find f_map for initial parameters with transformed datapoints
transformed_dp = self.transform_inputs(datapoints)
transformed_dp = self.transform_inputs(self.datapoints)
self._update(transformed_dp)

self.to(self.datapoints)

def __deepcopy__(self, memo) -> PairwiseGP:
attrs = (
"datapoints",
"comparisons",
"consolidated_datapoints",
"consolidated_comparisons",
"covar",
"covar_inv",
"covar_chol",
"likelihood_hess",
"utility",
"hlcov_eye",
"unconsolidated_datapoints",
"unconsolidated_comparisons",
"consolidated_indices",
)

if any(getattr(self, attr) is not None for attr in attrs):
# Temporarily remove non-leaf tensors so that pytorch allows deepcopy
old_attr = {}
Expand All @@ -237,16 +280,6 @@ def __deepcopy__(self, memo) -> PairwiseGP:
self.__deepcopy__ = dcp
return new_model

def _scaled_psd_safe_cholesky(
self, M: Tensor, jitter: Optional[float] = None
) -> Tensor:
r"""scale M by 1/outputscale before cholesky for better numerical stability"""
scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1)
M = M / scale
chol = psd_safe_cholesky(M, jitter=jitter)
chol = chol * scale.sqrt()
return chol

def _has_no_data(self):
r"""Return true if the model does not have both datapoints and comparisons"""
return (
Expand All @@ -269,8 +302,10 @@ def _update_covar(self, datapoints: Tensor) -> None:
datapoints: (Transformed) datapoints for finding f_max
"""
self.covar = self._calc_covar(datapoints, datapoints)
self.covar_chol = self._scaled_psd_safe_cholesky(
self.covar, jitter=self._jitter
self.covar_chol = _scaled_psd_safe_cholesky(
M=self.covar,
scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1),
jitter=self._jitter,
)
self.covar_inv = torch.cholesky_inverse(self.covar_chol)

Expand Down Expand Up @@ -587,25 +622,60 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:

return x

def _check_strict_input(self, inputs, t_inputs, target_or_inputs):
for input_, t_input in zip(inputs, t_inputs or (None,)):
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(t_input, attr, None)
found_attr = getattr(input_, attr, None)
if expected_attr != found_attr:
msg = (
"Cannot modify {attr} of {t_or_i} "
"(expected {e_attr}, found {f_attr})."
)
msg = msg.format(
attr=attr,
e_attr=expected_attr,
f_attr=found_attr,
t_or_i=target_or_inputs,
)
raise RuntimeError(msg)
def _consolidate_duplicates(
self,
datapoints: Tensor,
comparisons: Tensor,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> Tuple[Tensor, Tensor]:
"""Consolidate and cache datapoints and comparisons"""
# check if consolidated datapoints/comparisons are cached
if (
datapoints is not self.unconsolidated_datapoints
or comparisons is not self.unconsolidated_comparisons
or self.consolidated_datapoints is None
or self.consolidated_comparisons is None
):
if len(datapoints.shape) > 2 or len(comparisons.shape) > 2:
# Do not perform consolidation in batch mode as block design
# cannot be guaranteed
self.consolidated_datapoints = datapoints
self.consolidated_comparisons = comparisons
self.consolidated_indices = None
else:
(
self.consolidated_datapoints,
self.consolidated_comparisons,
self.consolidated_indices,
) = consolidate_duplicates(
datapoints,
comparisons,
rtol=rtol,
atol=atol,
)

return self.consolidated_datapoints, self.consolidated_comparisons

# ============== public APIs ==============
@property
def datapoints(self) -> Tensor:
r"""Alias for consolidated datapoints"""
return self.consolidated_datapoints

@property
def comparisons(self) -> Tensor:
r"""Alias for consolidated comparisons"""
return self.consolidated_comparisons

@property
def unconsolidated_utility(self) -> Tensor:
r"""Utility of the unconsolidated datapoints"""
if self.consolidated_indices is None:
# self.consolidated_indices is None in batch mode
return self.utility
else:
return self.utility[self.consolidated_indices]

@property
def num_outputs(self) -> int:
Expand Down Expand Up @@ -652,6 +722,11 @@ def set_train_data(
if datapoints is None or comparisons is None:
return

self.unconsolidated_datapoints, self.unconsolidated_comparisons = (
datapoints,
comparisons,
)

# following gpytorch.models.exact_gp.set_train_data
if datapoints is not None:
if torch.is_tensor(datapoints):
Expand All @@ -662,23 +737,31 @@ def set_train_data(
for input_ in inputs
)
if strict:
self._check_strict_input(inputs, self.train_inputs, "inputs")
_check_strict_input(inputs, self.train_inputs, "inputs")

self.datapoints = inputs[0]
datapoints = inputs[0]
# Compatibility variables with fit_gpytorch_*
# alias for datapoints ("train_inputs")
self.train_inputs = inputs

if comparisons is not None:
if strict:
self._check_strict_input([comparisons], [self.train_targets], "targets")
_check_strict_input([comparisons], [self.train_targets], "targets")

# convert to long so that it can be used as index and
# compatible with Tensor.scatter_
self.comparisons = comparisons.long()
comparisons = comparisons.long()
# Compatibility variables with fit_gpytorch_*
# alias for comparisons ("train_targets" here)
self.train_targets = self.comparisons
self.train_targets = comparisons

# self.datapoints and self.comparisons are being updated here
self._consolidate_duplicates(
datapoints,
comparisons,
rtol=self._consolidate_rtol,
atol=self._consolidate_atol,
)

# Compatibility variables with optimize_acqf
self._dtype = self.datapoints.dtype
Expand All @@ -704,7 +787,7 @@ def set_train_data(
self.DT = self.D.transpose(-1, -2)

if update_model:
transformed_dp = self.transform_inputs(datapoints)
transformed_dp = self.transform_inputs(self.datapoints)
self._update(transformed_dp)

self.to(self.datapoints)
Expand Down Expand Up @@ -780,18 +863,19 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
"or call .set_train_data() to add training data."
)

if datapoints is not self.datapoints:
if datapoints is not self.unconsolidated_datapoints:
raise RuntimeError("Must train on training data")

transformed_dp = self.transform_inputs(datapoints)

# We pass in the untransformed datapoints into set_train_data
# as we will be setting self.datapoints as the untransformed datapoints
# self.transform_inputs will be called inside before calling _update()
self.set_train_data(datapoints, self.comparisons, update_model=True)

transformed_dp = self.transform_inputs(self.datapoints)

hl = self.likelihood_hess
covar = self.covar

# Apply matrix inversion lemma on eq. in page 27 of [Brochu2010tutorial]_
# (A + B)^-1 = A^-1 - A^-1 @ (I + BA^-1)^-1 @ BA^-1
# where A = covar_inv, B = hl
Expand Down Expand Up @@ -854,7 +938,11 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
# output_covar is sometimes non-PSD
# perform a cholesky decomposition to check and amend
covariance_matrix=RootLinearOperator(
self._scaled_psd_safe_cholesky(output_covar, jitter=self._jitter)
_scaled_psd_safe_cholesky(
output_covar,
scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1),
jitter=self._jitter,
)
),
)
return post
Expand Down Expand Up @@ -975,7 +1063,7 @@ def forward(self, post: Posterior, comp: Tensor) -> Tensor:

model = self.model
likelihood = self.likelihood
if comp is not model.comparisons:
if comp is not model.unconsolidated_comparisons:
raise RuntimeError("Must train on training data")

f_map = post.mean.squeeze(-1)
Expand Down
4 changes: 4 additions & 0 deletions botorch/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
check_min_max_scaling,
check_no_nans,
check_standardization,
consolidate_duplicates,
detect_duplicates,
fantasize,
gpt_posterior_settings,
mod_batch_shape,
Expand All @@ -32,4 +34,6 @@
"multioutput_to_batch_mode_transform",
"mod_batch_shape",
"validate_input_scaling",
"detect_duplicates",
"consolidate_duplicates",
]
Loading

0 comments on commit f6e6fbc

Please sign in to comment.