Skip to content

Commit

Permalink
feat: add degraded only option (#48)
Browse files Browse the repository at this point in the history
* feat: add degraded only option

* docs: fix lightning cross-ref
  • Loading branch information
tilman151 authored Dec 22, 2023
1 parent 4e4bd6b commit 44340f7
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# RUL Datasets

This library contains a collection of common benchmark datasets for **remaining useful lifetime (RUL)** estimation.
They are provided as [LightningDataModules][pytorch_lightning.core.LightningDataModule] to be readily used in [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/).
They are provided as [LightningDataModules][lightning.pytorch.core.LightningDataModule] to be readily used in [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/).

Currently, four datasets are supported:

Expand Down
2 changes: 1 addition & 1 deletion docs/use_cases/libraries.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This library was developed to be used in [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/) first and foremost.
Lightning helps writing clean and reproducible deep learning code that can run on most common training hardware.
Datasets are represented by [LightningDataModules][pytorch_lightning.core.LightningDataModule] which give access to data loaders for each data split.
Datasets are represented by [LightningDataModules][lightning.pytorch.core.LightningDataModule] which give access to data loaders for each data split.
The RUL Datasets library implements several data modules that are 100% compatible with Lightning:

```python
Expand Down
21 changes: 19 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rul_datasets"
version = "0.0.0"
version = "0.11.3"
description = "A collection of datasets for RUL estimation as Lightning Data Modules."
authors = ["Krokotsch, Tilman <tilman.krokotsch@tu-berlin.de>"]
license = "MIT"
Expand All @@ -25,6 +25,7 @@ flake8 = "^5.0.4"
mypy = "^1.0.0"
hydra-core = "^1.1.1"
pytest = "^7.1.3"
pytest-mock = "^3.12.0"

[tool.poetry.group.docs]
optional = true
Expand Down
4 changes: 2 additions & 2 deletions rul_datasets/adaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class DomainAdaptionDataModule(pl.LightningDataModule):
"""
A higher-order [data module][pytorch_lightning.core.LightningDataModule] used for
A higher-order [data module][lightning.pytorch.core.LightningDataModule] used for
unsupervised domain adaption of a labeled source to an unlabeled target domain.
The training data of both domains is wrapped in a [AdaptionDataset]
[rul_datasets.adaption.AdaptionDataset] which provides a random sample of the
Expand Down Expand Up @@ -217,7 +217,7 @@ def _get_paired_dataset(self) -> PairedRulDataset:

class LatentAlignDataModule(DomainAdaptionDataModule):
"""
A higher-order [data module][pytorch_lightning.core.LightningDataModule] based on
A higher-order [data module][lightning.pytorch.core.LightningDataModule] based on
[DomainAdaptionDataModule][rul_datasets.adaption.DomainAdaptionDataModule].
It is specifically made to work with the latent space alignment approach by Zhang
Expand Down
2 changes: 1 addition & 1 deletion rul_datasets/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class BaselineDataModule(pl.LightningDataModule):
"""
A higher-order [data module][pytorch_lightning.core.LightningDataModule] that
A higher-order [data module][lightning.pytorch.core.LightningDataModule] that
takes a [RulDataModule][rul_datasets.core.RulDataModule]. It provides the
training and validation splits of the sub-dataset selected in the underlying data
module but provides the test splits of all available subsets of the dataset. This
Expand Down
64 changes: 55 additions & 9 deletions rul_datasets/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Basic data modules for experiments involving only a single subset of any RUL
dataset. """

from typing import Dict, List, Optional, Tuple, Any, Callable, cast, Union
from typing import Dict, List, Optional, Tuple, Any, Callable, cast, Union, Literal

import numpy as np
import pytorch_lightning as pl
Expand All @@ -14,7 +14,7 @@

class RulDataModule(pl.LightningDataModule):
"""
A [data module][pytorch_lightning.core.LightningDataModule] to provide windowed
A [data module][lightning.pytorch.core.LightningDataModule] to provide windowed
time series features with RUL targets. It exposes the splits of the underlying
dataset for easy usage with PyTorch and PyTorch Lightning.
Expand Down Expand Up @@ -52,6 +52,12 @@ class RulDataModule(pl.LightningDataModule):
... feature_extractor=lambda x: np.mean(x, axis=1),
... window_size=10
... )
Only Degraded Validation and Test Samples
>>> import rul_datasets
>>> cmapss = rul_datasets.reader.CmapssReader(fd=1)
>>> dm = rul_datasets.RulDataModule(cmapss, 32, degraded_only=["val", "test"])
"""

_data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]
Expand All @@ -62,13 +68,14 @@ def __init__(
batch_size: int,
feature_extractor: Optional[Callable] = None,
window_size: Optional[int] = None,
degraded_only: Optional[List[Literal["dev", "val", "test"]]] = None,
):
"""
Create a new RUL data module from a reader.
This data module exposes a training, validation and test data loader for the
underlying dataset. First, `prepare_data` is called to download and
pre-process the dataset. Afterwards, `setup_data` is called to load all
pre-process the dataset. Afterward, `setup_data` is called to load all
splits into memory.
If a `feature_extractor` is supplied, the data module extracts new features
Expand All @@ -85,18 +92,21 @@ def __init__(
`[num_windows, window_size, features]`.
Args:
reader: The dataset reader for the desired dataset, e.g. CmapssLoader.
batch_size: The size of the batches build by the data loaders.
reader: The dataset reader for the desired dataset, e.g., CmapssLoader.
batch_size: The size of the batches built by the data loaders.
feature_extractor: A feature extractor that extracts feature vectors from
windows.
window_size: The new window size to apply after the feature extractor.
degraded_only: Whether to load only degraded samples for the `dev`, 'val'
or 'test' split.
"""
super().__init__()

self._reader: AbstractReader = reader
self.batch_size = batch_size
self.feature_extractor = feature_extractor
self.window_size = window_size
self.degraded_only = degraded_only

if (self.feature_extractor is None) and (self.window_size is not None):
raise ValueError(
Expand All @@ -111,6 +121,7 @@ def __init__(
str(self.feature_extractor) if self.feature_extractor else None
),
"window_size": self.window_size,
"degraded_only": self.degraded_only,
}
self.save_hyperparameters(hparams)

Expand Down Expand Up @@ -217,24 +228,40 @@ def setup(self, stage: Optional[str] = None) -> None:
}

def load_split(
self, split: str, alias: Optional[str] = None
self,
split: str,
alias: Optional[str] = None,
degraded_only: Optional[bool] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Load a split from the underlying reader and apply the feature extractor.
By setting alias, it is possible to load a split aliased as another split,
e.g. load the test split and treat it as the dev split. The data of the split is
loaded but all pre-processing steps of alias are carried out.
e.g., load the test split and treat it as the dev split. The data of the split
is loaded, but all pre-processing steps of alias are carried out.
If `degraded_only` is set, only degraded samples are loaded. This is only
possible if the underlying reader has a `max_rul` set or `norm_rul` is set to
`True`. The `degraded_only` argument takes precedence over the `degraded_only`
of the data module.
Args:
split: The desired split to load.
alias: The split as which the loaded data should be treated.
degraded_only: Whether to only load degraded samples.
Returns:
The feature and target tensors of the split's runs.
"""
features, targets = self.reader.load_split(split, alias)
features, targets = self._apply_feature_extractor_per_run(features, targets)
tensor_features, tensor_targets = utils.to_tensor(features, targets)
if degraded_only is None:
degraded_only = (
self.degraded_only is not None
and (alias or split) in self.degraded_only
)
if degraded_only:
self._filter_out_healthy(tensor_features, tensor_targets)

return tensor_features, tensor_targets

Expand Down Expand Up @@ -269,6 +296,21 @@ def _extract_and_window(

return features, targets

def _filter_out_healthy(self, tensor_features, tensor_targets):
if self.reader.max_rul is not None:
thresh = self.reader.max_rul
elif hasattr(self.reader, "norm_rul") and self.reader.norm_rul:
thresh = 1.0
else:
raise ValueError(
"Cannot filter degraded samples if no max_rul is set and "
"norm_rul is False."
)
for i in range(len(tensor_targets)):
degraded = tensor_targets[i] < thresh
tensor_features[i] = tensor_features[i][degraded]
tensor_targets[i] = tensor_targets[i][degraded]

def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
"""
Create a [data loader][torch.utils.data.DataLoader] for the training split.
Expand Down Expand Up @@ -383,6 +425,7 @@ def __init__(
min_distance: int,
deterministic: bool = False,
mode: str = "linear",
degraded_only: bool = False,
):
super().__init__()

Expand All @@ -392,6 +435,7 @@ def __init__(
self.num_samples = num_samples
self.deterministic = deterministic
self.mode = mode
self.degraded_only = degraded_only

for dm in self.dms:
dm.check_compatibility(self.dms[0])
Expand Down Expand Up @@ -430,7 +474,9 @@ def _prepare_datasets(self):
features = []
labels = []
for domain_idx, dm in enumerate(self.dms):
run_features, run_labels = dm.load_split(self.split)
run_features, run_labels = dm.load_split(
self.split, degraded_only=self.degraded_only
)
for feat, lab in zip(run_features, run_labels):
if len(feat) > self.min_distance:
run_domain_idx.append(domain_idx)
Expand Down
2 changes: 1 addition & 1 deletion rul_datasets/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class SemiSupervisedDataModule(pl.LightningDataModule):
"""
A higher-order [data module][pytorch_lightning.core.LightningDataModule] used for
A higher-order [data module][lightning.pytorch.core.LightningDataModule] used for
semi-supervised learning with a labeled data module and an unlabeled one. It
makes sure that both data modules come from the same sub-dataset.
Expand Down
65 changes: 61 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_created_correctly(self, mock_loader):
"batch_size": 16,
"window_size": None,
"feature_extractor": None,
"degraded_only": None,
}

@pytest.mark.parametrize("window_size", [2, None])
Expand All @@ -53,8 +54,18 @@ def test_created_correctly_with_feature_extractor(self, mock_loader, window_size
"batch_size": 16,
"window_size": window_size,
"feature_extractor": str(fe),
"degraded_only": None,
}

@pytest.mark.parametrize("degraded_only", [None, ["dev"], ["val", "test"]])
def test_created_correctly_with_degraded_only(self, mock_loader, degraded_only):
dataset = core.RulDataModule(
mock_loader, batch_size=16, degraded_only=degraded_only
)

assert "degraded_only" in dataset.hparams
assert dataset.hparams["degraded_only"] == degraded_only

def test_prepare_data(self, mock_loader):
dataset = core.RulDataModule(mock_loader, batch_size=16)
dataset.prepare_data()
Expand All @@ -71,6 +82,36 @@ def test_setup(self, mock_loader, mock_runs):
mock_runs = tuple(torch.tensor(np.concatenate(r)) for r in mock_runs)
assert dataset._data == {"dev": mock_runs, "val": mock_runs, "test": mock_runs}

@pytest.mark.parametrize("split", ["dev", "val", "test"])
def test_load_split_degraded_only(self, mock_loader, mocker, split):
mock_loader.max_rul = 125
dataset = core.RulDataModule(mock_loader, batch_size=16, degraded_only=[split])
spy__filter_out_healthy = mocker.spy(dataset, "_filter_out_healthy")

dataset.load_split(split)
spy__filter_out_healthy.assert_called()
spy__filter_out_healthy.reset_mock()
for other_split in {"dev", "val", "test"} - {split}:
dataset.load_split(other_split)
spy__filter_out_healthy.assert_not_called()

@pytest.mark.parametrize("degraded_only", [True, False])
def test_load_split_degraded_only_override(
self, mock_loader, mocker, degraded_only
):
mock_loader.max_rul = 125
dataset = core.RulDataModule(
mock_loader, batch_size=16, degraded_only=None if degraded_only else ["dev"]
)
spy__filter_out_healthy = mocker.spy(dataset, "_filter_out_healthy")

dataset.load_split("dev", degraded_only=degraded_only)

if degraded_only:
spy__filter_out_healthy.assert_called()
else:
spy__filter_out_healthy.assert_not_called()

def test_empty_dataset(self, mock_loader):
"""Should not crash on empty dataset."""
mock_loader.load_split.return_value = [], []
Expand Down Expand Up @@ -248,21 +289,23 @@ def test_feature_extractor_no_rewindowing(self, mock_loader):
class DummyRul(reader.AbstractReader):
fd: int = 1
window_size: int = 30
max_rul: int = 125
percent_broken = None
percent_fail_runs = None
truncate_val = False
truncate_degraded_only = False

def __init__(self, length):
def __init__(self, length, norm_rul=False):
self.norm_rul = norm_rul
self.max_rul = None if norm_rul else 125
norm = 125 if norm_rul else 1
self.data = {
"dev": (
[np.zeros((length, self.window_size, 5))],
[np.clip(np.arange(length, 0, step=-1), a_min=None, a_max=125)],
[np.clip(np.arange(length, 0, step=-1), a_min=None, a_max=125) / norm],
),
"val": (
[np.zeros((100, self.window_size, 5))],
[np.clip(np.arange(100, 0, step=-1), a_min=None, a_max=125)],
[np.clip(np.arange(100, 0, step=-1), a_min=None, a_max=125) / norm],
),
}

Expand Down Expand Up @@ -357,6 +400,11 @@ def cmapss_normal(length):
return RulDataModule(DummyRul(length), 32)


@pytest.fixture
def cmapss_normed(length):
return RulDataModule(DummyRul(length, norm_rul=True), 32)


@pytest.fixture
def cmapss_short():
return RulDataModule(DummyRulShortRuns(), 32)
Expand Down Expand Up @@ -537,3 +585,12 @@ def test_compatability_check(self):
core.PairedRulDataset(dms, "dev", 1000, 1)

assert 2 == mock_check_compat.call_count

@pytest.mark.parametrize("degraded_only", [True, False])
def test_degraded_only(self, degraded_only, cmapss_normal, mocker):
spy_load_split = mocker.spy(cmapss_normal, "load_split")
core.PairedRulDataset(
[cmapss_normal], "dev", 1000, 1, degraded_only=degraded_only
)

spy_load_split.assert_called_with("dev", degraded_only=degraded_only)

0 comments on commit 44340f7

Please sign in to comment.