Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 1, 2023
1 parent 7247450 commit 49f5f51
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 77 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ Here's an example:


D4RLExperienceReplay
MinariExperienceReplay
OpenMLExperienceReplay

TensorSpec
Expand Down
146 changes: 69 additions & 77 deletions torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,54 +201,47 @@ def _download_and_preproc(self):
minari.download_dataset(dataset_id=self.dataset_id)
parent_dir = Path(tmpdir) / self.dataset_id / "data"

h5files = []
for filename in os.listdir(parent_dir):
if filename.endswith(".hdf5"):
file_path = parent_dir / filename
h5files.append(file_path)

td_data = TensorDict({}, [])
total_steps = 0
print("first read through data to create data structure...")
with tqdm.tqdm(h5files) as pbar:
for h5file in pbar:
pbar.set_description(f"reading h5 {h5file}")
h5_data = PersistentTensorDict.from_h5(h5file)
# Get the total number of steps for the dataset
total_steps += sum(
h5_data[episode, "actions"].shape[0]
for episode in h5_data.keys()
)
# populate the tensordict
for key, episode in h5_data.items():
for key, val in episode.items():
match = _NAME_MATCH[key]
if key in ("observations", "state", "infos"):
if not val.shape:
# Data is ambiguous, skipping
continue
# unique_shapes = defaultdict([])
# for subkey, subval in val.items():
# unique_shapes[subval.shape[0]].append(subkey)
# if not len(unique_shapes) == 2:
# raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.")
# val_td = val.to_tensordict()
# min_shape = min(*unique_shapes) # can only be found at root
# max_shape = min_shape + 1
# val_td = val_td.select(*unique_shapes[min_shape])
# print("key - val", key, val)
# print("episode", episode)
td_data.set(("next", match), torch.zeros_like(val)[0])
td_data.set(match, torch.zeros_like(val)[0])
if key not in ("terminations", "truncations", "rewards"):
td_data.set(match, torch.zeros_like(val)[0])
else:
td_data.set(
("next", match),
torch.zeros_like(val)[0].unsqueeze(-1),
)
break
h5_data.close()
h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5")
# Get the total number of steps for the dataset
total_steps += sum(
h5_data[episode, "actions"].shape[0]
for episode in h5_data.keys()
)
# populate the tensordict
episode_dict = {}
for key, episode in h5_data.items():
for key, val in episode.items():
episode_num = int(key[len("episode_"):])
episode_dict[episode_num] = key
match = _NAME_MATCH[key]
if key in ("observations", "state", "infos"):
if not val.shape:
# Data is ambiguous, skipping
continue
# unique_shapes = defaultdict([])
# for subkey, subval in val.items():
# unique_shapes[subval.shape[0]].append(subkey)
# if not len(unique_shapes) == 2:
# raise RuntimeError("Unique shapes in a sub-tensordict can only be of length 2.")
# val_td = val.to_tensordict()
# min_shape = min(*unique_shapes) # can only be found at root
# max_shape = min_shape + 1
# val_td = val_td.select(*unique_shapes[min_shape])
# print("key - val", key, val)
# print("episode", episode)
td_data.set(("next", match), torch.zeros_like(val)[0])
td_data.set(match, torch.zeros_like(val)[0])
if key not in ("terminations", "truncations", "rewards"):
td_data.set(match, torch.zeros_like(val)[0])
else:
td_data.set(
("next", match),
torch.zeros_like(val)[0].unsqueeze(-1),
)
break

# give it the proper size
td_data = td_data.expand(total_steps)
Expand All @@ -260,38 +253,37 @@ def _download_and_preproc(self):
print("Reading data")
index = 0
with tqdm.tqdm(total=total_steps) as pbar:
for h5file in h5files:
h5_data = PersistentTensorDict.from_h5(h5file)
# TODO: sort episodes
# iterate over episodes and populate the tensordict
for key, episode in h5_data.items():
for key, val in episode.items():
match = _NAME_MATCH[key]
if key in (
"observations",
"state",
"infos",
):
if not val.shape:
# Data is ambiguous, skipping
continue
steps = val.shape[0] - 1
td_data["next", match][index : (index + steps)] = val[
1:
]
td_data[match][index : (index + steps)] = val[:-1]
elif key not in ("terminations", "truncations", "rewards"):
steps = val.shape[0]
td_data[match][index : (index + val.shape[0])] = val
else:
steps = val.shape[0]
td_data[("next", match)][
index : (index + val.shape[0])
] = val.unsqueeze(-1)
pbar.update(steps)
pbar.set_description(f"index={index} - h5 {h5file}")
index += steps
h5_data.close()
# iterate over episodes and populate the tensordict
for episode_num in sorted(episode_dict):
key = episode_dict[episode_num]
episode = h5_data.get(key)
for key, val in episode.items():
match = _NAME_MATCH[key]
if key in (
"observations",
"state",
"infos",
):
if not val.shape:
# Data is ambiguous, skipping
continue
steps = val.shape[0] - 1
td_data["next", match][index : (index + steps)] = val[
1:
]
td_data[match][index : (index + steps)] = val[:-1]
elif key not in ("terminations", "truncations", "rewards"):
steps = val.shape[0]
td_data[match][index : (index + val.shape[0])] = val
else:
steps = val.shape[0]
td_data[("next", match)][
index : (index + val.shape[0])
] = val.unsqueeze(-1)
pbar.update(steps)
pbar.set_description(f"index={index} - episode num {episode_num}")
index += steps
h5_data.close()
# Add a "done" entry
with td_data.unlock_():
td_data["next", "done"] = MemoryMappedTensor.from_tensor(
Expand Down

0 comments on commit 49f5f51

Please sign in to comment.