From 2714b62cab6dc15891cd425e1129f25d0002d967 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 10:08:52 +0000 Subject: [PATCH] fixes --- test/test_libs.py | 71 ++++++++++++--------- torchrl/data/datasets/minari_data.py | 93 +++++++++++++++++++--------- 2 files changed, 105 insertions(+), 59 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 8b54ca1c243..00869c8fa16 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1964,38 +1964,51 @@ def test_d4rl_iteration(self, task, split_trajs): print(f"terminated test after {time.time()-t0}s") +_MINARI_DATASETS = [] + + +def _minari_selected_datasets(): + if not _has_minari: + return + global _MINARI_DATASETS + import minari + + torch.manual_seed(0) + + keys = list(minari.list_remote_datasets()) + indices = torch.randperm(len(keys))[:10] + keys = [keys[idx] for idx in indices] + keys = [ + key + for key in keys + if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] + ] + assert len(keys) > 5 + _MINARI_DATASETS += keys + print("_MINARI_DATASETS", _MINARI_DATASETS) + + +_minari_selected_datasets() + + @pytest.mark.skipif(not _has_minari, reason="Minari not found") +@pytest.mark.parametrize("split", [False, True]) +@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS) class TestMinari: - @pytest.fixture(scope="class") - def selected_datasets(self): - torch.manual_seed(0) - import minari - - keys = list(minari.list_remote_datasets()) - indices = torch.randperm(len(keys))[:10] - keys = [keys[idx] for idx in indices] - keys = [ - key - for key in keys - if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] - ] - assert len(keys) > 5 - return keys - - @pytest.mark.parametrize("split", [False, True]) - def test_load(self, selected_datasets, split): - for dataset in selected_datasets: - print("dataset", dataset) - data = MinariExperienceReplay(dataset, batch_size=32, split_trajs=split) + def test_load(self, selected_dataset, split): + print("dataset", selected_dataset) + data = MinariExperienceReplay( + selected_dataset, batch_size=32, split_trajs=split + ) + t0 = time.time() + for i, sample in enumerate(data): + t1 = time.time() + print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + assert data.metadata["action_space"].is_in(sample["action"]) + assert data.metadata["observation_space"].is_in(sample["observation"]) t0 = time.time() - for i, sample in enumerate(data): - t1 = time.time() - print(f"sampling time {1000 * (t1-t0): 4.4f}ms") - assert data.metadata["action_space"].is_in(sample["action"]) - assert data.metadata["observation_space"].is_in(sample["observation"]) - t0 = time.time() - if i == 10: - break + if i == 10: + break @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index df70145c0e1..945ad9d7320 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -18,7 +18,7 @@ import torch -from tensordict import MemoryMappedTensor, PersistentTensorDict, TensorDict +from tensordict import PersistentTensorDict, TensorDict from torchrl._utils import KeyDependentDefaultDict from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer @@ -223,19 +223,22 @@ def _download_and_preproc(self): total_steps = 0 print("first read through data to create data structure...") h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") - # Get the total number of steps for the dataset - total_steps += sum( - h5_data[episode, "actions"].shape[0] for episode in h5_data.keys() - ) # populate the tensordict episode_dict = {} for episode_key, episode in h5_data.items(): episode_num = int(episode_key[len("episode_") :]) - episode_dict[episode_num] = episode_key + episode_len = episode["actions"].shape[0] + episode_dict[episode_num] = (episode_key, episode_len) + # Get the total number of steps for the dataset + total_steps += episode_len for key, val in episode.items(): match = _NAME_MATCH[key] if key in ("observations", "state", "infos"): - if not val.shape: + if ( + not val.shape + ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: + if val.is_empty(): + continue val = _patch_info(val) td_data.set(("next", match), torch.zeros_like(val[0])) td_data.set(match, torch.zeros_like(val[0])) @@ -248,19 +251,26 @@ def _download_and_preproc(self): ) # give it the proper size + td_data["next", "done"] = ( + td_data["next", "truncated"] | td_data["next", "terminated"] + ) + if "terminated" in td_data.keys(): + td_data["done"] = td_data["truncated"] | td_data["terminated"] td_data = td_data.expand(total_steps) # save to designated location print(f"creating tensordict data in {self.data_path_root}: ", end="\t") td_data = td_data.memmap_like(self.data_path_root) - print(td_data) + print("tensordict structure:", td_data) print(f"Reading data from {max(*episode_dict)} episodes") index = 0 with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: # iterate over episodes and populate the tensordict for episode_num in sorted(episode_dict): - episode_key = episode_dict[episode_num] + episode_key, steps = episode_dict[episode_num] episode = h5_data.get(episode_key) + idx = slice(index, (index + steps)) + data_view = td_data[idx] for key, val in episode.items(): match = _NAME_MATCH[key] if key in ( @@ -268,19 +278,41 @@ def _download_and_preproc(self): "state", "infos", ): - if not val.shape: + if not val.shape or steps != val.shape[0] - 1: + if val.is_empty(): + continue val = _patch_info(val) - steps = val.shape[0] - 1 - td_data["next", match][index : (index + steps)] = val[1:] - td_data[match][index : (index + steps)] = val[:-1] + if steps != val.shape[0] - 1: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." + ) + data_view["next", match].copy_(val[1:]) + data_view[match].copy_(val[:-1]) elif key not in ("terminations", "truncations", "rewards"): - steps = val.shape[0] - td_data[match][index : (index + val.shape[0])] = val + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[match].copy_(val) else: - steps = val.shape[0] - td_data[("next", match)][ - index : (index + val.shape[0]) - ] = val.unsqueeze(-1) + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[("next", match)].copy_(val.unsqueeze(-1)) + data_view["next", "done"].copy_( + data_view["next", "terminated"] | data_view["next", "truncated"] + ) + if "done" in data_view.keys(): + data_view["done"].copy_( + data_view["terminated"] | data_view["truncated"] + ) if pbar is not None: pbar.update(steps) pbar.set_description( @@ -289,11 +321,8 @@ def _download_and_preproc(self): index += steps h5_data.close() # Add a "done" entry - with td_data.unlock_(): - td_data["next", "done"] = MemoryMappedTensor.from_tensor( - (td_data["next", "terminated"] | td_data["next", "truncated"]) - ) - if self.split_trajs: + if self.split_trajs: + with td_data.unlock_(): from torchrl.objectives.utils import split_trajectories td_data = split_trajectories(td_data).memmap_(self.data_path) @@ -311,7 +340,7 @@ def _download_and_preproc(self): return td_data def _make_split(self): - from torchrl.objectives.utils import split_trajectories + from torchrl.collectors.utils import split_trajectories self._load_and_proc_metadata() td_data = TensorDict.load_memmap(self.data_path_root) @@ -396,14 +425,18 @@ def _patch_info(info_td): unique_shapes = defaultdict(list) for subkey, subval in info_td.items(): unique_shapes[subval.shape[0]].append(subkey) - if not len(unique_shapes) == 2: - raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.") + if len(unique_shapes) == 1: + unique_shapes[subval.shape[0] + 1] = [] + if len(unique_shapes) != 2: + raise RuntimeError( + f"Unique shapes in a sub-tensordict can only be of length 2, got shapes {unique_shapes}." + ) val_td = info_td.to_tensordict() min_shape = min(*unique_shapes) # can only be found at root max_shape = min_shape + 1 - val_td_sel = val_td.select(*unique_shapes[min_shape]).apply( - lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0) + val_td_sel = val_td.select(*unique_shapes[min_shape]) + val_td_sel = val_td_sel.apply( + lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0), batch_size=[min_shape + 1] ) - val_td_sel.batch_size = [min_shape + 1] val_td_sel.update(val_td.select(*unique_shapes[max_shape])) return val_td_sel