Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate model conversion code #2431

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

from __future__ import annotations

import warnings
from copy import deepcopy
from typing import Dict, Optional, Set, Tuple

import torch
from botorch.exceptions import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
Expand All @@ -24,7 +26,14 @@
from botorch.models.transforms.outcome import OutcomeTransform
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from torch import Tensor
from torch.nn import Module
from torch.nn import Module, ModuleList

DEPRECATION_MESSAGE = (
"Model converter code is deprecated and will be removed in v0.13 release. "
"Its correct behavior is dependent on some assumptions about model priors "
"that do not always hold. Use it at your own risk! See "
"https://github.com/cornellius-gp/gpytorch/issues/2550."
)


def _get_module(module: Module, name: str) -> Module:
Expand All @@ -49,15 +58,25 @@ def _get_module(module: Module, name: str) -> Module:
return current


def _check_compatibility(models: ModelListGP) -> None:
"""Check if a ModelListGP can be converted."""
def _check_compatibility(models: ModuleList) -> None:
"""Check if the submodels of a ModelListGP are compatible with the converter."""
# Check that all submodules are of the same type.
for modn, mod in models[0].named_modules():
mcls = mod.__class__
if not all(isinstance(_get_module(m, modn), mcls) for m in models[1:]):
raise UnsupportedError(
"Sub-modules must be of the same type across models."
)
if "prior" in modn and len(mod.state_dict()) == 0:
warnings.warn(
"Model converter cannot verify compatibility of GPyTorch priors "
"that do not register their parameters as buffers. If the prior "
"is different than the default prior set by the model constructor "
"this may not work correctly. Use it at your own risk! See "
"https://github.com/cornellius-gp/gpytorch/issues/2550.",
BotorchWarning,
stacklevel=3,
)

# Check that each model is a BatchedMultiOutputGPyTorchModel.
if not all(isinstance(m, BatchedMultiOutputGPyTorchModel) for m in models):
Expand Down Expand Up @@ -128,6 +147,7 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
>>> list_gp = ModelListGP(gp1, gp2)
>>> batch_gp = model_list_to_batched(list_gp)
"""
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = model_list.training
model_list.train()
models = model_list.models
Expand Down Expand Up @@ -260,6 +280,7 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
>>> batch_gp = SingleTaskGP(train_X, train_Y)
>>> list_gp = batched_to_model_list(batch_gp)
"""
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = batch_model.training
batch_model.train()
# TODO: Add support for HeteroskedasticSingleTaskGP.
Expand Down Expand Up @@ -363,6 +384,7 @@ def batched_multi_output_to_single_output(
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y)
>>> batch_so_gp = batched_multioutput_to_single_output(batch_gp)
"""
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = batch_mo_model.training
batch_mo_model.train()
# TODO: Add support for HeteroskedasticSingleTaskGP.
Expand Down
35 changes: 32 additions & 3 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from botorch.exceptions import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models import (
HeteroskedasticSingleTaskGP,
ModelListGP,
Expand All @@ -16,6 +17,7 @@
from botorch.models.converter import (
batched_multi_output_to_single_output,
batched_to_model_list,
DEPRECATION_MESSAGE,
model_list_to_batched,
)
from botorch.models.transforms.input import AppendFeatures, Normalize
Expand All @@ -25,6 +27,7 @@
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.priors import LogNormalPrior


class TestConverters(BotorchTestCase):
Expand All @@ -41,7 +44,8 @@ def test_batched_to_model_list(self):
self.assertIsInstance(list_gp.models[0].likelihood, GaussianLikelihood)
# test observed noise
batch_gp = SingleTaskGP(train_X, train_Y, torch.rand_like(train_Y))
list_gp = batched_to_model_list(batch_gp)
with self.assertWarnsRegex(DeprecationWarning, DEPRECATION_MESSAGE):
list_gp = batched_to_model_list(batch_gp)
self.assertIsInstance(list_gp, ModelListGP)
self.assertIsInstance(
list_gp.models[0].likelihood, FixedNoiseGaussianLikelihood
Expand Down Expand Up @@ -108,7 +112,8 @@ def test_model_list_to_batched(self):
self.assertIsInstance(batch_gp, SingleTaskGP)
self.assertIsInstance(batch_gp.likelihood, GaussianLikelihood)
# test degenerate (single model)
batch_gp = model_list_to_batched(ModelListGP(gp1))
with self.assertWarnsRegex(DeprecationWarning, DEPRECATION_MESSAGE):
batch_gp = model_list_to_batched(ModelListGP(gp1))
self.assertEqual(batch_gp._num_outputs, 1)
# test mixing different likelihoods
gp2 = SingleTaskGP(train_X, train_Y1, torch.ones_like(train_Y1))
Expand Down Expand Up @@ -240,6 +245,27 @@ def test_model_list_to_batched(self):
with self.assertRaises(UnsupportedError):
model_list_to_batched(list_gp)

def test_model_list_to_batched_with_different_prior(self) -> None:
# The goal is to test priors that don't have their parameters
# recorded in the state dict.
train_X = torch.rand(10, 2, device=self.device, dtype=torch.double)
gp1 = SingleTaskGP(
train_X=train_X,
train_Y=train_X.sum(dim=-1, keepdim=True),
covar_module=RBFKernel(
ard_num_dims=2, lengthscale_prior=LogNormalPrior(3.0, 6.0)
),
)
gp2 = SingleTaskGP(
train_X=train_X,
train_Y=train_X.max(dim=-1, keepdim=True).values,
covar_module=RBFKernel(
ard_num_dims=2, lengthscale_prior=LogNormalPrior(2.0, 4.0)
),
)
with self.assertWarnsRegex(BotorchWarning, "Model converter cannot verify"):
model_list_to_batched(ModelListGP(gp1, gp2))

def test_roundtrip(self):
for dtype in (torch.float, torch.double):
train_X = torch.rand(10, 2, device=self.device, dtype=dtype)
Expand Down Expand Up @@ -288,7 +314,10 @@ def test_batched_multi_output_to_single_output(self):
dim=1,
)
batched_mo_model = SingleTaskGP(train_X, train_Y)
batched_so_model = batched_multi_output_to_single_output(batched_mo_model)
with self.assertWarnsRegex(DeprecationWarning, DEPRECATION_MESSAGE):
batched_so_model = batched_multi_output_to_single_output(
batched_mo_model
)
self.assertIsInstance(batched_so_model, SingleTaskGP)
self.assertEqual(batched_so_model.num_outputs, 1)
# test non-batched models
Expand Down
Loading