Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tensor input for integer_indices in Round transform #1709

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gpytorch import Module as GPyTorchModule
from gpytorch.constraints import GreaterThan
from gpytorch.priors import Prior
from torch import nn, Tensor
from torch import LongTensor, nn, Tensor
from torch.distributions import Kumaraswamy
from torch.nn import Module, ModuleDict
from torch.nn.functional import one_hot
Expand Down Expand Up @@ -708,7 +708,7 @@ class Round(InputTransform, Module):

def __init__(
self,
integer_indices: Optional[List[int]] = None,
integer_indices: Union[List[int], LongTensor, None] = None,
categorical_features: Optional[Dict[int, int]] = None,
transform_on_train: bool = True,
transform_on_eval: bool = True,
Expand Down Expand Up @@ -747,9 +747,9 @@ def __init__(
self.transform_on_train = transform_on_train
self.transform_on_eval = transform_on_eval
self.transform_on_fantasize = transform_on_fantasize
integer_indices = integer_indices or []
integer_indices = integer_indices if integer_indices is not None else []
self.register_buffer(
"integer_indices", torch.tensor(integer_indices, dtype=torch.long)
"integer_indices", torch.as_tensor(integer_indices, dtype=torch.long)
)
self.categorical_features = categorical_features or {}
self.approximate = approximate
Expand Down
7 changes: 7 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,13 @@ def test_round_transform(self):
self.assertFalse(round_tf.approximate)
self.assertEqual(round_tf.tau, 1e-3)

# With tensor indices.
round_tf = Round(
integer_indices=torch.tensor(int_idcs, dtype=dtype, device=self.device),
categorical_features=categorical_feats,
)
self.assertEqual(round_tf.integer_indices.tolist(), int_idcs)

# basic usage
for batch_shape, approx, categorical_features in itertools.product(
(torch.Size(), torch.Size([3])),
Expand Down