From 16c9b03be05436fb64c84e09b507332cb05fca43 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Wed, 25 Dec 2019 17:22:10 +0300 Subject: [PATCH 1/3] fix percent_checks --- pytorch_lightning/trainer/data_loading.py | 30 ++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 84413697948d5..434b96a340e7b 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -48,15 +48,30 @@ def init_train_dataloader(self, model): if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset): self.num_training_batches = float('inf') else: + if not 0. <= self.train_percent_check <= 1.: + raise ValueError(f"train_percent_check must lie in the range [0.0, 1.0], but got " + f"{self.train_percent_check:.3f}.") + self.num_training_batches = len(self.get_train_dataloader()) self.num_training_batches = int(self.num_training_batches * self.train_percent_check) + self.num_training_batches = max(1, self.num_training_batches) # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval + if self.val_check_batch > self.num_training_batches: + raise ValueError( + f"val_check_interval ({self.val_check_interval}) must be less than or equal to " + f"the number of the training batches ({self.num_training_batches}). " + f"If you want to disable validation set val_percent_check to 0.0 instead.") else: + if not 0. <= self.val_check_interval <= 1.: + raise ValueError(f"val_check_interval must lie in the range [0.0, 1.0], but got " + f"{self.val_check_interval:.3f}. If you want to disable " + f"validation set val_percent_check to 0.0 instead.") + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) @@ -89,13 +104,18 @@ def init_val_dataloader(self, model): :return: """ self.get_val_dataloaders = model.val_dataloader + self.num_val_batches = 0 # determine number of validation batches # val datasets could be none, 1 or 2+ if self.get_val_dataloaders() is not None: + if not 0. <= self.val_percent_check <= 1.: + raise ValueError(f"val_percent_check must lie in the range [0.0, 1.0], but got " + f"{self.val_percent_check:.3f}. If you want to disable " + f"validation set it to 0.0.") + self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) - self.num_val_batches = max(1, self.num_val_batches) on_ddp = self.use_ddp or self.use_ddp2 if on_ddp and self.get_val_dataloaders() is not None: @@ -134,6 +154,10 @@ def init_test_dataloader(self, model): # determine number of test batches if self.get_test_dataloaders() is not None: + if not 0. <= self.test_percent_check <= 1.: + raise ValueError(f"test_percent_check must lie in the range [0.0, 1.0], but got " + f"{self.test_percent_check:.3f}.") + len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) self.num_test_batches = len_sum self.num_test_batches = int(self.num_test_batches * self.test_percent_check) @@ -208,6 +232,10 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check, self.val_percent_check = val_percent_check self.test_percent_check = test_percent_check if overfit_pct > 0: + if overfit_pct > 1: + raise ValueError(f"overfit_pct must be not greater than 1.0, but got " + f"{overfit_pct:.3f}.") + self.train_percent_check = overfit_pct self.val_percent_check = overfit_pct self.test_percent_check = overfit_pct From 2b11894f8774f5931ed4105e1d0d88ba4e36ca03 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Fri, 27 Dec 2019 11:41:36 +0300 Subject: [PATCH 2/3] Added _percent_range_check --- pytorch_lightning/trainer/data_loading.py | 35 +++++++++++------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 434b96a340e7b..2970e5b715327 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -36,6 +36,15 @@ def __init__(self): self.shown_warnings = None self.val_check_interval = None + def _percent_range_check(self, name): + value = getattr(self, name) + msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}." + if name == "val_check_interval": + msg += " If you want to disable validation set `val_percent_check` to 0.0 instead." + + if not 0. <= value <= 1.: + raise ValueError(msg) + def init_train_dataloader(self, model): """ Dataloaders are provided by the model @@ -48,9 +57,7 @@ def init_train_dataloader(self, model): if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset): self.num_training_batches = float('inf') else: - if not 0. <= self.train_percent_check <= 1.: - raise ValueError(f"train_percent_check must lie in the range [0.0, 1.0], but got " - f"{self.train_percent_check:.3f}.") + self._percent_range_check('train_percent_check') self.num_training_batches = len(self.get_train_dataloader()) self.num_training_batches = int(self.num_training_batches * self.train_percent_check) @@ -63,14 +70,11 @@ def init_train_dataloader(self, model): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( - f"val_check_interval ({self.val_check_interval}) must be less than or equal to " - f"the number of the training batches ({self.num_training_batches}). " - f"If you want to disable validation set val_percent_check to 0.0 instead.") + f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " + f"to the number of the training batches ({self.num_training_batches}). " + f"If you want to disable validation set `val_percent_check` to 0.0 instead.") else: - if not 0. <= self.val_check_interval <= 1.: - raise ValueError(f"val_check_interval must lie in the range [0.0, 1.0], but got " - f"{self.val_check_interval:.3f}. If you want to disable " - f"validation set val_percent_check to 0.0 instead.") + self._percent_range_check('val_check_interval') self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) @@ -109,10 +113,7 @@ def init_val_dataloader(self, model): # determine number of validation batches # val datasets could be none, 1 or 2+ if self.get_val_dataloaders() is not None: - if not 0. <= self.val_percent_check <= 1.: - raise ValueError(f"val_percent_check must lie in the range [0.0, 1.0], but got " - f"{self.val_percent_check:.3f}. If you want to disable " - f"validation set it to 0.0.") + self._percent_range_check('val_percent_check') self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) @@ -154,9 +155,7 @@ def init_test_dataloader(self, model): # determine number of test batches if self.get_test_dataloaders() is not None: - if not 0. <= self.test_percent_check <= 1.: - raise ValueError(f"test_percent_check must lie in the range [0.0, 1.0], but got " - f"{self.test_percent_check:.3f}.") + self._percent_range_check('test_percent_check') len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) self.num_test_batches = len_sum @@ -233,7 +232,7 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check, self.test_percent_check = test_percent_check if overfit_pct > 0: if overfit_pct > 1: - raise ValueError(f"overfit_pct must be not greater than 1.0, but got " + raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got " f"{overfit_pct:.3f}.") self.train_percent_check = overfit_pct From ed75e7a648a4322c79eb0c094534ffa486c3e96d Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Fri, 27 Dec 2019 12:30:47 +0300 Subject: [PATCH 3/3] remove max --- pytorch_lightning/trainer/data_loading.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 2970e5b715327..33bd99fcabc94 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -61,7 +61,6 @@ def init_train_dataloader(self, model): self.num_training_batches = len(self.get_train_dataloader()) self.num_training_batches = int(self.num_training_batches * self.train_percent_check) - self.num_training_batches = max(1, self.num_training_batches) # determine when to check validation # if int passed in, val checks that often @@ -160,7 +159,6 @@ def init_test_dataloader(self, model): len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) self.num_test_batches = len_sum self.num_test_batches = int(self.num_test_batches * self.test_percent_check) - self.num_test_batches = max(1, self.num_test_batches) on_ddp = self.use_ddp or self.use_ddp2 if on_ddp and self.get_test_dataloaders() is not None: