Skip to content

Commit

Permalink
feat: add pickle_dump() and pickle_load(), and update save_dict_into_…
Browse files Browse the repository at this point in the history
…h5();
  • Loading branch information
WenjieDu committed Dec 15, 2023
1 parent 99f1ec2 commit 096738c
Showing 1 changed file with 96 additions and 9 deletions.
105 changes: 96 additions & 9 deletions pypots/data/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@


import os
from typing import Optional

import dill as pickle
import h5py

from ..utils.file import create_dir_if_not_exist
from ..utils.file import extract_parent_dir, create_dir_if_not_exist
from ..utils.logging import logger


def save_dict_into_h5(
data_dict: dict,
saving_dir: str,
saving_name: str = "datasets.h5",
saving_path: str,
file_name: Optional[str] = None,
) -> None:
"""Save the given data (in a dictionary) into the given h5 file.
Expand All @@ -26,11 +28,14 @@ def save_dict_into_h5(
data_dict : dict,
The data to be saved, should be a Python dictionary.
saving_dir : str,
The h5 file to save the data.
saving_path : str,
If `file_name` is not given, the given path should be a path to a file with ".h5" suffix.
If `file_name` is given, the given path should be a path to a directory.
If parent directories don't exist, they will be created.
saving_name : str, optional (default="datasets.h5")
The final name of the saved h5 file.
file_name : str, optional (default=None)
The name of the H5 file to be saved and should be with ".h5" suffix.
It's optional. If not set, `saving_path` should be a path to a file with ".h5" suffix.
"""

Expand All @@ -42,9 +47,91 @@ def save_set(handle, name, data):
else:
handle.create_dataset(name, data=data)

create_dir_if_not_exist(saving_dir)
saving_path = os.path.join(saving_dir, saving_name)
# check typing
assert isinstance(
data_dict, dict
), f"`data_dict` should be a Python dictionary, but got {type(data_dict)}."
assert isinstance(
saving_path, str
), f"`saving_path` should be a string, but got {type(saving_path)}."

if file_name is None: # if file_name is not given
# check suffix
if not saving_path.endswith(".h5") or saving_path.endswith(".hdf5"):
logger.warning(
f"‼️ `saving_path` should end with '.h5' or '.hdf5', but got {saving_path}. "
f"PyPOTS will automatically append '.h5' to the given `saving_path`."
)
else: # if file_name is given
# check typing
assert isinstance(
file_name, str
), f"`file_name` should be a string, but got {type(file_name)}."
# check suffix
if not file_name.endswith(".h5") or file_name.endswith(".hdf5"):
logger.warning(
f"‼️ `file_name` should end with '.h5' or '.hdf5', but got {file_name}. "
f"PyPOTS will automatically append '.h5' to the given `file_name`."
)
# organize the saving path
saving_path = os.path.join(saving_path, file_name)

# create the parent folders if not exist
create_dir_if_not_exist(extract_parent_dir(saving_path))

# create the h5 file handle and save the data
with h5py.File(saving_path, "w") as hf:
for k, v in data_dict.items():
save_set(hf, k, v)

logger.info(f"Successfully saved the given data into {saving_path}.")


def pickle_dump(data: object, path: str) -> Optional[str]:
"""Pickle the given object.
Parameters
----------
data:
The object to be pickled.
path:
Saving path.
Returns
-------
`path` if succeed else None
"""
try:
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
except pickle.PicklingError:
logger.error("❌ Pickling failed. No cache data saved.")
return None
logger.info(f"Successfully saved to {path}")
return path


def pickle_load(path: str) -> object:
"""Load pickled object from file.
Parameters
----------
path :
Local path of the pickled object.
Returns
-------
Object
Pickled object.
"""
try:
with open(path, "rb") as f:
data = pickle.load(f)
except pickle.UnpicklingError as e:
logger.error(
"❌ Data file corrupted. Operation aborted. See info below:\n" f"{e}"
)
return data

0 comments on commit 096738c

Please sign in to comment.