diff --git a/tensordict/_td.py b/tensordict/_td.py index 77f806ee2..0ec6616eb 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -57,6 +57,7 @@ _is_shared, _KEY_ERROR, _LOCK_ERROR, + _maybe_correct_neg_dim, _mismatch_keys, _NON_STR_KEY_ERR, _NON_STR_KEY_TUPLE_ERR, @@ -878,8 +879,7 @@ def all(self, dim: int = None) -> bool | TensorDictBase: "smaller than tensordict.batch_dims" ) if dim is not None: - if dim < 0: - dim = self.batch_dims + dim + dim = _maybe_correct_neg_dim(dim, self.batch_size) names = None if self._has_names(): @@ -901,8 +901,7 @@ def any(self, dim: int = None) -> bool | TensorDictBase: "smaller than tensordict.batch_dims" ) if dim is not None: - if dim < 0: - dim = self.batch_dims + dim + dim = _maybe_correct_neg_dim(dim, self.batch_size) names = None if self._has_names(): @@ -980,14 +979,7 @@ def proc_dim(dim, batch_dims, tuple_ok=True): for _d in proc_dim(d, batch_dims, tuple_ok=False) ) return dim - if dim >= batch_dims or dim < -batch_dims: - raise RuntimeError( - "dim must be greater than or equal to -tensordict.batch_dims and " - "smaller than tensordict.batch_dims" - ) - if dim < 0: - return (batch_dims + dim,) - return (dim,) + return (_maybe_correct_neg_dim(dim, None, batch_dims),) dim_needs_proc = (dim is not NO_DEFAULT) and (dim not in ("feature",)) if dim_needs_proc: @@ -1724,13 +1716,7 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints" batch_size = self.batch_size batch_sizes = [] - batch_dims = len(batch_size) - if dim < 0: - dim = len(batch_size) + dim - if dim >= batch_dims or dim < 0: - raise IndexError( - f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})" - ) + dim = _maybe_correct_neg_dim(dim, batch_size) max_size = batch_size[dim] if isinstance(split_size, int): idx0 = 0 @@ -2005,17 +1991,7 @@ def _squeeze(tensor): propagate_lock=True, ) # make the dim positive - if dim < 0: - newdim = self.batch_dims + dim - else: - newdim = dim - - if (newdim >= self.batch_dims) or (newdim < 0): - raise RuntimeError( - f"squeezing is allowed for dims comprised between " - f"`-td.batch_dims` and `td.batch_dims - 1` only. Got " - f"dim={dim} with a batch size of {self.batch_size}." - ) + newdim = _maybe_correct_neg_dim(dim, batch_size) if batch_size[dim] != 1: return self batch_size = list(batch_size) diff --git a/tensordict/base.py b/tensordict/base.py index 763876ca1..83ab269bc 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -65,6 +65,7 @@ _KEY_ERROR, _lock_warn, _make_dtype_promotion, + _maybe_correct_neg_dim, _parse_to, _pass_through, _pass_through_cls, @@ -3343,13 +3344,7 @@ def unbind(self, dim: int) -> tuple[T, ...]: tensor([4, 5, 6, 7]) """ - batch_dims = self.batch_dims - if dim < -batch_dims or dim >= batch_dims: - raise RuntimeError( - f"the dimension provided ({dim}) is beyond the tensordict dimensions ({self.ndim})." - ) - if dim < 0: - dim = batch_dims + dim + dim = _maybe_correct_neg_dim(dim, self.batch_size) results = self._unbind(dim) if self._is_memmap or self._is_shared: for result in results: @@ -7406,12 +7401,7 @@ def unflatten(self, dim, unflattened_size): >>> td_unflat = td_flat.unflatten(0, [3, 4]) >>> assert (td == td_unflat).all() """ - if dim < 0: - dim = self.ndim + dim - if dim < 0: - raise ValueError( - f"Incompatible dim {dim} for tensordict with shape {self.shape}." - ) + dim = _maybe_correct_neg_dim(dim, self.batch_size) def unflatten(tensor): return torch.unflatten( @@ -8956,11 +8946,7 @@ def _map( iterable: bool, ): num_workers = pool._processes - dim_orig = dim - if dim < 0: - dim = self.ndim + dim - if dim < 0 or dim >= self.ndim: - raise ValueError(f"Got incompatible dimension {dim_orig}") + dim = _maybe_correct_neg_dim(dim, self.batch_size) self_split = _split_tensordict( self, @@ -9588,18 +9574,11 @@ def softmax(self, dim: int, dtype: torch.dtype | None = None): # noqa: D417 """ if isinstance(dim, int): - if dim < 0: - new_dim = self.ndim + dim - else: - new_dim = dim + dim = _maybe_correct_neg_dim(dim, self.batch_size) else: raise ValueError(f"Expected dim of type int, got {type(dim)}.") - if (new_dim < 0) or (new_dim >= self.ndim): - raise ValueError( - f"The dimension {dim} is incompatible with a tensordict with batch_size {self.batch_size}." - ) return self._fast_apply( - lambda x: torch.softmax(x, dim=new_dim, dtype=dtype), + lambda x: torch.softmax(x, dim=dim, dtype=dtype), ) def log10(self) -> T: diff --git a/tensordict/utils.py b/tensordict/utils.py index dc2f0a769..c2fd31ff4 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2923,3 +2923,24 @@ def set_mode(self, type: Any | None) -> None: cm = self._lock if not is_compiling() else nullcontext() with cm: self._mode = type + + +def _maybe_correct_neg_dim( + dim: int, shape: torch.Size | None, ndim: int | None = None +) -> int: + """Corrects neg dim to pos.""" + if ndim is None: + ndim = len(shape) + if dim < 0: + new_dim = ndim + dim + else: + new_dim = dim + if new_dim < 0 or new_dim >= ndim: + if shape is not None: + raise IndexError( + f"Incompatible dim {new_dim} for tensordict with shape {shape}." + ) + raise IndexError( + f"Incompatible dim {new_dim} for tensordict with batch dims {ndim}." + ) + return new_dim diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 838ad35d0..cac9cab76 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2566,7 +2566,7 @@ def test_split_with_empty_tensordict(self): def test_split_with_invalid_arguments(self): td = TensorDict({"a": torch.zeros(2, 1)}, []) # Test empty batch size - with pytest.raises(IndexError, match="Dimension out of range"): + with pytest.raises(IndexError, match="Incompatible dim"): td.split(1, 0) td = TensorDict({}, [3, 2]) @@ -2587,9 +2587,9 @@ def test_split_with_invalid_arguments(self): td.split([1, 1], 0) # Test invalid dimension input - with pytest.raises(IndexError, match="Dimension out of range"): + with pytest.raises(IndexError, match="Incompatible dim"): td.split(1, 2) - with pytest.raises(IndexError, match="Dimension out of range"): + with pytest.raises(IndexError, match="Incompatible dim"): td.split(1, -3) def test_split_with_negative_dim(self):