Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 30, 2023
1 parent 6c27bdb commit bc360aa
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
33 changes: 33 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
UnsqueezeTransform,
VecNorm,
)
from torch import multiprocessing as mp

OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
Expand Down Expand Up @@ -1285,6 +1286,38 @@ def test_max_value_writer(size, batch_size, reward_ranges, device):
assert (sample.get("key") != 0).all()


class TestMultiProc:
@staticmethod
def worker(rb, q0, q1):
td = TensorDict({"a": torch.ones(10)}, [10])
rb.extend(td)
q0.put("extended")
extended = q1.get(timeout=5)
assert extended == "extended"
assert len(rb) == 21, len(rb)
assert (rb["_data", "a"][:9] == 2).all()
q0.put("finish")

def test_multiproc_rb(self):
rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21))
td = TensorDict({"a": torch.zeros(10)}, [10])
rb.extend(td)
q0 = mp.Queue(1)
q1 = mp.Queue(1)
proc = mp.Process(target=self.worker, args=(rb, q0, q1))
proc.start()
extended = q0.get(timeout=100)
assert extended == "extended"
assert len(rb) == 20
assert (rb["_data", "a"][10:20] == 1).all()
td = TensorDict({"a": torch.zeros(10) + 2}, [10])
rb.extend(td)
q1.put("extended")
finish = q0.get(timeout=5)
assert finish == "finish"
proc.join()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
33 changes: 33 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE
from torchrl.data.replay_buffers.utils import INT_CLASSES
from torch import multiprocessing as mp
from multiprocessing.context import get_spawning_popen

try:
from torchsnapshot.serialization import tensor_from_memoryview
Expand Down Expand Up @@ -259,6 +261,37 @@ def __init__(self, storage, max_size=None, device="cpu"):
)
self._storage = storage


@property
def _len(self):
_len_value = self.__dict__.get('_len_value', None)
if _len_value is None:
_len_value = self._len_value = mp.Value('i', 0)
return _len_value.value

@_len.setter
def _len(self, value):
_len_value = self.__dict__.get('_len_value', None)
if _len_value is None:
_len_value = self._len_value = mp.Value('i', 0)
_len_value.value = value

def __getstate__(self):
state = copy(self.__dict__)
if get_spawning_popen() is None:
len = self._len
del state["_len_value"]
state["len__context"] = len
return state

def __setstate__(self, state):
len = state.pop("len__context", None)
if len is not None:
_len_value = mp.Value('i', len)
state["_len_value"] = _len_value
self.__dict__.update(state)


def state_dict(self) -> Dict[str, Any]:
_storage = self._storage
if isinstance(_storage, torch.Tensor):
Expand Down
32 changes: 31 additions & 1 deletion torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import heapq
from abc import ABC, abstractmethod
from copy import copy
from typing import Any, Dict, Sequence

import numpy as np
import torch

from .storages import Storage

from torch import multiprocessing as mp
from multiprocessing.context import get_spawning_popen

class Writer(ABC):
"""A ReplayBuffer base Writer class."""
Expand Down Expand Up @@ -42,6 +44,34 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
return

@property
def _cursor(self):
_cursor_value = self.__dict__.get('_cursor_value', None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value('i', 0)
return _cursor_value.value

@_cursor.setter
def _cursor(self, value):
_cursor_value = self.__dict__.get('_cursor_value', None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value('i', 0)
_cursor_value.value = value

def __getstate__(self):
state = copy(self.__dict__)
if get_spawning_popen() is None:
cursor = self._cursor
del state["_cursor_value"]
state["cursor__context"] = cursor
return state

def __setstate__(self, state):
cursor = state.pop("cursor__context", None)
if cursor is not None:
_cursor_value = mp.Value('i', cursor)
state["_cursor_value"] = _cursor_value
self.__dict__.update(state)

class RoundRobinWriter(Writer):
"""A RoundRobin Writer class for composable replay buffers."""
Expand Down

0 comments on commit bc360aa

Please sign in to comment.