Skip to content

Commit

Permalink
feat: add load_dict_from_h5();
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Dec 19, 2023
1 parent 4725ab8 commit 1929e17
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pypots/data/saving/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from .h5 import save_dict_into_h5
from .h5 import save_dict_into_h5, load_dict_from_h5
from .pickle import pickle_dump, pickle_load

__all__ = [
"save_dict_into_h5",
"load_dict_from_h5",
"pickle_dump",
"pickle_load",
]
53 changes: 51 additions & 2 deletions pypots/data/saving/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@


import os
from datetime import datetime
from typing import Optional

import h5py
import yaml

from pypots.utils.file import extract_parent_dir, create_dir_if_not_exist
from pypots.utils.logging import logger
from ...utils.file import extract_parent_dir, create_dir_if_not_exist
from ...utils.logging import logger


def save_dict_into_h5(
Expand Down Expand Up @@ -84,3 +86,50 @@ def save_set(handle, name, data):
save_set(hf, k, v)

logger.info(f"Successfully saved the given data into {saving_path}.")


def load_dict_from_h5(
file_path: str,
) -> dict:
"""Load the data from the given h5 file and return as a Python dictionary.
Parameters
----------
file_path : str,
The path to the h5 file.
Returns
-------
data : dict,
The data loaded from the given h5 file.
"""
assert isinstance(
file_path, str
), f"`file_path` should be a string, but got {type(file_path)}."
assert os.path.exists(file_path), "`file_path` does not exist."

def load_set(handle, datadict):
for key, item in handle.items():
if isinstance(item, h5py.Group):
datadict[key] = {}
datadict[key] = load_set(item, datadict[key])
elif isinstance(item, h5py.Dataset):
value = item[()]
if "_type_" in item.attrs:
if item.attrs["_type_"].astype(str) == "datetime":
if hasattr(value, "__iter__"):
value = [datetime.fromtimestamp(ts) for ts in value]
else:
value = datetime.fromtimestamp(value)
elif item.attrs["_type_"].astype(str) == "yaml":
value = yaml.safe_load(value.decode())
datadict[key] = value

return datadict

data = {}
with h5py.File(file_path, "r") as hf:
data = load_set(hf, data)

return data

0 comments on commit 1929e17

Please sign in to comment.