Skip to content

Commit

Permalink
Merge pull request #268 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Updating package `pypots.data.saving`
  • Loading branch information
WenjieDu authored Dec 15, 2023
2 parents e2dc211 + aecc151 commit d457629
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 50 deletions.
7 changes: 7 additions & 0 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ def _setup_device(self, device: Union[None, str, torch.device, list]) -> None:
), "You are trying to use CUDA for model training, but CUDA is not available in your environment."

def _setup_path(self, saving_path) -> None:
MODEL_NO_NEED_TO_SAVE = [
"LOCF",
]
# if the model is no need to save (e.g. LOCF), then skip the following steps
if self.__class__.__name__ in MODEL_NO_NEED_TO_SAVE:
return

if isinstance(saving_path, str):
# get the current time to append to saving_path,
# so you can use the same saving_path to run multiple times
Expand Down
50 changes: 0 additions & 50 deletions pypots/data/saving.py

This file was deleted.

15 changes: 15 additions & 0 deletions pypots/data/saving/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Data saving utilities.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from .h5 import save_dict_into_h5
from .pickle import pickle_dump, pickle_load

__all__ = [
"save_dict_into_h5",
"pickle_dump",
"pickle_load",
]
86 changes: 86 additions & 0 deletions pypots/data/saving/h5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Data saving utilities with HDF5.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


import os
from typing import Optional

import h5py

from pypots.utils.file import extract_parent_dir, create_dir_if_not_exist
from pypots.utils.logging import logger


def save_dict_into_h5(
data_dict: dict,
saving_path: str,
file_name: Optional[str] = None,
) -> None:
"""Save the given data (in a dictionary) into the given h5 file.
Parameters
----------
data_dict : dict,
The data to be saved, should be a Python dictionary.
saving_path : str,
If `file_name` is not given, the given path should be a path to a file with ".h5" suffix.
If `file_name` is given, the given path should be a path to a directory.
If parent directories don't exist, they will be created.
file_name : str, optional (default=None)
The name of the H5 file to be saved and should be with ".h5" suffix.
It's optional. If not set, `saving_path` should be a path to a file with ".h5" suffix.
"""

def save_set(handle, name, data):
if isinstance(data, dict):
single_set_handle = handle.create_group(name)
for key, value in data.items():
save_set(single_set_handle, key, value)
else:
handle.create_dataset(name, data=data)

# check typing
assert isinstance(
data_dict, dict
), f"`data_dict` should be a Python dictionary, but got {type(data_dict)}."
assert isinstance(
saving_path, str
), f"`saving_path` should be a string, but got {type(saving_path)}."

if file_name is None: # if file_name is not given
# check suffix
if not saving_path.endswith(".h5") or saving_path.endswith(".hdf5"):
logger.warning(
f"‼️ `saving_path` should end with '.h5' or '.hdf5', but got {saving_path}. "
f"PyPOTS will automatically append '.h5' to the given `saving_path`."
)
else: # if file_name is given
# check typing
assert isinstance(
file_name, str
), f"`file_name` should be a string, but got {type(file_name)}."
# check suffix
if not file_name.endswith(".h5") or file_name.endswith(".hdf5"):
logger.warning(
f"‼️ `file_name` should end with '.h5' or '.hdf5', but got {file_name}. "
f"PyPOTS will automatically append '.h5' to the given `file_name`."
)
# organize the saving path
saving_path = os.path.join(saving_path, file_name)

# create the parent folders if not exist
create_dir_if_not_exist(extract_parent_dir(saving_path))

# create the h5 file handle and save the data
with h5py.File(saving_path, "w") as hf:
for k, v in data_dict.items():
save_set(hf, k, v)

logger.info(f"Successfully saved the given data into {saving_path}.")
61 changes: 61 additions & 0 deletions pypots/data/saving/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Data saving utilities with pickle.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import pickle
from typing import Optional

from pypots.utils.logging import logger


def pickle_dump(data: object, path: str) -> Optional[str]:
"""Pickle the given object.
Parameters
----------
data:
The object to be pickled.
path:
Saving path.
Returns
-------
`path` if succeed else None
"""
try:
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
except pickle.PicklingError:
logger.error("❌ Pickling failed. No cache data saved.")
return None
logger.info(f"Successfully saved to {path}")
return path


def pickle_load(path: str) -> object:
"""Load pickled object from file.
Parameters
----------
path :
Local path of the pickled object.
Returns
-------
Object
Pickled object.
"""
try:
with open(path, "rb") as f:
data = pickle.load(f)
except pickle.UnpicklingError as e:
logger.error(
"❌ Data file corrupted. Operation aborted. See info below:\n" f"{e}"
)
return data
36 changes: 36 additions & 0 deletions tests/data/saving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Test cases for data saving utils.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


import unittest

import pytest

from pypots.data.saving import save_dict_into_h5, pickle_dump, pickle_load
from pypots.utils.logging import logger


class TestLazyLoadingClasses(unittest.TestCase):
logger.info("Running tests for data saving utils...")

data_to_save = {
"a": 1,
"b": 2,
"c": {
"d": 0,
},
}

@pytest.mark.xdist_group(name="data-saving-h5")
def test_0_save_dict_into_h5(self):
save_dict_into_h5(self.data_to_save, "tests/data/saving_with_h5.h5")

@pytest.mark.xdist_group(name="data-saving-pickle")
def test_0_pickle_dump_load(self):
pickle_dump(self.data_to_save, "tests/data/saving_with_pickle.pkl")
loaded_data = pickle_load("tests/data/saving_with_pickle.pkl")
assert loaded_data == self.data_to_save

0 comments on commit d457629

Please sign in to comment.