From b07772d396999c5540819c1b430eae4f948f94df Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Fri, 24 Jan 2025 14:18:47 -0800 Subject: [PATCH] add abstract fully bayesian GP (#2696) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2696 see title. This enables supporting different fully bayesian models. Reviewed By: saitcakmak Differential Revision: D68529434 fbshipit-source-id: b739112924e3eeb061794c001c6c99c45bfb07fe --- botorch/acquisition/joint_entropy_search.py | 4 +- botorch/models/fully_bayesian.py | 121 ++++++++++++-------- 2 files changed, 74 insertions(+), 51 deletions(-) diff --git a/botorch/acquisition/joint_entropy_search.py b/botorch/acquisition/joint_entropy_search.py index 5db9d32019..a1336a6db1 100644 --- a/botorch/acquisition/joint_entropy_search.py +++ b/botorch/acquisition/joint_entropy_search.py @@ -29,7 +29,7 @@ from botorch import settings from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin from botorch.acquisition.objective import PosteriorTransform -from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP +from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP from botorch.models.model import Model from botorch.models.utils import check_no_nans, fantasize as fantasize_flag from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL @@ -124,7 +124,7 @@ def __init__( # and the optimal outputs have shapes num_optima x [num_models if FB] x 1 x 1 # The third dimension equaling 1 is required to get one optimum per model, # which raises a BotorchTensorDimensionWarning. - if isinstance(model, SaasFullyBayesianSingleTaskGP): + if isinstance(model, FullyBayesianSingleTaskGP): raise NotImplementedError(FULLY_BAYESIAN_ERROR_MSG) with warnings.catch_warnings(): warnings.filterwarnings("ignore") diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index dbe4740d05..cd06432609 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -31,7 +31,7 @@ """ import math -from abc import abstractmethod +from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any @@ -311,14 +311,13 @@ def load_mcmc_samples( return mean_module, covar_module, likelihood -class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel): - r"""A fully Bayesian single-task GP model with the SAAS prior. +class FullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel, ABC): + r"""An abstract fully Bayesian single-task GP model. This model assumes that the inputs have been normalized to [0, 1]^d and that the output has been standardized to have zero mean and unit variance. You can either normalize and standardize the data before constructing the model or use - an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_ - with a Matern-5/2 kernel is used by default. + an `input_transform` and `outcome_transform`. You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it isn't compatible with `fit_gpytorch_mll`. @@ -412,17 +411,9 @@ def _check_if_fitted(self): ) @property - def median_lengthscale(self) -> Tensor: - r"""Median lengthscales across the MCMC samples.""" - self._check_if_fitted() - lengthscale = self.covar_module.base_kernel.lengthscale.clone() - return lengthscale.median(0).values.squeeze(0) - - @property + @abstractmethod def num_mcmc_samples(self) -> int: r"""Number of MCMC samples in the model.""" - self._check_if_fitted() - return len(self.covar_module.outputscale) @property def batch_shape(self) -> torch.Size: @@ -459,41 +450,6 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None: self.likelihood, ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - r"""Custom logic for loading the state dict. - - The standard approach of calling `load_state_dict` currently doesn't play well - with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module` - and `likelihood` aren't initialized until the model has been fitted. The reason - for this is that we don't know the number of MCMC samples until NUTS is called. - Given the state dict, we can initialize a new model with some dummy samples and - then load the state dict into this model. This currently only works for a - `SaasPyroModel` and supporting more Pyro models likely requires moving the model - construction logic into the Pyro model itself. - """ - - if not isinstance(self.pyro_model, SaasPyroModel): - raise NotImplementedError("load_state_dict only works for SaasPyroModel") - raw_mean = state_dict["mean_module.raw_constant"] - num_mcmc_samples = len(raw_mean) - dim = self.pyro_model.train_X.shape[-1] - tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype} - # Load some dummy samples - mcmc_samples = { - "mean": torch.ones(num_mcmc_samples, **tkwargs), - "lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs), - "outputscale": torch.ones(num_mcmc_samples, **tkwargs), - } - if self.pyro_model.train_Yvar is None: - mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) - ( - self.mean_module, - self.covar_module, - self.likelihood, - ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) - # Load the actual samples from the state dict - super().load_state_dict(state_dict=state_dict, strict=strict) - def forward(self, X: Tensor) -> MultivariateNormal: """ Unlike in other classes' `forward` methods, there is no `if self.training` @@ -579,3 +535,70 @@ def condition_on_observations( X = X.repeat(*(Y.shape[:-2] + (1, 1))) return super().condition_on_observations(X, Y, **kwargs) + + +class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP): + r"""A fully Bayesian single-task GP model with the SAAS prior. + + This model assumes that the inputs have been normalized to [0, 1]^d and that + the output has been standardized to have zero mean and unit variance. You can + either normalize and standardize the data before constructing the model or use + an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_ + with a Matern-5/2 kernel is used by default. + + You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it + isn't compatible with `fit_gpytorch_mll`. + + Example: + >>> saas_gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y) + >>> fit_fully_bayesian_model_nuts(saas_gp) + >>> posterior = saas_gp.posterior(test_X) + """ + + @property + def num_mcmc_samples(self) -> int: + r"""Number of MCMC samples in the model.""" + self._check_if_fitted() + return len(self.covar_module.outputscale) + + @property + def median_lengthscale(self) -> Tensor: + r"""Median lengthscales across the MCMC samples.""" + self._check_if_fitted() + lengthscale = self.covar_module.base_kernel.lengthscale.clone() + return lengthscale.median(0).values.squeeze(0) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + r"""Custom logic for loading the state dict. + + The standard approach of calling `load_state_dict` currently doesn't play well + with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module` + and `likelihood` aren't initialized until the model has been fitted. The reason + for this is that we don't know the number of MCMC samples until NUTS is called. + Given the state dict, we can initialize a new model with some dummy samples and + then load the state dict into this model. This currently only works for a + `SaasPyroModel` and supporting more Pyro models likely requires moving the model + construction logic into the Pyro model itself. + """ + + if not isinstance(self.pyro_model, SaasPyroModel): + raise NotImplementedError("load_state_dict only works for SaasPyroModel") + raw_mean = state_dict["mean_module.raw_constant"] + num_mcmc_samples = len(raw_mean) + dim = self.pyro_model.train_X.shape[-1] + tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype} + # Load some dummy samples + mcmc_samples = { + "mean": torch.ones(num_mcmc_samples, **tkwargs), + "lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs), + "outputscale": torch.ones(num_mcmc_samples, **tkwargs), + } + if self.pyro_model.train_Yvar is None: + mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) + ( + self.mean_module, + self.covar_module, + self.likelihood, + ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) + # Load the actual samples from the state dict + super().load_state_dict(state_dict=state_dict, strict=strict)