Skip to content

Commit

Permalink
fix incomplete progress bar when refresh_rate > num batches (#4577)
Browse files Browse the repository at this point in the history
* fix progress bar overshoot

* fix updates for partially incomplete main  progress bar when val loop starts

* add tests

* chlog
  • Loading branch information
awaelchli authored Nov 23, 2020
1 parent 9186abe commit 89e8796
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed checkpoint hparams dict casting when omegaconf is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770))

- Fixed incomplete progress bars when total batches not divisible by refresh rate ([#4577](https://github.com/PyTorchLightning/pytorch-lightning/pull/4577))

## [1.0.7] - 2020-11-17

Expand Down
28 changes: 21 additions & 7 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,21 +334,22 @@ def on_epoch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
self.main_progress_bar.update(self.refresh_rate)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if not trainer.running_sanity_check:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
self.val_progress_bar.update(self.refresh_rate)
self.main_progress_bar.update(self.refresh_rate)
if self._should_update(self.val_batch_idx, self.total_val_batches):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
Expand All @@ -366,13 +367,26 @@ def on_test_start(self, trainer, pl_module):

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0:
self.test_progress_bar.update(self.refresh_rate)
if self._should_update(self.test_batch_idx, self.total_test_batches):
self._update_bar(self.test_progress_bar)

def on_test_end(self, trainer, pl_module):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()

def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar):
""" Updates the bar by the refresh rate without overshooting. """
if bar.total is not None:
delta = min(self.refresh_rate, bar.total - bar.n)
else:
# infinite / unknown size
delta = self.refresh_rate
if delta > 0:
bar.update(delta)


def convert_inf(x):
""" The tqdm doesn't support inf values. We have to convert it to None. """
Expand Down
78 changes: 77 additions & 1 deletion tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock, call

import pytest
from unittest import mock

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, BoringModel


@pytest.mark.parametrize('callbacks,refresh_rate', [
Expand Down Expand Up @@ -252,3 +254,77 @@ def test_progress_bar_warning_on_colab(tmpdir):
)

assert trainer.progress_bar_callback.refresh_rate == 19


class MockedUpdateProgressBars(ProgressBar):
""" Mocks the update method once bars get initializied. """

def _mock_bar_update(self, bar):
bar.update = Mock(wraps=bar.update)
return bar

def init_train_tqdm(self):
bar = super().init_train_tqdm()
return self._mock_bar_update(bar)

def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
return self._mock_bar_update(bar)

def init_test_tqdm(self):
bar = super().init_test_tqdm()
return self._mock_bar_update(bar)


@pytest.mark.parametrize("train_batches,val_batches,refresh_rate,train_deltas,val_deltas", [
[2, 3, 1, [1, 1, 1, 1, 1], [1, 1, 1]],
[0, 0, 3, [], []],
[1, 0, 3, [1], []],
[1, 1, 3, [2], [1]],
[5, 0, 3, [3, 2], []],
[5, 2, 3, [3, 3, 1], [2]],
[5, 2, 6, [6, 1], [2]],
])
def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, refresh_rate, train_deltas, val_deltas):
"""
Test that the main progress updates with the correct amount together with the val progress. At the end of
the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate.
"""
model = BoringModel()
progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=train_batches,
limit_val_batches=val_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
)
trainer.fit(model)
progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas])
if val_batches > 0:
progress_bar.val_progress_bar.update.assert_has_calls([call(delta) for delta in val_deltas])


@pytest.mark.parametrize("test_batches,refresh_rate,test_deltas", [
[1, 3, [1]],
[3, 1, [1, 1, 1]],
[5, 3, [3, 2]],
])
def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, test_deltas):
"""
Test that test progress updates with the correct amount.
"""
model = BoringModel()
progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_test_batches=test_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
)
trainer.test(model)
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])

0 comments on commit 89e8796

Please sign in to comment.