Skip to content

Commit

Permalink
Merge pull request #275 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Simplifying val_set, renaming X_intact, and adding unit tests for the visual package
  • Loading branch information
WenjieDu authored Dec 18, 2023
2 parents a796dc2 + 34e258a commit b6adbac
Show file tree
Hide file tree
Showing 71 changed files with 1,301 additions and 947 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/testing_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ jobs:
run: |
which python
which pip
pip install --upgrade pip
pip install torch==${{ steps.determine_pytorch_ver.outputs.value }} -f https://download.pytorch.org/whl/cpu
python -c "import torch; print('PyTorch:', torch.__version__)"
Expand All @@ -58,6 +57,10 @@ jobs:
pip install -r requirements.txt
pip install torch-geometric torch-scatter torch-sparse -f "https://data.pyg.org/whl/torch-${{ steps.determine_pytorch_ver.outputs.value }}+cpu.html"
pip install pypots[dev]
python_site_path=`python -c "import site; print(site.getsitepackages()[0])"`
echo "python site-packages path: $python_site_path"
rm -rf $python_site_path/pypots
python -c "import shutil;import site;shutil.copytree('pypots',site.getsitepackages()[0]+'/pypots')"
- name: Fetch the test environment details
run: |
Expand All @@ -66,7 +69,8 @@ jobs:
- name: Test with pytest
run: |
rm -rf tests/__pycache__
python tests/global_test_config.py
rm -rf tests/__pycache__ && rm -rf tests/*/__pycache__
python -m pytest -rA tests/*/* -n auto --cov=pypots --dist=loadgroup --cov-config=.coveragerc
- name: Generate the LCOV report
Expand Down
3 changes: 2 additions & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ dependencies:
#- conda-forge::pandas <2.0.0
#- conda-forge::h5py
#- conda-forge::tensorboard
#- conda-forge::pygrinder >=0.2
#- conda-forge::pygrinder >=0.4
#- conda-forge::tsdb >=0.2
#- conda-forge::matplotlib
#- pytorch::pytorch >=1.10.0
## Below we install the latest pypots because we need pypots-cli in it for development.
## PyPOTS itself already includes all basic dependencies.
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/brits/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def __init__(
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
super().__init__(data, False, return_labels, file_type)
6 changes: 3 additions & 3 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: dict) -> dict:
def _assemble_input_for_training(self, data: list) -> dict:
# fetch data
(
indices,
Expand Down Expand Up @@ -179,10 +179,10 @@ def _assemble_input_for_training(self, data: dict) -> dict:
}
return inputs

def _assemble_input_for_validating(self, data: dict) -> dict:
def _assemble_input_for_validating(self, data: list) -> dict:
return self._assemble_input_for_training(data)

def _assemble_input_for_testing(self, data: dict) -> dict:
def _assemble_input_for_testing(self, data: list) -> dict:
# fetch data
(
indices,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
super().__init__(data, False, return_labels, file_type)
self.locf = LOCF()

if not isinstance(self.data, str): # data from array
Expand Down
6 changes: 3 additions & 3 deletions pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: dict) -> dict:
def _assemble_input_for_training(self, data: list) -> dict:
# fetch data
(
indices,
Expand All @@ -157,10 +157,10 @@ def _assemble_input_for_training(self, data: dict) -> dict:
}
return inputs

def _assemble_input_for_validating(self, data: dict) -> dict:
def _assemble_input_for_validating(self, data: list) -> dict:
return self._assemble_input_for_training(data)

def _assemble_input_for_testing(self, data: dict) -> dict:
def _assemble_input_for_testing(self, data: list) -> dict:
(
indices,
X,
Expand Down
6 changes: 3 additions & 3 deletions pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: dict) -> dict:
def _assemble_input_for_training(self, data: list) -> dict:
# fetch data
(
indices,
Expand All @@ -199,10 +199,10 @@ def _assemble_input_for_training(self, data: dict) -> dict:
}
return inputs

def _assemble_input_for_validating(self, data: dict) -> dict:
def _assemble_input_for_validating(self, data: list) -> dict:
return self._assemble_input_for_training(data)

def _assemble_input_for_testing(self, data: dict) -> dict:
def _assemble_input_for_testing(self, data: list) -> dict:
(
indices,
X,
Expand Down
2 changes: 1 addition & 1 deletion pypots/clustering/crli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
super().__init__(data, False, return_labels, file_type)

def _fetch_data_from_array(self, idx: int) -> Iterable:
return super()._fetch_data_from_array(idx)
Expand Down
2 changes: 1 addition & 1 deletion pypots/clustering/vader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
super().__init__(data, False, return_labels, file_type)

def _fetch_data_from_array(self, idx: int) -> Iterable:
return super()._fetch_data_from_array(idx)
Expand Down
112 changes: 64 additions & 48 deletions pypots/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import h5py
import numpy as np
import torch
from pygrinder import fill_and_get_mask_torch
from torch.utils.data import Dataset

from .utils import turn_data_into_specified_dtype

# Currently we only support h5 files
SUPPORTED_DATASET_FILE_TYPE = ["h5py"]

Expand Down Expand Up @@ -48,21 +51,23 @@ class BaseDataset(Dataset):
def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
):
super().__init__()
# types and shapes had been checked after X and y input into the model
# So they are safe to use here. No need to check again.

self.data = data
self.return_X_ori = return_X_ori
self.return_labels = return_labels

if isinstance(self.data, str): # data from file
# check if the given file type is supported
assert (
file_type in SUPPORTED_DATASET_FILE_TYPE
), f"file_type should be one of {SUPPORTED_DATASET_FILE_TYPE}, but got {file_type}"

self.file_type = file_type

# open the file handle
Expand All @@ -74,8 +79,23 @@ def __init__(

else: # data from array
X = data["X"]
X_ori = None if "X_ori" not in data.keys() else data["X_ori"]
y = None if "y" not in data.keys() else data["y"]
self.X, self.y = self._check_input(X, y)
self.X, self.X_ori, self.y = self._check_array_input(X, X_ori, y)

if self.X_ori is not None and self.return_X_ori:
# Only when X_ori is given and fixed, we fill the missing values in X here in advance.
# Otherwise, we may need original X with missing values to generate X_ori, e.g. in DatasetForSAITS.
self.X, self.missing_mask = fill_and_get_mask_torch(self.X)

self.X_ori, X_ori_missing_mask = fill_and_get_mask_torch(self.X_ori)
indicating_mask = X_ori_missing_mask - self.missing_mask
self.indicating_mask = indicating_mask.to(torch.float32)
else:
self.missing_mask = None
self.indicating_mask = None
# if return_X_ori is false, set X_ori to None as well
self.X_ori = None

self.n_samples, self.n_steps, self.n_features = self._get_data_sizes()

Expand Down Expand Up @@ -112,12 +132,14 @@ def __len__(self) -> int:
return self.n_samples

@staticmethod
def _check_input(
def _check_array_input(
X: Union[np.ndarray, torch.Tensor, list],
X_ori: Union[np.ndarray, torch.Tensor, list],
y: Optional[Union[np.ndarray, torch.Tensor, list]] = None,
out_dtype: str = "tensor",
) -> Tuple[
Union[np.ndarray, torch.Tensor, list],
Union[np.ndarray, torch.Tensor],
Union[np.ndarray, torch.Tensor],
Optional[Union[np.ndarray, torch.Tensor, list]],
]:
"""Check value type and shape of input X and y
Expand All @@ -127,6 +149,10 @@ def _check_input(
X :
Time-series data that must have a shape like [n_samples, expected_n_steps, expected_n_features].
X_ori :
If X is with artificial missingness, X_ori is the original X without artificial missing values.
It must have the same shape as X. If X_ori is with original missing values, should be left as NaN.
y :
Labels of time-series samples (X) that must have a shape like [n_samples] or [n_samples, n_classes].
Expand All @@ -137,6 +163,8 @@ def _check_input(
-------
X :
X_ori :
y :
"""
Expand All @@ -145,55 +173,29 @@ def _check_input(
"ndarray",
], f'out_dtype should be "tensor" or "ndarray", but got {out_dtype}'

is_list = isinstance(X, list)
is_array = isinstance(X, np.ndarray)
is_tensor = isinstance(X, torch.Tensor)
assert is_tensor or is_array or is_list, TypeError(
"X should be an instance of list/np.ndarray/torch.Tensor, "
f"but got {type(X)}"
)

# convert the data type if in need
if out_dtype == "tensor":
if is_list:
X = torch.tensor(X)
elif is_array:
X = torch.from_numpy(X)
else: # is tensor
pass
else: # out_dtype is ndarray
# convert to np.ndarray first for shape check
if is_list:
X = np.asarray(X)
elif is_tensor:
X = X.numpy()
else: # is ndarray
pass
# change the data type of X
X = turn_data_into_specified_dtype(X, out_dtype)
X = X.to(torch.float32)

# check the shape of X here
X_shape = X.shape
assert len(X_shape) == 3, (
f"input should have 3 dimensions [n_samples, seq_len, n_features],"
f"but got shape={X_shape}"
f"but got X: {X_shape}"
)

if X_ori is not None:
X_ori = turn_data_into_specified_dtype(X_ori, out_dtype)
X_ori = X_ori.to(torch.float32)
assert (
X_shape == X_ori.shape
), f"X and X_ori must have matched shape, but got X: f{X.shape} and X_ori: {X_ori.shape}"
if y is not None:
assert len(X) == len(y), (
f"lengths of X and y must match, " f"but got f{len(X)} and {len(y)}"
)
if isinstance(y, torch.Tensor):
y = y if out_dtype == "tensor" else y.numpy()
elif isinstance(y, list):
y = torch.tensor(y) if out_dtype == "tensor" else np.asarray(y)
elif isinstance(y, np.ndarray):
y = torch.from_numpy(y) if out_dtype == "tensor" else y
else:
raise TypeError(
"y should be an instance of list/np.ndarray/torch.Tensor, "
f"but got {type(y)}"
)
y = turn_data_into_specified_dtype(y, out_dtype)

return X, y
return X, X_ori, y

@abstractmethod
def _fetch_data_from_array(self, idx: int) -> Iterable:
Expand All @@ -210,15 +212,24 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
The collated data sample, a list including all necessary sample info.
"""

X = self.X[idx].to(torch.float32)
missing_mask = (~torch.isnan(X)).to(torch.float32)
X = torch.nan_to_num(X)
if self.X_ori is None:
X = self.X[idx]
X, missing_mask = fill_and_get_mask_torch(X)
else:
X = self.X[idx]
missing_mask = self.missing_mask[idx]

sample = [
torch.tensor(idx),
X,
missing_mask,
]

if self.X_ori is not None and self.return_X_ori:
X_ori = self.X_ori[idx]
indicating_mask = self.indicating_mask[idx]
sample.extend([X_ori, indicating_mask])

if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))

Expand Down Expand Up @@ -286,14 +297,19 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
self.file_handle = self._open_file_handle()

X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32)
missing_mask = (~torch.isnan(X)).to(torch.float32)
X = torch.nan_to_num(X)
X, missing_mask = fill_and_get_mask_torch(X)
sample = [
torch.tensor(idx),
X,
missing_mask,
]

if "X_ori" in self.file_handle.keys() and self.return_X_ori:
X_ori = torch.from_numpy(self.file_handle["X_ori"][idx]).to(torch.float32)
X_ori, X_ori_missing_mask = fill_and_get_mask_torch(X_ori)
indicating_mask = (X_ori_missing_mask - missing_mask).to(torch.float32)
sample.extend([X_ori, indicating_mask])

# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
sample.append(self.file_handle["y"][idx].to(torch.long))
Expand Down
21 changes: 21 additions & 0 deletions pypots/data/checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from typing import Union

import h5py


def check_X_ori_in_val_set(val_set: Union[str, dict]) -> bool:
if isinstance(val_set, str):
with h5py.File(val_set, "r") as f:
return "X_ori" in f.keys()
elif isinstance(val_set, dict):
return "X_ori" in val_set.keys()
else:
raise TypeError("val_set must be a str or a Python dictionary.")
Loading

0 comments on commit b6adbac

Please sign in to comment.