From bc360aa0bd82e8e52d33305a09d8d803f6f59493 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 30 Nov 2023 16:53:50 +0000 Subject: [PATCH 1/3] init --- test/test_rb.py | 33 +++++++++++++++++++++++++ torchrl/data/replay_buffers/storages.py | 33 +++++++++++++++++++++++++ torchrl/data/replay_buffers/writers.py | 32 +++++++++++++++++++++++- 3 files changed, 97 insertions(+), 1 deletion(-) diff --git a/test/test_rb.py b/test/test_rb.py index c68c623300b..19e36d30558 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -63,6 +63,7 @@ UnsqueezeTransform, VecNorm, ) +from torch import multiprocessing as mp OLD_TORCH = parse(torch.__version__) < parse("2.0.0") _has_tv = importlib.util.find_spec("torchvision") is not None @@ -1285,6 +1286,38 @@ def test_max_value_writer(size, batch_size, reward_ranges, device): assert (sample.get("key") != 0).all() +class TestMultiProc: + @staticmethod + def worker(rb, q0, q1): + td = TensorDict({"a": torch.ones(10)}, [10]) + rb.extend(td) + q0.put("extended") + extended = q1.get(timeout=5) + assert extended == "extended" + assert len(rb) == 21, len(rb) + assert (rb["_data", "a"][:9] == 2).all() + q0.put("finish") + + def test_multiproc_rb(self): + rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) + td = TensorDict({"a": torch.zeros(10)}, [10]) + rb.extend(td) + q0 = mp.Queue(1) + q1 = mp.Queue(1) + proc = mp.Process(target=self.worker, args=(rb, q0, q1)) + proc.start() + extended = q0.get(timeout=100) + assert extended == "extended" + assert len(rb) == 20 + assert (rb["_data", "a"][10:20] == 1).all() + td = TensorDict({"a": torch.zeros(10) + 2}, [10]) + rb.extend(td) + q1.put("extended") + finish = q0.get(timeout=5) + assert finish == "finish" + proc.join() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 9c8417b9c97..439b2a1cd47 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -18,6 +18,8 @@ from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE from torchrl.data.replay_buffers.utils import INT_CLASSES +from torch import multiprocessing as mp +from multiprocessing.context import get_spawning_popen try: from torchsnapshot.serialization import tensor_from_memoryview @@ -259,6 +261,37 @@ def __init__(self, storage, max_size=None, device="cpu"): ) self._storage = storage + + @property + def _len(self): + _len_value = self.__dict__.get('_len_value', None) + if _len_value is None: + _len_value = self._len_value = mp.Value('i', 0) + return _len_value.value + + @_len.setter + def _len(self, value): + _len_value = self.__dict__.get('_len_value', None) + if _len_value is None: + _len_value = self._len_value = mp.Value('i', 0) + _len_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + len = self._len + del state["_len_value"] + state["len__context"] = len + return state + + def __setstate__(self, state): + len = state.pop("len__context", None) + if len is not None: + _len_value = mp.Value('i', len) + state["_len_value"] = _len_value + self.__dict__.update(state) + + def state_dict(self) -> Dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index cf78e0a0d99..b56699cc9f0 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -5,13 +5,15 @@ import heapq from abc import ABC, abstractmethod +from copy import copy from typing import Any, Dict, Sequence import numpy as np import torch from .storages import Storage - +from torch import multiprocessing as mp +from multiprocessing.context import get_spawning_popen class Writer(ABC): """A ReplayBuffer base Writer class.""" @@ -42,6 +44,34 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return + @property + def _cursor(self): + _cursor_value = self.__dict__.get('_cursor_value', None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value('i', 0) + return _cursor_value.value + + @_cursor.setter + def _cursor(self, value): + _cursor_value = self.__dict__.get('_cursor_value', None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value('i', 0) + _cursor_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + cursor = self._cursor + del state["_cursor_value"] + state["cursor__context"] = cursor + return state + + def __setstate__(self, state): + cursor = state.pop("cursor__context", None) + if cursor is not None: + _cursor_value = mp.Value('i', cursor) + state["_cursor_value"] = _cursor_value + self.__dict__.update(state) class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers.""" From 5d0ad99626d86420bb9758a8b44cd0cb9126703e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 09:20:51 +0000 Subject: [PATCH 2/3] amend --- test/test_rb.py | 2 +- torchrl/data/replay_buffers/storages.py | 38 +++++++++--- torchrl/data/replay_buffers/writers.py | 81 +++++++++++++++---------- 3 files changed, 78 insertions(+), 43 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 19e36d30558..44bb3b258ea 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -18,6 +18,7 @@ from packaging.version import parse from tensordict import is_tensorclass, tensorclass from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from torch import multiprocessing as mp from torchrl.data import ( PrioritizedReplayBuffer, RemoteTensorDictReplayBuffer, @@ -63,7 +64,6 @@ UnsqueezeTransform, VecNorm, ) -from torch import multiprocessing as mp OLD_TORCH = parse(torch.__version__) < parse("2.0.0") _has_tv = importlib.util.find_spec("torchvision") is not None diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 439b2a1cd47..9d2674036e4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -8,6 +8,7 @@ import warnings from collections import OrderedDict from copy import copy +from multiprocessing.context import get_spawning_popen from typing import Any, Dict, Sequence, Union import torch @@ -15,11 +16,10 @@ from tensordict.memmap import MemmapTensor, MemoryMappedTensor from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.utils import expand_right +from torch import multiprocessing as mp from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE from torchrl.data.replay_buffers.utils import INT_CLASSES -from torch import multiprocessing as mp -from multiprocessing.context import get_spawning_popen try: from torchsnapshot.serialization import tensor_from_memoryview @@ -261,19 +261,18 @@ def __init__(self, storage, max_size=None, device="cpu"): ) self._storage = storage - @property def _len(self): - _len_value = self.__dict__.get('_len_value', None) + _len_value = self.__dict__.get("_len_value", None) if _len_value is None: - _len_value = self._len_value = mp.Value('i', 0) + _len_value = self._len_value = mp.Value("i", 0) return _len_value.value @_len.setter def _len(self, value): - _len_value = self.__dict__.get('_len_value', None) + _len_value = self.__dict__.get("_len_value", None) if _len_value is None: - _len_value = self._len_value = mp.Value('i', 0) + _len_value = self._len_value = mp.Value("i", 0) _len_value.value = value def __getstate__(self): @@ -282,16 +281,37 @@ def __getstate__(self): len = self._len del state["_len_value"] state["len__context"] = len + elif not self.initialized: + # check that the storage is initialized + raise RuntimeError( + f"Cannot share a storage of type {type(self)} between processed if " + f"it has not been initialized yet. Populate the buffer with " + f"some data in the main process before passing it to the other " + f"subprocesses (or create the buffer explicitely with a TensorStorage)." + ) + else: + # check that the content is shared, otherwise tell the user we can't help + storage = self._storage + STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes." + if is_tensor_collection(storage): + if not storage.is_memmap() and not storage.is_shared(): + raise RuntimeError(STORAGE_ERR) + else: + if ( + not isinstance(storage, MemoryMappedTensor) + and not storage.is_shared() + ): + raise RuntimeError(STORAGE_ERR) + return state def __setstate__(self, state): len = state.pop("len__context", None) if len is not None: - _len_value = mp.Value('i', len) + _len_value = mp.Value("i", len) state["_len_value"] = _len_value self.__dict__.update(state) - def state_dict(self) -> Dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index b56699cc9f0..f171fb2a9ff 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -6,14 +6,15 @@ import heapq from abc import ABC, abstractmethod from copy import copy +from multiprocessing.context import get_spawning_popen from typing import Any, Dict, Sequence import numpy as np import torch +from torch import multiprocessing as mp from .storages import Storage -from torch import multiprocessing as mp -from multiprocessing.context import get_spawning_popen + class Writer(ABC): """A ReplayBuffer base Writer class.""" @@ -44,34 +45,6 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return - @property - def _cursor(self): - _cursor_value = self.__dict__.get('_cursor_value', None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value('i', 0) - return _cursor_value.value - - @_cursor.setter - def _cursor(self, value): - _cursor_value = self.__dict__.get('_cursor_value', None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value('i', 0) - _cursor_value.value = value - - def __getstate__(self): - state = copy(self.__dict__) - if get_spawning_popen() is None: - cursor = self._cursor - del state["_cursor_value"] - state["cursor__context"] = cursor - return state - - def __setstate__(self, state): - cursor = state.pop("cursor__context", None) - if cursor is not None: - _cursor_value = mp.Value('i', cursor) - state["_cursor_value"] = _cursor_value - self.__dict__.update(state) class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers.""" @@ -82,14 +55,17 @@ def __init__(self, **kw) -> None: def add(self, data: Any) -> int: ret = self._cursor - self._storage[self._cursor] = data + _cursor = self._cursor + # we need to update the cursor first to avoid race conditions between workers self._cursor = (self._cursor + 1) % self._storage.max_size + self._storage[_cursor] = data return ret def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor batch_size = len(data) index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size + # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % self._storage.max_size self._storage[index] = data return index @@ -103,21 +79,52 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): self._cursor = 0 + @property + def _cursor(self): + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + return _cursor_value.value + + @_cursor.setter + def _cursor(self, value): + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + _cursor_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + cursor = self._cursor + del state["_cursor_value"] + state["cursor__context"] = cursor + return state + + def __setstate__(self, state): + cursor = state.pop("cursor__context", None) + if cursor is not None: + _cursor_value = mp.Value("i", cursor) + state["_cursor_value"] = _cursor_value + self.__dict__.update(state) + class TensorDictRoundRobinWriter(RoundRobinWriter): """A RoundRobin Writer class for composable, tensordict-based replay buffers.""" def add(self, data: Any) -> int: ret = self._cursor + # we need to update the cursor first to avoid race conditions between workers + self._cursor = (ret + 1) % self._storage.max_size data["index"] = ret - self._storage[self._cursor] = data - self._cursor = (self._cursor + 1) % self._storage.max_size + self._storage[ret] = data return ret def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor batch_size = len(data) index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size + # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % self._storage.max_size # storage must convert the data to the appropriate format if needed data["index"] = index @@ -250,3 +257,11 @@ def extend(self, data: Sequence) -> None: def _empty(self) -> None: self._cursor = 0 self._current_top_values = [] + + def __getstate__(self): + if get_spawning_popen() is not None: + raise RuntimeError( + f"Writers of type {type(self)} cannot be shared between processed." + ) + state = copy(self.__dict__) + return state From ab463f0ae166c2d8335a0ad51c41c2f84c31aafb Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 10:44:39 +0000 Subject: [PATCH 3/3] amend --- docs/source/reference/data.rst | 28 +++++++++++ test/test_rb.py | 66 +++++++++++++++++++------ torchrl/data/replay_buffers/storages.py | 8 +++ 3 files changed, 87 insertions(+), 15 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 98d2d40cd5c..38f1d22d3fd 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -60,6 +60,34 @@ The following mean sampling latency improvements over using ListStorage were fou | :class:`LazyMemmapStorage` | 3.44x | +-------------------------------+-----------+ +Replay buffers with a shared storage and regular (RoundRobin) writers can also +be shared between processes on a single node. This allows each worker to read and +write onto the storage. The following code snippet examplifies this feature: + + >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage + >>> import torch + >>> from torch import multiprocessing as mp + >>> from tensordict import TensorDict + >>> + >>> def worker(rb): + ... # Updates the replay buffer with new data + ... td = TensorDict({"a": torch.ones(10)}, [10]) + ... rb.extend(td) + ... + >>> if __name__ == "__main__": + ... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) + ... td = TensorDict({"a": torch.zeros(10)}, [10]) + ... rb.extend(td) + ... + ... proc = mp.Process(target=worker, args=(rb,)) + ... proc.start() + ... proc.join() + ... # the replay buffer now has a length of 20, since the worker updated it + ... assert len(rb) == 20 + ... assert (rb["_data", "a"][:10] == 0).all() # data from main process + ... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process + + Storing trajectories ~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_rb.py b/test/test_rb.py index 44bb3b258ea..961ad13b49e 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -42,6 +42,7 @@ from torchrl.data.replay_buffers.writers import ( RoundRobinWriter, TensorDictMaxValueWriter, + TensorDictRoundRobinWriter, ) from torchrl.envs.transforms.transforms import ( BinarizeReward, @@ -1289,7 +1290,7 @@ def test_max_value_writer(size, batch_size, reward_ranges, device): class TestMultiProc: @staticmethod def worker(rb, q0, q1): - td = TensorDict({"a": torch.ones(10)}, [10]) + td = TensorDict({"a": torch.ones(10), "next": {"reward": torch.ones(10)}}, [10]) rb.extend(td) q0.put("extended") extended = q1.get(timeout=5) @@ -1298,24 +1299,59 @@ def worker(rb, q0, q1): assert (rb["_data", "a"][:9] == 2).all() q0.put("finish") - def test_multiproc_rb(self): - rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) - td = TensorDict({"a": torch.zeros(10)}, [10]) - rb.extend(td) + def exec_multiproc_rb( + self, + storage_type=LazyMemmapStorage, + init=True, + writer_type=TensorDictRoundRobinWriter, + ): + rb = TensorDictReplayBuffer(storage=storage_type(21), writer=writer_type()) + if init: + td = TensorDict( + {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10] + ) + rb.extend(td) q0 = mp.Queue(1) q1 = mp.Queue(1) proc = mp.Process(target=self.worker, args=(rb, q0, q1)) proc.start() - extended = q0.get(timeout=100) - assert extended == "extended" - assert len(rb) == 20 - assert (rb["_data", "a"][10:20] == 1).all() - td = TensorDict({"a": torch.zeros(10) + 2}, [10]) - rb.extend(td) - q1.put("extended") - finish = q0.get(timeout=5) - assert finish == "finish" - proc.join() + try: + extended = q0.get(timeout=100) + assert extended == "extended" + assert len(rb) == 20 + assert (rb["_data", "a"][10:20] == 1).all() + td = TensorDict({"a": torch.zeros(10) + 2}, [10]) + rb.extend(td) + q1.put("extended") + finish = q0.get(timeout=5) + assert finish == "finish" + finally: + proc.join() + + def test_multiproc_rb(self): + return self.exec_multiproc_rb() + + def test_error_list(self): + # list storage cannot be shared + with pytest.raises(RuntimeError, match="Cannot share a storage of type"): + self.exec_multiproc_rb(storage_type=ListStorage) + + def test_error_nonshared(self): + # non shared tensor storage cannot be shared + with pytest.raises( + RuntimeError, match="The storage must be place in shared memory" + ): + self.exec_multiproc_rb(storage_type=LazyTensorStorage) + + def test_error_maxwriter(self): + # TensorDictMaxValueWriter cannot be shared + with pytest.raises(RuntimeError, match="cannot be shared between processed"): + self.exec_multiproc_rb(writer_type=TensorDictMaxValueWriter) + + def test_error_noninit(self): + # list storage cannot be shared + with pytest.raises(RuntimeError, match="it has not been initialized yet"): + self.exec_multiproc_rb(init=False) if __name__ == "__main__": diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 9d2674036e4..e82f79b4774 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -168,6 +168,14 @@ def load_state_dict(self, state_dict): def _empty(self): self._storage = [] + def __getstate__(self): + if get_spawning_popen() is not None: + raise RuntimeError( + f"Cannot share a storage of type {type(self)} between processes." + ) + state = copy(self.__dict__) + return state + class TensorStorage(Storage): """A storage for tensors and tensordicts.