Skip to content

Commit

Permalink
BC-breaking: Change how ModelListGP.posterior evaluates sub-models (#…
Browse files Browse the repository at this point in the history
…1854)

Summary:
Pull Request resolved: #1854

Prior to this change, `ModelListGP.posterior` would call it's submodels directly and work with the returned MVNs, skipping any processing in the individual `posterior` methods. This was particularly consequential for the `MultiTaskGP` (and subclasses) where skipping the `posterior` meant changing the expected shape of the `X` required by the `posterior`.

With this change, `ModelListGP.posterior` evaluates the `posterior` methods for each of its submodels, and combines the MVNs into a single MVN where applicable. For `MultiTaskGP`, this makes it so that the `posterior` can be evaluated using the same inputs regardless of whether the mdoel is wrapped in a `ModelListGP` or not.

Differential Revision: D46364187

fbshipit-source-id: 8e0ee449070607224fa6cd8510ed973a0c0da0c7
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jun 1, 2023
1 parent d8ab55b commit be5019b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 67 deletions.
87 changes: 20 additions & 67 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import warnings
from abc import ABC
from copy import deepcopy
from typing import Any, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
from botorch.acquisition.objective import PosteriorTransform
Expand Down Expand Up @@ -610,74 +610,27 @@ def posterior(
hasattr(mod, "outcome_transform") and (not mod.outcome_transform._is_linear)
for mod in self.models
)
if returns_untransformed:
return ModelList.posterior(
self,
X,
output_indices,
observation_noise,
posterior_transform,
**kwargs,
# NOTE: We're not passing in the posterior transform here. We'll apply it later.
posterior = ModelList.posterior(
self,
X=X,
output_indices=output_indices,
observation_noise=observation_noise,
**kwargs,
)
if not returns_untransformed:
# Return the result as a GPyTorchPosterior/FullyBayesianPosterior.
mvns = [p.distribution for p in posterior.posteriors]
mvn = (
mvns[0]
if len(mvns) == 1
else MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
)

self.eval() # make sure model is in eval mode
# input transforms are applied at `posterior` in `eval` mode, and at
# `model.forward()` at the training time
transformed_X = self.transform_inputs(X)
mvn_gen: Iterator
with gpt_posterior_settings():
# only compute what's necessary
if output_indices is not None:
mvns = [self.models[i](transformed_X[i]) for i in output_indices]
if observation_noise is not False:
if isinstance(observation_noise, Tensor):
lh_kwargs = [
{"noise": observation_noise[..., i]}
for i, lh in enumerate(self.likelihood.likelihoods)
]
else:
lh_kwargs = [
{"noise": lh.noise.mean().expand(t_X.shape[:-1])}
if isinstance(lh, FixedNoiseGaussianLikelihood)
else {}
for t_X, lh in zip(
transformed_X, self.likelihood.likelihoods
)
]
mvns = [
self.likelihood_i(i, mvn, transformed_X[i], **lkws)
for i, mvn, lkws in zip(output_indices, mvns, lh_kwargs)
]
mvn_gen = zip(output_indices, mvns)
else:
mvns = self(*transformed_X)
if observation_noise is not False:
mvnX = [(mvn, transformed_X[i]) for i, mvn in enumerate(mvns)]
if torch.is_tensor(observation_noise):
mvns = self.likelihood(*mvnX, noise=observation_noise)
else:
mvns = self.likelihood(*mvnX)
mvn_gen = enumerate(mvns)
# apply output transforms of individual models if present
mvns = []
for i, mvn in mvn_gen:
if hasattr(self.models[i], "outcome_transform"):
oct = self.models[i].outcome_transform
tf_mvn = oct.untransform_posterior(GPyTorchPosterior(mvn)).distribution
if any(is_fully_bayesian(m) for m in self.models):
# Mixing fully Bayesian and other GP models is currently not supported.
posterior = FullyBayesianPosterior(distribution=mvn)
else:
tf_mvn = mvn
mvns.append(tf_mvn)
# return result as a GPyTorchPosteriors/FullyBayesianPosterior
mvn = (
mvns[0]
if len(mvns) == 1
else MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
)
if any(is_fully_bayesian(m) for m in self.models):
# mixing fully Bayesian and other GP models is currently not supported
posterior = FullyBayesianPosterior(distribution=mvn)
else:
posterior = GPyTorchPosterior(distribution=mvn)
posterior = GPyTorchPosterior(distribution=mvn)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior
Expand Down
22 changes: 22 additions & 0 deletions test/models/test_model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from botorch.fit import fit_gpytorch_mll
from botorch.models import ModelListGP
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import ChainedOutcomeTransform, Log, Standardize
from botorch.posteriors import GPyTorchPosterior, PosteriorList, TransformedPosterior
Expand Down Expand Up @@ -281,6 +282,27 @@ def test_ModelListGP_single(self):
self.assertIsInstance(posterior, GPyTorchPosterior)
self.assertIsInstance(posterior.distribution, MultivariateNormal)

def test_ModelListGP_multi_task(self):
tkwargs = {"device": self.device, "dtype": torch.float}
train_x_raw, train_y = _get_random_data(
batch_shape=torch.Size(), m=1, n=10, **tkwargs
)
task_idx = torch.cat(
[torch.ones(5, 1, **tkwargs), torch.zeros(5, 1, **tkwargs)], dim=0
)
train_x = torch.cat([train_x_raw, task_idx], dim=-1)
model = MultiTaskGP(
train_X=train_x,
train_Y=train_y,
task_feature=-1,
output_tasks=[0],
)
model_list_gp = ModelListGP(model)
with torch.no_grad():
model_mean = model.posterior(train_x_raw).mean
model_list_gp_mean = model_list_gp.posterior(train_x_raw).mean
self.assertAllClose(model_mean, model_list_gp_mean)

def test_transform_revert_train_inputs(self):
tkwargs = {"device": self.device, "dtype": torch.float}
model_list = _get_model(use_intf=True, **tkwargs)
Expand Down

0 comments on commit be5019b

Please sign in to comment.