Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 5, 2023
1 parent 602d39c commit f782869
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,45 @@ write onto the storage. The following code snippet examplifies this feature:
... assert (rb["_data", "a"][:10] == 0).all() # data from main process
... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process

Sharing replay buffers across processes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Replay buffers can be shared between processes as long as their components are
sharable. Sharable storages include :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage`
or any subclass of :class:`~torchrl.data.replay_buffers.storages.TensorStorage`
as long as they are instantiated and their content is stored as memory-mapped
tensors. Stateful writers such as :class:`~torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter`
are currently not sharable, and the same goes for stateful samplers such as
:class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.

A shared replay buffer can be read and extended on any process that has access
to it, as the following example shows:

>>> import pickle
>>>
>>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
>>> import torch
>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>>
>>> def worker(rb):
... td = TensorDict({"a": torch.ones(10)}, [10])
... # Extends the shared replay buffer on a subprocess
... rb.extend(td)
>>>
>>> if __name__ == "__main__":
... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21))
... td = TensorDict({"a": torch.zeros(10)}, [10])
.. # extends the replay buffer on the main process
... rb.extend(td)
...
... proc = mp.Process(target=worker, args=(rb,))
... proc.start()
... proc.join()
... # Checks that the length of the buffer equates the length of both
... # extensions (local and remote)
... assert len(rb) == 20


Storing trajectories
~~~~~~~~~~~~~~~~~~~~
Expand Down

0 comments on commit f782869

Please sign in to comment.