Skip to content

Commit

Permalink
Detach coefficient and offset in AffineTransform in eval mode (#1642)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1642

See #1635. These should only retain grad if learning the coefficient and while in train mode.

Differential Revision: D42700421

fbshipit-source-id: 907f72a6f585f55f1496e23f58bc463eac063860
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 30, 2023
1 parent 72e872a commit 816425e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
30 changes: 21 additions & 9 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,26 @@ def __init__(
torch.broadcast_shapes(coefficient.shape, offset.shape)

self._d = d
self.register_buffer("coefficient", coefficient)
self.register_buffer("offset", offset)
self.register_buffer("_coefficient", coefficient)
self.register_buffer("_offset", offset)
self.batch_shape = batch_shape
self.transform_on_train = transform_on_train
self.transform_on_eval = transform_on_eval
self.transform_on_fantasize = transform_on_fantasize
self.reverse = reverse

@property
def coefficient(self) -> Tensor:
r"""The tensor of linear coefficients."""
coeff = self._coefficient
return coeff if self.learn_coefficients and self.training else coeff.detach()

@property
def offset(self) -> Tensor:
r"""The tensor of offset coefficients."""
offset = self._offset
return offset if self.learn_coefficients and self.training else offset.detach()

@property
def learn_coefficients(self) -> bool:
return getattr(self, "_learn_coefficients", False)
Expand Down Expand Up @@ -459,8 +471,8 @@ def _check_shape(self, X: Tensor) -> None:

def _to(self, X: Tensor) -> None:
r"""Makes coefficient and offset have same device and dtype as X."""
self.coefficient = self.coefficient.to(X)
self.offset = self.offset.to(X)
self._coefficient = self.coefficient.to(X)
self._offset = self.offset.to(X)

def _update_coefficients(self, X: Tensor) -> None:
r"""Updates affine coefficients. Implemented by subclasses,
Expand Down Expand Up @@ -569,9 +581,9 @@ def _update_coefficients(self, X) -> None:
# Aggregate mins and ranges over extra batch and marginal dims
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2)
self.offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2)
self.coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset
self.coefficient.clamp_(min=self.min_range)
self._offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2)
self._coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset
self._coefficient.clamp_(min=self.min_range)


class InputStandardize(AffineInputTransform):
Expand Down Expand Up @@ -641,11 +653,11 @@ def _update_coefficients(self, X: Tensor) -> None:
# Aggregate means and standard deviations over extra batch and marginal dims
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2)
coefficient, self.offset = (
coefficient, self._offset = (
values.unsqueeze(-2)
for values in torch.std_mean(X, dim=reduce_dims, unbiased=True)
)
self.coefficient = coefficient.clamp_(min=self.min_std)
self._coefficient = coefficient.clamp_(min=self.min_std)


class Round(InputTransform, Module):
Expand Down
26 changes: 26 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,20 @@ def test_normalize(self):
self.assertTrue(
torch.equal(nlz.mins, bounds[..., 1:2, :] - bounds[..., 0:1, :])
)
# with grad
bounds.requires_grad = True
bounds = bounds * 2
self.assertIsNotNone(bounds.grad_fn)
nlz = Normalize(d=2, bounds=bounds)
# Set learn_coefficients=True for testing.
nlz.learn_coefficients = True
# We have grad in train mode.
self.assertIsNotNone(nlz.coefficient.grad_fn)
self.assertIsNotNone(nlz.offset.grad_fn)
# Grad is detached in eval mode.
nlz.eval()
self.assertIsNone(nlz.coefficient.grad_fn)
self.assertIsNone(nlz.offset.grad_fn)

# basic init, provided indices
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -326,6 +340,18 @@ def test_normalize(self):
nlz10 = Normalize(d=3, batch_shape=batch_shape, indices=[0, 2])
self.assertFalse(nlz9.equals(nlz10))

# test with grad
nlz = Normalize(d=1)
X.requires_grad = True
X = X * 2
self.assertIsNotNone(X.grad_fn)
nlz(X)
self.assertIsNotNone(nlz.coefficient.grad_fn)
self.assertIsNotNone(nlz.offset.grad_fn)
nlz.eval()
self.assertIsNone(nlz.coefficient.grad_fn)
self.assertIsNone(nlz.offset.grad_fn)

def test_standardize(self):
for dtype in (torch.float, torch.double):
# basic init
Expand Down

0 comments on commit 816425e

Please sign in to comment.