Skip to content

Commit

Permalink
Update the remaining models to use new default covar & likelihood mod…
Browse files Browse the repository at this point in the history
…ules

Summary:
X-link: pytorch/botorch#2507

Updates the default covar & likelihood modules of BoTorch models. See pytorch/botorch#2451 for details on the new defaults.

For models that utilize a composite kernel, such as multi-fidelity/task/context, this change only affects the base kernel.

Exceptions / Models that do not utilize the new modules:
- Fully-bayesian models.
- Pairwise GP.
- Fidelity kernels for MF models.
- (likelihood only) Any model that utilizes a likelihood other than `GaussianLikelihood` (e.g., `MultiTaskGaussianLikelihood`).

Reviewed By: esantorella

Differential Revision: D62196414
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 5, 2024
1 parent db48a77 commit 5b8f61c
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@
from botorch.models.multitask import MultiTaskGP
from botorch.utils.types import DEFAULT
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors.torch_priors import GammaPrior
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior


class ModelRegistryTest(TestCase):
Expand Down Expand Up @@ -452,12 +453,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
]
)

lengthscale_priors = [
GammaPrior(6.0, 3.0),
GammaPrior(3.0, 6.0),
]

for surrogate, lengthscale_prior in zip(surrogates, lengthscale_priors):
for surrogate, default_model in zip(surrogates, (False, True)):
constructor = Models.SAAS_MTGP if use_saas else Models.ST_MTGP
mtgp = constructor(
experiment=exp,
Expand All @@ -468,26 +464,25 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
self.assertIsInstance(mtgp, TorchModelBridge)
self.assertIsInstance(mtgp.model, BoTorchModel)
self.assertEqual(mtgp.model.acquisition_class, Acquisition)

self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
models = mtgp.model.surrogate.model.models

for i in range(len(models)):
for model in mtgp.model.surrogate.model.models:
self.assertIsInstance(
models[i],
model,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
)
if use_saas is False:
self.assertIsInstance(models[i].covar_module, ScaleKernel)
base_kernel = models[i].covar_module.base_kernel
if use_saas is False and default_model is False:
self.assertIsInstance(model.covar_module, ScaleKernel)
base_kernel = model.covar_module.base_kernel
self.assertIsInstance(base_kernel, MaternKernel)
self.assertEqual(
base_kernel.lengthscale_prior.concentration,
lengthscale_prior.concentration,
base_kernel.lengthscale_prior.concentration, 6.0
)
self.assertEqual(
base_kernel.lengthscale_prior.rate,
lengthscale_prior.rate,
self.assertEqual(base_kernel.lengthscale_prior.rate, 3.0)
elif use_saas is False:
self.assertIsInstance(model.covar_module, RBFKernel)
self.assertIsInstance(
model.covar_module.lengthscale_prior, LogNormalPrior
)

gr = mtgp.gen(
Expand Down

0 comments on commit 5b8f61c

Please sign in to comment.