From 2030ba06ba10e361d83d9f3807e24c552c08d8c0 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 27 Feb 2023 21:14:33 -0800 Subject: [PATCH] Allow tensor input for integer_indices in Round transform Summary: This simplifies storage in Ax. Without this change, `integer_indices = integer_indices or []` raises an error with tensors. Differential Revision: D43649436 fbshipit-source-id: f4195026911d18d0b61a6f1adde9d7b60d442cd7 --- botorch/models/transforms/input.py | 8 ++++---- test/models/transforms/test_input.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 09310163b5..cb536ca3e8 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -28,7 +28,7 @@ from gpytorch import Module as GPyTorchModule from gpytorch.constraints import GreaterThan from gpytorch.priors import Prior -from torch import nn, Tensor +from torch import LongTensor, nn, Tensor from torch.distributions import Kumaraswamy from torch.nn import Module, ModuleDict from torch.nn.functional import one_hot @@ -708,7 +708,7 @@ class Round(InputTransform, Module): def __init__( self, - integer_indices: Optional[List[int]] = None, + integer_indices: Union[List[int], LongTensor, None] = None, categorical_features: Optional[Dict[int, int]] = None, transform_on_train: bool = True, transform_on_eval: bool = True, @@ -747,9 +747,9 @@ def __init__( self.transform_on_train = transform_on_train self.transform_on_eval = transform_on_eval self.transform_on_fantasize = transform_on_fantasize - integer_indices = integer_indices or [] + integer_indices = integer_indices if integer_indices is not None else [] self.register_buffer( - "integer_indices", torch.tensor(integer_indices, dtype=torch.long) + "integer_indices", torch.as_tensor(integer_indices, dtype=torch.long) ) self.categorical_features = categorical_features or {} self.approximate = approximate diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 1893f3d982..98226dc442 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -595,6 +595,13 @@ def test_round_transform(self): self.assertFalse(round_tf.approximate) self.assertEqual(round_tf.tau, 1e-3) + # With tensor indices. + round_tf = Round( + integer_indices=torch.tensor(int_idcs, dtype=dtype, device=self.device), + categorical_features=categorical_feats, + ) + self.assertEqual(round_tf.integer_indices.tolist(), int_idcs) + # basic usage for batch_shape, approx, categorical_features in itertools.product( (torch.Size(), torch.Size([3])),