-
-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: refactor pypots.data.saving package and add testing cases;
- Loading branch information
Showing
4 changed files
with
115 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Data saving utilities. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
from .h5 import save_dict_into_h5 | ||
from .pickle import pickle_dump, pickle_load | ||
|
||
__all__ = [ | ||
"save_dict_into_h5", | ||
"pickle_dump", | ||
"pickle_load", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
Data saving utilities with pickle. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
import pickle | ||
from typing import Optional | ||
|
||
from pypots.utils.logging import logger | ||
|
||
|
||
def pickle_dump(data: object, path: str) -> Optional[str]: | ||
"""Pickle the given object. | ||
Parameters | ||
---------- | ||
data: | ||
The object to be pickled. | ||
path: | ||
Saving path. | ||
Returns | ||
------- | ||
`path` if succeed else None | ||
""" | ||
try: | ||
with open(path, "wb") as f: | ||
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) | ||
except pickle.PicklingError: | ||
logger.error("❌ Pickling failed. No cache data saved.") | ||
return None | ||
logger.info(f"Successfully saved to {path}") | ||
return path | ||
|
||
|
||
def pickle_load(path: str) -> object: | ||
"""Load pickled object from file. | ||
Parameters | ||
---------- | ||
path : | ||
Local path of the pickled object. | ||
Returns | ||
------- | ||
Object | ||
Pickled object. | ||
""" | ||
try: | ||
with open(path, "rb") as f: | ||
data = pickle.load(f) | ||
except pickle.UnpicklingError as e: | ||
logger.error( | ||
"❌ Data file corrupted. Operation aborted. See info below:\n" f"{e}" | ||
) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
Test cases for data saving utils. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
|
||
import unittest | ||
|
||
import pytest | ||
|
||
from pypots.data.saving import save_dict_into_h5, pickle_dump, pickle_load | ||
from pypots.utils.logging import logger | ||
|
||
|
||
class TestLazyLoadingClasses(unittest.TestCase): | ||
logger.info("Running tests for data saving utils...") | ||
|
||
data_to_save = { | ||
"a": 1, | ||
"b": 2, | ||
"c": { | ||
"d": 0, | ||
}, | ||
} | ||
|
||
@pytest.mark.xdist_group(name="data-saving-h5") | ||
def test_0_save_dict_into_h5(self): | ||
save_dict_into_h5(self.data_to_save, "tests/data/saving_with_h5.h5") | ||
|
||
@pytest.mark.xdist_group(name="data-saving-pickle") | ||
def test_0_pickle_dump_load(self): | ||
pickle_dump(self.data_to_save, "tests/data/saving_with_pickle.pkl") | ||
loaded_data = pickle_load("tests/data/saving_with_pickle.pkl") | ||
assert loaded_data == self.data_to_save |