Skip to content

Commit

Permalink
Fix percent_checks (#649)
Browse files Browse the repository at this point in the history
* fix percent_checks

* Added _percent_range_check

* remove max
  • Loading branch information
kuynzereb authored and williamFalcon committed Jan 5, 2020
1 parent 9ac91ad commit 7824b5c
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def __init__(self):
self.shown_warnings = None
self.val_check_interval = None

def _percent_range_check(self, name):
value = getattr(self, name)
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
if name == "val_check_interval":
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."

if not 0. <= value <= 1.:
raise ValueError(msg)

def init_train_dataloader(self, model):
"""
Dataloaders are provided by the model
Expand All @@ -48,6 +57,8 @@ def init_train_dataloader(self, model):
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset):
self.num_training_batches = float('inf')
else:
self._percent_range_check('train_percent_check')

self.num_training_batches = len(self.get_train_dataloader())
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)

Expand All @@ -56,7 +67,14 @@ def init_train_dataloader(self, model):
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
else:
self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

Expand Down Expand Up @@ -89,13 +107,15 @@ def init_val_dataloader(self, model):
:return:
"""
self.get_val_dataloaders = model.val_dataloader
self.num_val_batches = 0

# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.get_val_dataloaders() is not None:
self._percent_range_check('val_percent_check')

self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
self.num_val_batches = max(1, self.num_val_batches)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_val_dataloaders() is not None:
Expand Down Expand Up @@ -134,10 +154,11 @@ def init_test_dataloader(self, model):

# determine number of test batches
if self.get_test_dataloaders() is not None:
self._percent_range_check('test_percent_check')

len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
self.num_test_batches = len_sum
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
self.num_test_batches = max(1, self.num_test_batches)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_test_dataloaders() is not None:
Expand Down Expand Up @@ -208,6 +229,10 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check,
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
if overfit_pct > 1:
raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got "
f"{overfit_pct:.3f}.")

self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct

0 comments on commit 7824b5c

Please sign in to comment.