diff --git a/tensordict/_td.py b/tensordict/_td.py index 331ed4234..c3b5d85dd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -829,12 +829,14 @@ def _index_tensordict( raise RuntimeError( f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}." ) - if names is None: - names = self._get_names_idx(index) if new_batch_size is not None: batch_size = new_batch_size else: batch_size = _getitem_batch_size(batch_size, index) + + if names is None: + names = self._get_names_idx(index) + source = {} for key, item in self.items(): if isinstance(item, TensorDict): @@ -1371,14 +1373,20 @@ def is_boolean(idx): # this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3] count = 0 idx_to_take = [] + no_more_tensors = False for _idx in idx_names: if _idx is None: idx_to_take.append(None) elif _is_number(_idx): count += 1 elif isinstance(_idx, (torch.Tensor, np.ndarray)): - idx_to_take.extend([count] * _idx.ndim) - count += 1 + if not no_more_tensors: + idx_to_take.extend([count] * _idx.ndim) + count += 1 + no_more_tensors = True + else: + # skip this one + count += 1 else: idx_to_take.append(count) count += 1 @@ -1392,6 +1400,7 @@ def names(self, value): self._rename_subtds(value) self._erase_names() return + value = list(value) num_none = sum(v is None for v in value) if num_none: num_none -= 1 diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 3584b1d33..b960ea843 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -96,8 +96,10 @@ def _gather_tensor(tensor, dest=None): return out if out is None: - names = input.names if input._has_names() else None - + if len(index.shape) == input.ndim and input._has_names(): + names = input.names + else: + names = None return TensorDict( { key: _gather_tensor(value)