From 30eddaf990b406747b601241ca5be225769c2037 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 8 Jun 2023 21:49:39 -0700 Subject: [PATCH] Fix state dict loading behavior for `Standardize` transform (#1875) Summary: Addresses https://github.com/pytorch/botorch/issues/1874 Pull Request resolved: https://github.com/pytorch/botorch/pull/1875 Reviewed By: SebastianAment Differential Revision: D46547461 Pulled By: Balandat fbshipit-source-id: 17e925d33996b062bdd20dd7663fc09e58011b02 --- botorch/acquisition/objective.py | 1 + botorch/models/transforms/outcome.py | 37 ++++++++++++++----- .../multi_objective/test_objective.py | 1 + test/models/transforms/test_outcome.py | 21 +++++++++++ 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/botorch/acquisition/objective.py b/botorch/acquisition/objective.py index 4f67751f3b..832e2840f8 100644 --- a/botorch/acquisition/objective.py +++ b/botorch/acquisition/objective.py @@ -261,6 +261,7 @@ def __init__(self, Y_mean: Tensor, Y_std: Tensor) -> None: self.outcome_transform.means = Y_mean.unsqueeze(0) self.outcome_transform.stdvs = Y_std_unsqueezed self.outcome_transform._stdvs_sq = Y_std_unsqueezed.pow(2) + self.outcome_transform._is_trained = torch.tensor(True) self.outcome_transform.eval() def evaluate(self, Y: Tensor) -> Tensor: diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 07937511d6..4745ee9734 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -22,9 +22,11 @@ from __future__ import annotations +import warnings + from abc import ABC, abstractmethod from collections import OrderedDict -from typing import List, Optional, Tuple, Union +from typing import Any, List, Mapping, Optional, Tuple, Union import torch from botorch.models.transforms.utils import ( @@ -245,14 +247,28 @@ def __init__( standardization (if lower, only de-mean the data). """ super().__init__() - self.register_buffer("means", None) - self.register_buffer("stdvs", None) - self.register_buffer("_stdvs_sq", None) + self.register_buffer("means", torch.zeros(*batch_shape, 1, m)) + self.register_buffer("stdvs", torch.ones(*batch_shape, 1, m)) + self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, 1, m)) + self.register_buffer("_is_trained", torch.tensor(False)) self._outputs = normalize_indices(outputs, d=m) self._m = m self._batch_shape = batch_shape self._min_stdv = min_stdv + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True + ) -> None: + r"""Custom logic for loading the state dict.""" + if "_is_trained" not in state_dict: + warnings.warn( + "Key '_is_trained' not found in state_dict. Setting to True. " + "In a future release, this will result in an error.", + DeprecationWarning, + ) + state_dict = {**state_dict, "_is_trained": torch.tensor(True)} + super().load_state_dict(state_dict, strict=strict) + def forward( self, Y: Tensor, Yvar: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: @@ -295,6 +311,7 @@ def forward( self.means = means self.stdvs = stdvs self._stdvs_sq = stdvs.pow(2) + self._is_trained = torch.tensor(True) Y_tf = (Y - self.means) / self.stdvs Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None @@ -325,10 +342,10 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform: batch_shape=self._batch_shape, min_stdv=self._min_stdv, ) - if self.means is not None: - new_tf.means = self.means[..., nlzd_idcs] - new_tf.stdvs = self.stdvs[..., nlzd_idcs] - new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] + new_tf.means = self.means[..., nlzd_idcs] + new_tf.stdvs = self.stdvs[..., nlzd_idcs] + new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs] + new_tf._is_trained = self._is_trained if not self.training: new_tf.eval() return new_tf @@ -349,7 +366,7 @@ def untransform( - The un-standardized outcome observations. - The un-standardized observation noise (if applicable). """ - if self.means is None: + if not self._is_trained: raise RuntimeError( "`Standardize` transforms must be called on outcome data " "(e.g. `transform(Y)`) before calling `untransform`, since " @@ -382,7 +399,7 @@ def untransform_posterior( "Standardize does not yet support output selection for " "untransform_posterior" ) - if self.means is None: + if not self._is_trained: raise RuntimeError( "`Standardize` transforms must be called on outcome data " "(e.g. `transform(Y)`) before calling `untransform_posterior`, since " diff --git a/test/acquisition/multi_objective/test_objective.py b/test/acquisition/multi_objective/test_objective.py index 8b52b32841..02643d34ca 100644 --- a/test/acquisition/multi_objective/test_objective.py +++ b/test/acquisition/multi_objective/test_objective.py @@ -188,6 +188,7 @@ def test_unstandardize_mo_objective(self): tf.means = Y_mean tf.stdvs = Y_std tf._stdvs_sq = Y_std.pow(2) + tf._is_trained = torch.tensor(True) tf.eval() expected_posterior = tf.untransform_posterior(mock_posterior) self.assertTrue( diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index c57add2ad1..44dadf79ef 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -342,6 +342,27 @@ def test_standardize(self): with self.assertRaises(NotImplementedError): tf.untransform_posterior(None) + def test_standardize_state_dict(self): + for m in (1, 2): + with self.subTest(m=2): + transform = Standardize(m=m) + self.assertFalse(transform._is_trained) + self.assertTrue(transform.training) + Y = torch.rand(2, m) + transform(Y) + state_dict = transform.state_dict() + new_transform = Standardize(m=m) + self.assertFalse(new_transform._is_trained) + new_transform.load_state_dict(state_dict) + self.assertTrue(new_transform._is_trained) + # test deprecation error when loading state dict without _is_trained + state_dict.pop("_is_trained") + with self.assertWarnsRegex( + DeprecationWarning, + "Key '_is_trained' not found in state_dict. Setting to True.", + ): + new_transform.load_state_dict(state_dict) + def test_log(self): ms = (1, 2) batch_shapes = (torch.Size(), torch.Size([2]))