Skip to content

Commit

Permalink
Follow-ups from D40783106 (fantasize refactor) (#1479)
Browse files Browse the repository at this point in the history
Summary:
see T137147547

Remove inappropriate subclassing of `SingleTaskGP` (and clean up associated to-dos)
[x] Make `HeteroskedasticSingleTaskGP` not a subclass of `SingleTaskGP` since it can't `fantasize`
[x] Make `SaasFullyBayesianSingleTaskGP` not a subclass of `SingleTaskGP` since it can't `fantasize`
[x] Fix downstream problems in multiple dispatched this caused....

Add `fantasize` back to classes that weren't using it, but can call it and used to have it
[x] Make `MultiTaskGP` have a `fantasize` method (debatable! it used to have one, but wasn't used)
[x] Restore `fantasize` method with `NotImplementedError` to `SaasFullyBayesianMultiTaskGP` (as a result of adding `fantasize` to `MultiTaskGP`)
[x] Make `KroneckerMultiTaskGP` have a `fantasize` method (debatable! it used to have one, but wasn't used)

Pull Request resolved: #1479

Differential Revision: D41053234

Pulled By: esantorella

fbshipit-source-id: 6f422916b7ae20fe8235b280e4b917055cfab5d1
  • Loading branch information
esantorella authored and facebook-github-bot committed Nov 7, 2022
1 parent 7563a0b commit dd54803
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 20 deletions.
19 changes: 15 additions & 4 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import pyro
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils import validate_input_scaling
Expand All @@ -54,6 +54,7 @@
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.means.mean import Mean
from gpytorch.models.exact_gp import ExactGP
from torch import Tensor

MIN_INFERRED_NOISE_LEVEL = 1e-6
Expand Down Expand Up @@ -294,7 +295,7 @@ def load_mcmc_samples(
return mean_module, covar_module, likelihood


class SaasFullyBayesianSingleTaskGP(SingleTaskGP):
class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel):
r"""A fully Bayesian single-task GP model with the SAAS prior.
This model assumes that the inputs have been normalized to [0, 1]^d and that
Expand Down Expand Up @@ -364,7 +365,7 @@ def __init__(
train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)

X_tf, Y_tf, _ = self._transform_tensor_args(X=train_X, Y=train_Y)
super(SingleTaskGP, self).__init__(
super().__init__(
train_inputs=X_tf, train_targets=Y_tf, likelihood=GaussianLikelihood()
)
self.mean_module = None
Expand Down Expand Up @@ -473,9 +474,19 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
super().load_state_dict(state_dict=state_dict, strict=strict)

def forward(self, X: Tensor) -> MultivariateNormal:
"""
Unlike in other classes' `forward` methods, there is no `if self.training`
block, because it ought to be unreachable: If `self.train()` has been called,
then `self.covar_module` will be None, `check_if_fitted()` will fail, and the
rest of this method will not run.
"""
self._check_if_fitted()
return super().forward(X.unsqueeze(MCMC_DIM))
x = X.unsqueeze(MCMC_DIM)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)

# pyre-ignore[14]: Inconsistent override
def posterior(
self,
X: Tensor,
Expand Down
9 changes: 8 additions & 1 deletion botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""


from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, NoReturn, Optional, Tuple

import pyro
import torch
Expand Down Expand Up @@ -299,6 +299,9 @@ def batch_shape(self) -> torch.Size:
self._check_if_fitted()
return torch.Size([self.num_mcmc_samples])

def fantasize(self, *args, **kwargs) -> NoReturn:
raise NotImplementedError("Fantasize is not implemented!")

def _check_if_fitted(self):
r"""Raise an exception if the model hasn't been fitted."""
if self.covar_module is None:
Expand All @@ -321,6 +324,8 @@ def load_mcmc_samples(self, mcmc_samples: Dict[str, Tensor]) -> None:
self.latent_features,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)

# pyre-fixme[14]: Inconsistent override of
# BatchedMultiOutputGPyTorchModel.posterior
def posterior(
self,
X: Tensor,
Expand All @@ -345,6 +350,7 @@ def posterior(
posterior = FullyBayesianPosterior(mvn=posterior.mvn)
return posterior

# pyre-fixme[14]: Inconsistent override
def forward(self, X: Tensor) -> MultivariateNormal:
self._check_if_fitted()
X = X.unsqueeze(MCMC_DIM)
Expand Down Expand Up @@ -373,6 +379,7 @@ def forward(self, X: Tensor) -> MultivariateNormal:
return MultivariateNormal(mean_x, covar)

@classmethod
# pyre-fixme[14]: Inconsistent override of `MultiTaskGP.construct_inputs`
def construct_inputs(
cls,
training_data: Dict[str, SupervisedDataset],
Expand Down
22 changes: 15 additions & 7 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
return new_model


class HeteroskedasticSingleTaskGP(SingleTaskGP):
class HeteroskedasticSingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP):
r"""A single-task exact GP model using a heteroskedastic noise model.
This model differs from `SingleTaskGP` in that noise levels are provided
Expand Down Expand Up @@ -423,7 +423,12 @@ def __init__(
input_transform=input_transform,
)
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
super().__init__(
# This is hacky -- this class used to inherit from SingleTaskGP, but it
# shouldn't so this is a quick fix to enable getting rid of that
# inheritance
SingleTaskGP.__init__(
# pyre-fixme[6]: Incompatible parameter type
self,
train_X=train_X,
train_Y=train_Y,
likelihood=likelihood,
Expand All @@ -437,15 +442,18 @@ def __init__(
self.outcome_transform = outcome_transform
self.to(train_X)

# TODO: HeteroskedasticSingleTaskGP should not be a subclass of
# SingleTaskGP because it can't function the way a SingleTaskGP does
# pyre-fixme[15]: Inconsistent override
def condition_on_observations(self, *_, **__) -> NoReturn:
raise NotImplementedError

def fantasize(self, *_, **__) -> NoReturn:
raise NotImplementedError

# pyre-fixme[15]: Inconsistent override
def subset_output(self, idcs) -> NoReturn:
raise NotImplementedError

# pyre-fixme[14]: Inconsistent override
def forward(self, x: Tensor) -> MultivariateNormal:
if self.training:
x = self.transform_inputs(x)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
5 changes: 3 additions & 2 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
from botorch.models.gpytorch import GPyTorchModel, MultiTaskGPyTorchModel
from botorch.models.model import FantasizeMixin
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.posteriors.multitask import MultitaskGPPosterior
Expand Down Expand Up @@ -71,7 +72,7 @@
from torch import Tensor


class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel):
class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
r"""Multi-Task GP model using an ICM kernel, inferring observation noise.
Multi-task exact GP that uses a simple ICM kernel. Can be single-output or
Expand Down Expand Up @@ -377,7 +378,7 @@ def __init__(
self.to(train_X)


class KroneckerMultiTaskGP(ExactGP, GPyTorchModel):
class KroneckerMultiTaskGP(ExactGP, GPyTorchModel, FantasizeMixin):
"""Multi-task GP with Kronecker structure, using an ICM kernel.
This model assumes the "block design" case, i.e., it requires that all tasks
Expand Down
10 changes: 10 additions & 0 deletions test/models/test_fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ def test_raises(self):
task_feature=4,
)
train_X, train_Y, train_Yvar, model = self._get_data_and_model(**tkwargs)
sampler = IIDNormalSampler(num_samples=2)
with self.assertRaisesRegex(
NotImplementedError, "Fantasize is not implemented!"
):
model.fantasize(
X=torch.cat(
[torch.rand(5, 4, **tkwargs), torch.ones(5, 1, **tkwargs)], dim=1
),
sampler=sampler,
)

# Make sure an exception is raised if the model has not been fitted
not_fitted_error_msg = (
Expand Down
19 changes: 13 additions & 6 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,23 @@ def _get_model_and_data(
model = HeteroskedasticSingleTaskGP(**model_kwargs)
return model, model_kwargs

def test_custom_init(self):
pass
def test_custom_init(self) -> None:
"""
This test exists because `TestHeteroskedasticSingleTaskGP` inherits from
`TestSingleTaskGP`, which has a `test_custom_init` method that isn't relevant
for `TestHeteroskedasticSingleTaskGP`.
"""

def test_gp(self):
super().test_gp(double_only=True)

def test_fantasize(self) -> None:
"""
This test exists because `TestHeteroskedasticSingleTaskGP` inherits from
`TestSingleTaskGP`, which has a `fantasize` method that isn't relevant
for `TestHeteroskedasticSingleTaskGP`.
"""

def test_heteroskedastic_likelihood(self):
for batch_shape, m, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (1, 2), (torch.float, torch.double)
Expand All @@ -480,10 +491,6 @@ def test_condition_on_observations(self):
with self.assertRaises(NotImplementedError):
super().test_condition_on_observations()

def test_fantasize(self):
with self.assertRaises(NotImplementedError):
super().test_fantasize()

def test_subset_model(self):
with self.assertRaises(NotImplementedError):
super().test_subset_model()
Expand Down

0 comments on commit dd54803

Please sign in to comment.