From 6c9ed7188eca9b88ce34af5bf94825dd237edd75 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 18 Oct 2022 10:01:48 -0700 Subject: [PATCH] Fix a bug in `_fit_multioutput_independent` that failed mll comparison Summary: The model of the repacked mll was in `eval` mode while the `repacked_mll` itself was in `train` mode, leading to the loss of `mll` and `repacked_mll` evaluating differently, thus failing the model fitting. Thankfully codecov caught it! Differential Revision: D40477774 fbshipit-source-id: 8e678a6f603a90ea3b46f13a12e236ecbee1d88d --- botorch/fit.py | 2 +- test/test_fit.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/botorch/fit.py b/botorch/fit.py index 43d4aecd9d..7eec93a895 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -338,7 +338,7 @@ def _fit_multioutput_independent( unpacked_mll = fit_gpytorch_mll(unpacked_mll, **kwargs) # Repackage submodels and copy over state_dict - repacked_model = model_list_to_batched(unpacked_mll.model) + repacked_model = model_list_to_batched(unpacked_mll.model.train()) repacked_mll = type(mll)(repacked_model.likelihood, repacked_model) with state_rollback_ctx(mll, device=device("cpu")) as ckpt: mll.load_state_dict(repacked_mll.state_dict()) diff --git a/test/test_fit.py b/test/test_fit.py index a673244b86..04d8f630bc 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import math +import warnings from contextlib import nullcontext from copy import deepcopy from itertools import product @@ -16,7 +17,7 @@ import torch from botorch import fit from botorch.exceptions.errors import ModelFittingError, UnsupportedError -from botorch.exceptions.warnings import OptimizationWarning +from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning from botorch.fit import fit_gpytorch_mll from botorch.models import FixedNoiseGP, HeteroskedasticSingleTaskGP, SingleTaskGP from botorch.models.converter import batched_to_model_list @@ -412,7 +413,10 @@ def _test_main(self, mll, ckpt): return optimizer = MockOptimizer() - with state_rollback_ctx(mll, checkpoint=ckpt), debug(True): + with state_rollback_ctx(mll, checkpoint=ckpt), debug( + True + ), warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always", BotorchWarning) try: fit._fit_multioutput_independent( mll, @@ -425,6 +429,7 @@ def _test_main(self, mll, ckpt): except Exception: pass # exception handling tested separately else: + self.assertEqual(len(ws), 0) # Model repacking did not fail. self.assertFalse(mll.training) self.assertEqual(optimizer.call_count, mll.model.num_outputs) self.assertTrue( @@ -519,8 +524,13 @@ def test_fit_with_converter(self): with mock.patch( f"{fit_gpytorch_mll.__module__}.batched_to_model_list", wraps=batched_to_model_list, - ) as wrapped_converter: + ) as wrapped_converter, warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always", BotorchWarning) fit_gpytorch_mll(mll) + # Check that MLL repacking succeeded. + self.assertFalse( + any("Training loss of repacked model" in str(w.message) for w in ws) + ) wrapped_converter.assert_called_once() self.assertFalse(torch.allclose(intf.mins, torch.zeros(1, 2, **tkwargs))) self.assertFalse(torch.allclose(intf.ranges, torch.ones(1, 2, **tkwargs)))