Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed condition_on_observations in fully Bayesian models #2151

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,8 @@ def forward(self, X: Tensor) -> MultivariateNormal:
rest of this method will not run.
"""
self._check_if_fitted()
x = X.unsqueeze(MCMC_DIM)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
mean_x = self.mean_module(X)
covar_x = self.covar_module(X)
return MultivariateNormal(mean_x, covar_x)

# pyre-ignore[14]: Inconsistent override
Expand Down Expand Up @@ -534,11 +533,45 @@ def posterior(
"""
self._check_if_fitted()
posterior = super().posterior(
X=X,
X=X.unsqueeze(MCMC_DIM),
output_indices=output_indices,
observation_noise=observation_noise,
posterior_transform=posterior_transform,
**kwargs,
)
posterior = GaussianMixturePosterior(distribution=posterior.distribution)
return posterior

def condition_on_observations(
self, X: Tensor, Y: Tensor, **kwargs: Any
) -> BatchedMultiOutputGPyTorchModel:
"""Conditions on additional observations for a Fully Bayesian model (either
identical across models or unique per-model).

Args:
X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
the dimension of the feature space and `batch_shape` is the number of
sampled models.
Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
the dimension of the feature space and `batch_shape` is the number of
sampled models.

Returns:
BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
given observations. The returned model has `batch_shape` copies of the
training data in case of identical observations (and `batch_shape`
training datasets otherwise).
"""
if X.ndim == 2 and Y.ndim == 2:
# To avoid an error in GPyTorch when inferring the batch dimension, we add
# the explicit batch shape here. The result is that the conditioned model
# will have 'batch_shape' copies of the training data.
X = X.repeat(self.batch_shape + (1, 1))
Y = Y.repeat(self.batch_shape + (1, 1))

elif X.ndim < Y.ndim:
# We need to duplicate the training data to enable correct batch
# size inference in gpytorch.
X = X.repeat(*(Y.shape[:-2] + (1, 1)))

return super().condition_on_observations(X, Y, **kwargs)
7 changes: 5 additions & 2 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
"""
Yvar = kwargs.get("noise", None)
Yvar = kwargs.pop("noise", None)

if hasattr(self, "outcome_transform"):
# pass the transformed data to get_fantasy_model below
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
Expand All @@ -242,6 +243,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
kwargs.update({"noise": Yvar.squeeze(-1)})
# get_fantasy_model will properly copy any existing outcome transforms
# (since it deepcopies the original model)

return self.get_fantasy_model(inputs=X, targets=Y, **kwargs)


Expand Down Expand Up @@ -492,7 +494,8 @@ def condition_on_observations(
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
: (-1 if self._num_outputs == 1 else -2)
]
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
if not self._is_fully_bayesian:
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
return fantasy_model

def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
Expand Down
116 changes: 116 additions & 0 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
)
return train_X, train_Y, train_Yvar, test_X

def _get_unnormalized_condition_data(
self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs
):
with torch.random.fork_rng():
torch.manual_seed(0)
cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs)
cond_Y = 10 + torch.sin(cond_X[..., :1])
cond_Yvar = (
None if infer_noise else 0.1 * torch.ones(cond_Y.shape, **tkwargs)
)
return cond_X, cond_Y, cond_Yvar

def _get_mcmc_samples(
self, num_samples: int, dim: int, infer_noise: bool, **tkwargs
):
Expand Down Expand Up @@ -656,6 +668,110 @@ def test_custom_pyro_model(self):
atol=5e-4,
)

def test_condition_on_observation(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these tests. For completeness, can you add a test that calls model.fantasize and a test that evaluates some acquisition function (e.g., qLogEI) with the fantasized model? I just want to make sure everything works e2e including the sampler related code inside the acquisition functions.

# The following conditioned data shapes should work (output describes):
# training data shape after cond(batch shape in output is req. in gpytorch)
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
# X: n x d, Y: n x d --> num_models x n x d
# X: n x d, Y: num_models x n x d --> num_models x n x d
num_models = 3
num_cond = 2
for infer_noise, dtype in itertools.product(
(True, False), (torch.float, torch.double)
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(
infer_noise=infer_noise, **tkwargs
)
num_train, num_dims = train_X.shape
# condition on different observations per model to obtain num_models sets
# of training data
cond_X, cond_Y, cond_Yvar = self._get_unnormalized_condition_data(
num_models=num_models,
num_cond=num_cond,
infer_noise=infer_noise,
**tkwargs
)
model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
)
mcmc_samples = self._get_mcmc_samples(
num_samples=num_models,
dim=train_X.shape[-1],
infer_noise=infer_noise,
**tkwargs
)
model.load_mcmc_samples(mcmc_samples)

# need to forward pass before conditioning
model.posterior(train_X)
cond_model = model.condition_on_observations(
cond_X, cond_Y, noise=cond_Yvar
)
posterior = cond_model.posterior(test_X)
self.assertEqual(
posterior.mean.shape, torch.Size([num_models, len(test_X), 1])
)

# since the data is not equal for the conditioned points, a batch size
# is added to the training data
self.assertEqual(
cond_model.train_inputs[0].shape,
torch.Size([num_models, num_train + num_cond, num_dims]),
)

# the batch shape of the condition model is added during conditioning
self.assertEqual(cond_model.batch_shape, torch.Size([num_models]))

# condition on identical sets of data (i.e. one set) for all models
# i.e, with no batch shape. This infers the batch shape.
cond_X_nobatch, cond_Y_nobatch = cond_X[0], cond_Y[0]
model = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
)
mcmc_samples = self._get_mcmc_samples(
num_samples=num_models,
dim=train_X.shape[-1],
infer_noise=infer_noise,
**tkwargs
)
model.load_mcmc_samples(mcmc_samples)

# conditioning without a batch size - the resulting conditioned model
# will still have a batch size
model.posterior(train_X)
cond_model = model.condition_on_observations(
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
)
self.assertEqual(
cond_model.train_inputs[0].shape,
torch.Size([num_models, num_train + num_cond, num_dims]),
)

# test repeated conditining
repeat_cond_X = cond_X + 5
repeat_cond_model = cond_model.condition_on_observations(
repeat_cond_X, cond_Y, noise=cond_Yvar
)
self.assertEqual(
repeat_cond_model.train_inputs[0].shape,
torch.Size([num_models, num_train + 2 * num_cond, num_dims]),
)

# test repeated conditioning without a batch size
repeat_cond_X_nobatch = cond_X_nobatch + 10
repeat_cond_model2 = repeat_cond_model.condition_on_observations(
repeat_cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
)
self.assertEqual(
repeat_cond_model2.train_inputs[0].shape,
torch.Size([num_models, num_train + 3 * num_cond, num_dims]),
)

def test_bisect(self):
def f(x):
return 1 + x
Expand Down