Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Shared replay buffers #1724

Merged
merged 3 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,34 @@ The following mean sampling latency improvements over using ListStorage were fou
| :class:`LazyMemmapStorage` | 3.44x |
+-------------------------------+-----------+

Replay buffers with a shared storage and regular (RoundRobin) writers can also
be shared between processes on a single node. This allows each worker to read and
write onto the storage. The following code snippet examplifies this feature:

>>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
>>> import torch
>>> from torch import multiprocessing as mp
>>> from tensordict import TensorDict
>>>
>>> def worker(rb):
... # Updates the replay buffer with new data
... td = TensorDict({"a": torch.ones(10)}, [10])
... rb.extend(td)
...
>>> if __name__ == "__main__":
... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21))
... td = TensorDict({"a": torch.zeros(10)}, [10])
... rb.extend(td)
...
... proc = mp.Process(target=worker, args=(rb,))
... proc.start()
... proc.join()
... # the replay buffer now has a length of 20, since the worker updated it
... assert len(rb) == 20
... assert (rb["_data", "a"][:10] == 0).all() # data from main process
... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process


Storing trajectories
~~~~~~~~~~~~~~~~~~~~

Expand Down
69 changes: 69 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from packaging.version import parse
from tensordict import is_tensorclass, tensorclass
from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl.data import (
PrioritizedReplayBuffer,
RemoteTensorDictReplayBuffer,
Expand All @@ -41,6 +42,7 @@
from torchrl.data.replay_buffers.writers import (
RoundRobinWriter,
TensorDictMaxValueWriter,
TensorDictRoundRobinWriter,
)
from torchrl.envs.transforms.transforms import (
BinarizeReward,
Expand Down Expand Up @@ -1285,6 +1287,73 @@ def test_max_value_writer(size, batch_size, reward_ranges, device):
assert (sample.get("key") != 0).all()


class TestMultiProc:
@staticmethod
def worker(rb, q0, q1):
td = TensorDict({"a": torch.ones(10), "next": {"reward": torch.ones(10)}}, [10])
rb.extend(td)
q0.put("extended")
extended = q1.get(timeout=5)
assert extended == "extended"
assert len(rb) == 21, len(rb)
assert (rb["_data", "a"][:9] == 2).all()
q0.put("finish")

def exec_multiproc_rb(
self,
storage_type=LazyMemmapStorage,
init=True,
writer_type=TensorDictRoundRobinWriter,
):
rb = TensorDictReplayBuffer(storage=storage_type(21), writer=writer_type())
if init:
td = TensorDict(
{"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10]
)
rb.extend(td)
q0 = mp.Queue(1)
q1 = mp.Queue(1)
proc = mp.Process(target=self.worker, args=(rb, q0, q1))
proc.start()
try:
extended = q0.get(timeout=100)
assert extended == "extended"
assert len(rb) == 20
assert (rb["_data", "a"][10:20] == 1).all()
td = TensorDict({"a": torch.zeros(10) + 2}, [10])
rb.extend(td)
q1.put("extended")
finish = q0.get(timeout=5)
assert finish == "finish"
finally:
proc.join()

def test_multiproc_rb(self):
return self.exec_multiproc_rb()

def test_error_list(self):
# list storage cannot be shared
with pytest.raises(RuntimeError, match="Cannot share a storage of type"):
self.exec_multiproc_rb(storage_type=ListStorage)

def test_error_nonshared(self):
# non shared tensor storage cannot be shared
with pytest.raises(
RuntimeError, match="The storage must be place in shared memory"
):
self.exec_multiproc_rb(storage_type=LazyTensorStorage)

def test_error_maxwriter(self):
# TensorDictMaxValueWriter cannot be shared
with pytest.raises(RuntimeError, match="cannot be shared between processed"):
self.exec_multiproc_rb(writer_type=TensorDictMaxValueWriter)

def test_error_noninit(self):
# list storage cannot be shared
with pytest.raises(RuntimeError, match="it has not been initialized yet"):
self.exec_multiproc_rb(init=False)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
61 changes: 61 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import warnings
from collections import OrderedDict
from copy import copy
from multiprocessing.context import get_spawning_popen
from typing import Any, Dict, Sequence, Union

import torch
from tensordict import is_tensorclass
from tensordict.memmap import MemmapTensor, MemoryMappedTensor
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import expand_right
from torch import multiprocessing as mp

from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE
from torchrl.data.replay_buffers.utils import INT_CLASSES
Expand Down Expand Up @@ -166,6 +168,14 @@ def load_state_dict(self, state_dict):
def _empty(self):
self._storage = []

def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
f"Cannot share a storage of type {type(self)} between processes."
)
state = copy(self.__dict__)
return state


class TensorStorage(Storage):
"""A storage for tensors and tensordicts.
Expand Down Expand Up @@ -259,6 +269,57 @@ def __init__(self, storage, max_size=None, device="cpu"):
)
self._storage = storage

@property
def _len(self):
_len_value = self.__dict__.get("_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
return _len_value.value

@_len.setter
def _len(self, value):
_len_value = self.__dict__.get("_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value

def __getstate__(self):
state = copy(self.__dict__)
if get_spawning_popen() is None:
len = self._len
del state["_len_value"]
state["len__context"] = len
elif not self.initialized:
# check that the storage is initialized
raise RuntimeError(
f"Cannot share a storage of type {type(self)} between processed if "
f"it has not been initialized yet. Populate the buffer with "
f"some data in the main process before passing it to the other "
f"subprocesses (or create the buffer explicitely with a TensorStorage)."
)
else:
# check that the content is shared, otherwise tell the user we can't help
storage = self._storage
STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes."
if is_tensor_collection(storage):
if not storage.is_memmap() and not storage.is_shared():
raise RuntimeError(STORAGE_ERR)
else:
if (
not isinstance(storage, MemoryMappedTensor)
and not storage.is_shared()
):
raise RuntimeError(STORAGE_ERR)

return state

def __setstate__(self, state):
len = state.pop("len__context", None)
if len is not None:
_len_value = mp.Value("i", len)
state["_len_value"] = _len_value
self.__dict__.update(state)

def state_dict(self) -> Dict[str, Any]:
_storage = self._storage
if isinstance(_storage, torch.Tensor):
Expand Down
51 changes: 48 additions & 3 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

import heapq
from abc import ABC, abstractmethod
from copy import copy
from multiprocessing.context import get_spawning_popen
from typing import Any, Dict, Sequence

import numpy as np
import torch
from torch import multiprocessing as mp

from .storages import Storage

Expand Down Expand Up @@ -52,14 +55,17 @@ def __init__(self, **kw) -> None:

def add(self, data: Any) -> int:
ret = self._cursor
self._storage[self._cursor] = data
_cursor = self._cursor
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (self._cursor + 1) % self._storage.max_size
self._storage[_cursor] = data
return ret

def extend(self, data: Sequence) -> torch.Tensor:
cur_size = self._cursor
batch_size = len(data)
index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % self._storage.max_size
self._storage[index] = data
return index
Expand All @@ -73,21 +79,52 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _empty(self):
self._cursor = 0

@property
def _cursor(self):
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
return _cursor_value.value

@_cursor.setter
def _cursor(self, value):
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value

def __getstate__(self):
state = copy(self.__dict__)
if get_spawning_popen() is None:
cursor = self._cursor
del state["_cursor_value"]
state["cursor__context"] = cursor
return state

def __setstate__(self, state):
cursor = state.pop("cursor__context", None)
if cursor is not None:
_cursor_value = mp.Value("i", cursor)
state["_cursor_value"] = _cursor_value
self.__dict__.update(state)


class TensorDictRoundRobinWriter(RoundRobinWriter):
"""A RoundRobin Writer class for composable, tensordict-based replay buffers."""

def add(self, data: Any) -> int:
ret = self._cursor
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (ret + 1) % self._storage.max_size
data["index"] = ret
self._storage[self._cursor] = data
self._cursor = (self._cursor + 1) % self._storage.max_size
self._storage[ret] = data
return ret

def extend(self, data: Sequence) -> torch.Tensor:
cur_size = self._cursor
batch_size = len(data)
index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % self._storage.max_size
# storage must convert the data to the appropriate format if needed
data["index"] = index
Expand Down Expand Up @@ -220,3 +257,11 @@ def extend(self, data: Sequence) -> None:
def _empty(self) -> None:
self._cursor = 0
self._current_top_values = []

def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
f"Writers of type {type(self)} cannot be shared between processed."
)
state = copy(self.__dict__)
return state
Loading