Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix percent_checks #649

Merged
merged 3 commits into from
Jan 5, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def __init__(self):
self.shown_warnings = None
self.val_check_interval = None

def _percent_range_check(self, name):
value = getattr(self, name)
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
if name == "val_check_interval":
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."

if not 0. <= value <= 1.:
raise ValueError(msg)

def init_train_dataloader(self, model):
"""
Dataloaders are provided by the model
Expand All @@ -48,6 +57,8 @@ def init_train_dataloader(self, model):
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset):
self.num_training_batches = float('inf')
else:
self._percent_range_check('train_percent_check')

self.num_training_batches = len(self.get_train_dataloader())
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)

Expand All @@ -56,7 +67,14 @@ def init_train_dataloader(self, model):
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls pep8 prefers to have whitespace on the beginning of next line instead of the last line ending

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kuynzereb have you check this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I haven't fixed it yet and PR has been already merged. But I got your point. From now on I will not use trailing spaces in multiline strings :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well if you feel like nothing to do, you may check all string in the package... ;]

f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
else:
self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

Expand Down Expand Up @@ -89,13 +107,15 @@ def init_val_dataloader(self, model):
:return:
"""
self.get_val_dataloaders = model.val_dataloader
self.num_val_batches = 0

# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.get_val_dataloaders() is not None:
self._percent_range_check('val_percent_check')

self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
self.num_val_batches = max(1, self.num_val_batches)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_val_dataloaders() is not None:
Expand Down Expand Up @@ -134,10 +154,11 @@ def init_test_dataloader(self, model):

# determine number of test batches
if self.get_test_dataloaders() is not None:
self._percent_range_check('test_percent_check')

len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
self.num_test_batches = len_sum
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
self.num_test_batches = max(1, self.num_test_batches)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_test_dataloaders() is not None:
Expand Down Expand Up @@ -208,6 +229,10 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check,
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
if overfit_pct > 1:
raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got "
f"{overfit_pct:.3f}.")

self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct