-
-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #268 from WenjieDu/dev
Updating package `pypots.data.saving`
- Loading branch information
Showing
6 changed files
with
205 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Data saving utilities. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
from .h5 import save_dict_into_h5 | ||
from .pickle import pickle_dump, pickle_load | ||
|
||
__all__ = [ | ||
"save_dict_into_h5", | ||
"pickle_dump", | ||
"pickle_load", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
Data saving utilities with HDF5. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
|
||
import os | ||
from typing import Optional | ||
|
||
import h5py | ||
|
||
from pypots.utils.file import extract_parent_dir, create_dir_if_not_exist | ||
from pypots.utils.logging import logger | ||
|
||
|
||
def save_dict_into_h5( | ||
data_dict: dict, | ||
saving_path: str, | ||
file_name: Optional[str] = None, | ||
) -> None: | ||
"""Save the given data (in a dictionary) into the given h5 file. | ||
Parameters | ||
---------- | ||
data_dict : dict, | ||
The data to be saved, should be a Python dictionary. | ||
saving_path : str, | ||
If `file_name` is not given, the given path should be a path to a file with ".h5" suffix. | ||
If `file_name` is given, the given path should be a path to a directory. | ||
If parent directories don't exist, they will be created. | ||
file_name : str, optional (default=None) | ||
The name of the H5 file to be saved and should be with ".h5" suffix. | ||
It's optional. If not set, `saving_path` should be a path to a file with ".h5" suffix. | ||
""" | ||
|
||
def save_set(handle, name, data): | ||
if isinstance(data, dict): | ||
single_set_handle = handle.create_group(name) | ||
for key, value in data.items(): | ||
save_set(single_set_handle, key, value) | ||
else: | ||
handle.create_dataset(name, data=data) | ||
|
||
# check typing | ||
assert isinstance( | ||
data_dict, dict | ||
), f"`data_dict` should be a Python dictionary, but got {type(data_dict)}." | ||
assert isinstance( | ||
saving_path, str | ||
), f"`saving_path` should be a string, but got {type(saving_path)}." | ||
|
||
if file_name is None: # if file_name is not given | ||
# check suffix | ||
if not saving_path.endswith(".h5") or saving_path.endswith(".hdf5"): | ||
logger.warning( | ||
f"‼️ `saving_path` should end with '.h5' or '.hdf5', but got {saving_path}. " | ||
f"PyPOTS will automatically append '.h5' to the given `saving_path`." | ||
) | ||
else: # if file_name is given | ||
# check typing | ||
assert isinstance( | ||
file_name, str | ||
), f"`file_name` should be a string, but got {type(file_name)}." | ||
# check suffix | ||
if not file_name.endswith(".h5") or file_name.endswith(".hdf5"): | ||
logger.warning( | ||
f"‼️ `file_name` should end with '.h5' or '.hdf5', but got {file_name}. " | ||
f"PyPOTS will automatically append '.h5' to the given `file_name`." | ||
) | ||
# organize the saving path | ||
saving_path = os.path.join(saving_path, file_name) | ||
|
||
# create the parent folders if not exist | ||
create_dir_if_not_exist(extract_parent_dir(saving_path)) | ||
|
||
# create the h5 file handle and save the data | ||
with h5py.File(saving_path, "w") as hf: | ||
for k, v in data_dict.items(): | ||
save_set(hf, k, v) | ||
|
||
logger.info(f"Successfully saved the given data into {saving_path}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
Data saving utilities with pickle. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
import pickle | ||
from typing import Optional | ||
|
||
from pypots.utils.logging import logger | ||
|
||
|
||
def pickle_dump(data: object, path: str) -> Optional[str]: | ||
"""Pickle the given object. | ||
Parameters | ||
---------- | ||
data: | ||
The object to be pickled. | ||
path: | ||
Saving path. | ||
Returns | ||
------- | ||
`path` if succeed else None | ||
""" | ||
try: | ||
with open(path, "wb") as f: | ||
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) | ||
except pickle.PicklingError: | ||
logger.error("❌ Pickling failed. No cache data saved.") | ||
return None | ||
logger.info(f"Successfully saved to {path}") | ||
return path | ||
|
||
|
||
def pickle_load(path: str) -> object: | ||
"""Load pickled object from file. | ||
Parameters | ||
---------- | ||
path : | ||
Local path of the pickled object. | ||
Returns | ||
------- | ||
Object | ||
Pickled object. | ||
""" | ||
try: | ||
with open(path, "rb") as f: | ||
data = pickle.load(f) | ||
except pickle.UnpicklingError as e: | ||
logger.error( | ||
"❌ Data file corrupted. Operation aborted. See info below:\n" f"{e}" | ||
) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
Test cases for data saving utils. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
|
||
import unittest | ||
|
||
import pytest | ||
|
||
from pypots.data.saving import save_dict_into_h5, pickle_dump, pickle_load | ||
from pypots.utils.logging import logger | ||
|
||
|
||
class TestLazyLoadingClasses(unittest.TestCase): | ||
logger.info("Running tests for data saving utils...") | ||
|
||
data_to_save = { | ||
"a": 1, | ||
"b": 2, | ||
"c": { | ||
"d": 0, | ||
}, | ||
} | ||
|
||
@pytest.mark.xdist_group(name="data-saving-h5") | ||
def test_0_save_dict_into_h5(self): | ||
save_dict_into_h5(self.data_to_save, "tests/data/saving_with_h5.h5") | ||
|
||
@pytest.mark.xdist_group(name="data-saving-pickle") | ||
def test_0_pickle_dump_load(self): | ||
pickle_dump(self.data_to_save, "tests/data/saving_with_pickle.pkl") | ||
loaded_data = pickle_load("tests/data/saving_with_pickle.pkl") | ||
assert loaded_data == self.data_to_save |