Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] pickle-free RB checkpointing #1733

Merged
merged 7 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,49 @@ write onto the storage. The following code snippet examplifies this feature:
... assert (rb["_data", "a"][:10] == 0).all() # data from main process
... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process

Sharing replay buffers across processes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Replay buffers can be shared between processes as long as their components are
sharable. This feature allows for multiple processes to collect data and populate a shared
replay buffer collaboratively, rather than centralizing the data on the main process
which can incur some data transmission overhead.

Sharable storages include :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage`
or any subclass of :class:`~torchrl.data.replay_buffers.storages.TensorStorage`
as long as they are instantiated and their content is stored as memory-mapped
tensors. Stateful writers such as :class:`~torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter`
are currently not sharable, and the same goes for stateful samplers such as
:class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.

A shared replay buffer can be read and extended on any process that has access
to it, as the following example shows:

>>> import pickle
>>>
>>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
>>> import torch
>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>>
>>> def worker(rb):
... td = TensorDict({"a": torch.ones(10)}, [10])
... # Extends the shared replay buffer on a subprocess
... rb.extend(td)
>>>
>>> if __name__ == "__main__":
... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21))
... td = TensorDict({"a": torch.zeros(10)}, [10])
.. # extends the replay buffer on the main process
... rb.extend(td)
...
... proc = mp.Process(target=worker, args=(rb,))
... proc.start()
... proc.join()
... # Checks that the length of the buffer equates the length of both
... # extensions (local and remote)
... assert len(rb) == 20


Storing trajectories
~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -131,6 +174,32 @@ can be used:
device=None,
is_shared=False)

Checkpointing Replay Buffers
----------------------------

Each component of the replay buffer can potentially be stateful and, as such,
require a dedicated way of being serialized.
Our replay buffer enjoys two separate APIs for saving their state on disk:
:meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the
data of each component except transforms (storage, writer, sampler) using memory-mapped
tensors and json files for the metadata. This will work across all classes except
:class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content
cannot be anticipated (and as such does not comply with memory-mapped data
structures such as those that can be found in the tensordict library).
This API guarantees that a buffer that is saved and then loaded back will be in
the exact same state, whether we look at the status of its sampler (eg, priority trees)
its writer (eg, max writer heaps) or its storage.
Under the hood, :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public
`dumps` method in a specific folder for each of its components (except transforms
which we don't assume to be serializable using memory-mapped tensors in general).

Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an
alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data
structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load`
before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback
of this method is that it will struggle to save big data structures, which is a
common setting when using replay buffers.

Datasets
--------

Expand Down
Loading