From a40048158049ba0c00bc95a8c52e1dbebe505b53 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 17 Oct 2022 16:46:02 -0700 Subject: [PATCH] Fix input transform bug when sequentially training a `BatchedMultiOutputModel` (#1454) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1454 This fixes a bug where, due to the input transform being in `train` mode while the model is in `eval` mode, the learnable bounds of input transforms, such as Normalize, would be wrongly updated during model training, effectively disabling the transforms. This is a serious bug that was leading to using the model trained with normalized inputs on a model without any normalization. It has been around for a while, since https://github.com/pytorch/botorch/pull/1283 :'(. Differential Revision: D40453200 fbshipit-source-id: 0d60ac5f23efe06e200713555dfa109b612d9290 --- botorch/models/converter.py | 19 +++++++++---------- test/test_fit.py | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/botorch/models/converter.py b/botorch/models/converter.py index 309adb00e0..04e292bf3d 100644 --- a/botorch/models/converter.py +++ b/botorch/models/converter.py @@ -127,6 +127,8 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch >>> list_gp = ModelListGP(gp1, gp2) >>> batch_gp = model_list_to_batched(list_gp) """ + was_training = model_list.training + model_list.train() models = model_list.models _check_compatibility(models) @@ -159,9 +161,6 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch # construct the batched GP model input_transform = getattr(models[0], "input_transform", None) - if input_transform is not None: - input_transform.train() - batch_gp = models[0].__class__(input_transform=input_transform, **kwargs) adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys( batch_state_dict=batch_gp.state_dict(), input_transform=input_transform @@ -201,7 +200,7 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch # load the state dict into the new model batch_gp.load_state_dict(batch_state_dict) - return batch_gp + return batch_gp.train(mode=was_training) def _batched_kernel(kernel, batch_length: int): @@ -260,6 +259,8 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model >>> batch_gp = SingleTaskGP(train_X, train_Y) >>> list_gp = batched_to_model_list(batch_gp) """ + was_training = batch_model.training + batch_model.train() # TODO: Add support for HeteroskedasticSingleTaskGP. if isinstance(batch_model, HeteroskedasticSingleTaskGP): raise NotImplementedError( @@ -270,8 +271,6 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model "Conversion of MixedSingleTaskGP is currently not supported." ) input_transform = getattr(batch_model, "input_transform", None) - if input_transform is not None: - input_transform.train() outcome_transform = getattr(batch_model, "outcome_transform", None) batch_sd = batch_model.state_dict() @@ -326,7 +325,7 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model model.load_state_dict(sd) models.append(model) - return ModelListGP(*models) + return ModelListGP(*models).train(mode=was_training) def batched_multi_output_to_single_output( @@ -363,6 +362,8 @@ def batched_multi_output_to_single_output( >>> batch_mo_gp = SingleTaskGP(train_X, train_Y) >>> batch_so_gp = batched_multioutput_to_single_output(batch_gp) """ + was_training = batch_mo_model.training + batch_mo_model.train() # TODO: Add support for HeteroskedasticSingleTaskGP. if isinstance(batch_mo_model, HeteroskedasticSingleTaskGP): raise NotImplementedError( @@ -376,8 +377,6 @@ def batched_multi_output_to_single_output( "Conversion of models with custom likelihoods is currently unsupported." ) input_transform = getattr(batch_mo_model, "input_transform", None) - if input_transform is not None: - input_transform.train() batch_sd = batch_mo_model.state_dict() # TODO: add support for outcome transforms. @@ -400,7 +399,7 @@ def batched_multi_output_to_single_output( input_transform=input_transform, **kwargs ) single_outcome_model.load_state_dict(batch_sd) - return single_outcome_model + return single_outcome_model.train(mode=was_training) def _get_adjusted_batch_keys( diff --git a/test/test_fit.py b/test/test_fit.py index 937c99ca75..a673244b86 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -9,6 +9,7 @@ from copy import deepcopy from itertools import product from typing import Iterable, Optional +from unittest import mock from unittest.mock import MagicMock, patch from warnings import catch_warnings, warn, WarningMessage @@ -16,7 +17,9 @@ from botorch import fit from botorch.exceptions.errors import ModelFittingError, UnsupportedError from botorch.exceptions.warnings import OptimizationWarning +from botorch.fit import fit_gpytorch_mll from botorch.models import FixedNoiseGP, HeteroskedasticSingleTaskGP, SingleTaskGP +from botorch.models.converter import batched_to_model_list from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize from botorch.optim.utils import ( @@ -500,3 +503,24 @@ def mock_fit_gpytorch_mll(*args, **kwargs): self.assertEqual(converter.call_count, 1) self.assertTrue(any(str(exception) in str(w.message) for w in ws)) + + +class TestFitOther(BotorchTestCase): + def test_fit_with_converter(self): + # Check that sequential optimization using converter does not + # break input transforms. + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + X = torch.rand(5, 2, **tkwargs) * 10 + Y = X**2 + intf = Normalize(2) + model = SingleTaskGP(X, Y, input_transform=intf) + mll = ExactMarginalLogLikelihood(model.likelihood, model) + with mock.patch( + f"{fit_gpytorch_mll.__module__}.batched_to_model_list", + wraps=batched_to_model_list, + ) as wrapped_converter: + fit_gpytorch_mll(mll) + wrapped_converter.assert_called_once() + self.assertFalse(torch.allclose(intf.mins, torch.zeros(1, 2, **tkwargs))) + self.assertFalse(torch.allclose(intf.ranges, torch.ones(1, 2, **tkwargs)))