Skip to content

Commit

Permalink
pass X to OutcomeTransform (pytorch#2663)
Browse files Browse the repository at this point in the history
Summary:

This enables using outcome transforms with behavior that depends on X. For example, this enables implementing a stratified standardize transform, where the the standardization is performing for distinct values of an input dimension.

Differential Revision: D67724473
  • Loading branch information
sdaulton authored and facebook-github-bot committed Dec 31, 2024
1 parent 466da73 commit 8a6fb1a
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 46 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ def _get_noiseless_fantasy_model(
# Not transforming Yvar because 1e-7 is already close to 0 and it is a
# relative, not absolute, value.
Y_fantasized, _ = outcome_transform(
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1)
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1), X=batch_X_observed
)
Y_fantasized = Y_fantasized.squeeze(-1)
input_transform = getattr(model, "input_transform", None)
Expand Down
4 changes: 2 additions & 2 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def posterior(

posterior = GPyTorchPosterior(distribution=dist)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
if posterior_transform is not None:
posterior = posterior_transform(posterior)
return posterior
Expand Down Expand Up @@ -397,7 +397,7 @@ def __init__(
UserInputWarning,
stacklevel=3,
)
train_Y, _ = outcome_transform(train_Y)
train_Y, _ = outcome_transform(train_Y, X=transformed_X)
self._validate_tensor_args(X=transformed_X, Y=train_Y)
validate_input_scaling(train_X=transformed_X, train_Y=train_Y)
if train_Y.shape[-1] != num_outputs:
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def posterior(
# `posterior` (as is done in GP models). This is more general since it works
# even if the transform doesn't support `untransform_posterior`.
if hasattr(self, "outcome_transform"):
values, _ = self.outcome_transform.untransform(values)
values, _ = self.outcome_transform.untransform(values, X=X)
if output_indices is not None:
values = values[..., output_indices]
posterior = EnsemblePosterior(values=values)
Expand Down
4 changes: 3 additions & 1 deletion botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,9 @@ def __init__(
X=train_X, input_transform=input_transform
)
if outcome_transform is not None:
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
train_Y, train_Yvar = outcome_transform(
train_Y, train_Yvar, X=transformed_X
)
self._validate_tensor_args(X=transformed_X, Y=train_Y)
validate_input_scaling(
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar
Expand Down
4 changes: 3 additions & 1 deletion botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def __init__(
)
if outcome_transform is not None:
outcome_transform.train() # Ensure we learn parameters here on init
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
train_Y, train_Yvar = outcome_transform(
train_Y, train_Yvar, X=transformed_X
)
if train_Yvar is not None: # Clamp after transforming
train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)

Expand Down
4 changes: 3 additions & 1 deletion botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def __init__(
X=train_X, input_transform=input_transform
)
if outcome_transform is not None:
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
train_Y, train_Yvar = outcome_transform(
train_Y, train_Yvar, X=transformed_X
)
# Validate again after applying the transforms
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
ignore_X_dims = getattr(self, "_ignore_X_dims_scaling_check", None)
Expand Down
10 changes: 5 additions & 5 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def posterior(
mvn = self.likelihood(mvn, X)
posterior = GPyTorchPosterior(distribution=mvn)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior
Expand Down Expand Up @@ -244,7 +244,7 @@ def condition_on_observations(
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
# `noise` is assumed to already be outcome-transformed.
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar)
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar, X=X)
# Validate using strict=False, since we cannot tell if Y has an explicit
# output dimension. Do not check shapes when fantasizing as they are
# not expected to match.
Expand Down Expand Up @@ -467,7 +467,7 @@ def posterior(

posterior = GPyTorchPosterior(distribution=mvn)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior
Expand Down Expand Up @@ -511,7 +511,7 @@ def condition_on_observations(
if hasattr(self, "outcome_transform"):
# We need to apply transforms before shifting batch indices around.
# `noise` is assumed to already be outcome-transformed.
Y, _ = self.outcome_transform(Y)
Y, _ = self.outcome_transform(Y, X=X)
# Do not check shapes when fantasizing as they are not expected to match.
if fantasize_flag.off():
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
Expand Down Expand Up @@ -924,7 +924,7 @@ def posterior(
)
posterior = GPyTorchPosterior(distribution=mtmvn)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior
Expand Down
12 changes: 6 additions & 6 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _return_to_output_shape(self, tsr: Tensor) -> Tensor:
return out

def forward(
self, Y: Tensor, Yvar: Tensor | None = None
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
Y = self._squeeze_to_single_output(Y)
if Yvar is not None:
Expand All @@ -107,7 +107,7 @@ def forward(
return Y_out, Yvar_out

def untransform(
self, Y: Tensor, Yvar: Tensor | None = None
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
Y = self._squeeze_to_single_output(Y)
if Yvar is not None:
Expand All @@ -121,7 +121,7 @@ def untransform(
return Y, Yvar

def untransform_posterior(
self, posterior: HigherOrderGPPosterior
self, posterior: HigherOrderGPPosterior, X: Tensor | None = None
) -> TransformedPosterior:
# TODO: return a HigherOrderGPPosterior once rescaling constant
# muls * LinearOperators won't force a dense decomposition rather than a
Expand Down Expand Up @@ -227,7 +227,7 @@ def __init__(
output_shape=train_Y.shape[-num_output_dims:],
batch_shape=batch_shape,
)
train_Y, _ = outcome_transform(train_Y)
train_Y, _ = outcome_transform(train_Y, X=train_X)

self._aug_batch_shape = batch_shape
self._num_dimensions = num_output_dims + 1
Expand Down Expand Up @@ -416,7 +416,7 @@ def condition_on_observations(
"""
if hasattr(self, "outcome_transform"):
# we need to apply transforms before shifting batch indices around
Y, noise = self.outcome_transform(Y=Y, Yvar=noise)
Y, noise = self.outcome_transform(Y=Y, Yvar=noise, X=X)
# Do not check shapes when fantasizing as they are not expected to match.
if fantasize_flag.off():
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
Expand Down Expand Up @@ -540,7 +540,7 @@ def posterior(
num_outputs=self._num_outputs,
)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
return posterior

def make_posterior_variances(
Expand Down
9 changes: 6 additions & 3 deletions botorch/models/latent_kronecker_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self._use_min = use_min

def forward(
self, Y: Tensor, Yvar: Tensor | None = None
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
r"""Standardize outcomes.
Expand All @@ -93,6 +93,7 @@ def forward(
Y: A `batch_shape x n x m`-dim tensor of training targets.
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
associated with the training targets (if applicable).
X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
Returns:
A two-tuple with the transformed outcomes:
Expand Down Expand Up @@ -240,7 +241,9 @@ def __init__(
outcome_transform = MinMaxStandardize(batch_shape=batch_shape)
if outcome_transform is not None:
# transform outputs once and keep the results
train_Y = outcome_transform(train_Y.unsqueeze(-1))[0].squeeze(-1)
train_Y = outcome_transform(train_Y.unsqueeze(-1), X=transformed_X)[
0
].squeeze(-1)

ExactGP.__init__(
self,
Expand Down Expand Up @@ -506,7 +509,7 @@ def _rsample_from_base_samples(
)
# samples.shape = (*sample_shape, *broadcast_shape, n_test_x, n_t)
if hasattr(self, "outcome_transform") and self.outcome_transform is not None:
samples, _ = self.outcome_transform.untransform(samples)
samples, _ = self.outcome_transform.untransform(samples, X=X)
return samples

def condition_on_observations(
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def condition_on_observations(
else:
noise_i = torch.cat([noise[..., k] for k in range(i, j)], dim=-1)
if hasattr(model, "outcome_transform"):
y_i, noise_i = model.outcome_transform(y_i, noise_i)
y_i, noise_i = model.outcome_transform(y_i, noise_i, X=X_i)
if noise_i is not None:
noise_i = noise_i.squeeze(0)
targets.append(y_i)
Expand Down
8 changes: 5 additions & 3 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def __init__(
if outcome_transform == DEFAULT:
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
if outcome_transform is not None:
train_Y, train_Yvar = outcome_transform(Y=train_Y, Yvar=train_Yvar)
train_Y, train_Yvar = outcome_transform(
Y=train_Y, Yvar=train_Yvar, X=transformed_X
)

# squeeze output dim
train_Y = train_Y.squeeze(-1)
Expand Down Expand Up @@ -464,7 +466,7 @@ def __init__(
X=train_X, input_transform=input_transform
)
if outcome_transform is not None:
train_Y, _ = outcome_transform(train_Y)
train_Y, _ = outcome_transform(train_Y, X=transformed_X)

self._validate_tensor_args(X=transformed_X, Y=train_Y)
self._num_outputs = train_Y.shape[-1]
Expand Down Expand Up @@ -772,7 +774,7 @@ def posterior(
)

if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
return posterior

def train(self, val=True, *args, **kwargs):
Expand Down
Loading

0 comments on commit 8a6fb1a

Please sign in to comment.