diff --git a/botorch/models/pairwise_gp.py b/botorch/models/pairwise_gp.py index 5be462451f..56b1e4c164 100644 --- a/botorch/models/pairwise_gp.py +++ b/botorch/models/pairwise_gp.py @@ -34,6 +34,7 @@ ) from botorch.models.model import FantasizeMixin, Model from botorch.models.transforms.input import InputTransform +from botorch.models.utils.assorted import consolidate_duplicates from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior import Posterior from gpytorch import settings @@ -53,6 +54,36 @@ from torch.nn.modules.module import _IncompatibleKeys +# Helper functions +def _check_strict_input(inputs, t_inputs, target_or_inputs): + for input_, t_input in zip(inputs, t_inputs or (None,)): + for attr in {"shape", "dtype", "device"}: + expected_attr = getattr(t_input, attr, None) + found_attr = getattr(input_, attr, None) + if expected_attr != found_attr: + msg = ( + "Cannot modify {attr} of {t_or_i} " + "(expected {e_attr}, found {f_attr})." + ) + msg = msg.format( + attr=attr, + e_attr=expected_attr, + f_attr=found_attr, + t_or_i=target_or_inputs, + ) + raise RuntimeError(msg) + + +def _scaled_psd_safe_cholesky( + M: Tensor, scale: Union[float, Tensor], jitter: Optional[float] = None +) -> Tensor: + r"""scale M by 1/outputscale before cholesky for better numerical stability""" + M = M / scale + chol = psd_safe_cholesky(M, jitter=jitter) + chol = chol * scale.sqrt() + return chol + + # Why we subclass GP even though it provides no functionality: # if this subclassing is removed, we get the following GPyTorch error: # "RuntimeError: All MarginalLogLikelihood objects must be given a GP object as @@ -77,7 +108,6 @@ class PairwiseGP(Model, GP, FantasizeMixin): In the example below, the user/decision maker has stated that they prefer the first item over the second item and the third item over the second item, generating comparisons [0, 1] and [2, 1]. - Example: >>> from botorch.models import PairwiseGP >>> import torch @@ -87,8 +117,8 @@ class PairwiseGP(Model, GP, FantasizeMixin): """ _buffer_names = [ - "datapoints", - "comparisons", + "consolidated_datapoints", + "consolidated_comparisons", "D", "DT", "utility", @@ -97,6 +127,9 @@ class PairwiseGP(Model, GP, FantasizeMixin): "hlcov_eye", "covar", "covar_inv", + "unconsolidated_datapoints", + "unconsolidated_comparisons", + "consolidated_indices", ] def __init__( @@ -121,6 +154,20 @@ def __init__( """ super().__init__() + # Set optional parameters + # Explicitly set jitter for numerical stability in psd_safe_cholesky + self._jitter = kwargs.get("jitter", 1e-6) + # Stopping creteria in scipy.optimize.fsolve used to find f_map in _update() + # If None, set to 1e-6 by default in _update + self._xtol = kwargs.get("xtol") + # atol rtol for consolidate_duplicates + self._consolidate_rtol = kwargs.get("consolidate_rtol", 1e-4) + self._consolidate_atol = kwargs.get("consolidate_atol", 1e-6) + # The maximum number of calls to the function in scipy.optimize.fsolve + # If None, set to 100 by default in _update + # If zero, then 100*(N+1) is used by default by fsolve; + self._maxfev = kwargs.get("maxfev") + if input_transform is not None: input_transform.to(datapoints) # input transformation is applied in set_train_data @@ -142,6 +189,9 @@ def __init__( self.pred_cov_fac_need_update = True self.dim = None + self.consolidated_datapoints = None + self.consolidated_comparisons = None + self.consolidated_indices = None # See set_train_data for additional compatibility variables. # Not that the datapoints here are not transformed even if input_transform @@ -149,17 +199,6 @@ def __init__( # self.transform_inputs is called in `forward` self.set_train_data(datapoints, comparisons, update_model=False) - # Set optional parameters - # Explicitly set jitter for numerical stability in psd_safe_cholesky - self._jitter = kwargs.get("jitter", 1e-6) - # Stopping creteria in scipy.optimize.fsolve used to find f_map in _update() - # If None, set to 1e-6 by default in _update - self._xtol = kwargs.get("xtol") - # The maximum number of calls to the function in scipy.optimize.fsolve - # If None, set to 100 by default in _update - # If zero, then 100*(N+1) is used by default by fsolve; - self._maxfev = kwargs.get("maxfev") - # Set hyperparameters # Do not set the batch_shape explicitly so mean_module can operate in both mode # once fsolve used in _update can run in batch mode, we should explicitly set @@ -202,22 +241,26 @@ def __init__( if self.datapoints is not None and self.comparisons is not None: self.to(dtype=self.datapoints.dtype, device=self.datapoints.device) # Find f_map for initial parameters with transformed datapoints - transformed_dp = self.transform_inputs(datapoints) + transformed_dp = self.transform_inputs(self.datapoints) self._update(transformed_dp) self.to(self.datapoints) def __deepcopy__(self, memo) -> PairwiseGP: attrs = ( - "datapoints", - "comparisons", + "consolidated_datapoints", + "consolidated_comparisons", "covar", "covar_inv", "covar_chol", "likelihood_hess", "utility", "hlcov_eye", + "unconsolidated_datapoints", + "unconsolidated_comparisons", + "consolidated_indices", ) + if any(getattr(self, attr) is not None for attr in attrs): # Temporarily remove non-leaf tensors so that pytorch allows deepcopy old_attr = {} @@ -237,16 +280,6 @@ def __deepcopy__(self, memo) -> PairwiseGP: self.__deepcopy__ = dcp return new_model - def _scaled_psd_safe_cholesky( - self, M: Tensor, jitter: Optional[float] = None - ) -> Tensor: - r"""scale M by 1/outputscale before cholesky for better numerical stability""" - scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1) - M = M / scale - chol = psd_safe_cholesky(M, jitter=jitter) - chol = chol * scale.sqrt() - return chol - def _has_no_data(self): r"""Return true if the model does not have both datapoints and comparisons""" return ( @@ -269,8 +302,10 @@ def _update_covar(self, datapoints: Tensor) -> None: datapoints: (Transformed) datapoints for finding f_max """ self.covar = self._calc_covar(datapoints, datapoints) - self.covar_chol = self._scaled_psd_safe_cholesky( - self.covar, jitter=self._jitter + self.covar_chol = _scaled_psd_safe_cholesky( + M=self.covar, + scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1), + jitter=self._jitter, ) self.covar_inv = torch.cholesky_inverse(self.covar_chol) @@ -587,25 +622,60 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor: return x - def _check_strict_input(self, inputs, t_inputs, target_or_inputs): - for input_, t_input in zip(inputs, t_inputs or (None,)): - for attr in {"shape", "dtype", "device"}: - expected_attr = getattr(t_input, attr, None) - found_attr = getattr(input_, attr, None) - if expected_attr != found_attr: - msg = ( - "Cannot modify {attr} of {t_or_i} " - "(expected {e_attr}, found {f_attr})." - ) - msg = msg.format( - attr=attr, - e_attr=expected_attr, - f_attr=found_attr, - t_or_i=target_or_inputs, - ) - raise RuntimeError(msg) + def _consolidate_duplicates( + self, + datapoints: Tensor, + comparisons: Tensor, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> Tuple[Tensor, Tensor]: + """Consolidate and cache datapoints and comparisons""" + # check if consolidated datapoints/comparisons are cached + if ( + datapoints is not self.unconsolidated_datapoints + or comparisons is not self.unconsolidated_comparisons + or self.consolidated_datapoints is None + or self.consolidated_comparisons is None + ): + if len(datapoints.shape) > 2 or len(comparisons.shape) > 2: + # Do not perform consolidation in batch mode as block design + # cannot be guaranteed + self.consolidated_datapoints = datapoints + self.consolidated_comparisons = comparisons + self.consolidated_indices = None + else: + ( + self.consolidated_datapoints, + self.consolidated_comparisons, + self.consolidated_indices, + ) = consolidate_duplicates( + datapoints, + comparisons, + rtol=rtol, + atol=atol, + ) + + return self.consolidated_datapoints, self.consolidated_comparisons # ============== public APIs ============== + @property + def datapoints(self) -> Tensor: + r"""Alias for consolidated datapoints""" + return self.consolidated_datapoints + + @property + def comparisons(self) -> Tensor: + r"""Alias for consolidated comparisons""" + return self.consolidated_comparisons + + @property + def unconsolidated_utility(self) -> Tensor: + r"""Utility of the unconsolidated datapoints""" + if self.consolidated_indices is None: + # self.consolidated_indices is None in batch mode + return self.utility + else: + return self.utility[self.consolidated_indices] @property def num_outputs(self) -> int: @@ -652,6 +722,11 @@ def set_train_data( if datapoints is None or comparisons is None: return + self.unconsolidated_datapoints, self.unconsolidated_comparisons = ( + datapoints, + comparisons, + ) + # following gpytorch.models.exact_gp.set_train_data if datapoints is not None: if torch.is_tensor(datapoints): @@ -662,23 +737,31 @@ def set_train_data( for input_ in inputs ) if strict: - self._check_strict_input(inputs, self.train_inputs, "inputs") + _check_strict_input(inputs, self.train_inputs, "inputs") - self.datapoints = inputs[0] + datapoints = inputs[0] # Compatibility variables with fit_gpytorch_* # alias for datapoints ("train_inputs") self.train_inputs = inputs if comparisons is not None: if strict: - self._check_strict_input([comparisons], [self.train_targets], "targets") + _check_strict_input([comparisons], [self.train_targets], "targets") # convert to long so that it can be used as index and # compatible with Tensor.scatter_ - self.comparisons = comparisons.long() + comparisons = comparisons.long() # Compatibility variables with fit_gpytorch_* # alias for comparisons ("train_targets" here) - self.train_targets = self.comparisons + self.train_targets = comparisons + + # self.datapoints and self.comparisons are being updated here + self._consolidate_duplicates( + datapoints, + comparisons, + rtol=self._consolidate_rtol, + atol=self._consolidate_atol, + ) # Compatibility variables with optimize_acqf self._dtype = self.datapoints.dtype @@ -704,7 +787,7 @@ def set_train_data( self.DT = self.D.transpose(-1, -2) if update_model: - transformed_dp = self.transform_inputs(datapoints) + transformed_dp = self.transform_inputs(self.datapoints) self._update(transformed_dp) self.to(self.datapoints) @@ -780,18 +863,19 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal: "or call .set_train_data() to add training data." ) - if datapoints is not self.datapoints: + if datapoints is not self.unconsolidated_datapoints: raise RuntimeError("Must train on training data") - transformed_dp = self.transform_inputs(datapoints) - # We pass in the untransformed datapoints into set_train_data # as we will be setting self.datapoints as the untransformed datapoints # self.transform_inputs will be called inside before calling _update() self.set_train_data(datapoints, self.comparisons, update_model=True) + transformed_dp = self.transform_inputs(self.datapoints) + hl = self.likelihood_hess covar = self.covar + # Apply matrix inversion lemma on eq. in page 27 of [Brochu2010tutorial]_ # (A + B)^-1 = A^-1 - A^-1 @ (I + BA^-1)^-1 @ BA^-1 # where A = covar_inv, B = hl @@ -854,7 +938,11 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal: # output_covar is sometimes non-PSD # perform a cholesky decomposition to check and amend covariance_matrix=RootLinearOperator( - self._scaled_psd_safe_cholesky(output_covar, jitter=self._jitter) + _scaled_psd_safe_cholesky( + output_covar, + scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1), + jitter=self._jitter, + ) ), ) return post @@ -975,7 +1063,7 @@ def forward(self, post: Posterior, comp: Tensor) -> Tensor: model = self.model likelihood = self.likelihood - if comp is not model.comparisons: + if comp is not model.unconsolidated_comparisons: raise RuntimeError("Must train on training data") f_map = post.mean.squeeze(-1) diff --git a/botorch/models/utils/__init__.py b/botorch/models/utils/__init__.py index bff9ecd999..746e8ac216 100644 --- a/botorch/models/utils/__init__.py +++ b/botorch/models/utils/__init__.py @@ -10,6 +10,8 @@ check_min_max_scaling, check_no_nans, check_standardization, + consolidate_duplicates, + detect_duplicates, fantasize, gpt_posterior_settings, mod_batch_shape, @@ -32,4 +34,6 @@ "multioutput_to_batch_mode_transform", "mod_batch_shape", "validate_input_scaling", + "detect_duplicates", + "consolidate_duplicates", ] diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py index 27e2ff97d4..0662ad108b 100644 --- a/botorch/models/utils/assorted.py +++ b/botorch/models/utils/assorted.py @@ -10,7 +10,7 @@ import warnings from contextlib import contextmanager, ExitStack -from typing import List, Optional, Tuple +from typing import Iterator, List, Optional, Tuple import torch from botorch import settings @@ -282,6 +282,87 @@ def gpt_posterior_settings(): yield +def detect_duplicates( + X: Tensor, + rtol: Optional[float] = 1e-5, + atol: Optional[float] = 1e-8, +) -> Iterator[Tuple[int, int]]: + """Returns an iterator over index pairs `(duplicate index, original index)` for all + duplicate entries of `X`. Supporting 2-d Tensor only. + """ + if len(X.shape) != 2: + raise ValueError("X must have 2 dimensions.") + + tols = atol if atol is not None else 1e-8 + if rtol: + rval = X.abs().max(dim=-1, keepdim=True).values + tols = tols + rtol * rval.max(rval.transpose(-1, -2)) + + n = X.shape[-2] + dist = torch.full((n, n), float("inf"), device=X.device, dtype=X.dtype) + dist[torch.triu_indices(n, n, offset=1).unbind()] = torch.nn.functional.pdist( + X, p=float("inf") + ) + return ( + (i, int(j)) + # pyre-fixme[19]: Expected 1 positional argument. + for diff, j, i in zip(*(dist - tols).min(dim=-2), range(n)) + if diff < 0 + ) + + +def consolidate_duplicates( + X: Tensor, Y: Tensor, rtol: Optional[float] = None, atol: Optional[float] = None +) -> Tuple[Tensor, Tensor, Tensor]: + """Drop duplicated Xs and update the indices tensor Y accordingly. + Supporting 2d Tensor only as in batch mode block design is not guaranteed. + + Returns: + the consolidated X + the consolidated Y (e.g., pairwise comparisons indices) + a 1d tensor of length X.shape[-2], representing the new index of each + original item in X + + """ + if len(X.shape) != 2: + raise ValueError("X must have 2 dimensions.") + + n = X.shape[-2] + dup_map = dict(detect_duplicates(X=X, rtol=rtol, atol=atol)) + + # Handle edge cases conservatively + # If a item is in both dup set and kept set, do not remove it + common_set = set(dup_map.keys()).intersection(dup_map.values()) + for k in list(dup_map.keys()): + if k in common_set or dup_map[k] in common_set: + del dup_map[k] + + if dup_map: + dup_indices, kept_indices = zip(*dup_map.items()) + + unique_indices = sorted(set(range(n)) - set(dup_indices)) + + # After dropping the duplicates, + # the kept ones' indices may also change by being shifted up + new_idx_map = dict(zip(unique_indices, range(len(unique_indices)))) + new_indices_for_dup = (new_idx_map[idx] for idx in kept_indices) + new_idx_map.update(dict(zip(dup_indices, new_indices_for_dup))) + consolidated_X = X[list(unique_indices), :] + consolidated_Y = torch.tensor( + [[new_idx_map[item.item()] for item in row] for row in Y.unbind()], + dtype=torch.long, + ) + return ( + consolidated_X, + consolidated_Y, + torch.arange(n, device=Y.device, dtype=torch.long).apply_( + lambda x: new_idx_map[x] + ), + ) + else: + return X, Y, torch.arange(n, device=Y.device, dtype=Y.dtype) + + class fantasize(_Flag): r"""A flag denoting whether we are currently in a `fantasize` context.""" _state: bool = False diff --git a/test/models/test_pairwise_gp.py b/test/models/test_pairwise_gp.py index 4c52214c79..cf91a96be9 100644 --- a/test/models/test_pairwise_gp.py +++ b/test/models/test_pairwise_gp.py @@ -175,6 +175,27 @@ def test_pairwise_gp(self): with self.assertRaises(RuntimeError): model.set_train_data(train_X, changed_train_comp, strict=True) + # Test consolidation + i1, i2 = train_X.shape[-2], train_X.shape[-2] + 1 + dup_comp = torch.cat( + [train_comp, torch.tensor([[i1, i2]]).expand(*batch_shape, 1, 2)], + dim=-2, + ) + dup_X = torch.cat([train_X, train_X[..., :2, :]], dim=-2) + model = PairwiseGP(datapoints=dup_X, comparisons=dup_comp) + self.assertTrue(dup_X is model.unconsolidated_datapoints) + self.assertTrue(dup_comp is model.unconsolidated_comparisons) + if batch_shape: + self.assertTrue(dup_X is model.consolidated_datapoints) + self.assertTrue(dup_comp is model.consolidated_comparisons) + self.assertTrue(model.utility is model.unconsolidated_utility) + else: + self.assertFalse(torch.equal(dup_X, model.consolidated_datapoints)) + self.assertFalse(torch.equal(dup_comp, model.consolidated_comparisons)) + self.assertFalse( + torch.equal(model.utility, model.unconsolidated_utility) + ) + def test_condition_on_observations(self): for batch_shape, dtype, likelihood_cls in itertools.product( (torch.Size(), torch.Size([2])), diff --git a/test/models/utils/test_assorted.py b/test/models/utils/test_assorted.py index 493e360f93..898f1b2ed0 100644 --- a/test/models/utils/test_assorted.py +++ b/test/models/utils/test_assorted.py @@ -19,6 +19,8 @@ multioutput_to_batch_mode_transform, validate_input_scaling, ) + +from botorch.models.utils.assorted import consolidate_duplicates, detect_duplicates from botorch.utils.testing import BotorchTestCase from gpytorch import settings as gpt_settings @@ -226,3 +228,53 @@ def test_fantasize(self): with fantasize(False): self.assertFalse(fantasize.on()) self.assertTrue(fantasize.off()) + + +class TestConsolidation(BotorchTestCase): + def test_consolidation(self): + X = torch.tensor( + [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [1.0, 2.0, 3.0], + [3.0, 4.0, 5.0], + ] + ) + no_dup_X = torch.tensor( + [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0], + [4.0, 5.0, 6.0], + ] + ) + Y = torch.tensor([[0, 1], [2, 3]]) + expected_X = torch.tensor( + [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0], + ] + ) + expected_Y = torch.tensor([[0, 1], [0, 2]]) + + # deduped case + consolidated_X, consolidated_Y, new_indices = consolidate_duplicates(X=X, Y=Y) + self.assertTrue(torch.equal(consolidated_X, expected_X)) + self.assertTrue(torch.equal(consolidated_Y, expected_Y)) + self.assertTrue(torch.equal(new_indices, torch.tensor([0, 1, 0, 2]))) + + # not deduped case + consolidated_X, consolidated_Y, new_indices = consolidate_duplicates( + X=no_dup_X, Y=Y + ) + self.assertTrue(torch.equal(consolidated_X, no_dup_X)) + self.assertTrue(torch.equal(consolidated_Y, Y)) + self.assertTrue(torch.equal(new_indices, torch.tensor([0, 1, 2, 3]))) + + # test batch shape + with self.assertRaises(ValueError): + consolidate_duplicates(X=X.repeat(2, 1, 1), Y=Y.repeat(2, 1, 1)) + + with self.assertRaises(ValueError): + detect_duplicates(X=X.repeat(2, 1, 1))