Skip to content

Commit

Permalink
Fixed condition_on_observations in fully Bayesian models (#2151)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

Conditioning on observations in fully bayesian models - enables fully Bayesian JES & KG(?).

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes.

Pull Request resolved: #2151

Test Plan:
Tests are written to ensure functionality for inferred and fixed noise.  __note that the `_aug_batch_shape` attribute assignment was removed in `condition_on_observations`.__ In `FullyBayesianGPs`, this argument could not be assigned (hence the removal). I could not find the use for this argument, and all tests passed when removing it.

Other changes are commented throughout, and the changes were made so as to assure that FBGPs can have one set of training data throughout. Howver, conditioning on obervations adds a batch dim to the training data (which is necessary in GPyTorch [here](https://github.com/cornellius-gp/gpytorch/blob/58c033564d28a5537397bc464827783313534e56/gpytorch/models/exact_gp.py#L176)) to infer the correct batch dim.

Reviewed By: dme65

Differential Revision: D52256296

Pulled By: saitcakmak

fbshipit-source-id: e340897d76e02c32ef7a981bef8a77c49e030ad1
  • Loading branch information
hvarfner authored and facebook-github-bot committed Dec 27, 2023
1 parent 967535f commit 0c37aac
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 6 deletions.
41 changes: 37 additions & 4 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,8 @@ def forward(self, X: Tensor) -> MultivariateNormal:
rest of this method will not run.
"""
self._check_if_fitted()
x = X.unsqueeze(MCMC_DIM)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
mean_x = self.mean_module(X)
covar_x = self.covar_module(X)
return MultivariateNormal(mean_x, covar_x)

# pyre-ignore[14]: Inconsistent override
Expand Down Expand Up @@ -534,11 +533,45 @@ def posterior(
"""
self._check_if_fitted()
posterior = super().posterior(
X=X,
X=X.unsqueeze(MCMC_DIM),
output_indices=output_indices,
observation_noise=observation_noise,
posterior_transform=posterior_transform,
**kwargs,
)
posterior = GaussianMixturePosterior(distribution=posterior.distribution)
return posterior

def condition_on_observations(
self, X: Tensor, Y: Tensor, **kwargs: Any
) -> BatchedMultiOutputGPyTorchModel:
"""Conditions on additional observations for a Fully Bayesian model (either
identical across models or unique per-model).
Args:
X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
the dimension of the feature space and `batch_shape` is the number of
sampled models.
Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
the dimension of the feature space and `batch_shape` is the number of
sampled models.
Returns:
BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
given observations. The returned model has `batch_shape` copies of the
training data in case of identical observations (and `batch_shape`
training datasets otherwise).
"""
if X.ndim == 2 and Y.ndim == 2:
# To avoid an error in GPyTorch when inferring the batch dimension, we add
# the explicit batch shape here. The result is that the conditioned model
# will have 'batch_shape' copies of the training data.
X = X.repeat(self.batch_shape + (1, 1))
Y = Y.repeat(self.batch_shape + (1, 1))

elif X.ndim < Y.ndim:
# We need to duplicate the training data to enable correct batch
# size inference in gpytorch.
X = X.repeat(*(Y.shape[:-2] + (1, 1)))

return super().condition_on_observations(X, Y, **kwargs)
7 changes: 5 additions & 2 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
"""
Yvar = kwargs.get("noise", None)
Yvar = kwargs.pop("noise", None)

if hasattr(self, "outcome_transform"):
# pass the transformed data to get_fantasy_model below
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
Expand All @@ -242,6 +243,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
kwargs.update({"noise": Yvar.squeeze(-1)})
# get_fantasy_model will properly copy any existing outcome transforms
# (since it deepcopies the original model)

return self.get_fantasy_model(inputs=X, targets=Y, **kwargs)


Expand Down Expand Up @@ -492,7 +494,8 @@ def condition_on_observations(
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
: (-1 if self._num_outputs == 1 else -2)
]
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
if not self._is_fully_bayesian:
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
return fantasy_model

def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
Expand Down
116 changes: 116 additions & 0 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
)
return train_X, train_Y, train_Yvar, test_X

def _get_unnormalized_condition_data(
self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs
):
with torch.random.fork_rng():
torch.manual_seed(0)
cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs)
cond_Y = 10 + torch.sin(cond_X[..., :1])
cond_Yvar = (
None if infer_noise else 0.1 * torch.ones(cond_Y.shape, **tkwargs)
)
return cond_X, cond_Y, cond_Yvar

def _get_mcmc_samples(
self, num_samples: int, dim: int, infer_noise: bool, **tkwargs
):
Expand Down Expand Up @@ -656,6 +668,110 @@ def test_custom_pyro_model(self):
atol=5e-4,
)

def test_condition_on_observation(self):
# The following conditioned data shapes should work (output describes):
# training data shape after cond(batch shape in output is req. in gpytorch)
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
# X: n x d, Y: n x d --> num_models x n x d
# X: n x d, Y: num_models x n x d --> num_models x n x d
num_models = 3
num_cond = 2
for infer_noise, dtype in itertools.product(
(True, False), (torch.float, torch.double)
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(
infer_noise=infer_noise, **tkwargs
)
num_train, num_dims = train_X.shape
# condition on different observations per model to obtain num_models sets
# of training data
cond_X, cond_Y, cond_Yvar = self._get_unnormalized_condition_data(
num_models=num_models,
num_cond=num_cond,
infer_noise=infer_noise,
**tkwargs
)
model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
)
mcmc_samples = self._get_mcmc_samples(
num_samples=num_models,
dim=train_X.shape[-1],
infer_noise=infer_noise,
**tkwargs
)
model.load_mcmc_samples(mcmc_samples)

# need to forward pass before conditioning
model.posterior(train_X)
cond_model = model.condition_on_observations(
cond_X, cond_Y, noise=cond_Yvar
)
posterior = cond_model.posterior(test_X)
self.assertEqual(
posterior.mean.shape, torch.Size([num_models, len(test_X), 1])
)

# since the data is not equal for the conditioned points, a batch size
# is added to the training data
self.assertEqual(
cond_model.train_inputs[0].shape,
torch.Size([num_models, num_train + num_cond, num_dims]),
)

# the batch shape of the condition model is added during conditioning
self.assertEqual(cond_model.batch_shape, torch.Size([num_models]))

# condition on identical sets of data (i.e. one set) for all models
# i.e, with no batch shape. This infers the batch shape.
cond_X_nobatch, cond_Y_nobatch = cond_X[0], cond_Y[0]
model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
)
mcmc_samples = self._get_mcmc_samples(
num_samples=num_models,
dim=train_X.shape[-1],
infer_noise=infer_noise,
**tkwargs
)
model.load_mcmc_samples(mcmc_samples)

# conditioning without a batch size - the resulting conditioned model
# will still have a batch size
model.posterior(train_X)
cond_model = model.condition_on_observations(
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
)
self.assertEqual(
cond_model.train_inputs[0].shape,
torch.Size([num_models, num_train + num_cond, num_dims]),
)

# test repeated conditining
repeat_cond_X = cond_X + 5
repeat_cond_model = cond_model.condition_on_observations(
repeat_cond_X, cond_Y, noise=cond_Yvar
)
self.assertEqual(
repeat_cond_model.train_inputs[0].shape,
torch.Size([num_models, num_train + 2 * num_cond, num_dims]),
)

# test repeated conditioning without a batch size
repeat_cond_X_nobatch = cond_X_nobatch + 10
repeat_cond_model2 = repeat_cond_model.condition_on_observations(
repeat_cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
)
self.assertEqual(
repeat_cond_model2.train_inputs[0].shape,
torch.Size([num_models, num_train + 3 * num_cond, num_dims]),
)

def test_bisect(self):
def f(x):
return 1 + x
Expand Down

0 comments on commit 0c37aac

Please sign in to comment.