From bac446aa69dfa70b0fbf3cb0fdde7d5c963c13b9 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 30 Jun 2022 09:28:02 -0700 Subject: [PATCH] Put input transforms into train mode before converting models (#1283) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1283 Fixes #1273 During model construction, input transforms should be in `train` mode (so that they only apply if `transform_on_train` is true). Having the input transforms in eval mode leads to buggy behavior due to `transformed_X` getting transformed when it shouldn't. Differential Revision: D37542474 fbshipit-source-id: f4278294de5d83d967f3d21c312370e562cf372c --- botorch/models/converter.py | 6 ++++++ test/models/test_converter.py | 39 ++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/botorch/models/converter.py b/botorch/models/converter.py index d1f5ae030f..b57992a880 100644 --- a/botorch/models/converter.py +++ b/botorch/models/converter.py @@ -152,6 +152,8 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch # construct the batched GP model input_transform = getattr(models[0], "input_transform", None) + if input_transform is not None: + input_transform.train() batch_gp = models[0].__class__(input_transform=input_transform, **kwargs) adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys( batch_state_dict=batch_gp.state_dict(), input_transform=input_transform @@ -220,6 +222,8 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model "Conversion of MixedSingleTaskGP is currently not supported." ) input_transform = getattr(batch_model, "input_transform", None) + if input_transform is not None: + input_transform.train() outcome_transform = getattr(batch_model, "outcome_transform", None) batch_sd = batch_model.state_dict() @@ -324,6 +328,8 @@ def batched_multi_output_to_single_output( "Conversion of models with custom likelihoods is currently unsupported." ) input_transform = getattr(batch_mo_model, "input_transform", None) + if input_transform is not None: + input_transform.train() batch_sd = batch_mo_model.state_dict() # TODO: add support for outcome transforms. diff --git a/test/models/test_converter.py b/test/models/test_converter.py index 5fbf6cca2d..87a2822646 100644 --- a/test/models/test_converter.py +++ b/test/models/test_converter.py @@ -19,7 +19,7 @@ batched_to_model_list, model_list_to_batched, ) -from botorch.models.transforms.input import Normalize +from botorch.models.transforms.input import AppendFeatures, Normalize from botorch.models.transforms.outcome import Standardize from botorch.utils.testing import BotorchTestCase from gpytorch.likelihoods import GaussianLikelihood @@ -80,6 +80,16 @@ def test_batched_to_model_list(self): expected_octf.__getattr__(attr_name), ) ) + # test with AppendFeatures + input_tf = AppendFeatures( + feature_set=torch.rand(2, 1, device=self.device, dtype=dtype) + ) + batch_gp = SingleTaskGP( + train_X, train_Y, outcome_transform=octf, input_transform=input_tf + ).eval() + list_gp = batched_to_model_list(batch_gp) + self.assertIsInstance(list_gp, ModelListGP) + self.assertIsInstance(list_gp.models[0].input_transform, AppendFeatures) def test_model_list_to_batched(self): for dtype in (torch.float, torch.double): @@ -167,6 +177,16 @@ def test_model_list_to_batched(self): self.assertTrue( torch.equal(batch_gp.input_transform.bounds, input_tf.bounds) ) + # test with AppendFeatures + input_tf3 = AppendFeatures( + feature_set=torch.rand(2, 1, device=self.device, dtype=dtype) + ) + gp1_ = SingleTaskGP(train_X, train_Y1, input_transform=input_tf3) + gp2_ = SingleTaskGP(train_X, train_Y2, input_transform=input_tf3) + list_gp = ModelListGP(gp1_, gp2_).eval() + batch_gp = model_list_to_batched(list_gp) + self.assertIsInstance(batch_gp, SingleTaskGP) + self.assertIsInstance(batch_gp.input_transform, AppendFeatures) # test different input transforms input_tf2 = Normalize( d=2, @@ -177,7 +197,7 @@ def test_model_list_to_batched(self): gp1_ = SingleTaskGP(train_X, train_Y1, input_transform=input_tf) gp2_ = SingleTaskGP(train_X, train_Y2, input_transform=input_tf2) list_gp = ModelListGP(gp1_, gp2_) - with self.assertRaises(UnsupportedError): + with self.assertRaisesRegex(UnsupportedError, "have the same"): model_list_to_batched(list_gp) # test batched input transform @@ -292,17 +312,26 @@ def test_batched_multi_output_to_single_output(self): self.assertTrue( torch.equal(batch_so_model.input_transform.bounds, input_tf.bounds) ) + # test with AppendFeatures + input_tf = AppendFeatures( + feature_set=torch.rand(2, 1, device=self.device, dtype=dtype) + ) + batched_mo_model = SingleTaskGP( + train_X, train_Y, input_transform=input_tf + ).eval() + batch_so_model = batched_multi_output_to_single_output(batched_mo_model) + self.assertIsInstance(batch_so_model.input_transform, AppendFeatures) # test batched input transform - input_tf2 = Normalize( + input_tf = Normalize( d=2, bounds=torch.tensor( [[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype ), batch_shape=torch.Size([2]), ) - batched_mo_model = SingleTaskGP(train_X, train_Y, input_transform=input_tf2) - batched_so_model = batched_multi_output_to_single_output(batched_mo_model) + batched_mo_model = SingleTaskGP(train_X, train_Y, input_transform=input_tf) + batch_so_model = batched_multi_output_to_single_output(batched_mo_model) self.assertIsInstance(batch_so_model.input_transform, Normalize) self.assertTrue( torch.equal(batch_so_model.input_transform.bounds, input_tf.bounds)