From bc360aa0bd82e8e52d33305a09d8d803f6f59493 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 30 Nov 2023 16:53:50 +0000 Subject: [PATCH] init --- test/test_rb.py | 33 +++++++++++++++++++++++++ torchrl/data/replay_buffers/storages.py | 33 +++++++++++++++++++++++++ torchrl/data/replay_buffers/writers.py | 32 +++++++++++++++++++++++- 3 files changed, 97 insertions(+), 1 deletion(-) diff --git a/test/test_rb.py b/test/test_rb.py index c68c623300b..19e36d30558 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -63,6 +63,7 @@ UnsqueezeTransform, VecNorm, ) +from torch import multiprocessing as mp OLD_TORCH = parse(torch.__version__) < parse("2.0.0") _has_tv = importlib.util.find_spec("torchvision") is not None @@ -1285,6 +1286,38 @@ def test_max_value_writer(size, batch_size, reward_ranges, device): assert (sample.get("key") != 0).all() +class TestMultiProc: + @staticmethod + def worker(rb, q0, q1): + td = TensorDict({"a": torch.ones(10)}, [10]) + rb.extend(td) + q0.put("extended") + extended = q1.get(timeout=5) + assert extended == "extended" + assert len(rb) == 21, len(rb) + assert (rb["_data", "a"][:9] == 2).all() + q0.put("finish") + + def test_multiproc_rb(self): + rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) + td = TensorDict({"a": torch.zeros(10)}, [10]) + rb.extend(td) + q0 = mp.Queue(1) + q1 = mp.Queue(1) + proc = mp.Process(target=self.worker, args=(rb, q0, q1)) + proc.start() + extended = q0.get(timeout=100) + assert extended == "extended" + assert len(rb) == 20 + assert (rb["_data", "a"][10:20] == 1).all() + td = TensorDict({"a": torch.zeros(10) + 2}, [10]) + rb.extend(td) + q1.put("extended") + finish = q0.get(timeout=5) + assert finish == "finish" + proc.join() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 9c8417b9c97..439b2a1cd47 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -18,6 +18,8 @@ from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE from torchrl.data.replay_buffers.utils import INT_CLASSES +from torch import multiprocessing as mp +from multiprocessing.context import get_spawning_popen try: from torchsnapshot.serialization import tensor_from_memoryview @@ -259,6 +261,37 @@ def __init__(self, storage, max_size=None, device="cpu"): ) self._storage = storage + + @property + def _len(self): + _len_value = self.__dict__.get('_len_value', None) + if _len_value is None: + _len_value = self._len_value = mp.Value('i', 0) + return _len_value.value + + @_len.setter + def _len(self, value): + _len_value = self.__dict__.get('_len_value', None) + if _len_value is None: + _len_value = self._len_value = mp.Value('i', 0) + _len_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + len = self._len + del state["_len_value"] + state["len__context"] = len + return state + + def __setstate__(self, state): + len = state.pop("len__context", None) + if len is not None: + _len_value = mp.Value('i', len) + state["_len_value"] = _len_value + self.__dict__.update(state) + + def state_dict(self) -> Dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index cf78e0a0d99..b56699cc9f0 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -5,13 +5,15 @@ import heapq from abc import ABC, abstractmethod +from copy import copy from typing import Any, Dict, Sequence import numpy as np import torch from .storages import Storage - +from torch import multiprocessing as mp +from multiprocessing.context import get_spawning_popen class Writer(ABC): """A ReplayBuffer base Writer class.""" @@ -42,6 +44,34 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return + @property + def _cursor(self): + _cursor_value = self.__dict__.get('_cursor_value', None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value('i', 0) + return _cursor_value.value + + @_cursor.setter + def _cursor(self, value): + _cursor_value = self.__dict__.get('_cursor_value', None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value('i', 0) + _cursor_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + cursor = self._cursor + del state["_cursor_value"] + state["cursor__context"] = cursor + return state + + def __setstate__(self, state): + cursor = state.pop("cursor__context", None) + if cursor is not None: + _cursor_value = mp.Value('i', cursor) + state["_cursor_value"] = _cursor_value + self.__dict__.update(state) class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers."""