Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve early stopping verbose logging #6811

Merged
merged 20 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))


- Improved verbose logging for `EarlyStopping` callback ([#6811](https://github.com/PyTorchLightning/pytorch-lightning/pull/6811))


### Changed


Expand Down
24 changes: 22 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -196,8 +197,8 @@ def _run_early_stopping_check(self, trainer) -> None:
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
if reason:
log.info(f"[{trainer.global_rank}] {reason}")
if reason and self.verbose:
self._log_info(trainer, reason)

def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
should_stop = False
Expand All @@ -224,6 +225,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
)
elif self.monitor_op(current - self.min_delta, self.best_score):
should_stop = False
reason = self._improvement_message(current)
self.best_score = current
self.wait_count = 0
else:
Expand All @@ -236,3 +238,21 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
)

return should_stop, reason

def _improvement_message(self, current: torch.Tensor) -> str:
""" Formats a log message that informs the user about an improvement in the monitored score. """
if torch.isfinite(self.best_score):
msg = (
f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
)
else:
msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
return msg

@staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None:
if trainer is not None and trainer.world_size > 1:
log.info(f"[{trainer.global_rank}] {message}")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
else:
log.info(message)