Skip to content

Commit

Permalink
Pull out "fantasize" function so it isn't so widely inherited -- for …
Browse files Browse the repository at this point in the history
…discussion! (#1462)

Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

The `Model` base class has a `fantasize` method, but it isn't actually possible to `fantasize` from all models. For some models, `fantasize` fails with a `NotImplementedError` even though the docstring implies it should work. This is confusing (e.g. #1459 ). Another disadvantage is that codecov doesn't require tests for all of the many `fantasize` methods that exist, only for the base one -- which can't be called, since it is abstract.

This PR removes the "fantasize" method from the `Model` base class. Instead, there is a `fantasize` function that is called by the few classes that do actually use and/or test `fantasize`.

Pros:
- decreases the "surface area" of BoTorch and prevents misuse of methods that we didn't really intend to exist
- Allows codecov to surface where we aren't actually testing methods (I expect it to fail)

Cons:
-  `fantasize` still exists but doesn't work in `HeteroskedasticSingleTaskGP` because it inherits from `SingleTaskGP` (a further refactor can fix that)
- Could remove `fantasize` from classes where it _should_ exist (despite not being used or tested within BoTorch)
- Somewhat more verbose

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #1462

Test Plan:
[x] Codecov
[x] Unit tests should pass
[x] No new Pyre errors introduced

Reviewed By: saitcakmak

Differential Revision: D40783106

Pulled By: esantorella

fbshipit-source-id: 91cd7ee47efee1d89ae7f65af1ed94a4d88bdbe6
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 31, 2022
1 parent b103f0a commit ef7d39e
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 105 deletions.
6 changes: 0 additions & 6 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling import MCSampler
from gpytorch.constraints import GreaterThan
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import Kernel, MaternKernel, ScaleKernel
Expand Down Expand Up @@ -143,11 +142,6 @@ def forward(self, X, *args, **kwargs) -> MultivariateNormal:
X = self.transform_inputs(X)
return self.model(X)

def fantasize(self, X, sampler=MCSampler, observation_noise=True, *args, **kwargs):
raise NotImplementedError(
"Fantasization of approximate GPs has not been implemented yet."
)


class _SingleTaskVariationalGP(ApproximateGP):
"""
Expand Down
14 changes: 2 additions & 12 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@

import math
from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple

import pyro
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
from botorch.sampling.samplers import MCSampler
from gpytorch.constraints import GreaterThan
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import MaternKernel, ScaleKernel
Expand Down Expand Up @@ -418,15 +417,6 @@ def _aug_batch_shape(self) -> torch.Size:
aug_batch_shape += torch.Size([self.num_outputs])
return aug_batch_shape

def fantasize(
self,
X: Tensor,
sampler: MCSampler,
observation_noise: Union[bool, Tensor] = True,
**kwargs: Any,
) -> FixedNoiseGP:
raise NotImplementedError("Fantasize is not implemented!")

def train(self, mode: bool = True) -> None:
r"""Puts the model in `train` mode."""
super().train(mode=mode)
Expand Down
12 changes: 1 addition & 11 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""


from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

import pyro
import torch
Expand All @@ -24,7 +24,6 @@
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
from botorch.sampling.samplers import MCSampler
from botorch.utils.datasets import SupervisedDataset
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import MaternKernel
Expand Down Expand Up @@ -300,15 +299,6 @@ def batch_shape(self) -> torch.Size:
self._check_if_fitted()
return torch.Size([self.num_mcmc_samples])

def fantasize(
self,
X: Tensor,
sampler: MCSampler,
observation_noise: Union[bool, Tensor] = True,
**kwargs: Any,
) -> FixedNoiseMultiTaskGP:
raise NotImplementedError("Fantasize is not implemented!")

def _check_if_fitted(self):
r"""Raise an exception if the model hasn't been fitted."""
if self.covar_module is None:
Expand Down
24 changes: 17 additions & 7 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@

from __future__ import annotations

from typing import Any, List, Optional, Union
from typing import Any, List, NoReturn, Optional, Union

import torch
from botorch import settings
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model import FantasizeMixin
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import Log, OutcomeTransform
from botorch.models.utils import fantasize as fantasize_flag, validate_input_scaling
Expand Down Expand Up @@ -63,7 +64,7 @@
MIN_INFERRED_NOISE_LEVEL = 1e-4


class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP):
class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
r"""A single-task exact GP model.
A single-task exact GP using relatively strong priors on the Kernel
Expand Down Expand Up @@ -139,7 +140,9 @@ def __init__(
)
else:
self._is_custom_likelihood = True
ExactGP.__init__(self, train_X, train_Y, likelihood)
ExactGP.__init__(
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood
)
if mean_module is None:
mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
self.mean_module = mean_module
Expand Down Expand Up @@ -333,6 +336,8 @@ def fantasize(
)

def forward(self, x: Tensor) -> MultivariateNormal:
# TODO: reduce redundancy with the 'forward' method of
# SingleTaskGP, which is identical
if self.training:
x = self.transform_inputs(x)
mean_x = self.mean_module(x)
Expand Down Expand Up @@ -432,10 +437,15 @@ def __init__(
self.outcome_transform = outcome_transform
self.to(train_X)

def condition_on_observations(
self, X: Tensor, Y: Tensor, **kwargs: Any
) -> HeteroskedasticSingleTaskGP:
# TODO: HeteroskedasticSingleTaskGP should not be a subclass of
# SingleTaskGP because it can't function the way a SingleTaskGP does
# pyre-fixme[15]: Inconsistent override
def condition_on_observations(self, *_, **__) -> NoReturn:
raise NotImplementedError

def fantasize(self, *_, **__) -> NoReturn:
raise NotImplementedError

def subset_output(self, idcs: List[int]) -> HeteroskedasticSingleTaskGP:
# pyre-fixme[15]: Inconsistent override
def subset_output(self, idcs) -> NoReturn:
raise NotImplementedError
3 changes: 2 additions & 1 deletion botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model import FantasizeMixin
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.models.utils import gpt_posterior_settings
Expand Down Expand Up @@ -138,7 +139,7 @@ def untransform_posterior(
)


class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP):
class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
r"""
A model for high-dimensional output regression.
Expand Down
144 changes: 107 additions & 37 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Union
from typing import (
Any,
Callable,
Dict,
Hashable,
List,
Mapping,
Optional,
TypeVar,
Union,
)

import numpy as np
import torch
Expand All @@ -30,6 +40,8 @@
from torch import Tensor
from torch.nn import Module, ModuleList

TFantasizeMixin = TypeVar("TFantasizeMixin", bound="FantasizeMixin")


class Model(Module, ABC):
r"""Abstract base class for BoTorch models.
Expand Down Expand Up @@ -138,42 +150,6 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
f"`condition_on_observations` not implemented for {self.__class__.__name__}"
)

def fantasize(
self,
X: Tensor,
sampler: MCSampler,
observation_noise: bool = True,
**kwargs: Any,
) -> Model:
r"""Construct a fantasy model.
Constructs a fantasy model in the following fashion:
(1) compute the model posterior at `X` (including observation noise if
`observation_noise=True`).
(2) sample from this posterior (using `sampler`) to generate "fake"
observations.
(3) condition the model on the new fake observations.
Args:
X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of
the feature space, `n'` is the number of points per batch, and
`batch_shape` is the batch shape (must be compatible with the
batch shape of the model).
sampler: The sampler used for sampling from the posterior at `X`.
observation_noise: If True, include observation noise.
Returns:
The constructed fantasy model.
"""
propagate_grads = kwargs.pop("propagate_grads", False)
with fantasize_flag():
with settings.propagate_grads(propagate_grads):
post_X = self.posterior(X, observation_noise=observation_noise)
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
return self.condition_on_observations(
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
)

@classmethod
def construct_inputs(
cls,
Expand Down Expand Up @@ -252,6 +228,100 @@ def train(self, mode: bool = True) -> Model:
return super().train(mode=mode)


class FantasizeMixin(ABC):
"""
Mixin to add a `fantasize` method to a `Model`.
Example:
class BaseModel:
def __init__(self, ...):
def condition_on_observations(self, ...):
def posterior(self, ...):
def transform_inputs(self, ...):
class ModelThatCanFantasize(BaseModel, FantasizeMixin):
def __init__(self, args):
super().__init__(args)
model = ModelThatCanFantasize(...)
model.fantasize(X)
"""

@abstractmethod
def condition_on_observations(
self: TFantasizeMixin, X: Tensor, Y: Tensor, **kwargs: Any
) -> TFantasizeMixin:
"""
Classes that inherit from `FantasizeMixin` must implement
a `condition_on_observations` method.
"""

@abstractmethod
def posterior(
self,
X: Tensor,
*args,
observation_noise: bool = False,
**kwargs: Any,
) -> Posterior:
"""
Classes that inherit from `FantasizeMixin` must implement
a `posterior` method.
"""

@abstractmethod
def transform_inputs(
self,
X: Tensor,
input_transform: Optional[Module] = None,
) -> Tensor:
"""
Classes that inherit from `FantasizeMixin` must implement
a `transform_inputs` method.
"""

# When Python 3.11 arrives we can start annotating return types like
# this as
# 'Self', but at this point the verbose 'T...' syntax is needed.
def fantasize(
self: TFantasizeMixin,
# TODO: see if any of these can be imported only if TYPE_CHECKING
X: Tensor,
sampler: MCSampler,
observation_noise: bool = True,
**kwargs: Any,
) -> TFantasizeMixin:
r"""Construct a fantasy model.
Constructs a fantasy model in the following fashion:
(1) compute the model posterior at `X` (including observation noise if
`observation_noise=True`).
(2) sample from this posterior (using `sampler`) to generate "fake"
observations.
(3) condition the model on the new fake observations.
Args:
X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of
the feature space, `n'` is the number of points per batch, and
`batch_shape` is the batch shape (must be compatible with the
batch shape of the model).
sampler: The sampler used for sampling from the posterior at `X`.
observation_noise: If True, include observation noise.
kwargs: Will be passed to `model.condition_on_observations`
Returns:
The constructed fantasy model.
"""
propagate_grads = kwargs.pop("propagate_grads", False)
with fantasize_flag():
with settings.propagate_grads(propagate_grads):
post_X = self.posterior(X, observation_noise=observation_noise)
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
return self.condition_on_observations(
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
)


class ModelList(Model):
r"""A multi-output Model represented by a list of independent models.
Expand Down
3 changes: 2 additions & 1 deletion botorch/models/model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.gpytorch import GPyTorchModel, ModelListGPyTorchModel
from botorch.models.model import FantasizeMixin
from gpytorch.models import IndependentModelList
from torch import Tensor


class ModelListGP(IndependentModelList, ModelListGPyTorchModel):
class ModelListGP(IndependentModelList, ModelListGPyTorchModel, FantasizeMixin):
r"""A multi-output GP model with independent GPs for the outputs.
This model supports different-shaped training inputs for each of its
Expand Down
12 changes: 8 additions & 4 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
PairwiseLikelihood,
PairwiseProbitLikelihood,
)
from botorch.models.model import Model
from botorch.models.model import FantasizeMixin, Model
from botorch.models.transforms.input import InputTransform
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
Expand All @@ -44,7 +44,6 @@
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.mlls import MarginalLogLikelihood
from gpytorch.models.gp import GP
from gpytorch.module import Module
from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators import LinearOperator, RootLinearOperator
Expand All @@ -54,7 +53,12 @@
from torch.nn.modules.module import _IncompatibleKeys


class PairwiseGP(Model, GP):
# Why we subclass GP even though it provides no functionality:
# if this subclassing is removed, we get the following GPyTorch error:
# "RuntimeError: All MarginalLogLikelihood objects must be given a GP object as
# a model. If you are using a more complicated model involving a GP, pass the
# underlying GP object as the model, not a full PyTorch module."
class PairwiseGP(Model, GP, FantasizeMixin):
r"""Probit GP for preference learning with Laplace approximation
A probit-likelihood GP that learns via pairwise comparison data, using a
Expand Down Expand Up @@ -100,7 +104,7 @@ def __init__(
datapoints: Tensor,
comparisons: Tensor,
likelihood: Optional[PairwiseLikelihood] = None,
covar_module: Optional[Module] = None,
covar_module: Optional[ScaleKernel] = None,
input_transform: Optional[InputTransform] = None,
**kwargs,
) -> None:
Expand Down
Loading

0 comments on commit ef7d39e

Please sign in to comment.