Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 4, 2023
1 parent 18c7f10 commit 2714b62
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 59 deletions.
71 changes: 42 additions & 29 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,38 +1964,51 @@ def test_d4rl_iteration(self, task, split_trajs):
print(f"terminated test after {time.time()-t0}s")


_MINARI_DATASETS = []


def _minari_selected_datasets():
if not _has_minari:
return
global _MINARI_DATASETS
import minari

torch.manual_seed(0)

keys = list(minari.list_remote_datasets())
indices = torch.randperm(len(keys))[:10]
keys = [keys[idx] for idx in indices]
keys = [
key
for key in keys
if "=0.4" in minari.list_remote_datasets()[key]["minari_version"]
]
assert len(keys) > 5
_MINARI_DATASETS += keys
print("_MINARI_DATASETS", _MINARI_DATASETS)


_minari_selected_datasets()


@pytest.mark.skipif(not _has_minari, reason="Minari not found")
@pytest.mark.parametrize("split", [False, True])
@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS)
class TestMinari:
@pytest.fixture(scope="class")
def selected_datasets(self):
torch.manual_seed(0)
import minari

keys = list(minari.list_remote_datasets())
indices = torch.randperm(len(keys))[:10]
keys = [keys[idx] for idx in indices]
keys = [
key
for key in keys
if "=0.4" in minari.list_remote_datasets()[key]["minari_version"]
]
assert len(keys) > 5
return keys

@pytest.mark.parametrize("split", [False, True])
def test_load(self, selected_datasets, split):
for dataset in selected_datasets:
print("dataset", dataset)
data = MinariExperienceReplay(dataset, batch_size=32, split_trajs=split)
def test_load(self, selected_dataset, split):
print("dataset", selected_dataset)
data = MinariExperienceReplay(
selected_dataset, batch_size=32, split_trajs=split
)
t0 = time.time()
for i, sample in enumerate(data):
t1 = time.time()
print(f"sampling time {1000 * (t1-t0): 4.4f}ms")
assert data.metadata["action_space"].is_in(sample["action"])
assert data.metadata["observation_space"].is_in(sample["observation"])
t0 = time.time()
for i, sample in enumerate(data):
t1 = time.time()
print(f"sampling time {1000 * (t1-t0): 4.4f}ms")
assert data.metadata["action_space"].is_in(sample["action"])
assert data.metadata["observation_space"].is_in(sample["observation"])
t0 = time.time()
if i == 10:
break
if i == 10:
break


@pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found")
Expand Down
93 changes: 63 additions & 30 deletions torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch

from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict
from tensordict import PersistentTensorDict, TensorDict
from torchrl._utils import KeyDependentDefaultDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
Expand Down Expand Up @@ -223,19 +223,22 @@ def _download_and_preproc(self):
total_steps = 0
print("first read through data to create data structure...")
h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5")
# Get the total number of steps for the dataset
total_steps += sum(
h5_data[episode, "actions"].shape[0] for episode in h5_data.keys()
)
# populate the tensordict
episode_dict = {}
for episode_key, episode in h5_data.items():
episode_num = int(episode_key[len("episode_") :])
episode_dict[episode_num] = episode_key
episode_len = episode["actions"].shape[0]
episode_dict[episode_num] = (episode_key, episode_len)
# Get the total number of steps for the dataset
total_steps += episode_len
for key, val in episode.items():
match = _NAME_MATCH[key]
if key in ("observations", "state", "infos"):
if not val.shape:
if (
not val.shape
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1:
if val.is_empty():
continue
val = _patch_info(val)
td_data.set(("next", match), torch.zeros_like(val[0]))
td_data.set(match, torch.zeros_like(val[0]))
Expand All @@ -248,39 +251,68 @@ def _download_and_preproc(self):
)

# give it the proper size
td_data["next", "done"] = (
td_data["next", "truncated"] | td_data["next", "terminated"]
)
if "terminated" in td_data.keys():
td_data["done"] = td_data["truncated"] | td_data["terminated"]
td_data = td_data.expand(total_steps)
# save to designated location
print(f"creating tensordict data in {self.data_path_root}: ", end="\t")
td_data = td_data.memmap_like(self.data_path_root)
print(td_data)
print("tensordict structure:", td_data)

print(f"Reading data from {max(*episode_dict)} episodes")
index = 0
with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar:
# iterate over episodes and populate the tensordict
for episode_num in sorted(episode_dict):
episode_key = episode_dict[episode_num]
episode_key, steps = episode_dict[episode_num]
episode = h5_data.get(episode_key)
idx = slice(index, (index + steps))
data_view = td_data[idx]
for key, val in episode.items():
match = _NAME_MATCH[key]
if key in (
"observations",
"state",
"infos",
):
if not val.shape:
if not val.shape or steps != val.shape[0] - 1:
if val.is_empty():
continue
val = _patch_info(val)
steps = val.shape[0] - 1
td_data["next", match][index : (index + steps)] = val[1:]
td_data[match][index : (index + steps)] = val[:-1]
if steps != val.shape[0] - 1:
raise RuntimeError(
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}."
)
data_view["next", match].copy_(val[1:])
data_view[match].copy_(val[:-1])
elif key not in ("terminations", "truncations", "rewards"):
steps = val.shape[0]
td_data[match][index : (index + val.shape[0])] = val
if steps is None:
steps = val.shape[0]
else:
if steps != val.shape[0]:
raise RuntimeError(
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}."
)
data_view[match].copy_(val)
else:
steps = val.shape[0]
td_data[("next", match)][
index : (index + val.shape[0])
] = val.unsqueeze(-1)
if steps is None:
steps = val.shape[0]
else:
if steps != val.shape[0]:
raise RuntimeError(
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}."
)
data_view[("next", match)].copy_(val.unsqueeze(-1))
data_view["next", "done"].copy_(
data_view["next", "terminated"] | data_view["next", "truncated"]
)
if "done" in data_view.keys():
data_view["done"].copy_(
data_view["terminated"] | data_view["truncated"]
)
if pbar is not None:
pbar.update(steps)
pbar.set_description(
Expand All @@ -289,11 +321,8 @@ def _download_and_preproc(self):
index += steps
h5_data.close()
# Add a "done" entry
with td_data.unlock_():
td_data["next", "done"] = MemoryMappedTensor.from_tensor(
(td_data["next", "terminated"] | td_data["next", "truncated"])
)
if self.split_trajs:
if self.split_trajs:
with td_data.unlock_():
from torchrl.objectives.utils import split_trajectories

td_data = split_trajectories(td_data).memmap_(self.data_path)
Expand All @@ -311,7 +340,7 @@ def _download_and_preproc(self):
return td_data

def _make_split(self):
from torchrl.objectives.utils import split_trajectories
from torchrl.collectors.utils import split_trajectories

self._load_and_proc_metadata()
td_data = TensorDict.load_memmap(self.data_path_root)
Expand Down Expand Up @@ -396,14 +425,18 @@ def _patch_info(info_td):
unique_shapes = defaultdict(list)
for subkey, subval in info_td.items():
unique_shapes[subval.shape[0]].append(subkey)
if not len(unique_shapes) == 2:
raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.")
if len(unique_shapes) == 1:
unique_shapes[subval.shape[0] + 1] = []
if len(unique_shapes) != 2:
raise RuntimeError(
f"Unique shapes in a sub-tensordict can only be of length 2, got shapes {unique_shapes}."
)
val_td = info_td.to_tensordict()
min_shape = min(*unique_shapes) # can only be found at root
max_shape = min_shape + 1
val_td_sel = val_td.select(*unique_shapes[min_shape]).apply(
lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0)
val_td_sel = val_td.select(*unique_shapes[min_shape])
val_td_sel = val_td_sel.apply(
lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0), batch_size=[min_shape + 1]
)
val_td_sel.batch_size = [min_shape + 1]
val_td_sel.update(val_td.select(*unique_shapes[max_shape]))
return val_td_sel

0 comments on commit 2714b62

Please sign in to comment.