Skip to content

Commit

Permalink
feat: refactor pypots.data.saving package and add testing cases;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Dec 15, 2023
1 parent f1a2a17 commit aecc151
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 54 deletions.
15 changes: 15 additions & 0 deletions pypots/data/saving/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Data saving utilities.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from .h5 import save_dict_into_h5
from .pickle import pickle_dump, pickle_load

__all__ = [
"save_dict_into_h5",
"pickle_dump",
"pickle_load",
]
57 changes: 3 additions & 54 deletions pypots/data/saving.py → pypots/data/saving/h5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Data saving utilities.
Data saving utilities with HDF5.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
Expand All @@ -9,11 +9,10 @@
import os
from typing import Optional

import dill as pickle
import h5py

from ..utils.file import extract_parent_dir, create_dir_if_not_exist
from ..utils.logging import logger
from pypots.utils.file import extract_parent_dir, create_dir_if_not_exist
from pypots.utils.logging import logger


def save_dict_into_h5(
Expand Down Expand Up @@ -85,53 +84,3 @@ def save_set(handle, name, data):
save_set(hf, k, v)

logger.info(f"Successfully saved the given data into {saving_path}.")


def pickle_dump(data: object, path: str) -> Optional[str]:
"""Pickle the given object.
Parameters
----------
data:
The object to be pickled.
path:
Saving path.
Returns
-------
`path` if succeed else None
"""
try:
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
except pickle.PicklingError:
logger.error("❌ Pickling failed. No cache data saved.")
return None
logger.info(f"Successfully saved to {path}")
return path


def pickle_load(path: str) -> object:
"""Load pickled object from file.
Parameters
----------
path :
Local path of the pickled object.
Returns
-------
Object
Pickled object.
"""
try:
with open(path, "rb") as f:
data = pickle.load(f)
except pickle.UnpicklingError as e:
logger.error(
"❌ Data file corrupted. Operation aborted. See info below:\n" f"{e}"
)
return data
61 changes: 61 additions & 0 deletions pypots/data/saving/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Data saving utilities with pickle.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import pickle
from typing import Optional

from pypots.utils.logging import logger


def pickle_dump(data: object, path: str) -> Optional[str]:
"""Pickle the given object.
Parameters
----------
data:
The object to be pickled.
path:
Saving path.
Returns
-------
`path` if succeed else None
"""
try:
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
except pickle.PicklingError:
logger.error("❌ Pickling failed. No cache data saved.")
return None
logger.info(f"Successfully saved to {path}")
return path


def pickle_load(path: str) -> object:
"""Load pickled object from file.
Parameters
----------
path :
Local path of the pickled object.
Returns
-------
Object
Pickled object.
"""
try:
with open(path, "rb") as f:
data = pickle.load(f)
except pickle.UnpicklingError as e:
logger.error(
"❌ Data file corrupted. Operation aborted. See info below:\n" f"{e}"
)
return data
36 changes: 36 additions & 0 deletions tests/data/saving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Test cases for data saving utils.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


import unittest

import pytest

from pypots.data.saving import save_dict_into_h5, pickle_dump, pickle_load
from pypots.utils.logging import logger


class TestLazyLoadingClasses(unittest.TestCase):
logger.info("Running tests for data saving utils...")

data_to_save = {
"a": 1,
"b": 2,
"c": {
"d": 0,
},
}

@pytest.mark.xdist_group(name="data-saving-h5")
def test_0_save_dict_into_h5(self):
save_dict_into_h5(self.data_to_save, "tests/data/saving_with_h5.h5")

@pytest.mark.xdist_group(name="data-saving-pickle")
def test_0_pickle_dump_load(self):
pickle_dump(self.data_to_save, "tests/data/saving_with_pickle.pkl")
loaded_data = pickle_load("tests/data/saving_with_pickle.pkl")
assert loaded_data == self.data_to_save

0 comments on commit aecc151

Please sign in to comment.