Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 1, 2023
1 parent bc360aa commit 5d0ad99
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 43 deletions.
2 changes: 1 addition & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from packaging.version import parse
from tensordict import is_tensorclass, tensorclass
from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl.data import (
PrioritizedReplayBuffer,
RemoteTensorDictReplayBuffer,
Expand Down Expand Up @@ -63,7 +64,6 @@
UnsqueezeTransform,
VecNorm,
)
from torch import multiprocessing as mp

OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
Expand Down
38 changes: 29 additions & 9 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
import warnings
from collections import OrderedDict
from copy import copy
from multiprocessing.context import get_spawning_popen
from typing import Any, Dict, Sequence, Union

import torch
from tensordict import is_tensorclass
from tensordict.memmap import MemmapTensor, MemoryMappedTensor
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import expand_right
from torch import multiprocessing as mp

from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE
from torchrl.data.replay_buffers.utils import INT_CLASSES
from torch import multiprocessing as mp
from multiprocessing.context import get_spawning_popen

try:
from torchsnapshot.serialization import tensor_from_memoryview
Expand Down Expand Up @@ -261,19 +261,18 @@ def __init__(self, storage, max_size=None, device="cpu"):
)
self._storage = storage


@property
def _len(self):
_len_value = self.__dict__.get('_len_value', None)
_len_value = self.__dict__.get("_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value('i', 0)
_len_value = self._len_value = mp.Value("i", 0)
return _len_value.value

@_len.setter
def _len(self, value):
_len_value = self.__dict__.get('_len_value', None)
_len_value = self.__dict__.get("_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value('i', 0)
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value

def __getstate__(self):
Expand All @@ -282,16 +281,37 @@ def __getstate__(self):
len = self._len
del state["_len_value"]
state["len__context"] = len
elif not self.initialized:
# check that the storage is initialized
raise RuntimeError(
f"Cannot share a storage of type {type(self)} between processed if "
f"it has not been initialized yet. Populate the buffer with "
f"some data in the main process before passing it to the other "
f"subprocesses (or create the buffer explicitely with a TensorStorage)."
)
else:
# check that the content is shared, otherwise tell the user we can't help
storage = self._storage
STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes."
if is_tensor_collection(storage):
if not storage.is_memmap() and not storage.is_shared():
raise RuntimeError(STORAGE_ERR)
else:
if (
not isinstance(storage, MemoryMappedTensor)
and not storage.is_shared()
):
raise RuntimeError(STORAGE_ERR)

return state

def __setstate__(self, state):
len = state.pop("len__context", None)
if len is not None:
_len_value = mp.Value('i', len)
_len_value = mp.Value("i", len)
state["_len_value"] = _len_value
self.__dict__.update(state)


def state_dict(self) -> Dict[str, Any]:
_storage = self._storage
if isinstance(_storage, torch.Tensor):
Expand Down
81 changes: 48 additions & 33 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import heapq
from abc import ABC, abstractmethod
from copy import copy
from multiprocessing.context import get_spawning_popen
from typing import Any, Dict, Sequence

import numpy as np
import torch
from torch import multiprocessing as mp

from .storages import Storage
from torch import multiprocessing as mp
from multiprocessing.context import get_spawning_popen


class Writer(ABC):
"""A ReplayBuffer base Writer class."""
Expand Down Expand Up @@ -44,34 +45,6 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
return

@property
def _cursor(self):
_cursor_value = self.__dict__.get('_cursor_value', None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value('i', 0)
return _cursor_value.value

@_cursor.setter
def _cursor(self, value):
_cursor_value = self.__dict__.get('_cursor_value', None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value('i', 0)
_cursor_value.value = value

def __getstate__(self):
state = copy(self.__dict__)
if get_spawning_popen() is None:
cursor = self._cursor
del state["_cursor_value"]
state["cursor__context"] = cursor
return state

def __setstate__(self, state):
cursor = state.pop("cursor__context", None)
if cursor is not None:
_cursor_value = mp.Value('i', cursor)
state["_cursor_value"] = _cursor_value
self.__dict__.update(state)

class RoundRobinWriter(Writer):
"""A RoundRobin Writer class for composable replay buffers."""
Expand All @@ -82,14 +55,17 @@ def __init__(self, **kw) -> None:

def add(self, data: Any) -> int:
ret = self._cursor
self._storage[self._cursor] = data
_cursor = self._cursor
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (self._cursor + 1) % self._storage.max_size
self._storage[_cursor] = data
return ret

def extend(self, data: Sequence) -> torch.Tensor:
cur_size = self._cursor
batch_size = len(data)
index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % self._storage.max_size
self._storage[index] = data
return index
Expand All @@ -103,21 +79,52 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _empty(self):
self._cursor = 0

@property
def _cursor(self):
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
return _cursor_value.value

@_cursor.setter
def _cursor(self, value):
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value

def __getstate__(self):
state = copy(self.__dict__)
if get_spawning_popen() is None:
cursor = self._cursor
del state["_cursor_value"]
state["cursor__context"] = cursor
return state

def __setstate__(self, state):
cursor = state.pop("cursor__context", None)
if cursor is not None:
_cursor_value = mp.Value("i", cursor)
state["_cursor_value"] = _cursor_value
self.__dict__.update(state)


class TensorDictRoundRobinWriter(RoundRobinWriter):
"""A RoundRobin Writer class for composable, tensordict-based replay buffers."""

def add(self, data: Any) -> int:
ret = self._cursor
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (ret + 1) % self._storage.max_size
data["index"] = ret
self._storage[self._cursor] = data
self._cursor = (self._cursor + 1) % self._storage.max_size
self._storage[ret] = data
return ret

def extend(self, data: Sequence) -> torch.Tensor:
cur_size = self._cursor
batch_size = len(data)
index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % self._storage.max_size
# storage must convert the data to the appropriate format if needed
data["index"] = index
Expand Down Expand Up @@ -250,3 +257,11 @@ def extend(self, data: Sequence) -> None:
def _empty(self) -> None:
self._cursor = 0
self._current_top_values = []

def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
f"Writers of type {type(self)} cannot be shared between processed."
)
state = copy(self.__dict__)
return state

0 comments on commit 5d0ad99

Please sign in to comment.