From 4e4bd6bf41f0b266ad49a36dfe54183f2b7d5e89 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 8 Dec 2023 09:40:14 +0100 Subject: [PATCH] fix: normalize pair distances only if max_rul is set (#47) * fix: normalize distances only if max_rul is set * fix: linting issues --- rul_datasets/core.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/rul_datasets/core.py b/rul_datasets/core.py index ca33cd9..b6c5ce0 100644 --- a/rul_datasets/core.py +++ b/rul_datasets/core.py @@ -1,7 +1,7 @@ """Basic data modules for experiments involving only a single subset of any RUL dataset. """ -from typing import Dict, List, Optional, Tuple, Any, Callable +from typing import Dict, List, Optional, Tuple, Any, Callable, cast, Union import numpy as np import pytorch_lightning as pl @@ -411,17 +411,17 @@ def __init__( elif mode == "labeled": self._get_pair_func = self._get_labeled_pair_idx - def _get_max_rul(self): + def _get_max_rul(self) -> Optional[int]: max_ruls = [dm.reader.max_rul for dm in self.dms] if all(m is None for m in max_ruls): - max_rul = 1e10 + max_rul = None elif any(m is None for m in max_ruls): raise ValueError( "PairedRulDataset needs a set max_rul for all or none of the readers " "but at least one and not all of them has None." ) else: - max_rul = max(max_ruls) + max_rul = max(cast(List[int], max_ruls)) return max_rul @@ -470,7 +470,7 @@ def __next__(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tens else: raise StopIteration - def _get_pair_idx(self) -> Tuple[int, int, int, int, int]: + def _get_pair_idx(self) -> Tuple[int, int, int, Union[int, float], int]: chosen_run_idx = self._rng.integers(0, len(self._features)) domain_label = self._run_domain_idx[chosen_run_idx] chosen_run = self._features[chosen_run_idx] @@ -480,7 +480,7 @@ def _get_pair_idx(self) -> Tuple[int, int, int, int, int]: low=0, high=run_length - self.min_distance, ) - end_idx = min(run_length, anchor_idx + self._max_rul) + end_idx = min(run_length, anchor_idx + (self._max_rul or 999999)) query_idx = self._rng.integers( low=anchor_idx + self.min_distance, high=end_idx, @@ -489,7 +489,7 @@ def _get_pair_idx(self) -> Tuple[int, int, int, int, int]: return chosen_run_idx, anchor_idx, query_idx, distance, domain_label - def _get_pair_idx_piecewise(self) -> Tuple[int, int, int, int, int]: + def _get_pair_idx_piecewise(self) -> Tuple[int, int, int, Union[int, float], int]: chosen_run_idx = self._rng.integers(0, len(self._features)) domain_label = self._run_domain_idx[chosen_run_idx] chosen_run = self._features[chosen_run_idx] @@ -511,7 +511,7 @@ def _get_pair_idx_piecewise(self) -> Tuple[int, int, int, int, int]: return chosen_run_idx, anchor_idx, query_idx, distance, domain_label - def _get_labeled_pair_idx(self) -> Tuple[int, int, int, int, int]: + def _get_labeled_pair_idx(self) -> Tuple[int, int, int, Union[int, float], int]: chosen_run_idx = self._rng.integers(0, len(self._features)) domain_label = self._run_domain_idx[chosen_run_idx] chosen_run = self._features[chosen_run_idx] @@ -527,7 +527,7 @@ def _get_labeled_pair_idx(self) -> Tuple[int, int, int, int, int]: high=run_length, ) # RUL label difference is negative time step difference - distance = int(chosen_labels[anchor_idx] - chosen_labels[query_idx]) + distance = (chosen_labels[anchor_idx] - chosen_labels[query_idx]).item() return chosen_run_idx, anchor_idx, query_idx, distance, domain_label @@ -536,13 +536,15 @@ def _build_pair( run: torch.Tensor, anchor_idx: int, query_idx: int, - distance: int, + distance: Union[int, float], domain_label: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: anchors = run[anchor_idx] queries = run[query_idx] domain_tensor = torch.tensor(domain_label, dtype=torch.float) - distances = torch.tensor(distance, dtype=torch.float) / self._max_rul - distances = torch.clamp_max(distances, max=1) # max distance is max_rul + distances = torch.tensor(distance, dtype=torch.float) + if self._max_rul is not None: # normalize only if max_rul is set + distances /= self._max_rul + distances = torch.clamp_max(distances, max=1) return anchors, queries, distances, domain_tensor