Skip to content

Commit

Permalink
Put input transforms into train mode before converting models (#1283)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1283

Fixes #1273

During model construction, input transforms should be in `train` mode (so that they only apply if `transform_on_train` is true).
Having the input transforms in eval mode leads to buggy behavior due to `transformed_X` getting transformed when it shouldn't.

Reviewed By: Balandat

Differential Revision: D37542474

fbshipit-source-id: 7986829612f857995997036f9c48cbeb56d75ceb
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jul 5, 2022
1 parent 7ce7c6d commit a675968
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
6 changes: 6 additions & 0 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch

# construct the batched GP model
input_transform = getattr(models[0], "input_transform", None)
if input_transform is not None:
input_transform.train()
batch_gp = models[0].__class__(input_transform=input_transform, **kwargs)
adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
batch_state_dict=batch_gp.state_dict(), input_transform=input_transform
Expand Down Expand Up @@ -220,6 +222,8 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
"Conversion of MixedSingleTaskGP is currently not supported."
)
input_transform = getattr(batch_model, "input_transform", None)
if input_transform is not None:
input_transform.train()
outcome_transform = getattr(batch_model, "outcome_transform", None)
batch_sd = batch_model.state_dict()

Expand Down Expand Up @@ -324,6 +328,8 @@ def batched_multi_output_to_single_output(
"Conversion of models with custom likelihoods is currently unsupported."
)
input_transform = getattr(batch_mo_model, "input_transform", None)
if input_transform is not None:
input_transform.train()
batch_sd = batch_mo_model.state_dict()

# TODO: add support for outcome transforms.
Expand Down
39 changes: 34 additions & 5 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
batched_to_model_list,
model_list_to_batched,
)
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.input import AppendFeatures, Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.utils.testing import BotorchTestCase
from gpytorch.likelihoods import GaussianLikelihood
Expand Down Expand Up @@ -80,6 +80,16 @@ def test_batched_to_model_list(self):
expected_octf.__getattr__(attr_name),
)
)
# test with AppendFeatures
input_tf = AppendFeatures(
feature_set=torch.rand(2, 1, device=self.device, dtype=dtype)
)
batch_gp = SingleTaskGP(
train_X, train_Y, outcome_transform=octf, input_transform=input_tf
).eval()
list_gp = batched_to_model_list(batch_gp)
self.assertIsInstance(list_gp, ModelListGP)
self.assertIsInstance(list_gp.models[0].input_transform, AppendFeatures)

def test_model_list_to_batched(self):
for dtype in (torch.float, torch.double):
Expand Down Expand Up @@ -167,6 +177,16 @@ def test_model_list_to_batched(self):
self.assertTrue(
torch.equal(batch_gp.input_transform.bounds, input_tf.bounds)
)
# test with AppendFeatures
input_tf3 = AppendFeatures(
feature_set=torch.rand(2, 1, device=self.device, dtype=dtype)
)
gp1_ = SingleTaskGP(train_X, train_Y1, input_transform=input_tf3)
gp2_ = SingleTaskGP(train_X, train_Y2, input_transform=input_tf3)
list_gp = ModelListGP(gp1_, gp2_).eval()
batch_gp = model_list_to_batched(list_gp)
self.assertIsInstance(batch_gp, SingleTaskGP)
self.assertIsInstance(batch_gp.input_transform, AppendFeatures)
# test different input transforms
input_tf2 = Normalize(
d=2,
Expand All @@ -177,7 +197,7 @@ def test_model_list_to_batched(self):
gp1_ = SingleTaskGP(train_X, train_Y1, input_transform=input_tf)
gp2_ = SingleTaskGP(train_X, train_Y2, input_transform=input_tf2)
list_gp = ModelListGP(gp1_, gp2_)
with self.assertRaises(UnsupportedError):
with self.assertRaisesRegex(UnsupportedError, "have the same"):
model_list_to_batched(list_gp)

# test batched input transform
Expand Down Expand Up @@ -292,17 +312,26 @@ def test_batched_multi_output_to_single_output(self):
self.assertTrue(
torch.equal(batch_so_model.input_transform.bounds, input_tf.bounds)
)
# test with AppendFeatures
input_tf = AppendFeatures(
feature_set=torch.rand(2, 1, device=self.device, dtype=dtype)
)
batched_mo_model = SingleTaskGP(
train_X, train_Y, input_transform=input_tf
).eval()
batch_so_model = batched_multi_output_to_single_output(batched_mo_model)
self.assertIsInstance(batch_so_model.input_transform, AppendFeatures)

# test batched input transform
input_tf2 = Normalize(
input_tf = Normalize(
d=2,
bounds=torch.tensor(
[[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype
),
batch_shape=torch.Size([2]),
)
batched_mo_model = SingleTaskGP(train_X, train_Y, input_transform=input_tf2)
batched_so_model = batched_multi_output_to_single_output(batched_mo_model)
batched_mo_model = SingleTaskGP(train_X, train_Y, input_transform=input_tf)
batch_so_model = batched_multi_output_to_single_output(batched_mo_model)
self.assertIsInstance(batch_so_model.input_transform, Normalize)
self.assertTrue(
torch.equal(batch_so_model.input_transform.bounds, input_tf.bounds)
Expand Down

0 comments on commit a675968

Please sign in to comment.