diff --git a/pypots/data/saving/__init__.py b/pypots/data/saving/__init__.py index a3406580..e0197d3f 100644 --- a/pypots/data/saving/__init__.py +++ b/pypots/data/saving/__init__.py @@ -5,11 +5,12 @@ # Created by Wenjie Du # License: BSD-3-Clause -from .h5 import save_dict_into_h5 +from .h5 import save_dict_into_h5, load_dict_from_h5 from .pickle import pickle_dump, pickle_load __all__ = [ "save_dict_into_h5", + "load_dict_from_h5", "pickle_dump", "pickle_load", ] diff --git a/pypots/data/saving/h5.py b/pypots/data/saving/h5.py index 84fdc8eb..0b4ca735 100644 --- a/pypots/data/saving/h5.py +++ b/pypots/data/saving/h5.py @@ -7,12 +7,14 @@ import os +from datetime import datetime from typing import Optional import h5py +import yaml -from pypots.utils.file import extract_parent_dir, create_dir_if_not_exist -from pypots.utils.logging import logger +from ...utils.file import extract_parent_dir, create_dir_if_not_exist +from ...utils.logging import logger def save_dict_into_h5( @@ -84,3 +86,50 @@ def save_set(handle, name, data): save_set(hf, k, v) logger.info(f"Successfully saved the given data into {saving_path}.") + + +def load_dict_from_h5( + file_path: str, +) -> dict: + """Load the data from the given h5 file and return as a Python dictionary. + + Parameters + ---------- + file_path : str, + The path to the h5 file. + + Returns + ------- + data : dict, + The data loaded from the given h5 file. + + """ + assert isinstance( + file_path, str + ), f"`file_path` should be a string, but got {type(file_path)}." + assert os.path.exists(file_path), "`file_path` does not exist." + + def load_set(handle, datadict): + for key, item in handle.items(): + if isinstance(item, h5py.Group): + datadict[key] = {} + datadict[key] = load_set(item, datadict[key]) + elif isinstance(item, h5py.Dataset): + value = item[()] + if "_type_" in item.attrs: + if item.attrs["_type_"].astype(str) == "datetime": + if hasattr(value, "__iter__"): + value = [datetime.fromtimestamp(ts) for ts in value] + else: + value = datetime.fromtimestamp(value) + elif item.attrs["_type_"].astype(str) == "yaml": + value = yaml.safe_load(value.decode()) + datadict[key] = value + + return datadict + + data = {} + with h5py.File(file_path, "r") as hf: + data = load_set(hf, data) + + return data