From 5d0ad99626d86420bb9758a8b44cd0cb9126703e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Dec 2023 09:20:51 +0000 Subject: [PATCH] amend --- test/test_rb.py | 2 +- torchrl/data/replay_buffers/storages.py | 38 +++++++++--- torchrl/data/replay_buffers/writers.py | 81 +++++++++++++++---------- 3 files changed, 78 insertions(+), 43 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 19e36d30558..44bb3b258ea 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -18,6 +18,7 @@ from packaging.version import parse from tensordict import is_tensorclass, tensorclass from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from torch import multiprocessing as mp from torchrl.data import ( PrioritizedReplayBuffer, RemoteTensorDictReplayBuffer, @@ -63,7 +64,6 @@ UnsqueezeTransform, VecNorm, ) -from torch import multiprocessing as mp OLD_TORCH = parse(torch.__version__) < parse("2.0.0") _has_tv = importlib.util.find_spec("torchvision") is not None diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 439b2a1cd47..9d2674036e4 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -8,6 +8,7 @@ import warnings from collections import OrderedDict from copy import copy +from multiprocessing.context import get_spawning_popen from typing import Any, Dict, Sequence, Union import torch @@ -15,11 +16,10 @@ from tensordict.memmap import MemmapTensor, MemoryMappedTensor from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.utils import expand_right +from torch import multiprocessing as mp from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE from torchrl.data.replay_buffers.utils import INT_CLASSES -from torch import multiprocessing as mp -from multiprocessing.context import get_spawning_popen try: from torchsnapshot.serialization import tensor_from_memoryview @@ -261,19 +261,18 @@ def __init__(self, storage, max_size=None, device="cpu"): ) self._storage = storage - @property def _len(self): - _len_value = self.__dict__.get('_len_value', None) + _len_value = self.__dict__.get("_len_value", None) if _len_value is None: - _len_value = self._len_value = mp.Value('i', 0) + _len_value = self._len_value = mp.Value("i", 0) return _len_value.value @_len.setter def _len(self, value): - _len_value = self.__dict__.get('_len_value', None) + _len_value = self.__dict__.get("_len_value", None) if _len_value is None: - _len_value = self._len_value = mp.Value('i', 0) + _len_value = self._len_value = mp.Value("i", 0) _len_value.value = value def __getstate__(self): @@ -282,16 +281,37 @@ def __getstate__(self): len = self._len del state["_len_value"] state["len__context"] = len + elif not self.initialized: + # check that the storage is initialized + raise RuntimeError( + f"Cannot share a storage of type {type(self)} between processed if " + f"it has not been initialized yet. Populate the buffer with " + f"some data in the main process before passing it to the other " + f"subprocesses (or create the buffer explicitely with a TensorStorage)." + ) + else: + # check that the content is shared, otherwise tell the user we can't help + storage = self._storage + STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes." + if is_tensor_collection(storage): + if not storage.is_memmap() and not storage.is_shared(): + raise RuntimeError(STORAGE_ERR) + else: + if ( + not isinstance(storage, MemoryMappedTensor) + and not storage.is_shared() + ): + raise RuntimeError(STORAGE_ERR) + return state def __setstate__(self, state): len = state.pop("len__context", None) if len is not None: - _len_value = mp.Value('i', len) + _len_value = mp.Value("i", len) state["_len_value"] = _len_value self.__dict__.update(state) - def state_dict(self) -> Dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index b56699cc9f0..f171fb2a9ff 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -6,14 +6,15 @@ import heapq from abc import ABC, abstractmethod from copy import copy +from multiprocessing.context import get_spawning_popen from typing import Any, Dict, Sequence import numpy as np import torch +from torch import multiprocessing as mp from .storages import Storage -from torch import multiprocessing as mp -from multiprocessing.context import get_spawning_popen + class Writer(ABC): """A ReplayBuffer base Writer class.""" @@ -44,34 +45,6 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return - @property - def _cursor(self): - _cursor_value = self.__dict__.get('_cursor_value', None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value('i', 0) - return _cursor_value.value - - @_cursor.setter - def _cursor(self, value): - _cursor_value = self.__dict__.get('_cursor_value', None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value('i', 0) - _cursor_value.value = value - - def __getstate__(self): - state = copy(self.__dict__) - if get_spawning_popen() is None: - cursor = self._cursor - del state["_cursor_value"] - state["cursor__context"] = cursor - return state - - def __setstate__(self, state): - cursor = state.pop("cursor__context", None) - if cursor is not None: - _cursor_value = mp.Value('i', cursor) - state["_cursor_value"] = _cursor_value - self.__dict__.update(state) class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers.""" @@ -82,14 +55,17 @@ def __init__(self, **kw) -> None: def add(self, data: Any) -> int: ret = self._cursor - self._storage[self._cursor] = data + _cursor = self._cursor + # we need to update the cursor first to avoid race conditions between workers self._cursor = (self._cursor + 1) % self._storage.max_size + self._storage[_cursor] = data return ret def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor batch_size = len(data) index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size + # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % self._storage.max_size self._storage[index] = data return index @@ -103,21 +79,52 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): self._cursor = 0 + @property + def _cursor(self): + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + return _cursor_value.value + + @_cursor.setter + def _cursor(self, value): + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + _cursor_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + cursor = self._cursor + del state["_cursor_value"] + state["cursor__context"] = cursor + return state + + def __setstate__(self, state): + cursor = state.pop("cursor__context", None) + if cursor is not None: + _cursor_value = mp.Value("i", cursor) + state["_cursor_value"] = _cursor_value + self.__dict__.update(state) + class TensorDictRoundRobinWriter(RoundRobinWriter): """A RoundRobin Writer class for composable, tensordict-based replay buffers.""" def add(self, data: Any) -> int: ret = self._cursor + # we need to update the cursor first to avoid race conditions between workers + self._cursor = (ret + 1) % self._storage.max_size data["index"] = ret - self._storage[self._cursor] = data - self._cursor = (self._cursor + 1) % self._storage.max_size + self._storage[ret] = data return ret def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor batch_size = len(data) index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size + # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % self._storage.max_size # storage must convert the data to the appropriate format if needed data["index"] = index @@ -250,3 +257,11 @@ def extend(self, data: Sequence) -> None: def _empty(self) -> None: self._cursor = 0 self._current_top_values = [] + + def __getstate__(self): + if get_spawning_popen() is not None: + raise RuntimeError( + f"Writers of type {type(self)} cannot be shared between processed." + ) + state = copy(self.__dict__) + return state