Skip to content

Commit

Permalink
Compute transform parameters exclusively to train split
Browse files Browse the repository at this point in the history
  • Loading branch information
mauricekraus authored Oct 13, 2023
1 parent a923fda commit 7543f27
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchchronos/datasets/ucr_uea.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class UCRUEADataset(Dataset):
ds_name (str): Name of the dataset.
path (Path): Path to the dataset storage.
split (DatasetSplit, optional): Dataset split to use. Defaults to None.
transform (Transform, optional): Transform to apply to the data. Defaults to None.
transform (Transform, optional): Transform to apply to the data. Defaults to None. Parameters are only computed for the train split.
torchchronos_cache (bool, optional): Whether to cache the data in torchchronos format. Defaults to True.
This allows for much faster loading of the data once it has been loaded once.
It is recommended to leave this as True unless space is an issue.
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
np.savez(tc_cache_path, xs=self.xs, ys=self.ys)

self.xs = torch.tensor(self.xs, dtype=torch.float32).transpose(1, 2)
if self.transform is not None:
if self.transform is not None and split == DatasetSplit.TRAIN: # we base our transform parameters on the train set
self.transform = self.transform.fit(self.xs)

# It doesnt matter if test or train, but test is usually smaller
Expand Down

0 comments on commit 7543f27

Please sign in to comment.