Skip to content

Commit

Permalink
TPU gradient clipping. (#963)
Browse files Browse the repository at this point in the history
* clip

* Update pytorch_lightning/trainer/training_tricks.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/training_tricks.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* pull out epsilon

* add fp16 case

* Update pytorch_lightning/trainer/training_tricks.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
srush and Borda authored Feb 27, 2020
1 parent b941845 commit 27a3be0
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from abc import ABC, abstractmethod

import torch
import math

from pytorch_lightning.callbacks import GradientAccumulationScheduler

EPSILON = 1e-6
EPSILON_FP16 = 1e-5


class TrainerTrainingTricksMixin(ABC):

Expand All @@ -19,9 +23,29 @@ def get_model(self):
pass

def clip_gradients(self):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if self.gradient_clip_val > 0:
model = self.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
parameters = model.parameters()
max_norm = float(self.gradient_clip_val)
norm_type = float(2.0)
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
total_norm = torch.zeros([], device=device if parameters else None)
for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type
total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
for p in parameters:
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))

def print_nan_gradients(self):
model = self.get_model()
Expand Down

0 comments on commit 27a3be0

Please sign in to comment.