Skip to content

Commit

Permalink
feat: use data module for paired dataset (#45)
Browse files Browse the repository at this point in the history
* feat: use data module for paired dataset

This enables using the feature extractors of the data modules, e.g., for extracting spectra.

* fix: truncated validation
  • Loading branch information
tilman151 authored Dec 7, 2023
1 parent 0219026 commit 481b5dc
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 88 deletions.
43 changes: 20 additions & 23 deletions rul_datasets/adaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def __init__(
self.batch_size = source.batch_size
self.inductive = inductive

self.target_truncated = deepcopy(self.target.reader)
self.target_truncated.truncate_val = True
self.target_truncated = deepcopy(self.target)
self.target_truncated.reader.truncate_val = True

self._check_compatibility()

Expand All @@ -85,7 +85,7 @@ def __init__(

def _check_compatibility(self):
self.source.check_compatibility(self.target)
self.target.reader.check_compatibility(self.target_truncated)
self.target.reader.check_compatibility(self.target_truncated.reader)
if self.source.reader.fd == self.target.reader.fd:
raise ValueError(
f"FD of source and target has to be different for "
Expand Down Expand Up @@ -463,58 +463,55 @@ def __init__(
self.min_distance = min_distance
self.distance_mode = distance_mode

self.target_loader = self.target.reader
self.source_loader = self.source.reader

self._check_compatibility()

self.save_hyperparameters(
{
"fd_source": self.source_loader.fd,
"fd_target": self.target_loader.fd,
"fd_source": self.source.reader.fd,
"fd_target": self.target.reader.fd,
"num_samples": self.num_samples,
"batch_size": self.batch_size,
"window_size": self.source_loader.window_size,
"max_rul": self.source_loader.max_rul,
"window_size": self.source.reader.window_size,
"max_rul": self.source.reader.max_rul,
"min_distance": self.min_distance,
"percent_broken": self.target_loader.percent_broken,
"percent_fail_runs": self.target_loader.percent_fail_runs,
"truncate_target_val": self.target_loader.truncate_val,
"percent_broken": self.target.reader.percent_broken,
"percent_fail_runs": self.target.reader.percent_fail_runs,
"truncate_target_val": self.target.reader.truncate_val,
"distance_mode": self.distance_mode,
}
)

def _check_compatibility(self):
self.source.check_compatibility(self.target)
if self.source_loader.fd == self.target_loader.fd:
if self.source.reader.fd == self.target.reader.fd:
raise ValueError(
f"FD of source and target has to be different for "
f"domain adaption, but is {self.source_loader.fd} bot times."
f"domain adaption, but is {self.source.reader.fd} both times."
)
if (
self.target_loader.percent_broken is None
or self.target_loader.percent_broken == 1.0
self.target.reader.percent_broken is None
or self.target.reader.percent_broken == 1.0
):
raise ValueError(
"Target data needs a percent_broken smaller than 1 for pre-training."
)
if (
self.source_loader.percent_broken is not None
and self.source_loader.percent_broken < 1.0
self.source.reader.percent_broken is not None
and self.source.reader.percent_broken < 1.0
):
raise ValueError(
"Source data cannot have a percent_broken smaller than 1, "
"otherwise it would not be failed, labeled data."
)
if not self.target_loader.truncate_val:
if not self.target.reader.truncate_val:
warnings.warn(
"Validation data of unfailed runs is not truncated. "
"The validation metrics will not be valid."
)

def prepare_data(self, *args, **kwargs):
self.source_loader.prepare_data()
self.target_loader.prepare_data()
self.source.reader.prepare_data()
self.target.reader.prepare_data()

def setup(self, stage: Optional[str] = None):
self.source.setup(stage)
Expand All @@ -539,7 +536,7 @@ def _get_paired_dataset(self, split: str) -> PairedRulDataset:
min_distance = 1 if split == "val" else self.min_distance
num_samples = 50000 if split == "val" else self.num_samples
paired = PairedRulDataset(
[self.source_loader, self.target_loader],
[self.source, self.target],
split,
num_samples,
min_distance,
Expand Down
46 changes: 23 additions & 23 deletions rul_datasets/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,80 +133,80 @@ def __init__(
):
super().__init__()

self.failed_loader = failed_data_module.reader
self.unfailed_loader = unfailed_data_module.reader
self.failed = failed_data_module
self.unfailed = unfailed_data_module
self.num_samples = num_samples
self.batch_size = failed_data_module.batch_size
self.min_distance = min_distance
self.distance_mode = distance_mode
self.window_size = self.unfailed_loader.window_size
self.window_size = self.unfailed.reader.window_size
self.source = unfailed_data_module

self._check_loaders()

self.save_hyperparameters(
{
"fd_source": self.unfailed_loader.fd,
"fd_source": self.unfailed.reader.fd,
"num_samples": self.num_samples,
"batch_size": self.batch_size,
"window_size": self.window_size,
"max_rul": self.unfailed_loader.max_rul,
"max_rul": self.unfailed.reader.max_rul,
"min_distance": self.min_distance,
"percent_broken": self.unfailed_loader.percent_broken,
"percent_fail_runs": self.failed_loader.percent_fail_runs,
"truncate_val": self.unfailed_loader.truncate_val,
"percent_broken": self.unfailed.reader.percent_broken,
"percent_fail_runs": self.failed.reader.percent_fail_runs,
"truncate_val": self.unfailed.reader.truncate_val,
"distance_mode": self.distance_mode,
}
)

def _check_loaders(self):
self.failed_loader.check_compatibility(self.unfailed_loader)
if not self.failed_loader.fd == self.unfailed_loader.fd:
self.failed.reader.check_compatibility(self.unfailed.reader)
if not self.failed.reader.fd == self.unfailed.reader.fd:
raise ValueError("Failed and unfailed data need to come from the same FD.")
if self.failed_loader.percent_fail_runs is None or isinstance(
self.failed_loader.percent_fail_runs, float
if self.failed.reader.percent_fail_runs is None or isinstance(
self.failed.reader.percent_fail_runs, float
):
raise ValueError(
"Failed data needs list of failed runs "
"for pre-training but uses a float or is None."
)
if self.unfailed_loader.percent_fail_runs is None or isinstance(
self.unfailed_loader.percent_fail_runs, float
if self.unfailed.reader.percent_fail_runs is None or isinstance(
self.unfailed.reader.percent_fail_runs, float
):
raise ValueError(
"Unfailed data needs list of failed runs "
"for pre-training but uses a float or is None."
)
if set(self.failed_loader.percent_fail_runs).intersection(
self.unfailed_loader.percent_fail_runs
if set(self.failed.reader.percent_fail_runs).intersection(
self.unfailed.reader.percent_fail_runs
):
raise ValueError(
"Runs of failed and unfailed data overlap. "
"Please use mututally exclusive sets of runs."
)
if (
self.unfailed_loader.percent_broken is None
or self.unfailed_loader.percent_broken == 1.0
self.unfailed.reader.percent_broken is None
or self.unfailed.reader.percent_broken == 1.0
):
raise ValueError(
"Unfailed data needs a percent_broken smaller than 1 for pre-training."
)
if (
self.failed_loader.percent_broken is not None
and self.failed_loader.percent_broken < 1.0
self.failed.reader.percent_broken is not None
and self.failed.reader.percent_broken < 1.0
):
raise ValueError(
"Failed data cannot have a percent_broken smaller than 1, "
"otherwise it would not be failed data."
)
if not self.unfailed_loader.truncate_val:
if not self.unfailed.reader.truncate_val:
warnings.warn(
"Validation data of unfailed runs is not truncated. "
"The validation metrics will not be valid."
)

def prepare_data(self, *args, **kwargs):
self.unfailed_loader.prepare_data()
self.unfailed.reader.prepare_data()

def setup(self, stage: Optional[str] = None):
self.source.setup(stage)
Expand All @@ -229,7 +229,7 @@ def _get_paired_dataset(self, split: str) -> PairedRulDataset:
min_distance = 1 if split == "val" else self.min_distance
num_samples = 25000 if split == "val" else self.num_samples
paired = PairedRulDataset(
[self.unfailed_loader, self.failed_loader],
[self.unfailed, self.failed],
split,
num_samples,
min_distance,
Expand Down
35 changes: 19 additions & 16 deletions rul_datasets/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ class PairedRulDataset(IterableDataset):

def __init__(
self,
readers: List[AbstractReader],
dms: List[RulDataModule],
split: str,
num_samples: int,
min_distance: int,
Expand All @@ -386,19 +386,19 @@ def __init__(
):
super().__init__()

self.readers = readers
self.dms = dms
self.split = split
self.min_distance = min_distance
self.num_samples = num_samples
self.deterministic = deterministic
self.mode = mode

for reader in self.readers:
reader.check_compatibility(self.readers[0])
for dm in self.dms:
dm.check_compatibility(self.dms[0])

self._run_domain_idx: np.ndarray
self._features: List[np.ndarray]
self._labels: List[np.ndarray]
self._features: List[torch.Tensor]
self._labels: List[torch.Tensor]
self._prepare_datasets()

self._max_rul = self._get_max_rul()
Expand All @@ -412,22 +412,25 @@ def __init__(
self._get_pair_func = self._get_labeled_pair_idx

def _get_max_rul(self):
max_ruls = [reader.max_rul for reader in self.readers]
if any(m is None for m in max_ruls):
max_ruls = [dm.reader.max_rul for dm in self.dms]
if all(m is None for m in max_ruls):
max_rul = 1e10
elif any(m is None for m in max_ruls):
raise ValueError(
"PairedRulDataset needs a set max_rul for all readers "
"but at least one of them has is None."
"PairedRulDataset needs a set max_rul for all or none of the readers "
"but at least one and not all of them has None."
)
max_rul = max(max_ruls)
else:
max_rul = max(max_ruls)

return max_rul

def _prepare_datasets(self):
run_domain_idx = []
features = []
labels = []
for domain_idx, reader in enumerate(self.readers):
run_features, run_labels = reader.load_split(self.split)
for domain_idx, dm in enumerate(self.dms):
run_features, run_labels = dm.load_split(self.split)
for feat, lab in zip(run_features, run_labels):
if len(feat) > self.min_distance:
run_domain_idx.append(domain_idx)
Expand Down Expand Up @@ -530,14 +533,14 @@ def _get_labeled_pair_idx(self) -> Tuple[int, int, int, int, int]:

def _build_pair(
self,
run: np.ndarray,
run: torch.Tensor,
anchor_idx: int,
query_idx: int,
distance: int,
domain_label: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
anchors = utils.feature_to_tensor(run[anchor_idx], torch.float)
queries = utils.feature_to_tensor(run[query_idx], torch.float)
anchors = run[anchor_idx]
queries = run[query_idx]
domain_tensor = torch.tensor(domain_label, dtype=torch.float)
distances = torch.tensor(distance, dtype=torch.float) / self._max_rul
distances = torch.clamp_max(distances, max=1) # max distance is max_rul
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_test_dataloader(self):

def test_truncated_loader(self):
self.assertIsNot(self.dataset.target.reader, self.dataset.target_truncated)
self.assertTrue(self.dataset.target_truncated.truncate_val)
self.assertTrue(self.dataset.target_truncated.reader.truncate_val)

def test_hparams(self):
expected_hparams = {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def test_both_source_datasets_used(self):
)
for split in ["dev", "val"]:
with self.subTest(split):
num_broken_runs = len(dataset.unfailed_loader.load_split(split)[0])
num_fail_runs = len(dataset.failed_loader.load_split(split)[0])
num_broken_runs = len(dataset.unfailed.reader.load_split(split)[0])
num_fail_runs = len(dataset.failed.reader.load_split(split)[0])
paired_dataset = dataset._get_paired_dataset(split)
self.assertEqual(
num_broken_runs + num_fail_runs, len(paired_dataset._features)
Expand Down
Loading

0 comments on commit 481b5dc

Please sign in to comment.