Skip to content

Commit

Permalink
[Feature] LazyStackStorage
Browse files Browse the repository at this point in the history
ghstack-source-id: e9c031470aa0bdafbb2b26c73c06b25685a128e5
Pull Request resolved: #2723
  • Loading branch information
vmoens committed Jan 30, 2025
1 parent 280297a commit fe3f00c
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ using the following components:
LazyMemmapStorage
LazyTensorStorage
ListStorage
LazyStackStorage
ListStorageCheckpointer
NestedStorageCheckpointer
PrioritizedSampler
Expand Down
26 changes: 26 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@

from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
StorageEnsemble,
Expand Down Expand Up @@ -1116,6 +1117,31 @@ def test_storage_inplace_writing_ndim(self, storage_type):
assert (rb[:, 10:20] == 0).all()
assert len(rb) == 100

@pytest.mark.parametrize("max_size", [1000, None])
@pytest.mark.parametrize("stack_dim", [-1, 0])
def test_lazy_stack_storage(self, max_size, stack_dim):
# Create an instance of LazyStackStorage with given parameters
storage = LazyStackStorage(max_size=max_size, stack_dim=stack_dim)
# Create a ReplayBuffer using the created storage
rb = ReplayBuffer(storage=storage)
# Generate some random data to add to the buffer
torch.manual_seed(0)
data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
# Add the data to the buffer
rb.add(data0)
rb.add(data1)
# Sample from the buffer
sample = rb.sample(10)
# Check that the sampled data has the correct shape and type
assert isinstance(sample, LazyStackedTensorDict)
assert sample["b"].shape[0] == 10
assert all(isinstance(item, str) for item in sample["c"])
# If densify is True, check that the sampled data is dense
sample = sample.densify(layout=torch.jagged)
assert isinstance(sample["a"], torch.Tensor)
assert sample["a"].shape[0] == 10


@pytest.mark.parametrize("max_size", [1000])
@pytest.mark.parametrize("shape", [[3, 4]])
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
H5StorageCheckpointer,
ImmutableDatasetWriter,
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
ListStorageCheckpointer,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .storages import (
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
Storage,
Expand Down
6 changes: 5 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,10 @@ def __init__(
self._cache["stop-and-length"] = vals

else:
if traj_key is not None:
self._fetch_traj = True
elif end_key is not None:
self._fetch_traj = False
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
Expand Down Expand Up @@ -1331,7 +1335,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
if start_idx.shape[1] != storage.ndim:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be "
f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor "
f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} "
"instead."
)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
Expand Down
79 changes: 79 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ def set(
def get(self, index: Union[int, Sequence[int], slice]) -> Any:
if isinstance(index, (INT_CLASSES, slice)):
return self._storage[index]
elif isinstance(index, tuple):
if len(index) > 1:
raise RuntimeError(
f"{type(self).__name__} can only be indexed with one-length tuples."
)
return self.get(index[0])
else:
if isinstance(index, torch.Tensor) and index.device.type != "cpu":
index = index.cpu().tolist()
return [self._storage[i] for i in index]

def __len__(self):
Expand Down Expand Up @@ -353,6 +361,77 @@ def contains(self, item):
raise NotImplementedError(f"type {type(item)} is not supported yet.")


class LazyStackStorage(ListStorage):
"""A ListStorage that returns LazyStackTensorDict instances.
This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation.
It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts,
lazily stacking items when queried.
This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack).
Tensors of heterogeneous shapes can also be stored within the storage and stacked together.
Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with
the size of the buffer.
If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify`
(see :mod:`~torch.nested`).
Args:
max_size (int, optional): the maximum number of elements stored in the storage.
If not provided, an unlimited storage is created.
Keyword Args:
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
the cost of being executable in multiprocessed settings.
stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`.
Examples:
>>> import torch
>>> from torchrl.data import ReplayBuffer, LazyStackStorage
>>> from tensordict import TensorDict
>>> _ = torch.manual_seed(0)
>>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1))
>>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
>>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
>>> _ = rb.add(data0)
>>> _ = rb.add(data1)
>>> rb.sample(10)
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
c: NonTensorStack(
['another string!', 'another string!', 'another st...,
batch_size=torch.Size([10]),
device=None)},
exclusive_fields={
},
batch_size=torch.Size([10]),
device=None,
is_shared=False,
stack_dim=0)
"""

def __init__(
self,
max_size: int | None = None,
*,
compilable: bool = False,
stack_dim: int = -1,
):
super().__init__(max_size=max_size, compilable=compilable)
self.stack_dim = stack_dim

def get(self, index: Union[int, Sequence[int], slice]) -> Any:
out = super().get(index=index)
if isinstance(out, list):
stack_dim = self.stack_dim
if stack_dim < 0:
stack_dim = out[0].ndim + 1 + stack_dim
out = LazyStackedTensorDict(*out, stack_dim=stack_dim)
return out
return out


class TensorStorage(Storage):
"""A storage for tensors and tensordicts.
Expand Down

0 comments on commit fe3f00c

Please sign in to comment.