Skip to content

Commit

Permalink
Merge 4bf5221 into ff56d20
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsMrLin authored Mar 22, 2023
2 parents ff56d20 + 4bf5221 commit 531c6a9
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 59 deletions.
199 changes: 141 additions & 58 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from botorch.models.model import FantasizeMixin, Model
from botorch.models.transforms.input import InputTransform
from botorch.models.utils.assorted import consolidate_duplicates
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from gpytorch import settings
Expand All @@ -53,6 +54,38 @@
from torch.nn.modules.module import _IncompatibleKeys


# Helper functions
def _check_strict_input(
inputs: List[Tensor], t_inputs: List[Tensor], target_or_inputs: str
):
for input_, t_input in zip(inputs, t_inputs or (None,)):
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(t_input, attr, None)
found_attr = getattr(input_, attr, None)
if expected_attr != found_attr:
msg = (
"Cannot modify {attr} of {t_or_i} "
"(expected {e_attr}, found {f_attr})."
)
msg = msg.format(
attr=attr,
e_attr=expected_attr,
f_attr=found_attr,
t_or_i=target_or_inputs,
)
raise RuntimeError(msg)


def _scaled_psd_safe_cholesky(
M: Tensor, scale: Union[float, Tensor], jitter: Optional[float] = None
) -> Tensor:
r"""scale M by 1/outputscale before cholesky for better numerical stability"""
M = M / scale
chol = psd_safe_cholesky(M, jitter=jitter)
chol = chol * scale.sqrt()
return chol


# Why we subclass GP even though it provides no functionality:
# if this subclassing is removed, we get the following GPyTorch error:
# "RuntimeError: All MarginalLogLikelihood objects must be given a GP object as
Expand All @@ -77,7 +110,6 @@ class PairwiseGP(Model, GP, FantasizeMixin):
In the example below, the user/decision maker has stated that they prefer
the first item over the second item and the third item over the second item,
generating comparisons [0, 1] and [2, 1].
Example:
>>> from botorch.models import PairwiseGP
>>> import torch
Expand All @@ -87,8 +119,8 @@ class PairwiseGP(Model, GP, FantasizeMixin):
"""

_buffer_names = [
"datapoints",
"comparisons",
"consolidated_datapoints",
"consolidated_comparisons",
"D",
"DT",
"utility",
Expand All @@ -97,6 +129,9 @@ class PairwiseGP(Model, GP, FantasizeMixin):
"hlcov_eye",
"covar",
"covar_inv",
"unconsolidated_datapoints",
"unconsolidated_comparisons",
"consolidated_indices",
]

def __init__(
Expand All @@ -121,6 +156,20 @@ def __init__(
"""
super().__init__()

# Set optional parameters
# Explicitly set jitter for numerical stability in psd_safe_cholesky
self._jitter = kwargs.get("jitter", 1e-6)
# Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
# If None, set to 1e-6 by default in _update
self._xtol = kwargs.get("xtol")
# atol rtol for consolidate_duplicates
self._consolidate_rtol = kwargs.get("consolidate_rtol", 0)
self._consolidate_atol = kwargs.get("consolidate_atol", 1e-4)
# The maximum number of calls to the function in scipy.optimize.fsolve
# If None, set to 100 by default in _update
# If zero, then 100*(N+1) is used by default by fsolve;
self._maxfev = kwargs.get("maxfev")

if input_transform is not None:
input_transform.to(datapoints)
# input transformation is applied in set_train_data
Expand All @@ -142,24 +191,18 @@ def __init__(

self.pred_cov_fac_need_update = True
self.dim = None
self.unconsolidated_datapoints = None
self.unconsolidated_comparisons = None
self.consolidated_datapoints = None
self.consolidated_comparisons = None
self.consolidated_indices = None

# See set_train_data for additional compatibility variables.
# Not that the datapoints here are not transformed even if input_transform
# is not None to avoid double transformation during model fitting.
# self.transform_inputs is called in `forward`
self.set_train_data(datapoints, comparisons, update_model=False)

# Set optional parameters
# Explicitly set jitter for numerical stability in psd_safe_cholesky
self._jitter = kwargs.get("jitter", 1e-6)
# Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
# If None, set to 1e-6 by default in _update
self._xtol = kwargs.get("xtol")
# The maximum number of calls to the function in scipy.optimize.fsolve
# If None, set to 100 by default in _update
# If zero, then 100*(N+1) is used by default by fsolve;
self._maxfev = kwargs.get("maxfev")

# Set hyperparameters
# Do not set the batch_shape explicitly so mean_module can operate in both mode
# once fsolve used in _update can run in batch mode, we should explicitly set
Expand Down Expand Up @@ -202,22 +245,26 @@ def __init__(
if self.datapoints is not None and self.comparisons is not None:
self.to(dtype=self.datapoints.dtype, device=self.datapoints.device)
# Find f_map for initial parameters with transformed datapoints
transformed_dp = self.transform_inputs(datapoints)
transformed_dp = self.transform_inputs(self.datapoints)
self._update(transformed_dp)

self.to(self.datapoints)

def __deepcopy__(self, memo) -> PairwiseGP:
attrs = (
"datapoints",
"comparisons",
"consolidated_datapoints",
"consolidated_comparisons",
"covar",
"covar_inv",
"covar_chol",
"likelihood_hess",
"utility",
"hlcov_eye",
"unconsolidated_datapoints",
"unconsolidated_comparisons",
"consolidated_indices",
)

if any(getattr(self, attr) is not None for attr in attrs):
# Temporarily remove non-leaf tensors so that pytorch allows deepcopy
old_attr = {}
Expand All @@ -237,16 +284,6 @@ def __deepcopy__(self, memo) -> PairwiseGP:
self.__deepcopy__ = dcp
return new_model

def _scaled_psd_safe_cholesky(
self, M: Tensor, jitter: Optional[float] = None
) -> Tensor:
r"""scale M by 1/outputscale before cholesky for better numerical stability"""
scale = self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1)
M = M / scale
chol = psd_safe_cholesky(M, jitter=jitter)
chol = chol * scale.sqrt()
return chol

def _has_no_data(self):
r"""Return true if the model does not have both datapoints and comparisons"""
return (
Expand All @@ -269,8 +306,10 @@ def _update_covar(self, datapoints: Tensor) -> None:
datapoints: (Transformed) datapoints for finding f_max
"""
self.covar = self._calc_covar(datapoints, datapoints)
self.covar_chol = self._scaled_psd_safe_cholesky(
self.covar, jitter=self._jitter
self.covar_chol = _scaled_psd_safe_cholesky(
M=self.covar,
scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1),
jitter=self._jitter,
)
self.covar_inv = torch.cholesky_inverse(self.covar_chol)

Expand Down Expand Up @@ -587,25 +626,61 @@ def _util_newton_updates(self, dp, x0, max_iter=1, xtol=None) -> Tensor:

return x

def _check_strict_input(self, inputs, t_inputs, target_or_inputs):
for input_, t_input in zip(inputs, t_inputs or (None,)):
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(t_input, attr, None)
found_attr = getattr(input_, attr, None)
if expected_attr != found_attr:
msg = (
"Cannot modify {attr} of {t_or_i} "
"(expected {e_attr}, found {f_attr})."
)
msg = msg.format(
attr=attr,
e_attr=expected_attr,
f_attr=found_attr,
t_or_i=target_or_inputs,
)
raise RuntimeError(msg)
def _consolidate_duplicates(
self, datapoints: Tensor, comparisons: Tensor
) -> Tuple[Tensor, Tensor]:
"""Consolidate and cache datapoints and comparisons"""
# check if consolidated datapoints/comparisons are cached
if (
(datapoints is not self.unconsolidated_datapoints)
or (comparisons is not self.unconsolidated_comparisons)
or (self.consolidated_datapoints is None)
or (self.consolidated_comparisons is None)
):
self.unconsolidated_datapoints, self.unconsolidated_comparisons = (
datapoints,
comparisons,
)

if len(datapoints.shape) > 2 or len(comparisons.shape) > 2:
# Do not perform consolidation in batch mode as block design
# cannot be guaranteed
self.consolidated_datapoints = datapoints
self.consolidated_comparisons = comparisons
self.consolidated_indices = None
else:
(
self.consolidated_datapoints,
self.consolidated_comparisons,
self.consolidated_indices,
) = consolidate_duplicates(
datapoints,
comparisons,
rtol=self._consolidate_rtol,
atol=self._consolidate_atol,
)

return self.consolidated_datapoints, self.consolidated_comparisons

# ============== public APIs ==============
@property
def datapoints(self) -> Tensor:
r"""Alias for consolidated datapoints"""
return self.consolidated_datapoints

@property
def comparisons(self) -> Tensor:
r"""Alias for consolidated comparisons"""
return self.consolidated_comparisons

@property
def unconsolidated_utility(self) -> Tensor:
r"""Utility of the unconsolidated datapoints"""
if self.consolidated_indices is None:
# self.consolidated_indices is None in batch mode
return self.utility
else:
return self.utility[self.consolidated_indices]

@property
def num_outputs(self) -> int:
Expand Down Expand Up @@ -655,30 +730,33 @@ def set_train_data(
# following gpytorch.models.exact_gp.set_train_data
if datapoints is not None:
if torch.is_tensor(datapoints):
inputs = (datapoints,)
inputs = [datapoints]

inputs = tuple(
input_.unsqueeze(-1) if input_.ndimension() == 1 else input_
for input_ in inputs
)
if strict:
self._check_strict_input(inputs, self.train_inputs, "inputs")
_check_strict_input(inputs, self.train_inputs, "inputs")

self.datapoints = inputs[0]
datapoints = inputs[0]
# Compatibility variables with fit_gpytorch_*
# alias for datapoints ("train_inputs")
self.train_inputs = inputs

if comparisons is not None:
if strict:
self._check_strict_input([comparisons], [self.train_targets], "targets")
_check_strict_input([comparisons], [self.train_targets], "targets")

# convert to long so that it can be used as index and
# compatible with Tensor.scatter_
self.comparisons = comparisons.long()
comparisons = comparisons.long()
# Compatibility variables with fit_gpytorch_*
# alias for comparisons ("train_targets" here)
self.train_targets = self.comparisons
self.train_targets = comparisons

# self.datapoints and self.comparisons are being updated here
self._consolidate_duplicates(datapoints, comparisons)

# Compatibility variables with optimize_acqf
self._dtype = self.datapoints.dtype
Expand All @@ -704,7 +782,7 @@ def set_train_data(
self.DT = self.D.transpose(-1, -2)

if update_model:
transformed_dp = self.transform_inputs(datapoints)
transformed_dp = self.transform_inputs(self.datapoints)
self._update(transformed_dp)

self.to(self.datapoints)
Expand Down Expand Up @@ -780,18 +858,19 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
"or call .set_train_data() to add training data."
)

if datapoints is not self.datapoints:
if datapoints is not self.unconsolidated_datapoints:
raise RuntimeError("Must train on training data")

transformed_dp = self.transform_inputs(datapoints)

# We pass in the untransformed datapoints into set_train_data
# as we will be setting self.datapoints as the untransformed datapoints
# self.transform_inputs will be called inside before calling _update()
self.set_train_data(datapoints, self.comparisons, update_model=True)

transformed_dp = self.transform_inputs(self.datapoints)

hl = self.likelihood_hess
covar = self.covar

# Apply matrix inversion lemma on eq. in page 27 of [Brochu2010tutorial]_
# (A + B)^-1 = A^-1 - A^-1 @ (I + BA^-1)^-1 @ BA^-1
# where A = covar_inv, B = hl
Expand Down Expand Up @@ -854,7 +933,11 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
# output_covar is sometimes non-PSD
# perform a cholesky decomposition to check and amend
covariance_matrix=RootLinearOperator(
self._scaled_psd_safe_cholesky(output_covar, jitter=self._jitter)
_scaled_psd_safe_cholesky(
output_covar,
scale=self.covar_module.outputscale.unsqueeze(-1).unsqueeze(-1),
jitter=self._jitter,
)
),
)
return post
Expand Down Expand Up @@ -975,7 +1058,7 @@ def forward(self, post: Posterior, comp: Tensor) -> Tensor:

model = self.model
likelihood = self.likelihood
if comp is not model.comparisons:
if comp is not model.unconsolidated_comparisons:
raise RuntimeError("Must train on training data")

f_map = post.mean.squeeze(-1)
Expand Down
4 changes: 4 additions & 0 deletions botorch/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
check_min_max_scaling,
check_no_nans,
check_standardization,
consolidate_duplicates,
detect_duplicates,
fantasize,
gpt_posterior_settings,
mod_batch_shape,
Expand All @@ -32,4 +34,6 @@
"multioutput_to_batch_mode_transform",
"mod_batch_shape",
"validate_input_scaling",
"detect_duplicates",
"consolidate_duplicates",
]
Loading

0 comments on commit 531c6a9

Please sign in to comment.