Skip to content

Commit

Permalink
Raise an exception if check_val_every_n_epoch is not an integer (#6411)
Browse files Browse the repository at this point in the history
* raise an exception if check_val_every_n_epoch is not an integer

* remove unused object

* add type hints

* add return type

* update exception message

* update exception message
  • Loading branch information
kaushikb11 authored Mar 10, 2021
1 parent 615b2f7 commit 74d79e7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node):
def on_trainer_init(
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
) -> None:
self.trainer.datamodule = None
self.trainer.prepare_data_per_node = prepare_data_per_node

if not isinstance(check_val_every_n_epoch, int):
raise MisconfigurationException(
f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}"
)

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
self.trainer._is_data_prepared = False
Expand Down
10 changes: 10 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,3 +1828,13 @@ def compare_optimizers():
trainer.max_epochs = 2 # simulate multiple fit calls
trainer.fit(model)
compare_optimizers()


def test_check_val_every_n_epoch_exception(tmpdir):

with pytest.raises(MisconfigurationException, match="should be an integer."):
Trainer(
default_root_dir=tmpdir,
max_epochs=1,
check_val_every_n_epoch=1.2,
)

0 comments on commit 74d79e7

Please sign in to comment.