Skip to content

Commit

Permalink
Fix warning in test_normalize (#1876)
Browse files Browse the repository at this point in the history
Summary:
Return the proper type of `indices` in `get_init_args()` for `Normalize`. This would cause the following warning:

```
botorch/models/transforms/input.py:362: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    indices = torch.tensor(indices, dtype=torch.long)
```

Pull Request resolved: #1876

Reviewed By: SebastianAment

Differential Revision: D46547569

Pulled By: Balandat

fbshipit-source-id: 6f7f3e15d5d80851e68e9e3b60575b807c8c24a4
  • Loading branch information
Balandat authored and facebook-github-bot committed Jun 8, 2023
1 parent 6d329a8 commit f00057b
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(
d: int,
coefficient: Tensor,
offset: Tensor,
indices: Optional[List[int]] = None,
indices: Optional[Union[List[int], Tensor]] = None,
batch_shape: torch.Size = torch.Size(), # noqa: B008
transform_on_train: bool = True,
transform_on_eval: bool = True,
Expand All @@ -342,7 +342,8 @@ def __init__(
offset: Tensor of offset coefficients, shape must to be
broadcastable with `(batch_shape x n x d)`-dim input tensors.
indices: The indices of the inputs to transform. If omitted,
take all dimensions of the inputs into account.
take all dimensions of the inputs into account. Either a list of ints
or a Tensor of type `torch.long`.
batch_shape: The batch shape of the inputs (assuming input tensors
of shape `batch_shape x n x d`). If provided, perform individual
transformation per batch, otherwise uses a single transformation.
Expand All @@ -359,7 +360,9 @@ def __init__(
if (indices is not None) and (len(indices) == 0):
raise ValueError("`indices` list is empty!")
if (indices is not None) and (len(indices) > 0):
indices = torch.tensor(indices, dtype=torch.long)
indices = torch.as_tensor(
indices, dtype=torch.long, device=coefficient.device
)
if len(indices) > d:
raise ValueError("Can provide at most `d` indices!")
if (indices > d - 1).any():
Expand Down Expand Up @@ -498,7 +501,7 @@ class Normalize(AffineInputTransform):
def __init__(
self,
d: int,
indices: Optional[List[int]] = None,
indices: Optional[Union[List[int], Tensor]] = None,
bounds: Optional[Tensor] = None,
batch_shape: torch.Size = torch.Size(), # noqa: B008
transform_on_train: bool = True,
Expand Down Expand Up @@ -626,7 +629,7 @@ class InputStandardize(AffineInputTransform):
def __init__(
self,
d: int,
indices: Optional[List[int]] = None,
indices: Optional[Union[List[int], Tensor]] = None,
batch_shape: torch.Size = torch.Size(), # noqa: B008
transform_on_train: bool = True,
transform_on_eval: bool = True,
Expand Down

0 comments on commit f00057b

Please sign in to comment.