diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index cd2b71a0922..98d2d40cd5c 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -43,6 +43,7 @@ We also give users the ability to compose a replay buffer using the following co Writer RoundRobinWriter TensorDictRoundRobinWriter + TensorDictMaxValueWriter Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes. :class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery. diff --git a/test/test_rb.py b/test/test_rb.py index 8e894f45c3e..0b465c0b424 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -38,7 +38,10 @@ ListStorage, TensorStorage, ) -from torchrl.data.replay_buffers.writers import RoundRobinWriter +from torchrl.data.replay_buffers.writers import ( + RoundRobinWriter, + TensorDictMaxValueWriter, +) from torchrl.envs.transforms.transforms import ( BinarizeReward, CatFrames, @@ -1209,6 +1212,65 @@ def test_load_state_dict(self, storage_in, storage_out, init_out): assert (s.exclude("index") == 1).all() +@pytest.mark.parametrize("size", [20, 25, 30]) +@pytest.mark.parametrize("batch_size", [1, 10, 15]) +@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) +def test_max_value_writer(size, batch_size, reward_ranges): + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), + ) + + max_reward1, max_reward2, max_reward3 = reward_ranges + + td = TensorDict( + { + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.tensor(torch.rand(size)), + }, + batch_size=size, + device="cpu", + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward1).all() + assert (0 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), + "obs": torch.tensor(torch.rand(size)), + }, + batch_size=size, + device="cpu", + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward2).all() + assert (max_reward1 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), + "obs": torch.tensor(torch.rand(size)), + }, + batch_size=size, + device="cpu", + ) + + for sample in td: + rb.add(sample) + + sample = rb.sample() + assert (sample.get("key") <= max_reward3).all() + assert (max_reward2 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 4c90146ac7f..9a12749b482 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -14,6 +14,7 @@ ReplayBuffer, RoundRobinWriter, Storage, + TensorDictMaxValueWriter, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, TensorDictRoundRobinWriter, diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index e27dd8572d8..6be80e26c1f 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -23,4 +23,9 @@ Storage, TensorStorage, ) -from .writers import RoundRobinWriter, TensorDictRoundRobinWriter, Writer +from .writers import ( + RoundRobinWriter, + TensorDictMaxValueWriter, + TensorDictRoundRobinWriter, + Writer, +) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 5d21d202eae..cfc6c90bb2c 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -718,12 +718,13 @@ def add(self, data: TensorDictBase) -> int: data_add = data index = super()._add(data_add) - if is_tensor_collection(data_add): - data_add.set("index", index) + if index is not None: + if is_tensor_collection(data_add): + data_add.set("index", index) - # priority = self._get_priority(data) - # if priority: - self.update_tensordict_priority(data_add) + # priority = self._get_priority(data) + # if priority: + self.update_tensordict_priority(data_add) return index def extend(self, tensordicts: TensorDictBase) -> torch.Tensor: diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 49244262f4e..8a71c5927a1 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import heapq from abc import ABC, abstractmethod from typing import Any, Dict, Sequence @@ -92,3 +93,128 @@ def extend(self, data: Sequence) -> torch.Tensor: data["index"] = index self._storage[index] = data return index + + +class TensorDictMaxValueWriter(Writer): + """A Writer class for composable replay buffers that keeps the top elements based on some ranking key. + + If rank_key is not provided, the key will be ``("next", "reward")``. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter + >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + >>> rb = TensorDictReplayBuffer( + ... storage=LazyTensorStorage(1), + ... sampler=SamplerWithoutReplacement(), + ... batch_size=1, + ... writer=TensorDictMaxValueWriter(rank_key="key"), + ... ) + >>> td = TensorDict({ + ... "key": torch.tensor(range(10)), + ... "obs": torch.tensor(range(10)) + ... }, batch_size=10) + >>> rb.extend(td) + >>> print(rb.sample().get("obs").item()) + 9 + >>> td = TensorDict({ + ... "key": torch.tensor(range(10, 20)), + ... "obs": torch.tensor(range(10, 20)) + ... }, batch_size=10) + >>> rb.extend(td) + >>> print(rb.sample().get("obs").item()) + 19 + >>> td = TensorDict({ + ... "key": torch.tensor(range(10)), + ... "obs": torch.tensor(range(10)) + ... }, batch_size=10) + >>> rb.extend(td) + >>> print(rb.sample().get("obs").item()) + 19 + """ + + def __init__(self, rank_key=None, **kwargs) -> None: + super().__init__(**kwargs) + self._cursor = 0 + self._current_top_values = [] + self._rank_key = rank_key + if self._rank_key is None: + self._rank_key = ("next", "reward") + + def get_insert_index(self, data: Any) -> int: + """Returns the index where the data should be inserted, or ``None`` if it should not be inserted.""" + if data.batch_dims > 1: + raise RuntimeError( + "Expected input tensordict to have no more than 1 dimension, got" + f"tensordict.batch_size = {data.batch_size}" + ) + + ret = None + rank_data = data.get(("_data", self._rank_key)) + + # If time dimension, sum along it. + rank_data = rank_data.sum(-1).item() + + if rank_data is None: + raise KeyError(f"Rank key {self._rank_key} not found in data.") + + # If the buffer is not full, add the data + if len(self._current_top_values) < self._storage.max_size: + + ret = self._cursor + self._cursor = (self._cursor + 1) % self._storage.max_size + + # Add new reward to the heap + heapq.heappush(self._current_top_values, (rank_data, ret)) + + # If the buffer is full, check if the new data is better than the worst data in the buffer + elif rank_data > self._current_top_values[0][0]: + + # retrieve position of the smallest value + min_sample = heapq.heappop(self._current_top_values) + ret = min_sample[1] + + # Add new reward to the heap + heapq.heappush(self._current_top_values, (rank_data, ret)) + + return ret + + def add(self, data: Any) -> int: + """Inserts a single element of data at an appropriate index, and returns that index. + + The data passed to this module should be structured as :obj:`[]` or :obj:`[T]` where + :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed + over the time dimension. + """ + index = self.get_insert_index(data) + if index is not None: + data.set("index", index) + self._storage[index] = data + return index + + def extend(self, data: Sequence) -> None: + """Inserts a series of data points at appropriate indices. + + The data passed to this module should be structured as :obj:`[B]` or :obj:`[B, T]` where :obj:`B` is + the batch size, :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed over the + time dimension. + """ + data_to_replace = {} + for i, sample in enumerate(data): + index = self.get_insert_index(sample) + if index is not None: + data_to_replace[index] = i + + # Replace the data in the storage all at once + keys, values = zip(*data_to_replace.items()) + if len(keys) > 0: + index = data.get("index") + values = list(values) + keys = index[values] = torch.tensor(keys, dtype=index.dtype) + data.set("index", index) + self._storage[keys] = data[values] + + def _empty(self) -> None: + self._cursor = 0 + self._current_top_values = []