Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: accumulation and suggestion for learning rate finder #1801

Merged
merged 10 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561))

- Fixed missing profiler attribute in add_argparse_args() ArgumentParser ([#1794](https://github.com/PyTorchLightning/pytorch-lightning/pull/1794))

- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801))

## [0.7.5] - 2020-04-27

Expand Down
79 changes: 58 additions & 21 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn


class TrainerLRFinderMixin(ABC):
Expand Down Expand Up @@ -58,7 +59,8 @@ def lr_find(self,
max_lr: float = 1,
num_training: int = 100,
mode: str = 'exponential',
num_accumulation_steps: int = 1):
early_stop_threshold: float = 4.0,
num_accumulation_steps=None):
r"""
lr_find enables the user to do a range test of good initial learning rates,
to reduce the amount of guesswork in picking a good starting learning rate.
Expand All @@ -81,7 +83,12 @@ def lr_find(self,
after each batch. If set to 'exponential', will increase learning
rate exponentially.

num_accumulation_steps: number of batches to calculate loss over.
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.

num_accumulation_steps: deprepecated, number of batches to calculate loss over.
Set trainer argument ``accumulate_grad_batches`` instead.

Example::

Expand All @@ -104,6 +111,12 @@ def lr_find(self,
trainer.fit(model)

"""
if num_accumulation_steps is not None:
rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
" since v0.7.6 and will be removed in 0.9. Please"
" set trainer argument `accumulate_grad_batches` instead.",
DeprecationWarning)

save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

self.__lr_finder_dump_params(model)
Expand All @@ -115,7 +128,9 @@ def lr_find(self,
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

# Use special lr logger callback
self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)]
self.callbacks = [_LRCallback(num_training,
early_stop_threshold,
progress_bar_refresh_rate=1)]

# No logging
self.logger = None
Expand All @@ -127,9 +142,6 @@ def lr_find(self,
if self.progress_bar_callback:
self.progress_bar_callback.disable()

# Accumulation of gradients
self.accumulate_grad_batches = num_accumulation_steps

# Disable standard checkpoint & early stopping
self.checkpoint_callback = False
self.early_stop_callback = None
Expand All @@ -149,7 +161,6 @@ def lr_find(self,
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')
configure_optimizers = model.configure_optimizers
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])

# Fit, lr & loss logged in callback
Expand All @@ -164,6 +175,7 @@ def lr_find(self,
# Transfer results from callback to lr finder object
lr_finder.results.update({'lr': self.callbacks[0].lrs,
'loss': self.callbacks[0].losses})
lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose

# Reset model state
self.restore(str(save_path), on_gpu=self.on_gpu)
Expand All @@ -184,7 +196,6 @@ def __lr_finder_dump_params(self, model):
'logger': self.logger,
'max_steps': self.max_steps,
'progress_bar_refresh_rate': self.progress_bar_refresh_rate,
'accumulate_grad_batches': self.accumulate_grad_batches,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
Expand All @@ -198,7 +209,6 @@ def __lr_finder_restore_params(self, model):
self.callbacks = self.__dumped_params['callbacks']
self.max_steps = self.__dumped_params['max_steps']
self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
Expand Down Expand Up @@ -242,6 +252,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.num_training = num_training

self.results = {}
self._total_batch_idx = 0 # for debug purpose

def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
""" Construct a new `configure_optimizers()` method, that has a optimizer
Expand Down Expand Up @@ -298,30 +309,49 @@ def plot(self, suggest: bool = False, show: bool = False):

return fig

def suggestion(self):
def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
""" This will propose a suggestion for choice of initial learning rate
as the point with the steepest negative gradient.

Returns:
lr: suggested initial learning rate to use
skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
skip_end: how many samples to skip in the end. Prevent too optimistic estimates

"""
try:
min_grad = (np.gradient(np.array(self.results["loss"]))).argmin()
self._optimal_idx = min_grad
return self.results["lr"][min_grad]
loss = self.results["loss"][skip_begin:-skip_end]
min_grad = (np.gradient(np.array(loss))).argmin()
self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]
except Exception:
log.warning('Failed to compute suggesting for `lr`.'
' There might not be enough points.')
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
self._optimal_idx = None


class _LRCallback(Callback):
""" Special callback used by the learning rate finder. This callbacks log
the learning rate before each batch and log the corresponding loss after
each batch. """
def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98):
each batch.

Args:
num_training: number of iterations done by the learning rate finder
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than ``early_stop_threshold*best_loss``
then the search is stopped. To disable, set to ``None``.
progress_bar_refresh_rate: rate to refresh the progress bar for
the learning rate finder
beta: smoothing value, the loss being logged is a running average of
loss values logged until now. ``beta`` controls the forget rate i.e.
if ``beta=0`` all past information is ignored.

"""
def __init__(self, num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: bool = False,
beta: float = 0.98):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses = []
self.lrs = []
Expand All @@ -332,13 +362,19 @@ def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, b

def on_batch_start(self, trainer, pl_module):
""" Called before each training batch, logs the lr that will be used """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

def on_batch_end(self, trainer, pl_module):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar:
self.progress_bar.update()

Expand All @@ -350,10 +386,11 @@ def on_batch_end(self, trainer, pl_module):
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)

# Check if we diverging
if current_step > 1 and smoothed_loss > 4 * self.best_loss:
trainer.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()

# Save best loss for diverging checking
if smoothed_loss < self.best_loss or current_step == 1:
Expand Down
56 changes: 51 additions & 5 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ def test_trainer_reset_correctly(tmpdir):


def test_trainer_arg_bool(tmpdir):

""" Test that setting trainer arg to bool works """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
before_lr = hparams.learning_rate

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
max_epochs=5,
auto_lr_find=True
)

Expand All @@ -95,7 +95,7 @@ def test_trainer_arg_bool(tmpdir):


def test_trainer_arg_str(tmpdir):

""" Test that setting trainer arg to string works """
hparams = EvalModelTemplate.get_default_hparams()
hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field
model = EvalModelTemplate(hparams)
Expand All @@ -104,7 +104,7 @@ def test_trainer_arg_str(tmpdir):
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
max_epochs=5,
auto_lr_find='my_fancy_lr'
)

Expand All @@ -115,6 +115,7 @@ def test_trainer_arg_str(tmpdir):


def test_call_to_trainer_method(tmpdir):
""" Test that directly calling the trainer method works """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
Expand All @@ -123,7 +124,7 @@ def test_call_to_trainer_method(tmpdir):
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
max_epochs=5,
)

lrfinder = trainer.lr_find(model, mode='linear')
Expand All @@ -133,3 +134,48 @@ def test_call_to_trainer_method(tmpdir):

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'


def test_accumulation_and_early_stopping(tmpdir):
""" Test that early stopping of learning rate finder works, and that
accumulation also works for this feature """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

before_lr = hparams.learning_rate
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
accumulate_grad_batches=2
)

lrfinder = trainer.lr_find(model, early_stop_threshold=None)
after_lr = lrfinder.suggestion()

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
assert len(lrfinder.results['lr']) == 100, \
'Early stopping for learning rate finder did not work'
assert lrfinder._total_batch_idx == 100 * 2, \
'Accumulation parameter did not work'


def test_suggestion_parameters_work(tmpdir):
""" Test that default skipping does not alter results in basic case """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=10,
)

lrfinder = trainer.lr_find(model)
lr1 = lrfinder.suggestion(skip_begin=10) # default
lr2 = lrfinder.suggestion(skip_begin=80) # way too high, should have an impact

assert lr1 != lr2, \
'Skipping parameter did not influence learning rate'