Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Faster split, chunk and unbind #563

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 86 additions & 25 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from numbers import Number
from pathlib import Path
from textwrap import indent
from typing import Any, Callable, Iterable, Iterator, Sequence
from typing import Any, Callable, Iterable, Iterator, List, Sequence
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -585,8 +585,14 @@ def _convert_to_tensordict(self, dict_value: dict[str, Any]) -> T:
_is_memmap=self._is_memmap,
)

def _index_tensordict(self, index: IndexType) -> T:
def _index_tensordict(
self,
index: IndexType,
new_batch_size: torch.Size | None = None,
names: List[str] | None = None,
) -> T:
batch_size = self.batch_size
batch_dims = len(batch_size)
if (
not batch_size
and index is not None
Expand All @@ -595,10 +601,24 @@ def _index_tensordict(self, index: IndexType) -> T:
raise RuntimeError(
f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}."
)
names = self._get_names_idx(index)
batch_size = _getitem_batch_size(batch_size, index)
if names is None:
names = self._get_names_idx(index)
if new_batch_size is not None:
batch_size = new_batch_size
else:
batch_size = _getitem_batch_size(batch_size, index)
source = {}
for key, item in self.items():
if isinstance(item, TensorDict):
# this is the simplest case, we can pre-compute the batch size easily
new_batch_size = batch_size + item.batch_size[batch_dims:]
source[key] = item._index_tensordict(
index, new_batch_size=new_batch_size
)
else:
source[key] = _get_item(item, index)
return TensorDict(
source={key: _get_item(item, index) for key, item in self.items()},
source=source,
batch_size=batch_size,
device=self.device,
names=names,
Expand Down Expand Up @@ -650,54 +670,90 @@ def unbind(self, dim: int) -> tuple[T, ...]:
names = copy(self.names)
names = [name for i, name in enumerate(names) if i != dim]
out = []
unbind_self_dict = {key: tensor.unbind(dim) for key, tensor in self.items()}
# unbind_self_dict = {key: tensor.unbind(dim) for key, tensor in self.items()}
prefix = (slice(None),) * dim
for _idx in range(self.batch_size[dim]):
td = TensorDict(
{key: tensor[_idx] for key, tensor in unbind_self_dict.items()},
batch_size=batch_size,
_run_checks=False,
device=self.device,
_is_memmap=False,
_is_shared=False,
names=names,
)
_idx = prefix + (_idx,)
td = self._index_tensordict(_idx, new_batch_size=batch_size, names=names)
# td = TensorDict(
# {key: tensor[_idx] for key, tensor in unbind_self_dict.items()},
# batch_size=batch_size,
# _run_checks=False,
# device=self.device,
# _is_memmap=False,
# _is_shared=False,
# names=names,
# )
out.append(td)
if self.is_shared():
out[-1].share_memory_()
td._is_shared = True
elif self.is_memmap():
out[-1].memmap_()
td._is_memmap = True
return tuple(out)

def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]:
# we must use slices to keep the storage of the tensors
WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints"
batch_size = self.batch_size
batch_sizes = []
batch_dims = len(batch_size)
if dim < 0:
dim = len(batch_size) + dim
if dim >= batch_dims or dim < 0:
raise IndexError(
f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})"
)
max_size = batch_size[dim]
if isinstance(split_size, int):
idx0 = 0
idx1 = split_size
idx1 = min(max_size, split_size)
split_sizes = [slice(idx0, idx1)]
while idx1 < batch_size[dim]:
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0 for i, d in enumerate(batch_size)
)
)
)
while idx1 < max_size:
idx0 = idx1
idx1 += split_size
idx1 = min(max_size, idx1 + split_size)
split_sizes.append(slice(idx0, idx1))
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0
for i, d in enumerate(batch_size)
)
)
)
elif isinstance(split_size, (list, tuple)):
if len(split_size) == 0:
raise RuntimeError("Insufficient number of elements in split_size.")
try:
idx0 = 0
idx1 = split_size[0]
split_sizes = [slice(idx0, idx1)]
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0
for i, d in enumerate(batch_size)
)
)
)
for idx in split_size[1:]:
idx0 = idx1
idx1 += idx
idx1 = min(max_size, idx1 + idx)
split_sizes.append(slice(idx0, idx1))
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0
for i, d in enumerate(batch_size)
)
)
)
except TypeError:
raise TypeError(WRONG_TYPE)

Expand All @@ -708,7 +764,11 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas
else:
raise TypeError(WRONG_TYPE)
index = (slice(None),) * dim
return tuple(self[index + (ss,)] for ss in split_sizes)
names = self.names
return tuple(
self._index_tensordict(index + (ss,), new_batch_size=bs, names=names)
for ss, bs in zip(split_sizes, batch_sizes)
)

def memmap_like(self, prefix: str | None = None) -> T:
def save_metadata(data: TensorDictBase, filepath, metadata=None):
Expand Down Expand Up @@ -2301,10 +2361,11 @@ def _create_nested_str(self, key):
# return self.to_tensordict()._apply_nest(*args, **kwargs)
_convert_to_tensordict = TensorDict._convert_to_tensordict

def _get_names_idx(self, *args, **kwargs):
raise NotImplementedError
_get_names_idx = TensorDict._get_names_idx

def _index_tensordict(self, index):
def _index_tensordict(self, index, new_batch_size=None, names=None):
# we ignore the names and new_batch_size which are only provided for
# efficiency purposes
return self._get_sub_tensordict(index)

def _remove_batch_dim(self, *args, **kwargs):
Expand Down
24 changes: 10 additions & 14 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Callable,
Generator,
Iterator,
List,
Optional,
OrderedDict,
overload,
Expand Down Expand Up @@ -526,19 +527,9 @@ def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
raise ValueError(
f"chunks must be a strictly positive integer, got {chunks}."
)
indices = []
_idx_start = 0
if chunks > 1:
interval = _idx_end = self.batch_size[dim] // chunks
else:
interval = _idx_end = self.batch_size[dim]
for c in range(chunks):
indices.append(slice(_idx_start, _idx_end))
_idx_start = _idx_end
_idx_end = _idx_end + interval if c < chunks - 2 else self.batch_size[dim]
if dim < 0:
dim = len(self.batch_size) + dim
return tuple(self[(*[slice(None) for _ in range(dim)], idx)] for idx in indices)
# fall back on split, using upper rounding
split_size = -(self.batch_size[dim] // -chunks)
return self.split(split_size, dim=dim)

@overload
def unsqueeze(self, dim: int) -> T:
Expand Down Expand Up @@ -3625,7 +3616,12 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
return self

@abc.abstractmethod
def _index_tensordict(self, index: IndexType) -> T:
def _index_tensordict(
self,
index: IndexType,
new_batch_size: torch.Size | None = None,
names: List[str] | None = None,
) -> T:
...

# Locking functionality
Expand Down
6 changes: 1 addition & 5 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,11 +2750,7 @@ def test_split(self, td_name, device, performer, dim):

for idx, split_td in enumerate(tds):
expected_split_dim_size = 1 if idx == rep else 2
expected_batch_size = [
expected_split_dim_size if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(td.batch_size)
]

expected_batch_size = tensorsplit[idx].shape
# Test each split_td has the expected batch_size
assert split_td.batch_size == torch.Size(expected_batch_size)

Expand Down