diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 90994d3fac..611944a3d3 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -368,14 +368,26 @@ def __init__( torch.broadcast_shapes(coefficient.shape, offset.shape) self._d = d - self.register_buffer("coefficient", coefficient) - self.register_buffer("offset", offset) + self.register_buffer("_coefficient", coefficient) + self.register_buffer("_offset", offset) self.batch_shape = batch_shape self.transform_on_train = transform_on_train self.transform_on_eval = transform_on_eval self.transform_on_fantasize = transform_on_fantasize self.reverse = reverse + @property + def coefficient(self) -> Tensor: + r"""The tensor of linear coefficients.""" + coeff = self._coefficient + return coeff if self.learn_coefficients and self.training else coeff.detach() + + @property + def offset(self) -> Tensor: + r"""The tensor of offset coefficients.""" + offset = self._offset + return offset if self.learn_coefficients and self.training else offset.detach() + @property def learn_coefficients(self) -> bool: return getattr(self, "_learn_coefficients", False) @@ -459,8 +471,8 @@ def _check_shape(self, X: Tensor) -> None: def _to(self, X: Tensor) -> None: r"""Makes coefficient and offset have same device and dtype as X.""" - self.coefficient = self.coefficient.to(X) - self.offset = self.offset.to(X) + self._coefficient = self.coefficient.to(X) + self._offset = self.offset.to(X) def _update_coefficients(self, X: Tensor) -> None: r"""Updates affine coefficients. Implemented by subclasses, @@ -569,9 +581,9 @@ def _update_coefficients(self, X) -> None: # Aggregate mins and ranges over extra batch and marginal dims batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X` reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2) - self.offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2) - self.coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset - self.coefficient.clamp_(min=self.min_range) + self._offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2) + self._coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset + self._coefficient.clamp_(min=self.min_range) class InputStandardize(AffineInputTransform): @@ -641,11 +653,11 @@ def _update_coefficients(self, X: Tensor) -> None: # Aggregate means and standard deviations over extra batch and marginal dims batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X` reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2) - coefficient, self.offset = ( + coefficient, self._offset = ( values.unsqueeze(-2) for values in torch.std_mean(X, dim=reduce_dims, unbiased=True) ) - self.coefficient = coefficient.clamp_(min=self.min_std) + self._coefficient = coefficient.clamp_(min=self.min_std) class Round(InputTransform, Module): diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 168fa6f5ef..456a8b47ef 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -165,6 +165,20 @@ def test_normalize(self): self.assertTrue( torch.equal(nlz.mins, bounds[..., 1:2, :] - bounds[..., 0:1, :]) ) + # with grad + bounds.requires_grad = True + bounds = bounds * 2 + self.assertIsNotNone(bounds.grad_fn) + nlz = Normalize(d=2, bounds=bounds) + # Set learn_coefficients=True for testing. + nlz.learn_coefficients = True + # We have grad in train mode. + self.assertIsNotNone(nlz.coefficient.grad_fn) + self.assertIsNotNone(nlz.offset.grad_fn) + # Grad is detached in eval mode. + nlz.eval() + self.assertIsNone(nlz.coefficient.grad_fn) + self.assertIsNone(nlz.offset.grad_fn) # basic init, provided indices with self.assertRaises(ValueError): @@ -326,6 +340,18 @@ def test_normalize(self): nlz10 = Normalize(d=3, batch_shape=batch_shape, indices=[0, 2]) self.assertFalse(nlz9.equals(nlz10)) + # test with grad + nlz = Normalize(d=1) + X.requires_grad = True + X = X * 2 + self.assertIsNotNone(X.grad_fn) + nlz(X) + self.assertIsNotNone(nlz.coefficient.grad_fn) + self.assertIsNotNone(nlz.offset.grad_fn) + nlz.eval() + self.assertIsNone(nlz.coefficient.grad_fn) + self.assertIsNone(nlz.offset.grad_fn) + def test_standardize(self): for dtype in (torch.float, torch.double): # basic init