Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 5, 2023
1 parent 841f8d9 commit f44741b
Show file tree
Hide file tree
Showing 5 changed files with 457 additions and 108 deletions.
304 changes: 198 additions & 106 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,114 @@ class TC:
storage.set(0, data)
assert storage.get(0).device.type == device_storage.type

@pytest.mark.parametrize("storage_in", ["tensor", "memmap"])
@pytest.mark.parametrize("storage_out", ["tensor", "memmap"])
@pytest.mark.parametrize("init_out", [True, False])
def test_storage_state_dict(self, storage_in, storage_out, init_out):
buffer_size = 100
if storage_in == "memmap":
storage_in = LazyMemmapStorage(buffer_size, device="cpu")
elif storage_in == "tensor":
storage_in = LazyTensorStorage(buffer_size, device="cpu")
if storage_out == "memmap":
storage_out = LazyMemmapStorage(buffer_size, device="cpu")
elif storage_out == "tensor":
storage_out = LazyTensorStorage(buffer_size, device="cpu")

replay_buffer = TensorDictReplayBuffer(
pin_memory=False, prefetch=3, storage=storage_in, batch_size=3
)
# fill replay buffer with random data
transition = TensorDict(
{
"observation": torch.ones(1, 4),
"action": torch.ones(1, 2),
"reward": torch.ones(1, 1),
"dones": torch.ones(1, 1),
"next": {"observation": torch.ones(1, 4)},
},
batch_size=1,
)
for _ in range(3):
replay_buffer.extend(transition)

state_dict = replay_buffer.state_dict()

new_replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=3,
storage=storage_out,
batch_size=state_dict["_batch_size"],
)
if init_out:
new_replay_buffer.extend(transition)

new_replay_buffer.load_state_dict(state_dict)
s = new_replay_buffer.sample()
assert (s.exclude("index") == 1).all()

@pytest.mark.parametrize("device_data", get_default_devices())
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize("data_type", ["tensor", "tc", "td"])
@pytest.mark.parametrize("isinit", [True, False])
def test_storage_dumps_loads(
self, device_data, storage_type, data_type, isinit, tmpdir
):
dir_rb = tmpdir / "rb"
dir_save = tmpdir / "save"
dir_rb.mkdir()
dir_save.mkdir()
torch.manual_seed(0)

@tensorclass
class TC:
tensor: torch.Tensor
td: TensorDict
text: str

if data_type == "tensor":
data = torch.randint(10, (3,), device=device_data)
elif data_type == "td":
data = TensorDict(
{
"a": torch.randint(10, (3,), device=device_data),
"b": TensorDict(
{"c": torch.randint(10, (3,), device=device_data)},
batch_size=[3],
),
},
batch_size=[3],
device=device_data,
)
elif data_type == "tc":
data = TC(
tensor=torch.randint(10, (3,), device=device_data),
td=TensorDict(
{"c": torch.randint(10, (3,), device=device_data)}, batch_size=[3]
),
text="some text",
batch_size=[3],
device=device_data,
)
else:
raise NotImplementedError
if storage_type in (LazyMemmapStorage,):
storage = storage_type(max_size=10, scratch_dir=dir_rb)
else:
storage = storage_type(max_size=10)
storage.set(range(3), data)
storage.dumps(dir_save)
storage_recover = storage_type(max_size=10)
if isinit:
storage_recover.set(range(3), data.zero_())
storage_recover.loads(dir_save)
if data_type == "tensor":
torch.testing.assert_close(storage._storage, storage_recover._storage)
else:
assert_allclose_td(storage._storage, storage_recover._storage)
if data == "tc":
assert storage._storage.text == storage_recover._storage.text


@pytest.mark.parametrize("max_size", [1000])
@pytest.mark.parametrize("shape", [[3, 4]])
Expand Down Expand Up @@ -1174,125 +1282,109 @@ def test_replay_buffer_iter(size, drop_last):
assert i == (size - 1) // 3


class TestStateDict:
@pytest.mark.parametrize("storage_in", ["tensor", "memmap"])
@pytest.mark.parametrize("storage_out", ["tensor", "memmap"])
@pytest.mark.parametrize("init_out", [True, False])
def test_load_state_dict(self, storage_in, storage_out, init_out):
buffer_size = 100
if storage_in == "memmap":
storage_in = LazyMemmapStorage(buffer_size, device="cpu")
elif storage_in == "tensor":
storage_in = LazyTensorStorage(buffer_size, device="cpu")
if storage_out == "memmap":
storage_out = LazyMemmapStorage(buffer_size, device="cpu")
elif storage_out == "tensor":
storage_out = LazyTensorStorage(buffer_size, device="cpu")

replay_buffer = TensorDictReplayBuffer(
pin_memory=False, prefetch=3, storage=storage_in, batch_size=3
@pytest.mark.parametrize("size", [20, 25, 30])
@pytest.mark.parametrize("batch_size", [1, 10, 15])
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
@pytest.mark.parametrize("device", get_default_devices())
class TestMaxValueWriter:
def test_max_value_writer(self, size, batch_size, reward_ranges, device):
torch.manual_seed(0)
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(size, device=device),
sampler=SamplerWithoutReplacement(),
batch_size=batch_size,
writer=TensorDictMaxValueWriter(rank_key="key"),
)
# fill replay buffer with random data
transition = TensorDict(

max_reward1, max_reward2, max_reward3 = reward_ranges

td = TensorDict(
{
"observation": torch.ones(1, 4),
"action": torch.ones(1, 2),
"reward": torch.ones(1, 1),
"dones": torch.ones(1, 1),
"next": {"observation": torch.ones(1, 4)},
"key": torch.clamp_max(torch.rand(size), max=max_reward1),
"obs": torch.rand(size),
},
batch_size=1,
batch_size=size,
device=device,
)
for _ in range(3):
replay_buffer.extend(transition)

state_dict = replay_buffer.state_dict()
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward1).all()
assert (0 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))

new_replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=3,
storage=storage_out,
batch_size=state_dict["_batch_size"],
td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)
if init_out:
new_replay_buffer.extend(transition)

new_replay_buffer.load_state_dict(state_dict)
s = new_replay_buffer.sample()
assert (s.exclude("index") == 1).all()


@pytest.mark.parametrize("size", [20, 25, 30])
@pytest.mark.parametrize("batch_size", [1, 10, 15])
@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)])
@pytest.mark.parametrize("device", get_default_devices())
def test_max_value_writer(size, batch_size, reward_ranges, device):
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(size, device=device),
sampler=SamplerWithoutReplacement(),
batch_size=batch_size,
writer=TensorDictMaxValueWriter(rank_key="key"),
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward2).all()
assert (max_reward1 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))

max_reward1, max_reward2, max_reward3 = reward_ranges
td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)

td = TensorDict(
{
"key": torch.clamp_max(torch.rand(size), max=max_reward1),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward1).all()
assert (0 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))
for sample in td:
rb.add(sample)

td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") <= max_reward2).all()
assert (max_reward1 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))
sample = rb.sample()
assert (sample.get("key") <= max_reward3).all()
assert (max_reward2 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))

td = TensorDict(
{
"key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)
# Finally, test the case when no obs should be added
td = TensorDict(
{
"key": torch.zeros(size),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") != 0).all()

for sample in td:
rb.add(sample)
def test_max_value_writer_serialize(
self, size, batch_size, reward_ranges, device, tmpdir
):
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(size, device=device),
sampler=SamplerWithoutReplacement(),
batch_size=batch_size,
writer=TensorDictMaxValueWriter(rank_key="key"),
)

sample = rb.sample()
assert (sample.get("key") <= max_reward3).all()
assert (max_reward2 <= sample.get("key")).all()
assert len(sample.get("index").unique()) == len(sample.get("index"))
max_reward1, max_reward2, max_reward3 = reward_ranges

# Finally, test the case when no obs should be added
td = TensorDict(
{
"key": torch.zeros(size),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)
rb.extend(td)
sample = rb.sample()
assert (sample.get("key") != 0).all()
td = TensorDict(
{
"key": torch.clamp_max(torch.rand(size), max=max_reward1),
"obs": torch.rand(size),
},
batch_size=size,
device=device,
)
rb.extend(td)
rb._writer.dumps(tmpdir)
other = TensorDictMaxValueWriter(rank_key="key")
other.loads(tmpdir)
assert len(rb._writer._current_top_values) == len(other._current_top_values)
torch.testing.assert_close(
torch.tensor(rb._writer._current_top_values),
torch.tensor(other._current_top_values),
)


class TestMultiProc:
Expand Down
29 changes: 29 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# LICENSE file in the root directory of this source tree.

import collections
import json
import threading
import warnings
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -230,15 +232,42 @@ def state_dict(self) -> Dict[str, Any]:
"_storage": self._storage.state_dict(),
"_sampler": self._sampler.state_dict(),
"_writer": self._writer.state_dict(),
"_transforms": self._transform.state_dict(),
"_batch_size": self._batch_size,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._storage.load_state_dict(state_dict["_storage"])
self._sampler.load_state_dict(state_dict["_sampler"])
self._writer.load_state_dict(state_dict["_writer"])
self._transform.load_state_dict(state_dict["_transforms"])
self._batch_size = state_dict["_batch_size"]

def dumps(self, path):
path = Path(path).absolute()
path.mkdir(exist_ok=True)
self._storage.dumps(path / "storage")
self._sampler.dumps(path / "sampler")
self._writer.dumps(path / "writer")
# fall back on state_dict for transforms
transform_sd = self._transform.state_dict()
if transform_sd:
torch.save(transform_sd, path / "transform.t")
with open(path / "buffer_metadata.json", "w") as file:
json.dump({"batch_size": self._batch_size}, file)

def loads(self, path):
path = Path(path).absolute()
self._storage.loads(path / "storage")
self._sampler.loads(path / "sampler")
self._writer.loads(path / "writer")
# fall back on state_dict for transforms
if (path / "transform.t").exists():
self._transform.load_state_dict(torch.load(path / "transform.t"))
with open(path / "buffer_metadata.json", "r") as file:
metadata = json.load(file)
self._batch_size = metadata["batch_size"]

def add(self, data: Any) -> int:
"""Add a single element to the replay buffer.
Expand Down
Loading

0 comments on commit f44741b

Please sign in to comment.