Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move NaN/Inf detection to a separate utilities file #6834

Merged
merged 17 commits into from
Apr 8, 2021
Merged
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))


- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))

Expand Down Expand Up @@ -104,6 +106,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))


- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.nan import detect_nan_parameters
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -636,7 +637,7 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list:

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(opt_closure_result.loss)
self._check_nan(opt_closure_result.loss)

# track all the outputs across all steps
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
Expand Down Expand Up @@ -678,7 +679,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)
self._check_nan(result.loss)

else:
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
Expand All @@ -689,6 +690,12 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,

return result

def _check_nan(self, loss: torch.Tensor) -> None:
if not torch.isfinite(loss).all():
raise ValueError('The loss returned in `training_step` is nan or inf.')
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
model = self.trainer.lightning_module
detect_nan_parameters(model)

def backward(self, result, optimizer, opt_idx, *args, **kwargs):
self.trainer.dev_debugger.track_event("backward_call")

Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import Tensor

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_deprecation

EPSILON = 1e-6
EPSILON_FP16 = 1e-5
Expand All @@ -32,12 +33,23 @@ class TrainerTrainingTricksMixin(ABC):
lightning_module: LightningModule

def print_nan_gradients(self) -> None:
rank_zero_deprecation(
"Internal: TrainerTrainingTricksMixin.print_nan_gradients is deprecated in v1.3"
" and will be removed in v1.5."
" Use `pytorch_lightning.utilities.nan.print_nan_gradients` instead."
)

model = self.lightning_module
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)

def detect_nan_tensors(self, loss: Tensor) -> None:
rank_zero_deprecation(
"Internal: TrainerTrainingTricksMixin.detect_nan_tensors is deprecated in v1.3"
" and will be removed in v1.5."
" Use `pytorch_lightning.utilities.nan.detect_nan_parameters` instead."
)
model = self.lightning_module

# check if loss is nan
Expand Down
39 changes: 39 additions & 0 deletions pytorch_lightning/utilities/nan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to detect NaN/Inf values. """

import logging

import torch
import torch.nn as nn

log = logging.getLogger(__name__)


def print_nan_gradients(model: nn.Module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of these improvements? (my code)

https://github.com/jpuigcerver/PyLaia/blob/master/laia/utils/checks.py

  • Optional exception
  • Include non finite percentages

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be useful for some users with unstable losses. Worth to consider.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be useful! Maybe we can add it in a follow up PR? this one is mainly moving code around for parity

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you prefer, sure

""" Iterates over model parameters and prints out parameter + gradient information if NaN. """
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)


def detect_nan_parameters(model: nn.Module) -> None:
""" Iterates over model parameters and prints gradients if any parameter is not finite. """
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
print_nan_gradients(model)
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)