Skip to content

Commit

Permalink
Update type hints for multiple dataloaders in .fit() and .test() (#1723)
Browse files Browse the repository at this point in the history
* update typehints

* change log
  • Loading branch information
Adrian Wälchli authored May 4, 2020
1 parent 0cd5e64 commit d28b145
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564))

- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).

### Changed

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[DataLoader] = None
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
):
r"""
Runs the full optimization routine.
Expand Down Expand Up @@ -913,7 +913,11 @@ def run_pretrain_routine(self, model: LightningModule):
# CORE TRAINING LOOP
self.train()

def test(self, model: Optional[LightningModule] = None, test_dataloaders: Optional[DataLoader] = None):
def test(
self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
):
r"""
Separates from fit to make sure you never run on your test set until you want to.
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class CurrentTestModel(
trainer.fit(model)
trainer.test()

# verify there are 2 val loaders
# verify there are 2 test loaders
assert len(trainer.test_dataloaders) == 2, \
'Multiple test_dataloaders not initiated properly'

Expand All @@ -125,7 +125,7 @@ class CurrentTestModel(
trainer.test()


def test_train_dataloaders_passed_to_fit(tmpdir):
def test_train_dataloader_passed_to_fit(tmpdir):
"""Verify that train dataloader can be passed to fit """

class CurrentTestModel(LightTrainDataloader, TestModelBase):
Expand Down Expand Up @@ -175,7 +175,7 @@ class CurrentTestModel(


def test_all_dataloaders_passed_to_fit(tmpdir):
"""Verify train, val & test dataloader can be passed to fit """
"""Verify train, val & test dataloader(s) can be passed to fit and test method"""

class CurrentTestModel(
LightTrainDataloader,
Expand Down

0 comments on commit d28b145

Please sign in to comment.