Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix incomplete progress bar when refresh_rate > num batches #4577

Merged
merged 7 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed checkpoint hparams dict casting when omegaconf is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770))

- Fixed incomplete progress bars when total batches not divisible by refresh rate ([#4577](https://github.com/PyTorchLightning/pytorch-lightning/pull/4577))

## [1.0.7] - 2020-11-17

Expand Down
28 changes: 21 additions & 7 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,21 +334,22 @@ def on_epoch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
self.main_progress_bar.update(self.refresh_rate)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if not trainer.running_sanity_check:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
self.val_progress_bar.update(self.refresh_rate)
self.main_progress_bar.update(self.refresh_rate)
if self._should_update(self.val_batch_idx, self.total_val_batches):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
Expand All @@ -366,13 +367,26 @@ def on_test_start(self, trainer, pl_module):

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0:
self.test_progress_bar.update(self.refresh_rate)
if self._should_update(self.test_batch_idx, self.total_test_batches):
self._update_bar(self.test_progress_bar)

def on_test_end(self, trainer, pl_module):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()

def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar):
""" Updates the bar by the refresh rate without overshooting. """
if bar.total is not None:
delta = min(self.refresh_rate, bar.total - bar.n)
else:
# infinite / unknown size
delta = self.refresh_rate
if delta > 0:
bar.update(delta)


def convert_inf(x):
""" The tqdm doesn't support inf values. We have to convert it to None. """
Expand Down
78 changes: 77 additions & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock, call

import pytest
from unittest import mock

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, BoringModel


@pytest.mark.parametrize('callbacks,refresh_rate', [
Expand Down Expand Up @@ -252,3 +254,77 @@ def test_progress_bar_warning_on_colab(tmpdir):
)

assert trainer.progress_bar_callback.refresh_rate == 19


class MockedUpdateProgressBars(ProgressBar):
""" Mocks the update method once bars get initializied. """

def _mock_bar_update(self, bar):
bar.update = Mock(wraps=bar.update)
return bar

def init_train_tqdm(self):
bar = super().init_train_tqdm()
return self._mock_bar_update(bar)

def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
return self._mock_bar_update(bar)

def init_test_tqdm(self):
bar = super().init_test_tqdm()
return self._mock_bar_update(bar)


@pytest.mark.parametrize("train_batches,val_batches,refresh_rate,train_deltas,val_deltas", [
[2, 3, 1, [1, 1, 1, 1, 1], [1, 1, 1]],
[0, 0, 3, [], []],
[1, 0, 3, [1], []],
[1, 1, 3, [2], [1]],
[5, 0, 3, [3, 2], []],
[5, 2, 3, [3, 3, 1], [2]],
[5, 2, 6, [6, 1], [2]],
])
def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, refresh_rate, train_deltas, val_deltas):
"""
Test that the main progress updates with the correct amount together with the val progress. At the end of
the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate.
"""
model = BoringModel()
progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=train_batches,
limit_val_batches=val_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
)
trainer.fit(model)
progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas])
if val_batches > 0:
progress_bar.val_progress_bar.update.assert_has_calls([call(delta) for delta in val_deltas])


@pytest.mark.parametrize("test_batches,refresh_rate,test_deltas", [
[1, 3, [1]],
[3, 1, [1, 1, 1]],
[5, 3, [3, 2]],
])
def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, test_deltas):
"""
Test that test progress updates with the correct amount.
"""
model = BoringModel()
progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_test_batches=test_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
)
trainer.test(model)
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])