From 86915fe373a7e247a67d6af3a6fb870c0756e928 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Mar 2023 11:01:43 +0100 Subject: [PATCH 1/6] init --- torchrl/data/replay_buffers/replay_buffers.py | 426 +++++++++++++++--- torchrl/data/replay_buffers/samplers.py | 4 + 2 files changed, 374 insertions(+), 56 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 0a20dc6dff7..1e824817b92 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -82,32 +82,80 @@ def decorated_fun(self, *args, **kwargs): class ReplayBuffer: """A generic, composable replay buffer class. + All arguments are keyword-only arguments. + Args: storage (Storage, optional): the storage to be used. If none is provided - a default ListStorage with max_size of 1_000 will be created. - sampler (Sampler, optional): the sampler to be used. If none is provided - a default RandomSampler() will be used. + a default :class:`torchrl.data.replay_buffers.ListStorage` with + ``max_size`` of ``1_000`` will be created. + sampler (Sampler, optional): the sampler to be used. If none is provided, + a default :class:`torchrl.data.replay_buffers.RandomSampler` + will be used. writer (Writer, optional): the writer to be used. If none is provided - a default RoundRobinWriter() will be used. + a default :class:`torchrl.data.replay_buffers.RoundRobinWriter` + will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched - loading from a map-style dataset. + loading from a map-style dataset. The default value will be decided + based on the storage type. pin_memory (bool): whether pin_memory() should be called on the rb samples. prefetch (int, optional): number of next batches to be prefetched - using multithreading. - transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + using multithreading. Defaults to None (no prefetching). + transform (Transform, optional): Transform to be executed when + sample() is called. + To chain transforms use the :class:`torchrl.envs.Compose` class. Transforms should be used with :class:`tensordict.TensorDict` content. If used with other structures, the transforms should be - encoded with a `"data"` leading key that will be used to + encoded with a ``"data"`` leading key that will be used to construct a tensordict from the non-tensordict content. - batch_size (int, optional): the batch size to be used when sample() is called. + batch_size (int, optional): the batch size to be used when sample() is + called. + .. note:: + The batch-size can be specified at construction time via the + ``batch_size`` argument, or at sampling time. The former should + be preferred whenever the batch-size is consistent across the + experiment. If the batch-size is likely to change, it can be + passed to the :meth:`~.sample` method. This option is + incompatible with prefetching (since this requires to know the + batch-size in advance) as well as with samplers that have a + ``drop_last`` argument. + Examples: + >>> import torch + >>> + >>> from torchrl.data import ReplayBuffer, ListStorage + >>> + >>> torch.manual_seed(0) + >>> rb = ReplayBuffer( + ... storage=ListStorage(max_size=1000), + ... batch_size=5, + ... ) + >>> # populate the replay buffer + >>> data = range(10) + >>> rb.extend(data) + >>> # sample will return as many elements as specified in the constructor + >>> sample = rb.sample() + >>> print(sample) + tensor([4, 9, 3, 0, 3]) + >>> # Passing the batch-size to the sample method overrides the one in the constructor + >>> sample = rb.sample(batch_size=3) + >>> print(sample) + tensor([9, 7, 3]) + >>> # one cans sample using the ``sample`` method or iterate over the buffer + >>> for i, batch in enumerate(rb): + ... print(i, batch) + ... if i == 3: + ... break + 0 tensor([7, 3, 1, 6, 6]) + 1 tensor([9, 8, 6, 6, 8]) + 2 tensor([4, 3, 6, 9, 1]) + 3 tensor([4, 4, 1, 9, 9]) """ def __init__( self, + *, storage: Optional[Storage] = None, sampler: Optional[Sampler] = None, writer: Optional[Writer] = None, @@ -147,10 +195,21 @@ def __init__( transform.eval() self._transform = transform - if batch_size is None: - warnings.warn( - "Constructing replay buffer without specifying behaviour is no longer " - "recommended, and will be deprecated in the future." + if batch_size is None and prefetch: + raise ValueError( + "Dynamic batch-size specification is incompatible " + "with multithreaded sampling. " + "When using prefetch, the batch-size must be specified in " + "advance. " + ) + if ( + batch_size is None + and hasattr(self._sampler, "drop_last") + and self._sampler.drop_last + ): + raise ValueError( + "Samplers with drop_last=True must work with a predictible batch-size. " + "Please pass the batch-size to the ReplayBuffer constructor." ) self._batch_size = batch_size @@ -247,6 +306,7 @@ def update_priority( def _sample(self, batch_size: int) -> Tuple[Any, dict]: with self._replay_lock: index, info = self._sampler.sample(self._storage, batch_size) + info["index"] = index data = self._storage[index] if not isinstance(index, INT_CLASSES): data = self._collate_fn(data) @@ -279,17 +339,26 @@ def sample( A batch of data selected in the replay buffer. A tuple containing this batch and info if return_info flag is set to True. """ - if batch_size is not None: + if ( + batch_size is not None + and self._batch_size is not None + and batch_size != self._batch_size + ): warnings.warn( - "batch_size argument in sample has been deprecated. Set the batch_size " - "when constructing the replay buffer instead." + f"Got conflicting batch_sizes in constructor ({self._batch_size}) " + f"and `sample` ({batch_size}). Refer to the ReplayBuffer documentation " + "for a proper usage of the batch-size arguments. " + "The batch-size provided to the sample method " + "will prevail." ) - elif self._batch_size is not None: + elif batch_size is None and self._batch_size is not None: batch_size = self._batch_size - else: + elif batch_size is None: raise RuntimeError( "batch_size not specified. You can specify the batch_size when " - "constructing the replay buffer" + "constructing the replay buffer, or pass it to the sample method. " + "Refer to the ReplayBuffer documentation " + "for a proper usage of the batch-size arguments." ) if not self._prefetch: ret = self._sample(batch_size) @@ -336,9 +405,12 @@ def insert_transform(self, index: int, transform: "Transform") -> None: # noqa- self._transform.insert(index, transform) def __iter__(self): + if self._sampler.ran_out: + self._sampler.ran_out = False if self._batch_size is None: raise RuntimeError( - "batch_size was not specified during construction of the replay buffer" + "Cannot iterate over the replay buffer. " + "Batch_size was not specified during construction of the replay buffer." ) while not self._sampler.ran_out: data = self.sample() @@ -348,6 +420,8 @@ def __iter__(self): class PrioritizedReplayBuffer(ReplayBuffer): """Prioritized replay buffer. + All arguments are keyword-only arguments. + Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." @@ -359,22 +433,75 @@ class PrioritizedReplayBuffer(ReplayBuffer): beta (float): importance sampling negative exponent. eps (float): delta added to the priorities to ensure that the buffer does not contain null priorities. - dtype (torch.dtype): type of the data. Can be torch.float or torch.double. storage (Storage, optional): the storage to be used. If none is provided - a default ListStorage with max_size of 1_000 will be created. + a default :class:`torchrl.data.replay_buffers.ListStorage` with + ``max_size`` of ``1_000`` will be created. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched - loading from a map-style dataset. + loading from a map-style dataset. The default value will be decided + based on the storage type. pin_memory (bool): whether pin_memory() should be called on the rb samples. prefetch (int, optional): number of next batches to be prefetched - using multithreading. - transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + using multithreading. Defaults to None (no prefetching). + transform (Transform, optional): Transform to be executed when + sample() is called. + To chain transforms use the :class:`torchrl.envs.Compose` class. + Transforms should be used with :class:`tensordict.TensorDict` + content. If used with other structures, the transforms should be + encoded with a ``"data"`` leading key that will be used to + construct a tensordict from the non-tensordict content. + batch_size (int, optional): the batch size to be used when sample() is + called. + .. note:: + The batch-size can be specified at construction time via the + ``batch_size`` argument, or at sampling time. The former should + be preferred whenever the batch-size is consistent across the + experiment. If the batch-size is likely to change, it can be + passed to the :meth:`~.sample` method. This option is + incompatible with prefetching (since this requires to know the + batch-size in advance) as well as with samplers that have a + ``drop_last`` argument. + + .. note:: + Generic prioritized replay buffers (ie. non-tensordict backed) require + calling :meth:`~.sample` with the ``return_info`` argument set to + ``True`` to have access to the indices, and hence update the priority. + Using :class:`tensordict.TensorDict` and the related + :class:`torchrl.data.TensorDictPrioritizedReplayBuffer` simplifies this + process. + + Examples: + >>> import torch + >>> + >>> from torchrl.data import ListStorage, PrioritizedReplayBuffer + >>> + >>> torch.manual_seed(0) + >>> + >>> rb = PrioritizedReplayBuffer(alpha=0.7, beta=0.9, storage=ListStorage(10)) + >>> data = range(10) + >>> rb.extend(data) + >>> sample = rb.sample(3) + >>> print(sample) + tensor([1, 0, 1]) + >>> # get the info to find what the indices are + >>> sample, info = rb.sample(5, return_info=True) + >>> print(sample, info) + tensor([2, 7, 4, 3, 5]) {'_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])} + >>> # update priority + >>> priority = torch.ones(5) * 5 + >>> rb.update_priority(info["index"], priority) + >>> # and now a new sample, the weights should be updated + >>> sample, info = rb.sample(5, return_info=True) + >>> print(sample, info) + tensor([2, 5, 2, 2, 5]) {'_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465], + dtype=float32), 'index': array([2, 5, 2, 2, 5])} + """ def __init__( self, + *, alpha: float, beta: float, eps: float = 1e-8, @@ -401,15 +528,114 @@ def __init__( class TensorDictReplayBuffer(ReplayBuffer): - """TensorDict-specific wrapper around the ReplayBuffer class. + """TensorDict-specific wrapper around the :class:`torchrl.data.ReplayBuffer` class. + + All arguments are keyword-only arguments. Args: - priority_key (str): the key at which priority is assumed to be stored - within TensorDicts added to this ReplayBuffer. + storage (Storage, optional): the storage to be used. If none is provided + a default :class:`torchrl.data.replay_buffers.ListStorage` with + ``max_size`` of ``1_000`` will be created. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default :class:`torchrl.data.replay_buffers.RoundRobinWriter` + will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. The default value will be decided + based on the storage type. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. Defaults to None (no prefetching). + transform (Transform, optional): Transform to be executed when + sample() is called. + To chain transforms use the :class:`torchrl.envs.Compose` class. + Transforms should be used with :class:`tensordict.TensorDict` + content. If used with other structures, the transforms should be + encoded with a ``"data"`` leading key that will be used to + construct a tensordict from the non-tensordict content. + batch_size (int, optional): the batch size to be used when sample() is + called. + .. note:: + The batch-size can be specified at construction time via the + ``batch_size`` argument, or at sampling time. The former should + be preferred whenever the batch-size is consistent across the + experiment. If the batch-size is likely to change, it can be + passed to the :meth:`~.sample` method. This option is + incompatible with prefetching (since this requires to know the + batch-size in advance) as well as with samplers that have a + ``drop_last`` argument. + priority_key (str, optional): the key at which priority is assumed to + be stored within TensorDicts added to this ReplayBuffer. + This is to be used when the sampler is of type + :class:`torchrl.data.PrioritizedSampler`. + Defaults to ``"td_error"``. + + Examples: + >>> import torch + >>> + >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer + >>> from tensordict import TensorDict + >>> + >>> torch.manual_seed(0) + >>> + >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10), batch_size=5) + >>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 1, 1)}, [10]) + >>> rb.extend(data) + >>> sample = rb.sample(3) + >>> # samples keep track of the index + >>> print(sample) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False), + index: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) + >>> # we can iterate over the buffer + >>> for i, data in enumerate(rb): + ... print(i, data) + ... if i == 2: + ... break + 0 TensorDict( + fields={ + a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + 1 TensorDict( + fields={ + a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + """ - def __init__(self, *args, priority_key: str = "td_error", **kw) -> None: - super().__init__(*args, **kw) + def __init__(self, *, priority_key: str = "td_error", **kw) -> None: + super().__init__(**kw) self.priority_key = priority_key def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: @@ -498,8 +724,8 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: def sample( self, batch_size: Optional[int] = None, - include_info: bool = False, return_info: bool = False, + include_info: bool = None, ) -> TensorDictBase: """Samples a batch of data from the replay buffer. @@ -509,7 +735,6 @@ def sample( batch_size (int, optional): size of data to be collected. If none is provided, this method will sample a batch-size as indicated by the sampler. - include_info (bool): whether to add info to the returned tensordict. return_info (bool): whether to return info. If True, the result is a tuple (data, info). If False, the result is the data. @@ -517,10 +742,18 @@ def sample( A tensordict containing a batch of data selected in the replay buffer. A tuple containing this tensordict and info if return_info flag is set to True. """ + if include_info is not None: + warnings.warn( + "include_info is going to be deprecated soon." + "The default behaviour has changed to `include_info=True` " + "to avoid bugs linked to wrongly preassigned values in the " + "output tensordict." + ) + data, info = super().sample(batch_size, return_info=True) - if include_info: + if include_info in (True, None): for k, v in info.items(): - data.set(k, torch.tensor(v, device=data.device), inplace=True) + data.set(k, torch.tensor(v, device=data.device)) if "_batch_size" in data.keys(): # we need to reset the batch-size shape = data.pop("_batch_size") @@ -539,40 +772,119 @@ def sample( class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): - """TensorDict-specific wrapper around the PrioritizedReplayBuffer class. + """TensorDict-specific wrapper around the :class:`torchrl.data.PrioritizedReplayBuffer` class. - This class returns tensordicts with a new key "index" that represents + All arguments are keyword-only arguments. + + This class returns tensordicts with a new key ``"index"`` that represents the index of each element in the replay buffer. It also provides the - 'update_tensordict_priority' method that only requires for the + :meth:`~.update_tensordict_priority` method that only requires for the tensordict to be passed to it with its new priority value. Args: - alpha (float): exponent α determines how much prioritization is - used, with α = 0 corresponding to the uniform case. + alpha (float): exponent α determines how much prioritization is used, + with α = 0 corresponding to the uniform case. beta (float): importance sampling negative exponent. - priority_key (str, optional): key where the priority value can be - found in the stored tensordicts. Default is :obj:`"td_error"` - eps (float, optional): delta added to the priorities to ensure that the - buffer does not contain null priorities. - dtype (torch.dtype): type of the data. Can be torch.float or torch.double. + eps (float): delta added to the priorities to ensure that the buffer + does not contain null priorities. storage (Storage, optional): the storage to be used. If none is provided - a default ListStorage with max_size of 1_000 will be created. + a default :class:`torchrl.data.replay_buffers.ListStorage` with + ``max_size`` of ``1_000`` will be created. collate_fn (callable, optional): merges a list of samples to form a - mini-batch of Tensor(s)/outputs. Used when using batched loading - from a map-style dataset. - pin_memory (bool, optional): whether pin_memory() should be called on - the rb samples. Default is ``False``. + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. The default value will be decided + based on the storage type. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. prefetch (int, optional): number of next batches to be prefetched - using multithreading. - transform (Transform, optional): Transform to be executed when sample() is called. - To chain transforms use the :obj:`Compose` class. + using multithreading. Defaults to None (no prefetching). + transform (Transform, optional): Transform to be executed when + sample() is called. + To chain transforms use the :class:`torchrl.envs.Compose` class. + Transforms should be used with :class:`tensordict.TensorDict` + content. If used with other structures, the transforms should be + encoded with a ``"data"`` leading key that will be used to + construct a tensordict from the non-tensordict content. + batch_size (int, optional): the batch size to be used when sample() is + called. + .. note:: + The batch-size can be specified at construction time via the + ``batch_size`` argument, or at sampling time. The former should + be preferred whenever the batch-size is consistent across the + experiment. If the batch-size is likely to change, it can be + passed to the :meth:`~.sample` method. This option is + incompatible with prefetching (since this requires to know the + batch-size in advance) as well as with samplers that have a + ``drop_last`` argument. + priority_key (str, optional): the key at which priority is assumed to + be stored within TensorDicts added to this ReplayBuffer. + This is to be used when the sampler is of type + :class:`torchrl.data.PrioritizedSampler`. + Defaults to ``"td_error"``. reduction (str, optional): the reduction method for multidimensional tensordicts (ie stored trajectories). Can be one of "max", "min", "median" or "mean". + + Examples: + >>> import torch + >>> + >>> from torchrl.data import LazyTensorStorage, TensorDictPrioritizedReplayBuffer + >>> from tensordict import TensorDict + >>> + >>> torch.manual_seed(0) + >>> + >>> rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, storage=LazyTensorStorage(10), batch_size=5) + >>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 3, 1)}, [10]) + >>> rb.extend(data) + >>> print("len of rb", len(rb)) + len of rb 10 + >>> sample = rb.sample(5) + >>> print(sample) + TensorDict( + fields={ + _weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False), + a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + >>> print("index", sample["index"]) + index tensor([9, 5, 2, 2, 7]) + >>> # give a high priority to these samples... + >>> sample.set("td_error", 100*torch.ones(sample.shape)) + >>> # and update priority + >>> rb.update_tensordict_priority(sample) + >>> # the new sample should have a high overlap with the previous one + >>> sample = rb.sample(5) + >>> print(sample) + TensorDict( + fields={ + _weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False), + a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + >>> print("index", sample["index"]) + index tensor([2, 5, 5, 9, 7]) + """ def __init__( self, + *, alpha: float, beta: float, priority_key: str = "td_error", @@ -612,10 +924,12 @@ def __init__(self, *args, **kwargs): def sample( self, batch_size: Optional[int] = None, - include_info: bool = False, + include_info: bool = None, return_info: bool = False, ) -> TensorDictBase: - return super().sample(batch_size, include_info, return_info) + return super().sample( + batch_size=batch_size, include_info=include_info, return_info=return_info + ) def add(self, data: TensorDictBase) -> int: return super().add(data) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 9fd0fab8af4..564b1197c2c 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -137,6 +137,10 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: def ran_out(self): return self._ran_out + @ran_out.setter + def ran_out(self, value): + self._ran_out = value + class PrioritizedSampler(Sampler): """Prioritized sampler for replay buffer. From d23af8b1760a0879af6592922db350dc23771d7d Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Mar 2023 11:06:51 +0100 Subject: [PATCH 2/6] tests --- test/test_rb.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/test_rb.py b/test/test_rb.py index 6c829ca5668..8d75d287236 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -580,6 +580,54 @@ def test_index(self, rbtype, storage, size, prefetch): assert b +def test_multi_loops(): + """Tests that one can iterate multiple times over a buffer without rep.""" + rb = ReplayBuffer( + batch_size=5, storage=ListStorage(10), sampler=SamplerWithoutReplacement() + ) + rb.extend(torch.zeros(10)) + for i, d in enumerate(rb): # noqa: B007 + assert (d == 0).all() + assert i == 1 + for i, d in enumerate(rb): # noqa: B007 + assert (d == 0).all() + assert i == 1 + + +def test_batch_errors(): + """Tests error messages related to batch-size""" + rb = ReplayBuffer( + storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=False) + ) + rb.extend(torch.zeros(10)) + rb.sample(3) # that works + with pytest.raises( + RuntimeError, + match="Cannot iterate over the replay buffer. Batch_size was not specified", + ): + for _ in rb: + pass + with pytest.raises(RuntimeError, match="batch_size not specified"): + rb.sample() + with pytest.raises(ValueError, match="Samplers with drop_last=True"): + ReplayBuffer( + storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=True) + ) + # that works + ReplayBuffer( + storage=ListStorage(10), + ) + rb = ReplayBuffer( + storage=ListStorage(10), + sampler=SamplerWithoutReplacement(drop_last=False), + batch_size=3, + ) + rb.extend(torch.zeros(10)) + for _ in rb: + pass + rb.sample() + + @pytest.mark.parametrize("priority_key", ["pk", "td_error"]) @pytest.mark.parametrize("contiguous", [True, False]) @pytest.mark.parametrize("device", get_available_devices()) From dec5c56ba3a7b6742fd233e8f8d1ac16655ca54c Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Mar 2023 13:16:20 +0100 Subject: [PATCH 3/6] tests --- test/test_trainer.py | 18 +++++++------ test/test_transforms.py | 56 ++++++++++++++++++++--------------------- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/test/test_trainer.py b/test/test_trainer.py index 1251d4edd48..533fd4f0b0d 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -203,7 +203,9 @@ def test_rb_trainer(self, prioritized): S = 100 storage = ListStorage(S) if prioritized: - replay_buffer = TensorDictPrioritizedReplayBuffer(1.1, 0.9, storage=storage) + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=1.1, beta=0.9, storage=storage + ) else: replay_buffer = TensorDictReplayBuffer(storage=storage) @@ -260,8 +262,8 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type): if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - 1.1, - 0.9, + alpha=1.1, + beta=0.9, storage=storage, ) else: @@ -293,7 +295,7 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type): trainer2 = mocking_trainer() if prioritized: replay_buffer2 = TensorDictPrioritizedReplayBuffer( - 1.1, 0.9, storage=storage + alpha=1.1, beta=0.9, storage=storage ) else: replay_buffer2 = TensorDictReplayBuffer(storage=storage) @@ -398,8 +400,8 @@ def make_storage(): storage = make_storage() if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - 1.1, - 0.9, + alpha=1.1, + beta=0.9, storage=storage, ) else: @@ -430,8 +432,8 @@ def make_storage(): storage2 = make_storage() if prioritized: replay_buffer2 = TensorDictPrioritizedReplayBuffer( - 1.1, - 0.9, + alpha=1.1, + beta=0.9, storage=storage2, ) else: diff --git a/test/test_transforms.py b/test/test_transforms.py index 0b7d9391e6a..b28cb9a758a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -281,7 +281,7 @@ def test_transform_rb(self): batch = [20] torch.manual_seed(0) br = BinarizeReward() - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(br) reward = torch.randn(*batch, 1, device=device) misc = torch.randn(*batch, 1, device=device) @@ -419,7 +419,7 @@ def test_transform_rb(self): key_tensors = [key1_tensor, key2_tensor] td = TensorDict(dict(zip(keys, key_tensors)), batch_size, device=device) cat_frames = CatFrames(N=N, in_keys=keys, dim=dim) - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(cat_frames) rb.extend(td) with pytest.raises( @@ -651,7 +651,7 @@ def test_transform_rb(self, model, device): out_keys=out_keys, tensor_pixels_keys=tensor_pixels_key, ) - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(r3m) td = TensorDict({"pixels": torch.randint(255, (10, 244, 244, 3))}, [10]) rb.extend(td) @@ -1027,7 +1027,7 @@ def test_transform_env(self): def test_transform_rb(self): transform = StepCounter(10) - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) td = TensorDict({"a": torch.randn(10)}, [10]) rb.extend(td) rb.append_transform(transform) @@ -1345,7 +1345,7 @@ def test_transform_rb(self): dim=-1, del_keys=True, ) - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(ct) td = ( TensorDict( @@ -1525,7 +1525,7 @@ def test_transform_rb( batch, ) td.set("dont touch", dont_touch.clone()) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(cc) rb.extend(td) td = rb.sample(10) @@ -1668,7 +1668,7 @@ def test_transform_model(self, include_forward): @pytest.mark.parametrize("include_forward", [True, False]) def test_transform_rb(self, include_forward): - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) t = DiscreteActionProjection(7, 10, include_forward=include_forward) rb.append_transform(t) td = TensorDict( @@ -1863,7 +1863,7 @@ def test_transform_model(self, dtype_fixture): # noqa: F811 def test_transform_rb( self, ): - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) t = DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]) rb.append_transform(t) td = TensorDict( @@ -2029,7 +2029,7 @@ def test_transform_model(self): def test_transform_rb(self): t = ExcludeTransform("a") - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict( { @@ -2193,7 +2193,7 @@ def test_transform_model(self): def test_transform_rb(self): t = SelectTransform("b", "c") - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict( { @@ -2377,7 +2377,7 @@ def test_transform_model(self, out_keys): def test_transform_rb(self, out_keys): t = FlattenObservation(-3, -1, out_keys=out_keys) td = TensorDict({"pixels": torch.randint(255, (10, 10, 3))}, []).expand(10) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) rb.extend(td) td = rb.sample(2) @@ -2480,7 +2480,7 @@ def test_transform_model(self): def test_transform_rb(self): t = FrameSkipTransform(2) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) tensordict = TensorDict({"a": torch.zeros(10)}, [10]) rb.extend(tensordict) @@ -2678,7 +2678,7 @@ def test_transform_model(self, out_keys): @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) def test_transform_rb(self, out_keys): td = TensorDict({"pixels": torch.rand(3, 12, 12)}, []).expand(3) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(GrayScale(out_keys=out_keys)) rb.extend(td) r = rb.sample(3) @@ -2751,7 +2751,7 @@ def test_transform_model(self): def test_transform_rb(self): t = NoopResetEnv() - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict({}, [10]) rb.extend(td) @@ -3025,7 +3025,7 @@ def test_transform_rb(self): standard_normal=standard_normal, ) ) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) obs = torch.randn(7) @@ -3449,7 +3449,7 @@ def test_transform_model(self): def test_transform_rb(self): t = Resize(20, 21, in_keys=["pixels"]) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict({"pixels": torch.randn(3, 32, 32)}, []).expand(10) rb.extend(td) @@ -3527,7 +3527,7 @@ def test_transform_model(self): def test_transform_rb(self): t = RewardClipping(-0.1, 0.1) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) td = TensorDict({"reward": torch.randn(10)}, []).expand(10) rb.append_transform(t) rb.extend(td) @@ -3677,7 +3677,7 @@ def test_transform_rb(self, standard_normal): loc = 0.5 scale = 1.5 t = RewardScaling(0.5, 1.5, standard_normal=standard_normal) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) reward = torch.randn(10) td = TensorDict({"reward": reward}, []).expand(10) rb.append_transform(t) @@ -3768,7 +3768,7 @@ def test_transform_rb( self, ): t = RewardSum() - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) reward = torch.randn(10) td = TensorDict({("next", "reward"): reward}, []).expand(10) rb.append_transform(t) @@ -4102,7 +4102,7 @@ def test_transform_rb(self, out_keys, unsqueeze_dim): out_keys=out_keys, allow_positive_dim=True, ) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict( {"observation": TensorDict({"stuff": torch.randn(3, 4)}, [3, 4])}, [] @@ -4349,7 +4349,7 @@ def test_transform_rb(self, out_keys): out_keys=out_keys, allow_positive_dim=True, ) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict( {"observation": TensorDict({"stuff": torch.randn(3, 1, 4)}, [3, 1, 4])}, [] @@ -4544,7 +4544,7 @@ def test_transform_model(self, out_keys): @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) def test_transform_rb(self, out_keys): t = ToTensorImage(in_keys=["pixels"], out_keys=out_keys) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict({"pixels": torch.randint(255, (21, 22, 3))}, []) rb.extend(td.expand(10)) @@ -4587,7 +4587,7 @@ def test_transform_model(self): def test_transform_rb(self): batch_size = (2,) t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([*batch_size, 3])) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict({"a": torch.zeros(())}, []) rb.extend(td.expand(10)) @@ -4882,7 +4882,7 @@ def test_transform_rb(self): in_keys=["observation"], T=3, ) - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(t) rb.extend(td) with pytest.raises( @@ -5010,7 +5010,7 @@ def test_transform_rb(self): action_dim = 5 batch_size = (2,) t = gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=batch_size) - rb = ReplayBuffer(LazyTensorStorage(10)) + rb = ReplayBuffer(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict({"a": torch.zeros(())}, []) rb.extend(td.expand(10)) @@ -5158,7 +5158,7 @@ def test_transform_rb(self, model, device): out_keys=out_keys, tensor_pixels_keys=tensor_pixels_key, ) - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(vip) td = TensorDict({"pixels": torch.randint(255, (10, 244, 244, 3))}, [10]) rb.extend(td) @@ -6583,7 +6583,7 @@ def test_transform_rb(self, create_copy, inverse): else: t = RenameTransform(["a"], ["b"], ["a"], ["b"], create_copy=create_copy) tensordict = TensorDict({"b": torch.randn(())}, []).expand(10) - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(t) rb.extend(tensordict) assert "a" in rb._storage._storage.keys() @@ -6679,7 +6679,7 @@ def test_transform_model(self): def test_transform_rb(self): batch = [1] device = "cpu" - rb = ReplayBuffer(LazyTensorStorage(20)) + rb = ReplayBuffer(storage=LazyTensorStorage(20)) rb.append_transform(InitTracker()) reward = torch.randn(*batch, 1, device=device) misc = torch.randn(*batch, 1, device=device) From 76120983cde50ca84e3d52071c81f184e83ed518 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Mar 2023 13:18:29 +0100 Subject: [PATCH 4/6] tests --- test/test_rb_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 252913500f3..7443601c76d 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -53,7 +53,7 @@ def sample_from_buffer_remotely_returns_correct_tensordict_test(rank, name, worl _, inserted = _add_random_tensor_dict_to_buffer(buffer) sampled = _sample_from_buffer(buffer, 1) assert type(sampled) is type(inserted) is TensorDict - assert (sampled == inserted)["a"].item() + assert (sampled["a"] == inserted["a"]).all() @pytest.mark.skipif( From c471b96a31d2ade23d50566c82674f5c4409e3eb Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Mar 2023 14:03:05 +0100 Subject: [PATCH 5/6] fix examples --- torchrl/trainers/helpers/replay_buffer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index 4f9c48bf4b9..229a22cbe8e 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -35,6 +35,7 @@ def make_replay_buffer( sampler=sampler, pin_memory=device != torch.device("cpu"), prefetch=cfg.buffer_prefetch, + batch_size=cfg.batch_size, ) return buffer From 8984654d4bbaf1d833955af32116dace1e6359e4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 29 Mar 2023 14:39:33 +0100 Subject: [PATCH 6/6] fix examples --- examples/discrete_sac/discrete_sac.py | 6 +++--- examples/iql/iql_online.py | 13 +++++++++---- examples/td3/td3.py | 14 ++++++++++---- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/examples/discrete_sac/discrete_sac.py b/examples/discrete_sac/discrete_sac.py index 6fc101ff533..987571747f6 100644 --- a/examples/discrete_sac/discrete_sac.py +++ b/examples/discrete_sac/discrete_sac.py @@ -44,7 +44,7 @@ def make_replay_buffer( batch_size=256, buffer_scratch_dir="/tmp/", device="cpu", - make_replay_buffer=3, + prefetch=3, ): if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( @@ -52,7 +52,7 @@ def make_replay_buffer( beta=0.5, pin_memory=False, batch_size=batch_size, - prefetch=make_replay_buffer, + prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, @@ -63,7 +63,7 @@ def make_replay_buffer( replay_buffer = TensorDictReplayBuffer( pin_memory=False, batch_size=batch_size, - prefetch=make_replay_buffer, + prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, diff --git a/examples/iql/iql_online.py b/examples/iql/iql_online.py index 4dcc5bea747..1512f471f10 100644 --- a/examples/iql/iql_online.py +++ b/examples/iql/iql_online.py @@ -36,33 +36,36 @@ def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False): def make_replay_buffer( + batch_size, prb=False, buffer_size=1000000, buffer_scratch_dir="/tmp/", device="cpu", - make_replay_buffer=3, + prefetch=3, ): if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, beta=0.5, pin_memory=False, - prefetch=make_replay_buffer, + prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, device=device, ), + batch_size=batch_size, ) else: replay_buffer = TensorDictReplayBuffer( pin_memory=False, - prefetch=make_replay_buffer, + prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, device=device, ), + batch_size=batch_size, ) return replay_buffer @@ -218,7 +221,9 @@ def env_factory(num_workers): collector.set_seed(cfg.seed) # Make Replay Buffer - replay_buffer = make_replay_buffer(buffer_size=cfg.buffer_size, device=device) + replay_buffer = make_replay_buffer( + buffer_size=cfg.buffer_size, device=device, batch_size=cfg.batch_size + ) # Optimizers params = list(loss_module.parameters()) diff --git a/examples/td3/td3.py b/examples/td3/td3.py index 659da599240..a285c29acef 100644 --- a/examples/td3/td3.py +++ b/examples/td3/td3.py @@ -60,33 +60,36 @@ def apply_env_transforms(env, reward_scaling=1.0): def make_replay_buffer( + batch_size, prb=False, buffer_size=1000000, buffer_scratch_dir="/tmp/", device="cpu", - make_replay_buffer=3, + prefetch=3, ): if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( alpha=0.7, beta=0.5, pin_memory=False, - prefetch=make_replay_buffer, + prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, device=device, ), + batch_size=batch_size, ) else: replay_buffer = TensorDictReplayBuffer( pin_memory=False, - prefetch=make_replay_buffer, + prefetch=prefetch, storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, device=device, ), + batch_size=batch_size, ) return replay_buffer @@ -239,7 +242,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Make Replay Buffer replay_buffer = make_replay_buffer( - prb=cfg.prb, buffer_size=cfg.buffer_size, device=device + batch_size=cfg.batch_size, + prb=cfg.prb, + buffer_size=cfg.buffer_size, + device=device, ) # Optimizers