Skip to content

Commit

Permalink
add abstract fully bayesian GP (pytorch#2696)
Browse files Browse the repository at this point in the history
Summary:

see title. This enables supporting different fully bayesian models.

Differential Revision: D68529434
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 23, 2025
1 parent 0df5521 commit b07d239
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 51 deletions.
4 changes: 2 additions & 2 deletions botorch/acquisition/joint_entropy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from botorch import settings
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP
from botorch.models.model import Model
from botorch.models.utils import check_no_nans, fantasize as fantasize_flag
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
# and the optimal outputs have shapes num_optima x [num_models if FB] x 1 x 1
# The third dimension equaling 1 is required to get one optimum per model,
# which raises a BotorchTensorDimensionWarning.
if isinstance(model, SaasFullyBayesianSingleTaskGP):
if isinstance(model, FullyBayesianSingleTaskGP):
raise NotImplementedError(FULLY_BAYESIAN_ERROR_MSG)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
Expand Down
121 changes: 72 additions & 49 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""

import math
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any

Expand Down Expand Up @@ -311,14 +311,13 @@ def load_mcmc_samples(
return mean_module, covar_module, likelihood


class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel):
r"""A fully Bayesian single-task GP model with the SAAS prior.
class FullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel, ABC):
r"""An abstract fully Bayesian single-task GP model.
This model assumes that the inputs have been normalized to [0, 1]^d and that
the output has been standardized to have zero mean and unit variance. You can
either normalize and standardize the data before constructing the model or use
an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_
with a Matern-5/2 kernel is used by default.
an `input_transform` and `outcome_transform`.
You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
isn't compatible with `fit_gpytorch_mll`.
Expand Down Expand Up @@ -412,17 +411,9 @@ def _check_if_fitted(self):
)

@property
def median_lengthscale(self) -> Tensor:
r"""Median lengthscales across the MCMC samples."""
self._check_if_fitted()
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
return lengthscale.median(0).values.squeeze(0)

@property
@abstractmethod
def num_mcmc_samples(self) -> int:
r"""Number of MCMC samples in the model."""
self._check_if_fitted()
return len(self.covar_module.outputscale)

@property
def batch_shape(self) -> torch.Size:
Expand Down Expand Up @@ -459,41 +450,6 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
self.likelihood,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module`
and `likelihood` aren't initialized until the model has been fitted. The reason
for this is that we don't know the number of MCMC samples until NUTS is called.
Given the state dict, we can initialize a new model with some dummy samples and
then load the state dict into this model. This currently only works for a
`SaasPyroModel` and supporting more Pyro models likely requires moving the model
construction logic into the Pyro model itself.
"""

if not isinstance(self.pyro_model, SaasPyroModel):
raise NotImplementedError("load_state_dict only works for SaasPyroModel")
raw_mean = state_dict["mean_module.raw_constant"]
num_mcmc_samples = len(raw_mean)
dim = self.pyro_model.train_X.shape[-1]
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
# Load some dummy samples
mcmc_samples = {
"mean": torch.ones(num_mcmc_samples, **tkwargs),
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
}
if self.pyro_model.train_Yvar is None:
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
(
self.mean_module,
self.covar_module,
self.likelihood,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)

def forward(self, X: Tensor) -> MultivariateNormal:
"""
Unlike in other classes' `forward` methods, there is no `if self.training`
Expand Down Expand Up @@ -579,3 +535,70 @@ def condition_on_observations(
X = X.repeat(*(Y.shape[:-2] + (1, 1)))

return super().condition_on_observations(X, Y, **kwargs)


class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
r"""A fully Bayesian single-task GP model with the SAAS prior.
This model assumes that the inputs have been normalized to [0, 1]^d and that
the output has been standardized to have zero mean and unit variance. You can
either normalize and standardize the data before constructing the model or use
an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_
with a Matern-5/2 kernel is used by default.
You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
isn't compatible with `fit_gpytorch_mll`.
Example:
>>> saas_gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
>>> fit_fully_bayesian_model_nuts(saas_gp)
>>> posterior = saas_gp.posterior(test_X)
"""

@property
def num_mcmc_samples(self) -> int:
r"""Number of MCMC samples in the model."""
self._check_if_fitted()
return len(self.covar_module.outputscale)

@property
def median_lengthscale(self) -> Tensor:
r"""Median lengthscales across the MCMC samples."""
self._check_if_fitted()
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
return lengthscale.median(0).values.squeeze(0)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module`
and `likelihood` aren't initialized until the model has been fitted. The reason
for this is that we don't know the number of MCMC samples until NUTS is called.
Given the state dict, we can initialize a new model with some dummy samples and
then load the state dict into this model. This currently only works for a
`SaasPyroModel` and supporting more Pyro models likely requires moving the model
construction logic into the Pyro model itself.
"""

if not isinstance(self.pyro_model, SaasPyroModel):
raise NotImplementedError("load_state_dict only works for SaasPyroModel")
raw_mean = state_dict["mean_module.raw_constant"]
num_mcmc_samples = len(raw_mean)
dim = self.pyro_model.train_X.shape[-1]
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
# Load some dummy samples
mcmc_samples = {
"mean": torch.ones(num_mcmc_samples, **tkwargs),
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
}
if self.pyro_model.train_Yvar is None:
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
(
self.mean_module,
self.covar_module,
self.likelihood,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)

0 comments on commit b07d239

Please sign in to comment.