From 27a3be0287aaf71a271c26fee8daeb0e8ebf1c5c Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 27 Feb 2020 15:46:47 -0500 Subject: [PATCH] TPU gradient clipping. (#963) * clip * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec * pull out epsilon * add fp16 case * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/training_tricks.py | 26 +++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7fa4059afc3e2..6171e487e74d7 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -2,9 +2,13 @@ from abc import ABC, abstractmethod import torch +import math from pytorch_lightning.callbacks import GradientAccumulationScheduler +EPSILON = 1e-6 +EPSILON_FP16 = 1e-5 + class TrainerTrainingTricksMixin(ABC): @@ -19,9 +23,29 @@ def get_model(self): pass def clip_gradients(self): + # this code is a modification of torch.nn.utils.clip_grad_norm_ + # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md if self.gradient_clip_val > 0: model = self.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val) + parameters = model.parameters() + max_norm = float(self.gradient_clip_val) + norm_type = float(2.0) + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + if norm_type == math.inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + else: + device = parameters[0].device + total_norm = torch.zeros([], device=device if parameters else None) + for p in parameters: + param_norm = p.grad.data.norm(norm_type) ** norm_type + total_norm.add_(param_norm) + total_norm = (total_norm ** (1. / norm_type)) + eps = EPSILON_FP16 if self.precision == 16 else EPSILON + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + for p in parameters: + p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) def print_nan_gradients(self): model = self.get_model()