Skip to content

Commit

Permalink
[Test] Add extra test for val_check_interval in distributed scenario (#…
Browse files Browse the repository at this point in the history
…7863)

* add extra test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add computation

* Update docs/source/common/trainer.rst

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update docs/source/common/trainer.rst

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update tests/trainer/test_dataloaders.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* use tmpdir

* update on comments

* update

* Update tests/callbacks/test_progress_bar.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored and Borda committed Jun 8, 2021
1 parent 1e0b4f6 commit 3e94b05
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
18 changes: 18 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,24 @@ Can specify as float or int.
trainer = Trainer(val_check_interval=1000)


.. code-block::
# Here is the computation to estimate the total number of batches seen within an epoch.
# Find the total number of train batches
total_train_batches = total_train_samples // (train_batch_size * world_size)
# Compute how many times we will call validation during the training loop
val_check_batch = max(1, int(total_train_batches * val_check_interval))
val_checks_per_epoch = total_train_batches / val_check_batch
# Find the total number of validation batches
total_val_batches = total_val_samples // (val_batch_size * world_size)
# Total number of batches run
total_fit_batches = total_train_batches + total_val_batches
weights_save_path
^^^^^^^^^^^^^^^^^

Expand Down
48 changes: 47 additions & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

import pytest
import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import tqdm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


@pytest.mark.parametrize(
Expand Down Expand Up @@ -533,3 +535,47 @@ def test_progress_bar_can_be_pickled():
pickle.dumps(bar)
trainer.predict(model)
pickle.dumps(bar)


@RunIf(min_gpus=2, special=True)
@pytest.mark.parametrize([
"total_train_samples",
"train_batch_size",
"total_val_samples",
"val_batch_size",
"val_check_interval",
], [
(8, 4, 2, 1, 0.2),
(8, 4, 2, 1, 0.5),
])
def test_progress_bar_max_val_check_interval(
total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval, tmpdir
):

world_size = 2

train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size)
val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
val_check_interval=val_check_interval,
gpus=world_size,
accelerator="ddp",
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)

total_train_batches = total_train_samples // (train_batch_size * world_size)
val_check_batch = max(1, int(total_train_batches * val_check_interval))
assert trainer.val_check_batch == val_check_batch
val_checks_per_epoch = total_train_batches / val_check_batch
total_val_batches = total_val_samples // (val_batch_size * world_size)
assert trainer.progress_bar_callback.total_train_batches == total_train_batches
assert trainer.progress_bar_callback.total_val_batches == total_val_batches
total_val_batches = total_val_batches * val_checks_per_epoch
if trainer.is_global_zero:
assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches

0 comments on commit 3e94b05

Please sign in to comment.