Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 26, 2025
1 parent bf9f4b4 commit d7b479c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 60 deletions.
36 changes: 6 additions & 30 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
_is_shared,
_KEY_ERROR,
_LOCK_ERROR,
_maybe_correct_neg_dim,
_mismatch_keys,
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
Expand Down Expand Up @@ -878,8 +879,7 @@ def all(self, dim: int = None) -> bool | TensorDictBase:
"smaller than tensordict.batch_dims"
)
if dim is not None:
if dim < 0:
dim = self.batch_dims + dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)

names = None
if self._has_names():
Expand All @@ -901,8 +901,7 @@ def any(self, dim: int = None) -> bool | TensorDictBase:
"smaller than tensordict.batch_dims"
)
if dim is not None:
if dim < 0:
dim = self.batch_dims + dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)

names = None
if self._has_names():
Expand Down Expand Up @@ -980,14 +979,7 @@ def proc_dim(dim, batch_dims, tuple_ok=True):
for _d in proc_dim(d, batch_dims, tuple_ok=False)
)
return dim
if dim >= batch_dims or dim < -batch_dims:
raise RuntimeError(
"dim must be greater than or equal to -tensordict.batch_dims and "
"smaller than tensordict.batch_dims"
)
if dim < 0:
return (batch_dims + dim,)
return (dim,)
return (_maybe_correct_neg_dim(dim, None, batch_dims),)

dim_needs_proc = (dim is not NO_DEFAULT) and (dim not in ("feature",))
if dim_needs_proc:
Expand Down Expand Up @@ -1724,13 +1716,7 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas
WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints"
batch_size = self.batch_size
batch_sizes = []
batch_dims = len(batch_size)
if dim < 0:
dim = len(batch_size) + dim
if dim >= batch_dims or dim < 0:
raise IndexError(
f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})"
)
dim = _maybe_correct_neg_dim(dim, batch_size)
max_size = batch_size[dim]
if isinstance(split_size, int):
idx0 = 0
Expand Down Expand Up @@ -2005,17 +1991,7 @@ def _squeeze(tensor):
propagate_lock=True,
)
# make the dim positive
if dim < 0:
newdim = self.batch_dims + dim
else:
newdim = dim

if (newdim >= self.batch_dims) or (newdim < 0):
raise RuntimeError(
f"squeezing is allowed for dims comprised between "
f"`-td.batch_dims` and `td.batch_dims - 1` only. Got "
f"dim={dim} with a batch size of {self.batch_size}."
)
newdim = _maybe_correct_neg_dim(dim, batch_size)
if batch_size[dim] != 1:
return self
batch_size = list(batch_size)
Expand Down
33 changes: 6 additions & 27 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
_KEY_ERROR,
_lock_warn,
_make_dtype_promotion,
_maybe_correct_neg_dim,
_parse_to,
_pass_through,
_pass_through_cls,
Expand Down Expand Up @@ -3343,13 +3344,7 @@ def unbind(self, dim: int) -> tuple[T, ...]:
tensor([4, 5, 6, 7])

"""
batch_dims = self.batch_dims
if dim < -batch_dims or dim >= batch_dims:
raise RuntimeError(
f"the dimension provided ({dim}) is beyond the tensordict dimensions ({self.ndim})."
)
if dim < 0:
dim = batch_dims + dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)
results = self._unbind(dim)
if self._is_memmap or self._is_shared:
for result in results:
Expand Down Expand Up @@ -7406,12 +7401,7 @@ def unflatten(self, dim, unflattened_size):
>>> td_unflat = td_flat.unflatten(0, [3, 4])
>>> assert (td == td_unflat).all()
"""
if dim < 0:
dim = self.ndim + dim
if dim < 0:
raise ValueError(
f"Incompatible dim {dim} for tensordict with shape {self.shape}."
)
dim = _maybe_correct_neg_dim(dim, self.batch_size)

def unflatten(tensor):
return torch.unflatten(
Expand Down Expand Up @@ -8956,11 +8946,7 @@ def _map(
iterable: bool,
):
num_workers = pool._processes
dim_orig = dim
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(f"Got incompatible dimension {dim_orig}")
dim = _maybe_correct_neg_dim(dim, self.batch_size)

self_split = _split_tensordict(
self,
Expand Down Expand Up @@ -9588,18 +9574,11 @@ def softmax(self, dim: int, dtype: torch.dtype | None = None): # noqa: D417

"""
if isinstance(dim, int):
if dim < 0:
new_dim = self.ndim + dim
else:
new_dim = dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)
else:
raise ValueError(f"Expected dim of type int, got {type(dim)}.")
if (new_dim < 0) or (new_dim >= self.ndim):
raise ValueError(
f"The dimension {dim} is incompatible with a tensordict with batch_size {self.batch_size}."
)
return self._fast_apply(
lambda x: torch.softmax(x, dim=new_dim, dtype=dtype),
lambda x: torch.softmax(x, dim=dim, dtype=dtype),
)

def log10(self) -> T:
Expand Down
21 changes: 21 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,3 +2923,24 @@ def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
self._mode = type


def _maybe_correct_neg_dim(
dim: int, shape: torch.Size | None, ndim: int | None = None
) -> int:
"""Corrects neg dim to pos."""
if ndim is None:
ndim = len(shape)
if dim < 0:
new_dim = ndim + dim
else:
new_dim = dim
if new_dim < 0 or new_dim >= ndim:
if shape is not None:
raise IndexError(
f"Incompatible dim {new_dim} for tensordict with shape {shape}."
)
raise IndexError(
f"Incompatible dim {new_dim} for tensordict with batch dims {ndim}."
)
return new_dim
6 changes: 3 additions & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,7 +2566,7 @@ def test_split_with_empty_tensordict(self):
def test_split_with_invalid_arguments(self):
td = TensorDict({"a": torch.zeros(2, 1)}, [])
# Test empty batch size
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(IndexError, match="Incompatible dim"):
td.split(1, 0)

td = TensorDict({}, [3, 2])
Expand All @@ -2587,9 +2587,9 @@ def test_split_with_invalid_arguments(self):
td.split([1, 1], 0)

# Test invalid dimension input
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(IndexError, match="Incompatible dim"):
td.split(1, 2)
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(IndexError, match="Incompatible dim"):
td.split(1, -3)

def test_split_with_negative_dim(self):
Expand Down

0 comments on commit d7b479c

Please sign in to comment.