Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add abstract fully bayesian GP #2696

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions botorch/acquisition/joint_entropy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from botorch import settings
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP
from botorch.models.model import Model
from botorch.models.utils import check_no_nans, fantasize as fantasize_flag
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
# and the optimal outputs have shapes num_optima x [num_models if FB] x 1 x 1
# The third dimension equaling 1 is required to get one optimum per model,
# which raises a BotorchTensorDimensionWarning.
if isinstance(model, SaasFullyBayesianSingleTaskGP):
if isinstance(model, FullyBayesianSingleTaskGP):
raise NotImplementedError(FULLY_BAYESIAN_ERROR_MSG)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
Expand Down
121 changes: 72 additions & 49 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""

import math
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any

Expand Down Expand Up @@ -311,14 +311,13 @@ def load_mcmc_samples(
return mean_module, covar_module, likelihood


class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel):
r"""A fully Bayesian single-task GP model with the SAAS prior.
class FullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel, ABC):
r"""An abstract fully Bayesian single-task GP model.
This model assumes that the inputs have been normalized to [0, 1]^d and that
the output has been standardized to have zero mean and unit variance. You can
either normalize and standardize the data before constructing the model or use
an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_
with a Matern-5/2 kernel is used by default.
an `input_transform` and `outcome_transform`.
You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
isn't compatible with `fit_gpytorch_mll`.
Expand Down Expand Up @@ -412,17 +411,9 @@ def _check_if_fitted(self):
)

@property
def median_lengthscale(self) -> Tensor:
r"""Median lengthscales across the MCMC samples."""
self._check_if_fitted()
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
return lengthscale.median(0).values.squeeze(0)

@property
@abstractmethod
def num_mcmc_samples(self) -> int:
r"""Number of MCMC samples in the model."""
self._check_if_fitted()
return len(self.covar_module.outputscale)

@property
def batch_shape(self) -> torch.Size:
Expand Down Expand Up @@ -459,41 +450,6 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
self.likelihood,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module`
and `likelihood` aren't initialized until the model has been fitted. The reason
for this is that we don't know the number of MCMC samples until NUTS is called.
Given the state dict, we can initialize a new model with some dummy samples and
then load the state dict into this model. This currently only works for a
`SaasPyroModel` and supporting more Pyro models likely requires moving the model
construction logic into the Pyro model itself.
"""

if not isinstance(self.pyro_model, SaasPyroModel):
raise NotImplementedError("load_state_dict only works for SaasPyroModel")
raw_mean = state_dict["mean_module.raw_constant"]
num_mcmc_samples = len(raw_mean)
dim = self.pyro_model.train_X.shape[-1]
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
# Load some dummy samples
mcmc_samples = {
"mean": torch.ones(num_mcmc_samples, **tkwargs),
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
}
if self.pyro_model.train_Yvar is None:
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
(
self.mean_module,
self.covar_module,
self.likelihood,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)

def forward(self, X: Tensor) -> MultivariateNormal:
"""
Unlike in other classes' `forward` methods, there is no `if self.training`
Expand Down Expand Up @@ -579,3 +535,70 @@ def condition_on_observations(
X = X.repeat(*(Y.shape[:-2] + (1, 1)))

return super().condition_on_observations(X, Y, **kwargs)


class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
r"""A fully Bayesian single-task GP model with the SAAS prior.
This model assumes that the inputs have been normalized to [0, 1]^d and that
the output has been standardized to have zero mean and unit variance. You can
either normalize and standardize the data before constructing the model or use
an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_
with a Matern-5/2 kernel is used by default.
You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
isn't compatible with `fit_gpytorch_mll`.
Example:
>>> saas_gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
>>> fit_fully_bayesian_model_nuts(saas_gp)
>>> posterior = saas_gp.posterior(test_X)
"""

@property
def num_mcmc_samples(self) -> int:
r"""Number of MCMC samples in the model."""
self._check_if_fitted()
return len(self.covar_module.outputscale)

@property
def median_lengthscale(self) -> Tensor:
r"""Median lengthscales across the MCMC samples."""
self._check_if_fitted()
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
return lengthscale.median(0).values.squeeze(0)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module`
and `likelihood` aren't initialized until the model has been fitted. The reason
for this is that we don't know the number of MCMC samples until NUTS is called.
Given the state dict, we can initialize a new model with some dummy samples and
then load the state dict into this model. This currently only works for a
`SaasPyroModel` and supporting more Pyro models likely requires moving the model
construction logic into the Pyro model itself.
"""

if not isinstance(self.pyro_model, SaasPyroModel):
raise NotImplementedError("load_state_dict only works for SaasPyroModel")
raw_mean = state_dict["mean_module.raw_constant"]
num_mcmc_samples = len(raw_mean)
dim = self.pyro_model.train_X.shape[-1]
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
# Load some dummy samples
mcmc_samples = {
"mean": torch.ones(num_mcmc_samples, **tkwargs),
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
}
if self.pyro_model.train_Yvar is None:
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
(
self.mean_module,
self.covar_module,
self.likelihood,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)
Loading