Skip to content

Commit

Permalink
Fix a bug in _fit_multioutput_independent that failed mll comparison
Browse files Browse the repository at this point in the history
Summary: The model of the repacked mll was in `eval` mode while the `repacked_mll` itself was in `train` mode, leading to the loss of `mll` and `repacked_mll` evaluating differently, thus failing the model fitting. Thankfully codecov caught it!

Differential Revision: D40477774

fbshipit-source-id: 8e678a6f603a90ea3b46f13a12e236ecbee1d88d
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 18, 2022
1 parent ab50b85 commit 6c9ed71
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _fit_multioutput_independent(
unpacked_mll = fit_gpytorch_mll(unpacked_mll, **kwargs)

# Repackage submodels and copy over state_dict
repacked_model = model_list_to_batched(unpacked_mll.model)
repacked_model = model_list_to_batched(unpacked_mll.model.train())
repacked_mll = type(mll)(repacked_model.likelihood, repacked_model)
with state_rollback_ctx(mll, device=device("cpu")) as ckpt:
mll.load_state_dict(repacked_mll.state_dict())
Expand Down
16 changes: 13 additions & 3 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import warnings
from contextlib import nullcontext
from copy import deepcopy
from itertools import product
Expand All @@ -16,7 +17,7 @@
import torch
from botorch import fit
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models import FixedNoiseGP, HeteroskedasticSingleTaskGP, SingleTaskGP
from botorch.models.converter import batched_to_model_list
Expand Down Expand Up @@ -412,7 +413,10 @@ def _test_main(self, mll, ckpt):
return

optimizer = MockOptimizer()
with state_rollback_ctx(mll, checkpoint=ckpt), debug(True):
with state_rollback_ctx(mll, checkpoint=ckpt), debug(
True
), warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always", BotorchWarning)
try:
fit._fit_multioutput_independent(
mll,
Expand All @@ -425,6 +429,7 @@ def _test_main(self, mll, ckpt):
except Exception:
pass # exception handling tested separately
else:
self.assertEqual(len(ws), 0) # Model repacking did not fail.
self.assertFalse(mll.training)
self.assertEqual(optimizer.call_count, mll.model.num_outputs)
self.assertTrue(
Expand Down Expand Up @@ -519,8 +524,13 @@ def test_fit_with_converter(self):
with mock.patch(
f"{fit_gpytorch_mll.__module__}.batched_to_model_list",
wraps=batched_to_model_list,
) as wrapped_converter:
) as wrapped_converter, warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always", BotorchWarning)
fit_gpytorch_mll(mll)
# Check that MLL repacking succeeded.
self.assertFalse(
any("Training loss of repacked model" in str(w.message) for w in ws)
)
wrapped_converter.assert_called_once()
self.assertFalse(torch.allclose(intf.mins, torch.zeros(1, 2, **tkwargs)))
self.assertFalse(torch.allclose(intf.ranges, torch.ones(1, 2, **tkwargs)))

0 comments on commit 6c9ed71

Please sign in to comment.