From 93e9e30dc437d6159eaaa97c1f3fcca1859ef5f5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 17:28:04 +0100 Subject: [PATCH] [Feature] Span slice indices on the left and on the right (#2107) --- test/test_rb.py | 53 +++++++++++++++++++ torchrl/data/replay_buffers/samplers.py | 70 ++++++++++++++++++++++--- 2 files changed, 116 insertions(+), 7 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 10d71d87f89..9582738617c 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -2175,6 +2175,59 @@ def test_slice_sampler_without_replacement( done_recon = info[("next", "truncated")] | info[("next", "terminated")] assert done_recon.view(num_slices, -1)[:, -1].all() + def test_slice_sampler_left_right(self): + torch.manual_seed(0) + data = TensorDict( + {"obs": torch.arange(1, 11).repeat(10), "eps": torch.arange(100) // 10 + 1}, + [100], + ) + + for N in (2, 4): + rb = TensorDictReplayBuffer( + sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)), + batch_size=50, + storage=LazyMemmapStorage(100), + ) + rb.extend(data) + + for _ in range(10): + sample = rb.sample() + sample = split_trajectories(sample) + assert (sample["next", "truncated"].squeeze(-1).sum(-1) == 1).all() + assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"] + assert ((sample["eps"] == 0).sum(-1) <= N).all() + for i in range(sample.shape[0]): + curr_eps = sample[i]["eps"] + curr_eps = curr_eps[curr_eps != 0] + assert curr_eps.unique().numel() == 1 + + def test_slice_sampler_left_right_ndim(self): + torch.manual_seed(0) + data = TensorDict( + {"obs": torch.arange(1, 11).repeat(12), "eps": torch.arange(120) // 10 + 1}, + [120], + ) + data = data.reshape(4, 30) + + for N in (2, 4): + rb = TensorDictReplayBuffer( + sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)), + batch_size=50, + storage=LazyMemmapStorage(100, ndim=2), + ) + rb.extend(data) + + for _ in range(10): + sample = rb.sample() + sample = split_trajectories(sample) + assert (sample["next", "truncated"].squeeze(-1).sum(-1) <= 1).all() + assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"] + assert ((sample["eps"] == 0).sum(-1) <= N).all() + for i in range(sample.shape[0]): + curr_eps = sample[i]["eps"] + curr_eps = curr_eps[curr_eps != 0] + assert curr_eps.unique().numel() == 1 + def test_slicesampler_strictlength(self): torch.manual_seed(0) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index b81f790fbfa..bbd86655dc6 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -674,6 +674,13 @@ class SliceSampler(Sampler): the :meth:`~sample` method will be compiled with :func:`~torch.compile`. Keyword arguments can also be passed to torch.compile with this arg. Defaults to ``False``. + span (bool, int, Tuple[bool | int, bool | int], optional): if provided, the sampled + trajectory will span across the left and/or the right. This means that possibly + fewer elements will be provided than what was required. A boolean value means + that at least one element will be sampled per trajectory. An integer `i` means + that at least `slice_len - i` samples will be gathered for each sampled trajectory. + Using tuples allows a fine grained control over the span on the left (beginning + of the stored trajectory) and on the right (end of the stored trajectory). .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first @@ -753,6 +760,7 @@ def __init__( truncated_key: NestedKey | None = ("next", "truncated"), strict_length: bool = True, compile: bool | dict = False, + span: bool | Tuple[bool | int, bool | int] = False, ): self.num_slices = num_slices self.slice_len = slice_len @@ -763,6 +771,11 @@ def __init__( self._fetch_traj = True self.strict_length = strict_length self._cache = {} + + if isinstance(span, bool): + span = (span, span) + self.span = span + if trajectories is not None: if traj_key is not None or end_key: raise RuntimeError( @@ -916,6 +929,7 @@ def _end_to_start_stop(end, length): return start_idx, stop_idx, lengths def _start_to_end(self, st: torch.Tensor, length: int): + arange = torch.arange(length, device=st.device, dtype=st.dtype) ndims = st.shape[-1] - 1 if st.ndim else 0 if ndims: @@ -1128,14 +1142,55 @@ def _get_index( storage_length: int, traj_idx: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, dict]: + # end_point is the last possible index for start + last_indexable_start = lengths[traj_idx] - seq_length + 1 + if not self.span[1]: + end_point = last_indexable_start + elif self.span[1] is True: + end_point = lengths[traj_idx] + 1 + else: + span_left = self.span[1] + if span_left >= seq_length: + raise ValueError( + "The right and left span must be strictly lower than the sequence length" + ) + end_point = lengths[traj_idx] - span_left + + if not self.span[0]: + start_point = 0 + elif self.span[0] is True: + start_point = -seq_length + 1 + else: + span_right = self.span[0] + if span_right >= seq_length: + raise ValueError( + "The right and left span must be strictly lower than the sequence length" + ) + start_point = -span_right + relative_starts = ( - ( - torch.rand(num_slices, device=lengths.device) - * (lengths[traj_idx] - seq_length + 1) - ) - .floor() - .to(start_idx.dtype) - ) + torch.rand(num_slices, device=lengths.device) * (end_point - start_point) + ).floor().to(start_idx.dtype) + start_point + + if self.span[0]: + out_of_traj = relative_starts < 0 + if out_of_traj.any(): + # a negative start means sampling fewer elements + seq_length = torch.where( + ~out_of_traj, seq_length, seq_length + relative_starts + ) + relative_starts = torch.where(~out_of_traj, relative_starts, 0) + if self.span[1]: + out_of_traj = relative_starts + seq_length > lengths[traj_idx] + if out_of_traj.any(): + # a negative start means sampling fewer elements + # print('seq_length before', seq_length) + # print('relative_starts', relative_starts) + seq_length = torch.minimum( + seq_length, lengths[traj_idx] - relative_starts + ) + # print('seq_length after', seq_length) + starts = torch.cat( [ (start_idx[traj_idx, 0] + relative_starts).unsqueeze(1), @@ -1143,6 +1198,7 @@ def _get_index( ], 1, ) + index = self._tensor_slices_from_startend(seq_length, starts, storage_length) if self.truncated_key is not None: truncated_key = self.truncated_key