Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 24, 2024
1 parent 003be12 commit 8787abc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
17 changes: 13 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,12 +829,14 @@ def _index_tensordict(
raise RuntimeError(
f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}."
)
if names is None:
names = self._get_names_idx(index)
if new_batch_size is not None:
batch_size = new_batch_size
else:
batch_size = _getitem_batch_size(batch_size, index)

if names is None:
names = self._get_names_idx(index)

source = {}
for key, item in self.items():
if isinstance(item, TensorDict):
Expand Down Expand Up @@ -1371,14 +1373,20 @@ def is_boolean(idx):
# this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3]
count = 0
idx_to_take = []
no_more_tensors = False
for _idx in idx_names:
if _idx is None:
idx_to_take.append(None)
elif _is_number(_idx):
count += 1
elif isinstance(_idx, (torch.Tensor, np.ndarray)):
idx_to_take.extend([count] * _idx.ndim)
count += 1
if not no_more_tensors:
idx_to_take.extend([count] * _idx.ndim)
count += 1
no_more_tensors = True
else:
# skip this one
count += 1
else:
idx_to_take.append(count)
count += 1
Expand All @@ -1392,6 +1400,7 @@ def names(self, value):
self._rename_subtds(value)
self._erase_names()
return
value = list(value)
num_none = sum(v is None for v in value)
if num_none:
num_none -= 1
Expand Down
6 changes: 4 additions & 2 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ def _gather_tensor(tensor, dest=None):
return out

if out is None:
names = input.names if input._has_names() else None

if len(index.shape) == input.ndim and input._has_names():
names = input.names
else:
names = None
return TensorDict(
{
key: _gather_tensor(value)
Expand Down

0 comments on commit 8787abc

Please sign in to comment.