Skip to content

Commit

Permalink
Move NaN/Inf detection to a separate utilities file (#6834)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 8, 2021
1 parent 90e37ba commit 851f9e3
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 16 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))


- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))

Expand Down Expand Up @@ -113,6 +115,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))


- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -636,7 +637,7 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list:

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(opt_closure_result.loss)
self._check_finite(opt_closure_result.loss)

# track all the outputs across all steps
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
Expand Down Expand Up @@ -678,7 +679,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)
self._check_finite(result.loss)

else:
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
Expand All @@ -689,6 +690,12 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,

return result

def _check_finite(self, loss: torch.Tensor) -> None:
if not torch.isfinite(loss).all():
raise ValueError(f'The loss returned in `training_step` is {loss}.')
model = self.trainer.lightning_module
detect_nan_parameters(model)

def backward(self, result, optimizer, opt_idx, *args, **kwargs):
self.trainer.dev_debugger.track_event("backward_call")

Expand Down
33 changes: 20 additions & 13 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,42 @@
from torch import Tensor

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients

EPSILON = 1e-6
EPSILON_FP16 = 1e-5
log = logging.getLogger(__name__)


class TrainerTrainingTricksMixin(ABC):
"""
TODO: Remove this class in v1.5.
Use the NaN utilities from ``pytorch_lightning.utilities.finite_checks`` instead.
"""

# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
lightning_module: LightningModule

def print_nan_gradients(self) -> None:
rank_zero_deprecation(
"Internal: TrainerTrainingTricksMixin.print_nan_gradients is deprecated in v1.3"
" and will be removed in v1.5."
" Use `pytorch_lightning.utilities.finite_checks.print_nan_gradients` instead."
)
model = self.lightning_module
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)
print_nan_gradients(model)

def detect_nan_tensors(self, loss: Tensor) -> None:
model = self.lightning_module

rank_zero_deprecation(
"Internal: TrainerTrainingTricksMixin.detect_nan_tensors is deprecated in v1.3"
" and will be removed in v1.5."
" Use `pytorch_lightning.utilities.finite_checks.detect_nan_parameters` instead."
)
# check if loss is nan
if not torch.isfinite(loss).all():
raise ValueError('The loss returned in `training_step` is nan or inf.')
# check if a network weight is nan
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
self.print_nan_gradients()
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)
model = self.lightning_module
detect_nan_parameters(model)
39 changes: 39 additions & 0 deletions pytorch_lightning/utilities/finite_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to detect NaN/Inf values. """

import logging

import torch
import torch.nn as nn

log = logging.getLogger(__name__)


def print_nan_gradients(model: nn.Module) -> None:
""" Iterates over model parameters and prints out parameter + gradient information if NaN. """
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)


def detect_nan_parameters(model: nn.Module) -> None:
""" Iterates over model parameters and prints gradients if any parameter is not finite. """
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
print_nan_gradients(model)
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)
13 changes: 13 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest import mock

import pytest
import torch
from torch import optim

from pytorch_lightning import Callback, Trainer
Expand Down Expand Up @@ -218,3 +219,15 @@ def test_v1_5_0_profiler_output_filename(tmpdir, cls):
profiler = cls(output_filename=filepath)
assert profiler.dirpath == tmpdir
assert profiler.filename == "test"


def test_v1_5_0_trainer_training_trick_mixin(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False)
trainer.fit(model)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
trainer.print_nan_gradients()

dummy_loss = torch.tensor(1.0)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
trainer.detect_nan_tensors(dummy_loss)
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
terminate_on_nan=True,
)

with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is nan or inf.*"):
with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"):
trainer.fit(model)
assert trainer.global_step == model.test_step_inf_loss

Expand Down

0 comments on commit 851f9e3

Please sign in to comment.