diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index cd100bea8f..bb1dd41469 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -365,7 +365,7 @@ def __init__( raise ValueError("Elements of `indices` have to be smaller than `d`!") if len(indices.unique()) != len(indices): raise ValueError("Elements of `indices` tensor must be unique!") - self.indices = indices + self.register_buffer("indices", indices) torch.broadcast_shapes(coefficient.shape, offset.shape) self._d = d diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 4d53be6afd..1893f3d982 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -197,7 +197,13 @@ def test_normalize(self): self.assertEqual(nlz.mins.shape, torch.Size([1, 1])) self.assertEqual(nlz.ranges.shape, torch.Size([1, 1])) self.assertEqual(len(nlz.indices), 1) - self.assertTrue((nlz.indices == torch.tensor([0], dtype=torch.long)).all()) + nlz.to(device=self.device) + self.assertTrue( + ( + nlz.indices + == torch.tensor([0], dtype=torch.long, device=self.device) + ).all() + ) # test .to other_dtype = torch.float if dtype == torch.double else torch.double @@ -382,17 +388,25 @@ def test_standardize(self): self.assertEqual(stdz.means.shape, torch.Size([1, 1])) self.assertEqual(stdz.stds.shape, torch.Size([1, 1])) self.assertEqual(len(stdz.indices), 1) + stdz.to(device=self.device) self.assertTrue( - torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long)) + torch.equal( + stdz.indices, + torch.tensor([0], dtype=torch.long, device=self.device), + ) ) stdz = InputStandardize(d=2, indices=[0], batch_shape=torch.Size([3])) + stdz.to(device=self.device) self.assertTrue(stdz.training) self.assertEqual(stdz._d, 2) self.assertEqual(stdz.means.shape, torch.Size([3, 1, 1])) self.assertEqual(stdz.stds.shape, torch.Size([3, 1, 1])) self.assertEqual(len(stdz.indices), 1) self.assertTrue( - torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long)) + torch.equal( + stdz.indices, + torch.tensor([0], device=self.device, dtype=torch.long), + ) ) # test jitter