Skip to content

Commit

Permalink
Fix state dict loading behavior for Standardize transform (#1875)
Browse files Browse the repository at this point in the history
Summary:
Addresses #1874

Pull Request resolved: #1875

Reviewed By: SebastianAment

Differential Revision: D46547461

Pulled By: Balandat

fbshipit-source-id: 17e925d33996b062bdd20dd7663fc09e58011b02
  • Loading branch information
Balandat authored and facebook-github-bot committed Jun 9, 2023
1 parent f00057b commit 30eddaf
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
1 change: 1 addition & 0 deletions botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def __init__(self, Y_mean: Tensor, Y_std: Tensor) -> None:
self.outcome_transform.means = Y_mean.unsqueeze(0)
self.outcome_transform.stdvs = Y_std_unsqueezed
self.outcome_transform._stdvs_sq = Y_std_unsqueezed.pow(2)
self.outcome_transform._is_trained = torch.tensor(True)
self.outcome_transform.eval()

def evaluate(self, Y: Tensor) -> Tensor:
Expand Down
37 changes: 27 additions & 10 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@

from __future__ import annotations

import warnings

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
from typing import Any, List, Mapping, Optional, Tuple, Union

import torch
from botorch.models.transforms.utils import (
Expand Down Expand Up @@ -245,14 +247,28 @@ def __init__(
standardization (if lower, only de-mean the data).
"""
super().__init__()
self.register_buffer("means", None)
self.register_buffer("stdvs", None)
self.register_buffer("_stdvs_sq", None)
self.register_buffer("means", torch.zeros(*batch_shape, 1, m))
self.register_buffer("stdvs", torch.ones(*batch_shape, 1, m))
self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, 1, m))
self.register_buffer("_is_trained", torch.tensor(False))
self._outputs = normalize_indices(outputs, d=m)
self._m = m
self._batch_shape = batch_shape
self._min_stdv = min_stdv

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict."""
if "_is_trained" not in state_dict:
warnings.warn(
"Key '_is_trained' not found in state_dict. Setting to True. "
"In a future release, this will result in an error.",
DeprecationWarning,
)
state_dict = {**state_dict, "_is_trained": torch.tensor(True)}
super().load_state_dict(state_dict, strict=strict)

def forward(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -295,6 +311,7 @@ def forward(
self.means = means
self.stdvs = stdvs
self._stdvs_sq = stdvs.pow(2)
self._is_trained = torch.tensor(True)

Y_tf = (Y - self.means) / self.stdvs
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
Expand Down Expand Up @@ -325,10 +342,10 @@ def subset_output(self, idcs: List[int]) -> OutcomeTransform:
batch_shape=self._batch_shape,
min_stdv=self._min_stdv,
)
if self.means is not None:
new_tf.means = self.means[..., nlzd_idcs]
new_tf.stdvs = self.stdvs[..., nlzd_idcs]
new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs]
new_tf.means = self.means[..., nlzd_idcs]
new_tf.stdvs = self.stdvs[..., nlzd_idcs]
new_tf._stdvs_sq = self._stdvs_sq[..., nlzd_idcs]
new_tf._is_trained = self._is_trained
if not self.training:
new_tf.eval()
return new_tf
Expand All @@ -349,7 +366,7 @@ def untransform(
- The un-standardized outcome observations.
- The un-standardized observation noise (if applicable).
"""
if self.means is None:
if not self._is_trained:
raise RuntimeError(
"`Standardize` transforms must be called on outcome data "
"(e.g. `transform(Y)`) before calling `untransform`, since "
Expand Down Expand Up @@ -382,7 +399,7 @@ def untransform_posterior(
"Standardize does not yet support output selection for "
"untransform_posterior"
)
if self.means is None:
if not self._is_trained:
raise RuntimeError(
"`Standardize` transforms must be called on outcome data "
"(e.g. `transform(Y)`) before calling `untransform_posterior`, since "
Expand Down
1 change: 1 addition & 0 deletions test/acquisition/multi_objective/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_unstandardize_mo_objective(self):
tf.means = Y_mean
tf.stdvs = Y_std
tf._stdvs_sq = Y_std.pow(2)
tf._is_trained = torch.tensor(True)
tf.eval()
expected_posterior = tf.untransform_posterior(mock_posterior)
self.assertTrue(
Expand Down
21 changes: 21 additions & 0 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,27 @@ def test_standardize(self):
with self.assertRaises(NotImplementedError):
tf.untransform_posterior(None)

def test_standardize_state_dict(self):
for m in (1, 2):
with self.subTest(m=2):
transform = Standardize(m=m)
self.assertFalse(transform._is_trained)
self.assertTrue(transform.training)
Y = torch.rand(2, m)
transform(Y)
state_dict = transform.state_dict()
new_transform = Standardize(m=m)
self.assertFalse(new_transform._is_trained)
new_transform.load_state_dict(state_dict)
self.assertTrue(new_transform._is_trained)
# test deprecation error when loading state dict without _is_trained
state_dict.pop("_is_trained")
with self.assertWarnsRegex(
DeprecationWarning,
"Key '_is_trained' not found in state_dict. Setting to True.",
):
new_transform.load_state_dict(state_dict)

def test_log(self):
ms = (1, 2)
batch_shapes = (torch.Size(), torch.Size([2]))
Expand Down

0 comments on commit 30eddaf

Please sign in to comment.