Skip to content

Commit

Permalink
Update the default SingleTaskGP prior (#2449)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#2610

Pull Request resolved: #2449

See title

Differential Revision: D60080819
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Jul 29, 2024
1 parent 96a71e7 commit f4eeb4e
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 100 deletions.
26 changes: 16 additions & 10 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from botorch.models.transforms.outcome import Log, OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
MIN_INFERRED_NOISE_LEVEL,
)
from botorch.utils.containers import BotorchContainer
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
)
if likelihood is None:
if train_Yvar is None:
likelihood = get_gaussian_likelihood_with_gamma_prior(
likelihood = get_gaussian_likelihood_with_lognormal_prior(
batch_shape=self._aug_batch_shape
)
else:
Expand All @@ -190,15 +190,21 @@ def __init__(
mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
self.mean_module = mean_module
if covar_module is None:
covar_module = get_matern_kernel_with_gamma_prior(
ard_num_dims=transformed_X.shape[-1],
covar_module = get_covar_module_with_dim_scaled_prior(
dim=transformed_X.shape[-1],
batch_shape=self._aug_batch_shape,
)
self._subset_batch_dict = {
"mean_module.raw_constant": -1,
"covar_module.raw_outputscale": -1,
"covar_module.base_kernel.raw_lengthscale": -3,
}
if hasattr(covar_module, "outputscale"):
self._subset_batch_dict = {
"mean_module.raw_constant": -1,
"covar_module.raw_outputscale": -1,
"covar_module.base_kernel.raw_lengthscale": -3,
}
else:
self._subset_batch_dict = {
"mean_module.raw_constant": -1,
"covar_module.raw_lengthscale": -3,
}
if train_Yvar is None:
self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2
self.covar_module: Module = covar_module
Expand Down
7 changes: 3 additions & 4 deletions botorch/utils/gp_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,9 @@ def __init__(
"""
if not isinstance(kernel, ScaleKernel):
base_kernel = kernel
outputscale = torch.tensor(
1.0,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
outputscale = torch.ones(kernel.batch_shape).to(
dtype=kernel.lengthscale.dtype,
device=kernel.lengthscale.device,
)
else:
base_kernel = kernel.base_kernel
Expand Down
23 changes: 14 additions & 9 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from botorch.models.transforms.outcome import Standardize
from botorch.utils.test_helpers import SimpleGPyTorchModel
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels import RBFKernel
from gpytorch.kernels import MaternKernel, RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.priors import LogNormalPrior
Expand Down Expand Up @@ -134,13 +134,17 @@ def test_model_list_to_batched(self):
model_list_to_batched(ModelListGP(gp1, gp2))
# check scalar agreement
gp2 = SingleTaskGP(train_X, train_Y2)
gp2.likelihood.noise_covar.noise_prior.rate.fill_(1.0)

# modified to check the scalar agreement in a parameter that is accessible
# since the error is going to slip through for the non-parametrizable
# priors regardless (like the LogNormal)
gp2.likelihood.noise_covar.raw_noise_constraint.lower_bound.fill_(1e-3)
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1, gp2))
# check tensor shape agreement
gp2 = SingleTaskGP(train_X, train_Y2)
gp2.covar_module.raw_outputscale = torch.nn.Parameter(
torch.tensor([0.0], device=self.device, dtype=dtype)
gp2.likelihood.noise_covar.raw_noise = torch.nn.Parameter(
torch.tensor([[0.42]], device=self.device, dtype=dtype)
)
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1, gp2))
Expand All @@ -155,14 +159,15 @@ def test_model_list_to_batched(self):
with self.assertRaises(NotImplementedError):
model_list_to_batched(ModelListGP(gp2))
# test non-default kernel
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel())
gp2 = SingleTaskGP(train_X, train_Y2, covar_module=RBFKernel())
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=MaternKernel())
gp2 = SingleTaskGP(train_X, train_Y2, covar_module=MaternKernel())
list_gp = ModelListGP(gp1, gp2)
batch_gp = model_list_to_batched(list_gp)
self.assertEqual(type(batch_gp.covar_module), RBFKernel)
self.assertEqual(type(batch_gp.covar_module), MaternKernel)
# test error when component GPs have different kernel types
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel())
gp2 = SingleTaskGP(train_X, train_Y2)
# added types for both default and non-default kernels for clarity
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=MaternKernel())
gp2 = SingleTaskGP(train_X, train_Y2, covar_module=RBFKernel())
list_gp = ModelListGP(gp1, gp2)
with self.assertRaises(UnsupportedError):
model_list_to_batched(list_gp)
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_FixedSingleSampleModel(self):
post = model.posterior(test_X)
original_output = post.mean + post.variance.sqrt() * w
fss_output = fss_model(test_X)
self.assertTrue(torch.equal(original_output, fss_output))
self.assertAllClose(original_output, fss_output)

self.assertTrue(hasattr(fss_model, "num_outputs"))

Expand Down
12 changes: 6 additions & 6 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from botorch.utils.sampling import manual_seed
from botorch.utils.test_helpers import get_pvar_expected
from botorch.utils.testing import _get_random_data, BotorchTestCase
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import (
_GaussianLikelihoodBase,
FixedNoiseGaussianLikelihood,
Expand All @@ -33,7 +33,7 @@
from gpytorch.means import ConstantMean, ZeroMean
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm
from gpytorch.priors import GammaPrior
from gpytorch.priors import LogNormalPrior


class TestGPRegressionBase(BotorchTestCase):
Expand Down Expand Up @@ -96,10 +96,10 @@ def test_gp(self, double_only: bool = False):

# test init
self.assertIsInstance(model.mean_module, ConstantMean)
self.assertIsInstance(model.covar_module, ScaleKernel)
matern_kernel = model.covar_module.base_kernel
self.assertIsInstance(matern_kernel, MaternKernel)
self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
self.assertIsInstance(model.covar_module, RBFKernel)
rbf_kernel = model.covar_module
self.assertIsInstance(rbf_kernel, RBFKernel)
self.assertIsInstance(rbf_kernel.lengthscale_prior, LogNormalPrior)
if use_octf:
self.assertIsInstance(model.outcome_transform, Standardize)
if use_intf:
Expand Down
10 changes: 4 additions & 6 deletions test/models/test_model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import _get_random_data, BotorchTestCase
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import LikelihoodList
from gpytorch.likelihoods.gaussian_likelihood import (
FixedNoiseGaussianLikelihood,
Expand All @@ -34,7 +34,7 @@
from gpytorch.means import ConstantMean
from gpytorch.mlls import SumMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors import GammaPrior
from gpytorch.priors import LogNormalPrior
from torch import Tensor


Expand Down Expand Up @@ -104,10 +104,8 @@ def _base_test_ModelListGP(
self.assertEqual(model.num_outputs, 2)
for m in model.models:
self.assertIsInstance(m.mean_module, ConstantMean)
self.assertIsInstance(m.covar_module, ScaleKernel)
matern_kernel = m.covar_module.base_kernel
self.assertIsInstance(matern_kernel, MaternKernel)
self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
self.assertIsInstance(m.covar_module, RBFKernel)
self.assertIsInstance(m.covar_module.lengthscale_prior, LogNormalPrior)
if outcome_transform != "None":
self.assertIsInstance(
m.outcome_transform, (Log, Standardize, ChainedOutcomeTransform)
Expand Down
6 changes: 3 additions & 3 deletions test/optim/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def setUp(self) -> None:
self.mlls = {}
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
train_Y = torch.sin((2 * math.pi) * train_X)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
train_X = torch.linspace(0, 1, 30).unsqueeze(-1)
train_Y = torch.sin((6 * math.pi) * train_X)
train_Y = train_Y + 0.01 * torch.randn_like(train_Y)

model = SingleTaskGP(
train_X=train_X,
Expand Down
7 changes: 6 additions & 1 deletion test/optim/utils/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from botorch import settings
from botorch.models import SingleTaskGP
from botorch.models.utils.gpytorch_modules import get_matern_kernel_with_gamma_prior
from botorch.optim.utils import (
get_data_loader,
get_name_filter,
Expand Down Expand Up @@ -161,7 +162,11 @@ def test_sample_all_priors(self):
for dtype in (torch.float, torch.double):
train_X = torch.rand(3, 5, device=self.device, dtype=dtype)
train_Y = torch.rand(3, 1, device=self.device, dtype=dtype)
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
covar_module=get_matern_kernel_with_gamma_prior(train_X.shape[-1]),
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
mll.to(device=self.device, dtype=dtype)
original_state_dict = dict(deepcopy(mll.model.state_dict()))
Expand Down
5 changes: 2 additions & 3 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from botorch.settings import debug
from botorch.utils.context_managers import module_rollback_ctx, TensorCheckpoint
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels import MaternKernel
from gpytorch.kernels import RBFKernel
from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
from linear_operator.utils.errors import NotPSDError

Expand Down Expand Up @@ -136,8 +136,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=output_dim),
)
self.assertIsInstance(model.covar_module.base_kernel, MaternKernel)
model.covar_module.base_kernel.nu = 2.5
self.assertIsInstance(model.covar_module, RBFKernel)

mll = ExactMarginalLogLikelihood(model.likelihood, model)
for dtype in (torch.float32, torch.float64):
Expand Down
6 changes: 6 additions & 0 deletions test_community/acquisition/test_multi_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

import torch
from botorch.exceptions import UnsupportedError
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
)
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior

from botorch_community.acquisition.augmented_multisource import (
Expand All @@ -24,6 +28,8 @@ def _get_mock_agp(self, batch_shape, dtype):
model_kwargs = {
"train_X": train_X,
"train_Y": train_Y,
"covar_module": get_matern_kernel_with_gamma_prior(train_X.shape[-1]),
"mean_module": get_gaussian_likelihood_with_gamma_prior(),
}
model = SingleTaskAugmentedGP(**model_kwargs)
return model
Expand Down
6 changes: 6 additions & 0 deletions test_community/models/test_gp_regression_multisource.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from botorch.exceptions import InputDataError, OptimizationWarning
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
)
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling import SobolQMCNormalSampler
from botorch.utils.test_helpers import get_pvar_expected
Expand Down Expand Up @@ -65,6 +69,8 @@ def _get_model_and_data(
"train_Yvar": torch.full_like(train_Y, 0.01) if train_Yvar else None,
"outcome_transform": outcome_transform,
"input_transform": input_transform,
"covar_module": get_matern_kernel_with_gamma_prior(d),
"likelihood": get_gaussian_likelihood_with_gamma_prior(d),
}
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
return model, model_kwargs
Expand Down
25 changes: 17 additions & 8 deletions tutorials/constraint_active_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
" return radius * r * z\n",
"\n",
" def _get_base_point_mask(self, X):\n",
" distance_matrix = self.model.models[0].covar_module.base_kernel.covar_dist(\n",
" distance_matrix = self.model.models[0].covar_module.covar_dist(\n",
" X, self.base_points\n",
" )\n",
" return smooth_mask(distance_matrix, self.punchout_radius)\n",
Expand Down Expand Up @@ -676,9 +676,18 @@
"\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 6))\n",
"h1 = ax.contourf(Xplt.cpu().numpy(), Yplt.cpu().numpy(), Zplt.cpu().numpy(), 20, cmap=\"Blues\", alpha=0.6)\n",
"h1 = ax.contourf(\n",
" Xplt.cpu().numpy(),\n",
" Yplt.cpu().numpy(),\n",
" Zplt.cpu().numpy(),\n",
" 20,\n",
" cmap=\"Blues\",\n",
" alpha=0.6,\n",
")\n",
"fig.colorbar(h1)\n",
"ax.contour(Xplt.cpu().numpy(), Yplt.cpu().numpy(), Zplt.cpu().numpy(), [0.55, 0.75], colors=\"k\")\n",
"ax.contour(\n",
" Xplt.cpu().numpy(), Yplt.cpu().numpy(), Zplt.cpu().numpy(), [0.55, 0.75], colors=\"k\"\n",
")\n",
"\n",
"feasible_inds = (\n",
" identify_samples_which_satisfy_constraints(Y, constraints)\n",
Expand Down Expand Up @@ -715,10 +724,12 @@
],
"metadata": {
"fileHeader": "",
"fileUid": "cb282d47-f143-4c00-9ae1-b631e97daddb",
"isAdHoc": false,
"kernelspec": {
"display_name": "python3",
"display_name": "Python 3",
"language": "python",
"name": "python3"
"name": "bento_kernel_default"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -732,7 +743,5 @@
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
}
16 changes: 9 additions & 7 deletions tutorials/fit_model_with_torch_optimizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"# use a GPU if available\n",
Expand Down Expand Up @@ -190,7 +191,7 @@
" if (epoch + 1) % 10 == 0:\n",
" print(\n",
" f\"Epoch {epoch+1:>3}/{NUM_EPOCHS} - Loss: {loss.item():>4.3f} \"\n",
" f\"lengthscale: {model.covar_module.base_kernel.lengthscale.item():>4.3f} \"\n",
" f\"lengthscale: {model.covar_module.lengthscale.item():>4.3f} \"\n",
" f\"noise: {model.likelihood.noise.item():>4.3f}\"\n",
" )\n",
" optimizer.step()"
Expand All @@ -215,7 +216,7 @@
"outputs": [],
"source": [
"# set model (and likelihood)\n",
"model.eval();"
"model.eval()"
]
},
{
Expand Down Expand Up @@ -309,12 +310,13 @@
}
],
"metadata": {
"fileHeader": "",
"fileUid": "29414a8f-010b-41a8-837d-70d9294b809e",
"isAdHoc": false,
"kernelspec": {
"display_name": "python3",
"display_name": "Python 3",
"language": "python",
"name": "python3"
"name": "bento_kernel_default"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
}
Loading

0 comments on commit f4eeb4e

Please sign in to comment.