Skip to content

Commit

Permalink
Merge pull request #274 from WenjieDu/temp_test_branch
Browse files Browse the repository at this point in the history
Renaming X_intact into X_ori, and adding matplotlib as a dependency
  • Loading branch information
WenjieDu authored Dec 18, 2023
2 parents d679d73 + a28824c commit 34e258a
Show file tree
Hide file tree
Showing 50 changed files with 228 additions and 249 deletions.
3 changes: 2 additions & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ dependencies:
#- conda-forge::pandas <2.0.0
#- conda-forge::h5py
#- conda-forge::tensorboard
#- conda-forge::pygrinder >=0.2
#- conda-forge::pygrinder >=0.4
#- conda-forge::tsdb >=0.2
#- conda-forge::matplotlib
#- pytorch::pytorch >=1.10.0
## Below we install the latest pypots because we need pypots-cli in it for development.
## PyPOTS itself already includes all basic dependencies.
Expand Down
66 changes: 31 additions & 35 deletions pypots/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BaseDataset(Dataset):
def __init__(
self,
data: Union[dict, str],
return_X_intact: bool,
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
):
Expand All @@ -60,7 +60,7 @@ def __init__(
# So they are safe to use here. No need to check again.

self.data = data
self.return_X_intact = return_X_intact
self.return_X_ori = return_X_ori
self.return_labels = return_labels

if isinstance(self.data, str): # data from file
Expand All @@ -79,25 +79,23 @@ def __init__(

else: # data from array
X = data["X"]
X_intact = None if "X_intact" not in data.keys() else data["X_intact"]
X_ori = None if "X_ori" not in data.keys() else data["X_ori"]
y = None if "y" not in data.keys() else data["y"]
self.X, self.X_intact, self.y = self._check_array_input(X, X_intact, y)
self.X, self.X_ori, self.y = self._check_array_input(X, X_ori, y)

if self.X_intact is not None and self.return_X_intact:
# Only when X_intact is given and fixed, we fill the missing values in X here in advance.
# Otherwise, we may need original X with missing values to generate X_intact, e.g. in DatasetForSAITS.
if self.X_ori is not None and self.return_X_ori:
# Only when X_ori is given and fixed, we fill the missing values in X here in advance.
# Otherwise, we may need original X with missing values to generate X_ori, e.g. in DatasetForSAITS.
self.X, self.missing_mask = fill_and_get_mask_torch(self.X)

self.X_intact, X_intact_missing_mask = fill_and_get_mask_torch(
self.X_intact
)
indicating_mask = X_intact_missing_mask - self.missing_mask
self.X_ori, X_ori_missing_mask = fill_and_get_mask_torch(self.X_ori)
indicating_mask = X_ori_missing_mask - self.missing_mask
self.indicating_mask = indicating_mask.to(torch.float32)
else:
self.missing_mask = None
self.indicating_mask = None
# if return_X_intact is false, set X_intact to None as well
self.X_intact = None
# if return_X_ori is false, set X_ori to None as well
self.X_ori = None

self.n_samples, self.n_steps, self.n_features = self._get_data_sizes()

Expand Down Expand Up @@ -136,7 +134,7 @@ def __len__(self) -> int:
@staticmethod
def _check_array_input(
X: Union[np.ndarray, torch.Tensor, list],
X_intact: Union[np.ndarray, torch.Tensor, list],
X_ori: Union[np.ndarray, torch.Tensor, list],
y: Optional[Union[np.ndarray, torch.Tensor, list]] = None,
out_dtype: str = "tensor",
) -> Tuple[
Expand All @@ -151,9 +149,9 @@ def _check_array_input(
X :
Time-series data that must have a shape like [n_samples, expected_n_steps, expected_n_features].
X_intact :
If X is with artificial missingness, X_intact is the original X without artificial missing values.
It must have the same shape as X. If X_intact is with original missing values, should be left as NaN.
X_ori :
If X is with artificial missingness, X_ori is the original X without artificial missing values.
It must have the same shape as X. If X_ori is with original missing values, should be left as NaN.
y :
Labels of time-series samples (X) that must have a shape like [n_samples] or [n_samples, n_classes].
Expand All @@ -165,7 +163,7 @@ def _check_array_input(
-------
X :
X_intact :
X_ori :
y :
Expand All @@ -185,19 +183,19 @@ def _check_array_input(
f"input should have 3 dimensions [n_samples, seq_len, n_features],"
f"but got X: {X_shape}"
)
if X_intact is not None:
X_intact = turn_data_into_specified_dtype(X_intact, out_dtype)
X_intact = X_intact.to(torch.float32)
if X_ori is not None:
X_ori = turn_data_into_specified_dtype(X_ori, out_dtype)
X_ori = X_ori.to(torch.float32)
assert (
X_shape == X_intact.shape
), f"X and X_intact must have matched shape, but got X: f{X.shape} and X_intact: {X_intact.shape}"
X_shape == X_ori.shape
), f"X and X_ori must have matched shape, but got X: f{X.shape} and X_ori: {X_ori.shape}"
if y is not None:
assert len(X) == len(y), (
f"lengths of X and y must match, " f"but got f{len(X)} and {len(y)}"
)
y = turn_data_into_specified_dtype(y, out_dtype)

return X, X_intact, y
return X, X_ori, y

@abstractmethod
def _fetch_data_from_array(self, idx: int) -> Iterable:
Expand All @@ -214,7 +212,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
The collated data sample, a list including all necessary sample info.
"""

if self.X_intact is None:
if self.X_ori is None:
X = self.X[idx]
X, missing_mask = fill_and_get_mask_torch(X)
else:
Expand All @@ -227,10 +225,10 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
missing_mask,
]

if self.X_intact is not None and self.return_X_intact:
X_intact = self.X_intact[idx]
if self.X_ori is not None and self.return_X_ori:
X_ori = self.X_ori[idx]
indicating_mask = self.indicating_mask[idx]
sample.extend([X_intact, indicating_mask])
sample.extend([X_ori, indicating_mask])

if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))
Expand Down Expand Up @@ -306,13 +304,11 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
missing_mask,
]

if "X_intact" in self.file_handle.keys() and self.return_X_intact:
X_intact = torch.from_numpy(self.file_handle["X_intact"][idx]).to(
torch.float32
)
X_intact, X_intact_missing_mask = fill_and_get_mask_torch(X_intact)
indicating_mask = (X_intact_missing_mask - missing_mask).to(torch.float32)
sample.extend([X_intact, indicating_mask])
if "X_ori" in self.file_handle.keys() and self.return_X_ori:
X_ori = torch.from_numpy(self.file_handle["X_ori"][idx]).to(torch.float32)
X_ori, X_ori_missing_mask = fill_and_get_mask_torch(X_ori)
indicating_mask = (X_ori_missing_mask - missing_mask).to(torch.float32)
sample.extend([X_ori, indicating_mask])

# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
Expand Down
6 changes: 3 additions & 3 deletions pypots/data/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import h5py


def check_x_intact_in_val_set(val_set: Union[str, dict]) -> bool:
def check_X_ori_in_val_set(val_set: Union[str, dict]) -> bool:
if isinstance(val_set, str):
with h5py.File(val_set, "r") as f:
return "X_intact" in f.keys()
return "X_ori" in f.keys()
elif isinstance(val_set, dict):
return "X_intact" in val_set.keys()
return "X_ori" in val_set.keys()
else:
raise TypeError("val_set must be a str or a Python dictionary.")
24 changes: 12 additions & 12 deletions pypots/data/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,21 +303,21 @@ def gene_random_walk(

if missing_rate > 0:
# mask values in the validation set as ground truth
val_X_intact = val_X
val_X_ori = val_X
val_X = mcar(val_X, missing_rate)

# mask values in the test set as ground truth
test_X_intact = test_X
test_X_ori = test_X
test_X = mcar(test_X, 0.3)

data["val_X"] = val_X
data["val_X_intact"] = val_X_intact
data["val_X_ori"] = val_X_ori

# test_X is for model input
data["test_X"] = test_X
# test_X_intact is for error calc, not for model input, hence mustn't have NaNs
data["test_X_intact"] = np.nan_to_num(test_X_intact)
data["test_X_indicating_mask"] = ~np.isnan(test_X_intact) ^ ~np.isnan(test_X)
# test_X_ori is for error calc, not for model input, hence mustn't have NaNs
data["test_X_ori"] = np.nan_to_num(test_X_ori)
data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X)

return data

Expand Down Expand Up @@ -410,19 +410,19 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1):

if artificially_missing_rate > 0:
# mask values in the validation set as ground truth
val_X_intact = val_X
val_X_ori = val_X
val_X = mcar(val_X, artificially_missing_rate)
# mask values in the test set as ground truth
test_X_intact = test_X
test_X_ori = test_X
test_X = mcar(test_X, artificially_missing_rate)

data["val_X"] = val_X
data["val_X_intact"] = val_X_intact
data["val_X_ori"] = val_X_ori

# test_X is for model input
data["test_X"] = test_X
# test_X_intact is for error calc, not for model input, hence mustn't have NaNs
data["test_X_intact"] = np.nan_to_num(test_X_intact)
data["test_X_indicating_mask"] = ~np.isnan(test_X_intact) ^ ~np.isnan(test_X)
# test_X_ori is for error calc, not for model input, hence mustn't have NaNs
data["test_X_ori"] = np.nan_to_num(test_X_ori)
data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X)

return data
2 changes: 1 addition & 1 deletion pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _train_model(
imputation_mse = (
calc_mse(
results["imputed_data"],
inputs["X_intact"],
inputs["X_ori"],
inputs["indicating_mask"],
)
.sum()
Expand Down
20 changes: 9 additions & 11 deletions pypots/imputation/brits/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ class DatasetForBRITS(BaseDataset):
def __init__(
self,
data: Union[dict, str],
return_X_intact: bool,
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
):
super().__init__(data, return_X_intact, return_labels, file_type)
super().__init__(data, return_X_ori, return_labels, file_type)

if not isinstance(self.data, str):
# calculate all delta here.
if self.X_intact is None:
if self.X_ori is None:
forward_X, forward_missing_mask = fill_and_get_mask_torch(self.X)
else:
forward_missing_mask = self.missing_mask
Expand Down Expand Up @@ -116,8 +116,8 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
self.processed_data["backward"]["delta"][idx],
]

if self.X_intact is not None and self.return_X_intact:
sample.extend([self.X_intact[idx], self.indicating_mask[idx]])
if self.X_ori is not None and self.return_X_ori:
sample.extend([self.X_ori[idx], self.indicating_mask[idx]])

if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))
Expand Down Expand Up @@ -169,12 +169,10 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
backward["deltas"],
]

if "X_intact" in self.file_handle.keys() and self.return_X_intact:
X_intact = torch.from_numpy(self.file_handle["X_intact"][idx]).to(
torch.float32
)
X_intact, indicating_mask = fill_and_get_mask_torch(X_intact)
sample.extend([X_intact, indicating_mask])
if "X_ori" in self.file_handle.keys() and self.return_X_ori:
X_ori = torch.from_numpy(self.file_handle["X_ori"][idx]).to(torch.float32)
X_ori, indicating_mask = fill_and_get_mask_torch(X_ori)
sample.extend([X_ori, indicating_mask])

# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
Expand Down
18 changes: 8 additions & 10 deletions pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .data import DatasetForBRITS
from .modules import _BRITS
from ..base import BaseNNImputer
from ...data.checking import check_x_intact_in_val_set
from ...data.checking import check_X_ori_in_val_set
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger
Expand Down Expand Up @@ -165,7 +165,7 @@ def _assemble_input_for_validating(self, data: list) -> dict:
back_X,
back_missing_mask,
back_deltas,
X_intact,
X_ori,
indicating_mask,
) = self._send_data_to_given_device(data)

Expand All @@ -182,7 +182,7 @@ def _assemble_input_for_validating(self, data: list) -> dict:
"missing_mask": back_missing_mask,
"deltas": back_deltas,
},
"X_intact": X_intact,
"X_ori": X_ori,
"indicating_mask": indicating_mask,
}
return inputs
Expand All @@ -198,7 +198,7 @@ def fit(
) -> None:
# Step 1: wrap the input data with classes Dataset and DataLoader
training_set = DatasetForBRITS(
train_set, return_X_intact=False, return_labels=False, file_type=file_type
train_set, return_X_ori=False, return_labels=False, file_type=file_type
)
training_loader = DataLoader(
training_set,
Expand All @@ -208,12 +208,10 @@ def fit(
)
val_loader = None
if val_set is not None:
if not check_x_intact_in_val_set(val_set):
raise ValueError(
"val_set must contain 'X_intact' for model validation."
)
if not check_X_ori_in_val_set(val_set):
raise ValueError("val_set must contain 'X_ori' for model validation.")
val_set = DatasetForBRITS(
val_set, return_X_intact=True, return_labels=False, file_type=file_type
val_set, return_X_ori=True, return_labels=False, file_type=file_type
)
val_loader = DataLoader(
val_set,
Expand All @@ -237,7 +235,7 @@ def predict(
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(
test_set, return_X_intact=False, return_labels=False, file_type=file_type
test_set, return_X_ori=False, return_labels=False, file_type=file_type
)
test_loader = DataLoader(
test_set,
Expand Down
Loading

0 comments on commit 34e258a

Please sign in to comment.