Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Max Value Writer #1622

Merged
merged 22 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ We also give users the ability to compose a replay buffer using the following co
Writer
RoundRobinWriter
TensorDictRoundRobinWriter
TensorDictMaxValueWriter

Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.
Expand Down
58 changes: 57 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
ListStorage,
TensorStorage,
)
from torchrl.data.replay_buffers.writers import RoundRobinWriter
from torchrl.data.replay_buffers.writers import (
RoundRobinWriter,
TensorDictMaxValueWriter,
)
from torchrl.envs.transforms.transforms import (
BinarizeReward,
CatFrames,
Expand Down Expand Up @@ -1109,6 +1112,59 @@ def test_load_state_dict(self, storage_in, storage_out, init_out):
assert (s.exclude("index") == 1).all()


@pytest.mark.parametrize("size", [20, 25, 30])
@pytest.mark.parametrize("batch_size", [1, 10, 15])
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
def test_max_value_writer(size, batch_size, reward_ranges):
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(size),
sampler=SamplerWithoutReplacement(),
batch_size=batch_size,
writer=TensorDictMaxValueWriter(rank_key="key"),
)

max_reward1, max_reward2, max_reward3 = reward_ranges

td = TensorDict(
{
"key": torch.clamp_max(torch.rand(size), max=max_reward1),
"obs": torch.tensor(torch.rand(size)),
},
batch_size=size,
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward1).all()
assert (0 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))

td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2),
"obs": torch.tensor(torch.rand(size)),
},
batch_size=size,
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward2).all()
assert (max_reward1 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))

td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3),
"obs": torch.tensor(torch.rand(size)),
},
batch_size=size,
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward3).all()
assert (max_reward2 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ReplayBuffer,
RoundRobinWriter,
Storage,
TensorDictMaxValueWriter,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
TensorDictRoundRobinWriter,
Expand Down
7 changes: 6 additions & 1 deletion torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@
Storage,
TensorStorage,
)
from .writers import RoundRobinWriter, TensorDictRoundRobinWriter, Writer
from .writers import (
RoundRobinWriter,
TensorDictMaxValueWriter,
TensorDictRoundRobinWriter,
Writer,
)
111 changes: 111 additions & 0 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import heapq
from abc import ABC, abstractmethod
from typing import Any, Dict, Sequence

Expand Down Expand Up @@ -92,3 +93,113 @@ def extend(self, data: Sequence) -> torch.Tensor:
data["index"] = index
self._storage[index] = data
return index


class TensorDictMaxValueWriter(Writer):
"""A Writer class for composable replay buffers that keeps the top elements based on some ranking key.

If rank_key is not provided, the key will be ("next", "reward").
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
>>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
>>> rb = TensorDictReplayBuffer(
... storage=LazyTensorStorage(1),
... sampler=SamplerWithoutReplacement(),
... batch_size=1,
... writer=TensorDictMaxValueWriter(rank_key="key"),
... )
>>> td = TensorDict({
... "key": torch.tensor(range(10)),
... "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
9

albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
>>> td = TensorDict({
... "key": torch.tensor(range(10, 20)),
... "obs": torch.tensor(range(10, 20))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19

albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
>>> td = TensorDict({
... "key": torch.tensor(range(10)),
... "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19
"""
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, rank_key=None, **kw) -> None:
super().__init__(**kw)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
self._cursor = 0
self._current_top_values = []
self._rank_key = rank_key
if self._rank_key is None:
self._rank_key = ("next", "reward")

def get_insert_index(self, data: Any) -> int:
"""Returns the index where the data should be inserted, or None if it should not be inserted."""
vmoens marked this conversation as resolved.
Show resolved Hide resolved
ret = None
rank_data = data.get("_data").get(self._rank_key)
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

# Sum the rank key, in case it is a whole trajectory
rank_data = rank_data.sum().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe?
Maybe we should document what are the expected shapes for this class, eg

[B, T]

but not

[B1, B2, T]

Another option is to check the number of dimensions of the ranking key OR the name of the last dim of the input tensordict (which should be "time").

Not raising any exception and just doing a plain sum could lead to surprising results I think

Copy link
Contributor Author

@albertbou92 albertbou92 Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the first option. Since the ranking value has to be a single float we only allow data of the shape [] and [T] for the add method and [B] and [B, T] for the extend method. If data has a time dimension, we sum along it. If too many dimensions are provided, an error is raised.

I did not go for checking the dimension names because it seemed to restrictive. I don't think time dimension is always labelled

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not always but mostly
if you get your data from env.rollout or collector, it will.
If from there you store the data in a rb, it will keep the tag.
But if you reshape or do other stuff it could go away.


if rank_data is None:
raise ValueError(f"Rank key {self._rank_key} not found in data.")
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

# If the buffer is not full, add the data
if len(self._current_top_values) < self._storage.max_size:

ret = self._cursor
self._cursor = (self._cursor + 1) % self._storage.max_size

# Add new reward to the heap
heapq.heappush(self._current_top_values, (rank_data, ret))

# If the buffer is full, check if the new data is better than the worst data in the buffer
elif rank_data > self._current_top_values[0][0]:

# retrieve position of the smallest value
min_sample = heapq.heappop(self._current_top_values)
ret = min_sample[1]

# Add new reward to the heap
heapq.heappush(self._current_top_values, (rank_data, ret))

return ret

def add(self, data: Any) -> int:
"""Inserts a single element of data at an appropriate index, and returns that index."""
index = self.get_insert_index(data)
if index is not None:
data.set("index", index)
self._storage[index] = data
return index

def extend(self, data: Sequence) -> None:
"""Inserts a series of data points at appropriate indices."""
data_to_replace = {}
for i, sample in enumerate(data):
index = self.get_insert_index(sample)
if index is not None:
data_to_replace[index] = i

# Replace the data in the storage all at once
keys = list(data_to_replace.keys())
if len(keys) > 0:
values = list(data_to_replace.values())
data.get("index")[values].copy_(torch.tensor(keys))
self._storage[keys] = data[values]
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved

def _empty(self) -> None:
self._cursor = 0
self._current_top_values = []